1package stscreds
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io/ioutil"
  7	"strconv"
  8	"strings"
  9	"time"
 10
 11	"github.com/aws/aws-sdk-go-v2/aws"
 12	"github.com/aws/aws-sdk-go-v2/aws/retry"
 13	"github.com/aws/aws-sdk-go-v2/internal/sdk"
 14	"github.com/aws/aws-sdk-go-v2/service/sts"
 15	"github.com/aws/aws-sdk-go-v2/service/sts/types"
 16)
 17
 18var invalidIdentityTokenExceptionCode = (&types.InvalidIdentityTokenException{}).ErrorCode()
 19
 20const (
 21	// WebIdentityProviderName is the web identity provider name
 22	WebIdentityProviderName = "WebIdentityCredentials"
 23)
 24
 25// AssumeRoleWithWebIdentityAPIClient is a client capable of the STS AssumeRoleWithWebIdentity operation.
 26type AssumeRoleWithWebIdentityAPIClient interface {
 27	AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error)
 28}
 29
 30// WebIdentityRoleProvider is used to retrieve credentials using
 31// an OIDC token.
 32type WebIdentityRoleProvider struct {
 33	options WebIdentityRoleOptions
 34}
 35
 36// WebIdentityRoleOptions is a structure of configurable options for WebIdentityRoleProvider
 37type WebIdentityRoleOptions struct {
 38	// Client implementation of the AssumeRoleWithWebIdentity operation. Required
 39	Client AssumeRoleWithWebIdentityAPIClient
 40
 41	// JWT Token Provider. Required
 42	TokenRetriever IdentityTokenRetriever
 43
 44	// IAM Role ARN to assume. Required
 45	RoleARN string
 46
 47	// Session name, if you wish to uniquely identify this session.
 48	RoleSessionName string
 49
 50	// Expiry duration of the STS credentials. STS will assign a default expiry
 51	// duration if this value is unset. This is different from the Duration
 52	// option of AssumeRoleProvider, which automatically assigns 15 minutes if
 53	// Duration is unset.
 54	//
 55	// See the STS AssumeRoleWithWebIdentity API reference guide for more
 56	// information on defaults.
 57	// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
 58	Duration time.Duration
 59
 60	// An IAM policy in JSON format that you want to use as an inline session policy.
 61	Policy *string
 62
 63	// The Amazon Resource Names (ARNs) of the IAM managed policies that you
 64	// want to use as managed session policies.  The policies must exist in the
 65	// same account as the role.
 66	PolicyARNs []types.PolicyDescriptorType
 67}
 68
 69// IdentityTokenRetriever is an interface for retrieving a JWT
 70type IdentityTokenRetriever interface {
 71	GetIdentityToken() ([]byte, error)
 72}
 73
 74// IdentityTokenFile is for retrieving an identity token from the given file name
 75type IdentityTokenFile string
 76
 77// GetIdentityToken retrieves the JWT token from the file and returns the contents as a []byte
 78func (j IdentityTokenFile) GetIdentityToken() ([]byte, error) {
 79	b, err := ioutil.ReadFile(string(j))
 80	if err != nil {
 81		return nil, fmt.Errorf("unable to read file at %s: %v", string(j), err)
 82	}
 83
 84	return b, nil
 85}
 86
 87// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
 88// provided stsiface.ClientAPI
 89func NewWebIdentityRoleProvider(client AssumeRoleWithWebIdentityAPIClient, roleARN string, tokenRetriever IdentityTokenRetriever, optFns ...func(*WebIdentityRoleOptions)) *WebIdentityRoleProvider {
 90	o := WebIdentityRoleOptions{
 91		Client:         client,
 92		RoleARN:        roleARN,
 93		TokenRetriever: tokenRetriever,
 94	}
 95
 96	for _, fn := range optFns {
 97		fn(&o)
 98	}
 99
100	return &WebIdentityRoleProvider{options: o}
101}
102
103// Retrieve will attempt to assume a role from a token which is located at
104// 'WebIdentityTokenFilePath' specified destination and if that is empty an
105// error will be returned.
106func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
107	b, err := p.options.TokenRetriever.GetIdentityToken()
108	if err != nil {
109		return aws.Credentials{}, fmt.Errorf("failed to retrieve jwt from provide source, %w", err)
110	}
111
112	sessionName := p.options.RoleSessionName
113	if len(sessionName) == 0 {
114		// session name is used to uniquely identify a session. This simply
115		// uses unix time in nanoseconds to uniquely identify sessions.
116		sessionName = strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
117	}
118	input := &sts.AssumeRoleWithWebIdentityInput{
119		PolicyArns:       p.options.PolicyARNs,
120		RoleArn:          &p.options.RoleARN,
121		RoleSessionName:  &sessionName,
122		WebIdentityToken: aws.String(string(b)),
123	}
124	if p.options.Duration != 0 {
125		// If set use the value, otherwise STS will assign a default expiration duration.
126		input.DurationSeconds = aws.Int32(int32(p.options.Duration / time.Second))
127	}
128	if p.options.Policy != nil {
129		input.Policy = p.options.Policy
130	}
131
132	resp, err := p.options.Client.AssumeRoleWithWebIdentity(ctx, input, func(options *sts.Options) {
133		options.Retryer = retry.AddWithErrorCodes(options.Retryer, invalidIdentityTokenExceptionCode)
134	})
135	if err != nil {
136		return aws.Credentials{}, fmt.Errorf("failed to retrieve credentials, %w", err)
137	}
138
139	var accountID string
140	if resp.AssumedRoleUser != nil {
141		accountID = getAccountID(resp.AssumedRoleUser)
142	}
143
144	// InvalidIdentityToken error is a temporary error that can occur
145	// when assuming an Role with a JWT web identity token.
146
147	value := aws.Credentials{
148		AccessKeyID:     aws.ToString(resp.Credentials.AccessKeyId),
149		SecretAccessKey: aws.ToString(resp.Credentials.SecretAccessKey),
150		SessionToken:    aws.ToString(resp.Credentials.SessionToken),
151		Source:          WebIdentityProviderName,
152		CanExpire:       true,
153		Expires:         *resp.Credentials.Expiration,
154		AccountID:       accountID,
155	}
156	return value, nil
157}
158
159// extract accountID from arn with format "arn:partition:service:region:account-id:[resource-section]"
160func getAccountID(u *types.AssumedRoleUser) string {
161	if u.Arn == nil {
162		return ""
163	}
164	parts := strings.Split(*u.Arn, ":")
165	if len(parts) < 5 {
166		return ""
167	}
168	return parts[4]
169}