recorder_test.go

 1package providertests
 2
 3import (
 4	"bytes"
 5	"io"
 6	"net/http"
 7	"path/filepath"
 8	"strings"
 9	"testing"
10
11	"go.yaml.in/yaml/v4"
12	"gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16func newRecorder(t *testing.T) *recorder.Recorder {
17	cassetteName := filepath.Join("testdata", t.Name())
18
19	r, err := recorder.New(
20		cassetteName,
21		recorder.WithMode(recorder.ModeRecordOnce),
22		recorder.WithMatcher(customMatcher(t)),
23		recorder.WithMarshalFunc(marshalFunc),
24		recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
25		recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
26	)
27	if err != nil {
28		t.Fatalf("recorder: failed to create recorder: %v", err)
29	}
30
31	t.Cleanup(func() {
32		if err := r.Stop(); err != nil {
33			t.Errorf("recorder: failed to stop recorder: %v", err)
34		}
35	})
36
37	return r
38}
39
40func customMatcher(t *testing.T) recorder.MatcherFunc {
41	return func(r *http.Request, i cassette.Request) bool {
42		if r.Body == nil || r.Body == http.NoBody {
43			return cassette.DefaultMatcher(r, i)
44		}
45
46		var reqBody []byte
47		var err error
48		reqBody, err = io.ReadAll(r.Body)
49		if err != nil {
50			t.Fatalf("recorder: failed to read request body")
51		}
52		r.Body.Close()
53		r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
54
55		return r.Method == i.Method && r.URL.String() == i.URL && string(reqBody) == i.Body
56	}
57}
58
59func marshalFunc(in any) ([]byte, error) {
60	var buff bytes.Buffer
61	enc := yaml.NewEncoder(&buff)
62	enc.CompactSeqIndent()
63	if err := enc.Encode(in); err != nil {
64		return nil, err
65	}
66	return buff.Bytes(), nil
67}
68
69var headersToKeep = map[string]struct{}{
70	"accept":       {},
71	"content-type": {},
72	"user-agent":   {},
73}
74
75func hookRemoveHeaders(i *cassette.Interaction) error {
76	for k := range i.Request.Headers {
77		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
78			delete(i.Request.Headers, k)
79		}
80	}
81	for k := range i.Response.Headers {
82		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
83			delete(i.Response.Headers, k)
84		}
85	}
86	return nil
87}