1package http
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/aws/smithy-go/middleware"
8)
9
10type isContentTypeAutoSet struct{}
11
12// SetIsContentTypeDefaultValue returns a Context specifying if the request's
13// content-type header was set to a default value.
14func SetIsContentTypeDefaultValue(ctx context.Context, isDefault bool) context.Context {
15 return context.WithValue(ctx, isContentTypeAutoSet{}, isDefault)
16}
17
18// GetIsContentTypeDefaultValue returns if the content-type HTTP header on the
19// request is a default value that was auto assigned by an operation
20// serializer. Allows middleware post serialization to know if the content-type
21// was auto set to a default value or not.
22//
23// Also returns false if the Context value was never updated to include if
24// content-type was set to a default value.
25func GetIsContentTypeDefaultValue(ctx context.Context) bool {
26 v, _ := ctx.Value(isContentTypeAutoSet{}).(bool)
27 return v
28}
29
30// AddNoPayloadDefaultContentTypeRemover Adds the DefaultContentTypeRemover
31// middleware to the stack after the operation serializer. This middleware will
32// remove the content-type header from the request if it was set as a default
33// value, and no request payload is present.
34//
35// Returns error if unable to add the middleware.
36func AddNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) {
37 err = stack.Serialize.Insert(removeDefaultContentType{},
38 "OperationSerializer", middleware.After)
39 if err != nil {
40 return fmt.Errorf("failed to add %s serialize middleware, %w",
41 removeDefaultContentType{}.ID(), err)
42 }
43
44 return nil
45}
46
47// RemoveNoPayloadDefaultContentTypeRemover removes the
48// DefaultContentTypeRemover middleware from the stack. Returns an error if
49// unable to remove the middleware.
50func RemoveNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) {
51 _, err = stack.Serialize.Remove(removeDefaultContentType{}.ID())
52 if err != nil {
53 return fmt.Errorf("failed to remove %s serialize middleware, %w",
54 removeDefaultContentType{}.ID(), err)
55
56 }
57 return nil
58}
59
60// removeDefaultContentType provides after serialization middleware that will
61// remove the content-type header from an HTTP request if the header was set as
62// a default value by the operation serializer, and there is no request payload.
63type removeDefaultContentType struct{}
64
65// ID returns the middleware ID
66func (removeDefaultContentType) ID() string { return "RemoveDefaultContentType" }
67
68// HandleSerialize implements the serialization middleware.
69func (removeDefaultContentType) HandleSerialize(
70 ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
71) (
72 out middleware.SerializeOutput, meta middleware.Metadata, err error,
73) {
74 req, ok := input.Request.(*Request)
75 if !ok {
76 return out, meta, fmt.Errorf(
77 "unexpected request type %T for removeDefaultContentType middleware",
78 input.Request)
79 }
80
81 if GetIsContentTypeDefaultValue(ctx) && req.GetStream() == nil {
82 req.Header.Del("Content-Type")
83 input.Request = req
84 }
85
86 return next.HandleSerialize(ctx, input)
87}
88
89type headerValue struct {
90 header string
91 value string
92 append bool
93}
94
95type headerValueHelper struct {
96 headerValues []headerValue
97}
98
99func (h *headerValueHelper) addHeaderValue(value headerValue) {
100 h.headerValues = append(h.headerValues, value)
101}
102
103func (h *headerValueHelper) ID() string {
104 return "HTTPHeaderHelper"
105}
106
107func (h *headerValueHelper) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (out middleware.BuildOutput, metadata middleware.Metadata, err error) {
108 req, ok := in.Request.(*Request)
109 if !ok {
110 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
111 }
112
113 for _, value := range h.headerValues {
114 if value.append {
115 req.Header.Add(value.header, value.value)
116 } else {
117 req.Header.Set(value.header, value.value)
118 }
119 }
120
121 return next.HandleBuild(ctx, in)
122}
123
124func getOrAddHeaderValueHelper(stack *middleware.Stack) (*headerValueHelper, error) {
125 id := (*headerValueHelper)(nil).ID()
126 m, ok := stack.Build.Get(id)
127 if !ok {
128 m = &headerValueHelper{}
129 err := stack.Build.Add(m, middleware.After)
130 if err != nil {
131 return nil, err
132 }
133 }
134
135 requestUserAgent, ok := m.(*headerValueHelper)
136 if !ok {
137 return nil, fmt.Errorf("%T for %s middleware did not match expected type", m, id)
138 }
139
140 return requestUserAgent, nil
141}
142
143// AddHeaderValue returns a stack mutator that adds the header value pair to header.
144// Appends to any existing values if present.
145func AddHeaderValue(header string, value string) func(stack *middleware.Stack) error {
146 return func(stack *middleware.Stack) error {
147 helper, err := getOrAddHeaderValueHelper(stack)
148 if err != nil {
149 return err
150 }
151 helper.addHeaderValue(headerValue{header: header, value: value, append: true})
152 return nil
153 }
154}
155
156// SetHeaderValue returns a stack mutator that adds the header value pair to header.
157// Replaces any existing values if present.
158func SetHeaderValue(header string, value string) func(stack *middleware.Stack) error {
159 return func(stack *middleware.Stack) error {
160 helper, err := getOrAddHeaderValueHelper(stack)
161 if err != nil {
162 return err
163 }
164 helper.addHeaderValue(headerValue{header: header, value: value, append: false})
165 return nil
166 }
167}