diff --git a/internal/osbuild2/modprobe_stage.go b/internal/osbuild2/modprobe_stage.go index 5fac1dad0..caedbbcbe 100644 --- a/internal/osbuild2/modprobe_stage.go +++ b/internal/osbuild2/modprobe_stage.go @@ -3,8 +3,11 @@ package osbuild2 import ( "encoding/json" "fmt" + "regexp" ) +const modprobeCfgFilenameRegex = "^[\\w.-]{1,250}\\.conf$" + type ModprobeStageOptions struct { Filename string `json:"filename"` Commands ModprobeConfigCmdList `json:"commands"` @@ -12,7 +15,24 @@ type ModprobeStageOptions struct { func (ModprobeStageOptions) isStageOptions() {} +func (o ModprobeStageOptions) validate() error { + if len(o.Commands) == 0 { + return fmt.Errorf("at least one command is required") + } + + nameRegex := regexp.MustCompile(modprobeCfgFilenameRegex) + if !nameRegex.MatchString(o.Filename) { + return fmt.Errorf("modprobe configuration filename %q doesn't conform to schema (%s)", o.Filename, nameRegex.String()) + } + + return nil +} + func NewModprobeStage(options *ModprobeStageOptions) *Stage { + if err := options.validate(); err != nil { + panic(err) + } + return &Stage{ Type: "org.osbuild.modprobe", Options: options, @@ -75,14 +95,6 @@ func (configFile *ModprobeConfigCmdList) UnmarshalJSON(data []byte) error { return nil } -func (o ModprobeConfigCmdList) MarshalJSON() ([]byte, error) { - if len(o) == 0 { - return nil, fmt.Errorf("at least one modprobe command must be specified for a configuration file") - } - var configList []ModprobeConfigCmd = o - return json.Marshal(configList) -} - // ModprobeConfigCmdBlacklist represents the 'blacklist' command in the // modprobe configuration. type ModprobeConfigCmdBlacklist struct { @@ -92,13 +104,27 @@ type ModprobeConfigCmdBlacklist struct { func (ModprobeConfigCmdBlacklist) isModprobeConfigCmd() {} +func (c ModprobeConfigCmdBlacklist) validate() error { + if c.Command != "blacklist" { + return fmt.Errorf("'command' must have 'blacklist' value set") + } + if c.Modulename == "" { + return fmt.Errorf("'modulename' must not be empty") + } + return nil +} + // NewModprobeConfigCmdBlacklist creates a new instance of ModprobeConfigCmdBlacklist // for the provided modulename. func NewModprobeConfigCmdBlacklist(modulename string) *ModprobeConfigCmdBlacklist { - return &ModprobeConfigCmdBlacklist{ + cmd := &ModprobeConfigCmdBlacklist{ Command: "blacklist", Modulename: modulename, } + if err := cmd.validate(); err != nil { + panic(err) + } + return cmd } // ModprobeConfigCmdInstall represents the 'install' command in the @@ -111,12 +137,29 @@ type ModprobeConfigCmdInstall struct { func (ModprobeConfigCmdInstall) isModprobeConfigCmd() {} +func (c ModprobeConfigCmdInstall) validate() error { + if c.Command != "install" { + return fmt.Errorf("'command' must have 'install' value set") + } + if c.Modulename == "" { + return fmt.Errorf("'modulename' must not be empty") + } + if c.Cmdline == "" { + return fmt.Errorf("'cmdline' must not be empty") + } + return nil +} + // NewModprobeConfigCmdInstall creates a new instance of ModprobeConfigCmdInstall // for the provided modulename. func NewModprobeConfigCmdInstall(modulename, cmdline string) *ModprobeConfigCmdInstall { - return &ModprobeConfigCmdInstall{ + cmd := &ModprobeConfigCmdInstall{ Command: "install", Modulename: modulename, Cmdline: cmdline, } + if err := cmd.validate(); err != nil { + panic(err) + } + return cmd } diff --git a/internal/osbuild2/modprobe_stage_test.go b/internal/osbuild2/modprobe_stage_test.go index 18ca6aa27..576cb04b6 100644 --- a/internal/osbuild2/modprobe_stage_test.go +++ b/internal/osbuild2/modprobe_stage_test.go @@ -1,29 +1,36 @@ package osbuild2 import ( - "encoding/json" "testing" "github.com/stretchr/testify/assert" ) func TestNewModprobeStage(t *testing.T) { + stageOptions := &ModprobeStageOptions{ + Filename: "testing.conf", + Commands: ModprobeConfigCmdList{ + NewModprobeConfigCmdBlacklist("testing_module"), + }, + } expectedStage := &Stage{ Type: "org.osbuild.modprobe", - Options: &ModprobeStageOptions{}, + Options: stageOptions, } - actualStage := NewModprobeStage(&ModprobeStageOptions{}) + actualStage := NewModprobeStage(stageOptions) assert.Equal(t, expectedStage, actualStage) } -func TestModprobeStage_MarshalJSON_Invalid(t *testing.T) { +func TestModprobeStageOptionsValidate(t *testing.T) { tests := []struct { name string options ModprobeStageOptions + err bool }{ { name: "empty-options", options: ModprobeStageOptions{}, + err: true, }, { name: "no-commands", @@ -31,12 +38,110 @@ func TestModprobeStage_MarshalJSON_Invalid(t *testing.T) { Filename: "disallow-modules.conf", Commands: ModprobeConfigCmdList{}, }, + err: true, + }, + { + name: "no-filename", + options: ModprobeStageOptions{ + Commands: ModprobeConfigCmdList{NewModprobeConfigCmdBlacklist("module_name")}, + }, + err: true, + }, + { + name: "incorrect-filename", + options: ModprobeStageOptions{ + Filename: "disallow-modules.ccoonnff", + Commands: ModprobeConfigCmdList{NewModprobeConfigCmdBlacklist("module_name")}, + }, + err: true, + }, + { + name: "good-options", + options: ModprobeStageOptions{ + Filename: "disallow-modules.conf", + Commands: ModprobeConfigCmdList{NewModprobeConfigCmdBlacklist("module_name")}, + }, + err: false, }, } for idx, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotBytes, err := json.Marshal(tt.options) - assert.NotNilf(t, err, "json.Marshal() didn't return an error, but: %s [idx: %d]", string(gotBytes), idx) + if tt.err { + assert.Errorf(t, tt.options.validate(), "%q didn't return an error [idx: %d]", tt.name, idx) + assert.Panics(t, func() { NewModprobeStage(&tt.options) }) + } else { + assert.NoErrorf(t, tt.options.validate(), "%q returned an error [idx: %d]", tt.name, idx) + assert.NotPanics(t, func() { NewModprobeStage(&tt.options) }) + } + }) + } +} + +func TestNewModprobeConfigCmdBlacklist(t *testing.T) { + tests := []struct { + name string + modulename string + err bool + }{ + { + name: "empty-modulename", + modulename: "", + err: true, + }, + { + name: "non-empty-modulename", + modulename: "module_name", + err: false, + }, + } + for idx, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err { + assert.Errorf(t, ModprobeConfigCmdBlacklist{Command: "blacklist", Modulename: tt.modulename}.validate(), "%q didn't return an error [idx: %d]", tt.name, idx) + assert.Panics(t, func() { NewModprobeConfigCmdBlacklist(tt.modulename) }) + } else { + assert.NoErrorf(t, ModprobeConfigCmdBlacklist{Command: "blacklist", Modulename: tt.modulename}.validate(), "%q returned an error [idx: %d]", tt.name, idx) + assert.NotPanics(t, func() { NewModprobeConfigCmdBlacklist(tt.modulename) }) + } + }) + } +} + +func TestNewModprobeConfigCmdInstall(t *testing.T) { + tests := []struct { + name string + modulename string + cmdline string + err bool + }{ + { + name: "empty-modulename", + modulename: "", + cmdline: "/usr/bin/true", + err: true, + }, + { + name: "empty-cmdline", + modulename: "module_name", + cmdline: "", + err: true, + }, + { + name: "non-empty-modulename", + modulename: "module_name", + cmdline: "/usr/bin/true", + err: false, + }, + } + for idx, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err { + assert.Errorf(t, ModprobeConfigCmdInstall{Command: "install", Modulename: tt.modulename, Cmdline: tt.cmdline}.validate(), "%q didn't return an error [idx: %d]", tt.name, idx) + assert.Panics(t, func() { NewModprobeConfigCmdInstall(tt.modulename, tt.cmdline) }) + } else { + assert.NoErrorf(t, ModprobeConfigCmdInstall{Command: "install", Modulename: tt.modulename, Cmdline: tt.cmdline}.validate(), "%q returned an error [idx: %d]", tt.name, idx) + assert.NotPanics(t, func() { NewModprobeConfigCmdInstall(tt.modulename, tt.cmdline) }) + } }) } }