oauth2: support device flow
Device Authorization Grant following RFC 8628 https://datatracker.ietf.org/doc/html/rfc8628
Tested with GitHub
Fixes #418
Fixes golang/go#58126
Co-authored-by: cmP <centimitr@gmail.com>
Change-Id: Id588867110c6a5289bf1026da5d7ead88f9c7d14
GitHub-Last-Rev: 9a126d7b534532c7d18fb8d6796ad673b95fc09f
GitHub-Pull-Request: golang/oauth2#609
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/450155
Commit-Queue: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
diff --git a/deviceauth.go b/deviceauth.go
new file mode 100644
index 0000000..f3ea99e
--- /dev/null
+++ b/deviceauth.go
@@ -0,0 +1,188 @@
+package oauth2
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "golang.org/x/oauth2/internal"
+)
+
+// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
+const (
+ errAuthorizationPending = "authorization_pending"
+ errSlowDown = "slow_down"
+ errAccessDenied = "access_denied"
+ errExpiredToken = "expired_token"
+)
+
+// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
+// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
+type DeviceAuthResponse struct {
+ // DeviceCode
+ DeviceCode string `json:"device_code"`
+ // UserCode is the code the user should enter at the verification uri
+ UserCode string `json:"user_code"`
+ // VerificationURI is where user should enter the user code
+ VerificationURI string `json:"verification_uri"`
+ // VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
+ VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
+ // Expiry is when the device code and user code expire
+ Expiry time.Time `json:"expires_in,omitempty"`
+ // Interval is the duration in seconds that Poll should wait between requests
+ Interval int64 `json:"interval,omitempty"`
+}
+
+func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
+ type Alias DeviceAuthResponse
+ var expiresIn int64
+ if !d.Expiry.IsZero() {
+ expiresIn = int64(time.Until(d.Expiry).Seconds())
+ }
+ return json.Marshal(&struct {
+ ExpiresIn int64 `json:"expires_in,omitempty"`
+ *Alias
+ }{
+ ExpiresIn: expiresIn,
+ Alias: (*Alias)(&d),
+ })
+
+}
+
+func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
+ type Alias DeviceAuthResponse
+ aux := &struct {
+ ExpiresIn int64 `json:"expires_in"`
+ *Alias
+ }{
+ Alias: (*Alias)(c),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ if aux.ExpiresIn != 0 {
+ c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
+ }
+ return nil
+}
+
+// DeviceAuth returns a device auth struct which contains a device code
+// and authorization information provided for users to enter on another device.
+func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
+ // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
+ v := url.Values{
+ "client_id": {c.ClientID},
+ }
+ if len(c.Scopes) > 0 {
+ v.Set("scope", strings.Join(c.Scopes, " "))
+ }
+ for _, opt := range opts {
+ opt.setValue(v)
+ }
+ return retrieveDeviceAuth(ctx, c, v)
+}
+
+func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
+ req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ t := time.Now()
+ r, err := internal.ContextClient(ctx).Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
+ if err != nil {
+ return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
+ }
+ if code := r.StatusCode; code < 200 || code > 299 {
+ return nil, &RetrieveError{
+ Response: r,
+ Body: body,
+ }
+ }
+
+ da := &DeviceAuthResponse{}
+ err = json.Unmarshal(body, &da)
+ if err != nil {
+ return nil, fmt.Errorf("unmarshal %s", err)
+ }
+
+ if !da.Expiry.IsZero() {
+ // Make a small adjustment to account for time taken by the request
+ da.Expiry = da.Expiry.Add(-time.Since(t))
+ }
+
+ return da, nil
+}
+
+// DeviceAccessToken polls the server to exchange a device code for a token.
+func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
+ if !da.Expiry.IsZero() {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithDeadline(ctx, da.Expiry)
+ defer cancel()
+ }
+
+ // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
+ v := url.Values{
+ "client_id": {c.ClientID},
+ "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
+ "device_code": {da.DeviceCode},
+ }
+ if len(c.Scopes) > 0 {
+ v.Set("scope", strings.Join(c.Scopes, " "))
+ }
+ for _, opt := range opts {
+ opt.setValue(v)
+ }
+
+ // "If no value is provided, clients MUST use 5 as the default."
+ // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
+ interval := da.Interval
+ if interval == 0 {
+ interval = 5
+ }
+
+ ticker := time.NewTicker(time.Duration(interval) * time.Second)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-ticker.C:
+ tok, err := retrieveToken(ctx, c, v)
+ if err == nil {
+ return tok, nil
+ }
+
+ e, ok := err.(*RetrieveError)
+ if !ok {
+ return nil, err
+ }
+ switch e.ErrorCode {
+ case errSlowDown:
+ // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
+ // "the interval MUST be increased by 5 seconds for this and all subsequent requests"
+ interval += 5
+ ticker.Reset(time.Duration(interval) * time.Second)
+ case errAuthorizationPending:
+ // Do nothing.
+ case errAccessDenied, errExpiredToken:
+ fallthrough
+ default:
+ return tok, err
+ }
+ }
+ }
+}
diff --git a/deviceauth_test.go b/deviceauth_test.go
new file mode 100644
index 0000000..3b99620
--- /dev/null
+++ b/deviceauth_test.go
@@ -0,0 +1,97 @@
+package oauth2
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+)
+
+func TestDeviceAuthResponseMarshalJson(t *testing.T) {
+ tests := []struct {
+ name string
+ response DeviceAuthResponse
+ want string
+ }{
+ {
+ name: "empty",
+ response: DeviceAuthResponse{},
+ want: `{"device_code":"","user_code":"","verification_uri":""}`,
+ },
+ {
+ name: "soon",
+ response: DeviceAuthResponse{
+ Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
+ },
+ want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
+ },
+ }
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ begin := time.Now()
+ gotBytes, err := json.Marshal(tc.response)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
+ t.Skip("test ran too slowly to compare `expires_in`")
+ }
+ got := string(gotBytes)
+ if got != tc.want {
+ t.Errorf("want=%s, got=%s", tc.want, got)
+ }
+ })
+ }
+}
+
+func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
+ tests := []struct {
+ name string
+ data string
+ want DeviceAuthResponse
+ }{
+ {
+ name: "empty",
+ data: `{}`,
+ want: DeviceAuthResponse{},
+ },
+ {
+ name: "soon",
+ data: `{"expires_in":100}`,
+ want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
+ },
+ }
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ begin := time.Now()
+ got := DeviceAuthResponse{}
+ err := json.Unmarshal([]byte(tc.data), &got)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
+ t.Errorf("want=%#v, got=%#v", tc.want, got)
+ }
+ })
+ }
+}
+
+func ExampleConfig_DeviceAuth() {
+ var config Config
+ ctx := context.Background()
+ response, err := config.DeviceAuth(ctx)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
+ token, err := config.DeviceAccessToken(ctx, response)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(token)
+}
diff --git a/endpoints/endpoints.go b/endpoints/endpoints.go
index 7cc37c8..7fb3314 100644
--- a/endpoints/endpoints.go
+++ b/endpoints/endpoints.go
@@ -55,8 +55,9 @@
// GitHub is the endpoint for Github.
var GitHub = oauth2.Endpoint{
- AuthURL: "https://github.com/login/oauth/authorize",
- TokenURL: "https://github.com/login/oauth/access_token",
+ AuthURL: "https://github.com/login/oauth/authorize",
+ TokenURL: "https://github.com/login/oauth/access_token",
+ DeviceAuthURL: "https://github.com/login/device/code",
}
// GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@
var Google = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
+ DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
}
// Heroku is the endpoint for Heroku.
diff --git a/github/github.go b/github/github.go
index f297801..725a7c4 100644
--- a/github/github.go
+++ b/github/github.go
@@ -6,11 +6,8 @@
package github // import "golang.org/x/oauth2/github"
import (
- "golang.org/x/oauth2"
+ "golang.org/x/oauth2/endpoints"
)
// Endpoint is Github's OAuth 2.0 endpoint.
-var Endpoint = oauth2.Endpoint{
- AuthURL: "https://github.com/login/oauth/authorize",
- TokenURL: "https://github.com/login/oauth/access_token",
-}
+var Endpoint = endpoints.GitHub
diff --git a/google/google.go b/google/google.go
index cc12238..846683e 100644
--- a/google/google.go
+++ b/google/google.go
@@ -23,6 +23,7 @@
var Endpoint = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
+ DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
AuthStyle: oauth2.AuthStyleInParams,
}
diff --git a/oauth2.go b/oauth2.go
index cc7c98c..86a70e7 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -75,8 +75,9 @@
// Endpoint represents an OAuth 2.0 provider's authorization and token
// endpoint URLs.
type Endpoint struct {
- AuthURL string
- TokenURL string
+ AuthURL string
+ DeviceAuthURL string
+ TokenURL string
// AuthStyle optionally specifies how the endpoint wants the
// client ID & client secret sent. The zero value means to