oauth2: support device flow

Device Authorization Grant following RFC 8628 https://datatracker.ietf.org/doc/html/rfc8628

Tested with GitHub

Fixes #418

Fixes golang/go#58126

Co-authored-by: cmP <centimitr@gmail.com>

Change-Id: Id588867110c6a5289bf1026da5d7ead88f9c7d14
GitHub-Last-Rev: 9a126d7b534532c7d18fb8d6796ad673b95fc09f
GitHub-Pull-Request: golang/oauth2#609
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/450155
Commit-Queue: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
diff --git a/deviceauth.go b/deviceauth.go
new file mode 100644
index 0000000..f3ea99e
--- /dev/null
+++ b/deviceauth.go
@@ -0,0 +1,188 @@
+package oauth2
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"net/url"
+	"strings"
+	"time"
+
+	"golang.org/x/oauth2/internal"
+)
+
+// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
+const (
+	errAuthorizationPending = "authorization_pending"
+	errSlowDown             = "slow_down"
+	errAccessDenied         = "access_denied"
+	errExpiredToken         = "expired_token"
+)
+
+// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
+// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
+type DeviceAuthResponse struct {
+	// DeviceCode
+	DeviceCode string `json:"device_code"`
+	// UserCode is the code the user should enter at the verification uri
+	UserCode string `json:"user_code"`
+	// VerificationURI is where user should enter the user code
+	VerificationURI string `json:"verification_uri"`
+	// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
+	VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
+	// Expiry is when the device code and user code expire
+	Expiry time.Time `json:"expires_in,omitempty"`
+	// Interval is the duration in seconds that Poll should wait between requests
+	Interval int64 `json:"interval,omitempty"`
+}
+
+func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
+	type Alias DeviceAuthResponse
+	var expiresIn int64
+	if !d.Expiry.IsZero() {
+		expiresIn = int64(time.Until(d.Expiry).Seconds())
+	}
+	return json.Marshal(&struct {
+		ExpiresIn int64 `json:"expires_in,omitempty"`
+		*Alias
+	}{
+		ExpiresIn: expiresIn,
+		Alias:     (*Alias)(&d),
+	})
+
+}
+
+func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
+	type Alias DeviceAuthResponse
+	aux := &struct {
+		ExpiresIn int64 `json:"expires_in"`
+		*Alias
+	}{
+		Alias: (*Alias)(c),
+	}
+	if err := json.Unmarshal(data, &aux); err != nil {
+		return err
+	}
+	if aux.ExpiresIn != 0 {
+		c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
+	}
+	return nil
+}
+
+// DeviceAuth returns a device auth struct which contains a device code
+// and authorization information provided for users to enter on another device.
+func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
+	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
+	v := url.Values{
+		"client_id": {c.ClientID},
+	}
+	if len(c.Scopes) > 0 {
+		v.Set("scope", strings.Join(c.Scopes, " "))
+	}
+	for _, opt := range opts {
+		opt.setValue(v)
+	}
+	return retrieveDeviceAuth(ctx, c, v)
+}
+
+func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
+	req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	req.Header.Set("Accept", "application/json")
+
+	t := time.Now()
+	r, err := internal.ContextClient(ctx).Do(req)
+	if err != nil {
+		return nil, err
+	}
+
+	body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
+	if err != nil {
+		return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
+	}
+	if code := r.StatusCode; code < 200 || code > 299 {
+		return nil, &RetrieveError{
+			Response: r,
+			Body:     body,
+		}
+	}
+
+	da := &DeviceAuthResponse{}
+	err = json.Unmarshal(body, &da)
+	if err != nil {
+		return nil, fmt.Errorf("unmarshal %s", err)
+	}
+
+	if !da.Expiry.IsZero() {
+		// Make a small adjustment to account for time taken by the request
+		da.Expiry = da.Expiry.Add(-time.Since(t))
+	}
+
+	return da, nil
+}
+
+// DeviceAccessToken polls the server to exchange a device code for a token.
+func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
+	if !da.Expiry.IsZero() {
+		var cancel context.CancelFunc
+		ctx, cancel = context.WithDeadline(ctx, da.Expiry)
+		defer cancel()
+	}
+
+	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
+	v := url.Values{
+		"client_id":   {c.ClientID},
+		"grant_type":  {"urn:ietf:params:oauth:grant-type:device_code"},
+		"device_code": {da.DeviceCode},
+	}
+	if len(c.Scopes) > 0 {
+		v.Set("scope", strings.Join(c.Scopes, " "))
+	}
+	for _, opt := range opts {
+		opt.setValue(v)
+	}
+
+	// "If no value is provided, clients MUST use 5 as the default."
+	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
+	interval := da.Interval
+	if interval == 0 {
+		interval = 5
+	}
+
+	ticker := time.NewTicker(time.Duration(interval) * time.Second)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case <-ticker.C:
+			tok, err := retrieveToken(ctx, c, v)
+			if err == nil {
+				return tok, nil
+			}
+
+			e, ok := err.(*RetrieveError)
+			if !ok {
+				return nil, err
+			}
+			switch e.ErrorCode {
+			case errSlowDown:
+				// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
+				// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
+				interval += 5
+				ticker.Reset(time.Duration(interval) * time.Second)
+			case errAuthorizationPending:
+				// Do nothing.
+			case errAccessDenied, errExpiredToken:
+				fallthrough
+			default:
+				return tok, err
+			}
+		}
+	}
+}
diff --git a/deviceauth_test.go b/deviceauth_test.go
new file mode 100644
index 0000000..3b99620
--- /dev/null
+++ b/deviceauth_test.go
@@ -0,0 +1,97 @@
+package oauth2
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
+)
+
+func TestDeviceAuthResponseMarshalJson(t *testing.T) {
+	tests := []struct {
+		name     string
+		response DeviceAuthResponse
+		want     string
+	}{
+		{
+			name:     "empty",
+			response: DeviceAuthResponse{},
+			want:     `{"device_code":"","user_code":"","verification_uri":""}`,
+		},
+		{
+			name: "soon",
+			response: DeviceAuthResponse{
+				Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
+			},
+			want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			begin := time.Now()
+			gotBytes, err := json.Marshal(tc.response)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
+				t.Skip("test ran too slowly to compare `expires_in`")
+			}
+			got := string(gotBytes)
+			if got != tc.want {
+				t.Errorf("want=%s, got=%s", tc.want, got)
+			}
+		})
+	}
+}
+
+func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
+	tests := []struct {
+		name string
+		data string
+		want DeviceAuthResponse
+	}{
+		{
+			name: "empty",
+			data: `{}`,
+			want: DeviceAuthResponse{},
+		},
+		{
+			name: "soon",
+			data: `{"expires_in":100}`,
+			want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			begin := time.Now()
+			got := DeviceAuthResponse{}
+			err := json.Unmarshal([]byte(tc.data), &got)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
+				t.Errorf("want=%#v, got=%#v", tc.want, got)
+			}
+		})
+	}
+}
+
+func ExampleConfig_DeviceAuth() {
+	var config Config
+	ctx := context.Background()
+	response, err := config.DeviceAuth(ctx)
+	if err != nil {
+		panic(err)
+	}
+	fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
+	token, err := config.DeviceAccessToken(ctx, response)
+	if err != nil {
+		panic(err)
+	}
+	fmt.Println(token)
+}
diff --git a/endpoints/endpoints.go b/endpoints/endpoints.go
index 7cc37c8..7fb3314 100644
--- a/endpoints/endpoints.go
+++ b/endpoints/endpoints.go
@@ -55,8 +55,9 @@
 
 // GitHub is the endpoint for Github.
 var GitHub = oauth2.Endpoint{
-	AuthURL:  "https://github.com/login/oauth/authorize",
-	TokenURL: "https://github.com/login/oauth/access_token",
+	AuthURL:       "https://github.com/login/oauth/authorize",
+	TokenURL:      "https://github.com/login/oauth/access_token",
+	DeviceAuthURL: "https://github.com/login/device/code",
 }
 
 // GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@
 var Google = oauth2.Endpoint{
 	AuthURL:  "https://accounts.google.com/o/oauth2/auth",
 	TokenURL: "https://oauth2.googleapis.com/token",
+	DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
 }
 
 // Heroku is the endpoint for Heroku.
diff --git a/github/github.go b/github/github.go
index f297801..725a7c4 100644
--- a/github/github.go
+++ b/github/github.go
@@ -6,11 +6,8 @@
 package github // import "golang.org/x/oauth2/github"
 
 import (
-	"golang.org/x/oauth2"
+	"golang.org/x/oauth2/endpoints"
 )
 
 // Endpoint is Github's OAuth 2.0 endpoint.
-var Endpoint = oauth2.Endpoint{
-	AuthURL:  "https://github.com/login/oauth/authorize",
-	TokenURL: "https://github.com/login/oauth/access_token",
-}
+var Endpoint = endpoints.GitHub
diff --git a/google/google.go b/google/google.go
index cc12238..846683e 100644
--- a/google/google.go
+++ b/google/google.go
@@ -23,6 +23,7 @@
 var Endpoint = oauth2.Endpoint{
 	AuthURL:   "https://accounts.google.com/o/oauth2/auth",
 	TokenURL:  "https://oauth2.googleapis.com/token",
+	DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
 	AuthStyle: oauth2.AuthStyleInParams,
 }
 
diff --git a/oauth2.go b/oauth2.go
index cc7c98c..86a70e7 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -75,8 +75,9 @@
 // Endpoint represents an OAuth 2.0 provider's authorization and token
 // endpoint URLs.
 type Endpoint struct {
-	AuthURL  string
-	TokenURL string
+	AuthURL       string
+	DeviceAuthURL string
+	TokenURL      string
 
 	// AuthStyle optionally specifies how the endpoint wants the
 	// client ID & client secret sent. The zero value means to