http2: with Go 1.7 set Request.Context in ServeHTTP handlers
And act the same as HTTP/1.x in Go 1.7.
Updates golang/go#15134
Change-Id: Ib64dd82cc5f8dd60e1680525f664d5b72be11fc6
Reviewed-on: https://go-review.googlesource.com/23220
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>
diff --git a/http2/go17.go b/http2/go17.go
index 2e2eabd..3d3c71e 100644
--- a/http2/go17.go
+++ b/http2/go17.go
@@ -8,11 +8,33 @@
import (
"context"
+ "net"
"net/http"
"net/http/httptrace"
"time"
)
+type contextContext interface {
+ context.Context
+}
+
+func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
+ ctx, cancel = context.WithCancel(context.Background())
+ ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
+ if hs := opts.baseConfig(); hs != nil {
+ ctx = context.WithValue(ctx, http.ServerContextKey, hs)
+ }
+ return
+}
+
+func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) {
+ return context.WithCancel(ctx)
+}
+
+func requestWithContext(req *http.Request, ctx contextContext) *http.Request {
+ return req.WithContext(ctx)
+}
+
type clientTrace httptrace.ClientTrace
func reqContext(r *http.Request) context.Context { return r.Context() }
diff --git a/http2/not_go17.go b/http2/not_go17.go
index deffe68..077db39 100644
--- a/http2/not_go17.go
+++ b/http2/not_go17.go
@@ -6,7 +6,12 @@
package http2
-import "net/http"
+import (
+ "net"
+ "net/http"
+)
+
+type contextContext interface{}
type fakeContext struct{}
@@ -28,3 +33,17 @@
func traceFirstResponseByte(*clientTrace) {}
func traceWroteHeaders(*clientTrace) {}
func traceWroteRequest(*clientTrace, error) {}
+
+func nop() {}
+
+func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
+ return nil, nop
+}
+
+func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) {
+ return ctx, nop
+}
+
+func requestWithContext(req *http.Request, ctx contextContext) *http.Request {
+ return req
+}
diff --git a/http2/server.go b/http2/server.go
index 3a46db6..a2b6c4b 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -250,10 +250,14 @@
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
+ baseCtx, cancel := serverConnBaseContext(c, opts)
+ defer cancel()
+
sc := &serverConn{
srv: s,
hs: opts.baseConfig(),
conn: c,
+ baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c),
handler: opts.handler(),
@@ -272,6 +276,7 @@
serveG: newGoroutineLock(),
pushEnabled: true,
}
+
sc.flow.add(initialWindowSize)
sc.inflow.add(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
@@ -373,6 +378,7 @@
conn net.Conn
bw *bufferedWriter // writing to conn
handler http.Handler
+ baseCtx contextContext
framer *Framer
doneServing chan struct{} // closed when serverConn.serve ends
readFrameCh chan readFrameResult // written by serverConn.readFrames
@@ -436,10 +442,12 @@
// responseWriter's state field.
type stream struct {
// immutable:
- sc *serverConn
- id uint32
- body *pipe // non-nil if expecting DATA frames
- cw closeWaiter // closed wait stream transitions to closed state
+ sc *serverConn
+ id uint32
+ body *pipe // non-nil if expecting DATA frames
+ cw closeWaiter // closed wait stream transitions to closed state
+ ctx contextContext
+ cancelCtx func()
// owned by serverConn's serve loop:
bodyBytes int64 // body bytes seen so far
@@ -1157,6 +1165,7 @@
}
if st != nil {
st.gotReset = true
+ st.cancelCtx()
sc.closeStream(st, StreamError{f.StreamID, f.ErrCode})
}
return nil
@@ -1380,10 +1389,13 @@
}
sc.maxStreamID = id
+ ctx, cancelCtx := contextWithCancel(sc.baseCtx)
st = &stream{
- sc: sc,
- id: id,
- state: stateOpen,
+ sc: sc,
+ id: id,
+ state: stateOpen,
+ ctx: ctx,
+ cancelCtx: cancelCtx,
}
if f.StreamEnded() {
st.state = stateHalfClosedRemote
@@ -1617,6 +1629,7 @@
Body: body,
Trailer: trailer,
}
+ req = requestWithContext(req, st.ctx)
if bodyOpen {
// Disabled, per golang.org/issue/14960:
// st.reqBuf = sc.getRequestBodyBuf()
@@ -1661,6 +1674,7 @@
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
didPanic := true
defer func() {
+ rw.rws.stream.cancelCtx()
if didPanic {
e := recover()
// Same as net/http: