jwt: use RetrieveError for invalid status code errors

CL 84156 added oauth2.RetrieveError to the oauth2 and clientcredentials
packages, but missed using it in jwt.

Change-Id: I06d77cd18667526bfc869ebc1b5cc2bcbabc03a6
Reviewed-on: https://go-review.googlesource.com/85457
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/jwt/jwt.go b/jwt/jwt.go
index e016db4..e08f315 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -124,7 +124,10 @@
 		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
 	}
 	if c := resp.StatusCode; c < 200 || c > 299 {
-		return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
+		return nil, &oauth2.RetrieveError{
+			Response: resp,
+			Body:     body,
+		}
 	}
 	// tokenRes is the JSON response body.
 	var tokenRes struct {
diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go
index 9f82c71..1fbb9aa 100644
--- a/jwt/jwt_test.go
+++ b/jwt/jwt_test.go
@@ -8,11 +8,13 @@
 	"context"
 	"encoding/base64"
 	"encoding/json"
+	"fmt"
 	"net/http"
 	"net/http/httptest"
 	"strings"
 	"testing"
 
+	"golang.org/x/oauth2"
 	"golang.org/x/oauth2/jws"
 )
 
@@ -188,3 +190,32 @@
 		t.Errorf("access token header = %q; want %q", got, want)
 	}
 }
+
+func TestTokenRetrieveError(t *testing.T) {
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("Content-type", "application/json")
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte(`{"error": "invalid_grant"}`))
+	}))
+	defer ts.Close()
+
+	conf := &Config{
+		Email:      "aaa@xxx.com",
+		PrivateKey: dummyPrivateKey,
+		TokenURL:   ts.URL,
+	}
+
+	_, err := conf.TokenSource(context.Background()).Token()
+	if err == nil {
+		t.Fatalf("got no error, expected one")
+	}
+	_, ok := err.(*oauth2.RetrieveError)
+	if !ok {
+		t.Fatalf("got %T error, expected *RetrieveError", err)
+	}
+	// Test error string for backwards compatibility
+	expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
+	if errStr := err.Error(); errStr != expected {
+		t.Fatalf("got %#v, expected %#v", errStr, expected)
+	}
+}