1package imds
  2
  3import (
  4	"bytes"
  5	"context"
  6	"fmt"
  7	"io/ioutil"
  8	"net/url"
  9	"path"
 10	"time"
 11
 12	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
 13	"github.com/aws/aws-sdk-go-v2/aws/retry"
 14	"github.com/aws/smithy-go/middleware"
 15	smithyhttp "github.com/aws/smithy-go/transport/http"
 16)
 17
 18func addAPIRequestMiddleware(stack *middleware.Stack,
 19	options Options,
 20	operation string,
 21	getPath func(interface{}) (string, error),
 22	getOutput func(*smithyhttp.Response) (interface{}, error),
 23) (err error) {
 24	err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
 25	if err != nil {
 26		return err
 27	}
 28
 29	// Token Serializer build and state management.
 30	if !options.disableAPIToken {
 31		err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
 32		if err != nil {
 33			return err
 34		}
 35
 36		err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
 37		if err != nil {
 38			return err
 39		}
 40	}
 41
 42	return nil
 43}
 44
 45func addRequestMiddleware(stack *middleware.Stack,
 46	options Options,
 47	method string,
 48	operation string,
 49	getPath func(interface{}) (string, error),
 50	getOutput func(*smithyhttp.Response) (interface{}, error),
 51) (err error) {
 52	err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
 53	if err != nil {
 54		return err
 55	}
 56
 57	// Operation timeout
 58	err = stack.Initialize.Add(&operationTimeout{
 59		Disabled:       options.DisableDefaultTimeout,
 60		DefaultTimeout: defaultOperationTimeout,
 61	}, middleware.Before)
 62	if err != nil {
 63		return err
 64	}
 65
 66	// Operation Serializer
 67	err = stack.Serialize.Add(&serializeRequest{
 68		GetPath: getPath,
 69		Method:  method,
 70	}, middleware.After)
 71	if err != nil {
 72		return err
 73	}
 74
 75	// Operation endpoint resolver
 76	err = stack.Serialize.Insert(&resolveEndpoint{
 77		Endpoint:     options.Endpoint,
 78		EndpointMode: options.EndpointMode,
 79	}, "OperationSerializer", middleware.Before)
 80	if err != nil {
 81		return err
 82	}
 83
 84	// Operation Deserializer
 85	err = stack.Deserialize.Add(&deserializeResponse{
 86		GetOutput: getOutput,
 87	}, middleware.After)
 88	if err != nil {
 89		return err
 90	}
 91
 92	err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
 93		LogRequest:          options.ClientLogMode.IsRequest(),
 94		LogRequestWithBody:  options.ClientLogMode.IsRequestWithBody(),
 95		LogResponse:         options.ClientLogMode.IsResponse(),
 96		LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
 97	}, middleware.After)
 98	if err != nil {
 99		return err
100	}
101
102	err = addSetLoggerMiddleware(stack, options)
103	if err != nil {
104		return err
105	}
106
107	if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
108		return fmt.Errorf("add protocol finalizers: %w", err)
109	}
110
111	// Retry support
112	return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
113		Retryer:          options.Retryer,
114		LogRetryAttempts: options.ClientLogMode.IsRetries(),
115	})
116}
117
118func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
119	return middleware.AddSetLoggerMiddleware(stack, o.Logger)
120}
121
122type serializeRequest struct {
123	GetPath func(interface{}) (string, error)
124	Method  string
125}
126
127func (*serializeRequest) ID() string {
128	return "OperationSerializer"
129}
130
131func (m *serializeRequest) HandleSerialize(
132	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
133) (
134	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
135) {
136	request, ok := in.Request.(*smithyhttp.Request)
137	if !ok {
138		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
139	}
140
141	reqPath, err := m.GetPath(in.Parameters)
142	if err != nil {
143		return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
144	}
145
146	request.Request.URL.Path = reqPath
147	request.Request.Method = m.Method
148
149	return next.HandleSerialize(ctx, in)
150}
151
152type deserializeResponse struct {
153	GetOutput func(*smithyhttp.Response) (interface{}, error)
154}
155
156func (*deserializeResponse) ID() string {
157	return "OperationDeserializer"
158}
159
160func (m *deserializeResponse) HandleDeserialize(
161	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
162) (
163	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
164) {
165	out, metadata, err = next.HandleDeserialize(ctx, in)
166	if err != nil {
167		return out, metadata, err
168	}
169
170	resp, ok := out.RawResponse.(*smithyhttp.Response)
171	if !ok {
172		return out, metadata, fmt.Errorf(
173			"unexpected transport response type, %T, want %T", out.RawResponse, resp)
174	}
175	defer resp.Body.Close()
176
177	// read the full body so that any operation timeouts cleanup will not race
178	// the body being read.
179	body, err := ioutil.ReadAll(resp.Body)
180	if err != nil {
181		return out, metadata, fmt.Errorf("read response body failed, %w", err)
182	}
183	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
184
185	// Anything that's not 200 |< 300 is error
186	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
187		return out, metadata, &smithyhttp.ResponseError{
188			Response: resp,
189			Err:      fmt.Errorf("request to EC2 IMDS failed"),
190		}
191	}
192
193	result, err := m.GetOutput(resp)
194	if err != nil {
195		return out, metadata, fmt.Errorf(
196			"unable to get deserialized result for response, %w", err,
197		)
198	}
199	out.Result = result
200
201	return out, metadata, err
202}
203
204type resolveEndpoint struct {
205	Endpoint     string
206	EndpointMode EndpointModeState
207}
208
209func (*resolveEndpoint) ID() string {
210	return "ResolveEndpoint"
211}
212
213func (m *resolveEndpoint) HandleSerialize(
214	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
215) (
216	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
217) {
218
219	req, ok := in.Request.(*smithyhttp.Request)
220	if !ok {
221		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
222	}
223
224	var endpoint string
225	if len(m.Endpoint) > 0 {
226		endpoint = m.Endpoint
227	} else {
228		switch m.EndpointMode {
229		case EndpointModeStateIPv6:
230			endpoint = defaultIPv6Endpoint
231		case EndpointModeStateIPv4:
232			fallthrough
233		case EndpointModeStateUnset:
234			endpoint = defaultIPv4Endpoint
235		default:
236			return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
237		}
238	}
239
240	req.URL, err = url.Parse(endpoint)
241	if err != nil {
242		return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
243	}
244
245	return next.HandleSerialize(ctx, in)
246}
247
248const (
249	defaultOperationTimeout = 5 * time.Second
250)
251
252// operationTimeout adds a timeout on the middleware stack if the Context the
253// stack was called with does not have a deadline. The next middleware must
254// complete before the timeout, or the context will be canceled.
255//
256// If DefaultTimeout is zero, no default timeout will be used if the Context
257// does not have a timeout.
258//
259// The next middleware must also ensure that any resources that are also
260// canceled by the stack's context are completely consumed before returning.
261// Otherwise the timeout cleanup will race the resource being consumed
262// upstream.
263type operationTimeout struct {
264	Disabled       bool
265	DefaultTimeout time.Duration
266}
267
268func (*operationTimeout) ID() string { return "OperationTimeout" }
269
270func (m *operationTimeout) HandleInitialize(
271	ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
272) (
273	output middleware.InitializeOutput, metadata middleware.Metadata, err error,
274) {
275	if m.Disabled {
276		return next.HandleInitialize(ctx, input)
277	}
278
279	if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
280		var cancelFn func()
281		ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
282		defer cancelFn()
283	}
284
285	return next.HandleInitialize(ctx, input)
286}
287
288// appendURIPath joins a URI path component to the existing path with `/`
289// separators between the path components. If the path being added ends with a
290// trailing `/` that slash will be maintained.
291func appendURIPath(base, add string) string {
292	reqPath := path.Join(base, add)
293	if len(add) != 0 && add[len(add)-1] == '/' {
294		reqPath += "/"
295	}
296	return reqPath
297}
298
299func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
300	if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
301		return fmt.Errorf("add ResolveAuthScheme: %w", err)
302	}
303	if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
304		return fmt.Errorf("add GetIdentity: %w", err)
305	}
306	if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
307		return fmt.Errorf("add ResolveEndpointV2: %w", err)
308	}
309	if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
310		return fmt.Errorf("add Signing: %w", err)
311	}
312	return nil
313}