1package providertests
2
3import (
4 "bytes"
5 "io"
6 "net/http"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "go.yaml.in/yaml/v4"
12 "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16func newRecorder(t *testing.T) *recorder.Recorder {
17 cassetteName := filepath.Join("testdata", t.Name())
18
19 r, err := recorder.New(
20 cassetteName,
21 recorder.WithMode(recorder.ModeRecordOnce),
22 recorder.WithMatcher(customMatcher(t)),
23 recorder.WithMarshalFunc(marshalFunc),
24 recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
25 recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
26 )
27 if err != nil {
28 t.Fatalf("recorder: failed to create recorder: %v", err)
29 }
30
31 t.Cleanup(func() {
32 if err := r.Stop(); err != nil {
33 t.Errorf("recorder: failed to stop recorder: %v", err)
34 }
35 })
36
37 return r
38}
39
40func customMatcher(t *testing.T) recorder.MatcherFunc {
41 return func(r *http.Request, i cassette.Request) bool {
42 if r.Body == nil || r.Body == http.NoBody {
43 return cassette.DefaultMatcher(r, i)
44 }
45
46 var reqBody []byte
47 var err error
48 reqBody, err = io.ReadAll(r.Body)
49 if err != nil {
50 t.Fatalf("recorder: failed to read request body")
51 }
52 r.Body.Close()
53 r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
54
55 return r.Method == i.Method && r.URL.String() == i.URL && string(reqBody) == i.Body
56 }
57}
58
59func marshalFunc(in any) ([]byte, error) {
60 var buff bytes.Buffer
61 enc := yaml.NewEncoder(&buff)
62 enc.SetIndent(2)
63 enc.CompactSeqIndent()
64 if err := enc.Encode(in); err != nil {
65 return nil, err
66 }
67 return buff.Bytes(), nil
68}
69
70var headersToKeep = map[string]struct{}{
71 "accept": {},
72 "content-type": {},
73 "user-agent": {},
74}
75
76func hookRemoveHeaders(i *cassette.Interaction) error {
77 for k := range i.Request.Headers {
78 if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
79 delete(i.Request.Headers, k)
80 }
81 }
82 for k := range i.Response.Headers {
83 if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
84 delete(i.Response.Headers, k)
85 }
86 }
87 return nil
88}