google: support scopes for JWT access token

Change-Id: I11acd87a56cd003fdb68a5a687e37df450c400d1
GitHub-Last-Rev: efb2e8a08a8db0dc654298b90b814b3b7cb4d83d
GitHub-Pull-Request: golang/oauth2#504
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/327929
Trust: Shin Fan <shinfan@google.com>
Trust: Cody Oss <codyoss@google.com>
Run-TryBot: Shin Fan <shinfan@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
diff --git a/google/jwt.go b/google/jwt.go
index b0fdb3a..67d97b9 100644
--- a/google/jwt.go
+++ b/google/jwt.go
@@ -7,6 +7,7 @@
 import (
 	"crypto/rsa"
 	"fmt"
+	"strings"
 	"time"
 
 	"golang.org/x/oauth2"
@@ -24,6 +25,28 @@
 // optimization supported by a few Google services.
 // Unless you know otherwise, you should use JWTConfigFromJSON instead.
 func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) {
+	return newJWTSource(jsonKey, audience, nil)
+}
+
+// JWTAccessTokenSourceWithScope uses a Google Developers service account JSON
+// key file to read the credentials that authorize and authenticate the
+// requests, and returns a TokenSource that does not use any OAuth2 flow but
+// instead creates a JWT and sends that as the access token.
+// The scope is typically a list of URLs that specifies the scope of the
+// credentials.
+//
+// Note that this is not a standard OAuth flow, but rather an
+// optimization supported by a few Google services.
+// Unless you know otherwise, you should use JWTConfigFromJSON instead.
+func JWTAccessTokenSourceWithScope(jsonKey []byte, scope ...string) (oauth2.TokenSource, error) {
+	return newJWTSource(jsonKey, "", scope)
+}
+
+func newJWTSource(jsonKey []byte, audience string, scopes []string) (oauth2.TokenSource, error) {
+	if len(scopes) == 0 && audience == "" {
+		return nil, fmt.Errorf("google: missing scope/audience for JWT access token")
+	}
+
 	cfg, err := JWTConfigFromJSON(jsonKey)
 	if err != nil {
 		return nil, fmt.Errorf("google: could not parse JSON key: %v", err)
@@ -35,6 +58,7 @@
 	ts := &jwtAccessTokenSource{
 		email:    cfg.Email,
 		audience: audience,
+		scopes:   scopes,
 		pk:       pk,
 		pkID:     cfg.PrivateKeyID,
 	}
@@ -47,6 +71,7 @@
 
 type jwtAccessTokenSource struct {
 	email, audience string
+	scopes          []string
 	pk              *rsa.PrivateKey
 	pkID            string
 }
@@ -54,12 +79,14 @@
 func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) {
 	iat := time.Now()
 	exp := iat.Add(time.Hour)
+	scope := strings.Join(ts.scopes, " ")
 	cs := &jws.ClaimSet{
-		Iss: ts.email,
-		Sub: ts.email,
-		Aud: ts.audience,
-		Iat: iat.Unix(),
-		Exp: exp.Unix(),
+		Iss:   ts.email,
+		Sub:   ts.email,
+		Aud:   ts.audience,
+		Scope: scope,
+		Iat:   iat.Unix(),
+		Exp:   exp.Unix(),
 	}
 	hdr := &jws.Header{
 		Algorithm: "RS256",
diff --git a/google/jwt_test.go b/google/jwt_test.go
index f844436..043f445 100644
--- a/google/jwt_test.go
+++ b/google/jwt_test.go
@@ -13,29 +13,21 @@
 	"encoding/json"
 	"encoding/pem"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 
 	"golang.org/x/oauth2/jws"
 )
 
-func TestJWTAccessTokenSourceFromJSON(t *testing.T) {
-	// Generate a key we can use in the test data.
-	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
-	if err != nil {
-		t.Fatal(err)
-	}
+var (
+	privateKey *rsa.PrivateKey
+	jsonKey    []byte
+	once       sync.Once
+)
 
-	// Encode the key and substitute into our example JSON.
-	enc := pem.EncodeToMemory(&pem.Block{
-		Type:  "PRIVATE KEY",
-		Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
-	})
-	enc, err = json.Marshal(string(enc))
-	if err != nil {
-		t.Fatalf("json.Marshal: %v", err)
-	}
-	jsonKey := bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1)
+func TestJWTAccessTokenSourceFromJSON(t *testing.T) {
+	setupDummyKey(t)
 
 	ts, err := JWTAccessTokenSourceFromJSON(jsonKey, "audience")
 	if err != nil {
@@ -89,3 +81,80 @@
 		t.Errorf("Header KeyID = %q, want %q", got, want)
 	}
 }
+
+func TestJWTAccessTokenSourceWithScope(t *testing.T) {
+	setupDummyKey(t)
+
+	ts, err := JWTAccessTokenSourceWithScope(jsonKey, "scope1", "scope2")
+	if err != nil {
+		t.Fatalf("JWTAccessTokenSourceWithScope: %v\nJSON: %s", err, string(jsonKey))
+	}
+
+	tok, err := ts.Token()
+	if err != nil {
+		t.Fatalf("Token: %v", err)
+	}
+
+	if got, want := tok.TokenType, "Bearer"; got != want {
+		t.Errorf("TokenType = %q, want %q", got, want)
+	}
+	if got := tok.Expiry; tok.Expiry.Before(time.Now()) {
+		t.Errorf("Expiry = %v, should not be expired", got)
+	}
+
+	err = jws.Verify(tok.AccessToken, &privateKey.PublicKey)
+	if err != nil {
+		t.Errorf("jws.Verify on AccessToken: %v", err)
+	}
+
+	claim, err := jws.Decode(tok.AccessToken)
+	if err != nil {
+		t.Fatalf("jws.Decode on AccessToken: %v", err)
+	}
+
+	if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want {
+		t.Errorf("Iss = %q, want %q", got, want)
+	}
+	if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want {
+		t.Errorf("Sub = %q, want %q", got, want)
+	}
+	if got, want := claim.Scope, "scope1 scope2"; got != want {
+		t.Errorf("Aud = %q, want %q", got, want)
+	}
+
+	// Finally, check the header private key.
+	parts := strings.Split(tok.AccessToken, ".")
+	hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
+	if err != nil {
+		t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0])
+	}
+	var hdr jws.Header
+	if err := json.Unmarshal([]byte(hdrJSON), &hdr); err != nil {
+		t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON)
+	}
+
+	if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want {
+		t.Errorf("Header KeyID = %q, want %q", got, want)
+	}
+}
+
+func setupDummyKey(t *testing.T) {
+	once.Do(func() {
+		// Generate a key we can use in the test data.
+		pk, err := rsa.GenerateKey(rand.Reader, 2048)
+		if err != nil {
+			t.Fatal(err)
+		}
+		privateKey = pk
+		// Encode the key and substitute into our example JSON.
+		enc := pem.EncodeToMemory(&pem.Block{
+			Type:  "PRIVATE KEY",
+			Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
+		})
+		enc, err = json.Marshal(string(enc))
+		if err != nil {
+			t.Fatalf("json.Marshal: %v", err)
+		}
+		jsonKey = bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1)
+	})
+}