middleware_content_length.go

 1package http
 2
 3import (
 4	"context"
 5	"fmt"
 6
 7	"github.com/aws/smithy-go/middleware"
 8)
 9
10// ComputeContentLength provides a middleware to set the content-length
11// header for the length of a serialize request body.
12type ComputeContentLength struct {
13}
14
15// AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
16// stack's Build step.
17func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
18	return stack.Build.Add(&ComputeContentLength{}, middleware.After)
19}
20
21// ID returns the identifier for the ComputeContentLength.
22func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
23
24// HandleBuild adds the length of the serialized request to the HTTP header
25// if the length can be determined.
26func (m *ComputeContentLength) HandleBuild(
27	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
28) (
29	out middleware.BuildOutput, metadata middleware.Metadata, err error,
30) {
31	req, ok := in.Request.(*Request)
32	if !ok {
33		return out, metadata, fmt.Errorf("unknown request type %T", req)
34	}
35
36	// do nothing if request content-length was set to 0 or above.
37	if req.ContentLength >= 0 {
38		return next.HandleBuild(ctx, in)
39	}
40
41	// attempt to compute stream length
42	if n, ok, err := req.StreamLength(); err != nil {
43		return out, metadata, fmt.Errorf(
44			"failed getting length of request stream, %w", err)
45	} else if ok {
46		req.ContentLength = n
47	}
48
49	return next.HandleBuild(ctx, in)
50}
51
52// validateContentLength provides a middleware to validate the content-length
53// is valid (greater than zero), for the serialized request payload.
54type validateContentLength struct{}
55
56// ValidateContentLengthHeader adds middleware that validates request content-length
57// is set to value greater than zero.
58func ValidateContentLengthHeader(stack *middleware.Stack) error {
59	return stack.Build.Add(&validateContentLength{}, middleware.After)
60}
61
62// ID returns the identifier for the ComputeContentLength.
63func (m *validateContentLength) ID() string { return "ValidateContentLength" }
64
65// HandleBuild adds the length of the serialized request to the HTTP header
66// if the length can be determined.
67func (m *validateContentLength) HandleBuild(
68	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
69) (
70	out middleware.BuildOutput, metadata middleware.Metadata, err error,
71) {
72	req, ok := in.Request.(*Request)
73	if !ok {
74		return out, metadata, fmt.Errorf("unknown request type %T", req)
75	}
76
77	// if request content-length was set to less than 0, return an error
78	if req.ContentLength < 0 {
79		return out, metadata, fmt.Errorf(
80			"content length for payload is required and must be at least 0")
81	}
82
83	return next.HandleBuild(ctx, in)
84}