oauth2: close request body if errors occur before base RoundTripper is invoked

Fixes golang/oauth#269

Change-Id: I25eb3273a0868a999a2e98961ae5e4040e44ad7a
Reviewed-on: https://go-review.googlesource.com/114956
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/transport.go b/transport.go
index 92ac7e2..c55bfa0 100644
--- a/transport.go
+++ b/transport.go
@@ -34,6 +34,15 @@
 // access token. If no token exists or token is expired,
 // tries to refresh/fetch a new token.
 func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
+	reqBodyClosed := false
+	if req.Body != nil {
+		defer func() {
+			if !reqBodyClosed {
+				req.Body.Close()
+			}
+		}()
+	}
+
 	if t.Source == nil {
 		return nil, errors.New("oauth2: Transport's Source is nil")
 	}
@@ -46,6 +55,10 @@
 	token.SetAuthHeader(req2)
 	t.setModReq(req, req2)
 	res, err := t.base().RoundTrip(req2)
+
+	// req.Body is assumed to have been closed by the base RoundTripper.
+	reqBodyClosed = true
+
 	if err != nil {
 		t.setModReq(req, nil)
 		return nil, err
diff --git a/transport_test.go b/transport_test.go
index d6e8087..faa87d5 100644
--- a/transport_test.go
+++ b/transport_test.go
@@ -1,6 +1,8 @@
 package oauth2
 
 import (
+	"errors"
+	"io"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -27,6 +29,64 @@
 	}
 }
 
+type readCloseCounter struct {
+	CloseCount int
+	ReadErr    error
+}
+
+func (r *readCloseCounter) Read(b []byte) (int, error) {
+	return 0, r.ReadErr
+}
+
+func (r *readCloseCounter) Close() error {
+	r.CloseCount++
+	return nil
+}
+
+func TestTransportCloseRequestBody(t *testing.T) {
+	tr := &Transport{}
+	server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
+	defer server.Close()
+	client := &http.Client{Transport: tr}
+	body := &readCloseCounter{
+		ReadErr: errors.New("readCloseCounter.Read not implemented"),
+	}
+	resp, err := client.Post(server.URL, "application/json", body)
+	if err == nil {
+		t.Errorf("got no errors, want an error with nil token source")
+	}
+	if resp != nil {
+		t.Errorf("Response = %v; want nil", resp)
+	}
+	if expected := 1; body.CloseCount != expected {
+		t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
+	}
+}
+
+func TestTransportCloseRequestBodySuccess(t *testing.T) {
+	tr := &Transport{
+		Source: StaticTokenSource(&Token{
+			AccessToken: "abc",
+		}),
+	}
+	server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
+	defer server.Close()
+	client := &http.Client{Transport: tr}
+	body := &readCloseCounter{
+		ReadErr: io.EOF,
+	}
+	resp, err := client.Post(server.URL, "application/json", body)
+	if err != nil {
+		t.Errorf("got error %v; expected none", err)
+	}
+	if resp == nil {
+		t.Errorf("Response is nil; expected non-nil")
+	}
+	if expected := 1; body.CloseCount != expected {
+		t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
+	}
+}
+
 func TestTransportTokenSource(t *testing.T) {
 	ts := &tokenSource{
 		token: &Token{