http2: Discard data reads on HEAD requests

If a server returns a DATA frame while procesing a HEAD request, the
client will discard the data.

Fixes golang/go#22376

Change-Id: Ief9c17ddfe51cc17f7f6326c87330ac9d8b9d3ff
Reviewed-on: https://go-review.googlesource.com/72551
Run-TryBot: Tom Bergan <tombergan@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Tom Bergan <tombergan@google.com>
Reviewed-on: https://go-review.googlesource.com/88655
Run-TryBot: Andrew Bonventre <andybons@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 850d7ae..b6a5c7d 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -1703,6 +1703,14 @@
 		return nil
 	}
 	if f.Length > 0 {
+		if cs.req.Method == "HEAD" && len(data) > 0 {
+			cc.logf("protocol error: received DATA on a HEAD request")
+			rl.endStreamError(cs, StreamError{
+				StreamID: f.StreamID,
+				Code:     ErrCodeProtocol,
+			})
+			return nil
+		}
 		// Check connection-level flow control.
 		cc.mu.Lock()
 		if cs.inflow.available() >= int32(f.Length) {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 15dfa07..eff2103 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -2026,6 +2026,60 @@
 	ct.run()
 }
 
+func TestTransportReadHeadResponseWithBody(t *testing.T) {
+	response := "redirecting to /elsewhere"
+	ct := newClientTester(t)
+	clientDone := make(chan struct{})
+	ct.client = func() error {
+		defer close(clientDone)
+		req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
+		res, err := ct.tr.RoundTrip(req)
+		if err != nil {
+			return err
+		}
+		if res.ContentLength != int64(len(response)) {
+			return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response))
+		}
+		slurp, err := ioutil.ReadAll(res.Body)
+		if err != nil {
+			return fmt.Errorf("ReadAll: %v", err)
+		}
+		if len(slurp) > 0 {
+			return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				t.Logf("ReadFrame: %v", err)
+				return nil
+			}
+			hf, ok := f.(*HeadersFrame)
+			if !ok {
+				continue
+			}
+			var buf bytes.Buffer
+			enc := hpack.NewEncoder(&buf)
+			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+			enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))})
+			ct.fr.WriteHeaders(HeadersFrameParam{
+				StreamID:      hf.StreamID,
+				EndHeaders:    true,
+				EndStream:     false,
+				BlockFragment: buf.Bytes(),
+			})
+			ct.fr.WriteData(hf.StreamID, true, []byte(response))
+
+			<-clientDone
+			return nil
+		}
+	}
+	ct.run()
+}
+
 type neverEnding byte
 
 func (b neverEnding) Read(p []byte) (int, error) {