http2: make Transport send a Content-Length

Same policy and logic (and comments) as the net/http.Transport.

Updates golang/go#14003

Change-Id: I5744140fed16c00b0dc9a4bc74631b7df7d8241c
Reviewed-on: https://go-review.googlesource.com/18709
Reviewed-by: Andrew Gerrand <adg@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index fc50240..6d12725 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -551,6 +551,28 @@
 	}
 	hasTrailers := trailers != ""
 
+	var body io.Reader = req.Body
+	contentLen := req.ContentLength
+	if req.Body != nil && contentLen == 0 {
+		// Test to see if it's actually zero or just unset.
+		var buf [1]byte
+		n, rerr := io.ReadFull(body, buf[:])
+		if rerr != nil && rerr != io.EOF {
+			contentLen = -1
+			body = errorReader{rerr}
+		} else if n == 1 {
+			// Oh, guess there is data in this Body Reader after all.
+			// The ContentLength field just wasn't set.
+			// Stich the Body back together again, re-attaching our
+			// consumed byte.
+			contentLen = -1
+			body = io.MultiReader(bytes.NewReader(buf[:]), body)
+		} else {
+			// Body is actually empty.
+			body = nil
+		}
+	}
+
 	cc.mu.Lock()
 	if cc.closed || !cc.canTakeNewRequestLocked() {
 		cc.mu.Unlock()
@@ -559,7 +581,7 @@
 
 	cs := cc.newStream()
 	cs.req = req
-	hasBody := req.Body != nil
+	hasBody := body != nil
 
 	// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
 	if !cc.t.disableCompression() &&
@@ -584,7 +606,7 @@
 	// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
 	// sent by writeRequestBody below, along with any Trailers,
 	// again in form HEADERS{1}, CONTINUATION{0,})
-	hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers)
+	hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
 	cc.wmu.Lock()
 	endStream := !hasBody && !hasTrailers
 	werr := cc.writeHeaders(cs.ID, endStream, hdrs)
@@ -605,7 +627,7 @@
 	if hasBody {
 		bodyCopyErrc = make(chan error, 1)
 		go func() {
-			bodyCopyErrc <- cs.writeRequestBody(req.Body)
+			bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
 		}()
 	}
 
@@ -705,7 +727,7 @@
 // It doesn't escape to callers.
 var errAbortReqBodyWrite = errors.New("http2: aborting request body write")
 
-func (cs *clientStream) writeRequestBody(body io.ReadCloser) (err error) {
+func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
 	cc := cs.cc
 	sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
 	buf := cc.frameScratchBuffer()
@@ -716,7 +738,7 @@
 		// Request.Body is closed by the Transport,
 		// and in multiple cases: server replies <=299 and >299
 		// while still writing request body
-		cerr := body.Close()
+		cerr := bodyCloser.Close()
 		if err == nil {
 			err = cerr
 		}
@@ -829,7 +851,7 @@
 func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
 
 // requires cc.mu be held.
-func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string) []byte {
+func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) []byte {
 	cc.hbuf.Reset()
 
 	host := req.Host
@@ -855,7 +877,7 @@
 	var didUA bool
 	for k, vv := range req.Header {
 		lowKey := strings.ToLower(k)
-		if lowKey == "host" {
+		if lowKey == "host" || lowKey == "content-length" {
 			continue
 		}
 		if lowKey == "user-agent" {
@@ -876,6 +898,9 @@
 			cc.writeHeader(lowKey, v)
 		}
 	}
+	if contentLength >= 0 {
+		cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
+	}
 	if addGzipHeader {
 		cc.writeHeader("accept-encoding", "gzip")
 	}
@@ -1605,3 +1630,7 @@
 func (gz *gzipReader) Close() error {
 	return gz.body.Close()
 }
+
+type errorReader struct{ err error }
+
+func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
diff --git a/http2/transport_test.go b/http2/transport_test.go
index cd6f3df..dab483a 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -332,14 +332,19 @@
 }
 
 func TestTransportBody(t *testing.T) {
-	gotc := make(chan interface{}, 1)
+	type reqInfo struct {
+		req   *http.Request
+		slurp []byte
+		err   error
+	}
+	gotc := make(chan reqInfo, 1)
 	st := newServerTester(t,
 		func(w http.ResponseWriter, r *http.Request) {
 			slurp, err := ioutil.ReadAll(r.Body)
 			if err != nil {
-				gotc <- err
+				gotc <- reqInfo{err: err}
 			} else {
-				gotc <- string(slurp)
+				gotc <- reqInfo{req: r, slurp: slurp}
 			}
 		},
 		optOnlyServer,
@@ -364,13 +369,21 @@
 			t.Fatalf("#%d: %v", i, err)
 		}
 		defer res.Body.Close()
-		got := <-gotc
-		if err, ok := got.(error); ok {
-			t.Fatalf("#%d: %v", i, err)
-		} else if got.(string) != tt.body {
-			got := got.(string)
+		ri := <-gotc
+		if ri.err != nil {
+			t.Errorf("%#d: read error: %v", i, ri.err)
+			continue
+		}
+		if got := string(ri.slurp); got != tt.body {
 			t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
 		}
+		wantLen := int64(len(tt.body))
+		if tt.noContentLen && tt.body != "" {
+			wantLen = -1
+		}
+		if ri.req.ContentLength != wantLen {
+			t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
+		}
 	}
 }
 
@@ -735,6 +748,7 @@
 	if err != nil {
 		log.Fatal(err)
 	}
+	req.ContentLength = -1
 	res, err := c.Do(req)
 	if err != nil {
 		log.Fatal(err)