1// Package requestcompression implements runtime support for smithy-modeled
2// request compression.
3//
4// This package is designated as private and is intended for use only by the
5// smithy client runtime. The exported API therein is not considered stable and
6// is subject to breaking changes without notice.
7package requestcompression
8
9import (
10 "bytes"
11 "context"
12 "fmt"
13 "github.com/aws/smithy-go/middleware"
14 "github.com/aws/smithy-go/transport/http"
15 "io"
16)
17
18const MaxRequestMinCompressSizeBytes = 10485760
19
20// Enumeration values for supported compress Algorithms.
21const (
22 GZIP = "gzip"
23)
24
25type compressFunc func(io.Reader) ([]byte, error)
26
27var allowedAlgorithms = map[string]compressFunc{
28 GZIP: gzipCompress,
29}
30
31// AddRequestCompression add requestCompression middleware to op stack
32func AddRequestCompression(stack *middleware.Stack, disabled bool, minBytes int64, algorithms []string) error {
33 return stack.Serialize.Add(&requestCompression{
34 disableRequestCompression: disabled,
35 requestMinCompressSizeBytes: minBytes,
36 compressAlgorithms: algorithms,
37 }, middleware.After)
38}
39
40type requestCompression struct {
41 disableRequestCompression bool
42 requestMinCompressSizeBytes int64
43 compressAlgorithms []string
44}
45
46// ID returns the ID of the middleware
47func (m requestCompression) ID() string {
48 return "RequestCompression"
49}
50
51// HandleSerialize gzip compress the request's stream/body if enabled by config fields
52func (m requestCompression) HandleSerialize(
53 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
54) (
55 out middleware.SerializeOutput, metadata middleware.Metadata, err error,
56) {
57 if m.disableRequestCompression {
58 return next.HandleSerialize(ctx, in)
59 }
60 // still need to check requestMinCompressSizeBytes in case it is out of range after service client config
61 if m.requestMinCompressSizeBytes < 0 || m.requestMinCompressSizeBytes > MaxRequestMinCompressSizeBytes {
62 return out, metadata, fmt.Errorf("invalid range for min request compression size bytes %d, must be within 0 and 10485760 inclusively", m.requestMinCompressSizeBytes)
63 }
64
65 req, ok := in.Request.(*http.Request)
66 if !ok {
67 return out, metadata, fmt.Errorf("unknown request type %T", req)
68 }
69
70 for _, algorithm := range m.compressAlgorithms {
71 compressFunc := allowedAlgorithms[algorithm]
72 if compressFunc != nil {
73 if stream := req.GetStream(); stream != nil {
74 size, found, err := req.StreamLength()
75 if err != nil {
76 return out, metadata, fmt.Errorf("error while finding request stream length, %v", err)
77 } else if !found || size < m.requestMinCompressSizeBytes {
78 return next.HandleSerialize(ctx, in)
79 }
80
81 compressedBytes, err := compressFunc(stream)
82 if err != nil {
83 return out, metadata, fmt.Errorf("failed to compress request stream, %v", err)
84 }
85
86 var newReq *http.Request
87 if newReq, err = req.SetStream(bytes.NewReader(compressedBytes)); err != nil {
88 return out, metadata, fmt.Errorf("failed to set request stream, %v", err)
89 }
90 *req = *newReq
91
92 if val := req.Header.Get("Content-Encoding"); val != "" {
93 req.Header.Set("Content-Encoding", fmt.Sprintf("%s, %s", val, algorithm))
94 } else {
95 req.Header.Set("Content-Encoding", algorithm)
96 }
97 }
98 break
99 }
100 }
101
102 return next.HandleSerialize(ctx, in)
103}