From 221cdedebc9cf60deece0b7ad512f846d6909551 Mon Sep 17 00:00:00 2001 From: Achilleas Koutsou Date: Tue, 8 Feb 2022 12:22:46 +0100 Subject: [PATCH] disk: Entity types must implement Clone() All disk.Entitity types now implement Clone() which should return a deep copy of the same object. Add the Clone() method to the entity interface. The return type is Entity, but callers can assume it's safe to convert back to the original type. Co-Authored-By: Christian Kellner --- internal/disk/btrfs.go | 39 ++++++++++++++++++++++++ internal/disk/customizations.go | 5 ++- internal/disk/disk.go | 3 ++ internal/disk/filesystem.go | 2 +- internal/disk/luks.go | 21 +++++++++++++ internal/disk/lvm.go | 35 +++++++++++++++++++++ internal/disk/partition.go | 25 +++++++++++++++ internal/disk/partition_table.go | 52 ++++++++++++++++++-------------- 8 files changed, 158 insertions(+), 24 deletions(-) diff --git a/internal/disk/btrfs.go b/internal/disk/btrfs.go index bf5c33637..44d5b2ea0 100644 --- a/internal/disk/btrfs.go +++ b/internal/disk/btrfs.go @@ -16,6 +16,30 @@ func (b *Btrfs) IsContainer() bool { return true } +func (b *Btrfs) Clone() Entity { + if b == nil { + return nil + } + + clone := &Btrfs{ + UUID: b.UUID, + Label: b.Label, + Mountpoint: b.Mountpoint, + Subvolumes: make([]BtrfsSubvolume, len(b.Subvolumes)), + } + + for idx, subvol := range b.Subvolumes { + entClone := subvol.Clone() + svClone, cloneOk := entClone.(*BtrfsSubvolume) + if !cloneOk { + panic("BtrfsSubvolume.Clone() returned an Entity that cannot be converted to *BtrfsSubvolume; this is a programming error") + } + clone.Subvolumes[idx] = *svClone + } + + return clone +} + func (b *Btrfs) GetItemCount() uint { return uint(len(b.Subvolumes)) } @@ -56,6 +80,21 @@ func (subvol *BtrfsSubvolume) IsContainer() bool { return false } +func (bs *BtrfsSubvolume) Clone() Entity { + if bs == nil { + return nil + } + + return &BtrfsSubvolume{ + Name: bs.Name, + Size: bs.Size, + Mountpoint: bs.Mountpoint, + GroupID: bs.GroupID, + MntOps: bs.MntOps, + UUID: bs.UUID, + } +} + func (bs *BtrfsSubvolume) GetSize() uint64 { return bs.Size } diff --git a/internal/disk/customizations.go b/internal/disk/customizations.go index abb892805..c5c56c96e 100644 --- a/internal/disk/customizations.go +++ b/internal/disk/customizations.go @@ -30,7 +30,10 @@ func CreatePartitionTable( // we are modifying the contents of the base partition table, // including the file systems, which are shared among shallow // copies of the partition table, so make a copy first - table := basePartitionTable.Clone() + table, cloneOk := basePartitionTable.Clone().(*PartitionTable) + if !cloneOk { + panic("PartitionTable.Clone() returned an Entity that cannot be converted to *PartitionTable; this is a programming error") + } for _, m := range mountpoints { // if we already have a partition ensure that the diff --git a/internal/disk/disk.go b/internal/disk/disk.go index 19c6335f1..ff5bcafd1 100644 --- a/internal/disk/disk.go +++ b/internal/disk/disk.go @@ -35,6 +35,9 @@ type Entity interface { // IsContainer indicates if the implementing type can // contain any other entities. IsContainer() bool + + // Clone returns a deep copy of the entity. + Clone() Entity } // Container is the interface for entities that can contain other entities. diff --git a/internal/disk/filesystem.go b/internal/disk/filesystem.go index 3f527ba9f..97f5f4bc7 100644 --- a/internal/disk/filesystem.go +++ b/internal/disk/filesystem.go @@ -23,7 +23,7 @@ func (fs *Filesystem) IsContainer() bool { } // Clone the filesystem structure -func (fs *Filesystem) Clone() *Filesystem { +func (fs *Filesystem) Clone() Entity { if fs == nil { return nil } diff --git a/internal/disk/luks.go b/internal/disk/luks.go index 37e6b54e0..194db94ca 100644 --- a/internal/disk/luks.go +++ b/internal/disk/luks.go @@ -38,3 +38,24 @@ func (lc *LUKSContainer) GetChild(n uint) Entity { } return lc.Payload } + +func (lc *LUKSContainer) Clone() Entity { + if lc == nil { + return nil + } + + return &LUKSContainer{ + Passphrase: lc.Passphrase, + UUID: lc.UUID, + Cipher: lc.Cipher, + Label: lc.Label, + Subsystem: lc.Subsystem, + SectorSize: lc.SectorSize, + PBKDF: Argon2id{ + Iterations: lc.PBKDF.Iterations, + Memory: lc.PBKDF.Memory, + Parallelism: lc.PBKDF.Parallelism, + }, + Payload: lc.Payload, + } +} diff --git a/internal/disk/lvm.go b/internal/disk/lvm.go index 99c9223e0..526d2de28 100644 --- a/internal/disk/lvm.go +++ b/internal/disk/lvm.go @@ -13,6 +13,33 @@ func (vg *LVMVolumeGroup) IsContainer() bool { return true } +func (vg *LVMVolumeGroup) Clone() Entity { + if vg == nil { + return nil + } + + clone := &LVMVolumeGroup{ + Name: vg.Name, + Description: vg.Description, + LogicalVolumes: make([]LVMLogicalVolume, len(vg.LogicalVolumes)), + } + + for idx, lv := range vg.LogicalVolumes { + ent := lv.Clone() + var lv *LVMLogicalVolume + if ent != nil { + lvEnt, cloneOk := ent.(*LVMLogicalVolume) + if !cloneOk { + panic("LVMLogicalVolume.Clone() returned an Entity that cannot be converted to *LVMLogicalVolume; this is a programming error") + } + lv = lvEnt + } + clone.LogicalVolumes[idx] = *lv + } + + return clone +} + func (vg *LVMVolumeGroup) GetItemCount() uint { return uint(len(vg.LogicalVolumes)) } @@ -50,6 +77,14 @@ func (lv *LVMLogicalVolume) IsContainer() bool { return true } +func (lv *LVMLogicalVolume) Clone() Entity { + return &LVMLogicalVolume{ + Name: lv.Name, + Size: lv.Size, + Payload: lv.Payload, + } +} + func (lv *LVMLogicalVolume) GetItemCount() uint { if lv.Payload == nil { return 0 diff --git a/internal/disk/partition.go b/internal/disk/partition.go index 6467293cb..d02f81dca 100644 --- a/internal/disk/partition.go +++ b/internal/disk/partition.go @@ -22,6 +22,31 @@ func (p *Partition) IsContainer() bool { return true } +func (p *Partition) Clone() Entity { + if p == nil { + return p + } + + ent := p.Filesystem.Clone() + var fs *Filesystem + if ent != nil { + fsEnt, cloneOk := ent.(*Filesystem) + if !cloneOk { + panic("Filesystem.Clone() returned an Entity that cannot be converted to *Filesystem; this is a programming error") + } + fs = fsEnt + } + + return &Partition{ + Start: p.Start, + Size: p.Size, + Type: p.Type, + Bootable: p.Bootable, + UUID: p.UUID, + Filesystem: fs, + } +} + // Converts Partition to osbuild.QEMUPartition that encodes the same partition. func (p *Partition) QEMUPartition() osbuild.QEMUPartition { var fs *osbuild.QEMUFilesystem diff --git a/internal/disk/partition_table.go b/internal/disk/partition_table.go index a7cf63431..9b5fc65d1 100644 --- a/internal/disk/partition_table.go +++ b/internal/disk/partition_table.go @@ -25,6 +25,36 @@ func (pt *PartitionTable) IsContainer() bool { return true } +func (pt *PartitionTable) Clone() Entity { + if pt == nil { + return nil + } + + clone := &PartitionTable{ + Size: pt.Size, + UUID: pt.UUID, + Type: pt.Type, + Partitions: make([]Partition, len(pt.Partitions)), + SectorSize: pt.SectorSize, + ExtraPadding: pt.ExtraPadding, + } + + for idx, partition := range pt.Partitions { + ent := partition.Clone() + var part *Partition + + if ent != nil { + pEnt, cloneOk := ent.(*Partition) + if !cloneOk { + panic("PartitionTable.Clone() returned an Entity that cannot be converted to *PartitionTable; this is a programming error") + } + part = pEnt + } + clone.Partitions[idx] = *part + } + return clone +} + // AlignUp will align the given bytes to next aligned grain if not already // aligned func (pt *PartitionTable) AlignUp(size uint64) uint64 { @@ -54,28 +84,6 @@ func (pt *PartitionTable) SectorsToBytes(size uint64) uint64 { return size * sectorSize } -// Clone the partition table (deep copy). -func (pt *PartitionTable) Clone() *PartitionTable { - if pt == nil { - return nil - } - - var partitions []Partition - for _, p := range pt.Partitions { - p.Filesystem = p.Filesystem.Clone() - partitions = append(partitions, p) - } - return &PartitionTable{ - Size: pt.Size, - UUID: pt.UUID, - Type: pt.Type, - Partitions: partitions, - - SectorSize: pt.SectorSize, - ExtraPadding: pt.ExtraPadding, - } -} - // Converts PartitionTable to osbuild.QEMUAssemblerOptions that encode // the same partition table. func (pt *PartitionTable) QEMUAssemblerOptions() osbuild.QEMUAssemblerOptions {