checksum_middleware.go

 1package http
 2
 3import (
 4	"context"
 5	"fmt"
 6
 7	"github.com/aws/smithy-go/middleware"
 8)
 9
10const contentMD5Header = "Content-Md5"
11
12// contentMD5Checksum provides a middleware to compute and set
13// content-md5 checksum for a http request
14type contentMD5Checksum struct {
15}
16
17// AddContentChecksumMiddleware adds checksum middleware to middleware's
18// build step.
19func AddContentChecksumMiddleware(stack *middleware.Stack) error {
20	// This middleware must be executed before request body is set.
21	return stack.Build.Add(&contentMD5Checksum{}, middleware.Before)
22}
23
24// ID returns the identifier for the checksum middleware
25func (m *contentMD5Checksum) ID() string { return "ContentChecksum" }
26
27// HandleBuild adds behavior to compute md5 checksum and add content-md5 header
28// on http request
29func (m *contentMD5Checksum) HandleBuild(
30	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
31) (
32	out middleware.BuildOutput, metadata middleware.Metadata, err error,
33) {
34	req, ok := in.Request.(*Request)
35	if !ok {
36		return out, metadata, fmt.Errorf("unknown request type %T", req)
37	}
38
39	// if Content-MD5 header is already present, return
40	if v := req.Header.Get(contentMD5Header); len(v) != 0 {
41		return next.HandleBuild(ctx, in)
42	}
43
44	// fetch the request stream.
45	stream := req.GetStream()
46	// compute checksum if payload is explicit
47	if stream != nil {
48		if !req.IsStreamSeekable() {
49			return out, metadata, fmt.Errorf(
50				"unseekable stream is not supported for computing md5 checksum")
51		}
52
53		v, err := computeMD5Checksum(stream)
54		if err != nil {
55			return out, metadata, fmt.Errorf("error computing md5 checksum, %w", err)
56		}
57
58		// reset the request stream
59		if err := req.RewindStream(); err != nil {
60			return out, metadata, fmt.Errorf(
61				"error rewinding request stream after computing md5 checksum, %w", err)
62		}
63
64		// set the 'Content-MD5' header
65		req.Header.Set(contentMD5Header, string(v))
66	}
67
68	// set md5 header value
69	return next.HandleBuild(ctx, in)
70}