diff --git a/internal/cloud/awscloud/secure-instance.go b/internal/cloud/awscloud/secure-instance.go index eb6b069a7..6a66b7e11 100644 --- a/internal/cloud/awscloud/secure-instance.go +++ b/internal/cloud/awscloud/secure-instance.go @@ -78,6 +78,15 @@ func (a *AWS) RunSecureInstance(iamProfile, keyName, cloudWatchGroup, hostname s } }() + previousSI, err := a.terminatePreviousSI(identity.InstanceID) + if err != nil { + logrus.Errorf("Unable to terminate previous secure instance %s (parent instance ID: %s): %v", previousSI, identity.InstanceID, err) + return nil, fmt.Errorf("Unable to terminate previous secure instance %s (parent instance ID: %s): %v", previousSI, identity.InstanceID, err) + } + if previousSI != "" { + logrus.Warningf("Previous instance (%s) terminated by parent instance (%s)", previousSI, identity.InstanceID) + } + sgID, err := a.createOrReplaceSG(identity.InstanceID, identity.PrivateIP, vpcID) if sgID != "" { secureInstance.SGID = sgID @@ -194,6 +203,42 @@ func (a *AWS) TerminateSecureInstance(si *SecureInstance) error { return nil } +func (a *AWS) terminatePreviousSI(hostInstanceID string) (string, error) { + descrInstancesOutput, err := a.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ + Filters: []*ec2.Filter{ + &ec2.Filter{ + Name: aws.String("tag:parent"), + Values: []*string{aws.String(hostInstanceID)}, + }, + }, + }) + if err != nil { + return "", err + } + if len(descrInstancesOutput.Reservations) == 0 || len(descrInstancesOutput.Reservations[0].Instances) == 0 { + return "", nil + } + + if *descrInstancesOutput.Reservations[0].Instances[0].State.Name == ec2.InstanceStateNameTerminated { + return "", nil + } + + instanceID := descrInstancesOutput.Reservations[0].Instances[0].InstanceId + _, err = a.ec2.TerminateInstances(&ec2.TerminateInstancesInput{ + InstanceIds: []*string{instanceID}, + }) + if err != nil { + return *instanceID, err + } + err = a.ec2.WaitUntilInstanceTerminated(&ec2.DescribeInstancesInput{ + InstanceIds: []*string{instanceID}, + }) + if err != nil { + return *instanceID, err + } + return *instanceID, nil +} + func isInvalidGroupNotFoundErr(err error) bool { if awsErr, ok := err.(awserr.Error); ok { if awsErr.Code() == "InvalidGroup.NotFound" {