oauth2: rewrite google package, fix the broken build
Change-Id: I2753a88d7be483bdbc0cac09a1beccc4806ea4bc
Reviewed-on: https://go-review.googlesource.com/1361
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>
diff --git a/example_test.go b/example_test.go
index fb8dd8e..cb4726f 100644
--- a/example_test.go
+++ b/example_test.go
@@ -7,7 +7,6 @@
import (
"fmt"
"log"
- "net/http"
"testing"
"golang.org/x/oauth2"
@@ -17,23 +16,20 @@
// Related to https://codereview.appspot.com/107320046
func TestA(t *testing.T) {}
-func Example_regular() {
- opts, err := oauth2.New(
- oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"),
- oauth2.RedirectURL("YOUR_REDIRECT_URL"),
- oauth2.Scope("SCOPE1", "SCOPE2"),
- oauth2.Endpoint(
- "https://provider.com/o/oauth2/auth",
- "https://provider.com/o/oauth2/token",
- ),
- )
- if err != nil {
- log.Fatal(err)
+func ExampleConfig() {
+ conf := &oauth2.Config{
+ ClientID: "YOUR_CLIENT_ID",
+ ClientSecret: "YOUR_CLIENT_SECRET",
+ Scopes: []string{"SCOPE1", "SCOPE2"},
+ Endpoint: oauth2.Endpoint{
+ AuthURL: "https://provider.com/o/oauth2/auth",
+ TokenURL: "https://provider.com/o/oauth2/token",
+ },
}
// Redirect user to consent page to ask for permission
// for the scopes specified above.
- url := opts.AuthCodeURL("state", "online", "auto")
+ url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
fmt.Printf("Visit the URL for the auth dialog: %v", url)
// Use the authorization code that is pushed to the redirect URL.
@@ -41,22 +37,22 @@
// an access token and initiate a Transport that is
// authorized and authenticated by the retrieved token.
var code string
- if _, err = fmt.Scan(&code); err != nil {
+ if _, err := fmt.Scan(&code); err != nil {
log.Fatal(err)
}
- t, err := opts.NewTransportFromCode(code)
+ tok, err := conf.Exchange(oauth2.NoContext, code)
if err != nil {
log.Fatal(err)
}
- // You can use t to initiate a new http.Client and
- // start making authenticated requests.
- client := http.Client{Transport: t}
+ client := conf.Client(oauth2.NoContext, tok)
client.Get("...")
}
-func Example_jWT() {
- opts, err := oauth2.New(
+func ExampleJWTConfig() {
+ var initialToken *oauth2.Token // nil means no initial token
+ conf := &oauth2.JWTConfig{
+ Email: "xxx@developer.com",
// The contents of your RSA private key or your PEM file
// that contains a private key.
// If you have a p12 file instead, you
@@ -65,23 +61,12 @@
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
// It only supports PEM containers with no passphrase.
- oauth2.JWTClient(
- "xxx@developer.gserviceaccount.com",
- []byte("-----BEGIN RSA PRIVATE KEY-----...")),
- oauth2.Scope("SCOPE1", "SCOPE2"),
- oauth2.JWTEndpoint("https://provider.com/o/oauth2/token"),
- // If you would like to impersonate a user, you can
- // create a transport with a subject. The following GET
- // request will be made on the behalf of user@example.com.
- // Subject is optional.
- oauth2.Subject("user@example.com"),
- )
- if err != nil {
- log.Fatal(err)
+ PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."),
+ Subject: "user@example.com",
+ TokenURL: "https://provider.com/o/oauth2/token",
}
-
// Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com.
- client := http.Client{Transport: opts.NewTransport()}
+ client := conf.Client(oauth2.NoContext, initialToken)
client.Get("...")
}
diff --git a/google/appengine.go b/google/appengine.go
index 0502693..c6213d9 100644
--- a/google/appengine.go
+++ b/google/appengine.go
@@ -7,108 +7,31 @@
package google
import (
- "net/http"
- "strings"
- "sync"
"time"
- "golang.org/x/oauth2"
-
"appengine"
- "appengine/memcache"
- "appengine/urlfetch"
+
+ "golang.org/x/oauth2"
)
-var (
- // memcacheGob enables mocking of the memcache.Gob calls for unit testing.
- memcacheGob memcacher = &aeMemcache{}
-
- // accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
- accessTokenFunc = appengine.AccessToken
-
- // mu protects multiple threads from attempting to fetch a token at the same time.
- mu sync.Mutex
-
- // tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
- tokens map[string]*oauth2.Token
-)
-
-// safetyMargin is used to avoid clock-skew problems.
-// 5 minutes is conservative because tokens are valid for 60 minutes.
-const safetyMargin = 5 * time.Minute
-
-func init() {
- tokens = make(map[string]*oauth2.Token)
-}
-
-// AppEngineContext requires an App Engine request context.
-func AppEngineContext(ctx appengine.Context) oauth2.Option {
- return func(opts *oauth2.Options) error {
- opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts)
- opts.Client = &http.Client{
- Transport: &urlfetch.Transport{Context: ctx},
- }
- return nil
+// AppEngineTokenSource returns a token source that fetches tokens
+// issued to the current App Engine application's service account.
+// If you are implementing a 3-legged OAuth 2.0 flow on App Engine
+// that involves user accounts, see oauth2.Config instead.
+//
+// You are required to provide a valid appengine.Context as context.
+func AppEngineTokenSource(ctx appengine.Context, scope ...string) oauth2.TokenSource {
+ return &appEngineTokenSource{
+ ctx: ctx,
+ scopes: scope,
+ fetcherFunc: aeFetcherFunc,
}
}
-// FetchToken fetches a new access token for the provided scopes.
-// Tokens are cached locally and also with Memcache so that the app can scale
-// without hitting quota limits by calling appengine.AccessToken too frequently.
-func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) {
- return func(existing *oauth2.Token) (*oauth2.Token, error) {
- mu.Lock()
- defer mu.Unlock()
-
- key := ":" + strings.Join(opts.Scopes, "_")
- now := time.Now().Add(safetyMargin)
- if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
- return t, nil
- }
- delete(tokens, key)
-
- // Attempt to get token from Memcache
- tok := new(oauth2.Token)
- _, err := memcacheGob.Get(ctx, key, tok)
- if err == nil && !tok.Expiry.Before(now) {
- tokens[key] = tok // Save token locally
- return tok, nil
- }
-
- token, expiry, err := accessTokenFunc(ctx, opts.Scopes...)
- if err != nil {
- return nil, err
- }
- t := &oauth2.Token{
- AccessToken: token,
- Expiry: expiry,
- }
- tokens[key] = t
- // Also back up token in Memcache
- if err = memcacheGob.Set(ctx, &memcache.Item{
- Key: key,
- Value: []byte{},
- Object: *t,
- Expiration: expiry.Sub(now),
- }); err != nil {
- ctx.Errorf("unexpected memcache.Set error: %v", err)
- }
- return t, nil
+var aeFetcherFunc = func(ctx oauth2.Context, scope ...string) (string, time.Time, error) {
+ c, ok := ctx.(appengine.Context)
+ if !ok {
+ return "", time.Time{}, errInvalidContext
}
-}
-
-// aeMemcache wraps the needed Memcache functionality to make it easy to mock
-type aeMemcache struct{}
-
-func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
- return memcache.Gob.Get(c, key, tok)
-}
-
-func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error {
- return memcache.Gob.Set(c, item)
-}
-
-type memcacher interface {
- Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
- Set(c appengine.Context, item *memcache.Item) error
+ return appengine.AccessToken(c, scope...)
}
diff --git a/google/appengine_test.go b/google/appengine_test.go
deleted file mode 100644
index 2c07ce4..0000000
--- a/google/appengine_test.go
+++ /dev/null
@@ -1,266 +0,0 @@
-// Copyright 2014 The oauth2 Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build appengine,!appenginevm
-
-package google
-
-import (
- "fmt"
- "log"
- "net/http"
- "sync"
- "testing"
- "time"
-
- "golang.org/x/oauth2"
-
- "appengine"
- "appengine/memcache"
-)
-
-type tokMap map[string]*oauth2.Token
-
-type mockMemcache struct {
- mu sync.RWMutex
- vals tokMap
- getCount, setCount int
-}
-
-func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.getCount++
- v, ok := m.vals[key]
- if !ok {
- return nil, fmt.Errorf("unexpected test error: key %q not found", key)
- }
- *tok = *v
- return nil, nil // memcache.Item is ignored anyway - return nil
-}
-
-func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.setCount++
- tok, ok := item.Object.(oauth2.Token)
- if !ok {
- log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item)
- }
- m.vals[item.Key] = &tok
- return nil
-}
-
-var accessTokenCount = 0
-
-func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) {
- accessTokenCount++
- return "mytoken", time.Now(), nil
-}
-
-const (
- testScope = "myscope"
- testScopeKey = ":" + testScope
-)
-
-func init() {
- accessTokenFunc = mockAccessToken
-}
-
-func TestFetchTokenLocalCacheMiss(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- delete(tokens, testScopeKey) // clear local cache
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- tr := f.NewTransport()
- c := http.Client{Transport: tr}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 1; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 1; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache has been populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenLocalCacheHit(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- // Pre-populate the local cache
- tokens[testScopeKey] = &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(1 * time.Hour),
- }
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- tr := f.NewTransport()
- c := http.Client{Transport: tr}
- c.Get("server")
- if err != nil {
- t.Errorf("unable to FetchToken: %v", err)
- }
- if w := 0; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 0; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 0; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache remains populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenMemcacheHit(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- delete(tokens, testScopeKey) // clear local cache
- // Pre-populate the memcache
- tok := &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(1 * time.Hour),
- }
- m.Set(nil, &memcache.Item{
- Key: testScopeKey,
- Object: *tok,
- Expiration: 1 * time.Hour,
- })
- m.setCount = 0
-
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- c := http.Client{Transport: f.NewTransport()}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 0; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 0; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache has been populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenLocalCacheExpired(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- // Pre-populate the local cache
- tokens[testScopeKey] = &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(-1 * time.Hour),
- }
- // Pre-populate the memcache
- tok := &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(1 * time.Hour),
- }
- m.Set(nil, &memcache.Item{
- Key: testScopeKey,
- Object: *tok,
- Expiration: 1 * time.Hour,
- })
- m.setCount = 0
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- c := http.Client{Transport: f.NewTransport()}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 0; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 0; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache remains populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenMemcacheExpired(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- delete(tokens, testScopeKey) // clear local cache
- // Pre-populate the memcache
- tok := &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(-1 * time.Hour),
- }
- m.Set(nil, &memcache.Item{
- Key: testScopeKey,
- Object: *tok,
- Expiration: -1 * time.Hour,
- })
- m.setCount = 0
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- c := http.Client{Transport: f.NewTransport()}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 1; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 1; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache has been populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
diff --git a/google/appenginevm.go b/google/appenginevm.go
index ce2b1bd..12af742 100644
--- a/google/appenginevm.go
+++ b/google/appenginevm.go
@@ -7,102 +7,30 @@
package google
import (
- "strings"
- "sync"
"time"
"golang.org/x/oauth2"
"google.golang.org/appengine"
- "google.golang.org/appengine/memcache"
)
-var (
- // memcacheGob enables mocking of the memcache.Gob calls for unit testing.
- memcacheGob memcacher = &aeMemcache{}
-
- // accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
- accessTokenFunc = appengine.AccessToken
-
- // mu protects multiple threads from attempting to fetch a token at the same time.
- mu sync.Mutex
-
- // tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
- tokens map[string]*oauth2.Token
-)
-
-// safetyMargin is used to avoid clock-skew problems.
-// 5 minutes is conservative because tokens are valid for 60 minutes.
-const safetyMargin = 5 * time.Minute
-
-func init() {
- tokens = make(map[string]*oauth2.Token)
-}
-
-// AppEngineContext requires an App Engine request context.
-func AppEngineContext(ctx appengine.Context) oauth2.Option {
- return func(opts *oauth2.Options) error {
- opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts)
- return nil
+// AppEngineTokenSource returns a token source that fetches tokens
+// issued to the current App Engine application's service account.
+// If you are implementing a 3-legged OAuth 2.0 flow on App Engine
+// that involves user accounts, see oauth2.Config instead.
+//
+// You are required to provide a valid appengine.Context as context.
+func AppEngineTokenSource(ctx appengine.Context, scope ...string) oauth2.TokenSource {
+ return &appEngineTokenSource{
+ ctx: ctx,
+ scopes: scope,
+ fetcherFunc: aeVMFetcherFunc,
}
}
-// FetchToken fetches a new access token for the provided scopes.
-// Tokens are cached locally and also with Memcache so that the app can scale
-// without hitting quota limits by calling appengine.AccessToken too frequently.
-func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) {
- return func(existing *oauth2.Token) (*oauth2.Token, error) {
- mu.Lock()
- defer mu.Unlock()
-
- key := ":" + strings.Join(opts.Scopes, "_")
- now := time.Now().Add(safetyMargin)
- if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
- return t, nil
- }
- delete(tokens, key)
-
- // Attempt to get token from Memcache
- tok := new(oauth2.Token)
- _, err := memcacheGob.Get(ctx, key, tok)
- if err == nil && !tok.Expiry.Before(now) {
- tokens[key] = tok // Save token locally
- return tok, nil
- }
-
- token, expiry, err := accessTokenFunc(ctx, opts.Scopes...)
- if err != nil {
- return nil, err
- }
- t := &oauth2.Token{
- AccessToken: token,
- Expiry: expiry,
- }
- tokens[key] = t
- // Also back up token in Memcache
- if err = memcacheGob.Set(ctx, &memcache.Item{
- Key: key,
- Value: []byte{},
- Object: *t,
- Expiration: expiry.Sub(now),
- }); err != nil {
- ctx.Errorf("unexpected memcache.Set error: %v", err)
- }
- return t, nil
+var aeVMFetcherFunc = func(ctx oauth2.Context, scope ...string) (string, time.Time, error) {
+ c, ok := ctx.(appengine.Context)
+ if !ok {
+ return "", time.Time{}, errInvalidContext
}
-}
-
-// aeMemcache wraps the needed Memcache functionality to make it easy to mock
-type aeMemcache struct{}
-
-func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
- return memcache.Gob.Get(c, key, tok)
-}
-
-func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error {
- return memcache.Gob.Set(c, item)
-}
-
-type memcacher interface {
- Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
- Set(c appengine.Context, item *memcache.Item) error
+ return appengine.AccessToken(c, scope...)
}
diff --git a/google/appenginevm_test.go b/google/appenginevm_test.go
deleted file mode 100644
index 3ca4b0d..0000000
--- a/google/appenginevm_test.go
+++ /dev/null
@@ -1,265 +0,0 @@
-// Copyright 2014 The oauth2 Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build appenginevm !appengine
-
-package google
-
-import (
- "fmt"
- "log"
- "net/http"
- "sync"
- "testing"
- "time"
-
- "golang.org/x/oauth2"
- "google.golang.org/appengine"
- "google.golang.org/appengine/memcache"
-)
-
-type tokMap map[string]*oauth2.Token
-
-type mockMemcache struct {
- mu sync.RWMutex
- vals tokMap
- getCount, setCount int
-}
-
-func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.getCount++
- v, ok := m.vals[key]
- if !ok {
- return nil, fmt.Errorf("unexpected test error: key %q not found", key)
- }
- *tok = *v
- return nil, nil // memcache.Item is ignored anyway - return nil
-}
-
-func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.setCount++
- tok, ok := item.Object.(oauth2.Token)
- if !ok {
- log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item)
- }
- m.vals[item.Key] = &tok
- return nil
-}
-
-var accessTokenCount = 0
-
-func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) {
- accessTokenCount++
- return "mytoken", time.Now(), nil
-}
-
-const (
- testScope = "myscope"
- testScopeKey = ":" + testScope
-)
-
-func init() {
- accessTokenFunc = mockAccessToken
-}
-
-func TestFetchTokenLocalCacheMiss(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- delete(tokens, testScopeKey) // clear local cache
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- tr := f.NewTransport()
- c := http.Client{Transport: tr}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 1; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 1; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache has been populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenLocalCacheHit(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- // Pre-populate the local cache
- tokens[testScopeKey] = &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(1 * time.Hour),
- }
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- tr := f.NewTransport()
- c := http.Client{Transport: tr}
- c.Get("server")
- if err != nil {
- t.Errorf("unable to FetchToken: %v", err)
- }
- if w := 0; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 0; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 0; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache remains populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenMemcacheHit(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- delete(tokens, testScopeKey) // clear local cache
- // Pre-populate the memcache
- tok := &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(1 * time.Hour),
- }
- m.Set(nil, &memcache.Item{
- Key: testScopeKey,
- Object: *tok,
- Expiration: 1 * time.Hour,
- })
- m.setCount = 0
-
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- c := http.Client{Transport: f.NewTransport()}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 0; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 0; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache has been populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenLocalCacheExpired(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- // Pre-populate the local cache
- tokens[testScopeKey] = &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(-1 * time.Hour),
- }
- // Pre-populate the memcache
- tok := &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(1 * time.Hour),
- }
- m.Set(nil, &memcache.Item{
- Key: testScopeKey,
- Object: *tok,
- Expiration: 1 * time.Hour,
- })
- m.setCount = 0
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- c := http.Client{Transport: f.NewTransport()}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 0; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 0; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache remains populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
-
-func TestFetchTokenMemcacheExpired(t *testing.T) {
- m := &mockMemcache{vals: make(tokMap)}
- memcacheGob = m
- accessTokenCount = 0
- delete(tokens, testScopeKey) // clear local cache
- // Pre-populate the memcache
- tok := &oauth2.Token{
- AccessToken: "mytoken",
- Expiry: time.Now().Add(-1 * time.Hour),
- }
- m.Set(nil, &memcache.Item{
- Key: testScopeKey,
- Object: *tok,
- Expiration: -1 * time.Hour,
- })
- m.setCount = 0
- f, err := oauth2.New(
- AppEngineContext(nil),
- oauth2.Scope(testScope),
- )
- if err != nil {
- t.Error(err)
- }
- c := http.Client{Transport: f.NewTransport()}
- c.Get("server")
- if w := 1; m.getCount != w {
- t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
- }
- if w := 1; accessTokenCount != w {
- t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
- }
- if w := 1; m.setCount != w {
- t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
- }
- // Make sure local cache has been populated
- _, ok := tokens[testScopeKey]
- if !ok {
- t.Errorf("local cache not populated!")
- }
-}
diff --git a/google/example_test.go b/google/example_test.go
index 9fec175..31ff67a 100644
--- a/google/example_test.go
+++ b/google/example_test.go
@@ -8,6 +8,7 @@
import (
"fmt"
+ "io/ioutil"
"log"
"net/http"
"testing"
@@ -15,6 +16,7 @@
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/appengine"
+ "google.golang.org/appengine/urlfetch"
)
// Remove after Go 1.4.
@@ -24,33 +26,31 @@
func Example_webServer() {
// Your credentials should be obtained from the Google
// Developer Console (https://console.developers.google.com).
- opts, err := oauth2.New(
- oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"),
- oauth2.RedirectURL("YOUR_REDIRECT_URL"),
- oauth2.Scope(
+ conf := &oauth2.Config{
+ ClientID: "YOUR_CLIENT_ID",
+ ClientSecret: "YOUR_CLIENT_SECRET",
+ RedirectURL: "YOUR_REDIRECT_URL",
+ Scopes: []string{
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/blogger",
- ),
- google.Endpoint(),
- )
- if err != nil {
- log.Fatal(err)
+ },
+ Endpoint: google.Endpoint,
}
// Redirect user to Google's consent page to ask for permission
// for the scopes specified above.
- url := opts.AuthCodeURL("state", "online", "auto")
+ url := conf.AuthCodeURL("state")
fmt.Printf("Visit the URL for the auth dialog: %v", url)
- // Handle the exchange code to initiate a transport
- t, err := opts.NewTransportFromCode("exchange-code")
+ // Handle the exchange code to initiate a transport.
+ tok, err := conf.Exchange(oauth2.NoContext, "authorization-code")
if err != nil {
log.Fatal(err)
}
- client := http.Client{Transport: t}
+ client := conf.Client(oauth2.NoContext, tok)
client.Get("...")
}
-func Example_serviceAccountsJSON() {
+func ExampleJWTConfigFromJSON() {
// Your credentials should be obtained from the Google
// Developer Console (https://console.developers.google.com).
// Navigate to your project, then see the "Credentials" page
@@ -58,27 +58,26 @@
// To create a service account client, click "Create new Client ID",
// select "Service Account", and click "Create Client ID". A JSON
// key file will then be downloaded to your computer.
- opts, err := oauth2.New(
- google.ServiceAccountJSONKey("/path/to/your-project-key.json"),
- oauth2.Scope(
- "https://www.googleapis.com/auth/bigquery",
- "https://www.googleapis.com/auth/blogger",
- ),
- )
+ data, err := ioutil.ReadFile("/path/to/your-project-key.json")
+ if err != nil {
+ log.Fatal(err)
+ }
+ conf, err := google.JWTConfigFromJSON(oauth2.NoContext, data, "https://www.googleapis.com/auth/bigquery")
if err != nil {
log.Fatal(err)
}
// Initiate an http.Client. The following GET request will be
// authorized and authenticated on the behalf of
// your service account.
- client := http.Client{Transport: opts.NewTransport()}
+ client := conf.Client(oauth2.NoContext, nil)
client.Get("...")
}
-func Example_serviceAccounts() {
+func Example_serviceAccount() {
// Your credentials should be obtained from the Google
// Developer Console (https://console.developers.google.com).
- opts, err := oauth2.New(
+ conf := &oauth2.JWTConfig{
+ Email: "xxx@developer.gserviceaccount.com",
// The contents of your RSA private key or your PEM file
// that contains a private key.
// If you have a p12 file instead, you
@@ -87,58 +86,46 @@
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
// It only supports PEM containers with no passphrase.
- oauth2.JWTClient(
- "xxx@developer.gserviceaccount.com",
- []byte("-----BEGIN RSA PRIVATE KEY-----...")),
- oauth2.Scope(
+ PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."),
+ Scopes: []string{
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/blogger",
- ),
- google.JWTEndpoint(),
+ },
+ TokenURL: google.JWTTokenURL,
// If you would like to impersonate a user, you can
// create a transport with a subject. The following GET
// request will be made on the behalf of user@example.com.
- // Subject is optional.
- oauth2.Subject("user@example.com"),
- )
- if err != nil {
- log.Fatal(err)
+ // Optional.
+ Subject: "user@example.com",
}
-
// Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com.
- client := http.Client{Transport: opts.NewTransport()}
+ client := conf.Client(oauth2.NoContext, nil)
client.Get("...")
}
-func Example_appEngine() {
- ctx := appengine.NewContext(nil)
- opts, err := oauth2.New(
- google.AppEngineContext(ctx),
- oauth2.Scope(
- "https://www.googleapis.com/auth/bigquery",
- "https://www.googleapis.com/auth/blogger",
- ),
- )
- if err != nil {
- log.Fatal(err)
+func ExampleAppEngineTokenSource() {
+ var req *http.Request // from the ServeHTTP handler
+ ctx := appengine.NewContext(req)
+ client := &http.Client{
+ Transport: &oauth2.Transport{
+ Source: google.AppEngineTokenSource(ctx, "https://www.googleapis.com/auth/bigquery"),
+ Base: &urlfetch.Transport{
+ Context: ctx,
+ },
+ },
}
- // The following client will be authorized by the App Engine
- // app's service account for the provided scopes.
- client := http.Client{Transport: opts.NewTransport()}
client.Get("...")
}
-func Example_computeEngine() {
- opts, err := oauth2.New(
- // Query Google Compute Engine's metadata server to retrieve
- // an access token for the provided account.
- // If no account is specified, "default" is used.
- google.ComputeEngineAccount(""),
- )
- if err != nil {
- log.Fatal(err)
+func ExampleComputeTokenSource() {
+ client := &http.Client{
+ Transport: &oauth2.Transport{
+ // Fetch from Google Compute Engine's metadata server to retrieve
+ // an access token for the provided account.
+ // If no account is specified, "default" is used.
+ Source: google.ComputeTokenSource(""),
+ },
}
- client := http.Client{Transport: opts.NewTransport()}
client.Get("...")
}
diff --git a/google/google.go b/google/google.go
index 8256e2c..4890776 100644
--- a/google/google.go
+++ b/google/google.go
@@ -17,19 +17,41 @@
"encoding/json"
"fmt"
- "io/ioutil"
+ "net"
"net/http"
- "net/url"
"time"
"golang.org/x/oauth2"
- "golang.org/x/oauth2/internal"
)
-var (
- uriGoogleAuth, _ = url.Parse("https://accounts.google.com/o/oauth2/auth")
- uriGoogleToken, _ = url.Parse("https://accounts.google.com/o/oauth2/token")
-)
+// Endpoint is Google's OAuth 2.0 endpoint.
+var Endpoint = oauth2.Endpoint{
+ AuthURL: "https://accounts.google.com/o/oauth2/auth",
+ TokenURL: "https://accounts.google.com/o/oauth2/token",
+}
+
+// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.
+const JWTTokenURL = "https://accounts.google.com/o/oauth2/token"
+
+// JWTConfigFromJSON uses a Google Developers service account JSON key file to read
+// the credentials that authorize and authenticate the requests.
+// Create a service account on "Credentials" page under "APIs & Auth" for your
+// project at https://console.developers.google.com to download a JSON key file.
+func JWTConfigFromJSON(ctx oauth2.Context, jsonKey []byte, scope ...string) (*oauth2.JWTConfig, error) {
+ var key struct {
+ Email string `json:"client_email"`
+ PrivateKey string `json:"private_key"`
+ }
+ if err := json.Unmarshal(jsonKey, &key); err != nil {
+ return nil, err
+ }
+ return &oauth2.JWTConfig{
+ Email: key.Email,
+ PrivateKey: []byte(key.PrivateKey),
+ Scopes: scope,
+ TokenURL: JWTTokenURL,
+ }, nil
+}
type metaTokenRespBody struct {
AccessToken string `json:"access_token"`
@@ -37,93 +59,57 @@
TokenType string `json:"token_type"`
}
-// JWTEndpoint adds the endpoints required to complete the 2-legged service account flow.
-func JWTEndpoint() oauth2.Option {
- return func(opts *oauth2.Options) error {
- opts.AUD = uriGoogleToken
- return nil
- }
+// ComputeTokenSource returns a token source that fetches access tokens
+// from Google Compute Engine (GCE)'s metadata server. It's only valid to use
+// this token source if your program is running on a GCE instance.
+// If no account is specified, "default" is used.
+// Further information about retrieving access tokens from the GCE metadata
+// server can be found at https://cloud.google.com/compute/docs/authentication.
+func ComputeTokenSource(account string) oauth2.TokenSource {
+ return &computeSource{account: account}
}
-// Endpoint adds the endpoints required to do the 3-legged Web server flow.
-func Endpoint() oauth2.Option {
- return func(opts *oauth2.Options) error {
- opts.AuthURL = uriGoogleAuth
- opts.TokenURL = uriGoogleToken
- return nil
- }
+type computeSource struct {
+ account string
}
-// ComputeEngineAccount uses the specified account to retrieve an access
-// token from the Google Compute Engine's metadata server. If no user is
-// provided, "default" is being used.
-func ComputeEngineAccount(account string) oauth2.Option {
- return func(opts *oauth2.Options) error {
- if account == "" {
- account = "default"
- }
- opts.TokenFetcherFunc = makeComputeFetcher(opts, account)
- return nil
- }
+var metaClient = &http.Client{
+ Transport: &http.Transport{
+ Dial: (&net.Dialer{
+ Timeout: 750 * time.Millisecond,
+ KeepAlive: 30 * time.Second,
+ }).Dial,
+ ResponseHeaderTimeout: 750 * time.Millisecond,
+ },
}
-// ServiceAccountJSONKey uses the provided Google Developers
-// JSON key file to authorize the user. See the "Credentials" page under
-// "APIs & Auth" for your project at https://console.developers.google.com
-// to download a JSON key file.
-func ServiceAccountJSONKey(filename string) oauth2.Option {
- return func(opts *oauth2.Options) error {
- b, err := ioutil.ReadFile(filename)
- if err != nil {
- return err
- }
- var key struct {
- Email string `json:"client_email"`
- PrivateKey string `json:"private_key"`
- }
- if err := json.Unmarshal(b, &key); err != nil {
- return err
- }
- pk, err := internal.ParseKey([]byte(key.PrivateKey))
- if err != nil {
- return err
- }
- opts.Email = key.Email
- opts.PrivateKey = pk
- opts.AUD = uriGoogleToken
- return nil
+func (cs *computeSource) Token() (*oauth2.Token, error) {
+ acct := cs.account
+ if acct == "" {
+ acct = "default"
}
-}
-
-func makeComputeFetcher(opts *oauth2.Options, account string) func(*oauth2.Token) (*oauth2.Token, error) {
- return func(t *oauth2.Token) (*oauth2.Token, error) {
- u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + account + "/token"
- req, err := http.NewRequest("GET", u, nil)
- if err != nil {
- return nil, err
- }
- req.Header.Add("X-Google-Metadata-Request", "True")
- c := &http.Client{}
- if opts.Client != nil {
- c = opts.Client
- }
- resp, err := c.Do(req)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- if resp.StatusCode < 200 || resp.StatusCode > 299 {
- return nil, fmt.Errorf("oauth2: can't retrieve a token from metadata server, status code: %d", resp.StatusCode)
- }
- var tokenResp metaTokenRespBody
- err = json.NewDecoder(resp.Body).Decode(&tokenResp)
- if err != nil {
- return nil, err
- }
- return &oauth2.Token{
- AccessToken: tokenResp.AccessToken,
- TokenType: tokenResp.TokenType,
- Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second),
- }, nil
+ u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + acct + "/token"
+ req, err := http.NewRequest("GET", u, nil)
+ if err != nil {
+ return nil, err
}
+ req.Header.Add("X-Google-Metadata-Request", "True")
+ resp, err := metaClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode > 299 {
+ return nil, fmt.Errorf("oauth2: can't retrieve a token from metadata server, status code: %d", resp.StatusCode)
+ }
+ var tokenResp metaTokenRespBody
+ err = json.NewDecoder(resp.Body).Decode(&tokenResp)
+ if err != nil {
+ return nil, err
+ }
+ return &oauth2.Token{
+ AccessToken: tokenResp.AccessToken,
+ TokenType: tokenResp.TokenType,
+ Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second),
+ }, nil
}
diff --git a/google/source_appengine.go b/google/source_appengine.go
new file mode 100644
index 0000000..9b8aa97
--- /dev/null
+++ b/google/source_appengine.go
@@ -0,0 +1,68 @@
+// Copyright 2014 The oauth2 Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package google
+
+import (
+ "errors"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/oauth2"
+)
+
+var (
+ aeTokensMu sync.Mutex // guards aeTokens and appEngineTokenSource.key
+
+ // aeTokens helps the fetched tokens to be reused until their expiration.
+ aeTokens = make(map[string]*tokenLock) // key is '\0'-separated scopes
+)
+
+var errInvalidContext = errors.New("oauth2: a valid appengine.Context is required")
+
+type tokenLock struct {
+ mu sync.Mutex // guards t; held while updating t
+ t *oauth2.Token
+}
+
+type appEngineTokenSource struct {
+ ctx oauth2.Context
+ scopes []string
+ key string // guarded by package-level mutex, aeTokensMu
+
+ // fetcherFunc makes the actual RPC to fetch a new access token with an expiry time.
+ // Provider of this function is responsible to assert that the given context is valid.
+ fetcherFunc func(ctx oauth2.Context, scope ...string) (string, time.Time, error)
+}
+
+func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
+ aeTokensMu.Lock()
+ if ts.key == "" {
+ sort.Sort(sort.StringSlice(ts.scopes))
+ ts.key = strings.Join(ts.scopes, string(0))
+ }
+ tok, ok := aeTokens[ts.key]
+ if !ok {
+ tok = &tokenLock{}
+ aeTokens[ts.key] = tok
+ }
+ aeTokensMu.Unlock()
+
+ tok.mu.Lock()
+ defer tok.mu.Unlock()
+ if tok.t != nil && !tok.t.Expired() {
+ return tok.t, nil
+ }
+ access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...)
+ if err != nil {
+ return nil, err
+ }
+ tok.t = &oauth2.Token{
+ AccessToken: access,
+ Expiry: exp,
+ }
+ return tok.t, nil
+}
diff --git a/internal/oauth2.go b/internal/oauth2.go
index b91b662..47c8f14 100644
--- a/internal/oauth2.go
+++ b/internal/oauth2.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// Package internal contains support packages for oauth2 package.
package internal
import (
diff --git a/jwt.go b/jwt.go
index d861e93..eedbfc1 100644
--- a/jwt.go
+++ b/jwt.go
@@ -5,7 +5,6 @@
package oauth2
import (
- "crypto/rsa"
"encoding/json"
"fmt"
"io"
@@ -15,6 +14,7 @@
"strings"
"time"
+ "golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws"
)
@@ -38,7 +38,7 @@
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
- PrivateKey *rsa.PrivateKey
+ PrivateKey []byte
// Subject is the optional user to impersonate.
Subject string
@@ -76,8 +76,8 @@
func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client {
return &http.Client{
Transport: &Transport{
- Source: c.TokenSource(ctx, initialToken),
Base: contextTransport(ctx),
+ Source: c.TokenSource(ctx, initialToken),
},
}
}
@@ -90,6 +90,10 @@
}
func (js jwtSource) Token() (*Token, error) {
+ pk, err := internal.ParseKey(js.conf.PrivateKey)
+ if err != nil {
+ return nil, err
+ }
hc, err := contextClient(js.ctx)
if err != nil {
return nil, err
@@ -105,7 +109,7 @@
// to be compatible with legacy OAuth 2.0 providers.
claimSet.Prn = subject
}
- payload, err := jws.Encode(defaultHeader, claimSet, js.conf.PrivateKey)
+ payload, err := jws.Encode(defaultHeader, claimSet, pk)
if err != nil {
return nil, err
}
diff --git a/jwt_test.go b/jwt_test.go
index b51c702..2fe371b 100644
--- a/jwt_test.go
+++ b/jwt_test.go
@@ -48,18 +48,16 @@
}`))
}))
defer ts.Close()
- f, err := New(
- JWTClient("aaa@xxx.com", dummyPrivateKey),
- JWTEndpoint(ts.URL),
- )
- if err != nil {
- t.Error(err)
- }
- tr := f.NewTransport()
- c := http.Client{Transport: tr}
- c.Get(ts.URL)
- tok := tr.Token()
+ conf := &JWTConfig{
+ Email: "aaa@xxx.com",
+ PrivateKey: dummyPrivateKey,
+ TokenURL: ts.URL,
+ }
+ tok, err := conf.TokenSource(NoContext, nil).Token()
+ if err != nil {
+ t.Fatal(err)
+ }
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
}
@@ -81,19 +79,15 @@
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
- f, err := New(
- JWTClient("aaa@xxx.com", dummyPrivateKey),
- JWTEndpoint(ts.URL),
- )
- if err != nil {
- t.Error(err)
+
+ conf := &JWTConfig{
+ Email: "aaa@xxx.com",
+ PrivateKey: dummyPrivateKey,
+ TokenURL: ts.URL,
}
- tr := f.NewTransport()
- c := http.Client{Transport: tr}
- c.Get(ts.URL)
- tok := tr.Token()
+ tok, err := conf.TokenSource(NoContext, nil).Token()
if err != nil {
- t.Errorf("Failed retrieving token: %s.", err)
+ t.Fatal(err)
}
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
@@ -113,19 +107,14 @@
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
- f, err := New(
- JWTClient("aaa@xxx.com", dummyPrivateKey),
- JWTEndpoint(ts.URL),
- )
- if err != nil {
- t.Error(err)
+ conf := &JWTConfig{
+ Email: "aaa@xxx.com",
+ PrivateKey: dummyPrivateKey,
+ TokenURL: ts.URL,
}
- tr := f.NewTransport()
- c := http.Client{Transport: tr}
- c.Get(ts.URL)
- tok := tr.Token()
+ tok, err := conf.TokenSource(NoContext, nil).Token()
if err != nil {
- t.Errorf("Failed retrieving token: %s.", err)
+ t.Fatal(err)
}
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
diff --git a/oauth2.go b/oauth2.go
index 88121f3..753aa60 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -27,9 +27,14 @@
// Context can be an golang.org/x/net.Context, or an App Engine Context.
// In the future these will be unified.
-// If you don't care and aren't running on App Engine, you may use nil.
+// If you don't care and aren't running on App Engine, you may use NoContext.
type Context interface{}
+// NoContext is the default context. If you're not running this code
+// on App Engine or not using golang.org/x/net.Context to provide a custom
+// HTTP client, you should use NoContext.
+var NoContext Context = nil
+
// Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's URLs.
type Config struct {
@@ -272,8 +277,8 @@
func (c *Config) Client(ctx Context, t *Token) *http.Client {
return &http.Client{
Transport: &Transport{
- Source: c.TokenSource(ctx, t),
Base: contextTransport(ctx),
+ Source: c.TokenSource(ctx, t),
},
}
}
diff --git a/oauth2_test.go b/oauth2_test.go
index 6c21043..8159b86 100644
--- a/oauth2_test.go
+++ b/oauth2_test.go
@@ -10,6 +10,8 @@
"net/http"
"net/http/httptest"
"testing"
+
+ "golang.org/x/net/context"
)
type mockTransport struct {
@@ -33,31 +35,37 @@
// do nothing
}
-func newOpts(url string) *Options {
- opts, _ := New(
- Client("CLIENT_ID", "CLIENT_SECRET"),
- RedirectURL("REDIRECT_URL"),
- Scope("scope1", "scope2"),
- Endpoint(url+"/auth", url+"/token"),
- )
- return opts
+func newConf(url string) *Config {
+ return &Config{
+ ClientID: "CLIENT_ID",
+ ClientSecret: "CLIENT_SECRET",
+ RedirectURL: "REDIRECT_URL",
+ Scopes: []string{"scope1", "scope2"},
+ Endpoint: Endpoint{
+ AuthURL: url + "/auth",
+ TokenURL: url + "/token",
+ },
+ }
}
func TestAuthCodeURL(t *testing.T) {
- opts := newOpts("server")
- url := opts.AuthCodeURL("foo", "offline", "force")
+ conf := newConf("server")
+ url := conf.AuthCodeURL("foo", AccessTypeOffline, ApprovalForce)
if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" {
t.Errorf("Auth code URL doesn't match the expected, found: %v", url)
}
}
func TestAuthCodeURL_Optional(t *testing.T) {
- opts, _ := New(
- Client("CLIENT_ID", ""),
- Endpoint("auth-url", "token-token"),
- )
- url := opts.AuthCodeURL("", "", "")
- if url != "auth-url?client_id=CLIENT_ID&response_type=code" {
+ conf := &Config{
+ ClientID: "CLIENT_ID",
+ Endpoint: Endpoint{
+ AuthURL: "/auth-url",
+ TokenURL: "/token-url",
+ },
+ }
+ url := conf.AuthCodeURL("")
+ if url != "/auth-url?client_id=CLIENT_ID&response_type=code" {
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url)
}
}
@@ -86,12 +94,11 @@
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
}))
defer ts.Close()
- opts := newOpts(ts.URL)
- tr, err := opts.NewTransportFromCode("exchange-code")
+ conf := newConf(ts.URL)
+ tok, err := conf.Exchange(NoContext, "exchange-code")
if err != nil {
t.Error(err)
}
- tok := tr.Token()
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
}
@@ -131,15 +138,11 @@
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`))
}))
defer ts.Close()
- opts := newOpts(ts.URL)
- tr, err := opts.NewTransportFromCode("exchange-code")
+ conf := newConf(ts.URL)
+ tok, err := conf.Exchange(NoContext, "exchange-code")
if err != nil {
t.Error(err)
}
- tok := tr.Token()
- if tok.Expiry.IsZero() {
- t.Errorf("Token expiry should not be zero.")
- }
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
}
@@ -161,12 +164,11 @@
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
- opts := newOpts(ts.URL)
- tr, err := opts.NewTransportFromCode("exchange-code")
+ conf := newConf(ts.URL)
+ tok, err := conf.Exchange(NoContext, "code")
if err != nil {
- t.Error(err)
+ t.Fatal(err)
}
- tok := tr.Token()
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
@@ -178,12 +180,11 @@
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
- opts := newOpts(ts.URL)
- tr, err := opts.NewTransportFromCode("exchange-code")
+ conf := newConf(ts.URL)
+ tok, err := conf.Exchange(NoContext, "exchange-code")
if err != nil {
t.Error(err)
}
- tok := tr.Token()
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
@@ -200,15 +201,16 @@
},
}
c := &http.Client{Transport: tr}
- opts, err := New(
- Client("CLIENT_ID", ""),
- Endpoint("https://accounts.google.com/auth", "https://accounts.google.com/token"),
- HTTPClient(c),
- )
- if err != nil {
- t.Error(err)
+ conf := &Config{
+ ClientID: "CLIENT_ID",
+ Endpoint: Endpoint{
+ AuthURL: "https://accounts.google.com/auth",
+ TokenURL: "https://accounts.google.com/token",
+ },
}
- opts.NewTransportFromCode("code")
+
+ ctx := context.WithValue(context.Background(), HTTPClient, c)
+ conf.Exchange(ctx, "code")
}
func TestTokenRefreshRequest(t *testing.T) {
@@ -229,10 +231,8 @@
}
}))
defer ts.Close()
- opts := newOpts(ts.URL)
- tr := opts.NewTransport()
- tr.token = &Token{RefreshToken: "REFRESH_TOKEN"}
- c := http.Client{Transport: tr}
+ conf := newConf(ts.URL)
+ c := conf.Client(NoContext, &Token{RefreshToken: "REFRESH_TOKEN"})
c.Get(ts.URL + "/somethingelse")
}
@@ -254,28 +254,10 @@
}
}))
defer ts.Close()
- opts := newOpts(ts.URL)
- tr := opts.NewTransport()
- c := http.Client{Transport: tr}
+ conf := newConf(ts.URL)
+ c := conf.Client(NoContext, nil)
_, err := c.Get(ts.URL + "/somethingelse")
if err == nil {
t.Errorf("Fetch should return an error if no refresh token is set")
}
}
-
-func TestCacheNoToken(t *testing.T) {
- opts, err := New(
- Client("CLIENT_ID", "CLIENT_SECRET"),
- Endpoint("/auth", "/token"),
- )
- if err != nil {
- t.Error(err)
- }
- tr, err := opts.NewTransportFromTokenStore(&mockCache{token: nil, readErr: nil})
- if err != nil {
- t.Errorf("No error expected, %v is found", err)
- }
- if tr != nil {
- t.Errorf("No transport should have been initiated, tr is found to be %v", tr)
- }
-}
diff --git a/transport_test.go b/transport_test.go
index 5fbccf6..b3414e3 100644
--- a/transport_test.go
+++ b/transport_test.go
@@ -7,45 +7,29 @@
"time"
)
-type mockTokenFetcher struct{ token *Token }
+type tokenSource struct{ token *Token }
-func (f *mockTokenFetcher) Fn() func(*Token) (*Token, error) {
- return func(*Token) (*Token, error) {
- return f.token, nil
- }
+func (t *tokenSource) Token() (*Token, error) {
+ return t.token, nil
}
-func TestInitialTokenRead(t *testing.T) {
- tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"})
- server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
- if r.Header.Get("Authorization") != "Bearer abc" {
- t.Errorf("Transport doesn't set the Authorization header from the initial token")
- }
- })
- defer server.Close()
- client := http.Client{Transport: tr}
- client.Get(server.URL)
-}
-
-func TestTokenFetch(t *testing.T) {
- fetcher := &mockTokenFetcher{
+func TestTransportTokenSource(t *testing.T) {
+ ts := &tokenSource{
token: &Token{
AccessToken: "abc",
},
}
- tr := newTransport(http.DefaultTransport, &Options{TokenFetcherFunc: fetcher.Fn()}, nil)
+ tr := &Transport{
+ Source: ts,
+ }
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer abc" {
t.Errorf("Transport doesn't set the Authorization header from the fetched token")
}
})
defer server.Close()
-
client := http.Client{Transport: tr}
client.Get(server.URL)
- if tr.Token().AccessToken != "abc" {
- t.Errorf("New token is not set, found %v", tr.Token())
- }
}
func TestExpiredWithNoAccessToken(t *testing.T) {