middleware.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/url"
  9
 10	"github.com/aws/smithy-go"
 11	smithymiddleware "github.com/aws/smithy-go/middleware"
 12	smithyhttp "github.com/aws/smithy-go/transport/http"
 13)
 14
 15type buildEndpoint struct {
 16	Endpoint string
 17}
 18
 19func (b *buildEndpoint) ID() string {
 20	return "BuildEndpoint"
 21}
 22
 23func (b *buildEndpoint) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) (
 24	out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error,
 25) {
 26	request, ok := in.Request.(*smithyhttp.Request)
 27	if !ok {
 28		return out, metadata, fmt.Errorf("unknown transport, %T", in.Request)
 29	}
 30
 31	if len(b.Endpoint) == 0 {
 32		return out, metadata, fmt.Errorf("endpoint not provided")
 33	}
 34
 35	parsed, err := url.Parse(b.Endpoint)
 36	if err != nil {
 37		return out, metadata, fmt.Errorf("failed to parse endpoint, %w", err)
 38	}
 39
 40	request.URL = parsed
 41
 42	return next.HandleBuild(ctx, in)
 43}
 44
 45type serializeOpGetCredential struct{}
 46
 47func (s *serializeOpGetCredential) ID() string {
 48	return "OperationSerializer"
 49}
 50
 51func (s *serializeOpGetCredential) HandleSerialize(ctx context.Context, in smithymiddleware.SerializeInput, next smithymiddleware.SerializeHandler) (
 52	out smithymiddleware.SerializeOutput, metadata smithymiddleware.Metadata, err error,
 53) {
 54	request, ok := in.Request.(*smithyhttp.Request)
 55	if !ok {
 56		return out, metadata, fmt.Errorf("unknown transport type, %T", in.Request)
 57	}
 58
 59	params, ok := in.Parameters.(*GetCredentialsInput)
 60	if !ok {
 61		return out, metadata, fmt.Errorf("unknown input parameters, %T", in.Parameters)
 62	}
 63
 64	const acceptHeader = "Accept"
 65	request.Header[acceptHeader] = append(request.Header[acceptHeader][:0], "application/json")
 66
 67	if len(params.AuthorizationToken) > 0 {
 68		const authHeader = "Authorization"
 69		request.Header[authHeader] = append(request.Header[authHeader][:0], params.AuthorizationToken)
 70	}
 71
 72	return next.HandleSerialize(ctx, in)
 73}
 74
 75type deserializeOpGetCredential struct{}
 76
 77func (d *deserializeOpGetCredential) ID() string {
 78	return "OperationDeserializer"
 79}
 80
 81func (d *deserializeOpGetCredential) HandleDeserialize(ctx context.Context, in smithymiddleware.DeserializeInput, next smithymiddleware.DeserializeHandler) (
 82	out smithymiddleware.DeserializeOutput, metadata smithymiddleware.Metadata, err error,
 83) {
 84	out, metadata, err = next.HandleDeserialize(ctx, in)
 85	if err != nil {
 86		return out, metadata, err
 87	}
 88
 89	response, ok := out.RawResponse.(*smithyhttp.Response)
 90	if !ok {
 91		return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)}
 92	}
 93
 94	if response.StatusCode < 200 || response.StatusCode >= 300 {
 95		return out, metadata, deserializeError(response)
 96	}
 97
 98	var shape *GetCredentialsOutput
 99	if err = json.NewDecoder(response.Body).Decode(&shape); err != nil {
100		return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("failed to deserialize json response, %w", err)}
101	}
102
103	out.Result = shape
104	return out, metadata, err
105}
106
107func deserializeError(response *smithyhttp.Response) error {
108	// we could be talking to anything, json isn't guaranteed
109	// see https://github.com/aws/aws-sdk-go-v2/issues/2316
110	if response.Header.Get("Content-Type") == "application/json" {
111		return deserializeJSONError(response)
112	}
113
114	msg, err := io.ReadAll(response.Body)
115	if err != nil {
116		return &smithy.DeserializationError{
117			Err: fmt.Errorf("read response, %w", err),
118		}
119	}
120
121	return &EndpointError{
122		// no sensible value for Code
123		Message:    string(msg),
124		Fault:      stof(response.StatusCode),
125		statusCode: response.StatusCode,
126	}
127}
128
129func deserializeJSONError(response *smithyhttp.Response) error {
130	var errShape *EndpointError
131	if err := json.NewDecoder(response.Body).Decode(&errShape); err != nil {
132		return &smithy.DeserializationError{
133			Err: fmt.Errorf("failed to decode error message, %w", err),
134		}
135	}
136
137	errShape.Fault = stof(response.StatusCode)
138	errShape.statusCode = response.StatusCode
139	return errShape
140}
141
142// maps HTTP status code to smithy ErrorFault
143func stof(code int) smithy.ErrorFault {
144	if code >= 500 {
145		return smithy.FaultServer
146	}
147	return smithy.FaultClient
148}
149
150func addProtocolFinalizerMiddlewares(stack *smithymiddleware.Stack, options Options, operation string) error {
151	if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, smithymiddleware.Before); err != nil {
152		return fmt.Errorf("add ResolveAuthScheme: %w", err)
153	}
154	if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", smithymiddleware.After); err != nil {
155		return fmt.Errorf("add GetIdentity: %w", err)
156	}
157	if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", smithymiddleware.After); err != nil {
158		return fmt.Errorf("add ResolveEndpointV2: %w", err)
159	}
160	if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", smithymiddleware.After); err != nil {
161		return fmt.Errorf("add Signing: %w", err)
162	}
163	return nil
164}