request_compression.go

  1// Package requestcompression implements runtime support for smithy-modeled
  2// request compression.
  3//
  4// This package is designated as private and is intended for use only by the
  5// smithy client runtime. The exported API therein is not considered stable and
  6// is subject to breaking changes without notice.
  7package requestcompression
  8
  9import (
 10	"bytes"
 11	"context"
 12	"fmt"
 13	"github.com/aws/smithy-go/middleware"
 14	"github.com/aws/smithy-go/transport/http"
 15	"io"
 16)
 17
 18const MaxRequestMinCompressSizeBytes = 10485760
 19
 20// Enumeration values for supported compress Algorithms.
 21const (
 22	GZIP = "gzip"
 23)
 24
 25type compressFunc func(io.Reader) ([]byte, error)
 26
 27var allowedAlgorithms = map[string]compressFunc{
 28	GZIP: gzipCompress,
 29}
 30
 31// AddRequestCompression add requestCompression middleware to op stack
 32func AddRequestCompression(stack *middleware.Stack, disabled bool, minBytes int64, algorithms []string) error {
 33	return stack.Serialize.Add(&requestCompression{
 34		disableRequestCompression:   disabled,
 35		requestMinCompressSizeBytes: minBytes,
 36		compressAlgorithms:          algorithms,
 37	}, middleware.After)
 38}
 39
 40type requestCompression struct {
 41	disableRequestCompression   bool
 42	requestMinCompressSizeBytes int64
 43	compressAlgorithms          []string
 44}
 45
 46// ID returns the ID of the middleware
 47func (m requestCompression) ID() string {
 48	return "RequestCompression"
 49}
 50
 51// HandleSerialize gzip compress the request's stream/body if enabled by config fields
 52func (m requestCompression) HandleSerialize(
 53	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
 54) (
 55	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
 56) {
 57	if m.disableRequestCompression {
 58		return next.HandleSerialize(ctx, in)
 59	}
 60	// still need to check requestMinCompressSizeBytes in case it is out of range after service client config
 61	if m.requestMinCompressSizeBytes < 0 || m.requestMinCompressSizeBytes > MaxRequestMinCompressSizeBytes {
 62		return out, metadata, fmt.Errorf("invalid range for min request compression size bytes %d, must be within 0 and 10485760 inclusively", m.requestMinCompressSizeBytes)
 63	}
 64
 65	req, ok := in.Request.(*http.Request)
 66	if !ok {
 67		return out, metadata, fmt.Errorf("unknown request type %T", req)
 68	}
 69
 70	for _, algorithm := range m.compressAlgorithms {
 71		compressFunc := allowedAlgorithms[algorithm]
 72		if compressFunc != nil {
 73			if stream := req.GetStream(); stream != nil {
 74				size, found, err := req.StreamLength()
 75				if err != nil {
 76					return out, metadata, fmt.Errorf("error while finding request stream length, %v", err)
 77				} else if !found || size < m.requestMinCompressSizeBytes {
 78					return next.HandleSerialize(ctx, in)
 79				}
 80
 81				compressedBytes, err := compressFunc(stream)
 82				if err != nil {
 83					return out, metadata, fmt.Errorf("failed to compress request stream, %v", err)
 84				}
 85
 86				var newReq *http.Request
 87				if newReq, err = req.SetStream(bytes.NewReader(compressedBytes)); err != nil {
 88					return out, metadata, fmt.Errorf("failed to set request stream, %v", err)
 89				}
 90				*req = *newReq
 91
 92				if val := req.Header.Get("Content-Encoding"); val != "" {
 93					req.Header.Set("Content-Encoding", fmt.Sprintf("%s, %s", val, algorithm))
 94				} else {
 95					req.Header.Set("Content-Encoding", algorithm)
 96				}
 97			}
 98			break
 99		}
100	}
101
102	return next.HandleSerialize(ctx, in)
103}