acme/autocert: fix races in renewal tests

TestRenewFromCache and TestRenewFromCacheAlreadyRenewed had several
races and API misuses:

1. They called t.Fatalf from a goroutine other than the one invoking
   the Test function, which is explicitly disallowed (see
   https://pkg.go.dev/testing#T).

2. The test did not stop the renewal timers prior to restoring
   test-hook functions, and the process of stopping the renewal timers
   itself did not wait for in-flight calls to complete. That could
   cause data races if one of the renewals failed and triggered a
   retry with a short-enough randomized backoff.
   (One such race was observed in
   https://build.golang.org/log/1a19e22ad826bedeb5a939c6130f368f9979208a.)

3. The testDidRenewLoop hooks accessed the Manager.renewal field
   without locking the Mutex guarding that field.

4. TestGetCertificate_failedAttempt set a testDidRemoveState hook, but
   didn't wait for the timers referring to that hook to complete
   before restoring it, causing races with other timers. I tried
   pulling on that thread a bit, but couldn't untangle the numerous
   untracked goroutines in the package. Instead, I have made a smaller
   and more local change to copy the value of testDidRemoveState into
   a local variable in the timer's closure.

Given the number of untracked goroutines in this package, it is likely
that races and/or deadlocks remain. Notably, so far I have been unable
to spot the actual cause of golang/go#51080.

For golang/go#51080

Change-Id: I7797f6ac34ef3c272f16ca805251dac3aa7f0009
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/384594
Trust: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go
index ca558e7..1858184 100644
--- a/acme/autocert/autocert.go
+++ b/acme/autocert/autocert.go
@@ -458,7 +458,7 @@
 		leaf: cert.Leaf,
 	}
 	m.state[ck] = s
-	go m.renew(ck, s.key, s.leaf.NotAfter)
+	go m.startRenew(ck, s.key, s.leaf.NotAfter)
 	return cert, nil
 }
 
@@ -584,8 +584,9 @@
 	if err != nil {
 		// Remove the failed state after some time,
 		// making the manager call createCert again on the following TLS hello.
+		didRemove := testDidRemoveState // The lifetime of this timer is untracked, so copy mutable local state to avoid races.
 		time.AfterFunc(createCertRetryAfter, func() {
-			defer testDidRemoveState(ck)
+			defer didRemove(ck)
 			m.stateMu.Lock()
 			defer m.stateMu.Unlock()
 			// Verify the state hasn't changed and it's still invalid
@@ -603,7 +604,7 @@
 	}
 	state.cert = der
 	state.leaf = leaf
-	go m.renew(ck, state.key, state.leaf.NotAfter)
+	go m.startRenew(ck, state.key, state.leaf.NotAfter)
 	return state.tlscert()
 }
 
@@ -893,7 +894,7 @@
 	return path.Base(tokenPath) + "+http-01"
 }
 
-// renew starts a cert renewal timer loop, one per domain.
+// startRenew starts a cert renewal timer loop, one per domain.
 //
 // The loop is scheduled in two cases:
 // - a cert was fetched from cache for the first time (wasn't in m.state)
@@ -901,7 +902,7 @@
 //
 // The key argument is a certificate private key.
 // The exp argument is the cert expiration time (NotAfter).
-func (m *Manager) renew(ck certKey, key crypto.Signer, exp time.Time) {
+func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) {
 	m.renewalMu.Lock()
 	defer m.renewalMu.Unlock()
 	if m.renewal[ck] != nil {
diff --git a/acme/autocert/renewal.go b/acme/autocert/renewal.go
index 665f870..0df7da7 100644
--- a/acme/autocert/renewal.go
+++ b/acme/autocert/renewal.go
@@ -21,8 +21,9 @@
 	ck  certKey
 	key crypto.Signer
 
-	timerMu sync.Mutex
-	timer   *time.Timer
+	timerMu    sync.Mutex
+	timer      *time.Timer
+	timerClose chan struct{} // if non-nil, renew closes this channel (and nils out the timer fields) instead of running
 }
 
 // start starts a cert renewal timer at the time
@@ -38,16 +39,28 @@
 	dr.timer = time.AfterFunc(dr.next(exp), dr.renew)
 }
 
-// stop stops the cert renewal timer.
-// If the timer is already stopped, calling stop is a noop.
+// stop stops the cert renewal timer and waits for any in-flight calls to renew
+// to complete. If the timer is already stopped, calling stop is a noop.
 func (dr *domainRenewal) stop() {
 	dr.timerMu.Lock()
 	defer dr.timerMu.Unlock()
-	if dr.timer == nil {
-		return
+	for {
+		if dr.timer == nil {
+			return
+		}
+		if dr.timer.Stop() {
+			dr.timer = nil
+			return
+		} else {
+			// dr.timer fired, and we acquired dr.timerMu before the renew callback did.
+			// (We know this because otherwise the renew callback would have reset dr.timer!)
+			timerClose := make(chan struct{})
+			dr.timerClose = timerClose
+			dr.timerMu.Unlock()
+			<-timerClose
+			dr.timerMu.Lock()
+		}
 	}
-	dr.timer.Stop()
-	dr.timer = nil
 }
 
 // renew is called periodically by a timer.
@@ -55,7 +68,9 @@
 func (dr *domainRenewal) renew() {
 	dr.timerMu.Lock()
 	defer dr.timerMu.Unlock()
-	if dr.timer == nil {
+	if dr.timerClose != nil {
+		close(dr.timerClose)
+		dr.timer, dr.timerClose = nil, nil
 		return
 	}
 
@@ -67,8 +82,8 @@
 		next = renewJitter / 2
 		next += time.Duration(pseudoRand.int63n(int64(next)))
 	}
-	dr.timer = time.AfterFunc(next, dr.renew)
 	testDidRenewLoop(next, err)
+	dr.timer = time.AfterFunc(next, dr.renew)
 }
 
 // updateState locks and replaces the relevant Manager.state item with the given
diff --git a/acme/autocert/renewal_test.go b/acme/autocert/renewal_test.go
index e5f48ff..ffe4af2 100644
--- a/acme/autocert/renewal_test.go
+++ b/acme/autocert/renewal_test.go
@@ -61,11 +61,23 @@
 
 	// verify the renewal happened
 	defer func() {
+		// Stop the timers that read and execute testDidRenewLoop before restoring it.
+		// Otherwise the timer callback may race with the deferred write.
+		man.stopRenew()
 		testDidRenewLoop = func(next time.Duration, err error) {}
 	}()
-	done := make(chan struct{})
+	renewed := make(chan bool, 1)
 	testDidRenewLoop = func(next time.Duration, err error) {
-		defer close(done)
+		defer func() {
+			select {
+			case renewed <- true:
+			default:
+				// The renewal timer uses a random backoff. If the first renewal fails for
+				// some reason, we could end up with multiple calls here before the test
+				// stops the timer.
+			}
+		}()
+
 		if err != nil {
 			t.Errorf("testDidRenewLoop: %v", err)
 		}
@@ -81,7 +93,8 @@
 		after := time.Now().Add(future)
 		tlscert, err := man.cacheGet(context.Background(), exampleCertKey)
 		if err != nil {
-			t.Fatalf("man.cacheGet: %v", err)
+			t.Errorf("man.cacheGet: %v", err)
+			return
 		}
 		if !tlscert.Leaf.NotAfter.After(after) {
 			t.Errorf("cache leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after)
@@ -92,11 +105,13 @@
 		defer man.stateMu.Unlock()
 		s := man.state[exampleCertKey]
 		if s == nil {
-			t.Fatalf("m.state[%q] is nil", exampleCertKey)
+			t.Errorf("m.state[%q] is nil", exampleCertKey)
+			return
 		}
 		tlscert, err = s.tlscert()
 		if err != nil {
-			t.Fatalf("s.tlscert: %v", err)
+			t.Errorf("s.tlscert: %v", err)
+			return
 		}
 		if !tlscert.Leaf.NotAfter.After(after) {
 			t.Errorf("state leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after)
@@ -108,13 +123,7 @@
 	if _, err := man.GetCertificate(hello); err != nil {
 		t.Fatal(err)
 	}
-
-	// wait for renew loop
-	select {
-	case <-time.After(10 * time.Second):
-		t.Fatal("renew took too long to occur")
-	case <-done:
-	}
+	<-renewed
 }
 
 func TestRenewFromCacheAlreadyRenewed(t *testing.T) {
@@ -159,11 +168,23 @@
 
 	// verify the renewal accepted the newer cached cert
 	defer func() {
+		// Stop the timers that read and execute testDidRenewLoop before restoring it.
+		// Otherwise the timer callback may race with the deferred write.
+		man.stopRenew()
 		testDidRenewLoop = func(next time.Duration, err error) {}
 	}()
-	done := make(chan struct{})
+	renewed := make(chan bool, 1)
 	testDidRenewLoop = func(next time.Duration, err error) {
-		defer close(done)
+		defer func() {
+			select {
+			case renewed <- true:
+			default:
+				// The renewal timer uses a random backoff. If the first renewal fails for
+				// some reason, we could end up with multiple calls here before the test
+				// stops the timer.
+			}
+		}()
+
 		if err != nil {
 			t.Errorf("testDidRenewLoop: %v", err)
 		}
@@ -177,7 +198,8 @@
 		// ensure the cached cert was not modified
 		tlscert, err := man.cacheGet(context.Background(), exampleCertKey)
 		if err != nil {
-			t.Fatalf("man.cacheGet: %v", err)
+			t.Errorf("man.cacheGet: %v", err)
+			return
 		}
 		if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
 			t.Errorf("cache leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
@@ -188,30 +210,22 @@
 		defer man.stateMu.Unlock()
 		s := man.state[exampleCertKey]
 		if s == nil {
-			t.Fatalf("m.state[%q] is nil", exampleCertKey)
+			t.Errorf("m.state[%q] is nil", exampleCertKey)
+			return
 		}
 		stateKey := s.key.Public().(*ecdsa.PublicKey)
 		if !stateKey.Equal(newLeaf.PublicKey) {
-			t.Fatal("state key was not updated from cache")
+			t.Error("state key was not updated from cache")
+			return
 		}
 		tlscert, err = s.tlscert()
 		if err != nil {
-			t.Fatalf("s.tlscert: %v", err)
+			t.Errorf("s.tlscert: %v", err)
+			return
 		}
 		if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
 			t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
 		}
-
-		// verify the private key is replaced in the renewal state
-		r := man.renewal[exampleCertKey]
-		if r == nil {
-			t.Fatalf("m.renewal[%q] is nil", exampleCertKey)
-		}
-		renewalKey := r.key.Public().(*ecdsa.PublicKey)
-		if !renewalKey.Equal(newLeaf.PublicKey) {
-			t.Fatal("renewal private key was not updated from cache")
-		}
-
 	}
 
 	// assert the expiring cert is returned from state
@@ -225,21 +239,31 @@
 	}
 
 	// trigger renew
-	go man.renew(exampleCertKey, s.key, s.leaf.NotAfter)
+	man.startRenew(exampleCertKey, s.key, s.leaf.NotAfter)
+	<-renewed
+	func() {
+		man.renewalMu.Lock()
+		defer man.renewalMu.Unlock()
 
-	// wait for renew loop
-	select {
-	case <-time.After(10 * time.Second):
-		t.Fatal("renew took too long to occur")
-	case <-done:
-		// assert the new cert is returned from state after renew
-		hello := clientHelloInfo(exampleDomain, algECDSA)
-		tlscert, err := man.GetCertificate(hello)
-		if err != nil {
-			t.Fatal(err)
+		// verify the private key is replaced in the renewal state
+		r := man.renewal[exampleCertKey]
+		if r == nil {
+			t.Errorf("m.renewal[%q] is nil", exampleCertKey)
+			return
 		}
-		if !newLeaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
-			t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
+		renewalKey := r.key.Public().(*ecdsa.PublicKey)
+		if !renewalKey.Equal(newLeaf.PublicKey) {
+			t.Error("renewal private key was not updated from cache")
 		}
+	}()
+
+	// assert the new cert is returned from state after renew
+	hello = clientHelloInfo(exampleDomain, algECDSA)
+	tlscert, err = man.GetCertificate(hello)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !newLeaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
+		t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
 	}
 }