1package ec2rolecreds
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"math"
  9	"path"
 10	"strings"
 11	"time"
 12
 13	"github.com/aws/aws-sdk-go-v2/aws"
 14	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
 15	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
 16	"github.com/aws/aws-sdk-go-v2/internal/sdk"
 17	"github.com/aws/smithy-go"
 18	"github.com/aws/smithy-go/logging"
 19	"github.com/aws/smithy-go/middleware"
 20)
 21
 22// ProviderName provides a name of EC2Role provider
 23const ProviderName = "EC2RoleProvider"
 24
 25// GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
 26// GetMetadata operation.
 27type GetMetadataAPIClient interface {
 28	GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
 29}
 30
 31// A Provider retrieves credentials from the EC2 service, and keeps track if
 32// those credentials are expired.
 33//
 34// The New function must be used to create the with a custom EC2 IMDS client.
 35//
 36//	p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
 37//	     o.Client = imds.New(imds.Options{/* custom options */})
 38//	})
 39type Provider struct {
 40	options Options
 41}
 42
 43// Options is a list of user settable options for setting the behavior of the Provider.
 44type Options struct {
 45	// The API client that will be used by the provider to make GetMetadata API
 46	// calls to EC2 IMDS.
 47	//
 48	// If nil, the provider will default to the EC2 IMDS client.
 49	Client GetMetadataAPIClient
 50}
 51
 52// New returns an initialized Provider value configured to retrieve
 53// credentials from EC2 Instance Metadata service.
 54func New(optFns ...func(*Options)) *Provider {
 55	options := Options{}
 56
 57	for _, fn := range optFns {
 58		fn(&options)
 59	}
 60
 61	if options.Client == nil {
 62		options.Client = imds.New(imds.Options{})
 63	}
 64
 65	return &Provider{
 66		options: options,
 67	}
 68}
 69
 70// Retrieve retrieves credentials from the EC2 service. Error will be returned
 71// if the request fails, or unable to extract the desired credentials.
 72func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
 73	credsList, err := requestCredList(ctx, p.options.Client)
 74	if err != nil {
 75		return aws.Credentials{Source: ProviderName}, err
 76	}
 77
 78	if len(credsList) == 0 {
 79		return aws.Credentials{Source: ProviderName},
 80			fmt.Errorf("unexpected empty EC2 IMDS role list")
 81	}
 82	credsName := credsList[0]
 83
 84	roleCreds, err := requestCred(ctx, p.options.Client, credsName)
 85	if err != nil {
 86		return aws.Credentials{Source: ProviderName}, err
 87	}
 88
 89	creds := aws.Credentials{
 90		AccessKeyID:     roleCreds.AccessKeyID,
 91		SecretAccessKey: roleCreds.SecretAccessKey,
 92		SessionToken:    roleCreds.Token,
 93		Source:          ProviderName,
 94
 95		CanExpire: true,
 96		Expires:   roleCreds.Expiration,
 97	}
 98
 99	// Cap role credentials Expires to 1 hour so they can be refreshed more
100	// often. Jitter will be applied credentials cache if being used.
101	if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
102		creds.Expires = anHour
103	}
104
105	return creds, nil
106}
107
108// HandleFailToRefresh will extend the credentials Expires time if it it is
109// expired. If the credentials will not expire within the minimum time, they
110// will be returned.
111//
112// If the credentials cannot expire, the original error will be returned.
113func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
114	aws.Credentials, error,
115) {
116	if !prevCreds.CanExpire {
117		return aws.Credentials{}, err
118	}
119
120	if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
121		return prevCreds, nil
122	}
123
124	newCreds := prevCreds
125	randFloat64, err := sdkrand.CryptoRandFloat64()
126	if err != nil {
127		return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
128	}
129
130	// Random distribution of [5,15) minutes.
131	expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
132	newCreds.Expires = sdk.NowTime().Add(expireOffset)
133
134	logger := middleware.GetLogger(ctx)
135	logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
136
137	return newCreds, nil
138}
139
140// AdjustExpiresBy will adds the passed in duration to the passed in
141// credential's Expires time, unless the time until Expires is less than 15
142// minutes. Returns the credentials, even if not updated.
143func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
144	aws.Credentials, error,
145) {
146	if !creds.CanExpire {
147		return creds, nil
148	}
149	if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
150		return creds, nil
151	}
152
153	creds.Expires = creds.Expires.Add(dur)
154	return creds, nil
155}
156
157// ec2RoleCredRespBody provides the shape for unmarshaling credential
158// request responses.
159type ec2RoleCredRespBody struct {
160	// Success State
161	Expiration      time.Time
162	AccessKeyID     string
163	SecretAccessKey string
164	Token           string
165
166	// Error state
167	Code    string
168	Message string
169}
170
171const iamSecurityCredsPath = "/iam/security-credentials/"
172
173// requestCredList requests a list of credentials from the EC2 service. If
174// there are no credentials, or there is an error making or receiving the
175// request
176func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
177	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
178		Path: iamSecurityCredsPath,
179	})
180	if err != nil {
181		return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
182	}
183	defer resp.Content.Close()
184
185	credsList := []string{}
186	s := bufio.NewScanner(resp.Content)
187	for s.Scan() {
188		credsList = append(credsList, s.Text())
189	}
190
191	if err := s.Err(); err != nil {
192		return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
193	}
194
195	return credsList, nil
196}
197
198// requestCred requests the credentials for a specific credentials from the EC2 service.
199//
200// If the credentials cannot be found, or there is an error reading the response
201// and error will be returned.
202func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
203	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
204		Path: path.Join(iamSecurityCredsPath, credsName),
205	})
206	if err != nil {
207		return ec2RoleCredRespBody{},
208			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
209				credsName, err)
210	}
211	defer resp.Content.Close()
212
213	var respCreds ec2RoleCredRespBody
214	if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
215		return ec2RoleCredRespBody{},
216			fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
217				credsName, err)
218	}
219
220	if !strings.EqualFold(respCreds.Code, "Success") {
221		// If an error code was returned something failed requesting the role.
222		return ec2RoleCredRespBody{},
223			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
224				credsName,
225				&smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
226	}
227
228	return respCreds, nil
229}