google: support AWS 3rd party credentials

Change-Id: I655b38f7fb8023866bb284c7ce80ab9888682e73
GitHub-Last-Rev: 648f0b3d45d94760bb29e6bfe4680351d8e0364d
GitHub-Pull-Request: golang/oauth2#471
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/287752
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Tyler Bui-Palsulich <tbp@google.com>
Trust: Cody Oss <codyoss@google.com>
diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go
index 906d1fe..3725a0f 100644
--- a/google/internal/externalaccount/aws.go
+++ b/google/internal/externalaccount/aws.go
@@ -5,41 +5,54 @@
 package externalaccount
 
 import (
+	"context"
 	"crypto/hmac"
 	"crypto/sha256"
 	"encoding/hex"
+	"encoding/json"
 	"errors"
 	"fmt"
+	"golang.org/x/oauth2"
 	"io"
 	"io/ioutil"
 	"net/http"
+	"os"
 	"path"
 	"sort"
 	"strings"
 	"time"
 )
 
-// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
-type awsRequestSigner struct {
-	RegionName             string
-	AwsSecurityCredentials map[string]string
+type awsSecurityCredentials struct {
+	AccessKeyID     string `json:"AccessKeyID"`
+	SecretAccessKey string `json:"SecretAccessKey"`
+	SecurityToken   string `json:"Token"`
 }
 
+// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
+type awsRequestSigner struct {
+	RegionName             string
+	AwsSecurityCredentials awsSecurityCredentials
+}
+
+// getenv aliases os.Getenv for testing
+var getenv = os.Getenv
+
 const (
-// AWS Signature Version 4 signing algorithm identifier.
+	// AWS Signature Version 4 signing algorithm identifier.
 	awsAlgorithm = "AWS4-HMAC-SHA256"
 
-// The termination string for the AWS credential scope value as defined in
-// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
+	// The termination string for the AWS credential scope value as defined in
+	// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
 	awsRequestType = "aws4_request"
 
-// The AWS authorization header name for the security session token if available.
+	// The AWS authorization header name for the security session token if available.
 	awsSecurityTokenHeader = "x-amz-security-token"
 
-// The AWS authorization header name for the auto-generated date.
+	// The AWS authorization header name for the auto-generated date.
 	awsDateHeader = "x-amz-date"
 
-	awsTimeFormatLong = "20060102T150405Z"
+	awsTimeFormatLong  = "20060102T150405Z"
 	awsTimeFormatShort = "20060102"
 )
 
@@ -167,8 +180,8 @@
 
 	signedRequest.Header.Add("host", requestHost(req))
 
-	if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
-		signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
+	if rs.AwsSecurityCredentials.SecurityToken != "" {
+		signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
 	}
 
 	if signedRequest.Header.Get("date") == "" {
@@ -186,15 +199,6 @@
 }
 
 func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
-	secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
-	if !ok {
-		return "", errors.New("oauth2/google: missing secret_access_key header")
-	}
-	accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
-	if !ok {
-		return "", errors.New("oauth2/google: missing access_key_id header")
-	}
-
 	canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
 
 	dateStamp := timestamp.Format(awsTimeFormatShort)
@@ -203,28 +207,258 @@
 		serviceName = splitHost[0]
 	}
 
-	credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)
+	credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
 
 	requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
 	if err != nil {
 		return "", err
 	}
 	requestHash, err := getSha256([]byte(requestString))
-	if err != nil{
+	if err != nil {
 		return "", err
 	}
 
 	stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
 
-	signingKey := []byte("AWS4" + secretAccessKey)
+	signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
 	for _, signingInput := range []string{
 		dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
 	} {
 		signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
-		if err != nil{
+		if err != nil {
 			return "", err
 		}
 	}
 
-	return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
+	return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
+}
+
+type awsCredentialSource struct {
+	EnvironmentID               string
+	RegionURL                   string
+	RegionalCredVerificationURL string
+	CredVerificationURL         string
+	TargetResource              string
+	requestSigner               *awsRequestSigner
+	region                      string
+	ctx                         context.Context
+	client                      *http.Client
+}
+
+type awsRequestHeader struct {
+	Key   string `json:"key"`
+	Value string `json:"value"`
+}
+
+type awsRequest struct {
+	URL     string             `json:"url"`
+	Method  string             `json:"method"`
+	Headers []awsRequestHeader `json:"headers"`
+}
+
+func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
+	if cs.client == nil {
+		cs.client = oauth2.NewClient(cs.ctx, nil)
+	}
+	return cs.client.Do(req.WithContext(cs.ctx))
+}
+
+func (cs awsCredentialSource) subjectToken() (string, error) {
+	if cs.requestSigner == nil {
+		awsSecurityCredentials, err := cs.getSecurityCredentials()
+		if err != nil {
+			return "", err
+		}
+
+		if cs.region, err = cs.getRegion(); err != nil {
+			return "", err
+		}
+
+		cs.requestSigner = &awsRequestSigner{
+			RegionName:             cs.region,
+			AwsSecurityCredentials: awsSecurityCredentials,
+		}
+	}
+
+	// Generate the signed request to AWS STS GetCallerIdentity API.
+	// Use the required regional endpoint. Otherwise, the request will fail.
+	req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
+	if err != nil {
+		return "", err
+	}
+	// The full, canonical resource name of the workload identity pool
+	// provider, with or without the HTTPS prefix.
+	// Including this header as part of the signature is recommended to
+	// ensure data integrity.
+	if cs.TargetResource != "" {
+		req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource)
+	}
+	cs.requestSigner.SignRequest(req)
+
+	/*
+	   The GCP STS endpoint expects the headers to be formatted as:
+	   # [
+	   #   {key: 'x-amz-date', value: '...'},
+	   #   {key: 'Authorization', value: '...'},
+	   #   ...
+	   # ]
+	   # And then serialized as:
+	   # quote(json.dumps({
+	   #   url: '...',
+	   #   method: 'POST',
+	   #   headers: [{key: 'x-amz-date', value: '...'}, ...]
+	   # }))
+	*/
+
+	awsSignedReq := awsRequest{
+		URL:    req.URL.String(),
+		Method: "POST",
+	}
+	for headerKey, headerList := range req.Header {
+		for _, headerValue := range headerList {
+			awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
+				Key:   headerKey,
+				Value: headerValue,
+			})
+		}
+	}
+	sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
+		headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
+		if headerCompare == 0 {
+			return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
+		}
+		return headerCompare < 0
+	})
+
+	result, err := json.Marshal(awsSignedReq)
+	if err != nil {
+		return "", err
+	}
+	return string(result), nil
+}
+
+func (cs *awsCredentialSource) getRegion() (string, error) {
+	if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
+		return envAwsRegion, nil
+	}
+
+	if cs.RegionURL == "" {
+		return "", errors.New("oauth2/google: unable to determine AWS region")
+	}
+
+	req, err := http.NewRequest("GET", cs.RegionURL, nil)
+	if err != nil {
+		return "", err
+	}
+
+	resp, err := cs.doRequest(req)
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close()
+
+	respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
+	if err != nil {
+		return "", err
+	}
+
+	if resp.StatusCode != 200 {
+		return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
+	}
+
+	// This endpoint will return the region in format: us-east-2b.
+	// Only the us-east-2 part should be used.
+	respBodyEnd := 0
+	if len(respBody) > 1 {
+		respBodyEnd = len(respBody) - 1
+	}
+	return string(respBody[:respBodyEnd]), nil
+}
+
+func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
+	if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
+		if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
+			return awsSecurityCredentials{
+				AccessKeyID:     accessKeyID,
+				SecretAccessKey: secretAccessKey,
+				SecurityToken:   getenv("AWS_SESSION_TOKEN"),
+			}, nil
+		}
+	}
+
+	roleName, err := cs.getMetadataRoleName()
+	if err != nil {
+		return
+	}
+
+	credentials, err := cs.getMetadataSecurityCredentials(roleName)
+	if err != nil {
+		return
+	}
+
+	if credentials.AccessKeyID == "" {
+		return result, errors.New("oauth2/google: missing AccessKeyId credential")
+	}
+
+	if credentials.SecretAccessKey == "" {
+		return result, errors.New("oauth2/google: missing SecretAccessKey credential")
+	}
+
+	return credentials, nil
+}
+
+func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
+	var result awsSecurityCredentials
+
+	req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
+	if err != nil {
+		return result, err
+	}
+	req.Header.Add("Content-Type", "application/json")
+
+	resp, err := cs.doRequest(req)
+	if err != nil {
+		return result, err
+	}
+	defer resp.Body.Close()
+
+	respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
+	if err != nil {
+		return result, err
+	}
+
+	if resp.StatusCode != 200 {
+		return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
+	}
+
+	err = json.Unmarshal(respBody, &result)
+	return result, err
+}
+
+func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
+	if cs.CredVerificationURL == "" {
+		return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
+	}
+
+	req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
+	if err != nil {
+		return "", err
+	}
+
+	resp, err := cs.doRequest(req)
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close()
+
+	respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
+	if err != nil {
+		return "", err
+	}
+
+	if resp.StatusCode != 200 {
+		return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
+	}
+
+	return string(respBody), nil
 }
diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go
index 206c3a1..1a83a7b 100644
--- a/google/internal/externalaccount/aws_test.go
+++ b/google/internal/externalaccount/aws_test.go
@@ -5,7 +5,11 @@
 package externalaccount
 
 import (
+	"context"
+	"encoding/json"
+	"fmt"
 	"net/http"
+	"net/http/httptest"
 	"reflect"
 	"strings"
 	"testing"
@@ -21,24 +25,33 @@
 	}
 }
 
+func setEnvironment(env map[string]string) func(string) string {
+	return func(key string) string {
+		value, _ := env[key]
+		return value
+	}
+}
+
 var defaultRequestSigner = &awsRequestSigner{
 	RegionName: "us-east-1",
-	AwsSecurityCredentials: map[string]string{
-		"access_key_id":     "AKIDEXAMPLE",
-		"secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+	AwsSecurityCredentials: awsSecurityCredentials{
+		AccessKeyID:     "AKIDEXAMPLE",
+		SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
 	},
 }
 
-const accessKeyId = "ASIARD4OQDT6A77FR3CL"
-const secretAccessKey = "Y8AfSaucF37G4PpvfguKZ3/l7Id4uocLXxX0+VTx"
-const securityToken = "IQoJb3JpZ2luX2VjEIz//////////wEaCXVzLWVhc3QtMiJGMEQCIH7MHX/Oy/OB8OlLQa9GrqU1B914+iMikqWQW7vPCKlgAiA/Lsv8Jcafn14owfxXn95FURZNKaaphj0ykpmS+Ki+CSq0AwhlEAAaDDA3NzA3MTM5MTk5NiIMx9sAeP1ovlMTMKLjKpEDwuJQg41/QUKx0laTZYjPlQvjwSqS3OB9P1KAXPWSLkliVMMqaHqelvMF/WO/glv3KwuTfQsavRNs3v5pcSEm4SPO3l7mCs7KrQUHwGP0neZhIKxEXy+Ls//1C/Bqt53NL+LSbaGv6RPHaX82laz2qElphg95aVLdYgIFY6JWV5fzyjgnhz0DQmy62/Vi8pNcM2/VnxeCQ8CC8dRDSt52ry2v+nc77vstuI9xV5k8mPtnaPoJDRANh0bjwY5Sdwkbp+mGRUJBAQRlNgHUJusefXQgVKBCiyJY4w3Csd8Bgj9IyDV+Azuy1jQqfFZWgP68LSz5bURyIjlWDQunO82stZ0BgplKKAa/KJHBPCp8Qi6i99uy7qh76FQAqgVTsnDuU6fGpHDcsDSGoCls2HgZjZFPeOj8mmRhFk1Xqvkbjuz8V1cJk54d3gIJvQt8gD2D6yJQZecnuGWd5K2e2HohvCc8Fc9kBl1300nUJPV+k4tr/A5R/0QfEKOZL1/k5lf1g9CREnrM8LVkGxCgdYMxLQow1uTL+QU67AHRRSp5PhhGX4Rek+01vdYSnJCMaPhSEgcLqDlQkhk6MPsyT91QMXcWmyO+cAZwUPwnRamFepuP4K8k2KVXs/LIJHLELwAZ0ekyaS7CptgOqS7uaSTFG3U+vzFZLEnGvWQ7y9IPNQZ+Dffgh4p3vF4J68y9049sI6Sr5d5wbKkcbm8hdCDHZcv4lnqohquPirLiFQ3q7B17V9krMPu3mz1cg4Ekgcrn/E09NTsxAqD8NcZ7C7ECom9r+X3zkDOxaajW6hu3Az8hGlyylDaMiFfRbBJpTIlxp7jfa7CxikNgNtEKLH9iCzvuSg2vhA=="
+const (
+	accessKeyID     = "ASIARD4OQDT6A77FR3CL"
+	secretAccessKey = "Y8AfSaucF37G4PpvfguKZ3/l7Id4uocLXxX0+VTx"
+	securityToken   = "IQoJb3JpZ2luX2VjEIz//////////wEaCXVzLWVhc3QtMiJGMEQCIH7MHX/Oy/OB8OlLQa9GrqU1B914+iMikqWQW7vPCKlgAiA/Lsv8Jcafn14owfxXn95FURZNKaaphj0ykpmS+Ki+CSq0AwhlEAAaDDA3NzA3MTM5MTk5NiIMx9sAeP1ovlMTMKLjKpEDwuJQg41/QUKx0laTZYjPlQvjwSqS3OB9P1KAXPWSLkliVMMqaHqelvMF/WO/glv3KwuTfQsavRNs3v5pcSEm4SPO3l7mCs7KrQUHwGP0neZhIKxEXy+Ls//1C/Bqt53NL+LSbaGv6RPHaX82laz2qElphg95aVLdYgIFY6JWV5fzyjgnhz0DQmy62/Vi8pNcM2/VnxeCQ8CC8dRDSt52ry2v+nc77vstuI9xV5k8mPtnaPoJDRANh0bjwY5Sdwkbp+mGRUJBAQRlNgHUJusefXQgVKBCiyJY4w3Csd8Bgj9IyDV+Azuy1jQqfFZWgP68LSz5bURyIjlWDQunO82stZ0BgplKKAa/KJHBPCp8Qi6i99uy7qh76FQAqgVTsnDuU6fGpHDcsDSGoCls2HgZjZFPeOj8mmRhFk1Xqvkbjuz8V1cJk54d3gIJvQt8gD2D6yJQZecnuGWd5K2e2HohvCc8Fc9kBl1300nUJPV+k4tr/A5R/0QfEKOZL1/k5lf1g9CREnrM8LVkGxCgdYMxLQow1uTL+QU67AHRRSp5PhhGX4Rek+01vdYSnJCMaPhSEgcLqDlQkhk6MPsyT91QMXcWmyO+cAZwUPwnRamFepuP4K8k2KVXs/LIJHLELwAZ0ekyaS7CptgOqS7uaSTFG3U+vzFZLEnGvWQ7y9IPNQZ+Dffgh4p3vF4J68y9049sI6Sr5d5wbKkcbm8hdCDHZcv4lnqohquPirLiFQ3q7B17V9krMPu3mz1cg4Ekgcrn/E09NTsxAqD8NcZ7C7ECom9r+X3zkDOxaajW6hu3Az8hGlyylDaMiFfRbBJpTIlxp7jfa7CxikNgNtEKLH9iCzvuSg2vhA=="
+)
 
 var requestSignerWithToken = &awsRequestSigner{
 	RegionName: "us-east-2",
-	AwsSecurityCredentials: map[string]string{
-		"access_key_id":     accessKeyId,
-		"secret_access_key": secretAccessKey,
-		"security_token":    securityToken,
+	AwsSecurityCredentials: awsSecurityCredentials{
+		AccessKeyID:     accessKeyID,
+		SecretAccessKey: secretAccessKey,
+		SecurityToken:   securityToken,
 	},
 }
 
@@ -317,7 +330,7 @@
 	output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
 	output.Header = http.Header{
 		"Host":                 []string{"ec2.us-east-2.amazonaws.com"},
-		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyId + "/20200811/us-east-2/ec2/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=631ea80cddfaa545fdadb120dc92c9f18166e38a5c47b50fab9fce476e022855"},
+		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/ec2/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=631ea80cddfaa545fdadb120dc92c9f18166e38a5c47b50fab9fce476e022855"},
 		"X-Amz-Date":           []string{"20200811T065522Z"},
 		"X-Amz-Security-Token": []string{securityToken},
 	}
@@ -334,7 +347,7 @@
 
 	output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
 	output.Header = http.Header{
-		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyId + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=73452984e4a880ffdc5c392355733ec3f5ba310d5e0609a89244440cadfe7a7a"},
+		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=73452984e4a880ffdc5c392355733ec3f5ba310d5e0609a89244440cadfe7a7a"},
 		"Host":                 []string{"sts.us-east-2.amazonaws.com"},
 		"X-Amz-Date":           []string{"20200811T065522Z"},
 		"X-Amz-Security-Token": []string{securityToken},
@@ -355,7 +368,7 @@
 
 	output, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams))
 	output.Header = http.Header{
-		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyId + "/20200811/us-east-2/dynamodb/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-security-token;x-amz-target, Signature=fdaa5b9cc9c86b80fe61eaf504141c0b3523780349120f2bd8145448456e0385"},
+		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/dynamodb/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-security-token;x-amz-target, Signature=fdaa5b9cc9c86b80fe61eaf504141c0b3523780349120f2bd8145448456e0385"},
 		"Host":                 []string{"dynamodb.us-east-2.amazonaws.com"},
 		"X-Amz-Date":           []string{"20200811T065522Z"},
 		"Content-Type":         []string{"application/x-amz-json-1.0"},
@@ -373,9 +386,9 @@
 func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
 	var requestSigner = &awsRequestSigner{
 		RegionName: "us-east-2",
-		AwsSecurityCredentials: map[string]string{
-			"access_key_id":     accessKeyId,
-			"secret_access_key": secretAccessKey,
+		AwsSecurityCredentials: awsSecurityCredentials{
+			AccessKeyID:     accessKeyID,
+			SecretAccessKey: secretAccessKey,
 		},
 	}
 
@@ -383,7 +396,7 @@
 
 	output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
 	output.Header = http.Header{
-		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyId + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=d095ba304919cd0d5570ba8a3787884ee78b860f268ed040ba23831d55536d56"},
+		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=d095ba304919cd0d5570ba8a3787884ee78b860f268ed040ba23831d55536d56"},
 		"Host":          []string{"sts.us-east-2.amazonaws.com"},
 		"X-Amz-Date":    []string{"20200811T065522Z"},
 	}
@@ -394,3 +407,446 @@
 
 	testRequestSigner(t, requestSigner, input, output)
 }
+
+type testAwsServer struct {
+	url                         string
+	securityCredentialURL       string
+	regionURL                   string
+	regionalCredVerificationURL string
+
+	Credentials map[string]string
+
+	WriteRolename            func(http.ResponseWriter)
+	WriteSecurityCredentials func(http.ResponseWriter)
+	WriteRegion              func(http.ResponseWriter)
+}
+
+func createAwsTestServer(url, regionURL, regionalCredVerificationURL, rolename, region string, credentials map[string]string) *testAwsServer {
+	server := &testAwsServer{
+		url:                         url,
+		securityCredentialURL:       fmt.Sprintf("%s/%s", url, rolename),
+		regionURL:                   regionURL,
+		regionalCredVerificationURL: regionalCredVerificationURL,
+		Credentials:                 credentials,
+		WriteRolename: func(w http.ResponseWriter) {
+			w.Write([]byte(rolename))
+		},
+		WriteRegion: func(w http.ResponseWriter) {
+			w.Write([]byte(region))
+		},
+	}
+
+	server.WriteSecurityCredentials = func(w http.ResponseWriter) {
+		jsonCredentials, _ := json.Marshal(server.Credentials)
+		w.Write(jsonCredentials)
+	}
+
+	return server
+}
+
+func createDefaultAwsTestServer() *testAwsServer {
+	return createAwsTestServer(
+		"/latest/meta-data/iam/security-credentials",
+		"/latest/meta-data/placement/availability-zone",
+		"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"gcp-aws-role",
+		"us-east-2b",
+		map[string]string{
+			"SecretAccessKey": secretAccessKey,
+			"AccessKeyId":     accessKeyID,
+			"Token":           securityToken,
+		},
+	)
+}
+
+func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	switch p := r.URL.Path; p {
+	case server.url:
+		server.WriteRolename(w)
+	case server.securityCredentialURL:
+		server.WriteSecurityCredentials(w)
+	case server.regionURL:
+		server.WriteRegion(w)
+	}
+}
+
+func notFound(w http.ResponseWriter) {
+	w.WriteHeader(404)
+	w.Write([]byte("Not Found"))
+}
+
+func (server *testAwsServer) getCredentialSource(url string) CredentialSource {
+	return CredentialSource{
+		EnvironmentID:               "aws1",
+		URL:                         url + server.url,
+		RegionURL:                   url + server.regionURL,
+		RegionalCredVerificationURL: server.regionalCredVerificationURL,
+	}
+}
+
+func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, securityToken string) string {
+	req, _ := http.NewRequest("POST", url, nil)
+	req.Header.Add("x-goog-cloud-target-resource", testFileConfig.Audience)
+	signer := &awsRequestSigner{
+		RegionName: region,
+		AwsSecurityCredentials: awsSecurityCredentials{
+			AccessKeyID:     accessKeyID,
+			SecretAccessKey: secretAccessKey,
+			SecurityToken:   securityToken,
+		},
+	}
+	signer.SignRequest(req)
+
+	result := awsRequest{
+		URL:    url,
+		Method: "POST",
+		Headers: []awsRequestHeader{
+			{
+				Key:   "Authorization",
+				Value: req.Header.Get("Authorization"),
+			}, {
+				Key:   "Host",
+				Value: req.Header.Get("Host"),
+			}, {
+				Key:   "X-Amz-Date",
+				Value: req.Header.Get("X-Amz-Date"),
+			},
+		},
+	}
+
+	if securityToken != "" {
+		result.Headers = append(result.Headers, awsRequestHeader{
+			Key:   "X-Amz-Security-Token",
+			Value: securityToken,
+		})
+	}
+
+	result.Headers = append(result.Headers, awsRequestHeader{
+		Key:   "X-Goog-Cloud-Target-Resource",
+		Value: testFileConfig.Audience,
+	})
+
+	str, _ := json.Marshal(result)
+	return string(str)
+}
+
+func TestAwsCredential_BasicRequest(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
+	if err != nil {
+		t.Fatalf("retrieveSubjectToken() failed: %v", err)
+	}
+
+	expected := getExpectedSubjectToken(
+		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"us-east-2",
+		accessKeyID,
+		secretAccessKey,
+		securityToken,
+	)
+
+	if got, want := out, expected; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+	delete(server.Credentials, "Token")
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
+	if err != nil {
+		t.Fatalf("retrieveSubjectToken() failed: %v", err)
+	}
+
+	expected := getExpectedSubjectToken(
+		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"us-east-2",
+		accessKeyID,
+		secretAccessKey,
+		"",
+	)
+
+	if got, want := out, expected; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{
+		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
+		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+		"AWS_REGION":            "us-west-1",
+	})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
+	if err != nil {
+		t.Fatalf("retrieveSubjectToken() failed: %v", err)
+	}
+
+	expected := getExpectedSubjectToken(
+		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+		"us-west-1",
+		"AKIDEXAMPLE",
+		"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+		"",
+	)
+
+	if got, want := out, expected; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+	tfc.CredentialSource.EnvironmentID = "aws3"
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	_, err := tfc.parse(context.Background())
+	if err == nil {
+		t.Fatalf("parse() should have failed")
+	}
+	if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+	tfc.CredentialSource.RegionURL = ""
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	_, err = base.subjectToken()
+	if err == nil {
+		t.Fatalf("retrieveSubjectToken() should have failed")
+	}
+
+	if got, want := err.Error(), "oauth2/google: unable to determine AWS region"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+	server.WriteRegion = notFound
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	_, err = base.subjectToken()
+	if err == nil {
+		t.Fatalf("retrieveSubjectToken() should have failed")
+	}
+
+	if got, want := err.Error(), "oauth2/google: unable to retrieve AWS region - Not Found"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+	server.WriteSecurityCredentials = func(w http.ResponseWriter) {
+		w.Write([]byte("{}"))
+	}
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	_, err = base.subjectToken()
+	if err == nil {
+		t.Fatalf("retrieveSubjectToken() should have failed")
+	}
+
+	if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+	server.WriteSecurityCredentials = func(w http.ResponseWriter) {
+		w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
+	}
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	_, err = base.subjectToken()
+	if err == nil {
+		t.Fatalf("retrieveSubjectToken() should have failed")
+	}
+
+	if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+	tfc.CredentialSource.URL = ""
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	_, err = base.subjectToken()
+	if err == nil {
+		t.Fatalf("retrieveSubjectToken() should have failed")
+	}
+
+	if got, want := err.Error(), "oauth2/google: unable to determine the AWS metadata server security credentials endpoint"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+	server.WriteRolename = notFound
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	_, err = base.subjectToken()
+	if err == nil {
+		t.Fatalf("retrieveSubjectToken() should have failed")
+	}
+
+	if got, want := err.Error(), "oauth2/google: unable to retrieve AWS role name - Not Found"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
+
+func TestAwsCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
+	server := createDefaultAwsTestServer()
+	ts := httptest.NewServer(server)
+	server.WriteSecurityCredentials = notFound
+
+	tfc := testFileConfig
+	tfc.CredentialSource = server.getCredentialSource(ts.URL)
+
+	oldGetenv := getenv
+	defer func() { getenv = oldGetenv }()
+	getenv = setEnvironment(map[string]string{})
+
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	_, err = base.subjectToken()
+	if err == nil {
+		t.Fatalf("retrieveSubjectToken() should have failed")
+	}
+
+	if got, want := err.Error(), "oauth2/google: unable to retrieve AWS security credentials - Not Found"; !reflect.DeepEqual(got, want) {
+		t.Errorf("subjectToken = %q, want %q", got, want)
+	}
+}
diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go
index deb9deb..57a5870 100644
--- a/google/internal/externalaccount/basecredentials.go
+++ b/google/internal/externalaccount/basecredentials.go
@@ -9,6 +9,7 @@
 	"fmt"
 	"golang.org/x/oauth2"
 	"net/http"
+	"strconv"
 	"time"
 )
 
@@ -77,13 +78,27 @@
 }
 
 // parse determines the type of CredentialSource needed
-func (c *Config) parse(ctx context.Context) baseCredentialSource {
-	if c.CredentialSource.File != "" {
-		return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}
+func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
+	if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" {
+		if awsVersion, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil {
+			if awsVersion != 1 {
+				return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
+			}
+			return awsCredentialSource{
+				EnvironmentID:               c.CredentialSource.EnvironmentID,
+				RegionURL:                   c.CredentialSource.RegionURL,
+				RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
+				CredVerificationURL:         c.CredentialSource.URL,
+				TargetResource:              c.Audience,
+				ctx:                         ctx,
+			}, nil
+		}
+	} else if c.CredentialSource.File != "" {
+		return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
 	} else if c.CredentialSource.URL != "" {
-		return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}
+		return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}, nil
 	}
-	return nil
+	return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
 }
 
 type baseCredentialSource interface {
@@ -100,11 +115,12 @@
 func (ts tokenSource) Token() (*oauth2.Token, error) {
 	conf := ts.conf
 
-	credSource := conf.parse(ts.ctx)
-	if credSource == nil {
-		return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
+	credSource, err := conf.parse(ts.ctx)
+	if err != nil {
+		return nil, err
 	}
 	subjectToken, err := credSource.subjectToken()
+
 	if err != nil {
 		return nil, err
 	}
diff --git a/google/internal/externalaccount/filecredsource_test.go b/google/internal/externalaccount/filecredsource_test.go
index 56dd71e..ebd2bb7 100644
--- a/google/internal/externalaccount/filecredsource_test.go
+++ b/google/internal/externalaccount/filecredsource_test.go
@@ -56,7 +56,12 @@
 		tfc.CredentialSource = test.cs
 
 		t.Run(test.name, func(t *testing.T) {
-			out, err := tfc.parse(context.Background()).subjectToken()
+			base, err := tfc.parse(context.Background())
+			if err != nil {
+				t.Fatalf("parse() failed %v", err)
+			}
+
+			out, err := base.subjectToken()
 			if err != nil {
 				t.Errorf("Method subjectToken() errored.")
 			} else if test.want != out {
diff --git a/google/internal/externalaccount/urlcredsource_test.go b/google/internal/externalaccount/urlcredsource_test.go
index 592610f..1b78e68 100644
--- a/google/internal/externalaccount/urlcredsource_test.go
+++ b/google/internal/externalaccount/urlcredsource_test.go
@@ -28,7 +28,12 @@
 	tfc := testFileConfig
 	tfc.CredentialSource = cs
 
-	out, err := tfc.parse(context.Background()).subjectToken()
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
 	if err != nil {
 		t.Fatalf("retrieveSubjectToken() failed: %v", err)
 	}
@@ -51,7 +56,12 @@
 	tfc := testFileConfig
 	tfc.CredentialSource = cs
 
-	out, err := tfc.parse(context.Background()).subjectToken()
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
 	if err != nil {
 		t.Fatalf("Failed to retrieve URL subject token: %v", err)
 	}
@@ -82,7 +92,12 @@
 	tfc := testFileConfig
 	tfc.CredentialSource = cs
 
-	out, err := tfc.parse(context.Background()).subjectToken()
+	base, err := tfc.parse(context.Background())
+	if err != nil {
+		t.Fatalf("parse() failed %v", err)
+	}
+
+	out, err := base.subjectToken()
 	if err != nil {
 		t.Fatalf("%v", err)
 	}