x/time/rate: handle zero limit correctly
Fixes golang/go#39984
Change-Id: Iee82550fd6f141b22afcf96aea41ec2ff3e98e9a
Reviewed-on: https://go-review.googlesource.com/c/time/+/323429
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Trust: Carlos Amedee <carlos@golang.org>
diff --git a/rate/rate.go b/rate/rate.go
index 0cfcc84..b0b982e 100644
--- a/rate/rate.go
+++ b/rate/rate.go
@@ -306,15 +306,27 @@
// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN.
func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation {
lim.mu.Lock()
+ defer lim.mu.Unlock()
if lim.limit == Inf {
- lim.mu.Unlock()
return Reservation{
ok: true,
lim: lim,
tokens: n,
timeToAct: now,
}
+ } else if lim.limit == 0 {
+ var ok bool
+ if lim.burst >= n {
+ ok = true
+ lim.burst -= n
+ }
+ return Reservation{
+ ok: ok,
+ lim: lim,
+ tokens: lim.burst,
+ timeToAct: now,
+ }
}
now, last, tokens := lim.advance(now)
@@ -351,7 +363,6 @@
lim.last = last
}
- lim.mu.Unlock()
return r
}
@@ -377,6 +388,9 @@
// durationFromTokens is a unit conversion function from the number of tokens to the duration
// of time it takes to accumulate them at a rate of limit tokens per second.
func (limit Limit) durationFromTokens(tokens float64) time.Duration {
+ if limit <= 0 {
+ return InfDuration
+ }
seconds := tokens / float64(limit)
return time.Duration(float64(time.Second) * seconds)
}
@@ -384,5 +398,8 @@
// tokensFromDuration is a unit conversion function from a time duration to the number of tokens
// which could be accumulated during that duration at a rate of limit tokens per second.
func (limit Limit) tokensFromDuration(d time.Duration) float64 {
+ if limit <= 0 {
+ return 0
+ }
return d.Seconds() * float64(limit)
}
diff --git a/rate/rate_test.go b/rate/rate_test.go
index 1c5e9e7..93a3378 100644
--- a/rate/rate_test.go
+++ b/rate/rate_test.go
@@ -480,3 +480,13 @@
lim.WaitN(ctx, 1)
}
}
+
+func TestZeroLimit(t *testing.T) {
+ r := NewLimiter(0, 1)
+ if !r.Allow() {
+ t.Errorf("Limit(0, 1) want true when first used")
+ }
+ if r.Allow() {
+ t.Errorf("Limit(0, 1) want false when already used")
+ }
+}