net/http: also use Server.ReadHeaderTimeout for TLS handshake deadline
Fixes #48120
Change-Id: I72e89af8aaf3310e348d8ab639925ce0bf84204d
Reviewed-on: https://go-review.googlesource.com/c/go/+/355870
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 55fd4ae..e9b0b4d 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -865,6 +865,28 @@
return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
}
+// tlsHandshakeTimeout returns the time limit permitted for the TLS
+// handshake, or zero for unlimited.
+//
+// It returns the minimum of any positive ReadHeaderTimeout,
+// ReadTimeout, or WriteTimeout.
+func (srv *Server) tlsHandshakeTimeout() time.Duration {
+ var ret time.Duration
+ for _, v := range [...]time.Duration{
+ srv.ReadHeaderTimeout,
+ srv.ReadTimeout,
+ srv.WriteTimeout,
+ } {
+ if v <= 0 {
+ continue
+ }
+ if ret == 0 || v < ret {
+ ret = v
+ }
+ }
+ return ret
+}
+
// wrapper around io.ReadCloser which on first read, sends an
// HTTP/1.1 100 Continue header
type expectContinueReader struct {
@@ -1816,11 +1838,11 @@
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
- if d := c.server.ReadTimeout; d > 0 {
- c.rwc.SetReadDeadline(time.Now().Add(d))
- }
- if d := c.server.WriteTimeout; d > 0 {
- c.rwc.SetWriteDeadline(time.Now().Add(d))
+ tlsTO := c.server.tlsHandshakeTimeout()
+ if tlsTO > 0 {
+ dl := time.Now().Add(tlsTO)
+ c.rwc.SetReadDeadline(dl)
+ c.rwc.SetWriteDeadline(dl)
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
// If the handshake failed due to the client not speaking
@@ -1834,6 +1856,11 @@
c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return
}
+ // Restore Conn-level deadlines.
+ if tlsTO > 0 {
+ c.rwc.SetReadDeadline(time.Time{})
+ c.rwc.SetWriteDeadline(time.Time{})
+ }
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {
diff --git a/src/net/http/server_test.go b/src/net/http/server_test.go
index 0132f3b..d17c5c1 100644
--- a/src/net/http/server_test.go
+++ b/src/net/http/server_test.go
@@ -9,8 +9,61 @@
import (
"fmt"
"testing"
+ "time"
)
+func TestServerTLSHandshakeTimeout(t *testing.T) {
+ tests := []struct {
+ s *Server
+ want time.Duration
+ }{
+ {
+ s: &Server{},
+ want: 0,
+ },
+ {
+ s: &Server{
+ ReadTimeout: -1,
+ },
+ want: 0,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ },
+ want: 5 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: -1,
+ },
+ want: 5 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 4 * time.Second,
+ },
+ want: 4 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ ReadHeaderTimeout: 2 * time.Second,
+ WriteTimeout: 4 * time.Second,
+ },
+ want: 2 * time.Second,
+ },
+ }
+ for i, tt := range tests {
+ got := tt.s.tlsHandshakeTimeout()
+ if got != tt.want {
+ t.Errorf("%d. got %v; want %v", i, got, tt.want)
+ }
+ }
+}
+
func BenchmarkServerMatch(b *testing.B) {
fn := func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "OK")