google/externalaccount: validate tokenURL and ServiceAccountImpersonationURL

Change-Id: Iab70cc967fd97ac8e349a14760df0f8b02ddf074
GitHub-Last-Rev: ddf4dbd0b7096a0d34677047b9c3992bb6ed359b
GitHub-Pull-Request: golang/oauth2#514
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/340569
Reviewed-by: Patrick Jones <ithuriel@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Reviewed-by: Chris Broadfoot <cbro@golang.org>
Trust: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
diff --git a/google/google.go b/google/google.go
index 2b631f5..422ff1f 100644
--- a/google/google.go
+++ b/google/google.go
@@ -177,7 +177,7 @@
 			QuotaProjectID:                 f.QuotaProjectID,
 			Scopes:                         params.Scopes,
 		}
-		return cfg.TokenSource(ctx), nil
+		return cfg.TokenSource(ctx)
 	case "":
 		return nil, errors.New("missing 'type' field in credentials")
 	default:
diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go
index 669ba1e..b1f592c 100644
--- a/google/internal/externalaccount/aws_test.go
+++ b/google/internal/externalaccount/aws_test.go
@@ -28,8 +28,7 @@
 
 func setEnvironment(env map[string]string) func(string) string {
 	return func(key string) string {
-		value, _ := env[key]
-		return value
+		return env[key]
 	}
 }
 
@@ -650,7 +649,7 @@
 	getenv = setEnvironment(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
 		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
-		"AWS_DEFAULT_REGION":            "us-west-1",
+		"AWS_DEFAULT_REGION":    "us-west-1",
 	})
 
 	base, err := tfc.parse(context.Background())
@@ -688,7 +687,7 @@
 		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
 		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
 		"AWS_REGION":            "us-west-1",
-		"AWS_DEFAULT_REGION":            "us-east-1",
+		"AWS_DEFAULT_REGION":    "us-east-1",
 	})
 
 	base, err := tfc.parse(context.Background())
diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go
index a4d45d9..dab917f 100644
--- a/google/internal/externalaccount/basecredentials.go
+++ b/google/internal/externalaccount/basecredentials.go
@@ -7,10 +7,14 @@
 import (
 	"context"
 	"fmt"
-	"golang.org/x/oauth2"
 	"net/http"
+	"net/url"
+	"regexp"
 	"strconv"
+	"strings"
 	"time"
+
+	"golang.org/x/oauth2"
 )
 
 // now aliases time.Now for testing
@@ -22,43 +26,101 @@
 type Config struct {
 	// Audience is the Secure Token Service (STS) audience which contains the resource name for the workload
 	// identity pool or the workforce pool and the provider identifier in that pool.
-	Audience                       string
+	Audience string
 	// SubjectTokenType is the STS token type based on the Oauth2.0 token exchange spec
 	// e.g. `urn:ietf:params:oauth:token-type:jwt`.
-	SubjectTokenType               string
+	SubjectTokenType string
 	// TokenURL is the STS token exchange endpoint.
-	TokenURL                       string
+	TokenURL string
 	// TokenInfoURL is the token_info endpoint used to retrieve the account related information (
 	// user attributes like account identifier, eg. email, username, uid, etc). This is
 	// needed for gCloud session account identification.
-	TokenInfoURL                   string
+	TokenInfoURL string
 	// ServiceAccountImpersonationURL is the URL for the service account impersonation request. This is only
 	// required for workload identity pools when APIs to be accessed have not integrated with UberMint.
 	ServiceAccountImpersonationURL string
 	// ClientSecret is currently only required if token_info endpoint also
 	// needs to be called with the generated GCP access token. When provided, STS will be
 	// called with additional basic authentication using client_id as username and client_secret as password.
-	ClientSecret                   string
+	ClientSecret string
 	// ClientID is only required in conjunction with ClientSecret, as described above.
-	ClientID                       string
+	ClientID string
 	// CredentialSource contains the necessary information to retrieve the token itself, as well
 	// as some environmental information.
-	CredentialSource               CredentialSource
+	CredentialSource CredentialSource
 	// QuotaProjectID is injected by gCloud. If the value is non-empty, the Auth libraries
 	// will set the x-goog-user-project which overrides the project associated with the credentials.
-	QuotaProjectID                 string
+	QuotaProjectID string
 	// Scopes contains the desired scopes for the returned access token.
-	Scopes                         []string
+	Scopes []string
+}
+
+// Each element consists of a list of patterns.  validateURLs checks for matches
+// that include all elements in a given list, in that order.
+
+var (
+	validTokenURLPatterns = []*regexp.Regexp{
+		// The complicated part in the middle matches any number of characters that
+		// aren't period, spaces, or slashes.
+		regexp.MustCompile(`(?i)^[^\.\s\/\\]+\.sts\.googleapis\.com$`),
+		regexp.MustCompile(`(?i)^sts\.googleapis\.com$`),
+		regexp.MustCompile(`(?i)^sts\.[^\.\s\/\\]+\.googleapis\.com$`),
+		regexp.MustCompile(`(?i)^[^\.\s\/\\]+-sts\.googleapis\.com$`),
+	}
+	validImpersonateURLPatterns = []*regexp.Regexp{
+		regexp.MustCompile(`^[^\.\s\/\\]+\.iamcredentials\.googleapis\.com$`),
+		regexp.MustCompile(`^iamcredentials\.googleapis\.com$`),
+		regexp.MustCompile(`^iamcredentials\.[^\.\s\/\\]+\.googleapis\.com$`),
+		regexp.MustCompile(`^[^\.\s\/\\]+-iamcredentials\.googleapis\.com$`),
+	}
+)
+
+func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
+	parsed, err := url.Parse(input)
+	if err != nil {
+		return false
+	}
+	if !strings.EqualFold(parsed.Scheme, scheme) {
+		return false
+	}
+	toTest := parsed.Host
+
+	for _, pattern := range patterns {
+
+		if valid := pattern.MatchString(toTest); valid {
+			return true
+		}
+	}
+	return false
 }
 
 // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials.
-func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
+func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
+	return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns, "https")
+}
+
+// tokenSource is a private function that's directly called by some of the tests,
+// because the unit test URLs are mocked, and would otherwise fail the
+// validity check.
+func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp, scheme string) (oauth2.TokenSource, error) {
+	valid := validateURL(c.TokenURL, tokenURLValidPats, scheme)
+	if !valid {
+		return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource")
+	}
+
+	if c.ServiceAccountImpersonationURL != "" {
+		valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats, scheme)
+		if !valid {
+			return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource")
+		}
+	}
+
 	ts := tokenSource{
 		ctx:  ctx,
 		conf: c,
 	}
 	if c.ServiceAccountImpersonationURL == "" {
-		return oauth2.ReuseTokenSource(nil, ts)
+		return oauth2.ReuseTokenSource(nil, ts), nil
 	}
 	scopes := c.Scopes
 	ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
@@ -68,7 +130,7 @@
 		scopes: scopes,
 		ts:     oauth2.ReuseTokenSource(nil, ts),
 	}
-	return oauth2.ReuseTokenSource(nil, imp)
+	return oauth2.ReuseTokenSource(nil, imp), nil
 }
 
 // Subject token file types.
@@ -78,9 +140,9 @@
 )
 
 type format struct {
-	// Type is either "text" or "json".  When not provided "text" type is assumed.
+	// Type is either "text" or "json". When not provided "text" type is assumed.
 	Type string `json:"type"`
-	// SubjectTokenFieldName is only required for JSON format.  This would be "access_token" for azure.
+	// SubjectTokenFieldName is only required for JSON format. This would be "access_token" for azure.
 	SubjectTokenFieldName string `json:"subject_token_field_name"`
 }
 
@@ -128,7 +190,7 @@
 	subjectToken() (string, error)
 }
 
-// tokenSource is the source that handles external credentials.  It is used to retrieve Tokens.
+// tokenSource is the source that handles external credentials. It is used to retrieve Tokens.
 type tokenSource struct {
 	ctx  context.Context
 	conf *Config
diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go
index 1ebb227..b1131d6 100644
--- a/google/internal/externalaccount/basecredentials_test.go
+++ b/google/internal/externalaccount/basecredentials_test.go
@@ -9,6 +9,7 @@
 	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
+	"strings"
 	"testing"
 	"time"
 )
@@ -95,3 +96,117 @@
 	}
 
 }
+
+func TestValidateURLTokenURL(t *testing.T) {
+	var urlValidityTests = []struct {
+		tokURL        string
+		expectSuccess bool
+	}{
+		{"https://east.sts.googleapis.com", true},
+		{"https://sts.googleapis.com", true},
+		{"https://sts.asfeasfesef.googleapis.com", true},
+		{"https://us-east-1-sts.googleapis.com", true},
+		{"https://sts.googleapis.com/your/path/here", true},
+		{"https://.sts.googleapis.com", false},
+		{"https://badsts.googleapis.com", false},
+		{"https://sts.asfe.asfesef.googleapis.com", false},
+		{"https://sts..googleapis.com", false},
+		{"https://-sts.googleapis.com", false},
+		{"https://us-ea.st-1-sts.googleapis.com", false},
+		{"https://sts.googleapis.com.evil.com/whatever/path", false},
+		{"https://us-eas\\t-1.sts.googleapis.com", false},
+		{"https:/us-ea/st-1.sts.googleapis.com", false},
+		{"https:/us-east 1.sts.googleapis.com", false},
+		{"https://", false},
+		{"http://us-east-1.sts.googleapis.com", false},
+		{"https://us-east-1.sts.googleapis.comevil.com", false},
+	}
+	ctx := context.Background()
+	for _, tt := range urlValidityTests {
+		t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
+			config := testConfig
+			config.TokenURL = tt.tokURL
+			_, err := config.TokenSource(ctx)
+
+			if tt.expectSuccess && err != nil {
+				t.Errorf("got %v but want nil", err)
+			} else if !tt.expectSuccess && err == nil {
+				t.Errorf("got nil but expected an error")
+			}
+		})
+	}
+	for _, el := range urlValidityTests {
+		el.tokURL = strings.ToUpper(el.tokURL)
+	}
+	for _, tt := range urlValidityTests {
+		t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
+			config := testConfig
+			config.TokenURL = tt.tokURL
+			_, err := config.TokenSource(ctx)
+
+			if tt.expectSuccess && err != nil {
+				t.Errorf("got %v but want nil", err)
+			} else if !tt.expectSuccess && err == nil {
+				t.Errorf("got nil but expected an error")
+			}
+		})
+	}
+}
+
+func TestValidateURLImpersonateURL(t *testing.T) {
+	var urlValidityTests = []struct {
+		impURL        string
+		expectSuccess bool
+	}{
+		{"https://east.iamcredentials.googleapis.com", true},
+		{"https://iamcredentials.googleapis.com", true},
+		{"https://iamcredentials.asfeasfesef.googleapis.com", true},
+		{"https://us-east-1-iamcredentials.googleapis.com", true},
+		{"https://iamcredentials.googleapis.com/your/path/here", true},
+		{"https://.iamcredentials.googleapis.com", false},
+		{"https://badiamcredentials.googleapis.com", false},
+		{"https://iamcredentials.asfe.asfesef.googleapis.com", false},
+		{"https://iamcredentials..googleapis.com", false},
+		{"https://-iamcredentials.googleapis.com", false},
+		{"https://us-ea.st-1-iamcredentials.googleapis.com", false},
+		{"https://iamcredentials.googleapis.com.evil.com/whatever/path", false},
+		{"https://us-eas\\t-1.iamcredentials.googleapis.com", false},
+		{"https:/us-ea/st-1.iamcredentials.googleapis.com", false},
+		{"https:/us-east 1.iamcredentials.googleapis.com", false},
+		{"https://", false},
+		{"http://us-east-1.iamcredentials.googleapis.com", false},
+		{"https://us-east-1.iamcredentials.googleapis.comevil.com", false},
+	}
+	ctx := context.Background()
+	for _, tt := range urlValidityTests {
+		t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
+			config := testConfig
+			config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
+			config.ServiceAccountImpersonationURL = tt.impURL
+			_, err := config.TokenSource(ctx)
+
+			if tt.expectSuccess && err != nil {
+				t.Errorf("got %v but want nil", err)
+			} else if !tt.expectSuccess && err == nil {
+				t.Errorf("got nil but expected an error")
+			}
+		})
+	}
+	for _, el := range urlValidityTests {
+		el.impURL = strings.ToUpper(el.impURL)
+	}
+	for _, tt := range urlValidityTests {
+		t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
+			config := testConfig
+			config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
+			config.ServiceAccountImpersonationURL = tt.impURL
+			_, err := config.TokenSource(ctx)
+
+			if tt.expectSuccess && err != nil {
+				t.Errorf("got %v but want nil", err)
+			} else if !tt.expectSuccess && err == nil {
+				t.Errorf("got nil but expected an error")
+			}
+		})
+	}
+}
diff --git a/google/internal/externalaccount/clientauth.go b/google/internal/externalaccount/clientauth.go
index 62c2e36..99987ce 100644
--- a/google/internal/externalaccount/clientauth.go
+++ b/google/internal/externalaccount/clientauth.go
@@ -6,9 +6,10 @@
 
 import (
 	"encoding/base64"
-	"golang.org/x/oauth2"
 	"net/http"
 	"net/url"
+
+	"golang.org/x/oauth2"
 )
 
 // clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
diff --git a/google/internal/externalaccount/clientauth_test.go b/google/internal/externalaccount/clientauth_test.go
index 38633e3..bfb339d 100644
--- a/google/internal/externalaccount/clientauth_test.go
+++ b/google/internal/externalaccount/clientauth_test.go
@@ -5,11 +5,12 @@
 package externalaccount
 
 import (
-	"golang.org/x/oauth2"
 	"net/http"
 	"net/url"
 	"reflect"
 	"testing"
+
+	"golang.org/x/oauth2"
 )
 
 var clientID = "rbrgnognrhongo3bi4gb9ghg9g"
diff --git a/google/internal/externalaccount/impersonate.go b/google/internal/externalaccount/impersonate.go
index 1f6009b..64edb56 100644
--- a/google/internal/externalaccount/impersonate.go
+++ b/google/internal/externalaccount/impersonate.go
@@ -9,11 +9,12 @@
 	"context"
 	"encoding/json"
 	"fmt"
-	"golang.org/x/oauth2"
 	"io"
 	"io/ioutil"
 	"net/http"
 	"time"
+
+	"golang.org/x/oauth2"
 )
 
 // generateAccesstokenReq is used for service account impersonation
diff --git a/google/internal/externalaccount/impersonate_test.go b/google/internal/externalaccount/impersonate_test.go
index 197fe3c..6fed7b9 100644
--- a/google/internal/externalaccount/impersonate_test.go
+++ b/google/internal/externalaccount/impersonate_test.go
@@ -9,6 +9,7 @@
 	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
+	"regexp"
 	"testing"
 )
 
@@ -76,7 +77,11 @@
 	defer targetServer.Close()
 
 	testImpersonateConfig.TokenURL = targetServer.URL
-	ourTS := testImpersonateConfig.TokenSource(context.Background())
+	allURLs := regexp.MustCompile(".+")
+	ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}, "http")
+	if err != nil {
+		t.Fatalf("Failed to create TokenSource: %v", err)
+	}
 
 	oldNow := now
 	defer func() { now = oldNow }()
diff --git a/google/internal/externalaccount/sts_exchange.go b/google/internal/externalaccount/sts_exchange.go
index a8a704b..e6fcae5 100644
--- a/google/internal/externalaccount/sts_exchange.go
+++ b/google/internal/externalaccount/sts_exchange.go
@@ -65,6 +65,9 @@
 	defer resp.Body.Close()
 
 	body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
+	if err != nil {
+		return nil, err
+	}
 	if c := resp.StatusCode; c < 200 || c > 299 {
 		return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
 	}
diff --git a/google/internal/externalaccount/sts_exchange_test.go b/google/internal/externalaccount/sts_exchange_test.go
index 3d498c6..df4d5ff 100644
--- a/google/internal/externalaccount/sts_exchange_test.go
+++ b/google/internal/externalaccount/sts_exchange_test.go
@@ -7,12 +7,13 @@
 import (
 	"context"
 	"encoding/json"
-	"golang.org/x/oauth2"
 	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"net/url"
 	"testing"
+
+	"golang.org/x/oauth2"
 )
 
 var auth = clientAuthentication{
@@ -127,6 +128,9 @@
 		}
 		var opts map[string]interface{}
 		err = json.Unmarshal([]byte(strOpts[0]), &opts)
+		if err != nil {
+			t.Fatalf("Couldn't parse received \"options\" field.")
+		}
 		if len(opts) < 2 {
 			t.Errorf("Too few options received.")
 		}
diff --git a/google/internal/externalaccount/urlcredsource.go b/google/internal/externalaccount/urlcredsource.go
index 91b8f20..16dca65 100644
--- a/google/internal/externalaccount/urlcredsource.go
+++ b/google/internal/externalaccount/urlcredsource.go
@@ -9,10 +9,11 @@
 	"encoding/json"
 	"errors"
 	"fmt"
-	"golang.org/x/oauth2"
 	"io"
 	"io/ioutil"
 	"net/http"
+
+	"golang.org/x/oauth2"
 )
 
 type urlCredentialSource struct {