quic: don't send CONNECTION_CLOSE after stateless reset

After receiving a stateless reset, we must enter the draining
state and send no further packets (including CONNECTION_CLOSE).
We were sending one last CONNECTION_CLOSE after the user
closed the Conn; fix this.

RFC 9000, Section 10.3.1.

Change-Id: I6a9cc6019470a25476df518022a32eefe0c50fcd
Reviewed-on: https://go-review.googlesource.com/c/net/+/540117
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go
index daf425b..a9ef0db 100644
--- a/internal/quic/conn_close.go
+++ b/internal/quic/conn_close.go
@@ -62,7 +62,7 @@
 	c.lifetime.drainEndTime = time.Time{}
 	if c.lifetime.finalErr == nil {
 		// The peer never responded to our CONNECTION_CLOSE.
-		c.enterDraining(errNoPeerResponse)
+		c.enterDraining(now, errNoPeerResponse)
 	}
 	return true
 }
@@ -152,11 +152,17 @@
 }
 
 // enterDraining enters the draining state.
-func (c *Conn) enterDraining(err error) {
+func (c *Conn) enterDraining(now time.Time, err error) {
 	if c.isDraining() {
 		return
 	}
-	if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo {
+	if err == errStatelessReset {
+		// If we've received a stateless reset, then we must not send a CONNECTION_CLOSE.
+		// Setting connCloseSentTime here prevents us from doing so.
+		c.lifetime.finalErr = errStatelessReset
+		c.lifetime.localErr = errStatelessReset
+		c.lifetime.connCloseSentTime = now
+	} else if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo {
 		// If we've terminated the connection due to a peer protocol violation,
 		// record the final error on the connection as our reason for termination.
 		c.lifetime.finalErr = c.lifetime.localErr
@@ -239,14 +245,14 @@
 // The connection does not send a CONNECTION_CLOSE, and skips the draining period.
 func (c *Conn) abortImmediately(now time.Time, err error) {
 	c.abort(now, err)
-	c.enterDraining(err)
+	c.enterDraining(now, err)
 	c.exited = true
 }
 
 // exit fully terminates a connection immediately.
 func (c *Conn) exit() {
 	c.sendMsg(func(now time.Time, c *Conn) {
-		c.enterDraining(errors.New("connection closed"))
+		c.enterDraining(now, errors.New("connection closed"))
 		c.exited = true
 	})
 }
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index 8fa3a39..896c6d7 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -56,7 +56,7 @@
 			if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen {
 				var token statelessResetToken
 				copy(token[:], buf[len(buf)-len(token):])
-				c.handleStatelessReset(token)
+				c.handleStatelessReset(now, token)
 			}
 			// Invalid data at the end of a datagram is ignored.
 			break
@@ -525,7 +525,7 @@
 	if n < 0 {
 		return -1
 	}
-	c.enterDraining(peerTransportError{code: code, reason: reason})
+	c.enterDraining(now, peerTransportError{code: code, reason: reason})
 	return n
 }
 
@@ -534,7 +534,7 @@
 	if n < 0 {
 		return -1
 	}
-	c.enterDraining(&ApplicationError{Code: code, Reason: reason})
+	c.enterDraining(now, &ApplicationError{Code: code, Reason: reason})
 	return n
 }
 
@@ -556,9 +556,9 @@
 
 var errStatelessReset = errors.New("received stateless reset")
 
-func (c *Conn) handleStatelessReset(resetToken statelessResetToken) {
+func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) {
 	if !c.connIDState.isValidStatelessResetToken(resetToken) {
 		return
 	}
-	c.enterDraining(errStatelessReset)
+	c.enterDraining(now, errStatelessReset)
 }
diff --git a/internal/quic/listener.go b/internal/quic/listener.go
index 8b31dcb..ca8f9b2 100644
--- a/internal/quic/listener.go
+++ b/internal/quic/listener.go
@@ -253,12 +253,18 @@
 	if len(m.b) < minimumValidPacketSize {
 		return
 	}
+	var now time.Time
+	if l.testHooks != nil {
+		now = l.testHooks.timeNow()
+	} else {
+		now = time.Now()
+	}
 	// Check to see if this is a stateless reset.
 	var token statelessResetToken
 	copy(token[:], m.b[len(m.b)-len(token):])
 	if c := l.connsMap.byResetToken[token]; c != nil {
 		c.sendMsg(func(now time.Time, c *Conn) {
-			c.handleStatelessReset(token)
+			c.handleStatelessReset(now, token)
 		})
 		return
 	}
@@ -290,12 +296,6 @@
 		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16
 		return
 	}
-	var now time.Time
-	if l.testHooks != nil {
-		now = l.testHooks.timeNow()
-	} else {
-		now = time.Now()
-	}
 	cids := newServerConnIDs{
 		srcConnID: p.srcConnID,
 		dstConnID: p.dstConnID,
diff --git a/internal/quic/stateless_reset_test.go b/internal/quic/stateless_reset_test.go
index b12e975..8a16597 100644
--- a/internal/quic/stateless_reset_test.go
+++ b/internal/quic/stateless_reset_test.go
@@ -14,6 +14,7 @@
 	"errors"
 	"net/netip"
 	"testing"
+	"time"
 )
 
 func TestStatelessResetClientSendsStatelessResetTokenTransportParameter(t *testing.T) {
@@ -154,7 +155,9 @@
 	if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) {
 		t.Errorf("conn.Wait() = %v, want errStatelessReset", err)
 	}
-	tc.wantIdle("closed connection is idle")
+	tc.wantIdle("closed connection is idle in draining")
+	tc.advance(1 * time.Second) // long enough to exit the draining state
+	tc.wantIdle("closed connection is idle after draining")
 }
 
 func TestStatelessResetSuccessfulTransportParameter(t *testing.T) {