diff --git a/internal/disk/disk_test.go b/internal/disk/disk_test.go index 499f70fe5..9b5fa451a 100644 --- a/internal/disk/disk_test.go +++ b/internal/disk/disk_test.go @@ -471,3 +471,105 @@ func TestClone(t *testing.T) { } } } + +func TestFindDirectoryPartition(t *testing.T) { + assert := assert.New(t) + usr := Partition{ + Type: FilesystemDataGUID, + UUID: RootPartitionUUID, + Payload: &Filesystem{ + Type: "xfs", + Label: "root", + Mountpoint: "/usr", + FSTabOptions: "defaults", + FSTabFreq: 0, + FSTabPassNo: 0, + }, + } + + { + pt := testPartitionTables["plain"] + assert.Equal("/", pt.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot/efi", pt.findDirectoryEntityPath("/boot/efi/Linux")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot/loader")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot")[0].(Mountable).GetMountpoint()) + + ptMod := pt.Clone().(*PartitionTable) + ptMod.Partitions = append(ptMod.Partitions, usr) + assert.Equal("/", ptMod.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr/bin")[0].(Mountable).GetMountpoint()) + + // invalid dir should return nil + assert.Nil(pt.findDirectoryEntityPath("invalid")) + } + + { + pt := testPartitionTables["plain-noboot"] + assert.Equal("/", pt.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/", pt.findDirectoryEntityPath("/boot")[0].(Mountable).GetMountpoint()) + assert.Equal("/", pt.findDirectoryEntityPath("/boot/loader")[0].(Mountable).GetMountpoint()) + + ptMod := pt.Clone().(*PartitionTable) + ptMod.Partitions = append(ptMod.Partitions, usr) + assert.Equal("/", ptMod.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr/bin")[0].(Mountable).GetMountpoint()) + + // invalid dir should return nil + assert.Nil(pt.findDirectoryEntityPath("invalid")) + } + + { + pt := testPartitionTables["luks"] + assert.Equal("/", pt.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot/loader")[0].(Mountable).GetMountpoint()) + + ptMod := pt.Clone().(*PartitionTable) + ptMod.Partitions = append(ptMod.Partitions, usr) + assert.Equal("/", ptMod.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr/bin")[0].(Mountable).GetMountpoint()) + + // invalid dir should return nil + assert.Nil(pt.findDirectoryEntityPath("invalid")) + } + + { + pt := testPartitionTables["luks+lvm"] + assert.Equal("/", pt.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot/loader")[0].(Mountable).GetMountpoint()) + + ptMod := pt.Clone().(*PartitionTable) + ptMod.Partitions = append(ptMod.Partitions, usr) + assert.Equal("/", ptMod.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr/bin")[0].(Mountable).GetMountpoint()) + + // invalid dir should return nil + assert.Nil(pt.findDirectoryEntityPath("invalid")) + } + + { + pt := testPartitionTables["btrfs"] + assert.Equal("/", pt.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot")[0].(Mountable).GetMountpoint()) + assert.Equal("/boot", pt.findDirectoryEntityPath("/boot/loader")[0].(Mountable).GetMountpoint()) + + ptMod := pt.Clone().(*PartitionTable) + ptMod.Partitions = append(ptMod.Partitions, usr) + assert.Equal("/", ptMod.findDirectoryEntityPath("/opt")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr")[0].(Mountable).GetMountpoint()) + assert.Equal("/usr", ptMod.findDirectoryEntityPath("/usr/bin")[0].(Mountable).GetMountpoint()) + + // invalid dir should return nil + assert.Nil(pt.findDirectoryEntityPath("invalid")) + } + + { + pt := PartitionTable{} // pt with no root should return nil + assert.Nil(pt.findDirectoryEntityPath("/var")) + } +} diff --git a/internal/disk/partition_table.go b/internal/disk/partition_table.go index 88ad5f82a..9a94bcad2 100644 --- a/internal/disk/partition_table.go +++ b/internal/disk/partition_table.go @@ -3,6 +3,7 @@ package disk import ( "fmt" "math/rand" + "path/filepath" "github.com/google/uuid" "github.com/osbuild/osbuild-composer/internal/blueprint" @@ -161,6 +162,21 @@ func (pt *PartitionTable) EnsureSize(s uint64) bool { return false } +func (pt *PartitionTable) findDirectoryEntityPath(dir string) []Entity { + if path := entityPath(pt, dir); path != nil { + return path + } + + parent := filepath.Dir(dir) + if dir == parent { + // invalid dir or pt has no root + return nil + } + + // move up the directory path and check again + return pt.findDirectoryEntityPath(parent) +} + func (pt *PartitionTable) CreateMountpoint(mountpoint string, size uint64) (Entity, error) { filesystem := Filesystem{ Type: "xfs",