1package aws
2
3import (
4 "context"
5 "fmt"
6 "sync/atomic"
7 "time"
8
9 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
10 "github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
11)
12
13// CredentialsCacheOptions are the options
14type CredentialsCacheOptions struct {
15
16 // ExpiryWindow will allow the credentials to trigger refreshing prior to
17 // the credentials actually expiring. This is beneficial so race conditions
18 // with expiring credentials do not cause request to fail unexpectedly
19 // due to ExpiredTokenException exceptions.
20 //
21 // An ExpiryWindow of 10s would cause calls to IsExpired() to return true
22 // 10 seconds before the credentials are actually expired. This can cause an
23 // increased number of requests to refresh the credentials to occur.
24 //
25 // If ExpiryWindow is 0 or less it will be ignored.
26 ExpiryWindow time.Duration
27
28 // ExpiryWindowJitterFrac provides a mechanism for randomizing the
29 // expiration of credentials within the configured ExpiryWindow by a random
30 // percentage. Valid values are between 0.0 and 1.0.
31 //
32 // As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
33 // is 0.5 then credentials will be set to expire between 30 to 60 seconds
34 // prior to their actual expiration time.
35 //
36 // If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
37 // If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
38 // If ExpiryWindowJitterFrac < 0 the value will be treated as 0.
39 // If ExpiryWindowJitterFrac > 1 the value will be treated as 1.
40 ExpiryWindowJitterFrac float64
41}
42
43// CredentialsCache provides caching and concurrency safe credentials retrieval
44// via the provider's retrieve method.
45//
46// CredentialsCache will look for optional interfaces on the Provider to adjust
47// how the credential cache handles credentials caching.
48//
49// - HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
50// credential refresh failures. This could return an updated Credentials
51// value, or attempt another means of retrieving credentials.
52//
53// - AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
54// credentials Expires is modified. This could modify how the Credentials
55// Expires is adjusted based on the CredentialsCache ExpiryWindow option.
56// Such as providing a floor not to reduce the Expires below.
57type CredentialsCache struct {
58 provider CredentialsProvider
59
60 options CredentialsCacheOptions
61 creds atomic.Value
62 sf singleflight.Group
63}
64
65// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
66// is expected to not be nil. A variadic list of one or more functions can be
67// provided to modify the CredentialsCache configuration. This allows for
68// configuration of credential expiry window and jitter.
69func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
70 options := CredentialsCacheOptions{}
71
72 for _, fn := range optFns {
73 fn(&options)
74 }
75
76 if options.ExpiryWindow < 0 {
77 options.ExpiryWindow = 0
78 }
79
80 if options.ExpiryWindowJitterFrac < 0 {
81 options.ExpiryWindowJitterFrac = 0
82 } else if options.ExpiryWindowJitterFrac > 1 {
83 options.ExpiryWindowJitterFrac = 1
84 }
85
86 return &CredentialsCache{
87 provider: provider,
88 options: options,
89 }
90}
91
92// Retrieve returns the credentials. If the credentials have already been
93// retrieved, and not expired the cached credentials will be returned. If the
94// credentials have not been retrieved yet, or expired the provider's Retrieve
95// method will be called.
96//
97// Returns and error if the provider's retrieve method returns an error.
98func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
99 if creds, ok := p.getCreds(); ok && !creds.Expired() {
100 return creds, nil
101 }
102
103 resCh := p.sf.DoChan("", func() (interface{}, error) {
104 return p.singleRetrieve(&suppressedContext{ctx})
105 })
106 select {
107 case res := <-resCh:
108 return res.Val.(Credentials), res.Err
109 case <-ctx.Done():
110 return Credentials{}, &RequestCanceledError{Err: ctx.Err()}
111 }
112}
113
114func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
115 currCreds, ok := p.getCreds()
116 if ok && !currCreds.Expired() {
117 return currCreds, nil
118 }
119
120 newCreds, err := p.provider.Retrieve(ctx)
121 if err != nil {
122 handleFailToRefresh := defaultHandleFailToRefresh
123 if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok {
124 handleFailToRefresh = cs.HandleFailToRefresh
125 }
126 newCreds, err = handleFailToRefresh(ctx, currCreds, err)
127 if err != nil {
128 return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err)
129 }
130 }
131
132 if newCreds.CanExpire && p.options.ExpiryWindow > 0 {
133 adjustExpiresBy := defaultAdjustExpiresBy
134 if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok {
135 adjustExpiresBy = cs.AdjustExpiresBy
136 }
137
138 randFloat64, err := sdkrand.CryptoRandFloat64()
139 if err != nil {
140 return Credentials{}, fmt.Errorf("failed to get random provider, %w", err)
141 }
142
143 var jitter time.Duration
144 if p.options.ExpiryWindowJitterFrac > 0 {
145 jitter = time.Duration(randFloat64 *
146 p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
147 }
148
149 newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter))
150 if err != nil {
151 return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err)
152 }
153 }
154
155 p.creds.Store(&newCreds)
156 return newCreds, nil
157}
158
159// getCreds returns the currently stored credentials and true. Returning false
160// if no credentials were stored.
161func (p *CredentialsCache) getCreds() (Credentials, bool) {
162 v := p.creds.Load()
163 if v == nil {
164 return Credentials{}, false
165 }
166
167 c := v.(*Credentials)
168 if c == nil || !c.HasKeys() {
169 return Credentials{}, false
170 }
171
172 return *c, true
173}
174
175// Invalidate will invalidate the cached credentials. The next call to Retrieve
176// will cause the provider's Retrieve method to be called.
177func (p *CredentialsCache) Invalidate() {
178 p.creds.Store((*Credentials)(nil))
179}
180
181// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
182// matches the target provider type.
183func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
184 return IsCredentialsProvider(p.provider, target)
185}
186
187// HandleFailRefreshCredentialsCacheStrategy is an interface for
188// CredentialsCache to allow CredentialsProvider how failed to refresh
189// credentials is handled.
190type HandleFailRefreshCredentialsCacheStrategy interface {
191 // Given the previously cached Credentials, if any, and refresh error, may
192 // returns new or modified set of Credentials, or error.
193 //
194 // Credential caches may use default implementation if nil.
195 HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error)
196}
197
198// defaultHandleFailToRefresh returns the passed in error.
199func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) {
200 return Credentials{}, err
201}
202
203// AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
204// to allow CredentialsProvider to intercept adjustments to Credentials expiry
205// based on expectations and use cases of CredentialsProvider.
206//
207// Credential caches may use default implementation if nil.
208type AdjustExpiresByCredentialsCacheStrategy interface {
209 // Given a Credentials as input, applying any mutations and
210 // returning the potentially updated Credentials, or error.
211 AdjustExpiresBy(Credentials, time.Duration) (Credentials, error)
212}
213
214// defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
215// and returns the updated credentials value. If Credentials value's CanExpire
216// is false, the passed in credentials are returned unchanged.
217func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) {
218 if !creds.CanExpire {
219 return creds, nil
220 }
221
222 creds.Expires = creds.Expires.Add(dur)
223 return creds, nil
224}