http2: make the Transport write request body data as it's available

Unlike HTTP/1, we now permit streaming the write of a request body as
we read the response body, since HTTP/2's framing makes it possible.
Our behavior however is based on a heuristic: we always begin writing
the request body right away (like previously, and like HTTP/1), but if
we're still writing the request body and the server replies with a
status code over 299 (not 1xx and not 2xx), then we stop writing the
request body, assuming the server doesn't care about it. There is
currently no switch (and hopefully won't be) to force enable this
behavior. In the case where the server replied with a 1xx/2xx and
we're still writing the request body but the server doesn't want it,
the server can do a RST_STREAM, which we respect as before and stop
sending.

Also in this CL:

* adds an h2demo handler at https://http2.golang.org/ECHO to demo it

* fixes a potential flow control integer truncation bug

* start of clientTester type used for the tests in this CL, similar
  to the serverTester. It's still a bit cumbersome to write client
  tests, though.

* fix potential deadlock where awaitFlowControl could block while
  waiting a stream reset arrived. fix it by moving all checks into
  the sync.Cond loop, rather than having a sync.Cond check followed
  by a select. simplifies code, too.

* fix two data races in test-only code.

Updates golang/go#13444

Change-Id: Idfda6833a212a89fcd65293cdeb4169d1723724f
Reviewed-on: https://go-review.googlesource.com/17310
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
diff --git a/http2/h2demo/h2demo.go b/http2/h2demo/h2demo.go
index 8d5e4fd..15ef52f 100644
--- a/http2/h2demo/h2demo.go
+++ b/http2/h2demo/h2demo.go
@@ -91,6 +91,7 @@
   <li>GET <a href="/redirect">/redirect</a> to redirect back to / (this page)</li>
   <li>GET <a href="/goroutines">/goroutines</a> to see all active goroutines in this server</li>
   <li>PUT something to <a href="/crc32">/crc32</a> to get a count of number of bytes and its CRC-32</li>
+  <li>PUT something to <a href="/ECHO">/ECHO</a> and it will be streamed back to you capitalized</li>
 </ul>
 
 </body></html>`)
@@ -124,6 +125,40 @@
 	}
 }
 
+type capitalizeReader struct {
+	r io.Reader
+}
+
+func (cr capitalizeReader) Read(p []byte) (n int, err error) {
+	n, err = cr.r.Read(p)
+	for i, b := range p[:n] {
+		if b >= 'a' && b <= 'z' {
+			p[i] = b - ('a' - 'A')
+		}
+	}
+	return
+}
+
+type flushWriter struct {
+	w io.Writer
+}
+
+func (fw flushWriter) Write(p []byte) (n int, err error) {
+	n, err = fw.w.Write(p)
+	if f, ok := fw.w.(http.Flusher); ok {
+		f.Flush()
+	}
+	return
+}
+
+func echoCapitalHandler(w http.ResponseWriter, r *http.Request) {
+	if r.Method != "PUT" {
+		http.Error(w, "PUT required.", 400)
+		return
+	}
+	io.Copy(flushWriter{w}, capitalizeReader{r.Body})
+}
+
 var (
 	fsGrp   singleflight.Group
 	fsMu    sync.Mutex // guards fsCache
@@ -217,6 +252,7 @@
 	mux2.Handle("/file/go.src.tar.gz", fileServer("https://storage.googleapis.com/golang/go1.4.1.src.tar.gz"))
 	mux2.HandleFunc("/reqinfo", reqInfoHandler)
 	mux2.HandleFunc("/crc32", crcHandler)
+	mux2.HandleFunc("/ECHO", echoCapitalHandler)
 	mux2.HandleFunc("/clockstream", clockStreamHandler)
 	mux2.Handle("/gophertiles", tiles)
 	mux2.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
diff --git a/http2/server_test.go b/http2/server_test.go
index 7a42051..7e8eb7e 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -2213,6 +2213,9 @@
 		t.Skip("skipping curl test in short mode")
 	}
 	requireCurl(t)
+	var gotConn int32
+	testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) }
+
 	const msg = "Hello from curl!\n"
 	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("Foo", "Bar")
@@ -2226,9 +2229,6 @@
 	ts.StartTLS()
 	defer ts.Close()
 
-	var gotConn int32
-	testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) }
-
 	t.Logf("Running test server for curl to hit at: %s", ts.URL)
 	container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL)
 	defer kill(container)
diff --git a/http2/transport.go b/http2/transport.go
index 320bf67..9d9e09c 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -155,6 +155,7 @@
 	inflow      flow  // guarded by cc.mu
 	bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
 	readErr     error // sticky read error; owned by transportResponseBody.Read
+	stopReqBody bool  // stop writing req body; guarded by cc.mu
 
 	peerReset chan struct{} // closed on peer reset
 	resetErr  error         // populated before peerReset is closed
@@ -171,6 +172,14 @@
 	}
 }
 
+func (cs *clientStream) abortRequestBodyWrite() {
+	cc := cs.cc
+	cc.mu.Lock()
+	cs.stopReqBody = true
+	cc.cond.Broadcast()
+	cc.mu.Unlock()
+}
+
 type stickyErrWriter struct {
 	w   io.Writer
 	err *error
@@ -516,26 +525,33 @@
 		return nil, werr
 	}
 
-	var bodyCopyErrc chan error
-	var gotResHeaders chan struct{} // closed on resheaders
+	var bodyCopyErrc chan error // result of body copy
 	if hasBody {
 		bodyCopyErrc = make(chan error, 1)
-		gotResHeaders = make(chan struct{})
 		go func() {
-			bodyCopyErrc <- cs.writeRequestBody(req.Body, gotResHeaders)
+			bodyCopyErrc <- cs.writeRequestBody(req.Body)
 		}()
 	}
 
 	for {
 		select {
 		case re := <-cs.resc:
-			if gotResHeaders != nil {
-				close(gotResHeaders)
+			res := re.res
+			if re.err != nil || res.StatusCode > 299 {
+				// On error or status code 3xx, 4xx, 5xx, etc abort any
+				// ongoing write, assuming that the server doesn't care
+				// about our request body. If the server replied with 1xx or
+				// 2xx, however, then assume the server DOES potentially
+				// want our body (e.g. full-duplex streaming:
+				// golang.org/issue/13444). If it turns out the server
+				// doesn't, they'll RST_STREAM us soon enough.  This is a
+				// heuristic to avoid adding knobs to Transport.  Hopefully
+				// we can keep it.
+				cs.abortRequestBodyWrite()
 			}
 			if re.err != nil {
 				return nil, re.err
 			}
-			res := re.res
 			res.Request = req
 			res.TLS = cc.tlsState
 			return res, nil
@@ -547,45 +563,56 @@
 	}
 }
 
-var errServerResponseBeforeRequestBody = errors.New("http2: server sent response while still writing request body")
+// errAbortReqBodyWrite is an internal error value.
+// It doesn't escape to callers.
+var errAbortReqBodyWrite = errors.New("http2: aborting request body write")
 
-func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan struct{}) error {
+func (cs *clientStream) writeRequestBody(body io.ReadCloser) (err error) {
 	cc := cs.cc
 	sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
 	buf := cc.frameScratchBuffer()
 	defer cc.putFrameScratchBuffer(buf)
 
-	for !sentEnd {
-		var sawEOF bool
-		n, err := io.ReadFull(body, buf)
-		if err == io.ErrUnexpectedEOF {
+	defer func() {
+		// TODO: write h12Compare test showing whether
+		// Request.Body is closed by the Transport,
+		// and in multiple cases: server replies <=299 and >299
+		// while still writing request body
+		cerr := body.Close()
+		if err == nil {
+			err = cerr
+		}
+	}()
+
+	var sawEOF bool
+	for !sawEOF {
+		n, err := body.Read(buf)
+		if err == io.EOF {
 			sawEOF = true
 			err = nil
-		} else if err == io.EOF {
-			break
 		} else if err != nil {
 			return err
 		}
 
-		toWrite := buf[:n]
-		for len(toWrite) > 0 && err == nil {
+		remain := buf[:n]
+		for len(remain) > 0 && err == nil {
 			var allowed int32
-			allowed, err = cs.awaitFlowControl(int32(len(toWrite)))
+			allowed, err = cs.awaitFlowControl(len(remain))
 			if err != nil {
 				return err
 			}
-
 			cc.wmu.Lock()
-			select {
-			case <-gotResHeaders:
-				err = errServerResponseBeforeRequestBody
-			case <-cs.peerReset:
-				err = cs.resetErr
-			default:
-				data := toWrite[:allowed]
-				toWrite = toWrite[allowed:]
-				sentEnd = sawEOF && len(toWrite) == 0
-				err = cc.fr.WriteData(cs.ID, sentEnd, data)
+			data := remain[:allowed]
+			remain = remain[allowed:]
+			sentEnd = sawEOF && len(remain) == 0
+			err = cc.fr.WriteData(cs.ID, sentEnd, data)
+			if err == nil {
+				// TODO(bradfitz): this flush is for latency, not bandwidth.
+				// Most requests won't need this. Make this opt-in or opt-out?
+				// Use some heuristic on the body type? Nagel-like timers?
+				// Based on 'n'? Only last chunk of this for loop, unless flow control
+				// tokens are low? For now, always:
+				err = cc.bw.Flush()
 			}
 			cc.wmu.Unlock()
 		}
@@ -594,8 +621,6 @@
 		}
 	}
 
-	var err error
-
 	cc.wmu.Lock()
 	if !sentEnd {
 		err = cc.fr.WriteData(cs.ID, true, nil)
@@ -612,7 +637,7 @@
 // control tokens from the server.
 // It returns either the non-zero number of tokens taken or an error
 // if the stream is dead.
-func (cs *clientStream) awaitFlowControl(maxBytes int32) (taken int32, err error) {
+func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) {
 	cc := cs.cc
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
@@ -620,13 +645,17 @@
 		if cc.closed {
 			return 0, errClientConnClosed
 		}
+		if cs.stopReqBody {
+			return 0, errAbortReqBodyWrite
+		}
 		if err := cs.checkReset(); err != nil {
 			return 0, err
 		}
 		if a := cs.flow.available(); a > 0 {
 			take := a
-			if take > maxBytes {
-				take = maxBytes
+			if int(take) > maxBytes {
+
+				take = int32(maxBytes) // can't truncate int; take is int32
 			}
 			if take > int32(cc.maxFrameSize) {
 				take = int32(cc.maxFrameSize)
@@ -1092,6 +1121,7 @@
 		cs.resetErr = err
 		close(cs.peerReset)
 		cs.bufPipe.CloseWithError(err)
+		cs.cc.cond.Broadcast() // wake up checkReset via clientStream.awaitFlowControl
 	}
 	delete(rl.activeRes, cs.ID)
 	return nil
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 8379157..0c875ac 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5,21 +5,29 @@
 package http2
 
 import (
+	"bufio"
+	"bytes"
 	"crypto/tls"
+	"errors"
 	"flag"
 	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
 	"math/rand"
 	"net"
 	"net/http"
 	"net/url"
 	"os"
 	"reflect"
+	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
+
+	"golang.org/x/net/http2/hpack"
 )
 
 var (
@@ -182,6 +190,8 @@
 		if !ok {
 			return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
 		}
+		cp.mu.Lock()
+		defer cp.mu.Unlock()
 		if len(cp.dialing) != 0 {
 			return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
 		}
@@ -456,3 +466,296 @@
 		t.Errorf("body = %q; want %q", got, want)
 	}
 }
+
+type capitalizeReader struct {
+	r io.Reader
+}
+
+func (cr capitalizeReader) Read(p []byte) (n int, err error) {
+	n, err = cr.r.Read(p)
+	for i, b := range p[:n] {
+		if b >= 'a' && b <= 'z' {
+			p[i] = b - ('a' - 'A')
+		}
+	}
+	return
+}
+
+type flushWriter struct {
+	w io.Writer
+}
+
+func (fw flushWriter) Write(p []byte) (n int, err error) {
+	n, err = fw.w.Write(p)
+	if f, ok := fw.w.(http.Flusher); ok {
+		f.Flush()
+	}
+	return
+}
+
+type clientTester struct {
+	t      *testing.T
+	tr     *Transport
+	sc, cc net.Conn // server and client conn
+	fr     *Framer  // server's framer
+	client func() error
+	server func() error
+}
+
+func newClientTester(t *testing.T) *clientTester {
+	var dialOnce struct {
+		sync.Mutex
+		dialed bool
+	}
+	ct := &clientTester{
+		t: t,
+	}
+	ct.tr = &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			dialOnce.Lock()
+			defer dialOnce.Unlock()
+			if dialOnce.dialed {
+				return nil, errors.New("only one dial allowed in test mode")
+			}
+			dialOnce.dialed = true
+			return ct.cc, nil
+		},
+	}
+
+	ln := newLocalListener(t)
+	cc, err := net.Dial("tcp", ln.Addr().String())
+	if err != nil {
+		t.Fatal(err)
+
+	}
+	sc, err := ln.Accept()
+	if err != nil {
+		t.Fatal(err)
+	}
+	ln.Close()
+	ct.cc = cc
+	ct.sc = sc
+	ct.fr = NewFramer(sc, sc)
+	return ct
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+	ln, err := net.Listen("tcp4", "127.0.0.1:0")
+	if err == nil {
+		return ln
+	}
+	ln, err = net.Listen("tcp6", "[::1]:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	return ln
+}
+
+func (ct *clientTester) greet() {
+	buf := make([]byte, len(ClientPreface))
+	_, err := io.ReadFull(ct.sc, buf)
+	if err != nil {
+		ct.t.Fatalf("reading client preface: %v", err)
+	}
+	f, err := ct.fr.ReadFrame()
+	if err != nil {
+		ct.t.Fatalf("Reading client settings frame: %v", err)
+	}
+	if sf, ok := f.(*SettingsFrame); !ok {
+		ct.t.Fatalf("Wanted client settings frame; got %v", f)
+		_ = sf // stash it away?
+	}
+	if err := ct.fr.WriteSettings(); err != nil {
+		ct.t.Fatal(err)
+	}
+	if err := ct.fr.WriteSettingsAck(); err != nil {
+		ct.t.Fatal(err)
+	}
+}
+
+func (ct *clientTester) run() {
+	errc := make(chan error, 2)
+	ct.start("client", errc, ct.client)
+	ct.start("server", errc, ct.server)
+	for i := 0; i < 2; i++ {
+		if err := <-errc; err != nil {
+			ct.t.Error(err)
+			return
+		}
+	}
+}
+
+func (ct *clientTester) start(which string, errc chan<- error, fn func() error) {
+	go func() {
+		finished := false
+		var err error
+		defer func() {
+			if !finished {
+				err = fmt.Errorf("%s goroutine didn't finish.", which)
+			} else if err != nil {
+				err = fmt.Errorf("%s: %v", which, err)
+			}
+			errc <- err
+		}()
+		err = fn()
+		finished = true
+	}()
+}
+
+type countingReader struct {
+	n *int64
+}
+
+func (r countingReader) Read(p []byte) (n int, err error) {
+	for i := range p {
+		p[i] = byte(i)
+	}
+	atomic.AddInt64(r.n, int64(len(p)))
+	return len(p), err
+}
+
+func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
+func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
+
+func testTransportReqBodyAfterResponse(t *testing.T, status int) {
+	const bodySize = 10 << 20
+	ct := newClientTester(t)
+	ct.client = func() error {
+		var n int64 // atomic
+		req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
+		if err != nil {
+			return err
+		}
+		res, err := ct.tr.RoundTrip(req)
+		if err != nil {
+			return fmt.Errorf("RoundTrip: %v", err)
+		}
+		defer res.Body.Close()
+		if res.StatusCode != status {
+			return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
+		}
+		slurp, err := ioutil.ReadAll(res.Body)
+		if err != nil {
+			return fmt.Errorf("Slurp: %v", err)
+		}
+		if len(slurp) > 0 {
+			return fmt.Errorf("unexpected body: %q", slurp)
+		}
+		if status == 200 {
+			if got := atomic.LoadInt64(&n); got != bodySize {
+				return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
+			}
+		} else {
+			if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
+				return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
+			}
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+		var dataRecv int64
+		var closed bool
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				return err
+			}
+			//println(fmt.Sprintf("server got frame: %v", f))
+			switch f := f.(type) {
+			case *WindowUpdateFrame, *SettingsFrame:
+			case *HeadersFrame:
+				if !f.HeadersEnded() {
+					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+				}
+				if f.StreamEnded() {
+					return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
+				}
+				time.Sleep(50 * time.Millisecond) // let client send body
+				enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
+				ct.fr.WriteHeaders(HeadersFrameParam{
+					StreamID:      f.StreamID,
+					EndHeaders:    true,
+					EndStream:     false,
+					BlockFragment: buf.Bytes(),
+				})
+			case *DataFrame:
+				dataLen := len(f.Data())
+				dataRecv += int64(dataLen)
+				if dataLen > 0 {
+					if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
+						return err
+					}
+					if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
+						return err
+					}
+				}
+				if !closed && ((status != 200 && dataRecv > 0) ||
+					(status == 200 && dataRecv == bodySize)) {
+					closed = true
+					if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
+						return err
+					}
+					return nil
+				}
+			default:
+				return fmt.Errorf("Unexpected client frame %v", f)
+			}
+		}
+		return nil
+	}
+	ct.run()
+}
+
+// See golang.org/issue/13444
+func TestTransportFullDuplex(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(200) // redundant but for clarity
+		w.(http.Flusher).Flush()
+		io.Copy(flushWriter{w}, capitalizeReader{r.Body})
+		fmt.Fprintf(w, "bye.\n")
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+	c := &http.Client{Transport: tr}
+
+	pr, pw := io.Pipe()
+	req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
+	if err != nil {
+		log.Fatal(err)
+	}
+	res, err := c.Do(req)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer res.Body.Close()
+	if res.StatusCode != 200 {
+		t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
+	}
+	bs := bufio.NewScanner(res.Body)
+	want := func(v string) {
+		if !bs.Scan() {
+			t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
+		}
+	}
+	write := func(v string) {
+		_, err := io.WriteString(pw, v)
+		if err != nil {
+			t.Fatalf("pipe write: %v", err)
+		}
+	}
+	write("foo\n")
+	want("FOO")
+	write("bar\n")
+	want("BAR")
+	pw.Close()
+	want("bye.")
+	if err := bs.Err(); err != nil {
+		t.Fatal(err)
+	}
+}