internal/http3: add basic request stream handling for the server
Add support for the server to handle basic request streams.
Specifically, reading HTTP headers & body in requests, and writing HTTP
headers & body in responses.
More sophisticated behaviors, such as dealing with 1xx status and
trailing headers, will be done in the future.
For golang/go#70914
Change-Id: I6a00435d72a118efe2ab76e4c7c53b9dca1d63f6
Reviewed-on: https://go-review.googlesource.com/c/net/+/735820
Reviewed-by: Nicholas Husin <husin@google.com>
Auto-Submit: Nicholas Husin <husin@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/http3/conn_test.go b/internal/http3/conn_test.go
index 78f19c4..b5c823f 100644
--- a/internal/http3/conn_test.go
+++ b/internal/http3/conn_test.go
@@ -147,7 +147,7 @@
f(t, tc.testQUICConn)
})
synctestSubtest(t, "server", func(t *testing.T) {
- ts := newTestServer(t)
+ ts := newTestServer(t, nil)
tc := ts.connect()
f(t, tc.testQUICConn)
})
diff --git a/internal/http3/server.go b/internal/http3/server.go
index bdb3479..0bd8e8c 100644
--- a/internal/http3/server.go
+++ b/internal/http3/server.go
@@ -7,8 +7,10 @@
import (
"context"
"net/http"
+ "strconv"
"sync"
+ "golang.org/x/net/http/httpguts"
"golang.org/x/net/quic"
)
@@ -57,7 +59,7 @@
if err != nil {
return err
}
- go newServerConn(qconn)
+ go newServerConn(qconn, s.Handler)
}
}
@@ -67,11 +69,13 @@
genericConn // for handleUnidirectionalStream
enc qpackEncoder
dec qpackDecoder
+ handler http.Handler
}
-func newServerConn(qconn *quic.Conn) {
+func newServerConn(qconn *quic.Conn, handler http.Handler) {
sc := &serverConn{
- qconn: qconn,
+ qconn: qconn,
+ handler: handler,
}
sc.enc.init()
@@ -152,8 +156,56 @@
}
}
+func parseRequest(st *stream) (*http.Request, error) {
+ req := &http.Request{}
+ ftype, err := st.readFrameHeader()
+ if err != nil {
+ return nil, err
+ }
+ if ftype != frameTypeHeaders {
+ return nil, err
+ }
+ req.Header = make(http.Header)
+ var dec qpackDecoder
+ if err := dec.decode(st, func(_ indexType, name, value string) error {
+ switch name {
+ case ":method":
+ req.Method = value
+ case ":scheme":
+ req.URL.Scheme = value
+ case ":path":
+ req.URL.Path = value
+ case ":authority":
+ req.URL.Host = value
+ default:
+ req.Header.Add(name, value)
+ }
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+ if err := st.endFrame(); err != nil {
+ return nil, err
+ }
+ req.Body = &bodyReader{
+ st: st,
+ remain: -1,
+ }
+ return req, nil
+}
+
func (sc *serverConn) handleRequestStream(st *stream) error {
- // TODO
+ req, err := parseRequest(st)
+ if err != nil {
+ return err
+ }
+ defer req.Body.Close()
+
+ responseWriter := sc.newResponseWriter(st)
+ defer responseWriter.close()
+
+ // TODO: handle panic coming from the HTTP handler.
+ sc.handler.ServeHTTP(responseWriter, req)
return nil
}
@@ -168,3 +220,78 @@
sc.qconn.Abort(err)
}
}
+
+type responseWriter struct {
+ st *stream
+ bw *bodyWriter
+ mu sync.Mutex
+ headers http.Header
+ // TODO: support 1xx status
+ wroteHeader bool // Non-1xx header has been (logically) written.
+}
+
+func (sc *serverConn) newResponseWriter(st *stream) *responseWriter {
+ rw := &responseWriter{
+ st: st,
+ headers: make(http.Header),
+ bw: &bodyWriter{
+ st: st,
+ remain: -1,
+ flush: false,
+ name: "response",
+ },
+ }
+ return rw
+}
+
+func (rw *responseWriter) Header() http.Header {
+ return rw.headers
+}
+
+// Caller must hold rw.mu.
+func (rw *responseWriter) writeHeaderLocked(statusCode int) {
+ // TODO: support trailer header.
+ if rw.wroteHeader {
+ return
+ }
+ enc := &qpackEncoder{}
+ enc.init()
+ encHeaders := enc.encode(func(f func(itype indexType, name, value string)) {
+ f(mayIndex, ":status", strconv.Itoa(statusCode))
+ for name, values := range rw.headers {
+ if !httpguts.ValidHeaderFieldName(name) {
+ continue
+ }
+ for _, val := range values {
+ if !httpguts.ValidHeaderFieldValue(val) {
+ continue
+ }
+ // Issue #71374: Consider supporting never-indexed fields.
+ f(mayIndex, name, val)
+ }
+ }
+ })
+ rw.st.writeVarint(int64(frameTypeHeaders))
+ rw.st.writeVarint(int64(len(encHeaders)))
+ rw.st.Write(encHeaders)
+ rw.wroteHeader = true
+}
+
+func (rw *responseWriter) WriteHeader(statusCode int) {
+ rw.mu.Lock()
+ defer rw.mu.Unlock()
+ rw.writeHeaderLocked(statusCode)
+}
+
+func (rw *responseWriter) Write(b []byte) (int, error) {
+ rw.mu.Lock()
+ defer rw.mu.Unlock()
+ if !rw.wroteHeader {
+ rw.writeHeaderLocked(http.StatusOK)
+ }
+ return rw.bw.Write(b)
+}
+
+func (rw *responseWriter) close() error {
+ return rw.st.stream.Close()
+}
diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go
index 1e0cba9..3d29dba 100644
--- a/internal/http3/server_test.go
+++ b/internal/http3/server_test.go
@@ -7,6 +7,8 @@
package http3
import (
+ "io"
+ "net/http"
"net/netip"
"testing"
"testing/synctest"
@@ -20,7 +22,7 @@
// this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
// https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3
synctest.Test(t, func(t *testing.T) {
- ts := newTestServer(t)
+ ts := newTestServer(t, nil)
tc := ts.connect()
tc.newStream(streamTypePush)
tc.wantClosed("invalid client-created push stream", errH3StreamCreationError)
@@ -29,7 +31,7 @@
func TestServerCancelPushForUnsentPromise(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
- ts := newTestServer(t)
+ ts := newTestServer(t, nil)
tc := ts.connect()
tc.greet()
@@ -43,6 +45,103 @@
})
}
+func TestServerHeader(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ header := w.Header()
+ for key, values := range r.Header {
+ for _, value := range values {
+ header.Add(key, value)
+ }
+ }
+ w.WriteHeader(204)
+ }))
+ tc := ts.connect()
+ tc.greet()
+
+ reqStream := tc.newStream(streamTypeRequest)
+ reqStream.writeHeaders(http.Header{
+ "header-from-client": {"that", "should", "be", "echoed"},
+ })
+ synctest.Wait()
+ reqStream.wantHeaders(map[string][]string{
+ ":status": {"204"},
+ "Header-From-Client": {"that", "should", "be", "echoed"},
+ })
+ })
+}
+
+func TestServerPseudoHeader(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Pseudo-headers from client request should populate a specific
+ // field in http.Request, and should not be part of http.Request.Header.
+ if r.Header.Get(":method") != "" || r.Method != "GET" {
+ t.Error("want pseudo-headers from client to be reflected in appropriate fields in http.Request, not in http.Request.Header")
+ }
+ // Conversely, server should not be able to set pseudo-headers by
+ // writing to the ResponseWriter's Header.
+ header := w.Header()
+ header.Add(":status", "123")
+ w.WriteHeader(321)
+ }))
+ tc := ts.connect()
+ tc.greet()
+
+ reqStream := tc.newStream(streamTypeRequest)
+ reqStream.writeHeaders(http.Header{":method": {"GET"}})
+ synctest.Wait()
+ reqStream.wantHeaders(map[string][]string{":status": {"321"}})
+ })
+}
+
+func TestServerInvalidHeader(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("valid-name", "valid value")
+ // Invalid headers are skipped.
+ w.Header().Add("invalid name with spaces", "some value")
+ w.Header().Add("some-name", "invalid value with \n")
+ w.Header().Add("valid-name-2", "valid value 2")
+ w.WriteHeader(200)
+ }))
+ tc := ts.connect()
+ tc.greet()
+
+ reqStream := tc.newStream(streamTypeRequest)
+ reqStream.writeHeaders(http.Header{})
+ synctest.Wait()
+ reqStream.wantHeaders(map[string][]string{
+ ":status": {"200"},
+ "Valid-Name": {"valid value"},
+ "Valid-Name-2": {"valid value 2"},
+ })
+ })
+}
+
+func TestServerBody(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ w.Write(body) // Implicitly calls w.WriteHeader(200).
+ }))
+ tc := ts.connect()
+ tc.greet()
+
+ reqStream := tc.newStream(streamTypeRequest)
+ reqStream.writeHeaders(http.Header{})
+ bodyContent := []byte("some body content that should be echoed")
+ reqStream.writeData(bodyContent)
+ reqStream.stream.stream.CloseWrite()
+ synctest.Wait()
+ reqStream.wantHeaders(http.Header{":status": {"200"}})
+ reqStream.wantData(bodyContent)
+ })
+}
+
type testServer struct {
t testing.TB
s *Server
@@ -57,9 +156,6 @@
e *quic.Endpoint
}
-func (te *testQUICEndpoint) dial() {
-}
-
type testServerConn struct {
ts *testServer
@@ -67,7 +163,7 @@
control *testQUICStream
}
-func newTestServer(t testing.TB) *testServer {
+func newTestServer(t testing.TB, handler http.Handler) *testServer {
t.Helper()
ts := &testServer{
t: t,
@@ -75,6 +171,7 @@
Config: &quic.Config{
TLSConfig: testTLSConfig,
},
+ Handler: handler,
},
}
e := ts.tn.newQUICEndpoint(t, ts.s.Config)
diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go
index 0b7134a..8d12ecf 100644
--- a/internal/http3/transport_test.go
+++ b/internal/http3/transport_test.go
@@ -229,6 +229,16 @@
}
}
+func (ts *testQUICStream) writeData(b []byte) {
+ ts.t.Helper()
+ ts.writeVarint(int64(frameTypeData))
+ ts.writeVarint(int64(len(b)))
+ ts.Write(b)
+ if err := ts.Flush(); err != nil {
+ ts.t.Fatalf("flushing DATA frame: %v", err)
+ }
+}
+
func (ts *testQUICStream) wantData(want []byte) {
ts.t.Helper()
synctest.Wait()