1package imds
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "github.com/aws/aws-sdk-go-v2/aws"
8 "github.com/aws/smithy-go"
9 "github.com/aws/smithy-go/logging"
10 "net/http"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "github.com/aws/smithy-go/middleware"
16 smithyhttp "github.com/aws/smithy-go/transport/http"
17)
18
19const (
20 // Headers for Token and TTL
21 tokenHeader = "x-aws-ec2-metadata-token"
22 defaultTokenTTL = 5 * time.Minute
23)
24
25type tokenProvider struct {
26 client *Client
27 tokenTTL time.Duration
28
29 token *apiToken
30 tokenMux sync.RWMutex
31
32 disabled uint32 // Atomic updated
33}
34
35func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
36 return &tokenProvider{
37 client: client,
38 tokenTTL: ttl,
39 }
40}
41
42// apiToken provides the API token used by all operation calls for th EC2
43// Instance metadata service.
44type apiToken struct {
45 token string
46 expires time.Time
47}
48
49var timeNow = time.Now
50
51// Expired returns if the token is expired.
52func (t *apiToken) Expired() bool {
53 // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
54 // time is always based on reported wall-clock time.
55 return timeNow().Round(0).After(t.expires)
56}
57
58func (t *tokenProvider) ID() string { return "APITokenProvider" }
59
60// HandleFinalize is the finalize stack middleware, that if the token provider is
61// enabled, will attempt to add the cached API token to the request. If the API
62// token is not cached, it will be retrieved in a separate API call, getToken.
63//
64// For retry attempts, handler must be added after attempt retryer.
65//
66// If request for getToken fails the token provider may be disabled from future
67// requests, depending on the response status code.
68func (t *tokenProvider) HandleFinalize(
69 ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
70) (
71 out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
72) {
73 if t.fallbackEnabled() && !t.enabled() {
74 // short-circuits to insecure data flow if token provider is disabled.
75 return next.HandleFinalize(ctx, input)
76 }
77
78 req, ok := input.Request.(*smithyhttp.Request)
79 if !ok {
80 return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
81 }
82
83 tok, err := t.getToken(ctx)
84 if err != nil {
85 // If the error allows the token to downgrade to insecure flow allow that.
86 var bypassErr *bypassTokenRetrievalError
87 if errors.As(err, &bypassErr) {
88 return next.HandleFinalize(ctx, input)
89 }
90
91 return out, metadata, fmt.Errorf("failed to get API token, %w", err)
92 }
93
94 req.Header.Set(tokenHeader, tok.token)
95
96 return next.HandleFinalize(ctx, input)
97}
98
99// HandleDeserialize is the deserialize stack middleware for determining if the
100// operation the token provider is decorating failed because of a 401
101// unauthorized status code. If the operation failed for that reason the token
102// provider needs to be re-enabled so that it can start adding the API token to
103// operation calls.
104func (t *tokenProvider) HandleDeserialize(
105 ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
106) (
107 out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
108) {
109 out, metadata, err = next.HandleDeserialize(ctx, input)
110 if err == nil {
111 return out, metadata, err
112 }
113
114 resp, ok := out.RawResponse.(*smithyhttp.Response)
115 if !ok {
116 return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
117 }
118
119 if resp.StatusCode == http.StatusUnauthorized { // unauthorized
120 t.enable()
121 err = &retryableError{Err: err, isRetryable: true}
122 }
123
124 return out, metadata, err
125}
126
127func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
128 if t.fallbackEnabled() && !t.enabled() {
129 return nil, &bypassTokenRetrievalError{
130 Err: fmt.Errorf("cannot get API token, provider disabled"),
131 }
132 }
133
134 t.tokenMux.RLock()
135 tok = t.token
136 t.tokenMux.RUnlock()
137
138 if tok != nil && !tok.Expired() {
139 return tok, nil
140 }
141
142 tok, err = t.updateToken(ctx)
143 if err != nil {
144 return nil, err
145 }
146
147 return tok, nil
148}
149
150func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
151 t.tokenMux.Lock()
152 defer t.tokenMux.Unlock()
153
154 // Prevent multiple requests to update retrieving the token.
155 if t.token != nil && !t.token.Expired() {
156 tok := t.token
157 return tok, nil
158 }
159
160 result, err := t.client.getToken(ctx, &getTokenInput{
161 TokenTTL: t.tokenTTL,
162 })
163 if err != nil {
164 var statusErr interface{ HTTPStatusCode() int }
165 if errors.As(err, &statusErr) {
166 switch statusErr.HTTPStatusCode() {
167 // Disable future get token if failed because of 403, 404, or 405
168 case http.StatusForbidden,
169 http.StatusNotFound,
170 http.StatusMethodNotAllowed:
171
172 if t.fallbackEnabled() {
173 logger := middleware.GetLogger(ctx)
174 logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
175 t.disable()
176 }
177
178 // 400 errors are terminal, and need to be upstreamed
179 case http.StatusBadRequest:
180 return nil, err
181 }
182 }
183
184 // Disable if request send failed or timed out getting response
185 var re *smithyhttp.RequestSendError
186 var ce *smithy.CanceledError
187 if errors.As(err, &re) || errors.As(err, &ce) {
188 atomic.StoreUint32(&t.disabled, 1)
189 }
190
191 if !t.fallbackEnabled() {
192 // NOTE: getToken() is an implementation detail of some outer operation
193 // (e.g. GetMetadata). It has its own retries that have already been exhausted.
194 // Mark the underlying error as a terminal error.
195 err = &retryableError{Err: err, isRetryable: false}
196 return nil, err
197 }
198
199 // Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
200 // and allow the request to proceed. Future requests _may_ re-attempt fetching a
201 // token if not disabled.
202 return nil, &bypassTokenRetrievalError{Err: err}
203 }
204
205 tok := &apiToken{
206 token: result.Token,
207 expires: timeNow().Add(result.TokenTTL),
208 }
209 t.token = tok
210
211 return tok, nil
212}
213
214// enabled returns if the token provider is current enabled or not.
215func (t *tokenProvider) enabled() bool {
216 return atomic.LoadUint32(&t.disabled) == 0
217}
218
219// fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
220func (t *tokenProvider) fallbackEnabled() bool {
221 switch t.client.options.EnableFallback {
222 case aws.FalseTernary:
223 return false
224 default:
225 return true
226 }
227}
228
229// disable disables the token provider and it will no longer attempt to inject
230// the token, nor request updates.
231func (t *tokenProvider) disable() {
232 atomic.StoreUint32(&t.disabled, 1)
233}
234
235// enable enables the token provide to start refreshing tokens, and adding them
236// to the pending request.
237func (t *tokenProvider) enable() {
238 t.tokenMux.Lock()
239 t.token = nil
240 t.tokenMux.Unlock()
241 atomic.StoreUint32(&t.disabled, 0)
242}
243
244type bypassTokenRetrievalError struct {
245 Err error
246}
247
248func (e *bypassTokenRetrievalError) Error() string {
249 return fmt.Sprintf("bypass token retrieval, %v", e.Err)
250}
251
252func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
253
254type retryableError struct {
255 Err error
256 isRetryable bool
257}
258
259func (e *retryableError) RetryableError() bool { return e.isRetryable }
260
261func (e *retryableError) Error() string { return e.Err.Error() }