1package http
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/aws/smithy-go/middleware"
8)
9
10// ComputeContentLength provides a middleware to set the content-length
11// header for the length of a serialize request body.
12type ComputeContentLength struct {
13}
14
15// AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
16// stack's Build step.
17func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
18 return stack.Build.Add(&ComputeContentLength{}, middleware.After)
19}
20
21// ID returns the identifier for the ComputeContentLength.
22func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
23
24// HandleBuild adds the length of the serialized request to the HTTP header
25// if the length can be determined.
26func (m *ComputeContentLength) HandleBuild(
27 ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
28) (
29 out middleware.BuildOutput, metadata middleware.Metadata, err error,
30) {
31 req, ok := in.Request.(*Request)
32 if !ok {
33 return out, metadata, fmt.Errorf("unknown request type %T", req)
34 }
35
36 // do nothing if request content-length was set to 0 or above.
37 if req.ContentLength >= 0 {
38 return next.HandleBuild(ctx, in)
39 }
40
41 // attempt to compute stream length
42 if n, ok, err := req.StreamLength(); err != nil {
43 return out, metadata, fmt.Errorf(
44 "failed getting length of request stream, %w", err)
45 } else if ok {
46 req.ContentLength = n
47 }
48
49 return next.HandleBuild(ctx, in)
50}
51
52// validateContentLength provides a middleware to validate the content-length
53// is valid (greater than zero), for the serialized request payload.
54type validateContentLength struct{}
55
56// ValidateContentLengthHeader adds middleware that validates request content-length
57// is set to value greater than zero.
58func ValidateContentLengthHeader(stack *middleware.Stack) error {
59 return stack.Build.Add(&validateContentLength{}, middleware.After)
60}
61
62// ID returns the identifier for the ComputeContentLength.
63func (m *validateContentLength) ID() string { return "ValidateContentLength" }
64
65// HandleBuild adds the length of the serialized request to the HTTP header
66// if the length can be determined.
67func (m *validateContentLength) HandleBuild(
68 ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
69) (
70 out middleware.BuildOutput, metadata middleware.Metadata, err error,
71) {
72 req, ok := in.Request.(*Request)
73 if !ok {
74 return out, metadata, fmt.Errorf("unknown request type %T", req)
75 }
76
77 // if request content-length was set to less than 0, return an error
78 if req.ContentLength < 0 {
79 return out, metadata, fmt.Errorf(
80 "content length for payload is required and must be at least 0")
81 }
82
83 return next.HandleBuild(ctx, in)
84}