google/internal/externalaccount: Adding metadata verification

Change-Id: I4d664862b7b287131c1481b238ebd0875f7c233b
GitHub-Last-Rev: 74bcc33f5ed4863c740aaf09ad4ee3ac4366e8e1
GitHub-Pull-Request: golang/oauth2#608
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/449975
Run-TryBot: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go
index e917195..6318a23 100644
--- a/google/internal/externalaccount/aws.go
+++ b/google/internal/externalaccount/aws.go
@@ -267,6 +267,49 @@
 	Headers []awsRequestHeader `json:"headers"`
 }
 
+func (cs awsCredentialSource) validateMetadataServers() error {
+	if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
+		return err
+	}
+	if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
+		return err
+	}
+	return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
+}
+
+var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
+
+func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
+	if metadataUrl == "" {
+		// Zero value means use default, which is valid.
+		return true
+	}
+
+	u, err := url.Parse(metadataUrl)
+	if err != nil {
+		// Unparseable URL means invalid
+		return false
+	}
+
+	for _, validHostname := range validHostnames {
+		if u.Hostname() == validHostname {
+			// If it's one of the valid hostnames, everything is good
+			return true
+		}
+	}
+
+	// hostname not found in our allowlist, so not valid
+	return false
+}
+
+func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
+	if !cs.isValidMetadataServer(metadataUrl) {
+		return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
+	}
+
+	return nil
+}
+
 func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
 	if cs.client == nil {
 		cs.client = oauth2.NewClient(cs.ctx, nil)
diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go
index 0934389..30a003a 100644
--- a/google/internal/externalaccount/aws_test.go
+++ b/google/internal/externalaccount/aws_test.go
@@ -553,16 +553,25 @@
 func TestAWSCredential_BasicRequest(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
-	getenv = setEnvironment(map[string]string{})
 	oldNow := now
-	defer func() { now = oldNow }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
+	getenv = setEnvironment(map[string]string{})
 	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -618,16 +627,25 @@
 		validateSessionTokenHeaders,
 	)
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
-	getenv = setEnvironment(map[string]string{})
 	oldNow := now
-	defer func() { now = oldNow }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
+	getenv = setEnvironment(map[string]string{})
 	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -655,17 +673,26 @@
 func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 	delete(server.Credentials, "Token")
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
-	getenv = setEnvironment(map[string]string{})
 	oldNow := now
-	defer func() { now = oldNow }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
+	getenv = setEnvironment(map[string]string{})
 	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -693,20 +720,29 @@
 func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldNow := now
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
 		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
 		"AWS_REGION":            "us-west-1",
 	})
-	oldNow := now
-	defer func() { now = oldNow }()
 	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -734,20 +770,29 @@
 func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldNow := now
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
 		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
-		"AWS_DEFAULT_REGION":    "us-west-1",
+		"AWS_REGION":            "us-west-1",
 	})
-	oldNow := now
-	defer func() { now = oldNow }()
 	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -774,21 +819,30 @@
 func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldNow := now
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
 		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
 		"AWS_REGION":            "us-west-1",
 		"AWS_DEFAULT_REGION":    "us-east-1",
 	})
-	oldNow := now
-	defer func() { now = oldNow }()
 	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -815,16 +869,25 @@
 func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 	tfc.CredentialSource.EnvironmentID = "aws3"
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
-	_, err := tfc.parse(context.Background())
+	_, err = tfc.parse(context.Background())
 	if err == nil {
 		t.Fatalf("parse() should have failed")
 	}
@@ -836,14 +899,23 @@
 func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 	tfc.CredentialSource.RegionURL = ""
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -863,14 +935,23 @@
 func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 	server.WriteRegion = notFound
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -890,6 +971,10 @@
 func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 	server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
 		w.Write([]byte("{}"))
 	}
@@ -898,8 +983,13 @@
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -919,6 +1009,10 @@
 func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 	server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
 		w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
 	}
@@ -927,8 +1021,13 @@
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -948,14 +1047,23 @@
 func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 	tfc.CredentialSource.URL = ""
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -975,14 +1083,23 @@
 func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 	server.WriteRolename = notFound
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -1002,14 +1119,23 @@
 func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
 	server := createDefaultAwsTestServer()
 	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
 	server.WriteSecurityCredentials = notFound
 
 	tfc := testFileConfig
 	tfc.CredentialSource = server.getCredentialSource(ts.URL)
 
 	oldGetenv := getenv
-	defer func() { getenv = oldGetenv }()
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		validHostnames = oldValidHostnames
+	}()
 	getenv = setEnvironment(map[string]string{})
+	validHostnames = []string{tsURL.Hostname()}
 
 	base, err := tfc.parse(context.Background())
 	if err != nil {
@@ -1025,3 +1151,88 @@
 		t.Errorf("subjectToken = %q, want %q", got, want)
 	}
 }
+
+func TestAWSCredential_Validations(t *testing.T) {
+	var metadataServerValidityTests = []struct {
+		name       string
+		credSource CredentialSource
+		errText    string
+	}{
+		{
+			name: "No Metadata Server URLs",
+			credSource: CredentialSource{
+				EnvironmentID:         "aws1",
+				RegionURL:             "",
+				URL:                   "",
+				IMDSv2SessionTokenURL: "",
+			},
+		}, {
+			name: "IPv4 Metadata Server URLs",
+			credSource: CredentialSource{
+				EnvironmentID:         "aws1",
+				RegionURL:             "http://169.254.169.254/latest/meta-data/placement/availability-zone",
+				URL:                   "http://169.254.169.254/latest/meta-data/iam/security-credentials",
+				IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
+			},
+		}, {
+			name: "IPv6 Metadata Server URLs",
+			credSource: CredentialSource{
+				EnvironmentID:         "aws1",
+				RegionURL:             "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone",
+				URL:                   "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials",
+				IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token",
+			},
+		}, {
+			name: "Faulty RegionURL",
+			credSource: CredentialSource{
+				EnvironmentID:         "aws1",
+				RegionURL:             "http://abc.com/latest/meta-data/placement/availability-zone",
+				URL:                   "http://169.254.169.254/latest/meta-data/iam/security-credentials",
+				IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
+			},
+			errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url",
+		}, {
+			name: "Faulty CredVerificationURL",
+			credSource: CredentialSource{
+				EnvironmentID:         "aws1",
+				RegionURL:             "http://169.254.169.254/latest/meta-data/placement/availability-zone",
+				URL:                   "http://abc.com/latest/meta-data/iam/security-credentials",
+				IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
+			},
+			errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url",
+		}, {
+			name: "Faulty IMDSv2SessionTokenURL",
+			credSource: CredentialSource{
+				EnvironmentID:         "aws1",
+				RegionURL:             "http://169.254.169.254/latest/meta-data/placement/availability-zone",
+				URL:                   "http://169.254.169.254/latest/meta-data/iam/security-credentials",
+				IMDSv2SessionTokenURL: "http://abc.com/latest/api/token",
+			},
+			errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url",
+		},
+	}
+
+	for _, tt := range metadataServerValidityTests {
+		t.Run(tt.name, func(t *testing.T) {
+			tfc := testFileConfig
+			tfc.CredentialSource = tt.credSource
+
+			oldGetenv := getenv
+			defer func() { getenv = oldGetenv }()
+			getenv = setEnvironment(map[string]string{})
+
+			_, err := tfc.parse(context.Background())
+			if err != nil {
+				if tt.errText == "" {
+					t.Errorf("Didn't expect an error, but got %v", err)
+				} else if tt.errText != err.Error() {
+					t.Errorf("Expected %v, but got %v", tt.errText, err)
+				}
+			} else {
+				if tt.errText != "" {
+					t.Errorf("Expected error %v, but got none", tt.errText)
+				}
+			}
+		})
+	}
+}
diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go
index 9fc3553..3eab8df 100644
--- a/google/internal/externalaccount/basecredentials.go
+++ b/google/internal/externalaccount/basecredentials.go
@@ -213,6 +213,10 @@
 				awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
 			}
 
+			if err := awsCredSource.validateMetadataServers(); err != nil {
+				return nil, err
+			}
+
 			return awsCredSource, nil
 		}
 	} else if c.CredentialSource.File != "" {