golang.org/x/oauth2/jwt: Set kid to KeyID of private key

Set the KeyID hint in the token header. This allows remote servers to
identify the key used to sign the message.

Fixes #18307

Change-Id: Ib95398079833aad6b390650b465d7b09b5f53fda
Reviewed-on: https://go-review.googlesource.com/34320
Reviewed-by: Jaana Burcu Dogan <jbd@google.com>
diff --git a/jwt/jwt.go b/jwt/jwt.go
index f4b9523..e016db4 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -105,7 +105,9 @@
 	if t := js.conf.Expires; t > 0 {
 		claimSet.Exp = time.Now().Add(t).Unix()
 	}
-	payload, err := jws.Encode(defaultHeader, claimSet, pk)
+	h := *defaultHeader
+	h.KeyID = js.conf.PrivateKeyID
+	payload, err := jws.Encode(&h, claimSet, pk)
 	if err != nil {
 		return nil, err
 	}
diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go
index a490af5..9f82c71 100644
--- a/jwt/jwt_test.go
+++ b/jwt/jwt_test.go
@@ -6,9 +6,14 @@
 
 import (
 	"context"
+	"encoding/base64"
+	"encoding/json"
 	"net/http"
 	"net/http/httptest"
+	"strings"
 	"testing"
+
+	"golang.org/x/oauth2/jws"
 )
 
 var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
@@ -131,3 +136,55 @@
 		}
 	}
 }
+
+func TestJWTFetch_Assertion(t *testing.T) {
+	var assertion string
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		r.ParseForm()
+		assertion = r.Form.Get("assertion")
+
+		w.Header().Set("Content-Type", "application/json")
+		w.Write([]byte(`{
+			"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
+			"scope": "user",
+			"token_type": "bearer",
+			"expires_in": 3600
+		}`))
+	}))
+	defer ts.Close()
+
+	conf := &Config{
+		Email:        "aaa@xxx.com",
+		PrivateKey:   dummyPrivateKey,
+		PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
+		TokenURL:     ts.URL,
+	}
+
+	_, err := conf.TokenSource(context.Background()).Token()
+	if err != nil {
+		t.Fatalf("Failed to fetch token: %v", err)
+	}
+
+	parts := strings.Split(assertion, ".")
+	if len(parts) != 3 {
+		t.Fatalf("assertion = %q; want 3 parts", assertion)
+	}
+	gotjson, err := base64.RawURLEncoding.DecodeString(parts[0])
+	if err != nil {
+		t.Fatalf("invalid token header; err = %v", err)
+	}
+
+	got := jws.Header{}
+	if err := json.Unmarshal(gotjson, &got); err != nil {
+		t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err)
+	}
+
+	want := jws.Header{
+		Algorithm: "RS256",
+		Typ:       "JWT",
+		KeyID:     "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
+	}
+	if got != want {
+		t.Errorf("access token header = %q; want %q", got, want)
+	}
+}