1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package externalaccount
16
17import (
18 "bytes"
19 "context"
20 "crypto/hmac"
21 "crypto/sha256"
22 "encoding/hex"
23 "encoding/json"
24 "errors"
25 "fmt"
26 "log/slog"
27 "net/http"
28 "net/url"
29 "os"
30 "path"
31 "sort"
32 "strings"
33 "time"
34
35 "cloud.google.com/go/auth/internal"
36 "github.com/googleapis/gax-go/v2/internallog"
37)
38
39var (
40 // getenv aliases os.Getenv for testing
41 getenv = os.Getenv
42)
43
44const (
45 // AWS Signature Version 4 signing algorithm identifier.
46 awsAlgorithm = "AWS4-HMAC-SHA256"
47
48 // The termination string for the AWS credential scope value as defined in
49 // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
50 awsRequestType = "aws4_request"
51
52 // The AWS authorization header name for the security session token if available.
53 awsSecurityTokenHeader = "x-amz-security-token"
54
55 // The name of the header containing the session token for metadata endpoint calls
56 awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
57
58 awsIMDSv2SessionTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
59
60 awsIMDSv2SessionTTL = "300"
61
62 // The AWS authorization header name for the auto-generated date.
63 awsDateHeader = "x-amz-date"
64
65 defaultRegionalCredentialVerificationURL = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
66
67 // Supported AWS configuration environment variables.
68 awsAccessKeyIDEnvVar = "AWS_ACCESS_KEY_ID"
69 awsDefaultRegionEnvVar = "AWS_DEFAULT_REGION"
70 awsRegionEnvVar = "AWS_REGION"
71 awsSecretAccessKeyEnvVar = "AWS_SECRET_ACCESS_KEY"
72 awsSessionTokenEnvVar = "AWS_SESSION_TOKEN"
73
74 awsTimeFormatLong = "20060102T150405Z"
75 awsTimeFormatShort = "20060102"
76 awsProviderType = "aws"
77)
78
79type awsSubjectProvider struct {
80 EnvironmentID string
81 RegionURL string
82 RegionalCredVerificationURL string
83 CredVerificationURL string
84 IMDSv2SessionTokenURL string
85 TargetResource string
86 requestSigner *awsRequestSigner
87 region string
88 securityCredentialsProvider AwsSecurityCredentialsProvider
89 reqOpts *RequestOptions
90
91 Client *http.Client
92 logger *slog.Logger
93}
94
95func (sp *awsSubjectProvider) subjectToken(ctx context.Context) (string, error) {
96 // Set Defaults
97 if sp.RegionalCredVerificationURL == "" {
98 sp.RegionalCredVerificationURL = defaultRegionalCredentialVerificationURL
99 }
100 headers := make(map[string]string)
101 if sp.shouldUseMetadataServer() {
102 awsSessionToken, err := sp.getAWSSessionToken(ctx)
103 if err != nil {
104 return "", err
105 }
106
107 if awsSessionToken != "" {
108 headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
109 }
110 }
111
112 awsSecurityCredentials, err := sp.getSecurityCredentials(ctx, headers)
113 if err != nil {
114 return "", err
115 }
116 if sp.region, err = sp.getRegion(ctx, headers); err != nil {
117 return "", err
118 }
119 sp.requestSigner = &awsRequestSigner{
120 RegionName: sp.region,
121 AwsSecurityCredentials: awsSecurityCredentials,
122 }
123
124 // Generate the signed request to AWS STS GetCallerIdentity API.
125 // Use the required regional endpoint. Otherwise, the request will fail.
126 req, err := http.NewRequestWithContext(ctx, "POST", strings.Replace(sp.RegionalCredVerificationURL, "{region}", sp.region, 1), nil)
127 if err != nil {
128 return "", err
129 }
130 // The full, canonical resource name of the workload identity pool
131 // provider, with or without the HTTPS prefix.
132 // Including this header as part of the signature is recommended to
133 // ensure data integrity.
134 if sp.TargetResource != "" {
135 req.Header.Set("x-goog-cloud-target-resource", sp.TargetResource)
136 }
137 sp.requestSigner.signRequest(req)
138
139 /*
140 The GCP STS endpoint expects the headers to be formatted as:
141 # [
142 # {key: 'x-amz-date', value: '...'},
143 # {key: 'Authorization', value: '...'},
144 # ...
145 # ]
146 # And then serialized as:
147 # quote(json.dumps({
148 # url: '...',
149 # method: 'POST',
150 # headers: [{key: 'x-amz-date', value: '...'}, ...]
151 # }))
152 */
153
154 awsSignedReq := awsRequest{
155 URL: req.URL.String(),
156 Method: "POST",
157 }
158 for headerKey, headerList := range req.Header {
159 for _, headerValue := range headerList {
160 awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
161 Key: headerKey,
162 Value: headerValue,
163 })
164 }
165 }
166 sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
167 headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
168 if headerCompare == 0 {
169 return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
170 }
171 return headerCompare < 0
172 })
173
174 result, err := json.Marshal(awsSignedReq)
175 if err != nil {
176 return "", err
177 }
178 return url.QueryEscape(string(result)), nil
179}
180
181func (sp *awsSubjectProvider) providerType() string {
182 if sp.securityCredentialsProvider != nil {
183 return programmaticProviderType
184 }
185 return awsProviderType
186}
187
188func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
189 if sp.IMDSv2SessionTokenURL == "" {
190 return "", nil
191 }
192 req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil)
193 if err != nil {
194 return "", err
195 }
196 req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)
197
198 sp.logger.DebugContext(ctx, "aws session token request", "request", internallog.HTTPRequest(req, nil))
199 resp, body, err := internal.DoRequest(sp.Client, req)
200 if err != nil {
201 return "", err
202 }
203 sp.logger.DebugContext(ctx, "aws session token response", "response", internallog.HTTPResponse(resp, body))
204 if resp.StatusCode != http.StatusOK {
205 return "", fmt.Errorf("credentials: unable to retrieve AWS session token: %s", body)
206 }
207 return string(body), nil
208}
209
210func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
211 if sp.securityCredentialsProvider != nil {
212 return sp.securityCredentialsProvider.AwsRegion(ctx, sp.reqOpts)
213 }
214 if canRetrieveRegionFromEnvironment() {
215 if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" {
216 return envAwsRegion, nil
217 }
218 return getenv(awsDefaultRegionEnvVar), nil
219 }
220
221 if sp.RegionURL == "" {
222 return "", errors.New("credentials: unable to determine AWS region")
223 }
224
225 req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil)
226 if err != nil {
227 return "", err
228 }
229
230 for name, value := range headers {
231 req.Header.Add(name, value)
232 }
233 sp.logger.DebugContext(ctx, "aws region request", "request", internallog.HTTPRequest(req, nil))
234 resp, body, err := internal.DoRequest(sp.Client, req)
235 if err != nil {
236 return "", err
237 }
238 sp.logger.DebugContext(ctx, "aws region response", "response", internallog.HTTPResponse(resp, body))
239 if resp.StatusCode != http.StatusOK {
240 return "", fmt.Errorf("credentials: unable to retrieve AWS region - %s", body)
241 }
242
243 // This endpoint will return the region in format: us-east-2b.
244 // Only the us-east-2 part should be used.
245 bodyLen := len(body)
246 if bodyLen == 0 {
247 return "", nil
248 }
249 return string(body[:bodyLen-1]), nil
250}
251
252func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result *AwsSecurityCredentials, err error) {
253 if sp.securityCredentialsProvider != nil {
254 return sp.securityCredentialsProvider.AwsSecurityCredentials(ctx, sp.reqOpts)
255 }
256 if canRetrieveSecurityCredentialFromEnvironment() {
257 return &AwsSecurityCredentials{
258 AccessKeyID: getenv(awsAccessKeyIDEnvVar),
259 SecretAccessKey: getenv(awsSecretAccessKeyEnvVar),
260 SessionToken: getenv(awsSessionTokenEnvVar),
261 }, nil
262 }
263
264 roleName, err := sp.getMetadataRoleName(ctx, headers)
265 if err != nil {
266 return
267 }
268 credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers)
269 if err != nil {
270 return
271 }
272
273 if credentials.AccessKeyID == "" {
274 return result, errors.New("credentials: missing AccessKeyId credential")
275 }
276 if credentials.SecretAccessKey == "" {
277 return result, errors.New("credentials: missing SecretAccessKey credential")
278 }
279
280 return credentials, nil
281}
282
283func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (*AwsSecurityCredentials, error) {
284 var result *AwsSecurityCredentials
285
286 req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil)
287 if err != nil {
288 return result, err
289 }
290 for name, value := range headers {
291 req.Header.Add(name, value)
292 }
293 sp.logger.DebugContext(ctx, "aws security credential request", "request", internallog.HTTPRequest(req, nil))
294 resp, body, err := internal.DoRequest(sp.Client, req)
295 if err != nil {
296 return result, err
297 }
298 sp.logger.DebugContext(ctx, "aws security credential response", "response", internallog.HTTPResponse(resp, body))
299 if resp.StatusCode != http.StatusOK {
300 return result, fmt.Errorf("credentials: unable to retrieve AWS security credentials - %s", body)
301 }
302 if err := json.Unmarshal(body, &result); err != nil {
303 return nil, err
304 }
305 return result, nil
306}
307
308func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
309 if sp.CredVerificationURL == "" {
310 return "", errors.New("credentials: unable to determine the AWS metadata server security credentials endpoint")
311 }
312 req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil)
313 if err != nil {
314 return "", err
315 }
316 for name, value := range headers {
317 req.Header.Add(name, value)
318 }
319
320 sp.logger.DebugContext(ctx, "aws metadata role request", "request", internallog.HTTPRequest(req, nil))
321 resp, body, err := internal.DoRequest(sp.Client, req)
322 if err != nil {
323 return "", err
324 }
325 sp.logger.DebugContext(ctx, "aws metadata role response", "response", internallog.HTTPResponse(resp, body))
326 if resp.StatusCode != http.StatusOK {
327 return "", fmt.Errorf("credentials: unable to retrieve AWS role name - %s", body)
328 }
329 return string(body), nil
330}
331
332// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
333type awsRequestSigner struct {
334 RegionName string
335 AwsSecurityCredentials *AwsSecurityCredentials
336}
337
338// signRequest adds the appropriate headers to an http.Request
339// or returns an error if something prevented this.
340func (rs *awsRequestSigner) signRequest(req *http.Request) error {
341 // req is assumed non-nil
342 signedRequest := cloneRequest(req)
343 timestamp := Now()
344 signedRequest.Header.Set("host", requestHost(req))
345 if rs.AwsSecurityCredentials.SessionToken != "" {
346 signedRequest.Header.Set(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
347 }
348 if signedRequest.Header.Get("date") == "" {
349 signedRequest.Header.Set(awsDateHeader, timestamp.Format(awsTimeFormatLong))
350 }
351 authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
352 if err != nil {
353 return err
354 }
355 signedRequest.Header.Set("Authorization", authorizationCode)
356 req.Header = signedRequest.Header
357 return nil
358}
359
360func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
361 canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
362 dateStamp := timestamp.Format(awsTimeFormatShort)
363 serviceName := ""
364
365 if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
366 serviceName = splitHost[0]
367 }
368 credentialScope := strings.Join([]string{dateStamp, rs.RegionName, serviceName, awsRequestType}, "/")
369 requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
370 if err != nil {
371 return "", err
372 }
373 requestHash, err := getSha256([]byte(requestString))
374 if err != nil {
375 return "", err
376 }
377
378 stringToSign := strings.Join([]string{awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash}, "\n")
379 signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
380 for _, signingInput := range []string{
381 dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
382 } {
383 signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
384 if err != nil {
385 return "", err
386 }
387 }
388
389 return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
390}
391
392func getSha256(input []byte) (string, error) {
393 hash := sha256.New()
394 if _, err := hash.Write(input); err != nil {
395 return "", err
396 }
397 return hex.EncodeToString(hash.Sum(nil)), nil
398}
399
400func getHmacSha256(key, input []byte) ([]byte, error) {
401 hash := hmac.New(sha256.New, key)
402 if _, err := hash.Write(input); err != nil {
403 return nil, err
404 }
405 return hash.Sum(nil), nil
406}
407
408func cloneRequest(r *http.Request) *http.Request {
409 r2 := new(http.Request)
410 *r2 = *r
411 if r.Header != nil {
412 r2.Header = make(http.Header, len(r.Header))
413
414 // Find total number of values.
415 headerCount := 0
416 for _, headerValues := range r.Header {
417 headerCount += len(headerValues)
418 }
419 copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
420
421 for headerKey, headerValues := range r.Header {
422 headerCount = copy(copiedHeaders, headerValues)
423 r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
424 copiedHeaders = copiedHeaders[headerCount:]
425 }
426 }
427 return r2
428}
429
430func canonicalPath(req *http.Request) string {
431 result := req.URL.EscapedPath()
432 if result == "" {
433 return "/"
434 }
435 return path.Clean(result)
436}
437
438func canonicalQuery(req *http.Request) string {
439 queryValues := req.URL.Query()
440 for queryKey := range queryValues {
441 sort.Strings(queryValues[queryKey])
442 }
443 return queryValues.Encode()
444}
445
446func canonicalHeaders(req *http.Request) (string, string) {
447 // Header keys need to be sorted alphabetically.
448 var headers []string
449 lowerCaseHeaders := make(http.Header)
450 for k, v := range req.Header {
451 k := strings.ToLower(k)
452 if _, ok := lowerCaseHeaders[k]; ok {
453 // include additional values
454 lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
455 } else {
456 headers = append(headers, k)
457 lowerCaseHeaders[k] = v
458 }
459 }
460 sort.Strings(headers)
461
462 var fullHeaders bytes.Buffer
463 for _, header := range headers {
464 headerValue := strings.Join(lowerCaseHeaders[header], ",")
465 fullHeaders.WriteString(header)
466 fullHeaders.WriteRune(':')
467 fullHeaders.WriteString(headerValue)
468 fullHeaders.WriteRune('\n')
469 }
470
471 return strings.Join(headers, ";"), fullHeaders.String()
472}
473
474func requestDataHash(req *http.Request) (string, error) {
475 var requestData []byte
476 if req.Body != nil {
477 requestBody, err := req.GetBody()
478 if err != nil {
479 return "", err
480 }
481 defer requestBody.Close()
482
483 requestData, err = internal.ReadAll(requestBody)
484 if err != nil {
485 return "", err
486 }
487 }
488
489 return getSha256(requestData)
490}
491
492func requestHost(req *http.Request) string {
493 if req.Host != "" {
494 return req.Host
495 }
496 return req.URL.Host
497}
498
499func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
500 dataHash, err := requestDataHash(req)
501 if err != nil {
502 return "", err
503 }
504 return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
505}
506
507type awsRequestHeader struct {
508 Key string `json:"key"`
509 Value string `json:"value"`
510}
511
512type awsRequest struct {
513 URL string `json:"url"`
514 Method string `json:"method"`
515 Headers []awsRequestHeader `json:"headers"`
516}
517
518// The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
519// required.
520func canRetrieveRegionFromEnvironment() bool {
521 return getenv(awsRegionEnvVar) != "" || getenv(awsDefaultRegionEnvVar) != ""
522}
523
524// Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
525func canRetrieveSecurityCredentialFromEnvironment() bool {
526 return getenv(awsAccessKeyIDEnvVar) != "" && getenv(awsSecretAccessKeyEnvVar) != ""
527}
528
529func (sp *awsSubjectProvider) shouldUseMetadataServer() bool {
530 return sp.securityCredentialsProvider == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
531}