net/http: also use Server.ReadHeaderTimeout for TLS handshake deadline

Fixes #48120

Change-Id: I72e89af8aaf3310e348d8ab639925ce0bf84204d
Reviewed-on: https://go-review.googlesource.com/c/go/+/355870
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 55fd4ae..e9b0b4d 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -865,6 +865,28 @@
 	return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
 }
 
+// tlsHandshakeTimeout returns the time limit permitted for the TLS
+// handshake, or zero for unlimited.
+//
+// It returns the minimum of any positive ReadHeaderTimeout,
+// ReadTimeout, or WriteTimeout.
+func (srv *Server) tlsHandshakeTimeout() time.Duration {
+	var ret time.Duration
+	for _, v := range [...]time.Duration{
+		srv.ReadHeaderTimeout,
+		srv.ReadTimeout,
+		srv.WriteTimeout,
+	} {
+		if v <= 0 {
+			continue
+		}
+		if ret == 0 || v < ret {
+			ret = v
+		}
+	}
+	return ret
+}
+
 // wrapper around io.ReadCloser which on first read, sends an
 // HTTP/1.1 100 Continue header
 type expectContinueReader struct {
@@ -1816,11 +1838,11 @@
 	}()
 
 	if tlsConn, ok := c.rwc.(*tls.Conn); ok {
-		if d := c.server.ReadTimeout; d > 0 {
-			c.rwc.SetReadDeadline(time.Now().Add(d))
-		}
-		if d := c.server.WriteTimeout; d > 0 {
-			c.rwc.SetWriteDeadline(time.Now().Add(d))
+		tlsTO := c.server.tlsHandshakeTimeout()
+		if tlsTO > 0 {
+			dl := time.Now().Add(tlsTO)
+			c.rwc.SetReadDeadline(dl)
+			c.rwc.SetWriteDeadline(dl)
 		}
 		if err := tlsConn.HandshakeContext(ctx); err != nil {
 			// If the handshake failed due to the client not speaking
@@ -1834,6 +1856,11 @@
 			c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
 			return
 		}
+		// Restore Conn-level deadlines.
+		if tlsTO > 0 {
+			c.rwc.SetReadDeadline(time.Time{})
+			c.rwc.SetWriteDeadline(time.Time{})
+		}
 		c.tlsState = new(tls.ConnectionState)
 		*c.tlsState = tlsConn.ConnectionState()
 		if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {
diff --git a/src/net/http/server_test.go b/src/net/http/server_test.go
index 0132f3b..d17c5c1 100644
--- a/src/net/http/server_test.go
+++ b/src/net/http/server_test.go
@@ -9,8 +9,61 @@
 import (
 	"fmt"
 	"testing"
+	"time"
 )
 
+func TestServerTLSHandshakeTimeout(t *testing.T) {
+	tests := []struct {
+		s    *Server
+		want time.Duration
+	}{
+		{
+			s:    &Server{},
+			want: 0,
+		},
+		{
+			s: &Server{
+				ReadTimeout: -1,
+			},
+			want: 0,
+		},
+		{
+			s: &Server{
+				ReadTimeout: 5 * time.Second,
+			},
+			want: 5 * time.Second,
+		},
+		{
+			s: &Server{
+				ReadTimeout:  5 * time.Second,
+				WriteTimeout: -1,
+			},
+			want: 5 * time.Second,
+		},
+		{
+			s: &Server{
+				ReadTimeout:  5 * time.Second,
+				WriteTimeout: 4 * time.Second,
+			},
+			want: 4 * time.Second,
+		},
+		{
+			s: &Server{
+				ReadTimeout:       5 * time.Second,
+				ReadHeaderTimeout: 2 * time.Second,
+				WriteTimeout:      4 * time.Second,
+			},
+			want: 2 * time.Second,
+		},
+	}
+	for i, tt := range tests {
+		got := tt.s.tlsHandshakeTimeout()
+		if got != tt.want {
+			t.Errorf("%d. got %v; want %v", i, got, tt.want)
+		}
+	}
+}
+
 func BenchmarkServerMatch(b *testing.B) {
 	fn := func(w ResponseWriter, r *Request) {
 		fmt.Fprintf(w, "OK")