internal/cloud: add AWS client libary implementation
This creates an AWS client libary which contains a subset of the
AWS client library. This package provides a fake implementation of
the client library facilitating testing throughout the repository.
Updates golang/go#36841
Change-Id: Id91cd778ee794a6e38a3273c250505a97c6a0b02
Reviewed-on: https://go-review.googlesource.com/c/build/+/236299
Run-TryBot: Carlos Amedee <carlos@golang.org>
Reviewed-by: Alexander Rakoczy <alex@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/cloud/aws.go b/internal/cloud/aws.go
new file mode 100644
index 0000000..2cdca87
--- /dev/null
+++ b/internal/cloud/aws.go
@@ -0,0 +1,307 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cloud
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log"
+ "time"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/credentials"
+ "github.com/aws/aws-sdk-go/aws/request"
+ "github.com/aws/aws-sdk-go/aws/session"
+ "github.com/aws/aws-sdk-go/service/ec2"
+)
+
+const (
+ // tagName denotes the text used for Name tags.
+ tagName = "Name"
+ // tagDescription denotes the text used for Description tags.
+ tagDescription = "Description"
+)
+
+// vmClient defines the interface used to call the backing EC2 service. This is a partial interface
+// based on the EC2 package defined at `github.com/aws/aws-sdk-go/service/ec2`.
+type vmClient interface {
+ DescribeInstancesPagesWithContext(context.Context, *ec2.DescribeInstancesInput, func(*ec2.DescribeInstancesOutput, bool) bool, ...request.Option) error
+ DescribeInstancesWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.Option) (*ec2.DescribeInstancesOutput, error)
+ RunInstancesWithContext(context.Context, *ec2.RunInstancesInput, ...request.Option) (*ec2.Reservation, error)
+ TerminateInstancesWithContext(context.Context, *ec2.TerminateInstancesInput, ...request.Option) (*ec2.TerminateInstancesOutput, error)
+ WaitUntilInstanceRunningWithContext(context.Context, *ec2.DescribeInstancesInput, ...request.WaiterOption) error
+}
+
+// EC2VMConfiguration is the configuration needed for an EC2 instance.
+type EC2VMConfiguration struct {
+ // Description is a user defined description of the instance. It is displayed
+ // on the AWS UI. It is an optional field.
+ Description string
+ // ImageID is the ID of the image used to launch the instance. It is a required field.
+ ImageID string
+ // Name is a user defined name for the instance. It is displayed on the AWS UI. It is
+ // is an optional field.
+ Name string
+ // SSHKeyID is the name of the SSH key pair to use for access. It is a required field.
+ SSHKeyID string
+ // SecurityGroups contains the names of the security groups to be applied to the VM. If none
+ // are provided the default security group will be used.
+ SecurityGroups []string
+ // Tags the tags to apply to the resources during launch.
+ Tags map[string]string
+ // Type is the type of instance.
+ Type string
+ // UserData is the user data to make available to the instance. This data is available
+ // on the VM via the metadata endpoints. It must be a base64-encoded string. User
+ // data is limited to 16 KB.
+ UserData string
+ // Zone the Availability Zone of the instance.
+ Zone string
+}
+
+// Instance is a virtual machine.
+type Instance struct {
+ // CPUCount is the number of VCPUs the instance is configured with.
+ CPUCount int64
+ // CreatedAt is the time when the instance was launched.
+ CreatedAt time.Time
+ // Description is a user defined descripton of the instance.
+ Description string
+ // ID is the instance ID.
+ ID string
+ // IPAddressExternal is the public IPv4 address assigned to the instance.
+ IPAddressExternal string
+ // IPAddressInternal is the private IPv4 address assigned to the instance.
+ IPAddressInternal string
+ // ImageID is The ID of the AMI(image) used to launch the instance.
+ ImageID string
+ // Name is a user defined name for the instance instance.
+ Name string
+ // SSHKeyID is the name of the SSH key pair to use for access. It is a required field.
+ SSHKeyID string
+ // SecurityGroups is the security groups for the instance.
+ SecurityGroups []string
+ // State contains the state of the instance.
+ State string
+ // Tags contains tags assigned to the instance.
+ Tags map[string]string
+ // Type is the name of instance type.
+ Type string
+ // Zone is the availability zone where the instance is deployed.
+ Zone string
+}
+
+// AWSClient is a client for AWS services.
+type AWSClient struct {
+ ec2Client vmClient
+}
+
+// NewAWSClient creates a new AWS client.
+func NewAWSClient(region, keyID, accessKey string) (*AWSClient, error) {
+ s, err := session.NewSession(&aws.Config{
+ Region: aws.String(region),
+ Credentials: credentials.NewStaticCredentials(keyID, accessKey, ""), // Token is only required for STS
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create AWS session: %v", err)
+ }
+ return &AWSClient{
+ ec2Client: ec2.New(s),
+ }, nil
+}
+
+// Instance retrieves an EC2 instance by instance ID.
+func (ac *AWSClient) Instance(ctx context.Context, instID string) (*Instance, error) {
+ dio, err := ac.ec2Client.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{
+ InstanceIds: []*string{aws.String(instID)},
+ })
+ if err != nil {
+ return nil, fmt.Errorf("unable to retrieve instance %q information: %w", instID, err)
+ }
+
+ if dio == nil || len(dio.Reservations) != 1 || len(dio.Reservations[0].Instances) != 1 {
+ return nil, errors.New("describe instances output does not contain a valid instance")
+ }
+ ec2Inst := dio.Reservations[0].Instances[0]
+ return ec2ToInstance(ec2Inst), err
+}
+
+// RunningInstances retrieves all EC2 instances in a region which have not been terminated or stopped.
+func (ac *AWSClient) RunningInstances(ctx context.Context) ([]*Instance, error) {
+ instances := make([]*Instance, 0)
+
+ fn := func(page *ec2.DescribeInstancesOutput, lastPage bool) bool {
+ for _, res := range page.Reservations {
+ for _, inst := range res.Instances {
+ instances = append(instances, ec2ToInstance(inst))
+ }
+ }
+ return true
+ }
+ err := ac.ec2Client.DescribeInstancesPagesWithContext(ctx, &ec2.DescribeInstancesInput{
+ Filters: []*ec2.Filter{
+ &ec2.Filter{
+ Name: aws.String("instance-state-name"),
+ Values: []*string{aws.String(ec2.InstanceStateNameRunning), aws.String(ec2.InstanceStateNamePending)},
+ },
+ },
+ }, fn)
+ if err != nil {
+ return nil, err
+ }
+ return instances, nil
+}
+
+// CreateInstance creates an EC2 VM instance.
+func (ac *AWSClient) CreateInstance(ctx context.Context, config *EC2VMConfiguration) (*Instance, error) {
+ if config == nil {
+ return nil, errors.New("unable to create a VM with a nil instance")
+ }
+ runResult, err := ac.ec2Client.RunInstancesWithContext(ctx, vmConfig(config))
+ if err != nil {
+ return nil, fmt.Errorf("unable to create instance: %w", err)
+ }
+ if runResult == nil || len(runResult.Instances) != 1 {
+ return nil, fmt.Errorf("unexpected number of instances. want 1; got %d", len(runResult.Instances))
+ }
+ return ec2ToInstance(runResult.Instances[0]), nil
+}
+
+// DestroyInstances terminates EC2 VM instances.
+func (ac *AWSClient) DestroyInstances(ctx context.Context, instIDs ...string) error {
+ ids := aws.StringSlice(instIDs)
+ _, err := ac.ec2Client.TerminateInstancesWithContext(ctx, &ec2.TerminateInstancesInput{
+ InstanceIds: ids,
+ })
+ if err != nil {
+ return fmt.Errorf("unable to destroy vm: %w", err)
+ }
+ return err
+}
+
+// WaitUntilInstanceRunning waits until a stopping condition is met. The stopping conditions are:
+// - The requested instance state is `running`.
+// - The passed in context is cancelled or the deadline expires.
+// - 40 requests are made made with a 15 second delay between each request.
+func (ac *AWSClient) WaitUntilInstanceRunning(ctx context.Context, instID string) error {
+ err := ac.ec2Client.WaitUntilInstanceRunningWithContext(ctx, &ec2.DescribeInstancesInput{
+ InstanceIds: []*string{aws.String(instID)},
+ })
+ if err != nil {
+ return fmt.Errorf("failed waiting for vm instance: %w", err)
+ }
+ return err
+}
+
+// ec2ToInstance converts an `ec2.Instance` to an `Instance`
+func ec2ToInstance(inst *ec2.Instance) *Instance {
+ secGroup := make([]string, 0, len(inst.SecurityGroups))
+ for _, sg := range inst.SecurityGroups {
+ secGroup = append(secGroup, aws.StringValue(sg.GroupId))
+ }
+ i := &Instance{
+ CreatedAt: aws.TimeValue(inst.LaunchTime),
+ ID: *inst.InstanceId,
+ IPAddressExternal: aws.StringValue(inst.PublicIpAddress),
+ IPAddressInternal: aws.StringValue(inst.PrivateIpAddress),
+ ImageID: aws.StringValue(inst.ImageId),
+ SSHKeyID: aws.StringValue(inst.KeyName),
+ SecurityGroups: secGroup,
+ State: aws.StringValue(inst.State.Name),
+ Tags: make(map[string]string),
+ Type: aws.StringValue(inst.InstanceType),
+ }
+ if inst.Placement != nil {
+ i.Zone = aws.StringValue(inst.Placement.AvailabilityZone)
+ }
+ if inst.CpuOptions != nil {
+ i.CPUCount = aws.Int64Value(inst.CpuOptions.CoreCount)
+ }
+ for _, tag := range inst.Tags {
+ switch *tag.Key {
+ case tagName:
+ i.Name = *tag.Value
+ case tagDescription:
+ i.Description = *tag.Value
+ default:
+ i.Tags[*tag.Key] = *tag.Value
+ }
+ }
+ return i
+}
+
+// vmConfig converts a configuration into a request to create an instance.
+func vmConfig(config *EC2VMConfiguration) *ec2.RunInstancesInput {
+ ri := &ec2.RunInstancesInput{
+ ImageId: aws.String(config.ImageID),
+ InstanceType: aws.String(config.Type),
+ MinCount: aws.Int64(1),
+ MaxCount: aws.Int64(1),
+ Placement: &ec2.Placement{
+ AvailabilityZone: aws.String(config.Zone),
+ },
+ KeyName: aws.String(config.SSHKeyID),
+ InstanceInitiatedShutdownBehavior: aws.String(ec2.ShutdownBehaviorTerminate),
+ TagSpecifications: []*ec2.TagSpecification{
+ &ec2.TagSpecification{
+ ResourceType: aws.String("instance"),
+ Tags: []*ec2.Tag{
+ &ec2.Tag{
+ Key: aws.String(tagName),
+ Value: aws.String(config.Name),
+ },
+ &ec2.Tag{
+ Key: aws.String(tagDescription),
+ Value: aws.String(config.Description),
+ },
+ },
+ },
+ },
+ SecurityGroups: aws.StringSlice(config.SecurityGroups),
+ UserData: aws.String(config.UserData),
+ }
+ for k, v := range config.Tags {
+ ri.TagSpecifications[0].Tags = append(ri.TagSpecifications[0].Tags, &ec2.Tag{
+ Key: aws.String(k),
+ Value: aws.String(v),
+ })
+ }
+ return ri
+}
+
+// EC2UserData is stored in the user data for each EC2 instance. This is
+// used to store metadata about the running instance. The buildlet will retrieve
+// this on EC2 instances before allowing connections from the coordinator.
+type EC2UserData struct {
+ // BuildletBinaryURL is the url to the buildlet binary stored on GCS.
+ BuildletBinaryURL string `json:"buildlet_binary_url,omitempty"`
+ // BuildletHostType is the host type used by the buildlet. For example, `host-linux-arm64-aws`.
+ BuildletHostType string `json:"buildlet_host_type,omitempty"`
+ // BuildletImageURL is the url for the buildlet container image.
+ BuildletImageURL string `json:"buildlet_image_url,omitempty"`
+ // BuildletName is the name which should be passed onto the buildlet.
+ BuildletName string `json:"buildlet_name,omitempty"`
+ // Metadata provides a location for arbitrary metadata to be stored.
+ Metadata map[string]string `json:"metadata,omitempty"`
+ // TLSCert is the TLS certificate used by the buildlet.
+ TLSCert string `json:"tls_cert,omitempty"`
+ // TLSKey is the TLS key used by the buildlet.
+ TLSKey string `json:"tls_key,omitempty"`
+ // TLSPassword contains the SHA1 of the TLS key used by the buildlet for basic authentication.
+ TLSPassword string `json:"tls_password,omitempty"`
+}
+
+// EncodedString converts `EC2UserData` into JSON which is base64 encoded.
+// User data must be base64 encoded upon creation.
+func (ud *EC2UserData) EncodedString() string {
+ jsonUserData, err := json.Marshal(ud)
+ if err != nil {
+ log.Printf("unable to marshal user data: %v", err)
+ }
+ return base64.StdEncoding.EncodeToString([]byte(jsonUserData))
+}
diff --git a/internal/cloud/aws_test.go b/internal/cloud/aws_test.go
new file mode 100644
index 0000000..a90a73e
--- /dev/null
+++ b/internal/cloud/aws_test.go
@@ -0,0 +1,692 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cloud
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/request"
+ "github.com/aws/aws-sdk-go/service/ec2"
+ "github.com/google/go-cmp/cmp"
+)
+
+var _ vmClient = (*fakeEC2Client)(nil)
+
+type fakeEC2Client struct {
+ mu sync.RWMutex
+ // instances map of instanceId -> *ec2.Instance
+ instances map[string]*ec2.Instance
+}
+
+func newFakeEC2Client() *fakeEC2Client {
+ return &fakeEC2Client{
+ instances: make(map[string]*ec2.Instance),
+ }
+}
+
+// filterFunc represents a function used to filter out instances.
+type filterFunc func(*ec2.Instance) bool
+
+// createFilter returns filtering functions for a subset of `ec2.Filter`.
+// The response in the function returned indicates whether the instance
+// should be included.
+func createFilter(f *ec2.Filter) filterFunc {
+ if *f.Name == "instance-state-name" {
+ states := aws.StringValueSlice(f.Values)
+ return func(i *ec2.Instance) bool {
+ for _, s := range states {
+ if *i.State.Name == s {
+ return true
+ }
+ }
+ return false
+ }
+ }
+ // return noop filter for unsupported filters
+ return func(i *ec2.Instance) bool { return true }
+}
+
+// createFilters creates a filtering function for a subset of `ec2.Filter`.
+// The response for the returned function indicates whether the instance
+// should be included after all of the supplied filters have been evaluated.
+func createFilters(fs []*ec2.Filter) filterFunc {
+ if len(fs) == 0 {
+ // return noop filter for unsupported filters
+ return func(i *ec2.Instance) bool { return true }
+ }
+ filters := make([]filterFunc, 0, len(fs))
+ for _, f := range fs {
+ filters = append(filters, createFilter(f))
+ }
+ return func(i *ec2.Instance) bool {
+ for _, fn := range filters {
+ if !fn(i) {
+ return false
+ }
+ }
+ return true
+ }
+}
+
+func (f *fakeEC2Client) DescribeInstancesPagesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, fn func(*ec2.DescribeInstancesOutput, bool) bool, opt ...request.Option) error {
+ if input == nil || fn == nil {
+ return errors.New("invalid input")
+ }
+ filters := createFilters(input.Filters)
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ insts := make([]*ec2.Instance, 0, len(f.instances))
+ for _, inst := range f.instances {
+ if !filters(inst) {
+ continue
+ }
+ insts = append(insts, inst)
+ }
+ for it, inst := range insts {
+ fn(&ec2.DescribeInstancesOutput{
+ Reservations: []*ec2.Reservation{
+ &ec2.Reservation{
+ Instances: []*ec2.Instance{
+ inst,
+ },
+ },
+ },
+ }, it == len(insts)-1)
+ }
+ return nil
+}
+
+func (f *fakeEC2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.Option) (*ec2.DescribeInstancesOutput, error) {
+ if ctx == nil || input == nil || len(input.InstanceIds) == 0 {
+ return nil, request.ErrInvalidParams{}
+ }
+ filters := createFilters(input.Filters)
+ instances := make([]*ec2.Instance, 0, len(input.InstanceIds))
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ for _, id := range aws.StringValueSlice(input.InstanceIds) {
+ inst, ok := f.instances[id]
+ if !ok {
+ return nil, errors.New("instance not found")
+ }
+ if !filters(inst) {
+ continue
+ }
+ instances = append(instances, inst)
+ }
+ return &ec2.DescribeInstancesOutput{
+ Reservations: []*ec2.Reservation{
+ &ec2.Reservation{
+ Instances: instances,
+ },
+ },
+ }, nil
+}
+
+func (f *fakeEC2Client) RunInstancesWithContext(ctx context.Context, input *ec2.RunInstancesInput, opts ...request.Option) (*ec2.Reservation, error) {
+ if ctx == nil || input == nil {
+ return nil, request.ErrInvalidParams{}
+ }
+ if input.ImageId == nil || aws.StringValue(input.ImageId) == "" ||
+ input.InstanceType == nil || aws.StringValue(input.InstanceType) == "" ||
+ input.MinCount == nil || aws.Int64Value(input.MinCount) == 0 ||
+ input.Placement == nil || aws.StringValue(input.Placement.AvailabilityZone) == "" {
+ return nil, errors.New("invalid instance configuration")
+ }
+ instCount := int(aws.Int64Value(input.MaxCount))
+ instances := make([]*ec2.Instance, 0, instCount)
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ for i := 0; i < instCount; i++ {
+ inst := &ec2.Instance{
+ CpuOptions: &ec2.CpuOptions{
+ CoreCount: aws.Int64(4),
+ },
+ ImageId: input.ImageId,
+ InstanceType: input.InstanceType,
+ InstanceId: aws.String(fmt.Sprintf("instance-%s", randHex(10))),
+ Placement: input.Placement,
+ PrivateIpAddress: aws.String(randIPv4()),
+ PublicIpAddress: aws.String(randIPv4()),
+ State: &ec2.InstanceState{
+ Name: aws.String("running"),
+ },
+ Tags: []*ec2.Tag{},
+ KeyName: input.KeyName,
+ SecurityGroups: []*ec2.GroupIdentifier{},
+ LaunchTime: aws.Time(time.Now()),
+ }
+ for _, id := range input.SecurityGroups {
+ inst.SecurityGroups = append(inst.SecurityGroups, &ec2.GroupIdentifier{
+ GroupId: id,
+ })
+ }
+ for _, tagSpec := range input.TagSpecifications {
+ for _, tag := range tagSpec.Tags {
+ inst.Tags = append(inst.Tags, tag)
+ }
+ }
+ f.instances[*inst.InstanceId] = inst
+ instances = append(instances, inst)
+ }
+ return &ec2.Reservation{
+ Instances: instances,
+ ReservationId: aws.String(fmt.Sprintf("reservation-%s", randHex(10))),
+ }, nil
+}
+
+func (f *fakeEC2Client) TerminateInstancesWithContext(ctx context.Context, input *ec2.TerminateInstancesInput, opts ...request.Option) (*ec2.TerminateInstancesOutput, error) {
+ if ctx == nil || input == nil || len(input.InstanceIds) == 0 {
+ return nil, request.ErrInvalidParams{}
+ }
+ isc := make([]*ec2.InstanceStateChange, 0, len(input.InstanceIds))
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ for _, id := range input.InstanceIds {
+ if *id == "" {
+ return nil, errors.New("invalid instance id")
+ }
+ var prevState string
+ inst, ok := f.instances[*id]
+ if !ok {
+ return nil, errors.New("instance not found")
+ }
+ prevState = *inst.State.Name
+ inst.State.Name = aws.String(ec2.InstanceStateNameTerminated)
+ isc = append(isc, &ec2.InstanceStateChange{
+ CurrentState: &ec2.InstanceState{
+ Name: aws.String(prevState),
+ },
+ InstanceId: id,
+ PreviousState: &ec2.InstanceState{
+ Code: nil,
+ Name: aws.String(ec2.InstanceStateNameTerminated),
+ },
+ })
+ }
+ return &ec2.TerminateInstancesOutput{
+ TerminatingInstances: isc,
+ }, nil
+}
+
+func (f *fakeEC2Client) WaitUntilInstanceRunningWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.WaiterOption) error {
+ if ctx == nil || input == nil || len(input.InstanceIds) == 0 {
+ return request.ErrInvalidParams{}
+ }
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ for _, id := range input.InstanceIds {
+ inst, ok := f.instances[*id]
+ if !ok {
+ return fmt.Errorf("instance %s not found", *id)
+ }
+ inst.State = &ec2.InstanceState{
+ Name: aws.String("running"),
+ }
+ }
+ return nil
+}
+
+func fakeClient() *AWSClient {
+ return &AWSClient{
+ ec2Client: newFakeEC2Client(),
+ }
+}
+
+func fakeClientWithInstances(t *testing.T, count int) (*AWSClient, []*Instance) {
+ c := fakeClient()
+ ctx := context.Background()
+ insts := make([]*Instance, 0, count)
+ for i := 0; i < count; i++ {
+ inst, err := c.CreateInstance(ctx, randomVMConfig())
+ if err != nil {
+ t.Fatalf("unable to create instance: %s", err)
+ }
+ insts = append(insts, inst)
+ }
+ return c, insts
+}
+
+func randomVMConfig() *EC2VMConfiguration {
+ return &EC2VMConfiguration{
+ Description: fmt.Sprintf("description-" + randHex(4)),
+ ImageID: fmt.Sprintf("image-" + randHex(4)),
+ Name: fmt.Sprintf("name-" + randHex(4)),
+ SSHKeyID: fmt.Sprintf("ssh-key-id-" + randHex(4)),
+ SecurityGroups: []string{fmt.Sprintf("sg-" + randHex(4))},
+ Tags: map[string]string{
+ fmt.Sprintf("tag-key-" + randHex(4)): fmt.Sprintf("tag-value-" + randHex(4)),
+ },
+ Type: fmt.Sprintf("type-" + randHex(4)),
+ UserData: fmt.Sprintf("user-data-" + randHex(4)),
+ Zone: fmt.Sprintf("zone-" + randHex(4)),
+ }
+}
+
+func TestRunningInstances(t *testing.T) {
+ t.Run("query-all-instances", func(t *testing.T) {
+ c, wantInsts := fakeClientWithInstances(t, 4)
+ gotInsts, gotErr := c.RunningInstances(context.Background())
+ if gotErr != nil {
+ t.Fatalf("Instances(ctx) = %+v, %s; want nil, nil", gotInsts, gotErr)
+ }
+ if len(gotInsts) != len(wantInsts) {
+ t.Errorf("got instance count %d: want %d", len(gotInsts), len(wantInsts))
+ }
+ })
+ t.Run("query-with-a-terminated-instance", func(t *testing.T) {
+ ctx := context.Background()
+ c, wantInsts := fakeClientWithInstances(t, 4)
+ gotErr := c.DestroyInstances(ctx, wantInsts[0].ID)
+ if gotErr != nil {
+ t.Fatalf("unable to destroy instance: %s", gotErr)
+ }
+ gotInsts, gotErr := c.RunningInstances(ctx)
+ if gotErr != nil {
+ t.Fatalf("Instances(ctx) = %+v, %s; want nil, nil", gotInsts, gotErr)
+ }
+ if len(gotInsts) != len(wantInsts)-1 {
+ t.Errorf("got instance count %d: want %d", len(gotInsts), len(wantInsts)-1)
+ }
+ })
+}
+
+func TestInstance(t *testing.T) {
+ t.Run("query-instance", func(t *testing.T) {
+ c, wantInsts := fakeClientWithInstances(t, 1)
+ wantInst := wantInsts[0]
+ gotInst, gotErr := c.Instance(context.Background(), wantInst.ID)
+ if gotErr != nil || gotInst == nil || gotInst.ID != wantInst.ID {
+ t.Errorf("Instance(ctx, %s) = %+v, %s; want no error", wantInst.ID, gotInst, gotErr)
+ }
+ })
+ t.Run("query-terminated-instance", func(t *testing.T) {
+ c, wantInsts := fakeClientWithInstances(t, 1)
+ wantInst := wantInsts[0]
+ ctx := context.Background()
+ gotErr := c.DestroyInstances(ctx, wantInst.ID)
+ if gotErr != nil {
+ t.Fatalf("unable to destroy instance: %s", gotErr)
+ }
+ gotInst, gotErr := c.Instance(ctx, wantInst.ID)
+ if gotErr != nil || gotInst == nil || gotInst.ID != wantInst.ID {
+ t.Errorf("Instance(ctx, %s) = %+v, %s; want no error", wantInst.ID, gotInst, gotErr)
+ }
+ })
+}
+
+func TestCreateInstance(t *testing.T) {
+ ud := &EC2UserData{
+ BuildletBinaryURL: "b-url",
+ BuildletHostType: "b-host-type",
+ BuildletImageURL: "b-image-url",
+ BuildletName: "b-name",
+ Metadata: map[string]string{
+ "tag-a": "value-b",
+ },
+ TLSCert: "cert-a",
+ TLSKey: "key-a",
+ TLSPassword: "pass-a",
+ }
+ config := &EC2VMConfiguration{
+ Description: "description-a",
+ ImageID: "my-image",
+ Name: "my-instance",
+ SSHKeyID: "my-key",
+ SecurityGroups: []string{"test-key"},
+ Tags: map[string]string{
+ "tag-1": "value-1",
+ },
+ Type: "xby.large",
+ UserData: ud.EncodedString(),
+ Zone: "us-west-14",
+ }
+ c := fakeClient()
+ gotInst, gotErr := c.CreateInstance(context.Background(), config)
+ if gotErr != nil {
+ t.Errorf("CreateInstance(ctx, %v) = %+v, %s; want no error", config, gotInst, gotErr)
+ }
+ if gotInst.Description != config.Description {
+ t.Errorf("Instance.Description = %s; want %s", gotInst.Description, config.Description)
+ }
+ if gotInst.ImageID != config.ImageID {
+ t.Errorf("Instance.ImageID = %s; want %s", gotInst.ImageID, config.ImageID)
+ }
+ if gotInst.Name != config.Name {
+ t.Errorf("Instance.Name = %s; want %s", gotInst.Name, config.Name)
+ }
+ if gotInst.SSHKeyID != config.SSHKeyID {
+ t.Errorf("Instance.SSHKeyID = %s; want %s", gotInst.SSHKeyID, config.SSHKeyID)
+ }
+ if !cmp.Equal(gotInst.SecurityGroups, config.SecurityGroups) {
+ t.Errorf("Instance.SecruityGroups = %v; want %v", gotInst.SecurityGroups, config.SecurityGroups)
+ }
+ if !cmp.Equal(gotInst.Tags, config.Tags) {
+ t.Errorf("Instance.Tags = %v want %v", gotInst.Tags, config.Tags)
+ }
+ if gotInst.Type != config.Type {
+ t.Errorf("Instance.Type = %s; want %s", gotInst.Type, config.Type)
+ }
+ if gotInst.Zone != config.Zone {
+ t.Errorf("Instance.Zone = %s; want %s", gotInst.Zone, config.Zone)
+ }
+}
+
+func TestCreateInstanceError(t *testing.T) {
+ testCases := []struct {
+ desc string
+ vmConfig *EC2VMConfiguration
+ }{
+ {
+ desc: "missing-vmConfig",
+ vmConfig: nil,
+ },
+ {
+ desc: "missing-image-type",
+ vmConfig: &EC2VMConfiguration{
+ Type: "type-a",
+ Zone: "eu-15",
+ },
+ },
+ {
+ desc: "missing-vm-type",
+ vmConfig: &EC2VMConfiguration{
+ ImageID: "ami-15",
+ Zone: "eu-15",
+ },
+ },
+ {
+ desc: "missing-zone",
+ vmConfig: &EC2VMConfiguration{
+ ImageID: "ami-15",
+ Type: "abc.large",
+ },
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ c := fakeClient()
+ gotInst, gotErr := c.CreateInstance(context.Background(), tc.vmConfig)
+ if gotErr == nil || gotInst != nil {
+ t.Errorf("CreateInstance(ctx, %+v) = %+v, %s; want error", tc.vmConfig, gotInst, gotErr)
+ }
+ })
+ }
+}
+
+func TestDestroyInstances(t *testing.T) {
+ testCases := []struct {
+ desc string
+ ctx context.Context
+ vmCount int
+ wantErr bool
+ }{
+ {"baseline request", context.Background(), 1, false},
+ {"nil context", nil, 1, true},
+ {"missing vmID", context.Background(), 0, true},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ c, insts := fakeClientWithInstances(t, tc.vmCount)
+ instIDs := make([]string, 0, tc.vmCount)
+ for _, inst := range insts {
+ instIDs = append(instIDs, inst.ID)
+ }
+ gotErr := c.DestroyInstances(tc.ctx, instIDs...)
+ if (gotErr != nil) != tc.wantErr {
+ t.Errorf("DestroyVM(%v, %+v) = %v; want error %t", tc.ctx, instIDs, gotErr, tc.wantErr)
+ }
+ })
+ }
+}
+
+func TestWaitUntilInstanceRunning(t *testing.T) {
+ c, wantInsts := fakeClientWithInstances(t, 1)
+ wantInst := wantInsts[0]
+ ctx := context.Background()
+ gotErr := c.WaitUntilInstanceRunning(ctx, wantInst.ID)
+ if gotErr != nil {
+ t.Errorf("WaitUntilVMExists(%v, %v) failed with error %s", ctx, wantInst.ID, gotErr)
+ }
+}
+
+func TestWaitUntilInstanceRunningErr(t *testing.T) {
+ testCases := []struct {
+ desc string
+ ctx context.Context
+ vmCount int
+ }{
+ {"nil-context", nil, 1},
+ {"missing vmID", context.Background(), 0},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ c, wantInsts := fakeClientWithInstances(t, tc.vmCount)
+ ctx := context.Background()
+ wantID := ""
+ if len(wantInsts) > 0 {
+ wantID = wantInsts[0].ID
+ }
+ gotErr := c.WaitUntilInstanceRunning(tc.ctx, wantID)
+ if gotErr == nil {
+ t.Errorf("WaitUntilVMExists(%v, %v) = %s: want error", ctx, wantID, gotErr)
+ }
+ })
+ }
+}
+
+func TestEC2ToInstance(t *testing.T) {
+ wantCreationTime := time.Unix(1, 1)
+ wantDescription := "my-desc"
+ wantID := "inst-55"
+ wantIPExt := "1.1.1.1"
+ wantIPInt := "2.2.2.2"
+ wantImage := "ami-56"
+ wantKey := "my-key"
+ wantName := "my-name"
+ wantSecurityGroup := "22"
+ wantTagKey := "tag1"
+ wantTagValue := "taggy1"
+ wantType := "type-1"
+ wantZone := "us-east-22"
+ wantState := "running"
+ var wantCPUCount int64 = 66
+
+ ei := &ec2.Instance{
+ CpuOptions: &ec2.CpuOptions{
+ CoreCount: aws.Int64(wantCPUCount),
+ },
+ ImageId: aws.String(wantImage),
+ InstanceId: aws.String(wantID),
+ InstanceType: aws.String(wantType),
+ KeyName: aws.String(wantKey),
+ LaunchTime: aws.Time(wantCreationTime),
+ Placement: &ec2.Placement{
+ AvailabilityZone: aws.String(wantZone),
+ },
+ PrivateIpAddress: aws.String(wantIPInt),
+ PublicIpAddress: aws.String(wantIPExt),
+ SecurityGroups: []*ec2.GroupIdentifier{
+ &ec2.GroupIdentifier{
+ GroupId: aws.String(wantSecurityGroup),
+ },
+ },
+ State: &ec2.InstanceState{
+ Name: aws.String(wantState),
+ },
+ Tags: []*ec2.Tag{
+ &ec2.Tag{
+ Key: aws.String(tagName),
+ Value: aws.String(wantName),
+ },
+ &ec2.Tag{
+ Key: aws.String(tagDescription),
+ Value: aws.String(wantDescription),
+ },
+ &ec2.Tag{
+ Key: aws.String(wantTagKey),
+ Value: aws.String(wantTagValue),
+ },
+ },
+ }
+ gotInst := ec2ToInstance(ei)
+ if gotInst.CPUCount != wantCPUCount {
+ t.Errorf("CPUCount %d; want %d", gotInst.CPUCount, wantCPUCount)
+ }
+ if gotInst.CreatedAt != wantCreationTime {
+ t.Errorf("CreatedAt %s; want %s", gotInst.CreatedAt, wantCreationTime)
+ }
+ if gotInst.Description != wantDescription {
+ t.Errorf("Description %s; want %s", gotInst.Description, wantDescription)
+ }
+ if gotInst.ID != wantID {
+ t.Errorf("ID %s; want %s", gotInst.ID, wantID)
+ }
+ if gotInst.IPAddressExternal != wantIPExt {
+ t.Errorf("IPAddressExternal %s; want %s", gotInst.IPAddressExternal, wantIPExt)
+ }
+ if gotInst.IPAddressInternal != wantIPInt {
+ t.Errorf("IPAddressInternal %s; want %s", gotInst.IPAddressInternal, wantIPInt)
+ }
+ if gotInst.ImageID != wantImage {
+ t.Errorf("Image %s; want %s", gotInst.ImageID, wantImage)
+ }
+ if gotInst.Name != wantName {
+ t.Errorf("Name %s; want %s", gotInst.Name, wantName)
+ }
+ if gotInst.SSHKeyID != wantKey {
+ t.Errorf("SSHKeyID %s; want %s", gotInst.SSHKeyID, wantKey)
+ }
+ found := false
+ for _, sg := range gotInst.SecurityGroups {
+ if sg == wantSecurityGroup {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("SecurityGroups not found")
+ }
+ if gotInst.State != wantState {
+ t.Errorf("State %s; want %s", gotInst.State, wantState)
+ }
+ if gotInst.Type != wantType {
+ t.Errorf("Type %s; want %s", gotInst.Type, wantType)
+ }
+ if gotInst.Zone != wantZone {
+ t.Errorf("Zone %s; want %s", gotInst.Zone, wantZone)
+ }
+ gotValue, ok := gotInst.Tags[wantTagKey]
+ if !ok || gotValue != wantTagValue {
+ t.Errorf("Tags[%s] = %s, %t; want %s, %t", wantTagKey, gotValue, ok, wantTagValue, true)
+ }
+}
+
+func TestVMConfig(t *testing.T) {
+ wantDescription := "desc"
+ wantImage := "ami-56"
+ wantName := "my-instance"
+ wantKey := "my-key"
+ wantSecurityGroups := []string{"22"}
+ wantTags := map[string]string{
+ "tag1": "taggy1",
+ "tag2": "taggy2",
+ }
+ wantType := "type-1"
+ wantUserData := "user-data-x"
+ wantZone := "us-east-22"
+
+ rii := vmConfig(&EC2VMConfiguration{
+ Description: wantDescription,
+ ImageID: wantImage,
+ Name: wantName,
+ SSHKeyID: wantKey,
+ SecurityGroups: wantSecurityGroups,
+ Tags: wantTags,
+ Type: wantType,
+ UserData: wantUserData,
+ Zone: wantZone,
+ })
+
+ if *rii.ImageId != wantImage {
+ t.Errorf("image id %s; want %s", *rii.ImageId, wantImage)
+ }
+ if *rii.InstanceType != wantType {
+ t.Errorf("image id %s; want %s", *rii.ImageId, wantImage)
+ }
+ if *rii.MinCount != 1 {
+ t.Errorf("MinCount %d; want %d", *rii.MinCount, 1)
+ }
+ if *rii.MaxCount != 1 {
+ t.Errorf("MaxCount %d; want %d", *rii.MaxCount, 1)
+ }
+ if *rii.Placement.AvailabilityZone != wantZone {
+ t.Errorf("AvailabilityZone %s; want %s", *rii.Placement.AvailabilityZone, wantZone)
+ }
+ if !cmp.Equal(*rii.KeyName, wantKey) {
+ t.Errorf("SSHKeyID %+v; want %+v", *rii.KeyName, wantKey)
+ }
+ if *rii.InstanceInitiatedShutdownBehavior != ec2.ShutdownBehaviorTerminate {
+ t.Errorf("Shutdown Behavior %s; want %s", *rii.InstanceInitiatedShutdownBehavior, ec2.ShutdownBehaviorTerminate)
+ }
+ if *rii.UserData != wantUserData {
+ t.Errorf("UserData %s; want %s", *rii.UserData, wantUserData)
+ }
+ contains := func(tagSpec []*ec2.TagSpecification, key, value string) bool {
+ for _, ts := range tagSpec {
+ for _, t := range ts.Tags {
+ if *t.Key == key && *t.Value == value {
+ return true
+ }
+ }
+ }
+ return false
+ }
+ if !contains(rii.TagSpecifications, tagName, wantName) {
+ t.Errorf("want Tag Key: %s, Value: %s", tagName, wantName)
+ }
+ if !contains(rii.TagSpecifications, tagDescription, wantDescription) {
+ t.Errorf("want Tag Key: %s, Value: %s", tagDescription, wantDescription)
+ }
+ for k, v := range wantTags {
+ if !contains(rii.TagSpecifications, k, v) {
+ t.Errorf("want Tag Key: %s, Value: %s", k, v)
+ }
+ }
+ if !cmp.Equal(aws.StringValueSlice(rii.SecurityGroups), wantSecurityGroups) {
+ t.Errorf("SecurityGroups %v; want %v", aws.StringValueSlice(rii.SecurityGroups), wantSecurityGroups)
+ }
+}
+
+func TestEncodedString(t *testing.T) {
+ ud := EC2UserData{
+ BuildletBinaryURL: "binary_url_b",
+ BuildletHostType: "host_type_a",
+ BuildletImageURL: "image_url_c",
+ BuildletName: "name_d",
+ Metadata: map[string]string{
+ "key": "value",
+ },
+ TLSCert: "x",
+ TLSKey: "y",
+ TLSPassword: "z",
+ }
+ jsonUserData, err := json.Marshal(ud)
+ if err != nil {
+ t.Fatalf("unable to marshal user data to json: %s", err)
+ }
+ wantUD := base64.StdEncoding.EncodeToString([]byte(jsonUserData))
+ if ud.EncodedString() != wantUD {
+ t.Errorf("EncodedString() = %s; want %s", ud.EncodedString(), wantUD)
+ }
+}
diff --git a/internal/cloud/fake_aws.go b/internal/cloud/fake_aws.go
new file mode 100644
index 0000000..2a47b2c
--- /dev/null
+++ b/internal/cloud/fake_aws.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cloud
+
+import (
+ "context"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ mrand "math/rand"
+ "sync"
+ "time"
+
+ "github.com/aws/aws-sdk-go/service/ec2"
+)
+
+func init() { mrand.Seed(time.Now().UnixNano()) }
+
+// FakeAWSClient provides a fake AWS Client used to test the AWS client
+// functionality.
+type FakeAWSClient struct {
+ mu sync.RWMutex
+ instances map[string]*Instance
+}
+
+// NewFakeAWSClient crates a fake AWS client.
+func NewFakeAWSClient() *FakeAWSClient {
+ return &FakeAWSClient{
+ instances: make(map[string]*Instance),
+ }
+}
+
+// Instance returns the `Instance` record for the rquested instance. The instance record will
+// return records for recently terminated instances. If an instance is not found an error will
+// be returned.
+func (f *FakeAWSClient) Instance(ctx context.Context, instID string) (*Instance, error) {
+ if ctx == nil || instID == "" {
+ return nil, errors.New("invalid params")
+ }
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ inst, ok := f.instances[instID]
+ if !ok {
+ return nil, errors.New("instance not found")
+ }
+ return copyInstance(inst), nil
+}
+
+// Instances retrieves all EC2 instances in a region which have not been terminated or stopped.
+func (f *FakeAWSClient) RunningInstances(ctx context.Context) ([]*Instance, error) {
+ if ctx == nil {
+ return nil, errors.New("invalid params")
+ }
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ instances := make([]*Instance, 0, len(f.instances))
+ for _, inst := range f.instances {
+ if inst.State != ec2.InstanceStateNameRunning && inst.State != ec2.InstanceStateNamePending {
+ continue
+ }
+ instances = append(instances, copyInstance(inst))
+ }
+ return instances, nil
+}
+
+// CreateInstance creates an EC2 VM instance.
+func (f *FakeAWSClient) CreateInstance(ctx context.Context, config *EC2VMConfiguration) (*Instance, error) {
+ if ctx == nil || config == nil {
+ return nil, errors.New("invalid params")
+ }
+ if config.ImageID == "" {
+ return nil, errors.New("invalid Image ID")
+ }
+ if config.Type == "" {
+ return nil, errors.New("invalid Type")
+ }
+ if config.Zone == "" {
+ return nil, errors.New("invalid Zone")
+ }
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ inst := &Instance{
+ CPUCount: 4,
+ CreatedAt: time.Now(),
+ Description: config.Description,
+ ID: fmt.Sprintf("instance-%s", randHex(10)),
+ IPAddressExternal: randIPv4(),
+ IPAddressInternal: randIPv4(),
+ ImageID: config.ImageID,
+ Name: config.Name,
+ SSHKeyID: config.SSHKeyID,
+ SecurityGroups: config.SecurityGroups,
+ State: ec2.InstanceStateNameRunning,
+ Tags: make(map[string]string),
+ Type: config.Type,
+ Zone: config.Zone,
+ }
+ for k, v := range config.Tags {
+ inst.Tags[k] = v
+ }
+ f.instances[inst.ID] = inst
+ return copyInstance(inst), nil
+}
+
+// DestroyInstances terminates EC2 VM instances.
+func (f *FakeAWSClient) DestroyInstances(ctx context.Context, instIDs ...string) error {
+ if ctx == nil || len(instIDs) == 0 {
+ return errors.New("invalid params")
+ }
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ for _, id := range instIDs {
+ inst, ok := f.instances[id]
+ if !ok {
+ return errors.New("instance not found")
+ }
+ inst.State = ec2.InstanceStateNameTerminated
+ }
+ return nil
+}
+
+// WaitUntilInstanceRunning returns when an instance has transitioned into the running state.
+func (f *FakeAWSClient) WaitUntilInstanceRunning(ctx context.Context, instID string) error {
+ if ctx == nil || instID == "" {
+ return errors.New("invalid params")
+ }
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ inst, ok := f.instances[instID]
+ if !ok {
+ return errors.New("instance not found")
+ }
+ if inst.State != ec2.InstanceStateNameRunning {
+ return errors.New("timed out waiting for instance to enter running state")
+ }
+ return nil
+}
+
+// copyInstance copies the contents of a pointer to an instance and returns a newly created
+// instance with the same data as the original instance.
+func copyInstance(inst *Instance) *Instance {
+ i := &Instance{
+ CPUCount: inst.CPUCount,
+ CreatedAt: inst.CreatedAt,
+ Description: inst.Description,
+ ID: inst.ID,
+ IPAddressExternal: inst.IPAddressExternal,
+ IPAddressInternal: inst.IPAddressInternal,
+ ImageID: inst.ImageID,
+ Name: inst.Name,
+ SSHKeyID: inst.SSHKeyID,
+ SecurityGroups: inst.SecurityGroups,
+ State: inst.State,
+ Tags: make(map[string]string),
+ Type: inst.Type,
+ Zone: inst.Zone,
+ }
+ for k, v := range inst.Tags {
+ i.Tags[k] = v
+ }
+ return i
+}
+
+// randHex creates a random hex string of length n.
+func randHex(n int) string {
+ buf := make([]byte, n/2+1)
+ _, _ = rand.Read(buf)
+ return fmt.Sprintf("%x", buf)[:n]
+}
+
+// randIPv4 creates a random IPv4 address.
+func randIPv4() string {
+ return fmt.Sprintf("%d.%d.%d.%d", mrand.Intn(255), mrand.Intn(255), mrand.Intn(255), mrand.Intn(255))
+}
diff --git a/internal/cloud/fake_aws_test.go b/internal/cloud/fake_aws_test.go
new file mode 100644
index 0000000..5d78399
--- /dev/null
+++ b/internal/cloud/fake_aws_test.go
@@ -0,0 +1,336 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cloud
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+)
+
+func TestFakeAWSClientInstance(t *testing.T) {
+ t.Run("invalid-params", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance: %s", gotErr)
+ }
+ if gotInst, gotErr := f.Instance(nil, inst.ID); gotErr == nil {
+ t.Errorf("Instance(nil, %s) = %+v, nil, want error", inst.ID, gotInst)
+ }
+ if gotInst, gotErr := f.Instance(ctx, ""); gotErr == nil {
+ t.Errorf("Instance(ctx, %s) = %+v, nil, want error", "", gotInst)
+ }
+ })
+ t.Run("existing-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ gotInst, gotErr := f.Instance(ctx, inst.ID)
+ if gotErr != nil || gotInst == nil || gotInst.ID != inst.ID {
+ t.Errorf("Instance(ctx, %s) = %v, %s, want %+v, nil", inst.ID, gotInst, gotErr, inst)
+ }
+ })
+ t.Run("non-existing-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ instID := "instance-random"
+ gotInst, gotErr := f.Instance(ctx, instID)
+ if gotErr == nil || gotInst != nil {
+ t.Errorf("Instance(ctx, %s) = %v, %s, want error", instID, gotInst, gotErr)
+ }
+ })
+ t.Run("terminated-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ if gotErr := f.DestroyInstances(ctx, inst.ID); gotErr != nil {
+ t.Fatalf("unable to destroy instance")
+ }
+ gotInst, gotErr := f.Instance(ctx, inst.ID)
+ if gotErr != nil || gotInst == nil || gotInst.ID != inst.ID {
+ t.Errorf("Instance(ctx, %s) = %v, %s, want %+v, nil", inst.ID, gotInst, gotErr, inst)
+ }
+ })
+}
+
+func TestFakeAWSClientRunningInstances(t *testing.T) {
+ t.Run("invalid-params", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ _, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance: %s", gotErr)
+ }
+ if gotInst, gotErr := f.RunningInstances(nil); gotErr == nil {
+ t.Errorf("RunningInstances(nil) = %+v, nil, want error", gotInst)
+ }
+ })
+ t.Run("no-instances", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ gotInsts, gotErr := f.RunningInstances(ctx)
+ if gotErr != nil {
+ t.Errorf("RunningInstances() error = %v, no error", gotErr)
+ }
+ if !cmp.Equal(gotInsts, []*Instance{inst}) {
+ t.Errorf("RunningInstances() = %+v, %s; want %+v", gotInsts, gotErr, []*Instance{inst})
+ }
+ })
+ t.Run("single-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ gotInsts, gotErr := f.RunningInstances(ctx)
+ if gotErr != nil {
+ t.Errorf("RunningInstances() error = %v, no error", gotErr)
+ }
+ if !cmp.Equal(gotInsts, []*Instance{inst}) {
+ t.Errorf("RunningInstances() = %+v, %s; want %+v", gotInsts, gotErr, []*Instance{inst})
+ }
+ })
+ t.Run("multiple-instances", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ create := []*EC2VMConfiguration{
+ generateVMConfig(),
+ generateVMConfig(),
+ generateVMConfig(),
+ }
+ insts := make([]*Instance, 0, len(create))
+ for _, config := range create {
+ inst, gotErr := f.CreateInstance(ctx, config)
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ insts = append(insts, inst)
+ }
+ gotInsts, gotErr := f.RunningInstances(ctx)
+ if gotErr != nil {
+ t.Errorf("RunningInstances() error = %v, no error", gotErr)
+ }
+ opt := cmpopts.SortSlices(func(i, j *Instance) bool { return i.ID < j.ID })
+ if !cmp.Equal(gotInsts, insts, opt) {
+ t.Errorf("RunningInstances() = %+v, %s; want %+v", gotInsts, gotErr, insts)
+ }
+ })
+
+ t.Run("multiple-instances-with-one-termination", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ create := []*EC2VMConfiguration{
+ generateVMConfig(),
+ generateVMConfig(),
+ generateVMConfig(),
+ }
+ insts := make([]*Instance, 0, len(create))
+ for _, config := range create {
+ inst, gotErr := f.CreateInstance(ctx, config)
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ insts = append(insts, inst)
+ }
+ if gotErr := f.DestroyInstances(ctx, insts[0].ID); gotErr != nil {
+ t.Fatalf("unable to destroy instance")
+ }
+ gotInsts, gotErr := f.RunningInstances(ctx)
+ if gotErr != nil {
+ t.Errorf("RunningInstances() error = %v, no error", gotErr)
+ }
+ opt := cmpopts.SortSlices(func(i, j *Instance) bool { return i.ID < j.ID })
+ if !cmp.Equal(gotInsts, insts[1:], opt) {
+ t.Errorf("RunningInstances() = %+v, %s; want %+v", gotInsts, gotErr, insts[1:])
+ }
+ })
+}
+
+func TestFakeAWSClientCreateInstance(t *testing.T) {
+ t.Run("create-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ ud := &EC2UserData{}
+ config := &EC2VMConfiguration{
+ Description: "desc",
+ ImageID: "id-44",
+ Name: "name-22",
+ SSHKeyID: "key-43",
+ SecurityGroups: []string{"sg-1", "sg-2"},
+ Tags: map[string]string{
+ "key-1": "value-1",
+ },
+ Type: "ami-44",
+ UserData: ud.EncodedString(),
+ Zone: "zone-14",
+ }
+ gotInst, gotErr := f.CreateInstance(ctx, config)
+ if gotErr != nil {
+ t.Fatalf("CreateInstance(ctx, %+v) = %+v, %s; want no error", config, gotInst, gotErr)
+ }
+ // generated fields
+ if gotInst.CPUCount <= 0 {
+ t.Errorf("Instance. is not set")
+ }
+ if gotInst.ID == "" {
+ t.Errorf("Instance.ID is not set")
+ }
+ if gotInst.IPAddressExternal == "" {
+ t.Errorf("Instance.IPAddressExternal is not set")
+ }
+ if gotInst.IPAddressInternal == "" {
+ t.Errorf("Instance.IPAddressInternal is not set")
+ }
+ if gotInst.State == "" {
+ t.Errorf("Instance.State is not set")
+ }
+ // config fields
+ if gotInst.Description != config.Description {
+ t.Errorf("Instance.Description = %s, want %s", gotInst.Description, config.Description)
+ }
+ if gotInst.ImageID != config.ImageID {
+ t.Errorf("Instance.ImageID = %s, want %s", gotInst.ImageID, config.ImageID)
+ }
+ if gotInst.Name != config.Name {
+ t.Errorf("Instance.Name = %s, want %s", gotInst.Name, config.Name)
+ }
+ if gotInst.SSHKeyID != config.SSHKeyID {
+ t.Errorf("Instance.SSHKeyID = %s, want %s", gotInst.SSHKeyID, config.SSHKeyID)
+ }
+ if !cmp.Equal(gotInst.SecurityGroups, config.SecurityGroups) {
+ t.Errorf("Instance.SecurityGroups = %s, want %s", gotInst.SecurityGroups, config.SecurityGroups)
+ }
+ if !cmp.Equal(gotInst.Tags, config.Tags) {
+ t.Errorf("Instance.Tags = %+v, want %+v", gotInst.Tags, config.Tags)
+ }
+ if gotInst.Type != config.Type {
+ t.Errorf("Instance.Type = %s, want %s", gotInst.Type, config.Type)
+ }
+ if gotInst.Zone != config.Zone {
+ t.Errorf("Instance.Zone = %s, want %s", gotInst.Zone, config.Zone)
+ }
+ })
+}
+
+func TestFakeAWSClientDestroyInstances(t *testing.T) {
+ t.Run("invalid-params", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance: %s", gotErr)
+ }
+ if gotErr := f.DestroyInstances(nil, inst.ID); gotErr == nil {
+ t.Errorf("DestroyInstances(nil, %s) = nil, want error", inst.ID)
+ }
+ if gotErr := f.DestroyInstances(ctx); gotErr == nil {
+ t.Error("DestroyInstances(ctx) = nil, want error")
+ }
+ })
+ t.Run("destroy-existing-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ if gotErr = f.DestroyInstances(ctx, inst.ID); gotErr != nil {
+ t.Errorf("DestroyInstances(ctx, %s) = %s; want no error", inst.ID, gotErr)
+ }
+ })
+ t.Run("destroy-existing-instances", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst1, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ inst2, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ if gotErr = f.DestroyInstances(ctx, inst1.ID, inst2.ID); gotErr != nil {
+ t.Errorf("DestroyInstances(ctx, %s, %s) = %s; want no error", inst1.ID, inst2.ID, gotErr)
+ }
+ })
+ t.Run("destroy-non-existing-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ instID := "instance-random"
+ if gotErr := f.DestroyInstances(ctx, instID); gotErr == nil {
+ t.Errorf("DestroyInstances(ctx, %s) = %s; want error", instID, gotErr)
+ }
+ })
+}
+
+func TestFakeAWSClientWaitUntilInstanceRunning(t *testing.T) {
+ t.Run("invalid-params", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance: %s", gotErr)
+ }
+ if gotErr := f.WaitUntilInstanceRunning(nil, inst.ID); gotErr == nil {
+ t.Errorf("WaitUntilInstanceRunning(nil, %s) = nil, want error", inst.ID)
+ }
+ if gotErr := f.WaitUntilInstanceRunning(ctx, ""); gotErr == nil {
+ t.Errorf("WaitUntilInstanceRunning(ctx, %s) = nil, want error", "")
+ }
+ })
+ t.Run("wait-for-existing-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ inst, gotErr := f.CreateInstance(ctx, generateVMConfig())
+ if gotErr != nil {
+ t.Fatalf("unable to create instance")
+ }
+ if gotErr = f.WaitUntilInstanceRunning(ctx, inst.ID); gotErr != nil {
+ t.Errorf("WaitUntilInstanceRunning(ctx, %s) = %s; want no error", inst.ID, gotErr)
+ }
+ })
+ t.Run("wait-for-non-existing-instance", func(t *testing.T) {
+ ctx := context.Background()
+ f := NewFakeAWSClient()
+ instID := "instance-random"
+ if gotErr := f.WaitUntilInstanceRunning(ctx, instID); gotErr == nil {
+ t.Errorf("WaitUntilInstanceRunning(ctx, %s) = %s; want error", instID, gotErr)
+ }
+ })
+}
+
+func TestRandIPv4(t *testing.T) {
+ got := randIPv4()
+ gotIP := net.ParseIP(got)
+ if gotIP == nil {
+ t.Errorf("randIPv4() = %v, want conforment IPv4 address", got)
+ }
+}
+
+func generateVMConfig() *EC2VMConfiguration {
+ return &EC2VMConfiguration{
+ ImageID: fmt.Sprintf("ami-%s", randHex(4)),
+ SSHKeyID: fmt.Sprintf("key-%s", randHex(4)),
+ Type: fmt.Sprintf("type-%s", randHex(4)),
+ Zone: fmt.Sprintf("zone-%s", randHex(4)),
+ }
+}