| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package http2 |
| |
| import ( |
| "bytes" |
| "io" |
| "net/http" |
| "os" |
| "reflect" |
| "slices" |
| "testing" |
| |
| "golang.org/x/net/http2/hpack" |
| ) |
| |
| type testConnFramer struct { |
| t testing.TB |
| fr *Framer |
| dec *hpack.Decoder |
| } |
| |
| // readFrame reads the next frame. |
| // It returns nil if the conn is closed or no frames are available. |
| func (tf *testConnFramer) readFrame() Frame { |
| tf.t.Helper() |
| fr, err := tf.fr.ReadFrame() |
| if err == io.EOF || err == os.ErrDeadlineExceeded { |
| return nil |
| } |
| if err != nil { |
| tf.t.Fatalf("ReadFrame: %v", err) |
| } |
| return fr |
| } |
| |
| type readFramer interface { |
| readFrame() Frame |
| } |
| |
| // readFrame reads a frame of a specific type. |
| func readFrame[T any](t testing.TB, framer readFramer) T { |
| t.Helper() |
| var v T |
| fr := framer.readFrame() |
| if fr == nil { |
| t.Fatalf("got no frame, want frame %T", v) |
| } |
| v, ok := fr.(T) |
| if !ok { |
| t.Fatalf("got frame %T, want %T", fr, v) |
| } |
| return v |
| } |
| |
| // wantFrameType reads the next frame. |
| // It produces an error if the frame type is not the expected value. |
| func (tf *testConnFramer) wantFrameType(want FrameType) { |
| tf.t.Helper() |
| fr := tf.readFrame() |
| if fr == nil { |
| tf.t.Fatalf("got no frame, want frame %v", want) |
| } |
| if got := fr.Header().Type; got != want { |
| tf.t.Fatalf("got frame %v, want %v", got, want) |
| } |
| } |
| |
| // wantUnorderedFrames reads frames until every condition in want has been satisfied. |
| // |
| // want is a list of func(*SomeFrame) bool. |
| // wantUnorderedFrames will call each func with frames of the appropriate type |
| // until the func returns true. |
| // It calls t.Fatal if an unexpected frame is received (no func has that frame type, |
| // or all funcs with that type have returned true), or if the framer runs out of frames |
| // with unsatisfied funcs. |
| // |
| // Example: |
| // |
| // // Read a SETTINGS frame, and any number of DATA frames for a stream. |
| // // The SETTINGS frame may appear anywhere in the sequence. |
| // // The last DATA frame must indicate the end of the stream. |
| // tf.wantUnorderedFrames( |
| // func(f *SettingsFrame) bool { |
| // return true |
| // }, |
| // func(f *DataFrame) bool { |
| // return f.StreamEnded() |
| // }, |
| // ) |
| func (tf *testConnFramer) wantUnorderedFrames(want ...any) { |
| tf.t.Helper() |
| want = slices.Clone(want) |
| seen := 0 |
| frame: |
| for seen < len(want) && !tf.t.Failed() { |
| fr := tf.readFrame() |
| if fr == nil { |
| break |
| } |
| for i, f := range want { |
| if f == nil { |
| continue |
| } |
| typ := reflect.TypeOf(f) |
| if typ.Kind() != reflect.Func || |
| typ.NumIn() != 1 || |
| typ.NumOut() != 1 || |
| typ.Out(0) != reflect.TypeOf(true) { |
| tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f) |
| } |
| if typ.In(0) == reflect.TypeOf(fr) { |
| out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)}) |
| if out[0].Bool() { |
| want[i] = nil |
| seen++ |
| } |
| continue frame |
| } |
| } |
| tf.t.Errorf("got unexpected frame type %T", fr) |
| } |
| if seen < len(want) { |
| for _, f := range want { |
| if f == nil { |
| continue |
| } |
| tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0)) |
| } |
| tf.t.Fatalf("did not see %v expected frame types", len(want)-seen) |
| } |
| } |
| |
| type wantHeader struct { |
| streamID uint32 |
| endStream bool |
| header http.Header |
| } |
| |
| // wantHeaders reads a HEADERS frame and potential CONTINUATION frames, |
| // and asserts that they contain the expected headers. |
| func (tf *testConnFramer) wantHeaders(want wantHeader) { |
| tf.t.Helper() |
| |
| hf := readFrame[*HeadersFrame](tf.t, tf) |
| if got, want := hf.StreamID, want.streamID; got != want { |
| tf.t.Fatalf("got stream ID %v, want %v", got, want) |
| } |
| if got, want := hf.StreamEnded(), want.endStream; got != want { |
| tf.t.Fatalf("got stream ended %v, want %v", got, want) |
| } |
| |
| gotHeader := make(http.Header) |
| tf.dec.SetEmitFunc(func(hf hpack.HeaderField) { |
| gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value) |
| }) |
| defer tf.dec.SetEmitFunc(nil) |
| if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil { |
| tf.t.Fatalf("decoding HEADERS frame: %v", err) |
| } |
| headersEnded := hf.HeadersEnded() |
| for !headersEnded { |
| cf := readFrame[*ContinuationFrame](tf.t, tf) |
| if cf == nil { |
| tf.t.Fatalf("got end of frames, want CONTINUATION") |
| } |
| if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil { |
| tf.t.Fatalf("decoding CONTINUATION frame: %v", err) |
| } |
| headersEnded = cf.HeadersEnded() |
| } |
| if err := tf.dec.Close(); err != nil { |
| tf.t.Fatalf("hpack decoding error: %v", err) |
| } |
| |
| for k, v := range want.header { |
| if !reflect.DeepEqual(v, gotHeader[k]) { |
| tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k]) |
| } |
| } |
| } |
| |
| // decodeHeader supports some older server tests. |
| // TODO: rewrite those tests to use newer, more convenient test APIs. |
| func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) { |
| tf.dec.SetEmitFunc(func(hf hpack.HeaderField) { |
| if hf.Name == "date" { |
| return |
| } |
| pairs = append(pairs, [2]string{hf.Name, hf.Value}) |
| }) |
| defer tf.dec.SetEmitFunc(nil) |
| if _, err := tf.dec.Write(headerBlock); err != nil { |
| tf.t.Fatalf("hpack decoding error: %v", err) |
| } |
| if err := tf.dec.Close(); err != nil { |
| tf.t.Fatalf("hpack decoding error: %v", err) |
| } |
| return pairs |
| } |
| |
| type wantData struct { |
| streamID uint32 |
| endStream bool |
| size int |
| data []byte |
| multiple bool // data may be spread across multiple DATA frames |
| } |
| |
| // wantData reads zero or more DATA frames, and asserts that they match the expectation. |
| func (tf *testConnFramer) wantData(want wantData) { |
| tf.t.Helper() |
| gotSize := 0 |
| gotEndStream := false |
| if want.data != nil { |
| want.size = len(want.data) |
| } |
| var gotData []byte |
| for { |
| fr := tf.readFrame() |
| if fr == nil { |
| break |
| } |
| data, ok := fr.(*DataFrame) |
| if !ok { |
| tf.t.Fatalf("got frame %T, want DataFrame", fr) |
| } |
| if want.data != nil { |
| gotData = append(gotData, data.Data()...) |
| } |
| gotSize += len(data.Data()) |
| if data.StreamEnded() { |
| gotEndStream = true |
| break |
| } |
| if !want.endStream && gotSize >= want.size { |
| break |
| } |
| if !want.multiple { |
| break |
| } |
| } |
| if gotSize != want.size { |
| tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size) |
| } |
| if gotEndStream != want.endStream { |
| tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream) |
| } |
| if want.data != nil && !bytes.Equal(gotData, want.data) { |
| tf.t.Fatalf("got data %q, want %q", gotData, want.data) |
| } |
| } |
| |
| func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) { |
| tf.t.Helper() |
| fr := readFrame[*RSTStreamFrame](tf.t, tf) |
| if fr.StreamID != streamID || fr.ErrCode != code { |
| tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", summarizeFrame(fr), streamID, code) |
| } |
| } |
| |
| func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) { |
| fr := readFrame[*SettingsFrame](tf.t, tf) |
| if fr.Header().Flags.Has(FlagSettingsAck) { |
| tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK") |
| } |
| for wantID, wantVal := range want { |
| gotVal, ok := fr.Value(wantID) |
| if !ok { |
| tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal) |
| } else if gotVal != wantVal { |
| tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal) |
| } |
| } |
| if tf.t.Failed() { |
| tf.t.Fatalf("%v", fr) |
| } |
| } |
| |
| func (tf *testConnFramer) wantSettingsAck() { |
| tf.t.Helper() |
| fr := readFrame[*SettingsFrame](tf.t, tf) |
| if !fr.Header().Flags.Has(FlagSettingsAck) { |
| tf.t.Fatal("Settings Frame didn't have ACK set") |
| } |
| } |
| |
| func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) { |
| tf.t.Helper() |
| fr := readFrame[*GoAwayFrame](tf.t, tf) |
| if fr.LastStreamID != maxStreamID || fr.ErrCode != code { |
| tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", summarizeFrame(fr), maxStreamID, code) |
| } |
| } |
| |
| func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) { |
| tf.t.Helper() |
| wu := readFrame[*WindowUpdateFrame](tf.t, tf) |
| if wu.FrameHeader.StreamID != streamID { |
| tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID) |
| } |
| if wu.Increment != incr { |
| tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr) |
| } |
| } |
| |
| func (tf *testConnFramer) wantClosed() { |
| tf.t.Helper() |
| fr, err := tf.fr.ReadFrame() |
| if err == nil { |
| tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr) |
| } |
| if err == os.ErrDeadlineExceeded { |
| tf.t.Fatalf("connection is not closed; want it to be") |
| } |
| } |
| |
| func (tf *testConnFramer) wantIdle() { |
| tf.t.Helper() |
| fr, err := tf.fr.ReadFrame() |
| if err == nil { |
| tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr) |
| } |
| if err != os.ErrDeadlineExceeded { |
| tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeSettings(settings ...Setting) { |
| tf.t.Helper() |
| if err := tf.fr.WriteSettings(settings...); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeSettingsAck() { |
| tf.t.Helper() |
| if err := tf.fr.WriteSettingsAck(); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) { |
| tf.t.Helper() |
| if err := tf.fr.WriteData(streamID, endStream, data); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { |
| tf.t.Helper() |
| if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) { |
| tf.t.Helper() |
| if err := tf.fr.WriteHeaders(p); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| // writeHeadersMode writes header frames, as modified by mode: |
| // |
| // - noHeader: Don't write the header. |
| // - oneHeader: Write a single HEADERS frame. |
| // - splitHeader: Write a HEADERS frame and CONTINUATION frame. |
| func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) { |
| tf.t.Helper() |
| switch mode { |
| case noHeader: |
| case oneHeader: |
| tf.writeHeaders(p) |
| case splitHeader: |
| if len(p.BlockFragment) < 2 { |
| panic("too small") |
| } |
| contData := p.BlockFragment[1:] |
| contEnd := p.EndHeaders |
| p.BlockFragment = p.BlockFragment[:1] |
| p.EndHeaders = false |
| tf.writeHeaders(p) |
| tf.writeContinuation(p.StreamID, contEnd, contData) |
| default: |
| panic("bogus mode") |
| } |
| } |
| |
| func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) { |
| tf.t.Helper() |
| if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) { |
| if err := tf.fr.WritePriority(id, p); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) { |
| tf.t.Helper() |
| if err := tf.fr.WriteRSTStream(streamID, code); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writePing(ack bool, data [8]byte) { |
| tf.t.Helper() |
| if err := tf.fr.WritePing(ack, data); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) { |
| tf.t.Helper() |
| if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |
| |
| func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) { |
| tf.t.Helper() |
| if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil { |
| tf.t.Fatal(err) |
| } |
| } |