middleware_capture_request_compression.go

 1package requestcompression
 2
 3import (
 4	"bytes"
 5	"context"
 6	"fmt"
 7	"github.com/aws/smithy-go/middleware"
 8	smithyhttp "github.com/aws/smithy-go/transport/http"
 9	"io"
10	"net/http"
11)
12
13const captureUncompressedRequestID = "CaptureUncompressedRequest"
14
15// AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check
16func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error {
17	return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{
18		buf: buf,
19	}, "RequestCompression", middleware.Before)
20}
21
22type captureUncompressedRequestMiddleware struct {
23	req   *http.Request
24	buf   *bytes.Buffer
25	bytes []byte
26}
27
28// ID returns id of the captureUncompressedRequestMiddleware
29func (*captureUncompressedRequestMiddleware) ID() string {
30	return captureUncompressedRequestID
31}
32
33// HandleSerialize captures request payload before it is compressed by request compression middleware
34func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
35) (
36	output middleware.SerializeOutput, metadata middleware.Metadata, err error,
37) {
38	request, ok := input.Request.(*smithyhttp.Request)
39	if !ok {
40		return output, metadata, fmt.Errorf("error when retrieving http request")
41	}
42
43	_, err = io.Copy(m.buf, request.GetStream())
44	if err != nil {
45		return output, metadata, fmt.Errorf("error when copying http request stream: %q", err)
46	}
47	if err = request.RewindStream(); err != nil {
48		return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err)
49	}
50
51	return next.HandleSerialize(ctx, input)
52}