disk: guard against nil pointer dereferencing

Add guards protecting against dereferencing the pointer receiver of
method calls.

Co-Authored-By: Achilleas Koutsou <achilleas@koutsou.net>
This commit is contained in:
Christian Kellner 2022-02-19 13:28:40 +01:00 committed by Tom Gundersen
parent beaf411628
commit 206e030f2c
6 changed files with 51 additions and 7 deletions

View file

@ -105,6 +105,9 @@ func (bs *BtrfsSubvolume) Clone() Entity {
}
func (bs *BtrfsSubvolume) GetSize() uint64 {
if bs == nil {
return 0
}
return bs.Size
}
@ -117,6 +120,9 @@ func (bs *BtrfsSubvolume) EnsureSize(s uint64) bool {
}
func (bs *BtrfsSubvolume) GetMountpoint() string {
if bs == nil {
return ""
}
return bs.Mountpoint
}
@ -125,6 +131,9 @@ func (bs *BtrfsSubvolume) GetFSType() string {
}
func (bs *BtrfsSubvolume) GetFSSpec() FSSpec {
if bs == nil {
return FSSpec{}
}
return FSSpec{
UUID: bs.UUID,
Label: bs.Name,
@ -132,8 +141,11 @@ func (bs *BtrfsSubvolume) GetFSSpec() FSSpec {
}
func (bs *BtrfsSubvolume) GetFSTabOptions() FSTabOptions {
ops := strings.Join([]string{bs.MntOps, fmt.Sprintf("subvol=%s", bs.Name)}, ",")
if bs == nil {
return FSTabOptions{}
}
ops := strings.Join([]string{bs.MntOps, fmt.Sprintf("subvol=%s", bs.Name)}, ",")
return FSTabOptions{
MntOps: ops,
Freq: 0,

View file

@ -24,8 +24,6 @@ import (
"github.com/google/uuid"
)
// TODO: guard against nil dereferencing in pointer methods
const (
// Default sector size in bytes
DefaultSectorSize = 512

View file

@ -44,14 +44,23 @@ func (fs *Filesystem) Clone() Entity {
}
func (fs *Filesystem) GetMountpoint() string {
if fs == nil {
return ""
}
return fs.Mountpoint
}
func (fs *Filesystem) GetFSType() string {
if fs == nil {
return ""
}
return fs.Type
}
func (fs *Filesystem) GetFSSpec() FSSpec {
if fs == nil {
return FSSpec{}
}
return FSSpec{
UUID: fs.UUID,
Label: fs.Label,
@ -59,6 +68,9 @@ func (fs *Filesystem) GetFSSpec() FSSpec {
}
func (fs *Filesystem) GetFSTabOptions() FSTabOptions {
if fs == nil {
return FSTabOptions{}
}
return FSTabOptions{
MntOps: fs.FSTabOptions,
Freq: fs.FSTabFreq,

View file

@ -66,6 +66,10 @@ func (lc *LUKSContainer) Clone() Entity {
}
func (lc *LUKSContainer) GenUUID(rng *rand.Rand) {
if lc == nil {
return
}
if lc.UUID == "" {
lc.UUID = uuid.Must(newRandomUUIDFromReader(rng)).String()
}

View file

@ -41,14 +41,23 @@ func (vg *LVMVolumeGroup) Clone() Entity {
}
func (vg *LVMVolumeGroup) GetItemCount() uint {
if vg == nil {
return 0
}
return uint(len(vg.LogicalVolumes))
}
func (vg *LVMVolumeGroup) GetChild(n uint) Entity {
if vg == nil {
panic("LVMVolumeGroup.GetChild: nil entity")
}
return &vg.LogicalVolumes[n]
}
func (vg *LVMVolumeGroup) CreateVolume(mountpoint string, size uint64) (Entity, error) {
if vg == nil {
panic("LVMVolumeGroup.CreateVolume: nil entity")
}
filesystem := Filesystem{
Type: "xfs",
Mountpoint: mountpoint,
@ -78,6 +87,9 @@ func (lv *LVMLogicalVolume) IsContainer() bool {
}
func (lv *LVMLogicalVolume) Clone() Entity {
if lv == nil {
return nil
}
return &LVMLogicalVolume{
Name: lv.Name,
Size: lv.Size,
@ -86,24 +98,30 @@ func (lv *LVMLogicalVolume) Clone() Entity {
}
func (lv *LVMLogicalVolume) GetItemCount() uint {
if lv.Payload == nil {
if lv == nil || lv.Payload == nil {
return 0
}
return 1
}
func (lv *LVMLogicalVolume) GetChild(n uint) Entity {
if n != 0 {
if n != 0 || lv == nil {
panic(fmt.Sprintf("invalid child index for LVMLogicalVolume: %d != 0", n))
}
return lv.Payload
}
func (lv *LVMLogicalVolume) GetSize() uint64 {
if lv == nil {
return 0
}
return lv.Size
}
func (lv *LVMLogicalVolume) EnsureSize(s uint64) bool {
if lv == nil {
panic("LVMLogicalVolume.EnsureSize: nil entity")
}
if s > lv.Size {
lv.Size = s
return true

View file

@ -24,7 +24,7 @@ func (p *Partition) IsContainer() bool {
func (p *Partition) Clone() Entity {
if p == nil {
return p
return nil
}
ent := p.Payload.Clone()
@ -48,7 +48,7 @@ func (p *Partition) Clone() Entity {
}
func (pt *Partition) GetItemCount() uint {
if pt.Payload == nil {
if pt == nil || pt.Payload == nil {
return 0
}
return 1