token_cache.go

  1package bearer
  2
  3import (
  4	"context"
  5	"fmt"
  6	"sync/atomic"
  7	"time"
  8
  9	smithycontext "github.com/aws/smithy-go/context"
 10	"github.com/aws/smithy-go/internal/sync/singleflight"
 11)
 12
 13// package variable that can be override in unit tests.
 14var timeNow = time.Now
 15
 16// TokenCacheOptions provides a set of optional configuration options for the
 17// TokenCache TokenProvider.
 18type TokenCacheOptions struct {
 19	// The duration before the token will expire when the credentials will be
 20	// refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls
 21	// will be blocking.
 22	//
 23	// Asynchronous refreshes are deduplicated, and only one will be in-flight
 24	// at a time. If the token expires while an asynchronous refresh is in
 25	// flight, the next call to RetrieveBearerToken will block on that refresh
 26	// to return.
 27	RefreshBeforeExpires time.Duration
 28
 29	// The timeout the underlying TokenProvider's RetrieveBearerToken call must
 30	// return within, or will be canceled. Defaults to 0, no timeout.
 31	//
 32	// If 0 timeout, its possible for the underlying tokenProvider's
 33	// RetrieveBearerToken call to block forever. Preventing subsequent
 34	// TokenCache attempts to refresh the token.
 35	//
 36	// If this timeout is reached all pending deduplicated calls to
 37	// TokenCache RetrieveBearerToken will fail with an error.
 38	RetrieveBearerTokenTimeout time.Duration
 39
 40	// The minimum duration between asynchronous refresh attempts. If the next
 41	// asynchronous recent refresh attempt was within the minimum delay
 42	// duration, the call to retrieve will return the current cached token, if
 43	// not expired.
 44	//
 45	// The asynchronous retrieve is deduplicated across multiple calls when
 46	// RetrieveBearerToken is called. The asynchronous retrieve is not a
 47	// periodic task. It is only performed when the token has not yet expired,
 48	// and the current item is within the RefreshBeforeExpires window, and the
 49	// TokenCache's RetrieveBearerToken method is called.
 50	//
 51	// If 0, (default) there will be no minimum delay between asynchronous
 52	// refresh attempts.
 53	//
 54	// If DisableAsyncRefresh is true, this option is ignored.
 55	AsyncRefreshMinimumDelay time.Duration
 56
 57	// Sets if the TokenCache will attempt to refresh the token in the
 58	// background asynchronously instead of blocking for credentials to be
 59	// refreshed. If disabled token refresh will be blocking.
 60	//
 61	// The first call to RetrieveBearerToken will always be blocking, because
 62	// there is no cached token.
 63	DisableAsyncRefresh bool
 64}
 65
 66// TokenCache provides an utility to cache Bearer Authentication tokens from a
 67// wrapped TokenProvider. The TokenCache can be has options to configure the
 68// cache's early and asynchronous refresh of the token.
 69type TokenCache struct {
 70	options  TokenCacheOptions
 71	provider TokenProvider
 72
 73	cachedToken            atomic.Value
 74	lastRefreshAttemptTime atomic.Value
 75	sfGroup                singleflight.Group
 76}
 77
 78// NewTokenCache returns a initialized TokenCache that implements the
 79// TokenProvider interface. Wrapping the provider passed in. Also taking a set
 80// of optional functional option parameters to configure the token cache.
 81func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache {
 82	var options TokenCacheOptions
 83	for _, fn := range optFns {
 84		fn(&options)
 85	}
 86
 87	return &TokenCache{
 88		options:  options,
 89		provider: provider,
 90	}
 91}
 92
 93// RetrieveBearerToken returns the token if it could be obtained, or error if a
 94// valid token could not be retrieved.
 95//
 96// The passed in Context's cancel/deadline/timeout will impacting only this
 97// individual retrieve call and not any other already queued up calls. This
 98// means underlying provider's RetrieveBearerToken calls could block for ever,
 99// and not be canceled with the Context. Set RetrieveBearerTokenTimeout to
100// provide a timeout, preventing the underlying TokenProvider blocking forever.
101//
102// By default, if the passed in Context is canceled, all of its values will be
103// considered expired. The wrapped TokenProvider will not be able to lookup the
104// values from the Context once it is expired. This is done to protect against
105// expired values no longer being valid. To disable this behavior, use
106// smithy-go's context.WithPreserveExpiredValues to add a value to the Context
107// before calling RetrieveBearerToken to enable support for expired values.
108//
109// Without RetrieveBearerTokenTimeout there is the potential for a underlying
110// Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent
111// attempts at refreshing the token.
112func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) {
113	cachedToken, ok := p.getCachedToken()
114	if !ok || cachedToken.Expired(timeNow()) {
115		return p.refreshBearerToken(ctx)
116	}
117
118	// Check if the token should be refreshed before it expires.
119	refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires))
120	if !refreshToken {
121		return cachedToken, nil
122	}
123
124	if p.options.DisableAsyncRefresh {
125		return p.refreshBearerToken(ctx)
126	}
127
128	p.tryAsyncRefresh(ctx)
129
130	return cachedToken, nil
131}
132
133// tryAsyncRefresh attempts to asynchronously refresh the token returning the
134// already cached token. If it AsyncRefreshMinimumDelay option is not zero, and
135// the duration since the last refresh is less than that value, nothing will be
136// done.
137func (p *TokenCache) tryAsyncRefresh(ctx context.Context) {
138	if p.options.AsyncRefreshMinimumDelay != 0 {
139		var lastRefreshAttempt time.Time
140		if v := p.lastRefreshAttemptTime.Load(); v != nil {
141			lastRefreshAttempt = v.(time.Time)
142		}
143
144		if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) {
145			return
146		}
147	}
148
149	// Ignore the returned channel so this won't be blocking, and limit the
150	// number of additional goroutines created.
151	p.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
152		res, err := p.refreshBearerToken(ctx)
153		if p.options.AsyncRefreshMinimumDelay != 0 {
154			var refreshAttempt time.Time
155			if err != nil {
156				refreshAttempt = timeNow()
157			}
158			p.lastRefreshAttemptTime.Store(refreshAttempt)
159		}
160
161		return res, err
162	})
163}
164
165func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) {
166	resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) {
167		ctx := smithycontext.WithSuppressCancel(ctx)
168		if v := p.options.RetrieveBearerTokenTimeout; v != 0 {
169			var cancel func()
170			ctx, cancel = context.WithTimeout(ctx, v)
171			defer cancel()
172		}
173		return p.singleRetrieve(ctx)
174	})
175
176	select {
177	case res := <-resCh:
178		return res.Val.(Token), res.Err
179	case <-ctx.Done():
180		return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err())
181	}
182}
183
184func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) {
185	token, err := p.provider.RetrieveBearerToken(ctx)
186	if err != nil {
187		return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err)
188	}
189
190	p.cachedToken.Store(&token)
191	return token, nil
192}
193
194// getCachedToken returns the currently cached token and true if found. Returns
195// false if no token is cached.
196func (p *TokenCache) getCachedToken() (Token, bool) {
197	v := p.cachedToken.Load()
198	if v == nil {
199		return Token{}, false
200	}
201
202	t := v.(*Token)
203	if t == nil || t.Value == "" {
204		return Token{}, false
205	}
206
207	return *t, true
208}