internal/cloud, cmd/coordinator: add a rate limiter for the AWS client

This change adds an optional rate limiter which implements both rate
and resource limits. This change should enable the coordinator to add
additional builders hosted on EC2. This change will reduce the chance
that we encounter one of the limits.

The rate limiter is essential for any calls which mutate cloud
resources. While the EC2 client.DefaultRetryer should be sufficient
enough to handle the non-mutating requests, it would implement a best
practice in having client side rate limiting for those types of requests.

Amazon EC2 throttles requests EC2 API requests
https://docs.aws.amazon.com/AWSEC2/latest/APIReference/throttling.html

Fixes golang/go#40950

Change-Id: Ib80bcf7c5ab0b0483d5beb11f3581cdb1d0174fe
Reviewed-on: https://go-review.googlesource.com/c/build/+/267901
Trust: Carlos Amedee <carlos@golang.org>
Run-TryBot: Carlos Amedee <carlos@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Alexander Rakoczy <alex@golang.org>
diff --git a/cmd/coordinator/coordinator.go b/cmd/coordinator/coordinator.go
index c98e094..9e43d72 100644
--- a/cmd/coordinator/coordinator.go
+++ b/cmd/coordinator/coordinator.go
@@ -4081,7 +4081,7 @@
 		log.Fatalf("unable to retrieve secret %q: %s", secret.NameAWSAccessKey, err)
 	}
 
-	awsClient, err := cloud.NewAWSClient(buildenv.Production.AWSRegion, awsKeyID, awsAccessKey)
+	awsClient, err := cloud.NewAWSClient(buildenv.Production.AWSRegion, awsKeyID, awsAccessKey, cloud.WithRateLimiter(cloud.DefaultEC2LimitConfig))
 	if err != nil {
 		log.Fatalf("unable to create AWS client: %s", err)
 	}
diff --git a/internal/cloud/aws.go b/internal/cloud/aws.go
index 008660a..e5aef99 100644
--- a/internal/cloud/aws.go
+++ b/internal/cloud/aws.go
@@ -118,8 +118,11 @@
 	quotaClient quotaClient
 }
 
+// AWSOpt is an optional configuration setting for the AWSClient.
+type AWSOpt func(*AWSClient)
+
 // NewAWSClient creates a new AWS client.
-func NewAWSClient(region, keyID, accessKey string) (*AWSClient, error) {
+func NewAWSClient(region, keyID, accessKey string, opts ...AWSOpt) (*AWSClient, error) {
 	s, err := session.NewSession(&aws.Config{
 		Region:      aws.String(region),
 		Credentials: credentials.NewStaticCredentials(keyID, accessKey, ""), // Token is only required for STS
@@ -127,10 +130,14 @@
 	if err != nil {
 		return nil, fmt.Errorf("failed to create AWS session: %v", err)
 	}
-	return &AWSClient{
+	c := &AWSClient{
 		ec2Client:   ec2.New(s),
 		quotaClient: servicequotas.New(s),
-	}, nil
+	}
+	for _, opt := range opts {
+		opt(c)
+	}
+	return c, nil
 }
 
 // Instance retrieves an EC2 instance by instance ID.
diff --git a/internal/cloud/aws_interceptor.go b/internal/cloud/aws_interceptor.go
new file mode 100644
index 0000000..d3f9763
--- /dev/null
+++ b/internal/cloud/aws_interceptor.go
@@ -0,0 +1,162 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cloud
+
+import (
+	"context"
+	"errors"
+
+	"github.com/aws/aws-sdk-go/aws"
+	"github.com/aws/aws-sdk-go/aws/request"
+	"github.com/aws/aws-sdk-go/service/ec2"
+	"golang.org/x/sync/errgroup"
+	"golang.org/x/time/rate"
+)
+
+// rateLimiter is an interface mainly used for testing.
+type rateLimiter interface {
+	Wait(ctx context.Context) (err error)
+	WaitN(ctx context.Context, n int) (err error)
+}
+
+// DefaultEC2LimitConfig sets limits defined in
+// https://docs.aws.amazon.com/AWSEC2/latest/APIReference/throttling.html
+var DefaultEC2LimitConfig = &EC2LimitConfig{
+	MutatingRate:                    5,
+	MutatingRateBucket:              200,
+	NonMutatingRate:                 20,
+	NonMutatingRateBucket:           100,
+	RunInstanceRate:                 2,
+	RunInstanceRateBucket:           5,
+	RunInstanceResource:             2,
+	RunInstanceResourceBucket:       1000,
+	TerminateInstanceResource:       20,
+	TerminateInstanceResourceBucket: 1000,
+}
+
+// EC2LimitConfig contains the desired rate and resource rate limit configurations.
+type EC2LimitConfig struct {
+	// MutatingRate sets the refill rate for mutating requests.
+	MutatingRate float64
+	// MutatingRateBucket sets the bucket size for mutating requests.
+	MutatingRateBucket int
+	// NonMutatingRate sets the refill rate for non-mutating requests.
+	NonMutatingRate float64
+	// NonMutatingRateBucket sets the bucket size for non-mutating requests.
+	NonMutatingRateBucket int
+	// RunInstanceRate sets the refill rate for run instance rate requests.
+	RunInstanceRate float64
+	// RunInstanceRateBucket sets the bucket size for run instance rate requests.
+	RunInstanceRateBucket int
+	// RunInstanceResource sets the refill rate for run instance rate resources.
+	RunInstanceResource float64
+	// RunInstanceResourceBucket sets the bucket size for run instance rate resources.
+	RunInstanceResourceBucket int
+	// TerminateInstanceResource sets the refill rate for terminate instance rate resources.
+	TerminateInstanceResource float64
+	// TerminateInstanceResourceBucket sets the bucket size for terminate instance resources.
+	TerminateInstanceResourceBucket int
+}
+
+// WithRateLimiter adds a rate limiter to the AWSClient.
+func WithRateLimiter(config *EC2LimitConfig) AWSOpt {
+	return func(c *AWSClient) {
+		c.ec2Client = &EC2RateLimitInterceptor{
+			next:                      c.ec2Client,
+			mutatingRate:              rate.NewLimiter(rate.Limit(config.MutatingRate), config.MutatingRateBucket),
+			nonMutatingRate:           rate.NewLimiter(rate.Limit(config.NonMutatingRate), config.NonMutatingRateBucket),
+			runInstancesRate:          rate.NewLimiter(rate.Limit(config.RunInstanceRate), config.RunInstanceRateBucket),
+			runInstancesResource:      rate.NewLimiter(rate.Limit(config.RunInstanceResource), config.RunInstanceResourceBucket),
+			terminateInstanceResource: rate.NewLimiter(rate.Limit(config.TerminateInstanceResource), config.TerminateInstanceResourceBucket),
+		}
+	}
+}
+
+var _ vmClient = (*EC2RateLimitInterceptor)(nil)
+
+// EC2RateLimitInterceptor implements an interceptor that will rate limit requests
+// to the AWS API and allow calls to the appropriate clients to proceed.
+type EC2RateLimitInterceptor struct {
+	// next is the client called after the rate limiting.
+	next vmClient
+	// mutatingRate is the rate limiter for mutating requests.
+	mutatingRate rateLimiter
+	// 	nonMutatingRate is the rate limiter for non-mutating requests.
+	nonMutatingRate rateLimiter
+	// runInstancesRate is the rate limiter for run instances requests.
+	runInstancesRate rateLimiter
+	// runInstancesResource is the rate limiter for run instance resources.
+	runInstancesResource rateLimiter
+	// terminateInstanceResource is the rate limiter for terminate instance resources.
+	terminateInstanceResource rateLimiter
+}
+
+// DescribeInstancesPagesWithContext rate limits calls. The rate limiter will return an error if the request exceeds the bucket size, the Context is canceled, or the expected wait time exceeds the Context's Deadline.
+func (i *EC2RateLimitInterceptor) DescribeInstancesPagesWithContext(ctx context.Context, in *ec2.DescribeInstancesInput, fn func(*ec2.DescribeInstancesOutput, bool) bool, opts ...request.Option) error {
+	if err := i.nonMutatingRate.Wait(ctx); err != nil {
+		return err
+	}
+	return i.next.DescribeInstancesPagesWithContext(ctx, in, fn, opts...)
+}
+
+// DescribeInstancesWithContext rate limits calls. The rate limiter will return an error if the request exceeds the bucket size, the Context is canceled, or the expected wait time exceeds the Context's Deadline.
+func (i *EC2RateLimitInterceptor) DescribeInstancesWithContext(ctx context.Context, in *ec2.DescribeInstancesInput, opts ...request.Option) (*ec2.DescribeInstancesOutput, error) {
+	if err := i.nonMutatingRate.Wait(ctx); err != nil {
+		return nil, err
+	}
+	return i.next.DescribeInstancesWithContext(ctx, in, opts...)
+}
+
+// RunInstancesWithContext rate limits calls. The rate limiter will return an error if the request exceeds the bucket size, the Context is canceled, or the expected wait time exceeds the Context's Deadline. An error is returned if either the rate or resource limiter returns an error.
+func (i *EC2RateLimitInterceptor) RunInstancesWithContext(ctx context.Context, in *ec2.RunInstancesInput, opts ...request.Option) (*ec2.Reservation, error) {
+	g, ctx := errgroup.WithContext(ctx)
+	g.Go(func() error {
+		return i.runInstancesRate.Wait(ctx)
+	})
+	g.Go(func() error {
+		numInst := aws.Int64Value(in.MaxCount)
+		c := int(numInst)
+		if int64(c) != numInst {
+			return errors.New("unable to convert max count to int")
+		}
+		return i.runInstancesResource.WaitN(ctx, c)
+	})
+	if err := g.Wait(); err != nil {
+		return nil, err
+	}
+	return i.next.RunInstancesWithContext(ctx, in, opts...)
+}
+
+// TerminateInstancesWithContext rate limits calls. The rate limiter will return an error if the request exceeds the bucket size, the Context is canceled, or the expected wait time exceeds the Context's Deadline. An error is returned if either the rate or resource limiter returns an error.
+func (i *EC2RateLimitInterceptor) TerminateInstancesWithContext(ctx context.Context, in *ec2.TerminateInstancesInput, opts ...request.Option) (*ec2.TerminateInstancesOutput, error) {
+	g, ctx := errgroup.WithContext(ctx)
+	g.Go(func() error {
+		return i.mutatingRate.Wait(ctx)
+	})
+	g.Go(func() error {
+		c := len(in.InstanceIds)
+		return i.terminateInstanceResource.WaitN(ctx, c)
+	})
+	if err := g.Wait(); err != nil {
+		return nil, err
+	}
+	return i.next.TerminateInstancesWithContext(ctx, in, opts...)
+}
+
+// WaitUntilInstanceRunningWithContext rate limits calls. The rate limiter will return an error if the request exceeds the bucket size, the Context is canceled, or the expected wait time exceeds the Context's Deadline.
+func (i *EC2RateLimitInterceptor) WaitUntilInstanceRunningWithContext(ctx context.Context, in *ec2.DescribeInstancesInput, opts ...request.WaiterOption) error {
+	if err := i.nonMutatingRate.Wait(ctx); err != nil {
+		return err
+	}
+	return i.next.WaitUntilInstanceRunningWithContext(ctx, in, opts...)
+}
+
+// DescribeInstanceTypesPagesWithContext rate limits calls. The rate limiter will return an error if the request exceeds the bucket size, the Context is canceled, or the expected wait time exceeds the Context's Deadline.
+func (i *EC2RateLimitInterceptor) DescribeInstanceTypesPagesWithContext(ctx context.Context, in *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opts ...request.Option) error {
+	if err := i.nonMutatingRate.Wait(ctx); err != nil {
+		return err
+	}
+	return i.next.DescribeInstanceTypesPagesWithContext(ctx, in, fn, opts...)
+}
diff --git a/internal/cloud/aws_interceptor_test.go b/internal/cloud/aws_interceptor_test.go
new file mode 100644
index 0000000..d77aceb
--- /dev/null
+++ b/internal/cloud/aws_interceptor_test.go
@@ -0,0 +1,225 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cloud
+
+import (
+	"context"
+	"errors"
+	"sync/atomic"
+	"testing"
+
+	"github.com/aws/aws-sdk-go/aws"
+	"github.com/aws/aws-sdk-go/aws/request"
+	"github.com/aws/aws-sdk-go/service/ec2"
+)
+
+var _ rateLimiter = (*fakeRateLimiter)(nil)
+
+var rateExceededErr = errors.New("rate limit exceeded")
+
+type fakeRateLimiter struct {
+	waitCalledCount int64
+	waitCallLimit   int64
+}
+
+func newFakeRateLimiter(limit int64) *fakeRateLimiter {
+	return &fakeRateLimiter{waitCallLimit: limit}
+}
+
+func (frl *fakeRateLimiter) Wait(ctx context.Context) (err error) {
+	return frl.WaitN(ctx, 1)
+}
+
+func (frl *fakeRateLimiter) WaitN(ctx context.Context, n int) (err error) {
+	count := atomic.AddInt64(&frl.waitCalledCount, int64(n))
+	if count > frl.waitCallLimit {
+		return rateExceededErr
+	}
+	return nil
+}
+
+func (frl *fakeRateLimiter) called() bool {
+	if atomic.LoadInt64(&frl.waitCalledCount) > 0 {
+		return true
+	}
+	return false
+}
+
+type noopEC2Client struct {
+	t *testing.T
+}
+
+func (f *noopEC2Client) DescribeInstancesPagesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, fn func(*ec2.DescribeInstancesOutput, bool) bool, opt ...request.Option) error {
+	if ctx == nil || input == nil || fn == nil || len(opt) != 1 {
+		f.t.Fatal("DescribeInstancesPagesWithContext params not passed down")
+	}
+	return nil
+}
+
+func (f *noopEC2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.Option) (*ec2.DescribeInstancesOutput, error) {
+	if ctx == nil || input == nil || len(opt) != 1 {
+		f.t.Fatal("DescribeInstancesWithContext params not passed down")
+	}
+	return nil, nil
+}
+
+func (f *noopEC2Client) RunInstancesWithContext(ctx context.Context, input *ec2.RunInstancesInput, opts ...request.Option) (*ec2.Reservation, error) {
+	if ctx == nil || input == nil || len(opts) != 1 {
+		f.t.Fatal("RunInstancesWithContext params not passed down")
+	}
+	return nil, nil
+}
+
+func (f *noopEC2Client) TerminateInstancesWithContext(ctx context.Context, input *ec2.TerminateInstancesInput, opts ...request.Option) (*ec2.TerminateInstancesOutput, error) {
+	if ctx == nil || input == nil || len(opts) != 1 {
+		f.t.Fatal("TerminateInstancesWithContext params not passed down")
+	}
+	return nil, nil
+}
+
+func (f *noopEC2Client) WaitUntilInstanceRunningWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.WaiterOption) error {
+	if ctx == nil || input == nil || len(opt) != 1 {
+		f.t.Fatal("WaitUntilInstanceRunningWithContext params not passed down")
+	}
+	return nil
+}
+
+func (f *noopEC2Client) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opt ...request.Option) error {
+	if ctx == nil || input == nil || fn == nil || len(opt) != 1 {
+		f.t.Fatal("DescribeInstancesPagesWithContext params not passed down")
+	}
+	return nil
+}
+
+func TestEC2RateLimitInterceptorDescribeInstancesPagesWithContext(t *testing.T) {
+	rate := newFakeRateLimiter(1)
+	i := &EC2RateLimitInterceptor{
+		next:            &noopEC2Client{t: t},
+		nonMutatingRate: rate,
+	}
+	fn := func() error {
+		return i.DescribeInstancesPagesWithContext(context.Background(), &ec2.DescribeInstancesInput{}, func(*ec2.DescribeInstancesOutput, bool) bool { return true }, request.WithAppendUserAgent("test-agent"))
+	}
+	if err := fn(); err != nil {
+		t.Fatalf("DescribeInstancesPagesWithContext(...) = nil, %s; want no error", err)
+	}
+	if !rate.called() {
+		t.Error("rateLimiter.Wait() was never called")
+	}
+	if err := fn(); err != rateExceededErr {
+		t.Errorf("DescribeInstancesPagesWithContext(...) = %s; want %s", err, rateExceededErr)
+	}
+}
+
+func TestEC2RateLimitInterceptorDescribeInstancesWithContext(t *testing.T) {
+	rate := newFakeRateLimiter(1)
+	i := &EC2RateLimitInterceptor{
+		next:            &noopEC2Client{t: t},
+		nonMutatingRate: rate,
+	}
+	fn := func() error {
+		_, err := i.DescribeInstancesWithContext(context.Background(), &ec2.DescribeInstancesInput{}, request.WithAppendUserAgent("test-agent"))
+		return err
+	}
+	if err := fn(); err != nil {
+		t.Fatalf("DescribeInstancesWithContext(...) = nil, %s; want no error", err)
+	}
+	if !rate.called() {
+		t.Errorf("rateLimiter.Wait() was never called")
+	}
+	if err := fn(); err != rateExceededErr {
+		t.Errorf("DescribeInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
+	}
+}
+
+func TestEC2RateLimitInterceptorRunInstancesWithContext(t *testing.T) {
+	rate := newFakeRateLimiter(1)
+	resource := newFakeRateLimiter(1)
+	i := &EC2RateLimitInterceptor{
+		next:                 &noopEC2Client{t: t},
+		runInstancesRate:     rate,
+		runInstancesResource: resource,
+	}
+	fn := func() error {
+		_, err := i.RunInstancesWithContext(context.Background(), &ec2.RunInstancesInput{
+			MaxCount: aws.Int64(1),
+		}, request.WithAppendUserAgent("test-agent"))
+		return err
+	}
+	if err := fn(); err != nil {
+		t.Fatalf("RunInstancesWithContext(...) = nil, %s; want no error", err)
+	}
+	if !rate.called() || !resource.called() {
+		t.Errorf("rateLimiter.Wait() was never called; rate=%t, resource=%t", rate.called(), resource.called())
+	}
+	if err := fn(); err != rateExceededErr {
+		t.Errorf("RunInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
+	}
+}
+
+func TestEC2RateLimitInterceptorTerminateInstancesWithContext(t *testing.T) {
+	rate := newFakeRateLimiter(1)
+	resource := newFakeRateLimiter(1)
+	i := &EC2RateLimitInterceptor{
+		next:                      &noopEC2Client{t: t},
+		mutatingRate:              rate,
+		terminateInstanceResource: resource,
+	}
+	fn := func() error {
+		_, err := i.TerminateInstancesWithContext(context.Background(), &ec2.TerminateInstancesInput{
+			InstanceIds: []*string{aws.String("foo")},
+		}, request.WithAppendUserAgent("test-agent"))
+		return err
+	}
+	if err := fn(); err != nil {
+		t.Fatalf("TerminateInstancesWithContext(...) = nil, %s; want no error", err)
+	}
+	if !rate.called() || !resource.called() {
+		t.Errorf("rateLimiter.Wait() was never called; rate=%t, resource=%t", rate.called(), resource.called())
+	}
+	if err := fn(); err != rateExceededErr {
+		t.Errorf("TerminateInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
+	}
+}
+
+func TestEC2RateLimitInterceptorWaitUntilInstanceRunningWithContext(t *testing.T) {
+	rate := newFakeRateLimiter(1)
+	i := &EC2RateLimitInterceptor{
+		next:            &noopEC2Client{t: t},
+		nonMutatingRate: rate,
+	}
+	fn := func() error {
+		return i.WaitUntilInstanceRunningWithContext(context.Background(), &ec2.DescribeInstancesInput{}, request.WithWaiterMaxAttempts(1))
+	}
+	if err := fn(); err != nil {
+		t.Fatalf("WaitUntilInstanceRunningWithContext(...) = nil, %s; want no error", err)
+	}
+	if !rate.called() {
+		t.Errorf("rateLimiter.Wait() was never called")
+	}
+	if err := fn(); err != rateExceededErr {
+		t.Errorf("WaitUntilInstanceRunningWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
+	}
+}
+
+func TestEC2RateLimitInterceptorDescribeInstanceTypesPagesWithContext(t *testing.T) {
+	rate := newFakeRateLimiter(1)
+	i := &EC2RateLimitInterceptor{
+		next:            &noopEC2Client{t: t},
+		nonMutatingRate: rate,
+	}
+	fn := func() error {
+		return i.DescribeInstanceTypesPagesWithContext(context.Background(), &ec2.DescribeInstanceTypesInput{}, func(*ec2.DescribeInstanceTypesOutput, bool) bool { return true }, request.WithAppendUserAgent("test-agent"))
+	}
+	if err := fn(); err != nil {
+		t.Fatalf("DescribeInstanceTypesPagesWithContext(...) = nil, %s; want no error", err)
+	}
+	if !rate.called() {
+		t.Errorf("rateLimiter.Wait() was never called")
+	}
+	if err := fn(); err != rateExceededErr {
+		t.Errorf("DescribeInstanceTypesPagesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
+	}
+}