| // Copyright 2021 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package externalaccount |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto/hmac" |
| "crypto/sha256" |
| "encoding/hex" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net/http" |
| "net/url" |
| "os" |
| "path" |
| "sort" |
| "strings" |
| "time" |
| |
| "golang.org/x/oauth2" |
| ) |
| |
| type awsSecurityCredentials struct { |
| AccessKeyID string `json:"AccessKeyID"` |
| SecretAccessKey string `json:"SecretAccessKey"` |
| SecurityToken string `json:"Token"` |
| } |
| |
| // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature. |
| type awsRequestSigner struct { |
| RegionName string |
| AwsSecurityCredentials awsSecurityCredentials |
| } |
| |
| // getenv aliases os.Getenv for testing |
| var getenv = os.Getenv |
| |
| const ( |
| // AWS Signature Version 4 signing algorithm identifier. |
| awsAlgorithm = "AWS4-HMAC-SHA256" |
| |
| // The termination string for the AWS credential scope value as defined in |
| // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html |
| awsRequestType = "aws4_request" |
| |
| // The AWS authorization header name for the security session token if available. |
| awsSecurityTokenHeader = "x-amz-security-token" |
| |
| // The name of the header containing the session token for metadata endpoint calls |
| awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token" |
| |
| awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" |
| |
| awsIMDSv2SessionTtl = "300" |
| |
| // The AWS authorization header name for the auto-generated date. |
| awsDateHeader = "x-amz-date" |
| |
| awsTimeFormatLong = "20060102T150405Z" |
| awsTimeFormatShort = "20060102" |
| ) |
| |
| func getSha256(input []byte) (string, error) { |
| hash := sha256.New() |
| if _, err := hash.Write(input); err != nil { |
| return "", err |
| } |
| return hex.EncodeToString(hash.Sum(nil)), nil |
| } |
| |
| func getHmacSha256(key, input []byte) ([]byte, error) { |
| hash := hmac.New(sha256.New, key) |
| if _, err := hash.Write(input); err != nil { |
| return nil, err |
| } |
| return hash.Sum(nil), nil |
| } |
| |
| func cloneRequest(r *http.Request) *http.Request { |
| r2 := new(http.Request) |
| *r2 = *r |
| if r.Header != nil { |
| r2.Header = make(http.Header, len(r.Header)) |
| |
| // Find total number of values. |
| headerCount := 0 |
| for _, headerValues := range r.Header { |
| headerCount += len(headerValues) |
| } |
| copiedHeaders := make([]string, headerCount) // shared backing array for headers' values |
| |
| for headerKey, headerValues := range r.Header { |
| headerCount = copy(copiedHeaders, headerValues) |
| r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount] |
| copiedHeaders = copiedHeaders[headerCount:] |
| } |
| } |
| return r2 |
| } |
| |
| func canonicalPath(req *http.Request) string { |
| result := req.URL.EscapedPath() |
| if result == "" { |
| return "/" |
| } |
| return path.Clean(result) |
| } |
| |
| func canonicalQuery(req *http.Request) string { |
| queryValues := req.URL.Query() |
| for queryKey := range queryValues { |
| sort.Strings(queryValues[queryKey]) |
| } |
| return queryValues.Encode() |
| } |
| |
| func canonicalHeaders(req *http.Request) (string, string) { |
| // Header keys need to be sorted alphabetically. |
| var headers []string |
| lowerCaseHeaders := make(http.Header) |
| for k, v := range req.Header { |
| k := strings.ToLower(k) |
| if _, ok := lowerCaseHeaders[k]; ok { |
| // include additional values |
| lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...) |
| } else { |
| headers = append(headers, k) |
| lowerCaseHeaders[k] = v |
| } |
| } |
| sort.Strings(headers) |
| |
| var fullHeaders bytes.Buffer |
| for _, header := range headers { |
| headerValue := strings.Join(lowerCaseHeaders[header], ",") |
| fullHeaders.WriteString(header) |
| fullHeaders.WriteRune(':') |
| fullHeaders.WriteString(headerValue) |
| fullHeaders.WriteRune('\n') |
| } |
| |
| return strings.Join(headers, ";"), fullHeaders.String() |
| } |
| |
| func requestDataHash(req *http.Request) (string, error) { |
| var requestData []byte |
| if req.Body != nil { |
| requestBody, err := req.GetBody() |
| if err != nil { |
| return "", err |
| } |
| defer requestBody.Close() |
| |
| requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20)) |
| if err != nil { |
| return "", err |
| } |
| } |
| |
| return getSha256(requestData) |
| } |
| |
| func requestHost(req *http.Request) string { |
| if req.Host != "" { |
| return req.Host |
| } |
| return req.URL.Host |
| } |
| |
| func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) { |
| dataHash, err := requestDataHash(req) |
| if err != nil { |
| return "", err |
| } |
| |
| return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil |
| } |
| |
| // SignRequest adds the appropriate headers to an http.Request |
| // or returns an error if something prevented this. |
| func (rs *awsRequestSigner) SignRequest(req *http.Request) error { |
| signedRequest := cloneRequest(req) |
| timestamp := now() |
| |
| signedRequest.Header.Add("host", requestHost(req)) |
| |
| if rs.AwsSecurityCredentials.SecurityToken != "" { |
| signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken) |
| } |
| |
| if signedRequest.Header.Get("date") == "" { |
| signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong)) |
| } |
| |
| authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp) |
| if err != nil { |
| return err |
| } |
| signedRequest.Header.Set("Authorization", authorizationCode) |
| |
| req.Header = signedRequest.Header |
| return nil |
| } |
| |
| func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) { |
| canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req) |
| |
| dateStamp := timestamp.Format(awsTimeFormatShort) |
| serviceName := "" |
| if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 { |
| serviceName = splitHost[0] |
| } |
| |
| credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType) |
| |
| requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData) |
| if err != nil { |
| return "", err |
| } |
| requestHash, err := getSha256([]byte(requestString)) |
| if err != nil { |
| return "", err |
| } |
| |
| stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash) |
| |
| signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey) |
| for _, signingInput := range []string{ |
| dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign, |
| } { |
| signingKey, err = getHmacSha256(signingKey, []byte(signingInput)) |
| if err != nil { |
| return "", err |
| } |
| } |
| |
| return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil |
| } |
| |
| type awsCredentialSource struct { |
| EnvironmentID string |
| RegionURL string |
| RegionalCredVerificationURL string |
| CredVerificationURL string |
| IMDSv2SessionTokenURL string |
| TargetResource string |
| requestSigner *awsRequestSigner |
| region string |
| ctx context.Context |
| client *http.Client |
| } |
| |
| type awsRequestHeader struct { |
| Key string `json:"key"` |
| Value string `json:"value"` |
| } |
| |
| type awsRequest struct { |
| URL string `json:"url"` |
| Method string `json:"method"` |
| Headers []awsRequestHeader `json:"headers"` |
| } |
| |
| func (cs awsCredentialSource) validateMetadataServers() error { |
| if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil { |
| return err |
| } |
| if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil { |
| return err |
| } |
| return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url") |
| } |
| |
| var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"} |
| |
| func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool { |
| if metadataUrl == "" { |
| // Zero value means use default, which is valid. |
| return true |
| } |
| |
| u, err := url.Parse(metadataUrl) |
| if err != nil { |
| // Unparseable URL means invalid |
| return false |
| } |
| |
| for _, validHostname := range validHostnames { |
| if u.Hostname() == validHostname { |
| // If it's one of the valid hostnames, everything is good |
| return true |
| } |
| } |
| |
| // hostname not found in our allowlist, so not valid |
| return false |
| } |
| |
| func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error { |
| if !cs.isValidMetadataServer(metadataUrl) { |
| return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName) |
| } |
| |
| return nil |
| } |
| |
| func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { |
| if cs.client == nil { |
| cs.client = oauth2.NewClient(cs.ctx, nil) |
| } |
| return cs.client.Do(req.WithContext(cs.ctx)) |
| } |
| |
| func (cs awsCredentialSource) subjectToken() (string, error) { |
| if cs.requestSigner == nil { |
| awsSessionToken, err := cs.getAWSSessionToken() |
| if err != nil { |
| return "", err |
| } |
| |
| headers := make(map[string]string) |
| if awsSessionToken != "" { |
| headers[awsIMDSv2SessionTokenHeader] = awsSessionToken |
| } |
| |
| awsSecurityCredentials, err := cs.getSecurityCredentials(headers) |
| if err != nil { |
| return "", err |
| } |
| |
| if cs.region, err = cs.getRegion(headers); err != nil { |
| return "", err |
| } |
| |
| cs.requestSigner = &awsRequestSigner{ |
| RegionName: cs.region, |
| AwsSecurityCredentials: awsSecurityCredentials, |
| } |
| } |
| |
| // Generate the signed request to AWS STS GetCallerIdentity API. |
| // Use the required regional endpoint. Otherwise, the request will fail. |
| req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil) |
| if err != nil { |
| return "", err |
| } |
| // The full, canonical resource name of the workload identity pool |
| // provider, with or without the HTTPS prefix. |
| // Including this header as part of the signature is recommended to |
| // ensure data integrity. |
| if cs.TargetResource != "" { |
| req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource) |
| } |
| cs.requestSigner.SignRequest(req) |
| |
| /* |
| The GCP STS endpoint expects the headers to be formatted as: |
| # [ |
| # {key: 'x-amz-date', value: '...'}, |
| # {key: 'Authorization', value: '...'}, |
| # ... |
| # ] |
| # And then serialized as: |
| # quote(json.dumps({ |
| # url: '...', |
| # method: 'POST', |
| # headers: [{key: 'x-amz-date', value: '...'}, ...] |
| # })) |
| */ |
| |
| awsSignedReq := awsRequest{ |
| URL: req.URL.String(), |
| Method: "POST", |
| } |
| for headerKey, headerList := range req.Header { |
| for _, headerValue := range headerList { |
| awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{ |
| Key: headerKey, |
| Value: headerValue, |
| }) |
| } |
| } |
| sort.Slice(awsSignedReq.Headers, func(i, j int) bool { |
| headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key) |
| if headerCompare == 0 { |
| return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0 |
| } |
| return headerCompare < 0 |
| }) |
| |
| result, err := json.Marshal(awsSignedReq) |
| if err != nil { |
| return "", err |
| } |
| return url.QueryEscape(string(result)), nil |
| } |
| |
| func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { |
| if cs.IMDSv2SessionTokenURL == "" { |
| return "", nil |
| } |
| |
| req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil) |
| if err != nil { |
| return "", err |
| } |
| |
| req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl) |
| |
| resp, err := cs.doRequest(req) |
| if err != nil { |
| return "", err |
| } |
| defer resp.Body.Close() |
| |
| respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) |
| if err != nil { |
| return "", err |
| } |
| |
| if resp.StatusCode != 200 { |
| return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody)) |
| } |
| |
| return string(respBody), nil |
| } |
| |
| func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { |
| if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { |
| return envAwsRegion, nil |
| } |
| if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" { |
| return envAwsRegion, nil |
| } |
| |
| if cs.RegionURL == "" { |
| return "", errors.New("oauth2/google: unable to determine AWS region") |
| } |
| |
| req, err := http.NewRequest("GET", cs.RegionURL, nil) |
| if err != nil { |
| return "", err |
| } |
| |
| for name, value := range headers { |
| req.Header.Add(name, value) |
| } |
| |
| resp, err := cs.doRequest(req) |
| if err != nil { |
| return "", err |
| } |
| defer resp.Body.Close() |
| |
| respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) |
| if err != nil { |
| return "", err |
| } |
| |
| if resp.StatusCode != 200 { |
| return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody)) |
| } |
| |
| // This endpoint will return the region in format: us-east-2b. |
| // Only the us-east-2 part should be used. |
| respBodyEnd := 0 |
| if len(respBody) > 1 { |
| respBodyEnd = len(respBody) - 1 |
| } |
| return string(respBody[:respBodyEnd]), nil |
| } |
| |
| func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { |
| if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { |
| if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { |
| return awsSecurityCredentials{ |
| AccessKeyID: accessKeyID, |
| SecretAccessKey: secretAccessKey, |
| SecurityToken: getenv("AWS_SESSION_TOKEN"), |
| }, nil |
| } |
| } |
| |
| roleName, err := cs.getMetadataRoleName(headers) |
| if err != nil { |
| return |
| } |
| |
| credentials, err := cs.getMetadataSecurityCredentials(roleName, headers) |
| if err != nil { |
| return |
| } |
| |
| if credentials.AccessKeyID == "" { |
| return result, errors.New("oauth2/google: missing AccessKeyId credential") |
| } |
| |
| if credentials.SecretAccessKey == "" { |
| return result, errors.New("oauth2/google: missing SecretAccessKey credential") |
| } |
| |
| return credentials, nil |
| } |
| |
| func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) { |
| var result awsSecurityCredentials |
| |
| req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) |
| if err != nil { |
| return result, err |
| } |
| req.Header.Add("Content-Type", "application/json") |
| |
| for name, value := range headers { |
| req.Header.Add(name, value) |
| } |
| |
| resp, err := cs.doRequest(req) |
| if err != nil { |
| return result, err |
| } |
| defer resp.Body.Close() |
| |
| respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) |
| if err != nil { |
| return result, err |
| } |
| |
| if resp.StatusCode != 200 { |
| return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody)) |
| } |
| |
| err = json.Unmarshal(respBody, &result) |
| return result, err |
| } |
| |
| func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) { |
| if cs.CredVerificationURL == "" { |
| return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint") |
| } |
| |
| req, err := http.NewRequest("GET", cs.CredVerificationURL, nil) |
| if err != nil { |
| return "", err |
| } |
| |
| for name, value := range headers { |
| req.Header.Add(name, value) |
| } |
| |
| resp, err := cs.doRequest(req) |
| if err != nil { |
| return "", err |
| } |
| defer resp.Body.Close() |
| |
| respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) |
| if err != nil { |
| return "", err |
| } |
| |
| if resp.StatusCode != 200 { |
| return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody)) |
| } |
| |
| return string(respBody), nil |
| } |