http2: make the Transport write request body data as it's available
Unlike HTTP/1, we now permit streaming the write of a request body as
we read the response body, since HTTP/2's framing makes it possible.
Our behavior however is based on a heuristic: we always begin writing
the request body right away (like previously, and like HTTP/1), but if
we're still writing the request body and the server replies with a
status code over 299 (not 1xx and not 2xx), then we stop writing the
request body, assuming the server doesn't care about it. There is
currently no switch (and hopefully won't be) to force enable this
behavior. In the case where the server replied with a 1xx/2xx and
we're still writing the request body but the server doesn't want it,
the server can do a RST_STREAM, which we respect as before and stop
sending.
Also in this CL:
* adds an h2demo handler at https://http2.golang.org/ECHO to demo it
* fixes a potential flow control integer truncation bug
* start of clientTester type used for the tests in this CL, similar
to the serverTester. It's still a bit cumbersome to write client
tests, though.
* fix potential deadlock where awaitFlowControl could block while
waiting a stream reset arrived. fix it by moving all checks into
the sync.Cond loop, rather than having a sync.Cond check followed
by a select. simplifies code, too.
* fix two data races in test-only code.
Updates golang/go#13444
Change-Id: Idfda6833a212a89fcd65293cdeb4169d1723724f
Reviewed-on: https://go-review.googlesource.com/17310
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 8379157..0c875ac 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5,21 +5,29 @@
package http2
import (
+ "bufio"
+ "bytes"
"crypto/tls"
+ "errors"
"flag"
"fmt"
"io"
"io/ioutil"
+ "log"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"reflect"
+ "strconv"
"strings"
"sync"
+ "sync/atomic"
"testing"
"time"
+
+ "golang.org/x/net/http2/hpack"
)
var (
@@ -182,6 +190,8 @@
if !ok {
return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
}
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
if len(cp.dialing) != 0 {
return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
}
@@ -456,3 +466,296 @@
t.Errorf("body = %q; want %q", got, want)
}
}
+
+type capitalizeReader struct {
+ r io.Reader
+}
+
+func (cr capitalizeReader) Read(p []byte) (n int, err error) {
+ n, err = cr.r.Read(p)
+ for i, b := range p[:n] {
+ if b >= 'a' && b <= 'z' {
+ p[i] = b - ('a' - 'A')
+ }
+ }
+ return
+}
+
+type flushWriter struct {
+ w io.Writer
+}
+
+func (fw flushWriter) Write(p []byte) (n int, err error) {
+ n, err = fw.w.Write(p)
+ if f, ok := fw.w.(http.Flusher); ok {
+ f.Flush()
+ }
+ return
+}
+
+type clientTester struct {
+ t *testing.T
+ tr *Transport
+ sc, cc net.Conn // server and client conn
+ fr *Framer // server's framer
+ client func() error
+ server func() error
+}
+
+func newClientTester(t *testing.T) *clientTester {
+ var dialOnce struct {
+ sync.Mutex
+ dialed bool
+ }
+ ct := &clientTester{
+ t: t,
+ }
+ ct.tr = &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ dialOnce.Lock()
+ defer dialOnce.Unlock()
+ if dialOnce.dialed {
+ return nil, errors.New("only one dial allowed in test mode")
+ }
+ dialOnce.dialed = true
+ return ct.cc, nil
+ },
+ }
+
+ ln := newLocalListener(t)
+ cc, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+
+ }
+ sc, err := ln.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln.Close()
+ ct.cc = cc
+ ct.sc = sc
+ ct.fr = NewFramer(sc, sc)
+ return ct
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp4", "127.0.0.1:0")
+ if err == nil {
+ return ln
+ }
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+func (ct *clientTester) greet() {
+ buf := make([]byte, len(ClientPreface))
+ _, err := io.ReadFull(ct.sc, buf)
+ if err != nil {
+ ct.t.Fatalf("reading client preface: %v", err)
+ }
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ ct.t.Fatalf("Reading client settings frame: %v", err)
+ }
+ if sf, ok := f.(*SettingsFrame); !ok {
+ ct.t.Fatalf("Wanted client settings frame; got %v", f)
+ _ = sf // stash it away?
+ }
+ if err := ct.fr.WriteSettings(); err != nil {
+ ct.t.Fatal(err)
+ }
+ if err := ct.fr.WriteSettingsAck(); err != nil {
+ ct.t.Fatal(err)
+ }
+}
+
+func (ct *clientTester) run() {
+ errc := make(chan error, 2)
+ ct.start("client", errc, ct.client)
+ ct.start("server", errc, ct.server)
+ for i := 0; i < 2; i++ {
+ if err := <-errc; err != nil {
+ ct.t.Error(err)
+ return
+ }
+ }
+}
+
+func (ct *clientTester) start(which string, errc chan<- error, fn func() error) {
+ go func() {
+ finished := false
+ var err error
+ defer func() {
+ if !finished {
+ err = fmt.Errorf("%s goroutine didn't finish.", which)
+ } else if err != nil {
+ err = fmt.Errorf("%s: %v", which, err)
+ }
+ errc <- err
+ }()
+ err = fn()
+ finished = true
+ }()
+}
+
+type countingReader struct {
+ n *int64
+}
+
+func (r countingReader) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(i)
+ }
+ atomic.AddInt64(r.n, int64(len(p)))
+ return len(p), err
+}
+
+func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
+func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
+
+func testTransportReqBodyAfterResponse(t *testing.T, status int) {
+ const bodySize = 10 << 20
+ ct := newClientTester(t)
+ ct.client = func() error {
+ var n int64 // atomic
+ req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
+ if err != nil {
+ return err
+ }
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != status {
+ return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("Slurp: %v", err)
+ }
+ if len(slurp) > 0 {
+ return fmt.Errorf("unexpected body: %q", slurp)
+ }
+ if status == 200 {
+ if got := atomic.LoadInt64(&n); got != bodySize {
+ return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
+ }
+ } else {
+ if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
+ return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
+ }
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ var dataRecv int64
+ var closed bool
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ //println(fmt.Sprintf("server got frame: %v", f))
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ if f.StreamEnded() {
+ return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
+ }
+ time.Sleep(50 * time.Millisecond) // let client send body
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ case *DataFrame:
+ dataLen := len(f.Data())
+ dataRecv += int64(dataLen)
+ if dataLen > 0 {
+ if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
+ return err
+ }
+ if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
+ return err
+ }
+ }
+ if !closed && ((status != 200 && dataRecv > 0) ||
+ (status == 200 && dataRecv == bodySize)) {
+ closed = true
+ if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
+ return err
+ }
+ return nil
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ return nil
+ }
+ ct.run()
+}
+
+// See golang.org/issue/13444
+func TestTransportFullDuplex(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200) // redundant but for clarity
+ w.(http.Flusher).Flush()
+ io.Copy(flushWriter{w}, capitalizeReader{r.Body})
+ fmt.Fprintf(w, "bye.\n")
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+
+ pr, pw := io.Pipe()
+ req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
+ if err != nil {
+ log.Fatal(err)
+ }
+ res, err := c.Do(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
+ }
+ bs := bufio.NewScanner(res.Body)
+ want := func(v string) {
+ if !bs.Scan() {
+ t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
+ }
+ }
+ write := func(v string) {
+ _, err := io.WriteString(pw, v)
+ if err != nil {
+ t.Fatalf("pipe write: %v", err)
+ }
+ }
+ write("foo\n")
+ want("FOO")
+ write("bar\n")
+ want("BAR")
+ pw.Close()
+ want("bye.")
+ if err := bs.Err(); err != nil {
+ t.Fatal(err)
+ }
+}