1package agent
2
3import (
4 "bytes"
5 "encoding/json"
6 "io"
7 "net/http"
8 "path/filepath"
9 "reflect"
10 "strings"
11 "testing"
12
13 "go.yaml.in/yaml/v4"
14 "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
15 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
16)
17
18func newRecorder(t *testing.T) *recorder.Recorder {
19 cassetteName := filepath.Join("testdata", t.Name())
20
21 r, err := recorder.New(
22 cassetteName,
23 recorder.WithMode(recorder.ModeRecordOnce),
24 recorder.WithMatcher(customMatcher(t)),
25 recorder.WithMarshalFunc(marshalFunc),
26 recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
27 recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
28 )
29 if err != nil {
30 t.Fatalf("recorder: failed to create recorder: %v", err)
31 }
32
33 t.Cleanup(func() {
34 if err := r.Stop(); err != nil {
35 t.Errorf("recorder: failed to stop recorder: %v", err)
36 }
37 })
38
39 return r
40}
41
42func customMatcher(t *testing.T) recorder.MatcherFunc {
43 return func(r *http.Request, i cassette.Request) bool {
44 if r.Body == nil || r.Body == http.NoBody {
45 return cassette.DefaultMatcher(r, i)
46 }
47 if r.Method != i.Method || r.URL.String() != i.URL {
48 return false
49 }
50
51 reqBody, err := io.ReadAll(r.Body)
52 if err != nil {
53 t.Fatalf("recorder: failed to read request body")
54 }
55 r.Body.Close()
56 r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
57
58 // Some providers can sometimes generate JSON requests with keys in
59 // a different order, which means a direct string comparison will fail.
60 // Falling back to deserializing the content if we don't have a match.
61 requestContent := normalizeLineEndings(reqBody)
62 cassetteContent := normalizeLineEndings(i.Body)
63 if requestContent == cassetteContent {
64 return true
65 }
66 var content1, content2 any
67 if err := json.Unmarshal([]byte(requestContent), &content1); err != nil {
68 return false
69 }
70 if err := json.Unmarshal([]byte(cassetteContent), &content2); err != nil {
71 return false
72 }
73 return reflect.DeepEqual(content1, content2)
74 }
75}
76
77func marshalFunc(in any) ([]byte, error) {
78 var buff bytes.Buffer
79 enc := yaml.NewEncoder(&buff)
80 enc.SetIndent(2)
81 enc.CompactSeqIndent()
82 if err := enc.Encode(in); err != nil {
83 return nil, err
84 }
85 return buff.Bytes(), nil
86}
87
88var headersToKeep = map[string]struct{}{
89 "accept": {},
90 "content-type": {},
91 "user-agent": {},
92}
93
94func hookRemoveHeaders(i *cassette.Interaction) error {
95 for k := range i.Request.Headers {
96 if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
97 delete(i.Request.Headers, k)
98 }
99 }
100 for k := range i.Response.Headers {
101 if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
102 delete(i.Response.Headers, k)
103 }
104 }
105 return nil
106}
107
108// normalizeLineEndings does not only replace `\r\n` into `\n`,
109// but also replaces `\\r\\n` into `\\n`. That's because we want the content
110// inside JSON string to be replaces as well.
111func normalizeLineEndings[T string | []byte](s T) string {
112 str := string(s)
113 str = strings.ReplaceAll(str, "\r\n", "\n")
114 str = strings.ReplaceAll(str, `\r\n`, `\n`)
115 return str
116}