aws_provider.go

  1// Copyright 2023 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package externalaccount
 16
 17import (
 18	"bytes"
 19	"context"
 20	"crypto/hmac"
 21	"crypto/sha256"
 22	"encoding/hex"
 23	"encoding/json"
 24	"errors"
 25	"fmt"
 26	"log/slog"
 27	"net/http"
 28	"net/url"
 29	"os"
 30	"path"
 31	"sort"
 32	"strings"
 33	"time"
 34
 35	"cloud.google.com/go/auth/internal"
 36	"github.com/googleapis/gax-go/v2/internallog"
 37)
 38
 39var (
 40	// getenv aliases os.Getenv for testing
 41	getenv = os.Getenv
 42)
 43
 44const (
 45	// AWS Signature Version 4 signing algorithm identifier.
 46	awsAlgorithm = "AWS4-HMAC-SHA256"
 47
 48	// The termination string for the AWS credential scope value as defined in
 49	// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
 50	awsRequestType = "aws4_request"
 51
 52	// The AWS authorization header name for the security session token if available.
 53	awsSecurityTokenHeader = "x-amz-security-token"
 54
 55	// The name of the header containing the session token for metadata endpoint calls
 56	awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
 57
 58	awsIMDSv2SessionTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
 59
 60	awsIMDSv2SessionTTL = "300"
 61
 62	// The AWS authorization header name for the auto-generated date.
 63	awsDateHeader = "x-amz-date"
 64
 65	defaultRegionalCredentialVerificationURL = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
 66
 67	// Supported AWS configuration environment variables.
 68	awsAccessKeyIDEnvVar     = "AWS_ACCESS_KEY_ID"
 69	awsDefaultRegionEnvVar   = "AWS_DEFAULT_REGION"
 70	awsRegionEnvVar          = "AWS_REGION"
 71	awsSecretAccessKeyEnvVar = "AWS_SECRET_ACCESS_KEY"
 72	awsSessionTokenEnvVar    = "AWS_SESSION_TOKEN"
 73
 74	awsTimeFormatLong  = "20060102T150405Z"
 75	awsTimeFormatShort = "20060102"
 76	awsProviderType    = "aws"
 77)
 78
 79type awsSubjectProvider struct {
 80	EnvironmentID               string
 81	RegionURL                   string
 82	RegionalCredVerificationURL string
 83	CredVerificationURL         string
 84	IMDSv2SessionTokenURL       string
 85	TargetResource              string
 86	requestSigner               *awsRequestSigner
 87	region                      string
 88	securityCredentialsProvider AwsSecurityCredentialsProvider
 89	reqOpts                     *RequestOptions
 90
 91	Client *http.Client
 92	logger *slog.Logger
 93}
 94
 95func (sp *awsSubjectProvider) subjectToken(ctx context.Context) (string, error) {
 96	// Set Defaults
 97	if sp.RegionalCredVerificationURL == "" {
 98		sp.RegionalCredVerificationURL = defaultRegionalCredentialVerificationURL
 99	}
100	headers := make(map[string]string)
101	if sp.shouldUseMetadataServer() {
102		awsSessionToken, err := sp.getAWSSessionToken(ctx)
103		if err != nil {
104			return "", err
105		}
106
107		if awsSessionToken != "" {
108			headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
109		}
110	}
111
112	awsSecurityCredentials, err := sp.getSecurityCredentials(ctx, headers)
113	if err != nil {
114		return "", err
115	}
116	if sp.region, err = sp.getRegion(ctx, headers); err != nil {
117		return "", err
118	}
119	sp.requestSigner = &awsRequestSigner{
120		RegionName:             sp.region,
121		AwsSecurityCredentials: awsSecurityCredentials,
122	}
123
124	// Generate the signed request to AWS STS GetCallerIdentity API.
125	// Use the required regional endpoint. Otherwise, the request will fail.
126	req, err := http.NewRequestWithContext(ctx, "POST", strings.Replace(sp.RegionalCredVerificationURL, "{region}", sp.region, 1), nil)
127	if err != nil {
128		return "", err
129	}
130	// The full, canonical resource name of the workload identity pool
131	// provider, with or without the HTTPS prefix.
132	// Including this header as part of the signature is recommended to
133	// ensure data integrity.
134	if sp.TargetResource != "" {
135		req.Header.Set("x-goog-cloud-target-resource", sp.TargetResource)
136	}
137	sp.requestSigner.signRequest(req)
138
139	/*
140	   The GCP STS endpoint expects the headers to be formatted as:
141	   # [
142	   #   {key: 'x-amz-date', value: '...'},
143	   #   {key: 'Authorization', value: '...'},
144	   #   ...
145	   # ]
146	   # And then serialized as:
147	   # quote(json.dumps({
148	   #   url: '...',
149	   #   method: 'POST',
150	   #   headers: [{key: 'x-amz-date', value: '...'}, ...]
151	   # }))
152	*/
153
154	awsSignedReq := awsRequest{
155		URL:    req.URL.String(),
156		Method: "POST",
157	}
158	for headerKey, headerList := range req.Header {
159		for _, headerValue := range headerList {
160			awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
161				Key:   headerKey,
162				Value: headerValue,
163			})
164		}
165	}
166	sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
167		headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
168		if headerCompare == 0 {
169			return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
170		}
171		return headerCompare < 0
172	})
173
174	result, err := json.Marshal(awsSignedReq)
175	if err != nil {
176		return "", err
177	}
178	return url.QueryEscape(string(result)), nil
179}
180
181func (sp *awsSubjectProvider) providerType() string {
182	if sp.securityCredentialsProvider != nil {
183		return programmaticProviderType
184	}
185	return awsProviderType
186}
187
188func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
189	if sp.IMDSv2SessionTokenURL == "" {
190		return "", nil
191	}
192	req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil)
193	if err != nil {
194		return "", err
195	}
196	req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)
197
198	sp.logger.DebugContext(ctx, "aws session token request", "request", internallog.HTTPRequest(req, nil))
199	resp, body, err := internal.DoRequest(sp.Client, req)
200	if err != nil {
201		return "", err
202	}
203	sp.logger.DebugContext(ctx, "aws session token response", "response", internallog.HTTPResponse(resp, body))
204	if resp.StatusCode != http.StatusOK {
205		return "", fmt.Errorf("credentials: unable to retrieve AWS session token: %s", body)
206	}
207	return string(body), nil
208}
209
210func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
211	if sp.securityCredentialsProvider != nil {
212		return sp.securityCredentialsProvider.AwsRegion(ctx, sp.reqOpts)
213	}
214	if canRetrieveRegionFromEnvironment() {
215		if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" {
216			return envAwsRegion, nil
217		}
218		return getenv(awsDefaultRegionEnvVar), nil
219	}
220
221	if sp.RegionURL == "" {
222		return "", errors.New("credentials: unable to determine AWS region")
223	}
224
225	req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil)
226	if err != nil {
227		return "", err
228	}
229
230	for name, value := range headers {
231		req.Header.Add(name, value)
232	}
233	sp.logger.DebugContext(ctx, "aws region request", "request", internallog.HTTPRequest(req, nil))
234	resp, body, err := internal.DoRequest(sp.Client, req)
235	if err != nil {
236		return "", err
237	}
238	sp.logger.DebugContext(ctx, "aws region response", "response", internallog.HTTPResponse(resp, body))
239	if resp.StatusCode != http.StatusOK {
240		return "", fmt.Errorf("credentials: unable to retrieve AWS region - %s", body)
241	}
242
243	// This endpoint will return the region in format: us-east-2b.
244	// Only the us-east-2 part should be used.
245	bodyLen := len(body)
246	if bodyLen == 0 {
247		return "", nil
248	}
249	return string(body[:bodyLen-1]), nil
250}
251
252func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result *AwsSecurityCredentials, err error) {
253	if sp.securityCredentialsProvider != nil {
254		return sp.securityCredentialsProvider.AwsSecurityCredentials(ctx, sp.reqOpts)
255	}
256	if canRetrieveSecurityCredentialFromEnvironment() {
257		return &AwsSecurityCredentials{
258			AccessKeyID:     getenv(awsAccessKeyIDEnvVar),
259			SecretAccessKey: getenv(awsSecretAccessKeyEnvVar),
260			SessionToken:    getenv(awsSessionTokenEnvVar),
261		}, nil
262	}
263
264	roleName, err := sp.getMetadataRoleName(ctx, headers)
265	if err != nil {
266		return
267	}
268	credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers)
269	if err != nil {
270		return
271	}
272
273	if credentials.AccessKeyID == "" {
274		return result, errors.New("credentials: missing AccessKeyId credential")
275	}
276	if credentials.SecretAccessKey == "" {
277		return result, errors.New("credentials: missing SecretAccessKey credential")
278	}
279
280	return credentials, nil
281}
282
283func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (*AwsSecurityCredentials, error) {
284	var result *AwsSecurityCredentials
285
286	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil)
287	if err != nil {
288		return result, err
289	}
290	for name, value := range headers {
291		req.Header.Add(name, value)
292	}
293	sp.logger.DebugContext(ctx, "aws security credential request", "request", internallog.HTTPRequest(req, nil))
294	resp, body, err := internal.DoRequest(sp.Client, req)
295	if err != nil {
296		return result, err
297	}
298	sp.logger.DebugContext(ctx, "aws security credential response", "response", internallog.HTTPResponse(resp, body))
299	if resp.StatusCode != http.StatusOK {
300		return result, fmt.Errorf("credentials: unable to retrieve AWS security credentials - %s", body)
301	}
302	if err := json.Unmarshal(body, &result); err != nil {
303		return nil, err
304	}
305	return result, nil
306}
307
308func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
309	if sp.CredVerificationURL == "" {
310		return "", errors.New("credentials: unable to determine the AWS metadata server security credentials endpoint")
311	}
312	req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil)
313	if err != nil {
314		return "", err
315	}
316	for name, value := range headers {
317		req.Header.Add(name, value)
318	}
319
320	sp.logger.DebugContext(ctx, "aws metadata role request", "request", internallog.HTTPRequest(req, nil))
321	resp, body, err := internal.DoRequest(sp.Client, req)
322	if err != nil {
323		return "", err
324	}
325	sp.logger.DebugContext(ctx, "aws metadata role response", "response", internallog.HTTPResponse(resp, body))
326	if resp.StatusCode != http.StatusOK {
327		return "", fmt.Errorf("credentials: unable to retrieve AWS role name - %s", body)
328	}
329	return string(body), nil
330}
331
332// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
333type awsRequestSigner struct {
334	RegionName             string
335	AwsSecurityCredentials *AwsSecurityCredentials
336}
337
338// signRequest adds the appropriate headers to an http.Request
339// or returns an error if something prevented this.
340func (rs *awsRequestSigner) signRequest(req *http.Request) error {
341	// req is assumed non-nil
342	signedRequest := cloneRequest(req)
343	timestamp := Now()
344	signedRequest.Header.Set("host", requestHost(req))
345	if rs.AwsSecurityCredentials.SessionToken != "" {
346		signedRequest.Header.Set(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
347	}
348	if signedRequest.Header.Get("date") == "" {
349		signedRequest.Header.Set(awsDateHeader, timestamp.Format(awsTimeFormatLong))
350	}
351	authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
352	if err != nil {
353		return err
354	}
355	signedRequest.Header.Set("Authorization", authorizationCode)
356	req.Header = signedRequest.Header
357	return nil
358}
359
360func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
361	canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
362	dateStamp := timestamp.Format(awsTimeFormatShort)
363	serviceName := ""
364
365	if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
366		serviceName = splitHost[0]
367	}
368	credentialScope := strings.Join([]string{dateStamp, rs.RegionName, serviceName, awsRequestType}, "/")
369	requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
370	if err != nil {
371		return "", err
372	}
373	requestHash, err := getSha256([]byte(requestString))
374	if err != nil {
375		return "", err
376	}
377
378	stringToSign := strings.Join([]string{awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash}, "\n")
379	signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
380	for _, signingInput := range []string{
381		dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
382	} {
383		signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
384		if err != nil {
385			return "", err
386		}
387	}
388
389	return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
390}
391
392func getSha256(input []byte) (string, error) {
393	hash := sha256.New()
394	if _, err := hash.Write(input); err != nil {
395		return "", err
396	}
397	return hex.EncodeToString(hash.Sum(nil)), nil
398}
399
400func getHmacSha256(key, input []byte) ([]byte, error) {
401	hash := hmac.New(sha256.New, key)
402	if _, err := hash.Write(input); err != nil {
403		return nil, err
404	}
405	return hash.Sum(nil), nil
406}
407
408func cloneRequest(r *http.Request) *http.Request {
409	r2 := new(http.Request)
410	*r2 = *r
411	if r.Header != nil {
412		r2.Header = make(http.Header, len(r.Header))
413
414		// Find total number of values.
415		headerCount := 0
416		for _, headerValues := range r.Header {
417			headerCount += len(headerValues)
418		}
419		copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
420
421		for headerKey, headerValues := range r.Header {
422			headerCount = copy(copiedHeaders, headerValues)
423			r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
424			copiedHeaders = copiedHeaders[headerCount:]
425		}
426	}
427	return r2
428}
429
430func canonicalPath(req *http.Request) string {
431	result := req.URL.EscapedPath()
432	if result == "" {
433		return "/"
434	}
435	return path.Clean(result)
436}
437
438func canonicalQuery(req *http.Request) string {
439	queryValues := req.URL.Query()
440	for queryKey := range queryValues {
441		sort.Strings(queryValues[queryKey])
442	}
443	return queryValues.Encode()
444}
445
446func canonicalHeaders(req *http.Request) (string, string) {
447	// Header keys need to be sorted alphabetically.
448	var headers []string
449	lowerCaseHeaders := make(http.Header)
450	for k, v := range req.Header {
451		k := strings.ToLower(k)
452		if _, ok := lowerCaseHeaders[k]; ok {
453			// include additional values
454			lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
455		} else {
456			headers = append(headers, k)
457			lowerCaseHeaders[k] = v
458		}
459	}
460	sort.Strings(headers)
461
462	var fullHeaders bytes.Buffer
463	for _, header := range headers {
464		headerValue := strings.Join(lowerCaseHeaders[header], ",")
465		fullHeaders.WriteString(header)
466		fullHeaders.WriteRune(':')
467		fullHeaders.WriteString(headerValue)
468		fullHeaders.WriteRune('\n')
469	}
470
471	return strings.Join(headers, ";"), fullHeaders.String()
472}
473
474func requestDataHash(req *http.Request) (string, error) {
475	var requestData []byte
476	if req.Body != nil {
477		requestBody, err := req.GetBody()
478		if err != nil {
479			return "", err
480		}
481		defer requestBody.Close()
482
483		requestData, err = internal.ReadAll(requestBody)
484		if err != nil {
485			return "", err
486		}
487	}
488
489	return getSha256(requestData)
490}
491
492func requestHost(req *http.Request) string {
493	if req.Host != "" {
494		return req.Host
495	}
496	return req.URL.Host
497}
498
499func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
500	dataHash, err := requestDataHash(req)
501	if err != nil {
502		return "", err
503	}
504	return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
505}
506
507type awsRequestHeader struct {
508	Key   string `json:"key"`
509	Value string `json:"value"`
510}
511
512type awsRequest struct {
513	URL     string             `json:"url"`
514	Method  string             `json:"method"`
515	Headers []awsRequestHeader `json:"headers"`
516}
517
518// The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
519// required.
520func canRetrieveRegionFromEnvironment() bool {
521	return getenv(awsRegionEnvVar) != "" || getenv(awsDefaultRegionEnvVar) != ""
522}
523
524// Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
525func canRetrieveSecurityCredentialFromEnvironment() bool {
526	return getenv(awsAccessKeyIDEnvVar) != "" && getenv(awsSecretAccessKeyEnvVar) != ""
527}
528
529func (sp *awsSubjectProvider) shouldUseMetadataServer() bool {
530	return sp.securityCredentialsProvider == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
531}