acme/autocert: let automatic renewal work with short lifetime certs

Fixes golang/go#64997
Fixes golang/go#36548

Change-Id: Idb7a426ad3bfa6ac3b796f4b466da6e3154f1ffa
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/719080
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Mark Freeman <markfreeman@google.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go
index ccd5b7e..cde9066 100644
--- a/acme/autocert/autocert.go
+++ b/acme/autocert/autocert.go
@@ -134,7 +134,8 @@
 	// RenewBefore optionally specifies how early certificates should
 	// be renewed before they expire.
 	//
-	// If zero, they're renewed 30 days before expiration.
+	// If zero, they're renewed at the lesser of 30 days or
+	// 1/3 of the certificate lifetime.
 	RenewBefore time.Duration
 
 	// Client is used to perform low-level operations, such as account registration
@@ -464,7 +465,7 @@
 		leaf: cert.Leaf,
 	}
 	m.state[ck] = s
-	m.startRenew(ck, s.key, s.leaf.NotAfter)
+	m.startRenew(ck, s.key, s.leaf.NotBefore, s.leaf.NotAfter)
 	return cert, nil
 }
 
@@ -610,7 +611,7 @@
 	}
 	state.cert = der
 	state.leaf = leaf
-	m.startRenew(ck, state.key, state.leaf.NotAfter)
+	m.startRenew(ck, state.key, state.leaf.NotBefore, state.leaf.NotAfter)
 	return state.tlscert()
 }
 
@@ -908,7 +909,7 @@
 //
 // The key argument is a certificate private key.
 // The exp argument is the cert expiration time (NotAfter).
-func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) {
+func (m *Manager) startRenew(ck certKey, key crypto.Signer, notBefore, notAfter time.Time) {
 	m.renewalMu.Lock()
 	defer m.renewalMu.Unlock()
 	if m.renewal[ck] != nil {
@@ -920,7 +921,7 @@
 	}
 	dr := &domainRenewal{m: m, ck: ck, key: key}
 	m.renewal[ck] = dr
-	dr.start(exp)
+	dr.start(notBefore, notAfter)
 }
 
 // stopRenew stops all currently running cert renewal timers.
@@ -1028,13 +1029,6 @@
 	return defaultHostPolicy
 }
 
-func (m *Manager) renewBefore() time.Duration {
-	if m.RenewBefore > renewJitter {
-		return m.RenewBefore
-	}
-	return 720 * time.Hour // 30 days
-}
-
 func (m *Manager) now() time.Time {
 	if m.nowFunc != nil {
 		return m.nowFunc()
diff --git a/acme/autocert/renewal.go b/acme/autocert/renewal.go
index 0df7da7..93984f3 100644
--- a/acme/autocert/renewal.go
+++ b/acme/autocert/renewal.go
@@ -11,9 +11,6 @@
 	"time"
 )
 
-// renewJitter is the maximum deviation from Manager.RenewBefore.
-const renewJitter = time.Hour
-
 // domainRenewal tracks the state used by the periodic timers
 // renewing a single domain's cert.
 type domainRenewal struct {
@@ -30,13 +27,13 @@
 // defined by the certificate expiration time exp.
 //
 // If the timer is already started, calling start is a noop.
-func (dr *domainRenewal) start(exp time.Time) {
+func (dr *domainRenewal) start(notBefore, notAfter time.Time) {
 	dr.timerMu.Lock()
 	defer dr.timerMu.Unlock()
 	if dr.timer != nil {
 		return
 	}
-	dr.timer = time.AfterFunc(dr.next(exp), dr.renew)
+	dr.timer = time.AfterFunc(dr.next(notBefore, notAfter), dr.renew)
 }
 
 // stop stops the cert renewal timer and waits for any in-flight calls to renew
@@ -79,7 +76,7 @@
 	// TODO: rotate dr.key at some point?
 	next, err := dr.do(ctx)
 	if err != nil {
-		next = renewJitter / 2
+		next = time.Hour / 2
 		next += time.Duration(pseudoRand.int63n(int64(next)))
 	}
 	testDidRenewLoop(next, err)
@@ -107,8 +104,8 @@
 	// a race is likely unavoidable in a distributed environment
 	// but we try nonetheless
 	if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil {
-		next := dr.next(tlscert.Leaf.NotAfter)
-		if next > dr.m.renewBefore()+renewJitter {
+		next := dr.next(tlscert.Leaf.NotBefore, tlscert.Leaf.NotAfter)
+		if next > 0 {
 			signer, ok := tlscert.PrivateKey.(crypto.Signer)
 			if ok {
 				state := &certState{
@@ -139,18 +136,23 @@
 		return 0, err
 	}
 	dr.updateState(state)
-	return dr.next(leaf.NotAfter), nil
+	return dr.next(leaf.NotBefore, leaf.NotAfter), nil
 }
 
-func (dr *domainRenewal) next(expiry time.Time) time.Duration {
-	d := expiry.Sub(dr.m.now()) - dr.m.renewBefore()
-	// add a bit of randomness to renew deadline
-	n := pseudoRand.int63n(int64(renewJitter))
-	d -= time.Duration(n)
-	if d < 0 {
-		return 0
+// next returns the wait time before the next renewal should start.
+// If manager.RenewBefore is set, it uses that capped at 30 days,
+// otherwise it uses a default of 1/3 of the cert lifetime.
+// It builds in a jitter of 10% of the renew threshold, capped at 1 hour.
+func (dr *domainRenewal) next(notBefore, notAfter time.Time) time.Duration {
+	threshold := min(notAfter.Sub(notBefore)/3, 30*24*time.Hour)
+	if dr.m.RenewBefore > 0 {
+		threshold = min(dr.m.RenewBefore, 30*24*time.Hour)
 	}
-	return d
+	maxJitter := min(threshold/10, time.Hour)
+	jitter := pseudoRand.int63n(int64(maxJitter))
+	renewAt := notAfter.Add(-(threshold - time.Duration(jitter)))
+	renewWait := renewAt.Sub(dr.m.now())
+	return max(0, renewWait)
 }
 
 var testDidRenewLoop = func(next time.Duration, err error) {}
diff --git a/acme/autocert/renewal_test.go b/acme/autocert/renewal_test.go
index ffe4af2..67e2da2 100644
--- a/acme/autocert/renewal_test.go
+++ b/acme/autocert/renewal_test.go
@@ -17,27 +17,60 @@
 
 func TestRenewalNext(t *testing.T) {
 	now := time.Now()
-	man := &Manager{
-		RenewBefore: 7 * 24 * time.Hour,
-		nowFunc:     func() time.Time { return now },
-	}
-	defer man.stopRenew()
+	nowFn := func() time.Time { return now }
 	tt := []struct {
-		expiry   time.Time
-		min, max time.Duration
+		name        string
+		renewBefore time.Duration // arg to Manager
+		// leaf cert validity
+		notBefore time.Time
+		validFor  time.Duration
+		// wait time
+		waitMin, waitMax time.Duration
 	}{
-		{now.Add(90 * 24 * time.Hour), 83*24*time.Hour - renewJitter, 83 * 24 * time.Hour},
-		{now.Add(time.Hour), 0, 1},
-		{now, 0, 1},
-		{now.Add(-time.Hour), 0, 1},
+		{"default renewal, 1h cert, valid",
+			0, now, time.Hour, 40 * time.Minute, 50 * time.Minute},
+		{"default renewal, 1h cert, should renew",
+			0, now.Add(-50 * time.Minute), time.Hour, 0, 0},
+		{"default renewal, 1h cert, expired",
+			0, now.Add(-400 * 24 * time.Hour), time.Hour, 0, 0},
+		{"default renewal, 6d cert, valid",
+			0, now, 6 * 24 * time.Hour, 4 * 24 * time.Hour, (4*24 + 1) * time.Hour},
+		{"default renewal, 6d cert, should renew",
+			0, now.Add(-5 * 24 * time.Hour), 6 * 24 * time.Hour, 0, 0},
+		{"default renewal, 6d cert, expired",
+			0, now.Add(-400 * 24 * time.Hour), 6 * 24 * time.Hour, 0, 0},
+		{"default renewal, 90d cert, valid",
+			0, now, 90 * 24 * time.Hour, 60 * 24 * time.Hour, (60*24 + 1) * time.Hour},
+		{"default renewal, 90d cert, should renew",
+			0, now.Add(-70 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0},
+		{"default renewal, 90d cert, expired",
+			0, now.Add(-400 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0},
+		{"default renewal, 398d cert, valid",
+			0, now, 398 * 24 * time.Hour, (368 * 24) * time.Hour, (368*24 + 1) * time.Hour},
+		{"default renewal, 398d cert, should renew",
+			0, now.Add(-378 * 24 * time.Hour), 398 * 24 * time.Hour, 0, 0},
+		{"default renewal, 398d cert, expired",
+			0, now.Add(-400 * 24 * time.Hour), 398 * 24 * time.Hour, 0, 0},
+		{"7d renewal, 90d cert, valid",
+			7 * 24 * time.Hour, now, 90 * 24 * time.Hour, 83 * 24 * time.Hour, (83*24 + 1) * time.Hour},
+		{"7d renewal, 90d cert, should not renew",
+			7 * 24 * time.Hour, now.Add(-70 * 24 * time.Hour), 90 * 24 * time.Hour, 13 * 24 * time.Hour, (13*24 + 1) * time.Hour},
+		{"7d renewal, 90d cert, should renew",
+			7 * 24 * time.Hour, now.Add(-85 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0},
+		{"7d renewal, 90d cert, expired",
+			7 * 24 * time.Hour, now.Add(-400 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0},
 	}
 
-	dr := &domainRenewal{m: man}
-	for i, test := range tt {
-		next := dr.next(test.expiry)
-		if next < test.min || test.max < next {
-			t.Errorf("%d: next = %v; want between %v and %v", i, next, test.min, test.max)
-		}
+	for _, test := range tt {
+		t.Run(test.name, func(t *testing.T) {
+			dr := &domainRenewal{m: &Manager{RenewBefore: test.renewBefore, nowFunc: nowFn}}
+			defer dr.m.stopRenew()
+
+			next := dr.next(test.notBefore, test.notBefore.Add(test.validFor))
+			if next < test.waitMin || next > test.waitMax {
+				t.Errorf("expected wait time: %v <= %v <= %v", test.waitMin, next, test.waitMax)
+			}
+		})
 	}
 }
 
@@ -239,7 +272,7 @@
 	}
 
 	// trigger renew
-	man.startRenew(exampleCertKey, s.key, s.leaf.NotAfter)
+	man.startRenew(exampleCertKey, s.key, s.leaf.NotBefore, s.leaf.NotAfter)
 	<-renewed
 	func() {
 		man.renewalMu.Lock()