1package agent
2
3import (
4 "bytes"
5 "encoding/json"
6 "io"
7 "net/http"
8 "path/filepath"
9 "reflect"
10 "strings"
11 "testing"
12
13 "go.yaml.in/yaml/v4"
14 "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
15 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
16)
17
18func newRecorder(t *testing.T) *recorder.Recorder {
19 cassetteName := filepath.Join("testdata", t.Name())
20
21 r, err := recorder.New(
22 cassetteName,
23 recorder.WithMode(recorder.ModeRecordOnce),
24 recorder.WithMatcher(customMatcher(t)),
25 recorder.WithMarshalFunc(marshalFunc),
26 recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
27 recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
28 )
29 if err != nil {
30 t.Fatalf("recorder: failed to create recorder: %v", err)
31 }
32
33 t.Cleanup(func() {
34 if err := r.Stop(); err != nil {
35 t.Errorf("recorder: failed to stop recorder: %v", err)
36 }
37 })
38
39 return r
40}
41
42func customMatcher(t *testing.T) recorder.MatcherFunc {
43 return func(r *http.Request, i cassette.Request) bool {
44 if r.Body == nil || r.Body == http.NoBody {
45 return cassette.DefaultMatcher(r, i)
46 }
47 if r.Method != i.Method || r.URL.String() != i.URL {
48 return false
49 }
50
51 reqBody, err := io.ReadAll(r.Body)
52 if err != nil {
53 t.Fatalf("recorder: failed to read request body")
54 }
55 r.Body.Close()
56 r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
57
58 // Some providers can sometimes generate JSON requests with keys in
59 // a different order, which means a direct string comparison will fail.
60 // Falling back to deserializing the content if we don't have a match.
61 if string(reqBody) == i.Body { // hot path
62 return true
63 }
64 var content1, content2 any
65 if err := json.Unmarshal(reqBody, &content1); err != nil {
66 return false
67 }
68 if err := json.Unmarshal([]byte(i.Body), &content2); err != nil {
69 return false
70 }
71 return reflect.DeepEqual(content1, content2)
72 }
73}
74
75func marshalFunc(in any) ([]byte, error) {
76 var buff bytes.Buffer
77 enc := yaml.NewEncoder(&buff)
78 enc.SetIndent(2)
79 enc.CompactSeqIndent()
80 if err := enc.Encode(in); err != nil {
81 return nil, err
82 }
83 return buff.Bytes(), nil
84}
85
86var headersToKeep = map[string]struct{}{
87 "accept": {},
88 "content-type": {},
89 "user-agent": {},
90}
91
92func hookRemoveHeaders(i *cassette.Interaction) error {
93 for k := range i.Request.Headers {
94 if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
95 delete(i.Request.Headers, k)
96 }
97 }
98 for k := range i.Response.Headers {
99 if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
100 delete(i.Response.Headers, k)
101 }
102 }
103 return nil
104}