credential_cache.go

  1package aws
  2
  3import (
  4	"context"
  5	"fmt"
  6	"sync/atomic"
  7	"time"
  8
  9	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
 10	"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
 11)
 12
 13// CredentialsCacheOptions are the options
 14type CredentialsCacheOptions struct {
 15
 16	// ExpiryWindow will allow the credentials to trigger refreshing prior to
 17	// the credentials actually expiring. This is beneficial so race conditions
 18	// with expiring credentials do not cause request to fail unexpectedly
 19	// due to ExpiredTokenException exceptions.
 20	//
 21	// An ExpiryWindow of 10s would cause calls to IsExpired() to return true
 22	// 10 seconds before the credentials are actually expired. This can cause an
 23	// increased number of requests to refresh the credentials to occur.
 24	//
 25	// If ExpiryWindow is 0 or less it will be ignored.
 26	ExpiryWindow time.Duration
 27
 28	// ExpiryWindowJitterFrac provides a mechanism for randomizing the
 29	// expiration of credentials within the configured ExpiryWindow by a random
 30	// percentage. Valid values are between 0.0 and 1.0.
 31	//
 32	// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
 33	// is 0.5 then credentials will be set to expire between 30 to 60 seconds
 34	// prior to their actual expiration time.
 35	//
 36	// If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
 37	// If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
 38	// If ExpiryWindowJitterFrac < 0 the value will be treated as 0.
 39	// If ExpiryWindowJitterFrac > 1 the value will be treated as 1.
 40	ExpiryWindowJitterFrac float64
 41}
 42
 43// CredentialsCache provides caching and concurrency safe credentials retrieval
 44// via the provider's retrieve method.
 45//
 46// CredentialsCache will look for optional interfaces on the Provider to adjust
 47// how the credential cache handles credentials caching.
 48//
 49//   - HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
 50//     credential refresh failures. This could return an updated Credentials
 51//     value, or attempt another means of retrieving credentials.
 52//
 53//   - AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
 54//     credentials Expires is modified. This could modify how the Credentials
 55//     Expires is adjusted based on the CredentialsCache ExpiryWindow option.
 56//     Such as providing a floor not to reduce the Expires below.
 57type CredentialsCache struct {
 58	provider CredentialsProvider
 59
 60	options CredentialsCacheOptions
 61	creds   atomic.Value
 62	sf      singleflight.Group
 63}
 64
 65// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
 66// is expected to not be nil. A variadic list of one or more functions can be
 67// provided to modify the CredentialsCache configuration. This allows for
 68// configuration of credential expiry window and jitter.
 69func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
 70	options := CredentialsCacheOptions{}
 71
 72	for _, fn := range optFns {
 73		fn(&options)
 74	}
 75
 76	if options.ExpiryWindow < 0 {
 77		options.ExpiryWindow = 0
 78	}
 79
 80	if options.ExpiryWindowJitterFrac < 0 {
 81		options.ExpiryWindowJitterFrac = 0
 82	} else if options.ExpiryWindowJitterFrac > 1 {
 83		options.ExpiryWindowJitterFrac = 1
 84	}
 85
 86	return &CredentialsCache{
 87		provider: provider,
 88		options:  options,
 89	}
 90}
 91
 92// Retrieve returns the credentials. If the credentials have already been
 93// retrieved, and not expired the cached credentials will be returned. If the
 94// credentials have not been retrieved yet, or expired the provider's Retrieve
 95// method will be called.
 96//
 97// Returns and error if the provider's retrieve method returns an error.
 98func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
 99	if creds, ok := p.getCreds(); ok && !creds.Expired() {
100		return creds, nil
101	}
102
103	resCh := p.sf.DoChan("", func() (interface{}, error) {
104		return p.singleRetrieve(&suppressedContext{ctx})
105	})
106	select {
107	case res := <-resCh:
108		return res.Val.(Credentials), res.Err
109	case <-ctx.Done():
110		return Credentials{}, &RequestCanceledError{Err: ctx.Err()}
111	}
112}
113
114func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
115	currCreds, ok := p.getCreds()
116	if ok && !currCreds.Expired() {
117		return currCreds, nil
118	}
119
120	newCreds, err := p.provider.Retrieve(ctx)
121	if err != nil {
122		handleFailToRefresh := defaultHandleFailToRefresh
123		if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok {
124			handleFailToRefresh = cs.HandleFailToRefresh
125		}
126		newCreds, err = handleFailToRefresh(ctx, currCreds, err)
127		if err != nil {
128			return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err)
129		}
130	}
131
132	if newCreds.CanExpire && p.options.ExpiryWindow > 0 {
133		adjustExpiresBy := defaultAdjustExpiresBy
134		if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok {
135			adjustExpiresBy = cs.AdjustExpiresBy
136		}
137
138		randFloat64, err := sdkrand.CryptoRandFloat64()
139		if err != nil {
140			return Credentials{}, fmt.Errorf("failed to get random provider, %w", err)
141		}
142
143		var jitter time.Duration
144		if p.options.ExpiryWindowJitterFrac > 0 {
145			jitter = time.Duration(randFloat64 *
146				p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
147		}
148
149		newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter))
150		if err != nil {
151			return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err)
152		}
153	}
154
155	p.creds.Store(&newCreds)
156	return newCreds, nil
157}
158
159// getCreds returns the currently stored credentials and true. Returning false
160// if no credentials were stored.
161func (p *CredentialsCache) getCreds() (Credentials, bool) {
162	v := p.creds.Load()
163	if v == nil {
164		return Credentials{}, false
165	}
166
167	c := v.(*Credentials)
168	if c == nil || !c.HasKeys() {
169		return Credentials{}, false
170	}
171
172	return *c, true
173}
174
175// Invalidate will invalidate the cached credentials. The next call to Retrieve
176// will cause the provider's Retrieve method to be called.
177func (p *CredentialsCache) Invalidate() {
178	p.creds.Store((*Credentials)(nil))
179}
180
181// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
182// matches the target provider type.
183func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
184	return IsCredentialsProvider(p.provider, target)
185}
186
187// HandleFailRefreshCredentialsCacheStrategy is an interface for
188// CredentialsCache to allow CredentialsProvider  how failed to refresh
189// credentials is handled.
190type HandleFailRefreshCredentialsCacheStrategy interface {
191	// Given the previously cached Credentials, if any, and refresh error, may
192	// returns new or modified set of Credentials, or error.
193	//
194	// Credential caches may use default implementation if nil.
195	HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error)
196}
197
198// defaultHandleFailToRefresh returns the passed in error.
199func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) {
200	return Credentials{}, err
201}
202
203// AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
204// to allow CredentialsProvider to intercept adjustments to Credentials expiry
205// based on expectations and use cases of CredentialsProvider.
206//
207// Credential caches may use default implementation if nil.
208type AdjustExpiresByCredentialsCacheStrategy interface {
209	// Given a Credentials as input, applying any mutations and
210	// returning the potentially updated Credentials, or error.
211	AdjustExpiresBy(Credentials, time.Duration) (Credentials, error)
212}
213
214// defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
215// and returns the updated credentials value. If Credentials value's CanExpire
216// is false, the passed in credentials are returned unchanged.
217func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) {
218	if !creds.CanExpire {
219		return creds, nil
220	}
221
222	creds.Expires = creds.Expires.Add(dur)
223	return creds, nil
224}