oauth2, oauth2/google: add, use ReuseTokenSource

Token caching is now done whenever you make a Client, and
ReuseTokenSource is exported from the oauth2 package and used by the
Google TokenSources (Compute and App Engine).

Token.Expired is now Token.Valid, and works on nil receivers.

Some other wording cleanups in the process.

All tests pass. App Engine should pass, but is untested.

Change-Id: Ibe1d2599ac3ccfe9b399b1672f74bb24cfc8d311
Reviewed-on: https://go-review.googlesource.com/2195
Reviewed-by: Burcu Dogan <jbd@google.com>
diff --git a/example_test.go b/example_test.go
index cb4726f..d26a3dc 100644
--- a/example_test.go
+++ b/example_test.go
@@ -50,7 +50,6 @@
 }
 
 func ExampleJWTConfig() {
-	var initialToken *oauth2.Token // nil means no initial token
 	conf := &oauth2.JWTConfig{
 		Email: "xxx@developer.com",
 		// The contents of your RSA private key or your PEM file
@@ -67,6 +66,6 @@
 	}
 	// Initiate an http.Client, the following GET request will be
 	// authorized and authenticated on the behalf of user@example.com.
-	client := conf.Client(oauth2.NoContext, initialToken)
+	client := conf.Client(oauth2.NoContext)
 	client.Get("...")
 }
diff --git a/google/example_test.go b/google/example_test.go
index 6d21d5e..a59cfe9 100644
--- a/google/example_test.go
+++ b/google/example_test.go
@@ -69,7 +69,7 @@
 	// Initiate an http.Client. The following GET request will be
 	// authorized and authenticated on the behalf of
 	// your service account.
-	client := conf.Client(oauth2.NoContext, nil)
+	client := conf.Client(oauth2.NoContext)
 	client.Get("...")
 }
 
@@ -101,7 +101,7 @@
 	}
 	// Initiate an http.Client, the following GET request will be
 	// authorized and authenticated on the behalf of user@example.com.
-	client := conf.Client(oauth2.NoContext, nil)
+	client := conf.Client(oauth2.NoContext)
 	client.Get("...")
 }
 
diff --git a/google/google.go b/google/google.go
index 4890776..eb6c92a 100644
--- a/google/google.go
+++ b/google/google.go
@@ -15,7 +15,6 @@
 
 import (
 	"encoding/json"
-
 	"fmt"
 	"net"
 	"net/http"
@@ -24,6 +23,9 @@
 	"golang.org/x/oauth2"
 )
 
+// TODO(bradfitz,jbd): import "google.golang.org/cloud/compute/metadata" instead of
+// the metaClient and metadata.google.internal stuff below.
+
 // Endpoint is Google's OAuth 2.0 endpoint.
 var Endpoint = oauth2.Endpoint{
 	AuthURL:  "https://accounts.google.com/o/oauth2/auth",
@@ -66,7 +68,7 @@
 // Further information about retrieving access tokens from the GCE metadata
 // server can be found at https://cloud.google.com/compute/docs/authentication.
 func ComputeTokenSource(account string) oauth2.TokenSource {
-	return &computeSource{account: account}
+	return oauth2.ReuseTokenSource(nil, &computeSource{account: account})
 }
 
 type computeSource struct {
diff --git a/google/source_appengine.go b/google/source_appengine.go
index 9b8aa97..d0eb3da 100644
--- a/google/source_appengine.go
+++ b/google/source_appengine.go
@@ -29,13 +29,16 @@
 }
 
 type appEngineTokenSource struct {
-	ctx    oauth2.Context
-	scopes []string
-	key    string // guarded by package-level mutex, aeTokensMu
+	ctx oauth2.Context
 
-	// fetcherFunc makes the actual RPC to fetch a new access token with an expiry time.
-	// Provider of this function is responsible to assert that the given context is valid.
-	fetcherFunc func(ctx oauth2.Context, scope ...string) (string, time.Time, error)
+	// fetcherFunc makes the actual RPC to fetch a new access
+	// token with an expiry time.  Provider of this function is
+	// responsible to assert that the given context is valid.
+	fetcherFunc func(ctx oauth2.Context, scope ...string) (accessToken string, expiry time.Time, err error)
+
+	// scopes and key are guarded by the package-level mutex aeTokensMu
+	scopes []string
+	key    string
 }
 
 func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
@@ -53,7 +56,7 @@
 
 	tok.mu.Lock()
 	defer tok.mu.Unlock()
-	if tok.t != nil && !tok.t.Expired() {
+	if tok.t.Valid() {
 		return tok.t, nil
 	}
 	access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...)
diff --git a/jwt.go b/jwt.go
index d7ec0d3..9507671 100644
--- a/jwt.go
+++ b/jwt.go
@@ -52,33 +52,21 @@
 
 // TokenSource returns a JWT TokenSource using the configuration
 // in c and the HTTP client from the provided context.
-//
-// The returned TokenSource only does JWT requests when necessary but
-// otherwise returns the same token repeatedly until it expires.
-//
-// The provided initialToken may be nil, in which case the first
-// call to TokenSource will do a new JWT request.
-func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource {
-	return &newWhenNeededSource{
-		t:   initialToken,
-		new: jwtSource{ctx, c},
-	}
+func (c *JWTConfig) TokenSource(ctx Context) TokenSource {
+	return ReuseTokenSource(nil, jwtSource{ctx, c})
 }
 
 // Client returns an HTTP client wrapping the context's
 // HTTP transport and adding Authorization headers with tokens
 // obtained from c.
 //
-// The provided initialToken may be nil, in which case the first
-// call to TokenSource will do a new JWT request.
-//
 // The returned client and its Transport should not be modified.
-func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client {
-	return NewClient(ctx, c.TokenSource(ctx, initialToken))
+func (c *JWTConfig) Client(ctx Context) *http.Client {
+	return NewClient(ctx, c.TokenSource(ctx))
 }
 
 // jwtSource is a source that always does a signed JWT request for a token.
-// It should typically be wrapped with a newWhenNeededSource.
+// It should typically be wrapped with a reuseTokenSource.
 type jwtSource struct {
 	ctx  Context
 	conf *JWTConfig
diff --git a/jwt_test.go b/jwt_test.go
index 8c2e62e..e9a732c 100644
--- a/jwt_test.go
+++ b/jwt_test.go
@@ -55,12 +55,12 @@
 		PrivateKey: dummyPrivateKey,
 		TokenURL:   ts.URL,
 	}
-	tok, err := conf.TokenSource(NoContext, nil).Token()
+	tok, err := conf.TokenSource(NoContext).Token()
 	if err != nil {
 		t.Fatal(err)
 	}
-	if tok.Expired() {
-		t.Errorf("Token shouldn't be expired")
+	if !tok.Valid() {
+		t.Errorf("Token invalid")
 	}
 	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
 		t.Errorf("Unexpected access token, %#v", tok.AccessToken)
@@ -89,19 +89,25 @@
 		PrivateKey: dummyPrivateKey,
 		TokenURL:   ts.URL,
 	}
-	tok, err := conf.TokenSource(NoContext, nil).Token()
+	tok, err := conf.TokenSource(NoContext).Token()
 	if err != nil {
 		t.Fatal(err)
 	}
-	if tok.AccessToken != "" {
-		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
+	if tok == nil {
+		t.Fatalf("token is nil")
 	}
-	if tok.TokenType != "bearer" {
-		t.Errorf("Unexpected token type, %#v.", tok.TokenType)
+	if tok.Valid() {
+		t.Errorf("token is valid. want invalid.")
+	}
+	if tok.AccessToken != "" {
+		t.Errorf("Unexpected non-empty access token %q.", tok.AccessToken)
+	}
+	if want := "bearer"; tok.TokenType != want {
+		t.Errorf("TokenType = %q; want %q", tok.TokenType, want)
 	}
 	scope := tok.Extra("scope")
-	if scope != "user" {
-		t.Errorf("Unexpected value for scope: %v", scope)
+	if want := "user"; scope != want {
+		t.Errorf("token scope = %q; want %q", scope, want)
 	}
 }
 
@@ -116,7 +122,7 @@
 		PrivateKey: dummyPrivateKey,
 		TokenURL:   ts.URL,
 	}
-	tok, err := conf.TokenSource(NoContext, nil).Token()
+	tok, err := conf.TokenSource(NoContext).Token()
 	if err == nil {
 		t.Error("got a token; expected error")
 		if tok.AccessToken != "" {
diff --git a/oauth2.go b/oauth2.go
index 3644683..5f2b145 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -26,7 +26,6 @@
 )
 
 // Context can be an golang.org/x/net.Context, or an App Engine Context.
-// In the future these will be unified.
 // If you don't care and aren't running on App Engine, you may use NoContext.
 type Context interface{}
 
@@ -36,7 +35,7 @@
 var NoContext Context = nil
 
 // Config describes a typical 3-legged OAuth2 flow, with both the
-// client application information and the server's URLs.
+// client application information and the server's endpoint URLs.
 type Config struct {
 	// ClientID is the application's ID.
 	ClientID string
@@ -45,9 +44,9 @@
 	ClientSecret string
 
 	// Endpoint contains the resource server's token endpoint
-	// URLs.  These are supplied by the server and are often
-	// available via site-specific packages (for example,
-	// google.Endpoint or github.Endpoint)
+	// URLs. These are constants specific to each server and are
+	// often available via site-specific packages, such as
+	// google.Endpoint or github.Endpoint.
 	Endpoint Endpoint
 
 	// RedirectURL is the URL to redirect users going through
@@ -61,6 +60,7 @@
 // A TokenSource is anything that can return a token.
 type TokenSource interface {
 	// Token returns a token or an error.
+	// Token must be safe for concurrent use by multiple goroutines.
 	Token() (*Token, error)
 }
 
@@ -208,7 +208,7 @@
 //
 // Most users will use Config.Client instead.
 func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
-	nwn := &newWhenNeededSource{t: t}
+	nwn := &reuseTokenSource{t: t}
 	nwn.new = tokenRefresher{
 		ctx:      ctx,
 		conf:     c,
@@ -239,13 +239,13 @@
 	})
 }
 
-// newWhenNeededSource is a TokenSource that holds a single token in memory
+// reuseTokenSource is a TokenSource that holds a single token in memory
 // and validates its expiry before each call to retrieve it with
 // Token. If it's expired, it will be auto-refreshed using the
 // new TokenSource.
 //
 // The first call to TokenRefresher must be SetToken.
-type newWhenNeededSource struct {
+type reuseTokenSource struct {
 	new TokenSource // called when t is expired.
 
 	mu sync.Mutex // guards t
@@ -255,10 +255,10 @@
 // Token returns the current token if it's still valid, else will
 // refresh the current token (using r.Context for HTTP client
 // information) and return the new one.
-func (s *newWhenNeededSource) Token() (*Token, error) {
+func (s *reuseTokenSource) Token() (*Token, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	if s.t != nil && !s.t.Expired() {
+	if s.t.Valid() {
 		return s.t, nil
 	}
 	t, err := s.new.Token()
@@ -410,12 +410,41 @@
 type contextKey struct{}
 
 // NewClient creates an *http.Client from a Context and TokenSource.
-// The client's lifetime does not extend beyond the lifetime of the context.
+// The returned client is not valid beyond the lifetime of the context.
 func NewClient(ctx Context, src TokenSource) *http.Client {
 	return &http.Client{
 		Transport: &Transport{
 			Base:   contextTransport(ctx),
-			Source: src,
+			Source: ReuseTokenSource(nil, src),
 		},
 	}
 }
+
+// ReuseTokenSource returns a TokenSource which repeatedly returns the
+// same token as long as it's valid, starting with t.
+// When its cached token is invalid, a new token is obtained from src.
+//
+// ReuseTokenSource is typically used to reuse tokens from a cache
+// (such as a file on disk) between runs of a program, rather than
+// obtaining new tokens unnecessarily.
+//
+// The initial token t may be nil, in which case the TokenSource is
+// wrapped in a caching version if it isn't one already. This also
+// means it's always safe to wrap ReuseTokenSource around any other
+// TokenSource without adverse effects.
+func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
+	// Don't wrap a reuseTokenSource in itself. That would work,
+	// but cause an unnecessary number of mutex operations.
+	// Just build the equivalent one.
+	if rt, ok := src.(*reuseTokenSource); ok {
+		if t == nil {
+			// Just use it directly.
+			return rt
+		}
+		src = rt.new
+	}
+	return &reuseTokenSource{
+		t:   t,
+		new: src,
+	}
+}
diff --git a/oauth2_test.go b/oauth2_test.go
index c567c3a..804098a 100644
--- a/oauth2_test.go
+++ b/oauth2_test.go
@@ -99,8 +99,8 @@
 	if err != nil {
 		t.Error(err)
 	}
-	if tok.Expired() {
-		t.Errorf("Token shouldn't be expired.")
+	if !tok.Valid() {
+		t.Fatalf("Token invalid. Got: %#v", tok)
 	}
 	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
 		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
@@ -143,8 +143,8 @@
 	if err != nil {
 		t.Error(err)
 	}
-	if tok.Expired() {
-		t.Errorf("Token shouldn't be expired.")
+	if !tok.Valid() {
+		t.Fatalf("Token invalid. Got: %#v", tok)
 	}
 	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
 		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
diff --git a/token.go b/token.go
index 0c52888..6aa0b41 100644
--- a/token.go
+++ b/token.go
@@ -74,14 +74,16 @@
 	return ""
 }
 
-// Expired returns true if there is no access token or the
-// access token is expired.
-func (t *Token) Expired() bool {
-	if t.AccessToken == "" {
-		return true
-	}
+// expired reports whether the token is expired.
+// t must be non-nil.
+func (t *Token) expired() bool {
 	if t.Expiry.IsZero() {
 		return false
 	}
 	return t.Expiry.Before(time.Now())
 }
+
+// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
+func (t *Token) Valid() bool {
+	return t != nil && t.AccessToken != "" && !t.expired()
+}
diff --git a/transport_test.go b/transport_test.go
index b3414e3..efb8232 100644
--- a/transport_test.go
+++ b/transport_test.go
@@ -32,10 +32,10 @@
 	client.Get(server.URL)
 }
 
-func TestExpiredWithNoAccessToken(t *testing.T) {
+func TestTokenValidNoAccessToken(t *testing.T) {
 	token := &Token{}
-	if !token.Expired() {
-		t.Errorf("Token should be expired if no access token is provided")
+	if token.Valid() {
+		t.Errorf("Token should not be valid with no access token")
 	}
 }
 
@@ -43,8 +43,8 @@
 	token := &Token{
 		Expiry: time.Now().Add(-5 * time.Hour),
 	}
-	if !token.Expired() {
-		t.Errorf("Token should be expired if no access token is provided")
+	if token.Valid() {
+		t.Errorf("Token should not be valid if it expired in the past")
 	}
 }