| // Copyright 2021 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package externalaccount |
| |
| import ( |
| "crypto/hmac" |
| "crypto/sha256" |
| "encoding/hex" |
| "errors" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net/http" |
| "path" |
| "sort" |
| "strings" |
| "time" |
| ) |
| |
| // RequestSigner is a utility class to sign http requests using a AWS V4 signature. |
| type awsRequestSigner struct { |
| RegionName string |
| AwsSecurityCredentials map[string]string |
| } |
| |
| const ( |
| // AWS Signature Version 4 signing algorithm identifier. |
| awsAlgorithm = "AWS4-HMAC-SHA256" |
| |
| // The termination string for the AWS credential scope value as defined in |
| // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html |
| awsRequestType = "aws4_request" |
| |
| // The AWS authorization header name for the security session token if available. |
| awsSecurityTokenHeader = "x-amz-security-token" |
| |
| // The AWS authorization header name for the auto-generated date. |
| awsDateHeader = "x-amz-date" |
| |
| awsTimeFormatLong = "20060102T150405Z" |
| awsTimeFormatShort = "20060102" |
| ) |
| |
| func getSha256(input []byte) (string, error) { |
| hash := sha256.New() |
| if _, err := hash.Write(input); err != nil { |
| return "", err |
| } |
| return hex.EncodeToString(hash.Sum(nil)), nil |
| } |
| |
| func getHmacSha256(key, input []byte) ([]byte, error) { |
| hash := hmac.New(sha256.New, key) |
| if _, err := hash.Write(input); err != nil { |
| return nil, err |
| } |
| return hash.Sum(nil), nil |
| } |
| |
| func cloneRequest(r *http.Request) *http.Request { |
| r2 := new(http.Request) |
| *r2 = *r |
| if r.Header != nil { |
| r2.Header = make(http.Header, len(r.Header)) |
| |
| // Find total number of values. |
| headerCount := 0 |
| for _, headerValues := range r.Header { |
| headerCount += len(headerValues) |
| } |
| copiedHeaders := make([]string, headerCount) // shared backing array for headers' values |
| |
| for headerKey, headerValues := range r.Header { |
| headerCount = copy(copiedHeaders, headerValues) |
| r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount] |
| copiedHeaders = copiedHeaders[headerCount:] |
| } |
| } |
| return r2 |
| } |
| |
| func canonicalPath(req *http.Request) string { |
| result := req.URL.EscapedPath() |
| if result == "" { |
| return "/" |
| } |
| return path.Clean(result) |
| } |
| |
| func canonicalQuery(req *http.Request) string { |
| queryValues := req.URL.Query() |
| for queryKey := range queryValues { |
| sort.Strings(queryValues[queryKey]) |
| } |
| return queryValues.Encode() |
| } |
| |
| func canonicalHeaders(req *http.Request) (string, string) { |
| // Header keys need to be sorted alphabetically. |
| var headers []string |
| lowerCaseHeaders := make(http.Header) |
| for k, v := range req.Header { |
| k := strings.ToLower(k) |
| if _, ok := lowerCaseHeaders[k]; ok { |
| // include additional values |
| lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...) |
| } else { |
| headers = append(headers, k) |
| lowerCaseHeaders[k] = v |
| } |
| } |
| sort.Strings(headers) |
| |
| var fullHeaders strings.Builder |
| for _, header := range headers { |
| headerValue := strings.Join(lowerCaseHeaders[header], ",") |
| fullHeaders.WriteString(header) |
| fullHeaders.WriteRune(':') |
| fullHeaders.WriteString(headerValue) |
| fullHeaders.WriteRune('\n') |
| } |
| |
| return strings.Join(headers, ";"), fullHeaders.String() |
| } |
| |
| func requestDataHash(req *http.Request) (string, error) { |
| var requestData []byte |
| if req.Body != nil { |
| requestBody, err := req.GetBody() |
| if err != nil { |
| return "", err |
| } |
| defer requestBody.Close() |
| |
| requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20)) |
| if err != nil { |
| return "", err |
| } |
| } |
| |
| return getSha256(requestData) |
| } |
| |
| func requestHost(req *http.Request) string { |
| if req.Host != "" { |
| return req.Host |
| } |
| return req.URL.Host |
| } |
| |
| func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) { |
| dataHash, err := requestDataHash(req) |
| if err != nil { |
| return "", err |
| } |
| |
| return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil |
| } |
| |
| // SignRequest adds the appropriate headers to an http.Request |
| // or returns an error if something prevented this. |
| func (rs *awsRequestSigner) SignRequest(req *http.Request) error { |
| signedRequest := cloneRequest(req) |
| timestamp := now() |
| |
| signedRequest.Header.Add("host", requestHost(req)) |
| |
| if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok { |
| signedRequest.Header.Add(awsSecurityTokenHeader, securityToken) |
| } |
| |
| if signedRequest.Header.Get("date") == "" { |
| signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong)) |
| } |
| |
| authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp) |
| if err != nil { |
| return err |
| } |
| signedRequest.Header.Set("Authorization", authorizationCode) |
| |
| req.Header = signedRequest.Header |
| return nil |
| } |
| |
| func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) { |
| secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"] |
| if !ok { |
| return "", errors.New("oauth2/google: missing secret_access_key header") |
| } |
| accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"] |
| if !ok { |
| return "", errors.New("oauth2/google: missing access_key_id header") |
| } |
| |
| canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req) |
| |
| dateStamp := timestamp.Format(awsTimeFormatShort) |
| serviceName := "" |
| if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 { |
| serviceName = splitHost[0] |
| } |
| |
| credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType) |
| |
| requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData) |
| if err != nil { |
| return "", err |
| } |
| requestHash, err := getSha256([]byte(requestString)) |
| if err != nil{ |
| return "", err |
| } |
| |
| stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash) |
| |
| signingKey := []byte("AWS4" + secretAccessKey) |
| for _, signingInput := range []string{ |
| dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign, |
| } { |
| signingKey, err = getHmacSha256(signingKey, []byte(signingInput)) |
| if err != nil{ |
| return "", err |
| } |
| } |
| |
| return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil |
| } |