diff --git a/internal/cloud/awscloud/mocks_test.go b/internal/cloud/awscloud/mocks_test.go index ba0035ccd..3631d89e3 100644 --- a/internal/cloud/awscloud/mocks_test.go +++ b/internal/cloud/awscloud/mocks_test.go @@ -101,6 +101,7 @@ type ec2mock struct { snapshotId string calledFn map[string]int + failFn map[string]error } func newEc2Mock(t *testing.T) *ec2mock { @@ -110,6 +111,7 @@ func newEc2Mock(t *testing.T) *ec2mock { imageName: "image-name", snapshotId: "snapshot-id", calledFn: make(map[string]int), + failFn: make(map[string]error), } } @@ -136,6 +138,11 @@ func (m *ec2mock) AuthorizeSecurityGroupIngress(ctx context.Context, input *ec2. func (m *ec2mock) CreateSecurityGroup(ctx context.Context, input *ec2.CreateSecurityGroupInput, optfns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) { m.calledFn["CreateSecurityGroup"] += 1 + + if err, ok := m.failFn["CreateSecurityGroup"]; ok { + return nil, err + } + return &ec2.CreateSecurityGroupOutput{ GroupId: aws.String("sg-id"), }, nil @@ -186,6 +193,11 @@ func (m *ec2mock) DescribeSubnets(ctx context.Context, input *ec2.DescribeSubnet func (m *ec2mock) CreateLaunchTemplate(ctx context.Context, input *ec2.CreateLaunchTemplateInput, optfns ...func(*ec2.Options)) (*ec2.CreateLaunchTemplateOutput, error) { m.calledFn["CreateLaunchTemplate"] += 1 + + if err, ok := m.failFn["CreateLaunchTemplate"]; ok { + return nil, err + } + return &ec2.CreateLaunchTemplateOutput{ LaunchTemplate: &ec2types.LaunchTemplate{ LaunchTemplateId: aws.String("lt-id"), @@ -273,6 +285,21 @@ func (m *ec2mock) TerminateInstances(ctx context.Context, input *ec2.TerminateIn func (m *ec2mock) CreateFleet(ctx context.Context, input *ec2.CreateFleetInput, optfns ...func(*ec2.Options)) (*ec2.CreateFleetOutput, error) { m.calledFn["CreateFleet"] += 1 + + if err, ok := m.failFn["CreateFleet"]; ok { + if err != nil { + return nil, err + } + return &ec2.CreateFleetOutput{ + Errors: []ec2types.CreateFleetError{ + { + ErrorCode: aws.String("UnfillableCapacity"), + ErrorMessage: aws.String("Msg"), + }, + }, + }, nil + } + return &ec2.CreateFleetOutput{ FleetId: aws.String("fleet-id"), Instances: []ec2types.CreateFleetInstance{ diff --git a/internal/cloud/awscloud/secure-instance.go b/internal/cloud/awscloud/secure-instance.go index 034c1b677..437b021a1 100644 --- a/internal/cloud/awscloud/secure-instance.go +++ b/internal/cloud/awscloud/secure-instance.go @@ -547,7 +547,7 @@ func (a *AWS) createFleet(input *ec2.CreateFleetInput) (*ec2.CreateFleetOutput, return nil, fmt.Errorf("Unable to create spot fleet: %w", err) } - if len(createFleetOutput.Errors) > 0 && createFleetOutput.Errors[0].ErrorCode == aws.String("UnfillableCapacity") { + if len(createFleetOutput.Errors) > 0 && *createFleetOutput.Errors[0].ErrorCode == "UnfillableCapacity" { logrus.Warn("Received UnfillableCapacity from CreateFleet, retrying CreateFleet with OnDemand instance") input.SpotOptions = nil createFleetOutput, err = a.ec2.CreateFleet(context.Background(), input) diff --git a/internal/cloud/awscloud/secure-instance_test.go b/internal/cloud/awscloud/secure-instance_test.go index 824335c3c..6d8fc6338 100644 --- a/internal/cloud/awscloud/secure-instance_test.go +++ b/internal/cloud/awscloud/secure-instance_test.go @@ -104,3 +104,63 @@ func TestSITerminateSecureInstance(t *testing.T) { require.Equal(t, 1, m.calledFn["DeleteLaunchTemplate"]) require.Equal(t, 2, m.calledFn["DescribeInstances"]) } + +func TestSICreateSGFailures(t *testing.T) { + m := newEc2Mock(t) + aws := awscloud.NewForTest(m, &ec2imdsmock{t, "instance-id", "region1"}, nil, nil, nil) + require.NotNil(t, aws) + + m.failFn["CreateSecurityGroup"] = fmt.Errorf("some-error") + si, err := aws.RunSecureInstance("iam-profile", "key-name", "cw-group", "hostname") + require.Error(t, err) + require.Nil(t, si) + require.Equal(t, 1, m.calledFn["CreateSecurityGroup"]) + require.Equal(t, 1, m.calledFn["DeleteSecurityGroup"]) + require.Equal(t, 0, m.calledFn["CreateFleet"]) + require.Equal(t, 0, m.calledFn["CreateLaunchTemplate"]) + require.Equal(t, 0, m.calledFn["DeleteLaunchTemplate"]) +} + +func TestSICreateLTFailures(t *testing.T) { + m := newEc2Mock(t) + aws := awscloud.NewForTest(m, &ec2imdsmock{t, "instance-id", "region1"}, nil, nil, nil) + require.NotNil(t, aws) + + m.failFn["CreateLaunchTemplate"] = fmt.Errorf("some-error") + si, err := aws.RunSecureInstance("iam-profile", "key-name", "cw-group", "hostname") + require.Error(t, err) + require.Nil(t, si) + require.Equal(t, 1, m.calledFn["CreateSecurityGroup"]) + require.Equal(t, 2, m.calledFn["DeleteSecurityGroup"]) + require.Equal(t, 1, m.calledFn["CreateLaunchTemplate"]) + require.Equal(t, 1, m.calledFn["DeleteLaunchTemplate"]) + require.Equal(t, 0, m.calledFn["CreateFleet"]) +} + +func TestSICreateFleetFailures(t *testing.T) { + m := newEc2Mock(t) + aws := awscloud.NewForTest(m, &ec2imdsmock{t, "instance-id", "region1"}, nil, nil, nil) + require.NotNil(t, aws) + + // unfillable capacity should call create fleet twice + m.failFn["CreateFleet"] = nil + si, err := aws.RunSecureInstance("iam-profile", "key-name", "cw-group", "hostname") + require.Error(t, err) + require.Nil(t, si) + require.Equal(t, 2, m.calledFn["CreateFleet"]) + require.Equal(t, 1, m.calledFn["CreateSecurityGroup"]) + require.Equal(t, 1, m.calledFn["CreateLaunchTemplate"]) + require.Equal(t, 2, m.calledFn["DeleteSecurityGroup"]) + require.Equal(t, 2, m.calledFn["DeleteLaunchTemplate"]) + + // other errors should just fail immediately + m.failFn["CreateFleet"] = fmt.Errorf("random error") + si, err = aws.RunSecureInstance("iam-profile", "key-name", "cw-group", "hostname") + require.Error(t, err) + require.Nil(t, si) + require.Equal(t, 3, m.calledFn["CreateFleet"]) + require.Equal(t, 2, m.calledFn["CreateSecurityGroup"]) + require.Equal(t, 2, m.calledFn["CreateLaunchTemplate"]) + require.Equal(t, 4, m.calledFn["DeleteSecurityGroup"]) + require.Equal(t, 4, m.calledFn["DeleteLaunchTemplate"]) +}