google/internal/externalaccount: Added check for aws region and security credential environment variables before aws metadata call

Adds check for aws values in environment variables before the metadata server is called to prevent unnecessary off box calls. See https://github.com/googleapis/google-auth-library-java/pull/1100 for same change in java library.

Change-Id: Ie86a899be88c38d3fcbbe377f9bf30a7a66530c0
GitHub-Last-Rev: bcab69572cb0dca4c7c6426203d4232e6e89d8db
GitHub-Pull-Request: golang/oauth2#612
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/453715
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go
index 6318a23..2bf3202 100644
--- a/google/internal/externalaccount/aws.go
+++ b/google/internal/externalaccount/aws.go
@@ -62,6 +62,13 @@
 	// The AWS authorization header name for the auto-generated date.
 	awsDateHeader = "x-amz-date"
 
+	// Supported AWS configuration environment variables.
+	awsAccessKeyId     = "AWS_ACCESS_KEY_ID"
+	awsDefaultRegion   = "AWS_DEFAULT_REGION"
+	awsRegion          = "AWS_REGION"
+	awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
+	awsSessionToken    = "AWS_SESSION_TOKEN"
+
 	awsTimeFormatLong  = "20060102T150405Z"
 	awsTimeFormatShort = "20060102"
 )
@@ -317,16 +324,33 @@
 	return cs.client.Do(req.WithContext(cs.ctx))
 }
 
+func canRetrieveRegionFromEnvironment() bool {
+	// The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
+	// required.
+	return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
+}
+
+func canRetrieveSecurityCredentialFromEnvironment() bool {
+	// Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
+	return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
+}
+
+func shouldUseMetadataServer() bool {
+	return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
+}
+
 func (cs awsCredentialSource) subjectToken() (string, error) {
 	if cs.requestSigner == nil {
-		awsSessionToken, err := cs.getAWSSessionToken()
-		if err != nil {
-			return "", err
-		}
-
 		headers := make(map[string]string)
-		if awsSessionToken != "" {
-			headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
+		if shouldUseMetadataServer() {
+			awsSessionToken, err := cs.getAWSSessionToken()
+			if err != nil {
+				return "", err
+			}
+
+			if awsSessionToken != "" {
+				headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
+			}
 		}
 
 		awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
@@ -432,11 +456,11 @@
 }
 
 func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
-	if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
-		return envAwsRegion, nil
-	}
-	if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" {
-		return envAwsRegion, nil
+	if canRetrieveRegionFromEnvironment() {
+		if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
+			return envAwsRegion, nil
+		}
+		return getenv("AWS_DEFAULT_REGION"), nil
 	}
 
 	if cs.RegionURL == "" {
@@ -477,14 +501,12 @@
 }
 
 func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
-	if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
-		if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
-			return awsSecurityCredentials{
-				AccessKeyID:     accessKeyID,
-				SecretAccessKey: secretAccessKey,
-				SecurityToken:   getenv("AWS_SESSION_TOKEN"),
-			}, nil
-		}
+	if canRetrieveSecurityCredentialFromEnvironment() {
+		return awsSecurityCredentials{
+			AccessKeyID:     getenv(awsAccessKeyId),
+			SecretAccessKey: getenv(awsSecretAccessKey),
+			SecurityToken:   getenv(awsSessionToken),
+		}, nil
 	}
 
 	roleName, err := cs.getMetadataRoleName(headers)
diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go
index 30a003a..058b004 100644
--- a/google/internal/externalaccount/aws_test.go
+++ b/google/internal/externalaccount/aws_test.go
@@ -474,6 +474,38 @@
 	)
 }
 
+func createDefaultAwsTestServerWithImdsv2(t *testing.T) *testAwsServer {
+	validateSessionTokenHeaders := func(r *http.Request) {
+		if r.URL.Path == "/latest/api/token" {
+			headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader)
+			if headerValue != awsIMDSv2SessionTtl {
+				t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl)
+			}
+		} else {
+			headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
+			if headerValue != "sessiontoken" {
+				t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
+			}
+		}
+	}
+
+	return createAwsTestServer(
+		"/latest/meta-data/iam/security-credentials",
+		"/latest/meta-data/placement/availability-zone",
+		"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"/latest/api/token",
+		"gcp-aws-role",
+		"us-east-2b",
+		map[string]string{
+			"SecretAccessKey": secretAccessKey,
+			"AccessKeyId":     accessKeyID,
+			"Token":           securityToken,
+		},
+		"sessiontoken",
+		validateSessionTokenHeaders,
+	)
+}
+
 func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	switch p := r.URL.Path; p {
 	case server.url:
@@ -597,35 +629,7 @@
 }
 
 func TestAWSCredential_IMDSv2(t *testing.T) {
-	validateSessionTokenHeaders := func(r *http.Request) {
-		if r.URL.Path == "/latest/api/token" {
-			headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader)
-			if headerValue != awsIMDSv2SessionTtl {
-				t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl)
-			}
-		} else {
-			headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
-			if headerValue != "sessiontoken" {
-				t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
-			}
-		}
-	}
-
-	server := createAwsTestServer(
-		"/latest/meta-data/iam/security-credentials",
-		"/latest/meta-data/placement/availability-zone",
-		"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
-		"/latest/api/token",
-		"gcp-aws-role",
-		"us-east-2b",
-		map[string]string{
-			"SecretAccessKey": secretAccessKey,
-			"AccessKeyId":     accessKeyID,
-			"Token":           securityToken,
-		},
-		"sessiontoken",
-		validateSessionTokenHeaders,
-	)
+	server := createDefaultAwsTestServerWithImdsv2(t)
 	ts := httptest.NewServer(server)
 	tsURL, err := neturl.Parse(ts.URL)
 	if err != nil {
@@ -1152,6 +1156,208 @@
 	}
 }
 
+func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
+
+	metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		t.Error("Metadata server should not have been called.")
+	}))
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+	tfc.CredentialSource.IMDSv2SessionTokenURL = metadataTs.URL
+
+	oldGetenv := getenv
+	oldNow := now
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
+	getenv = setEnvironment(map[string]string{
+		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
+		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+		"AWS_REGION":            "us-west-1",
+	})
+	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
+	if err != nil {
+		t.Fatalf("retrieveSubjectToken() failed: %v", err)
+	}
+
+	expected := getExpectedSubjectToken(
+		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"us-west-1",
+		"AKIDEXAMPLE",
+		"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+		"",
+	)
+
+	if got, want := out, expected; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
+	}
+}
+
+func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
+	server := createDefaultAwsTestServerWithImdsv2(t)
+	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	oldNow := now
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
+	getenv = setEnvironment(map[string]string{
+		"AWS_ACCESS_KEY_ID":     accessKeyID,
+		"AWS_SECRET_ACCESS_KEY": secretAccessKey,
+	})
+	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
+	if err != nil {
+		t.Fatalf("retrieveSubjectToken() failed: %v", err)
+	}
+
+	expected := getExpectedSubjectToken(
+		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"us-east-2",
+		accessKeyID,
+		secretAccessKey,
+		"",
+	)
+
+	if got, want := out, expected; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
+	}
+}
+
+func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
+	server := createDefaultAwsTestServerWithImdsv2(t)
+	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	oldNow := now
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
+	getenv = setEnvironment(map[string]string{
+		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+		"AWS_REGION":            "us-west-1",
+	})
+	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
+	if err != nil {
+		t.Fatalf("retrieveSubjectToken() failed: %v", err)
+	}
+
+	expected := getExpectedSubjectToken(
+		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"us-west-1",
+		accessKeyID,
+		secretAccessKey,
+		securityToken,
+	)
+
+	if got, want := out, expected; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
+	}
+}
+
+func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
+	server := createDefaultAwsTestServerWithImdsv2(t)
+	ts := httptest.NewServer(server)
+	tsURL, err := neturl.Parse(ts.URL)
+	if err != nil {
+		t.Fatalf("couldn't parse httptest servername")
+	}
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	oldNow := now
+	oldValidHostnames := validHostnames
+	defer func() {
+		getenv = oldGetenv
+		now = oldNow
+		validHostnames = oldValidHostnames
+	}()
+	getenv = setEnvironment(map[string]string{
+		"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
+		"AWS_REGION":        "us-west-1",
+	})
+	now = setTime(defaultTime)
+	validHostnames = []string{tsURL.Hostname()}
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
+	if err != nil {
+		t.Fatalf("retrieveSubjectToken() failed: %v", err)
+	}
+
+	expected := getExpectedSubjectToken(
+		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"us-west-1",
+		accessKeyID,
+		secretAccessKey,
+		securityToken,
+	)
+
+	if got, want := out, expected; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
+	}
+}
+
 func TestAWSCredential_Validations(t *testing.T) {
 	var metadataServerValidityTests = []struct {
 		name       string