internal/http3: add basic request stream handling for the server

Add support for the server to handle basic request streams.
Specifically, reading HTTP headers & body in requests, and writing HTTP
headers & body in responses.

More sophisticated behaviors, such as dealing with 1xx status and
trailing headers, will be done in the future.

For golang/go#70914

Change-Id: I6a00435d72a118efe2ab76e4c7c53b9dca1d63f6
Reviewed-on: https://go-review.googlesource.com/c/net/+/735820
Reviewed-by: Nicholas Husin <husin@google.com>
Auto-Submit: Nicholas Husin <husin@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/http3/conn_test.go b/internal/http3/conn_test.go
index 78f19c4..b5c823f 100644
--- a/internal/http3/conn_test.go
+++ b/internal/http3/conn_test.go
@@ -147,7 +147,7 @@
 		f(t, tc.testQUICConn)
 	})
 	synctestSubtest(t, "server", func(t *testing.T) {
-		ts := newTestServer(t)
+		ts := newTestServer(t, nil)
 		tc := ts.connect()
 		f(t, tc.testQUICConn)
 	})
diff --git a/internal/http3/server.go b/internal/http3/server.go
index bdb3479..0bd8e8c 100644
--- a/internal/http3/server.go
+++ b/internal/http3/server.go
@@ -7,8 +7,10 @@
 import (
 	"context"
 	"net/http"
+	"strconv"
 	"sync"
 
+	"golang.org/x/net/http/httpguts"
 	"golang.org/x/net/quic"
 )
 
@@ -57,7 +59,7 @@
 		if err != nil {
 			return err
 		}
-		go newServerConn(qconn)
+		go newServerConn(qconn, s.Handler)
 	}
 }
 
@@ -67,11 +69,13 @@
 	genericConn // for handleUnidirectionalStream
 	enc         qpackEncoder
 	dec         qpackDecoder
+	handler     http.Handler
 }
 
-func newServerConn(qconn *quic.Conn) {
+func newServerConn(qconn *quic.Conn, handler http.Handler) {
 	sc := &serverConn{
-		qconn: qconn,
+		qconn:   qconn,
+		handler: handler,
 	}
 	sc.enc.init()
 
@@ -152,8 +156,56 @@
 	}
 }
 
+func parseRequest(st *stream) (*http.Request, error) {
+	req := &http.Request{}
+	ftype, err := st.readFrameHeader()
+	if err != nil {
+		return nil, err
+	}
+	if ftype != frameTypeHeaders {
+		return nil, err
+	}
+	req.Header = make(http.Header)
+	var dec qpackDecoder
+	if err := dec.decode(st, func(_ indexType, name, value string) error {
+		switch name {
+		case ":method":
+			req.Method = value
+		case ":scheme":
+			req.URL.Scheme = value
+		case ":path":
+			req.URL.Path = value
+		case ":authority":
+			req.URL.Host = value
+		default:
+			req.Header.Add(name, value)
+		}
+		return nil
+	}); err != nil {
+		return nil, err
+	}
+	if err := st.endFrame(); err != nil {
+		return nil, err
+	}
+	req.Body = &bodyReader{
+		st:     st,
+		remain: -1,
+	}
+	return req, nil
+}
+
 func (sc *serverConn) handleRequestStream(st *stream) error {
-	// TODO
+	req, err := parseRequest(st)
+	if err != nil {
+		return err
+	}
+	defer req.Body.Close()
+
+	responseWriter := sc.newResponseWriter(st)
+	defer responseWriter.close()
+
+	// TODO: handle panic coming from the HTTP handler.
+	sc.handler.ServeHTTP(responseWriter, req)
 	return nil
 }
 
@@ -168,3 +220,78 @@
 		sc.qconn.Abort(err)
 	}
 }
+
+type responseWriter struct {
+	st      *stream
+	bw      *bodyWriter
+	mu      sync.Mutex
+	headers http.Header
+	// TODO: support 1xx status
+	wroteHeader bool // Non-1xx header has been (logically) written.
+}
+
+func (sc *serverConn) newResponseWriter(st *stream) *responseWriter {
+	rw := &responseWriter{
+		st:      st,
+		headers: make(http.Header),
+		bw: &bodyWriter{
+			st:     st,
+			remain: -1,
+			flush:  false,
+			name:   "response",
+		},
+	}
+	return rw
+}
+
+func (rw *responseWriter) Header() http.Header {
+	return rw.headers
+}
+
+// Caller must hold rw.mu.
+func (rw *responseWriter) writeHeaderLocked(statusCode int) {
+	// TODO: support trailer header.
+	if rw.wroteHeader {
+		return
+	}
+	enc := &qpackEncoder{}
+	enc.init()
+	encHeaders := enc.encode(func(f func(itype indexType, name, value string)) {
+		f(mayIndex, ":status", strconv.Itoa(statusCode))
+		for name, values := range rw.headers {
+			if !httpguts.ValidHeaderFieldName(name) {
+				continue
+			}
+			for _, val := range values {
+				if !httpguts.ValidHeaderFieldValue(val) {
+					continue
+				}
+				// Issue #71374: Consider supporting never-indexed fields.
+				f(mayIndex, name, val)
+			}
+		}
+	})
+	rw.st.writeVarint(int64(frameTypeHeaders))
+	rw.st.writeVarint(int64(len(encHeaders)))
+	rw.st.Write(encHeaders)
+	rw.wroteHeader = true
+}
+
+func (rw *responseWriter) WriteHeader(statusCode int) {
+	rw.mu.Lock()
+	defer rw.mu.Unlock()
+	rw.writeHeaderLocked(statusCode)
+}
+
+func (rw *responseWriter) Write(b []byte) (int, error) {
+	rw.mu.Lock()
+	defer rw.mu.Unlock()
+	if !rw.wroteHeader {
+		rw.writeHeaderLocked(http.StatusOK)
+	}
+	return rw.bw.Write(b)
+}
+
+func (rw *responseWriter) close() error {
+	return rw.st.stream.Close()
+}
diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go
index 1e0cba9..3d29dba 100644
--- a/internal/http3/server_test.go
+++ b/internal/http3/server_test.go
@@ -7,6 +7,8 @@
 package http3
 
 import (
+	"io"
+	"net/http"
 	"net/netip"
 	"testing"
 	"testing/synctest"
@@ -20,7 +22,7 @@
 	// this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
 	// https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3
 	synctest.Test(t, func(t *testing.T) {
-		ts := newTestServer(t)
+		ts := newTestServer(t, nil)
 		tc := ts.connect()
 		tc.newStream(streamTypePush)
 		tc.wantClosed("invalid client-created push stream", errH3StreamCreationError)
@@ -29,7 +31,7 @@
 
 func TestServerCancelPushForUnsentPromise(t *testing.T) {
 	synctest.Test(t, func(t *testing.T) {
-		ts := newTestServer(t)
+		ts := newTestServer(t, nil)
 		tc := ts.connect()
 		tc.greet()
 
@@ -43,6 +45,103 @@
 	})
 }
 
+func TestServerHeader(t *testing.T) {
+	synctest.Test(t, func(t *testing.T) {
+		ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			header := w.Header()
+			for key, values := range r.Header {
+				for _, value := range values {
+					header.Add(key, value)
+				}
+			}
+			w.WriteHeader(204)
+		}))
+		tc := ts.connect()
+		tc.greet()
+
+		reqStream := tc.newStream(streamTypeRequest)
+		reqStream.writeHeaders(http.Header{
+			"header-from-client": {"that", "should", "be", "echoed"},
+		})
+		synctest.Wait()
+		reqStream.wantHeaders(map[string][]string{
+			":status":            {"204"},
+			"Header-From-Client": {"that", "should", "be", "echoed"},
+		})
+	})
+}
+
+func TestServerPseudoHeader(t *testing.T) {
+	synctest.Test(t, func(t *testing.T) {
+		ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			// Pseudo-headers from client request should populate a specific
+			// field in http.Request, and should not be part of http.Request.Header.
+			if r.Header.Get(":method") != "" || r.Method != "GET" {
+				t.Error("want pseudo-headers from client to be reflected in appropriate fields in http.Request, not in http.Request.Header")
+			}
+			// Conversely, server should not be able to set pseudo-headers by
+			// writing to the ResponseWriter's Header.
+			header := w.Header()
+			header.Add(":status", "123")
+			w.WriteHeader(321)
+		}))
+		tc := ts.connect()
+		tc.greet()
+
+		reqStream := tc.newStream(streamTypeRequest)
+		reqStream.writeHeaders(http.Header{":method": {"GET"}})
+		synctest.Wait()
+		reqStream.wantHeaders(map[string][]string{":status": {"321"}})
+	})
+}
+
+func TestServerInvalidHeader(t *testing.T) {
+	synctest.Test(t, func(t *testing.T) {
+		ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			w.Header().Add("valid-name", "valid value")
+			// Invalid headers are skipped.
+			w.Header().Add("invalid name with spaces", "some value")
+			w.Header().Add("some-name", "invalid value with \n")
+			w.Header().Add("valid-name-2", "valid value 2")
+			w.WriteHeader(200)
+		}))
+		tc := ts.connect()
+		tc.greet()
+
+		reqStream := tc.newStream(streamTypeRequest)
+		reqStream.writeHeaders(http.Header{})
+		synctest.Wait()
+		reqStream.wantHeaders(map[string][]string{
+			":status":      {"200"},
+			"Valid-Name":   {"valid value"},
+			"Valid-Name-2": {"valid value 2"},
+		})
+	})
+}
+
+func TestServerBody(t *testing.T) {
+	synctest.Test(t, func(t *testing.T) {
+		ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			body, err := io.ReadAll(r.Body)
+			if err != nil {
+				t.Fatal(err)
+			}
+			w.Write(body) // Implicitly calls w.WriteHeader(200).
+		}))
+		tc := ts.connect()
+		tc.greet()
+
+		reqStream := tc.newStream(streamTypeRequest)
+		reqStream.writeHeaders(http.Header{})
+		bodyContent := []byte("some body content that should be echoed")
+		reqStream.writeData(bodyContent)
+		reqStream.stream.stream.CloseWrite()
+		synctest.Wait()
+		reqStream.wantHeaders(http.Header{":status": {"200"}})
+		reqStream.wantData(bodyContent)
+	})
+}
+
 type testServer struct {
 	t  testing.TB
 	s  *Server
@@ -57,9 +156,6 @@
 	e *quic.Endpoint
 }
 
-func (te *testQUICEndpoint) dial() {
-}
-
 type testServerConn struct {
 	ts *testServer
 
@@ -67,7 +163,7 @@
 	control *testQUICStream
 }
 
-func newTestServer(t testing.TB) *testServer {
+func newTestServer(t testing.TB, handler http.Handler) *testServer {
 	t.Helper()
 	ts := &testServer{
 		t: t,
@@ -75,6 +171,7 @@
 			Config: &quic.Config{
 				TLSConfig: testTLSConfig,
 			},
+			Handler: handler,
 		},
 	}
 	e := ts.tn.newQUICEndpoint(t, ts.s.Config)
diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go
index 0b7134a..8d12ecf 100644
--- a/internal/http3/transport_test.go
+++ b/internal/http3/transport_test.go
@@ -229,6 +229,16 @@
 	}
 }
 
+func (ts *testQUICStream) writeData(b []byte) {
+	ts.t.Helper()
+	ts.writeVarint(int64(frameTypeData))
+	ts.writeVarint(int64(len(b)))
+	ts.Write(b)
+	if err := ts.Flush(); err != nil {
+		ts.t.Fatalf("flushing DATA frame: %v", err)
+	}
+}
+
 func (ts *testQUICStream) wantData(want []byte) {
 	ts.t.Helper()
 	synctest.Wait()