http2: another Transport body-writing bug fix, and more tests
Change-Id: I700832c477a38ab11da39a382186bdc7d3d3186e
Reviewed-on: https://go-review.googlesource.com/16445
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
diff --git a/http2/transport.go b/http2/transport.go
index 0db6a4e..6f0d125 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -542,14 +542,15 @@
func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan struct{}) error {
cc := cs.cc
- done := false
+ sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
buf := cc.frameScratchBuffer()
defer cc.putFrameScratchBuffer(buf)
- for !done {
+ for !sentEnd {
+ var sawEOF bool
n, err := io.ReadFull(body, buf)
if err == io.ErrUnexpectedEOF {
- done = true
+ sawEOF = true
err = nil
} else if err == io.EOF {
break
@@ -572,8 +573,10 @@
case <-cs.peerReset:
err = cs.resetErr
default:
- err = cc.fr.WriteData(cs.ID, done, toWrite[:allowed])
+ data := toWrite[:allowed]
toWrite = toWrite[allowed:]
+ sentEnd = sawEOF && len(toWrite) == 0
+ err = cc.fr.WriteData(cs.ID, sentEnd, data)
}
cc.wmu.Unlock()
}
@@ -585,7 +588,7 @@
var err error
cc.wmu.Lock()
- if !done {
+ if !sentEnd {
err = cc.fr.WriteData(cs.ID, true, nil)
}
if ferr := cc.bw.Flush(); ferr != nil && err == nil {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 790a061..6b23563 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -7,8 +7,10 @@
import (
"crypto/tls"
"flag"
+ "fmt"
"io"
"io/ioutil"
+ "math/rand"
"net"
"net/http"
"net/url"
@@ -206,6 +208,15 @@
}
}
+func randString(n int) string {
+ rnd := rand.New(rand.NewSource(int64(n)))
+ b := make([]byte, n)
+ for i := range b {
+ b[i] = byte(rnd.Intn(256))
+ }
+ return string(b)
+}
+
var bodyTests = []struct {
body string
noContentLen bool
@@ -216,6 +227,15 @@
{body: "", noContentLen: true},
{body: strings.Repeat("a", 1<<20), noContentLen: true},
{body: strings.Repeat("a", 1<<20)},
+ {body: randString(16<<10 - 1)},
+ {body: randString(16 << 10)},
+ {body: randString(16<<10 + 1)},
+ {body: randString(512<<10 - 1)},
+ {body: randString(512 << 10)},
+ {body: randString(512<<10 + 1)},
+ {body: randString(1<<20 - 1)},
+ {body: randString(1 << 20)},
+ {body: randString(1<<20 + 2)},
}
func TestTransportBody(t *testing.T) {
@@ -227,7 +247,6 @@
gotc <- err
} else {
gotc <- string(slurp)
-
}
},
optOnlyServer,
@@ -256,11 +275,20 @@
if err, ok := got.(error); ok {
t.Fatalf("#%d: %v", i, err)
} else if got.(string) != tt.body {
- t.Errorf("#%d: Read body = %q; want %q", i, got, tt.body)
+ got := got.(string)
+ t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
}
}
}
+func shortString(v string) string {
+ const maxLen = 100
+ if len(v) <= maxLen {
+ return v
+ }
+ return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
+}
+
func TestTransportDialTLS(t *testing.T) {
var mu sync.Mutex // guards following
var gotReq, didDial bool