oauth2: use a JSON struct types instead of empty interface maps

Change-Id: Ifd66ea35c15dbd14acca0c945b533ec755de12e4
Reviewed-on: https://go-review.googlesource.com/1872
Reviewed-by: Burcu Dogan <jbd@google.com>
diff --git a/jwt.go b/jwt.go
index 7c02e3d..d7ec0d3 100644
--- a/jwt.go
+++ b/jwt.go
@@ -123,25 +123,33 @@
 	if c := resp.StatusCode; c < 200 || c > 299 {
 		return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
 	}
-	b := make(map[string]interface{})
-	if err := json.Unmarshal(body, &b); err != nil {
+	// tokenRes is the JSON response body.
+	var tokenRes struct {
+		AccessToken string `json:"access_token"`
+		TokenType   string `json:"token_type"`
+		IDToken     string `json:"id_token"`
+		ExpiresIn   int64  `json:"expires_in"` // relative seconds from now
+	}
+	if err := json.Unmarshal(body, &tokenRes); err != nil {
 		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
 	}
-	token := &Token{}
-	token.AccessToken, _ = b["access_token"].(string)
-	token.TokenType, _ = b["token_type"].(string)
-	token.raw = b
-	if e, ok := b["expires_in"].(float64); ok {
-		token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
+	token := &Token{
+		AccessToken: tokenRes.AccessToken,
+		TokenType:   tokenRes.TokenType,
+		raw:         make(map[string]interface{}),
 	}
-	if idtoken, ok := b["id_token"].(string); ok {
+	json.Unmarshal(body, &token.raw) // no error checks for optional fields
+
+	if secs := tokenRes.ExpiresIn; secs > 0 {
+		token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
+	}
+	if v := tokenRes.IDToken; v != "" {
 		// decode returned id token to get expiry
-		claimSet, err := jws.Decode(idtoken)
+		claimSet, err := jws.Decode(v)
 		if err != nil {
-			return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+			return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
 		}
 		token.Expiry = time.Unix(claimSet.Exp, 0)
-		return token, nil
 	}
 	return token, nil
 }
diff --git a/jwt_test.go b/jwt_test.go
index 3cc5671..8c2e62e 100644
--- a/jwt_test.go
+++ b/jwt_test.go
@@ -117,10 +117,10 @@
 		TokenURL:   ts.URL,
 	}
 	tok, err := conf.TokenSource(NoContext, nil).Token()
-	if err != nil {
-		t.Fatal(err)
-	}
-	if tok.AccessToken != "" {
-		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
+	if err == nil {
+		t.Error("got a token; expected error")
+		if tok.AccessToken != "" {
+			t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
+		}
 	}
 }
diff --git a/oauth2.go b/oauth2.go
index cdb836d..3644683 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -300,8 +300,7 @@
 		return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
 	}
 
-	token := &Token{}
-	expires := 0
+	var token *Token
 	content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
 	switch content {
 	case "application/x-www-form-urlencoded", "text/plain":
@@ -309,10 +308,12 @@
 		if err != nil {
 			return nil, err
 		}
-		token.AccessToken = vals.Get("access_token")
-		token.TokenType = vals.Get("token_type")
-		token.RefreshToken = vals.Get("refresh_token")
-		token.raw = vals
+		token = &Token{
+			AccessToken:  vals.Get("access_token"),
+			TokenType:    vals.Get("token_type"),
+			RefreshToken: vals.Get("refresh_token"),
+			raw:          vals,
+		}
 		e := vals.Get("expires_in")
 		if e == "" {
 			// TODO(jbd): Facebook's OAuth2 implementation is broken and
@@ -320,38 +321,52 @@
 			// when Facebook fixes their implementation.
 			e = vals.Get("expires")
 		}
-		expires, _ = strconv.Atoi(e)
+		expires, _ := strconv.Atoi(e)
+		if expires != 0 {
+			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
+		}
 	default:
-		b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type
-		if err = json.Unmarshal(body, &b); err != nil {
+		var tj tokenJSON
+		if err = json.Unmarshal(body, &tj); err != nil {
 			return nil, err
 		}
-		token.AccessToken, _ = b["access_token"].(string)
-		token.TokenType, _ = b["token_type"].(string)
-		token.RefreshToken, _ = b["refresh_token"].(string)
-		token.raw = b
-		e, ok := b["expires_in"].(float64)
-		if !ok {
-			// TODO(jbd): Facebook's OAuth2 implementation is broken and
-			// returns expires_in field in expires. Remove the fallback to expires,
-			// when Facebook fixes their implementation.
-			e, _ = b["expires"].(float64)
+		token = &Token{
+			AccessToken:  tj.AccessToken,
+			TokenType:    tj.TokenType,
+			RefreshToken: tj.RefreshToken,
+			Expiry:       tj.expiry(),
+			raw:          make(map[string]interface{}),
 		}
-		expires = int(e)
+		json.Unmarshal(body, &token.raw) // no error checks for optional fields
 	}
 	// Don't overwrite `RefreshToken` with an empty value
 	// if this was a token refreshing request.
 	if token.RefreshToken == "" {
 		token.RefreshToken = v.Get("refresh_token")
 	}
-	if expires == 0 {
-		token.Expiry = time.Time{}
-	} else {
-		token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
-	}
 	return token, nil
 }
 
+// tokenJSON is the struct representing the HTTP response from OAuth2
+// providers returning a token in JSON form.
+type tokenJSON struct {
+	AccessToken  string `json:"access_token"`
+	TokenType    string `json:"token_type"`
+	RefreshToken string `json:"refresh_token"`
+	ExpiresIn    int32  `json:"expires_in"`
+	Expires      int32  `json:"expires"` // broken Facebook spelling of expires_in
+}
+
+func (e *tokenJSON) expiry() (t time.Time) {
+	if v := e.ExpiresIn; v != 0 {
+		return time.Now().Add(time.Duration(v) * time.Second)
+	}
+	if v := e.Expires; v != 0 {
+		return time.Now().Add(time.Duration(v) * time.Second)
+	}
+	return
+}
+
 func condVal(v string) []string {
 	if v == "" {
 		return nil
diff --git a/oauth2_test.go b/oauth2_test.go
index 8159b86..c567c3a 100644
--- a/oauth2_test.go
+++ b/oauth2_test.go
@@ -181,12 +181,9 @@
 	}))
 	defer ts.Close()
 	conf := newConf(ts.URL)
-	tok, err := conf.Exchange(NoContext, "exchange-code")
-	if err != nil {
-		t.Error(err)
-	}
-	if tok.AccessToken != "" {
-		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
+	_, err := conf.Exchange(NoContext, "exchange-code")
+	if err == nil {
+		t.Error("expected error from invalid access_token type")
 	}
 }