blueprint: add CheckAllowed helper

New helper to check if a blueprint containts only a set of allowed
customization. If not an error is returned.
This commit is contained in:
Christian Kellner 2021-08-25 19:28:44 +02:00 committed by Tom Gundersen
parent 36084fba20
commit 6d08418107
2 changed files with 94 additions and 0 deletions

View file

@ -1,5 +1,10 @@
package blueprint
import (
"fmt"
"reflect"
)
type Customizations struct {
Hostname *string `json:"hostname,omitempty" toml:"hostname,omitempty"`
Kernel *KernelCustomization `json:"kernel,omitempty" toml:"kernel,omitempty"`
@ -78,6 +83,53 @@ func (e *CustomizationError) Error() string {
return e.Message
}
//CheckCustomizations returns an error of type `CustomizationError`
//if `c` has any customizations not specified in `allowed`
func (c *Customizations) CheckAllowed(allowed ...string) error {
if c == nil {
return nil
}
allowMap := make(map[string]bool)
for _, a := range allowed {
allowMap[a] = true
}
t := reflect.TypeOf(*c)
v := reflect.ValueOf(*c)
for i := 0; i < t.NumField(); i++ {
empty := false
field := v.Field(i)
switch field.Kind() {
case reflect.String:
if field.String() == "" {
empty = true
}
case reflect.Array, reflect.Slice:
if field.Len() == 0 {
empty = true
}
case reflect.Ptr:
if field.IsNil() {
empty = true
}
default:
panic(fmt.Sprintf("unhandled customization field type %s, %s", v.Kind(), t.Field(i).Name))
}
if !empty && !allowMap[t.Field(i).Name] {
return &CustomizationError{fmt.Sprintf("'%s' is not allowed", t.Field(i).Name)}
}
}
return nil
}
func (c *Customizations) GetHostname() *string {
if c == nil {
return nil

View file

@ -6,6 +6,48 @@ import (
"github.com/stretchr/testify/assert"
)
func TestCheckAllowed(t *testing.T) {
Desc := "Test descritpion"
Pass := "testpass"
Key := "testkey"
Home := "Home"
Shell := "Shell"
Groups := []string{
"Group",
}
UID := 123
GID := 321
expectedUsers := []UserCustomization{
UserCustomization{
Name: "John",
Description: &Desc,
Password: &Pass,
Key: &Key,
Home: &Home,
Shell: &Shell,
Groups: Groups,
UID: &UID,
GID: &GID,
},
}
var expectedHostname = "Hostname"
x := Customizations{Hostname: &expectedHostname, User: expectedUsers}
err := x.CheckAllowed("Hostname", "User")
assert.NoError(t, err)
// "User" not allowed anymore
err = x.CheckAllowed("Hostname")
assert.Error(t, err)
// "Hostname" not allowed anymore
err = x.CheckAllowed("User")
assert.Error(t, err)
}
func TestGetHostname(t *testing.T) {
var expectedHostname = "Hostname"