Add Cacher interface.
diff --git a/oauth2.go b/oauth2.go
index ce166b1..38f4a19 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -22,6 +22,18 @@
"strings"
)
+// Cacher implementations read and write OAuth 2.0 tokens from a cache.
+type Cacher interface {
+ // Read reads the token from the cache.
+ // If the read is successful, it should return the token and a nil error.
+ // The returned tokens may be expired tokens.
+ // If there is no token in the cache, it should return a nil token and a nil error.
+ // It should return a non-nil error when an unrecoverable failure occurs.
+ Read() (*Token, error)
+ // Write writes the token to the cache.
+ Write(*Token)
+}
+
// Option represents a function that applies some state to
// an Options object.
type Option func(*Options) error
@@ -91,6 +103,16 @@
}
}
+// Cache requires a Cacher implementation. It will initially read
+// the token if the transport is initialized with NewTransportFromCache
+// and will write the refreshed tokens back to the cache.
+func Cache(c Cacher) Option {
+ return func(o *Options) error {
+ o.Cache = c
+ return nil
+ }
+}
+
type Flow struct {
opts Options
}
@@ -109,11 +131,11 @@
case f.opts.TokenFetcherFunc != nil:
return f, nil
case f.opts.AUD != nil:
- // TODO(jbd): Assert required JWT params.
+ // TODO(jbd): Assert the required JWT params.
f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts)
return f, nil
case f.opts.AuthURL != nil && f.opts.TokenURL != nil:
- // TODO(jbd): Assert required OAuth2 params.
+ // TODO(jbd): Assert the required OAuth2 params.
f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts)
return f, nil
default:
@@ -175,6 +197,23 @@
})
}
+// NewTransportFromCache reads the token from the cache and returns
+// a Transport that is authorized and the authenticated
+// by the returned token.
+func (f *Flow) NewTransportFromCache() (*Transport, error) {
+ if f.opts.Cache == nil {
+ return nil, errors.New("oauth2: no cache is set")
+ }
+ tok, err := f.opts.Cache.Read()
+ if err != nil {
+ return nil, err
+ }
+ if tok == nil {
+ return nil, nil
+ }
+ return f.newTransportFromToken(tok), nil
+}
+
// NewTransportFromCode exchanges the code to retrieve a new access token
// and returns an authorized and authenticated Transport.
func (f *Flow) NewTransportFromCode(code string) (*Transport, error) {
@@ -182,22 +221,22 @@
if err != nil {
return nil, err
}
- return f.NewTransportFromToken(token), nil
-}
-
-// NewTransportFromToken returns a new Transport that is authorized
-// and authenticated with the provided token.
-func (f *Flow) NewTransportFromToken(t *Token) *Transport {
- tr := f.opts.Transport
- if tr == nil {
- tr = http.DefaultTransport
- }
- return newTransport(tr, f.opts.TokenFetcherFunc, t)
+ return f.newTransportFromToken(token), nil
}
// NewTransport returns a Transport.
func (f *Flow) NewTransport() *Transport {
- return f.NewTransportFromToken(nil)
+ return f.newTransportFromToken(nil)
+}
+
+// newTransportFromToken returns a new Transport that is authorized
+// and authenticated with the provided token.
+func (f *Flow) newTransportFromToken(t *Token) *Transport {
+ tr := f.opts.Transport
+ if tr == nil {
+ tr = http.DefaultTransport
+ }
+ return newTransport(tr, &f.opts, t)
}
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
@@ -255,6 +294,8 @@
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
AUD *url.URL
+ Cache Cacher
+
TokenFetcherFunc func(t *Token) (*Token, error)
Transport http.RoundTripper
diff --git a/oauth2_test.go b/oauth2_test.go
index 92419c6..d356f62 100644
--- a/oauth2_test.go
+++ b/oauth2_test.go
@@ -20,6 +20,19 @@
return t.rt(req)
}
+type mockCache struct {
+ token *Token
+ readErr error
+}
+
+func (c *mockCache) Read() (*Token, error) {
+ return c.token, c.readErr
+}
+
+func (c *mockCache) Write(*Token) {
+ // do nothing
+}
+
func newTestFlow(url string) *Flow {
f, _ := New(
Client("CLIENT_ID", "CLIENT_SECRET"),
@@ -211,7 +224,8 @@
}))
defer ts.Close()
f := newTestFlow(ts.URL)
- tr := f.NewTransportFromToken(&Token{RefreshToken: "REFRESH_TOKEN"})
+ tr := f.NewTransport()
+ tr.SetToken(&Token{RefreshToken: "REFRESH_TOKEN"})
c := http.Client{Transport: tr}
c.Get(ts.URL + "/somethingelse")
}
@@ -235,10 +249,25 @@
}))
defer ts.Close()
f := newTestFlow(ts.URL)
- tr := f.NewTransportFromToken(&Token{})
+ tr := f.NewTransport()
c := http.Client{Transport: tr}
_, err := c.Get(ts.URL + "/somethingelse")
if err == nil {
t.Errorf("Fetch should return an error if no refresh token is set")
}
}
+
+func TestCacheNoToken(t *testing.T) {
+ f, _ := New(
+ Client("CLIENT_ID", "CLIENT_SECRET"),
+ Endpoint("/auth", "/token"),
+ Cache(&mockCache{token: nil, readErr: nil}),
+ )
+ tr, err := f.NewTransportFromCache()
+ if err != nil {
+ t.Errorf("No error expected, %v is found", err)
+ }
+ if tr != nil {
+ t.Errorf("No transport should have been initiated, tr is found to be %v", tr)
+ }
+}
diff --git a/transport.go b/transport.go
index e1a35b0..9df11d8 100644
--- a/transport.go
+++ b/transport.go
@@ -66,8 +66,8 @@
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
type Transport struct {
- fetcher func(t *Token) (*Token, error)
- base http.RoundTripper
+ opts *Options
+ base http.RoundTripper
mu sync.RWMutex
token *Token
@@ -76,8 +76,12 @@
// NewTransport creates a new Transport that uses the provided
// token fetcher as token retrieving strategy. It authenticates
// the requests and delegates origTransport to make the actual requests.
-func newTransport(base http.RoundTripper, fn func(t *Token) (*Token, error), token *Token) *Transport {
- return &Transport{base: base, fetcher: fn, token: token}
+func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
+ return &Transport{
+ base: base,
+ opts: opts,
+ token: token,
+ }
}
// RoundTrip authorizes and authenticates the request with an
@@ -94,6 +98,9 @@
return nil, err
}
token = t.Token()
+ if t.opts.Cache != nil {
+ t.opts.Cache.Write(token)
+ }
}
// To set the Authorization header, we must make a copy of the Request
@@ -129,7 +136,7 @@
func (t *Transport) RefreshToken() error {
t.mu.Lock()
defer t.mu.Unlock()
- token, err := t.fetcher(t.token)
+ token, err := t.opts.TokenFetcherFunc(t.token)
if err != nil {
return err
}
diff --git a/transport_test.go b/transport_test.go
index f7cdbc4..5fbccf6 100644
--- a/transport_test.go
+++ b/transport_test.go
@@ -15,10 +15,6 @@
}
}
-func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) {
- return f.token, nil
-}
-
func TestInitialTokenRead(t *testing.T) {
tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"})
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
@@ -37,7 +33,7 @@
AccessToken: "abc",
},
}
- tr := newTransport(http.DefaultTransport, fetcher.Fn(), nil)
+ tr := newTransport(http.DefaultTransport, &Options{TokenFetcherFunc: fetcher.Fn()}, nil)
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer abc" {
t.Errorf("Transport doesn't set the Authorization header from the fetched token")