golang.org/x/oauth2/jwt: Set kid to KeyID of private key
Set the KeyID hint in the token header. This allows remote servers to
identify the key used to sign the message.
Fixes #18307
Change-Id: Ib95398079833aad6b390650b465d7b09b5f53fda
Reviewed-on: https://go-review.googlesource.com/34320
Reviewed-by: Jaana Burcu Dogan <jbd@google.com>
diff --git a/jwt/jwt.go b/jwt/jwt.go
index f4b9523..e016db4 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -105,7 +105,9 @@
if t := js.conf.Expires; t > 0 {
claimSet.Exp = time.Now().Add(t).Unix()
}
- payload, err := jws.Encode(defaultHeader, claimSet, pk)
+ h := *defaultHeader
+ h.KeyID = js.conf.PrivateKeyID
+ payload, err := jws.Encode(&h, claimSet, pk)
if err != nil {
return nil, err
}
diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go
index a490af5..9f82c71 100644
--- a/jwt/jwt_test.go
+++ b/jwt/jwt_test.go
@@ -6,9 +6,14 @@
import (
"context"
+ "encoding/base64"
+ "encoding/json"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
+
+ "golang.org/x/oauth2/jws"
)
var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
@@ -131,3 +136,55 @@
}
}
}
+
+func TestJWTFetch_Assertion(t *testing.T) {
+ var assertion string
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ assertion = r.Form.Get("assertion")
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{
+ "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
+ "scope": "user",
+ "token_type": "bearer",
+ "expires_in": 3600
+ }`))
+ }))
+ defer ts.Close()
+
+ conf := &Config{
+ Email: "aaa@xxx.com",
+ PrivateKey: dummyPrivateKey,
+ PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
+ TokenURL: ts.URL,
+ }
+
+ _, err := conf.TokenSource(context.Background()).Token()
+ if err != nil {
+ t.Fatalf("Failed to fetch token: %v", err)
+ }
+
+ parts := strings.Split(assertion, ".")
+ if len(parts) != 3 {
+ t.Fatalf("assertion = %q; want 3 parts", assertion)
+ }
+ gotjson, err := base64.RawURLEncoding.DecodeString(parts[0])
+ if err != nil {
+ t.Fatalf("invalid token header; err = %v", err)
+ }
+
+ got := jws.Header{}
+ if err := json.Unmarshal(gotjson, &got); err != nil {
+ t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err)
+ }
+
+ want := jws.Header{
+ Algorithm: "RS256",
+ Typ: "JWT",
+ KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
+ }
+ if got != want {
+ t.Errorf("access token header = %q; want %q", got, want)
+ }
+}