oauth2: close request body if errors occur before base RoundTripper is invoked
Fixes golang/oauth#269
Change-Id: I25eb3273a0868a999a2e98961ae5e4040e44ad7a
Reviewed-on: https://go-review.googlesource.com/114956
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/transport.go b/transport.go
index 92ac7e2..c55bfa0 100644
--- a/transport.go
+++ b/transport.go
@@ -34,6 +34,15 @@
// access token. If no token exists or token is expired,
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
+ reqBodyClosed := false
+ if req.Body != nil {
+ defer func() {
+ if !reqBodyClosed {
+ req.Body.Close()
+ }
+ }()
+ }
+
if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
}
@@ -46,6 +55,10 @@
token.SetAuthHeader(req2)
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)
+
+ // req.Body is assumed to have been closed by the base RoundTripper.
+ reqBodyClosed = true
+
if err != nil {
t.setModReq(req, nil)
return nil, err
diff --git a/transport_test.go b/transport_test.go
index d6e8087..faa87d5 100644
--- a/transport_test.go
+++ b/transport_test.go
@@ -1,6 +1,8 @@
package oauth2
import (
+ "errors"
+ "io"
"net/http"
"net/http/httptest"
"testing"
@@ -27,6 +29,64 @@
}
}
+type readCloseCounter struct {
+ CloseCount int
+ ReadErr error
+}
+
+func (r *readCloseCounter) Read(b []byte) (int, error) {
+ return 0, r.ReadErr
+}
+
+func (r *readCloseCounter) Close() error {
+ r.CloseCount++
+ return nil
+}
+
+func TestTransportCloseRequestBody(t *testing.T) {
+ tr := &Transport{}
+ server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
+ defer server.Close()
+ client := &http.Client{Transport: tr}
+ body := &readCloseCounter{
+ ReadErr: errors.New("readCloseCounter.Read not implemented"),
+ }
+ resp, err := client.Post(server.URL, "application/json", body)
+ if err == nil {
+ t.Errorf("got no errors, want an error with nil token source")
+ }
+ if resp != nil {
+ t.Errorf("Response = %v; want nil", resp)
+ }
+ if expected := 1; body.CloseCount != expected {
+ t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
+ }
+}
+
+func TestTransportCloseRequestBodySuccess(t *testing.T) {
+ tr := &Transport{
+ Source: StaticTokenSource(&Token{
+ AccessToken: "abc",
+ }),
+ }
+ server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
+ defer server.Close()
+ client := &http.Client{Transport: tr}
+ body := &readCloseCounter{
+ ReadErr: io.EOF,
+ }
+ resp, err := client.Post(server.URL, "application/json", body)
+ if err != nil {
+ t.Errorf("got error %v; expected none", err)
+ }
+ if resp == nil {
+ t.Errorf("Response is nil; expected non-nil")
+ }
+ if expected := 1; body.CloseCount != expected {
+ t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
+ }
+}
+
func TestTransportTokenSource(t *testing.T) {
ts := &tokenSource{
token: &Token{