1package imds
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"github.com/aws/aws-sdk-go-v2/aws"
  8	"github.com/aws/smithy-go"
  9	"github.com/aws/smithy-go/logging"
 10	"net/http"
 11	"sync"
 12	"sync/atomic"
 13	"time"
 14
 15	"github.com/aws/smithy-go/middleware"
 16	smithyhttp "github.com/aws/smithy-go/transport/http"
 17)
 18
 19const (
 20	// Headers for Token and TTL
 21	tokenHeader     = "x-aws-ec2-metadata-token"
 22	defaultTokenTTL = 5 * time.Minute
 23)
 24
 25type tokenProvider struct {
 26	client   *Client
 27	tokenTTL time.Duration
 28
 29	token    *apiToken
 30	tokenMux sync.RWMutex
 31
 32	disabled uint32 // Atomic updated
 33}
 34
 35func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
 36	return &tokenProvider{
 37		client:   client,
 38		tokenTTL: ttl,
 39	}
 40}
 41
 42// apiToken provides the API token used by all operation calls for th EC2
 43// Instance metadata service.
 44type apiToken struct {
 45	token   string
 46	expires time.Time
 47}
 48
 49var timeNow = time.Now
 50
 51// Expired returns if the token is expired.
 52func (t *apiToken) Expired() bool {
 53	// Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
 54	// time is always based on reported wall-clock time.
 55	return timeNow().Round(0).After(t.expires)
 56}
 57
 58func (t *tokenProvider) ID() string { return "APITokenProvider" }
 59
 60// HandleFinalize is the finalize stack middleware, that if the token provider is
 61// enabled, will attempt to add the cached API token to the request. If the API
 62// token is not cached, it will be retrieved in a separate API call, getToken.
 63//
 64// For retry attempts, handler must be added after attempt retryer.
 65//
 66// If request for getToken fails the token provider may be disabled from future
 67// requests, depending on the response status code.
 68func (t *tokenProvider) HandleFinalize(
 69	ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
 70) (
 71	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
 72) {
 73	if t.fallbackEnabled() && !t.enabled() {
 74		// short-circuits to insecure data flow if token provider is disabled.
 75		return next.HandleFinalize(ctx, input)
 76	}
 77
 78	req, ok := input.Request.(*smithyhttp.Request)
 79	if !ok {
 80		return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
 81	}
 82
 83	tok, err := t.getToken(ctx)
 84	if err != nil {
 85		// If the error allows the token to downgrade to insecure flow allow that.
 86		var bypassErr *bypassTokenRetrievalError
 87		if errors.As(err, &bypassErr) {
 88			return next.HandleFinalize(ctx, input)
 89		}
 90
 91		return out, metadata, fmt.Errorf("failed to get API token, %w", err)
 92	}
 93
 94	req.Header.Set(tokenHeader, tok.token)
 95
 96	return next.HandleFinalize(ctx, input)
 97}
 98
 99// HandleDeserialize is the deserialize stack middleware for determining if the
100// operation the token provider is decorating failed because of a 401
101// unauthorized status code. If the operation failed for that reason the token
102// provider needs to be re-enabled so that it can start adding the API token to
103// operation calls.
104func (t *tokenProvider) HandleDeserialize(
105	ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
106) (
107	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
108) {
109	out, metadata, err = next.HandleDeserialize(ctx, input)
110	if err == nil {
111		return out, metadata, err
112	}
113
114	resp, ok := out.RawResponse.(*smithyhttp.Response)
115	if !ok {
116		return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
117	}
118
119	if resp.StatusCode == http.StatusUnauthorized { // unauthorized
120		t.enable()
121		err = &retryableError{Err: err, isRetryable: true}
122	}
123
124	return out, metadata, err
125}
126
127func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
128	if t.fallbackEnabled() && !t.enabled() {
129		return nil, &bypassTokenRetrievalError{
130			Err: fmt.Errorf("cannot get API token, provider disabled"),
131		}
132	}
133
134	t.tokenMux.RLock()
135	tok = t.token
136	t.tokenMux.RUnlock()
137
138	if tok != nil && !tok.Expired() {
139		return tok, nil
140	}
141
142	tok, err = t.updateToken(ctx)
143	if err != nil {
144		return nil, err
145	}
146
147	return tok, nil
148}
149
150func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
151	t.tokenMux.Lock()
152	defer t.tokenMux.Unlock()
153
154	// Prevent multiple requests to update retrieving the token.
155	if t.token != nil && !t.token.Expired() {
156		tok := t.token
157		return tok, nil
158	}
159
160	result, err := t.client.getToken(ctx, &getTokenInput{
161		TokenTTL: t.tokenTTL,
162	})
163	if err != nil {
164		var statusErr interface{ HTTPStatusCode() int }
165		if errors.As(err, &statusErr) {
166			switch statusErr.HTTPStatusCode() {
167			// Disable future get token if failed because of 403, 404, or 405
168			case http.StatusForbidden,
169				http.StatusNotFound,
170				http.StatusMethodNotAllowed:
171
172				if t.fallbackEnabled() {
173					logger := middleware.GetLogger(ctx)
174					logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
175					t.disable()
176				}
177
178			// 400 errors are terminal, and need to be upstreamed
179			case http.StatusBadRequest:
180				return nil, err
181			}
182		}
183
184		// Disable if request send failed or timed out getting response
185		var re *smithyhttp.RequestSendError
186		var ce *smithy.CanceledError
187		if errors.As(err, &re) || errors.As(err, &ce) {
188			atomic.StoreUint32(&t.disabled, 1)
189		}
190
191		if !t.fallbackEnabled() {
192			// NOTE: getToken() is an implementation detail of some outer operation
193			// (e.g. GetMetadata). It has its own retries that have already been exhausted.
194			// Mark the underlying error as a terminal error.
195			err = &retryableError{Err: err, isRetryable: false}
196			return nil, err
197		}
198
199		// Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
200		// and allow the request to proceed. Future requests _may_ re-attempt fetching a
201		// token if not disabled.
202		return nil, &bypassTokenRetrievalError{Err: err}
203	}
204
205	tok := &apiToken{
206		token:   result.Token,
207		expires: timeNow().Add(result.TokenTTL),
208	}
209	t.token = tok
210
211	return tok, nil
212}
213
214// enabled returns if the token provider is current enabled or not.
215func (t *tokenProvider) enabled() bool {
216	return atomic.LoadUint32(&t.disabled) == 0
217}
218
219// fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
220func (t *tokenProvider) fallbackEnabled() bool {
221	switch t.client.options.EnableFallback {
222	case aws.FalseTernary:
223		return false
224	default:
225		return true
226	}
227}
228
229// disable disables the token provider and it will no longer attempt to inject
230// the token, nor request updates.
231func (t *tokenProvider) disable() {
232	atomic.StoreUint32(&t.disabled, 1)
233}
234
235// enable enables the token provide to start refreshing tokens, and adding them
236// to the pending request.
237func (t *tokenProvider) enable() {
238	t.tokenMux.Lock()
239	t.token = nil
240	t.tokenMux.Unlock()
241	atomic.StoreUint32(&t.disabled, 0)
242}
243
244type bypassTokenRetrievalError struct {
245	Err error
246}
247
248func (e *bypassTokenRetrievalError) Error() string {
249	return fmt.Sprintf("bypass token retrieval, %v", e.Err)
250}
251
252func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
253
254type retryableError struct {
255	Err         error
256	isRetryable bool
257}
258
259func (e *retryableError) RetryableError() bool { return e.isRetryable }
260
261func (e *retryableError) Error() string { return e.Err.Error() }