diff --git a/internal/blueprint/customizations.go b/internal/blueprint/customizations.go index e6b00db5e..371aba8f2 100644 --- a/internal/blueprint/customizations.go +++ b/internal/blueprint/customizations.go @@ -1,5 +1,10 @@ package blueprint +import ( + "fmt" + "reflect" +) + type Customizations struct { Hostname *string `json:"hostname,omitempty" toml:"hostname,omitempty"` Kernel *KernelCustomization `json:"kernel,omitempty" toml:"kernel,omitempty"` @@ -78,6 +83,53 @@ func (e *CustomizationError) Error() string { return e.Message } +//CheckCustomizations returns an error of type `CustomizationError` +//if `c` has any customizations not specified in `allowed` +func (c *Customizations) CheckAllowed(allowed ...string) error { + if c == nil { + return nil + } + + allowMap := make(map[string]bool) + + for _, a := range allowed { + allowMap[a] = true + } + + t := reflect.TypeOf(*c) + v := reflect.ValueOf(*c) + + for i := 0; i < t.NumField(); i++ { + + empty := false + field := v.Field(i) + + switch field.Kind() { + case reflect.String: + if field.String() == "" { + empty = true + } + case reflect.Array, reflect.Slice: + if field.Len() == 0 { + empty = true + } + case reflect.Ptr: + if field.IsNil() { + empty = true + } + default: + panic(fmt.Sprintf("unhandled customization field type %s, %s", v.Kind(), t.Field(i).Name)) + + } + + if !empty && !allowMap[t.Field(i).Name] { + return &CustomizationError{fmt.Sprintf("'%s' is not allowed", t.Field(i).Name)} + } + } + + return nil +} + func (c *Customizations) GetHostname() *string { if c == nil { return nil diff --git a/internal/blueprint/customizations_test.go b/internal/blueprint/customizations_test.go index 6bc3f8907..8ab7b91f1 100644 --- a/internal/blueprint/customizations_test.go +++ b/internal/blueprint/customizations_test.go @@ -6,6 +6,48 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCheckAllowed(t *testing.T) { + Desc := "Test descritpion" + Pass := "testpass" + Key := "testkey" + Home := "Home" + Shell := "Shell" + Groups := []string{ + "Group", + } + UID := 123 + GID := 321 + + expectedUsers := []UserCustomization{ + UserCustomization{ + Name: "John", + Description: &Desc, + Password: &Pass, + Key: &Key, + Home: &Home, + Shell: &Shell, + Groups: Groups, + UID: &UID, + GID: &GID, + }, + } + + var expectedHostname = "Hostname" + + x := Customizations{Hostname: &expectedHostname, User: expectedUsers} + + err := x.CheckAllowed("Hostname", "User") + assert.NoError(t, err) + + // "User" not allowed anymore + err = x.CheckAllowed("Hostname") + assert.Error(t, err) + + // "Hostname" not allowed anymore + err = x.CheckAllowed("User") + assert.Error(t, err) +} + func TestGetHostname(t *testing.T) { var expectedHostname = "Hostname"