http2: make Transport send a Content-Length
Same policy and logic (and comments) as the net/http.Transport.
Updates golang/go#14003
Change-Id: I5744140fed16c00b0dc9a4bc74631b7df7d8241c
Reviewed-on: https://go-review.googlesource.com/18709
Reviewed-by: Andrew Gerrand <adg@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index fc50240..6d12725 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -551,6 +551,28 @@
}
hasTrailers := trailers != ""
+ var body io.Reader = req.Body
+ contentLen := req.ContentLength
+ if req.Body != nil && contentLen == 0 {
+ // Test to see if it's actually zero or just unset.
+ var buf [1]byte
+ n, rerr := io.ReadFull(body, buf[:])
+ if rerr != nil && rerr != io.EOF {
+ contentLen = -1
+ body = errorReader{rerr}
+ } else if n == 1 {
+ // Oh, guess there is data in this Body Reader after all.
+ // The ContentLength field just wasn't set.
+ // Stich the Body back together again, re-attaching our
+ // consumed byte.
+ contentLen = -1
+ body = io.MultiReader(bytes.NewReader(buf[:]), body)
+ } else {
+ // Body is actually empty.
+ body = nil
+ }
+ }
+
cc.mu.Lock()
if cc.closed || !cc.canTakeNewRequestLocked() {
cc.mu.Unlock()
@@ -559,7 +581,7 @@
cs := cc.newStream()
cs.req = req
- hasBody := req.Body != nil
+ hasBody := body != nil
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
@@ -584,7 +606,7 @@
// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
// sent by writeRequestBody below, along with any Trailers,
// again in form HEADERS{1}, CONTINUATION{0,})
- hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers)
+ hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
cc.wmu.Lock()
endStream := !hasBody && !hasTrailers
werr := cc.writeHeaders(cs.ID, endStream, hdrs)
@@ -605,7 +627,7 @@
if hasBody {
bodyCopyErrc = make(chan error, 1)
go func() {
- bodyCopyErrc <- cs.writeRequestBody(req.Body)
+ bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
}()
}
@@ -705,7 +727,7 @@
// It doesn't escape to callers.
var errAbortReqBodyWrite = errors.New("http2: aborting request body write")
-func (cs *clientStream) writeRequestBody(body io.ReadCloser) (err error) {
+func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
cc := cs.cc
sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
buf := cc.frameScratchBuffer()
@@ -716,7 +738,7 @@
// Request.Body is closed by the Transport,
// and in multiple cases: server replies <=299 and >299
// while still writing request body
- cerr := body.Close()
+ cerr := bodyCloser.Close()
if err == nil {
err = cerr
}
@@ -829,7 +851,7 @@
func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
// requires cc.mu be held.
-func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string) []byte {
+func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) []byte {
cc.hbuf.Reset()
host := req.Host
@@ -855,7 +877,7 @@
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
- if lowKey == "host" {
+ if lowKey == "host" || lowKey == "content-length" {
continue
}
if lowKey == "user-agent" {
@@ -876,6 +898,9 @@
cc.writeHeader(lowKey, v)
}
}
+ if contentLength >= 0 {
+ cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
+ }
if addGzipHeader {
cc.writeHeader("accept-encoding", "gzip")
}
@@ -1605,3 +1630,7 @@
func (gz *gzipReader) Close() error {
return gz.body.Close()
}
+
+type errorReader struct{ err error }
+
+func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
diff --git a/http2/transport_test.go b/http2/transport_test.go
index cd6f3df..dab483a 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -332,14 +332,19 @@
}
func TestTransportBody(t *testing.T) {
- gotc := make(chan interface{}, 1)
+ type reqInfo struct {
+ req *http.Request
+ slurp []byte
+ err error
+ }
+ gotc := make(chan reqInfo, 1)
st := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {
slurp, err := ioutil.ReadAll(r.Body)
if err != nil {
- gotc <- err
+ gotc <- reqInfo{err: err}
} else {
- gotc <- string(slurp)
+ gotc <- reqInfo{req: r, slurp: slurp}
}
},
optOnlyServer,
@@ -364,13 +369,21 @@
t.Fatalf("#%d: %v", i, err)
}
defer res.Body.Close()
- got := <-gotc
- if err, ok := got.(error); ok {
- t.Fatalf("#%d: %v", i, err)
- } else if got.(string) != tt.body {
- got := got.(string)
+ ri := <-gotc
+ if ri.err != nil {
+ t.Errorf("%#d: read error: %v", i, ri.err)
+ continue
+ }
+ if got := string(ri.slurp); got != tt.body {
t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
}
+ wantLen := int64(len(tt.body))
+ if tt.noContentLen && tt.body != "" {
+ wantLen = -1
+ }
+ if ri.req.ContentLength != wantLen {
+ t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
+ }
}
}
@@ -735,6 +748,7 @@
if err != nil {
log.Fatal(err)
}
+ req.ContentLength = -1
res, err := c.Do(req)
if err != nil {
log.Fatal(err)