blob: ddd0987c7b6c857f5395b590a6b34ac9054e254b [file] [log] [blame]
Carlos Amedee725674f2020-08-21 11:31:10 -04001// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cloud
6
7import (
8 "context"
9 "errors"
10 "sync/atomic"
11 "testing"
12
13 "github.com/aws/aws-sdk-go/aws"
14 "github.com/aws/aws-sdk-go/aws/request"
15 "github.com/aws/aws-sdk-go/service/ec2"
16)
17
18var _ rateLimiter = (*fakeRateLimiter)(nil)
19
20var rateExceededErr = errors.New("rate limit exceeded")
21
22type fakeRateLimiter struct {
23 waitCalledCount int64
24 waitCallLimit int64
25}
26
27func newFakeRateLimiter(limit int64) *fakeRateLimiter {
28 return &fakeRateLimiter{waitCallLimit: limit}
29}
30
31func (frl *fakeRateLimiter) Wait(ctx context.Context) (err error) {
32 return frl.WaitN(ctx, 1)
33}
34
35func (frl *fakeRateLimiter) WaitN(ctx context.Context, n int) (err error) {
36 count := atomic.AddInt64(&frl.waitCalledCount, int64(n))
37 if count > frl.waitCallLimit {
38 return rateExceededErr
39 }
40 return nil
41}
42
43func (frl *fakeRateLimiter) called() bool {
44 if atomic.LoadInt64(&frl.waitCalledCount) > 0 {
45 return true
46 }
47 return false
48}
49
50type noopEC2Client struct {
51 t *testing.T
52}
53
54func (f *noopEC2Client) DescribeInstancesPagesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, fn func(*ec2.DescribeInstancesOutput, bool) bool, opt ...request.Option) error {
55 if ctx == nil || input == nil || fn == nil || len(opt) != 1 {
56 f.t.Fatal("DescribeInstancesPagesWithContext params not passed down")
57 }
58 return nil
59}
60
61func (f *noopEC2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.Option) (*ec2.DescribeInstancesOutput, error) {
62 if ctx == nil || input == nil || len(opt) != 1 {
63 f.t.Fatal("DescribeInstancesWithContext params not passed down")
64 }
65 return nil, nil
66}
67
68func (f *noopEC2Client) RunInstancesWithContext(ctx context.Context, input *ec2.RunInstancesInput, opts ...request.Option) (*ec2.Reservation, error) {
69 if ctx == nil || input == nil || len(opts) != 1 {
70 f.t.Fatal("RunInstancesWithContext params not passed down")
71 }
Carlos Amedeefb867d22020-11-10 00:09:57 -050072 if ctx.Err() != nil {
73 f.t.Fatalf("context.Err() = %s; want no error", ctx.Err())
74 }
Carlos Amedee725674f2020-08-21 11:31:10 -040075 return nil, nil
76}
77
78func (f *noopEC2Client) TerminateInstancesWithContext(ctx context.Context, input *ec2.TerminateInstancesInput, opts ...request.Option) (*ec2.TerminateInstancesOutput, error) {
79 if ctx == nil || input == nil || len(opts) != 1 {
80 f.t.Fatal("TerminateInstancesWithContext params not passed down")
81 }
Carlos Amedeefb867d22020-11-10 00:09:57 -050082 if ctx.Err() != nil {
83 f.t.Fatalf("context.Err() = %s; want no error", ctx.Err())
84 }
Carlos Amedee725674f2020-08-21 11:31:10 -040085 return nil, nil
86}
87
88func (f *noopEC2Client) WaitUntilInstanceRunningWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.WaiterOption) error {
89 if ctx == nil || input == nil || len(opt) != 1 {
90 f.t.Fatal("WaitUntilInstanceRunningWithContext params not passed down")
91 }
92 return nil
93}
94
95func (f *noopEC2Client) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opt ...request.Option) error {
96 if ctx == nil || input == nil || fn == nil || len(opt) != 1 {
97 f.t.Fatal("DescribeInstancesPagesWithContext params not passed down")
98 }
99 return nil
100}
101
102func TestEC2RateLimitInterceptorDescribeInstancesPagesWithContext(t *testing.T) {
103 rate := newFakeRateLimiter(1)
104 i := &EC2RateLimitInterceptor{
105 next: &noopEC2Client{t: t},
106 nonMutatingRate: rate,
107 }
108 fn := func() error {
109 return i.DescribeInstancesPagesWithContext(context.Background(), &ec2.DescribeInstancesInput{}, func(*ec2.DescribeInstancesOutput, bool) bool { return true }, request.WithAppendUserAgent("test-agent"))
110 }
111 if err := fn(); err != nil {
112 t.Fatalf("DescribeInstancesPagesWithContext(...) = nil, %s; want no error", err)
113 }
114 if !rate.called() {
115 t.Error("rateLimiter.Wait() was never called")
116 }
117 if err := fn(); err != rateExceededErr {
118 t.Errorf("DescribeInstancesPagesWithContext(...) = %s; want %s", err, rateExceededErr)
119 }
120}
121
122func TestEC2RateLimitInterceptorDescribeInstancesWithContext(t *testing.T) {
123 rate := newFakeRateLimiter(1)
124 i := &EC2RateLimitInterceptor{
125 next: &noopEC2Client{t: t},
126 nonMutatingRate: rate,
127 }
128 fn := func() error {
129 _, err := i.DescribeInstancesWithContext(context.Background(), &ec2.DescribeInstancesInput{}, request.WithAppendUserAgent("test-agent"))
130 return err
131 }
132 if err := fn(); err != nil {
133 t.Fatalf("DescribeInstancesWithContext(...) = nil, %s; want no error", err)
134 }
135 if !rate.called() {
136 t.Errorf("rateLimiter.Wait() was never called")
137 }
138 if err := fn(); err != rateExceededErr {
139 t.Errorf("DescribeInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
140 }
141}
142
143func TestEC2RateLimitInterceptorRunInstancesWithContext(t *testing.T) {
144 rate := newFakeRateLimiter(1)
145 resource := newFakeRateLimiter(1)
146 i := &EC2RateLimitInterceptor{
147 next: &noopEC2Client{t: t},
148 runInstancesRate: rate,
149 runInstancesResource: resource,
150 }
151 fn := func() error {
152 _, err := i.RunInstancesWithContext(context.Background(), &ec2.RunInstancesInput{
153 MaxCount: aws.Int64(1),
154 }, request.WithAppendUserAgent("test-agent"))
155 return err
156 }
157 if err := fn(); err != nil {
158 t.Fatalf("RunInstancesWithContext(...) = nil, %s; want no error", err)
159 }
160 if !rate.called() || !resource.called() {
161 t.Errorf("rateLimiter.Wait() was never called; rate=%t, resource=%t", rate.called(), resource.called())
162 }
163 if err := fn(); err != rateExceededErr {
164 t.Errorf("RunInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
165 }
166}
167
168func TestEC2RateLimitInterceptorTerminateInstancesWithContext(t *testing.T) {
169 rate := newFakeRateLimiter(1)
170 resource := newFakeRateLimiter(1)
171 i := &EC2RateLimitInterceptor{
172 next: &noopEC2Client{t: t},
173 mutatingRate: rate,
174 terminateInstanceResource: resource,
175 }
176 fn := func() error {
177 _, err := i.TerminateInstancesWithContext(context.Background(), &ec2.TerminateInstancesInput{
178 InstanceIds: []*string{aws.String("foo")},
179 }, request.WithAppendUserAgent("test-agent"))
180 return err
181 }
182 if err := fn(); err != nil {
183 t.Fatalf("TerminateInstancesWithContext(...) = nil, %s; want no error", err)
184 }
185 if !rate.called() || !resource.called() {
186 t.Errorf("rateLimiter.Wait() was never called; rate=%t, resource=%t", rate.called(), resource.called())
187 }
188 if err := fn(); err != rateExceededErr {
189 t.Errorf("TerminateInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
190 }
191}
192
193func TestEC2RateLimitInterceptorWaitUntilInstanceRunningWithContext(t *testing.T) {
194 rate := newFakeRateLimiter(1)
195 i := &EC2RateLimitInterceptor{
196 next: &noopEC2Client{t: t},
197 nonMutatingRate: rate,
198 }
199 fn := func() error {
200 return i.WaitUntilInstanceRunningWithContext(context.Background(), &ec2.DescribeInstancesInput{}, request.WithWaiterMaxAttempts(1))
201 }
202 if err := fn(); err != nil {
203 t.Fatalf("WaitUntilInstanceRunningWithContext(...) = nil, %s; want no error", err)
204 }
205 if !rate.called() {
206 t.Errorf("rateLimiter.Wait() was never called")
207 }
208 if err := fn(); err != rateExceededErr {
209 t.Errorf("WaitUntilInstanceRunningWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
210 }
211}
212
213func TestEC2RateLimitInterceptorDescribeInstanceTypesPagesWithContext(t *testing.T) {
214 rate := newFakeRateLimiter(1)
215 i := &EC2RateLimitInterceptor{
216 next: &noopEC2Client{t: t},
217 nonMutatingRate: rate,
218 }
219 fn := func() error {
220 return i.DescribeInstanceTypesPagesWithContext(context.Background(), &ec2.DescribeInstanceTypesInput{}, func(*ec2.DescribeInstanceTypesOutput, bool) bool { return true }, request.WithAppendUserAgent("test-agent"))
221 }
222 if err := fn(); err != nil {
223 t.Fatalf("DescribeInstanceTypesPagesWithContext(...) = nil, %s; want no error", err)
224 }
225 if !rate.called() {
226 t.Errorf("rateLimiter.Wait() was never called")
227 }
228 if err := fn(); err != rateExceededErr {
229 t.Errorf("DescribeInstanceTypesPagesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
230 }
231}