1package imds
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "io/ioutil"
8 "net/url"
9 "path"
10 "time"
11
12 awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
13 "github.com/aws/aws-sdk-go-v2/aws/retry"
14 "github.com/aws/smithy-go/middleware"
15 smithyhttp "github.com/aws/smithy-go/transport/http"
16)
17
18func addAPIRequestMiddleware(stack *middleware.Stack,
19 options Options,
20 operation string,
21 getPath func(interface{}) (string, error),
22 getOutput func(*smithyhttp.Response) (interface{}, error),
23) (err error) {
24 err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
25 if err != nil {
26 return err
27 }
28
29 // Token Serializer build and state management.
30 if !options.disableAPIToken {
31 err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
32 if err != nil {
33 return err
34 }
35
36 err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
37 if err != nil {
38 return err
39 }
40 }
41
42 return nil
43}
44
45func addRequestMiddleware(stack *middleware.Stack,
46 options Options,
47 method string,
48 operation string,
49 getPath func(interface{}) (string, error),
50 getOutput func(*smithyhttp.Response) (interface{}, error),
51) (err error) {
52 err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
53 if err != nil {
54 return err
55 }
56
57 // Operation timeout
58 err = stack.Initialize.Add(&operationTimeout{
59 Disabled: options.DisableDefaultTimeout,
60 DefaultTimeout: defaultOperationTimeout,
61 }, middleware.Before)
62 if err != nil {
63 return err
64 }
65
66 // Operation Serializer
67 err = stack.Serialize.Add(&serializeRequest{
68 GetPath: getPath,
69 Method: method,
70 }, middleware.After)
71 if err != nil {
72 return err
73 }
74
75 // Operation endpoint resolver
76 err = stack.Serialize.Insert(&resolveEndpoint{
77 Endpoint: options.Endpoint,
78 EndpointMode: options.EndpointMode,
79 }, "OperationSerializer", middleware.Before)
80 if err != nil {
81 return err
82 }
83
84 // Operation Deserializer
85 err = stack.Deserialize.Add(&deserializeResponse{
86 GetOutput: getOutput,
87 }, middleware.After)
88 if err != nil {
89 return err
90 }
91
92 err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
93 LogRequest: options.ClientLogMode.IsRequest(),
94 LogRequestWithBody: options.ClientLogMode.IsRequestWithBody(),
95 LogResponse: options.ClientLogMode.IsResponse(),
96 LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
97 }, middleware.After)
98 if err != nil {
99 return err
100 }
101
102 err = addSetLoggerMiddleware(stack, options)
103 if err != nil {
104 return err
105 }
106
107 if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
108 return fmt.Errorf("add protocol finalizers: %w", err)
109 }
110
111 // Retry support
112 return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
113 Retryer: options.Retryer,
114 LogRetryAttempts: options.ClientLogMode.IsRetries(),
115 })
116}
117
118func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
119 return middleware.AddSetLoggerMiddleware(stack, o.Logger)
120}
121
122type serializeRequest struct {
123 GetPath func(interface{}) (string, error)
124 Method string
125}
126
127func (*serializeRequest) ID() string {
128 return "OperationSerializer"
129}
130
131func (m *serializeRequest) HandleSerialize(
132 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
133) (
134 out middleware.SerializeOutput, metadata middleware.Metadata, err error,
135) {
136 request, ok := in.Request.(*smithyhttp.Request)
137 if !ok {
138 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
139 }
140
141 reqPath, err := m.GetPath(in.Parameters)
142 if err != nil {
143 return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
144 }
145
146 request.Request.URL.Path = reqPath
147 request.Request.Method = m.Method
148
149 return next.HandleSerialize(ctx, in)
150}
151
152type deserializeResponse struct {
153 GetOutput func(*smithyhttp.Response) (interface{}, error)
154}
155
156func (*deserializeResponse) ID() string {
157 return "OperationDeserializer"
158}
159
160func (m *deserializeResponse) HandleDeserialize(
161 ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
162) (
163 out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
164) {
165 out, metadata, err = next.HandleDeserialize(ctx, in)
166 if err != nil {
167 return out, metadata, err
168 }
169
170 resp, ok := out.RawResponse.(*smithyhttp.Response)
171 if !ok {
172 return out, metadata, fmt.Errorf(
173 "unexpected transport response type, %T, want %T", out.RawResponse, resp)
174 }
175 defer resp.Body.Close()
176
177 // read the full body so that any operation timeouts cleanup will not race
178 // the body being read.
179 body, err := ioutil.ReadAll(resp.Body)
180 if err != nil {
181 return out, metadata, fmt.Errorf("read response body failed, %w", err)
182 }
183 resp.Body = ioutil.NopCloser(bytes.NewReader(body))
184
185 // Anything that's not 200 |< 300 is error
186 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
187 return out, metadata, &smithyhttp.ResponseError{
188 Response: resp,
189 Err: fmt.Errorf("request to EC2 IMDS failed"),
190 }
191 }
192
193 result, err := m.GetOutput(resp)
194 if err != nil {
195 return out, metadata, fmt.Errorf(
196 "unable to get deserialized result for response, %w", err,
197 )
198 }
199 out.Result = result
200
201 return out, metadata, err
202}
203
204type resolveEndpoint struct {
205 Endpoint string
206 EndpointMode EndpointModeState
207}
208
209func (*resolveEndpoint) ID() string {
210 return "ResolveEndpoint"
211}
212
213func (m *resolveEndpoint) HandleSerialize(
214 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
215) (
216 out middleware.SerializeOutput, metadata middleware.Metadata, err error,
217) {
218
219 req, ok := in.Request.(*smithyhttp.Request)
220 if !ok {
221 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
222 }
223
224 var endpoint string
225 if len(m.Endpoint) > 0 {
226 endpoint = m.Endpoint
227 } else {
228 switch m.EndpointMode {
229 case EndpointModeStateIPv6:
230 endpoint = defaultIPv6Endpoint
231 case EndpointModeStateIPv4:
232 fallthrough
233 case EndpointModeStateUnset:
234 endpoint = defaultIPv4Endpoint
235 default:
236 return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
237 }
238 }
239
240 req.URL, err = url.Parse(endpoint)
241 if err != nil {
242 return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
243 }
244
245 return next.HandleSerialize(ctx, in)
246}
247
248const (
249 defaultOperationTimeout = 5 * time.Second
250)
251
252// operationTimeout adds a timeout on the middleware stack if the Context the
253// stack was called with does not have a deadline. The next middleware must
254// complete before the timeout, or the context will be canceled.
255//
256// If DefaultTimeout is zero, no default timeout will be used if the Context
257// does not have a timeout.
258//
259// The next middleware must also ensure that any resources that are also
260// canceled by the stack's context are completely consumed before returning.
261// Otherwise the timeout cleanup will race the resource being consumed
262// upstream.
263type operationTimeout struct {
264 Disabled bool
265 DefaultTimeout time.Duration
266}
267
268func (*operationTimeout) ID() string { return "OperationTimeout" }
269
270func (m *operationTimeout) HandleInitialize(
271 ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
272) (
273 output middleware.InitializeOutput, metadata middleware.Metadata, err error,
274) {
275 if m.Disabled {
276 return next.HandleInitialize(ctx, input)
277 }
278
279 if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
280 var cancelFn func()
281 ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
282 defer cancelFn()
283 }
284
285 return next.HandleInitialize(ctx, input)
286}
287
288// appendURIPath joins a URI path component to the existing path with `/`
289// separators between the path components. If the path being added ends with a
290// trailing `/` that slash will be maintained.
291func appendURIPath(base, add string) string {
292 reqPath := path.Join(base, add)
293 if len(add) != 0 && add[len(add)-1] == '/' {
294 reqPath += "/"
295 }
296 return reqPath
297}
298
299func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
300 if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
301 return fmt.Errorf("add ResolveAuthScheme: %w", err)
302 }
303 if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
304 return fmt.Errorf("add GetIdentity: %w", err)
305 }
306 if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
307 return fmt.Errorf("add ResolveEndpointV2: %w", err)
308 }
309 if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
310 return fmt.Errorf("add Signing: %w", err)
311 }
312 return nil
313}