blob: 2c4532571a48b571a7aeccb4ade4f4ad0bfdefcb [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"bytes"
"io"
"net/http"
"os"
"reflect"
"slices"
"testing"
"golang.org/x/net/http2/hpack"
)
type testConnFramer struct {
t testing.TB
fr *Framer
dec *hpack.Decoder
}
// readFrame reads the next frame.
// It returns nil if the conn is closed or no frames are available.
func (tf *testConnFramer) readFrame() Frame {
tf.t.Helper()
fr, err := tf.fr.ReadFrame()
if err == io.EOF || err == os.ErrDeadlineExceeded {
return nil
}
if err != nil {
tf.t.Fatalf("ReadFrame: %v", err)
}
return fr
}
type readFramer interface {
readFrame() Frame
}
// readFrame reads a frame of a specific type.
func readFrame[T any](t testing.TB, framer readFramer) T {
t.Helper()
var v T
fr := framer.readFrame()
if fr == nil {
t.Fatalf("got no frame, want frame %T", v)
}
v, ok := fr.(T)
if !ok {
t.Fatalf("got frame %T, want %T", fr, v)
}
return v
}
// wantFrameType reads the next frame.
// It produces an error if the frame type is not the expected value.
func (tf *testConnFramer) wantFrameType(want FrameType) {
tf.t.Helper()
fr := tf.readFrame()
if fr == nil {
tf.t.Fatalf("got no frame, want frame %v", want)
}
if got := fr.Header().Type; got != want {
tf.t.Fatalf("got frame %v, want %v", got, want)
}
}
// wantUnorderedFrames reads frames until every condition in want has been satisfied.
//
// want is a list of func(*SomeFrame) bool.
// wantUnorderedFrames will call each func with frames of the appropriate type
// until the func returns true.
// It calls t.Fatal if an unexpected frame is received (no func has that frame type,
// or all funcs with that type have returned true), or if the framer runs out of frames
// with unsatisfied funcs.
//
// Example:
//
// // Read a SETTINGS frame, and any number of DATA frames for a stream.
// // The SETTINGS frame may appear anywhere in the sequence.
// // The last DATA frame must indicate the end of the stream.
// tf.wantUnorderedFrames(
// func(f *SettingsFrame) bool {
// return true
// },
// func(f *DataFrame) bool {
// return f.StreamEnded()
// },
// )
func (tf *testConnFramer) wantUnorderedFrames(want ...any) {
tf.t.Helper()
want = slices.Clone(want)
seen := 0
frame:
for seen < len(want) && !tf.t.Failed() {
fr := tf.readFrame()
if fr == nil {
break
}
for i, f := range want {
if f == nil {
continue
}
typ := reflect.TypeOf(f)
if typ.Kind() != reflect.Func ||
typ.NumIn() != 1 ||
typ.NumOut() != 1 ||
typ.Out(0) != reflect.TypeOf(true) {
tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
}
if typ.In(0) == reflect.TypeOf(fr) {
out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
if out[0].Bool() {
want[i] = nil
seen++
}
continue frame
}
}
tf.t.Errorf("got unexpected frame type %T", fr)
}
if seen < len(want) {
for _, f := range want {
if f == nil {
continue
}
tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
}
tf.t.Fatalf("did not see %v expected frame types", len(want)-seen)
}
}
type wantHeader struct {
streamID uint32
endStream bool
header http.Header
}
// wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
// and asserts that they contain the expected headers.
func (tf *testConnFramer) wantHeaders(want wantHeader) {
tf.t.Helper()
hf := readFrame[*HeadersFrame](tf.t, tf)
if got, want := hf.StreamID, want.streamID; got != want {
tf.t.Fatalf("got stream ID %v, want %v", got, want)
}
if got, want := hf.StreamEnded(), want.endStream; got != want {
tf.t.Fatalf("got stream ended %v, want %v", got, want)
}
gotHeader := make(http.Header)
tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value)
})
defer tf.dec.SetEmitFunc(nil)
if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil {
tf.t.Fatalf("decoding HEADERS frame: %v", err)
}
headersEnded := hf.HeadersEnded()
for !headersEnded {
cf := readFrame[*ContinuationFrame](tf.t, tf)
if cf == nil {
tf.t.Fatalf("got end of frames, want CONTINUATION")
}
if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil {
tf.t.Fatalf("decoding CONTINUATION frame: %v", err)
}
headersEnded = cf.HeadersEnded()
}
if err := tf.dec.Close(); err != nil {
tf.t.Fatalf("hpack decoding error: %v", err)
}
for k, v := range want.header {
if !reflect.DeepEqual(v, gotHeader[k]) {
tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
}
}
}
// decodeHeader supports some older server tests.
// TODO: rewrite those tests to use newer, more convenient test APIs.
func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) {
tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
if hf.Name == "date" {
return
}
pairs = append(pairs, [2]string{hf.Name, hf.Value})
})
defer tf.dec.SetEmitFunc(nil)
if _, err := tf.dec.Write(headerBlock); err != nil {
tf.t.Fatalf("hpack decoding error: %v", err)
}
if err := tf.dec.Close(); err != nil {
tf.t.Fatalf("hpack decoding error: %v", err)
}
return pairs
}
type wantData struct {
streamID uint32
endStream bool
size int
data []byte
multiple bool // data may be spread across multiple DATA frames
}
// wantData reads zero or more DATA frames, and asserts that they match the expectation.
func (tf *testConnFramer) wantData(want wantData) {
tf.t.Helper()
gotSize := 0
gotEndStream := false
if want.data != nil {
want.size = len(want.data)
}
var gotData []byte
for {
fr := tf.readFrame()
if fr == nil {
break
}
data, ok := fr.(*DataFrame)
if !ok {
tf.t.Fatalf("got frame %T, want DataFrame", fr)
}
if want.data != nil {
gotData = append(gotData, data.Data()...)
}
gotSize += len(data.Data())
if data.StreamEnded() {
gotEndStream = true
break
}
if !want.endStream && gotSize >= want.size {
break
}
if !want.multiple {
break
}
}
if gotSize != want.size {
tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
}
if gotEndStream != want.endStream {
tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
}
if want.data != nil && !bytes.Equal(gotData, want.data) {
tf.t.Fatalf("got data %q, want %q", gotData, want.data)
}
}
func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
tf.t.Helper()
fr := readFrame[*RSTStreamFrame](tf.t, tf)
if fr.StreamID != streamID || fr.ErrCode != code {
tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", summarizeFrame(fr), streamID, code)
}
}
func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
fr := readFrame[*SettingsFrame](tf.t, tf)
if fr.Header().Flags.Has(FlagSettingsAck) {
tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
}
for wantID, wantVal := range want {
gotVal, ok := fr.Value(wantID)
if !ok {
tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
} else if gotVal != wantVal {
tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
}
}
if tf.t.Failed() {
tf.t.Fatalf("%v", fr)
}
}
func (tf *testConnFramer) wantSettingsAck() {
tf.t.Helper()
fr := readFrame[*SettingsFrame](tf.t, tf)
if !fr.Header().Flags.Has(FlagSettingsAck) {
tf.t.Fatal("Settings Frame didn't have ACK set")
}
}
func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) {
tf.t.Helper()
fr := readFrame[*GoAwayFrame](tf.t, tf)
if fr.LastStreamID != maxStreamID || fr.ErrCode != code {
tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", summarizeFrame(fr), maxStreamID, code)
}
}
func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) {
tf.t.Helper()
wu := readFrame[*WindowUpdateFrame](tf.t, tf)
if wu.FrameHeader.StreamID != streamID {
tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
}
if wu.Increment != incr {
tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
}
}
func (tf *testConnFramer) wantClosed() {
tf.t.Helper()
fr, err := tf.fr.ReadFrame()
if err == nil {
tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
}
if err == os.ErrDeadlineExceeded {
tf.t.Fatalf("connection is not closed; want it to be")
}
}
func (tf *testConnFramer) wantIdle() {
tf.t.Helper()
fr, err := tf.fr.ReadFrame()
if err == nil {
tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
}
if err != os.ErrDeadlineExceeded {
tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
}
}
func (tf *testConnFramer) writeSettings(settings ...Setting) {
tf.t.Helper()
if err := tf.fr.WriteSettings(settings...); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writeSettingsAck() {
tf.t.Helper()
if err := tf.fr.WriteSettingsAck(); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) {
tf.t.Helper()
if err := tf.fr.WriteData(streamID, endStream, data); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
tf.t.Helper()
if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) {
tf.t.Helper()
if err := tf.fr.WriteHeaders(p); err != nil {
tf.t.Fatal(err)
}
}
// writeHeadersMode writes header frames, as modified by mode:
//
// - noHeader: Don't write the header.
// - oneHeader: Write a single HEADERS frame.
// - splitHeader: Write a HEADERS frame and CONTINUATION frame.
func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) {
tf.t.Helper()
switch mode {
case noHeader:
case oneHeader:
tf.writeHeaders(p)
case splitHeader:
if len(p.BlockFragment) < 2 {
panic("too small")
}
contData := p.BlockFragment[1:]
contEnd := p.EndHeaders
p.BlockFragment = p.BlockFragment[:1]
p.EndHeaders = false
tf.writeHeaders(p)
tf.writeContinuation(p.StreamID, contEnd, contData)
default:
panic("bogus mode")
}
}
func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
tf.t.Helper()
if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) {
if err := tf.fr.WritePriority(id, p); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) {
tf.t.Helper()
if err := tf.fr.WriteRSTStream(streamID, code); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writePing(ack bool, data [8]byte) {
tf.t.Helper()
if err := tf.fr.WritePing(ack, data); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
tf.t.Helper()
if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
tf.t.Fatal(err)
}
}
func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) {
tf.t.Helper()
if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil {
tf.t.Fatal(err)
}
}