1package ec2rolecreds
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "math"
9 "path"
10 "strings"
11 "time"
12
13 "github.com/aws/aws-sdk-go-v2/aws"
14 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
15 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
16 "github.com/aws/aws-sdk-go-v2/internal/sdk"
17 "github.com/aws/smithy-go"
18 "github.com/aws/smithy-go/logging"
19 "github.com/aws/smithy-go/middleware"
20)
21
22// ProviderName provides a name of EC2Role provider
23const ProviderName = "EC2RoleProvider"
24
25// GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
26// GetMetadata operation.
27type GetMetadataAPIClient interface {
28 GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
29}
30
31// A Provider retrieves credentials from the EC2 service, and keeps track if
32// those credentials are expired.
33//
34// The New function must be used to create the with a custom EC2 IMDS client.
35//
36// p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
37// o.Client = imds.New(imds.Options{/* custom options */})
38// })
39type Provider struct {
40 options Options
41}
42
43// Options is a list of user settable options for setting the behavior of the Provider.
44type Options struct {
45 // The API client that will be used by the provider to make GetMetadata API
46 // calls to EC2 IMDS.
47 //
48 // If nil, the provider will default to the EC2 IMDS client.
49 Client GetMetadataAPIClient
50}
51
52// New returns an initialized Provider value configured to retrieve
53// credentials from EC2 Instance Metadata service.
54func New(optFns ...func(*Options)) *Provider {
55 options := Options{}
56
57 for _, fn := range optFns {
58 fn(&options)
59 }
60
61 if options.Client == nil {
62 options.Client = imds.New(imds.Options{})
63 }
64
65 return &Provider{
66 options: options,
67 }
68}
69
70// Retrieve retrieves credentials from the EC2 service. Error will be returned
71// if the request fails, or unable to extract the desired credentials.
72func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
73 credsList, err := requestCredList(ctx, p.options.Client)
74 if err != nil {
75 return aws.Credentials{Source: ProviderName}, err
76 }
77
78 if len(credsList) == 0 {
79 return aws.Credentials{Source: ProviderName},
80 fmt.Errorf("unexpected empty EC2 IMDS role list")
81 }
82 credsName := credsList[0]
83
84 roleCreds, err := requestCred(ctx, p.options.Client, credsName)
85 if err != nil {
86 return aws.Credentials{Source: ProviderName}, err
87 }
88
89 creds := aws.Credentials{
90 AccessKeyID: roleCreds.AccessKeyID,
91 SecretAccessKey: roleCreds.SecretAccessKey,
92 SessionToken: roleCreds.Token,
93 Source: ProviderName,
94
95 CanExpire: true,
96 Expires: roleCreds.Expiration,
97 }
98
99 // Cap role credentials Expires to 1 hour so they can be refreshed more
100 // often. Jitter will be applied credentials cache if being used.
101 if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
102 creds.Expires = anHour
103 }
104
105 return creds, nil
106}
107
108// HandleFailToRefresh will extend the credentials Expires time if it it is
109// expired. If the credentials will not expire within the minimum time, they
110// will be returned.
111//
112// If the credentials cannot expire, the original error will be returned.
113func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
114 aws.Credentials, error,
115) {
116 if !prevCreds.CanExpire {
117 return aws.Credentials{}, err
118 }
119
120 if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
121 return prevCreds, nil
122 }
123
124 newCreds := prevCreds
125 randFloat64, err := sdkrand.CryptoRandFloat64()
126 if err != nil {
127 return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
128 }
129
130 // Random distribution of [5,15) minutes.
131 expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
132 newCreds.Expires = sdk.NowTime().Add(expireOffset)
133
134 logger := middleware.GetLogger(ctx)
135 logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
136
137 return newCreds, nil
138}
139
140// AdjustExpiresBy will adds the passed in duration to the passed in
141// credential's Expires time, unless the time until Expires is less than 15
142// minutes. Returns the credentials, even if not updated.
143func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
144 aws.Credentials, error,
145) {
146 if !creds.CanExpire {
147 return creds, nil
148 }
149 if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
150 return creds, nil
151 }
152
153 creds.Expires = creds.Expires.Add(dur)
154 return creds, nil
155}
156
157// ec2RoleCredRespBody provides the shape for unmarshaling credential
158// request responses.
159type ec2RoleCredRespBody struct {
160 // Success State
161 Expiration time.Time
162 AccessKeyID string
163 SecretAccessKey string
164 Token string
165
166 // Error state
167 Code string
168 Message string
169}
170
171const iamSecurityCredsPath = "/iam/security-credentials/"
172
173// requestCredList requests a list of credentials from the EC2 service. If
174// there are no credentials, or there is an error making or receiving the
175// request
176func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
177 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
178 Path: iamSecurityCredsPath,
179 })
180 if err != nil {
181 return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
182 }
183 defer resp.Content.Close()
184
185 credsList := []string{}
186 s := bufio.NewScanner(resp.Content)
187 for s.Scan() {
188 credsList = append(credsList, s.Text())
189 }
190
191 if err := s.Err(); err != nil {
192 return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
193 }
194
195 return credsList, nil
196}
197
198// requestCred requests the credentials for a specific credentials from the EC2 service.
199//
200// If the credentials cannot be found, or there is an error reading the response
201// and error will be returned.
202func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
203 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
204 Path: path.Join(iamSecurityCredsPath, credsName),
205 })
206 if err != nil {
207 return ec2RoleCredRespBody{},
208 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
209 credsName, err)
210 }
211 defer resp.Content.Close()
212
213 var respCreds ec2RoleCredRespBody
214 if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
215 return ec2RoleCredRespBody{},
216 fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
217 credsName, err)
218 }
219
220 if !strings.EqualFold(respCreds.Code, "Success") {
221 // If an error code was returned something failed requesting the role.
222 return ec2RoleCredRespBody{},
223 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
224 credsName,
225 &smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
226 }
227
228 return respCreds, nil
229}