1package agent
  2
  3import (
  4	"bytes"
  5	"encoding/json"
  6	"io"
  7	"net/http"
  8	"path/filepath"
  9	"reflect"
 10	"strings"
 11	"testing"
 12
 13	"go.yaml.in/yaml/v4"
 14	"gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16)
 17
 18func newRecorder(t *testing.T) *recorder.Recorder {
 19	cassetteName := filepath.Join("testdata", t.Name())
 20
 21	r, err := recorder.New(
 22		cassetteName,
 23		recorder.WithMode(recorder.ModeRecordOnce),
 24		recorder.WithMatcher(customMatcher(t)),
 25		recorder.WithMarshalFunc(marshalFunc),
 26		recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
 27		recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
 28	)
 29	if err != nil {
 30		t.Fatalf("recorder: failed to create recorder: %v", err)
 31	}
 32
 33	t.Cleanup(func() {
 34		if err := r.Stop(); err != nil {
 35			t.Errorf("recorder: failed to stop recorder: %v", err)
 36		}
 37	})
 38
 39	return r
 40}
 41
 42func customMatcher(t *testing.T) recorder.MatcherFunc {
 43	return func(r *http.Request, i cassette.Request) bool {
 44		if r.Body == nil || r.Body == http.NoBody {
 45			return cassette.DefaultMatcher(r, i)
 46		}
 47		if r.Method != i.Method || r.URL.String() != i.URL {
 48			return false
 49		}
 50
 51		reqBody, err := io.ReadAll(r.Body)
 52		if err != nil {
 53			t.Fatalf("recorder: failed to read request body")
 54		}
 55		r.Body.Close()
 56		r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
 57
 58		// Some providers can sometimes generate JSON requests with keys in
 59		// a different order, which means a direct string comparison will fail.
 60		// Falling back to deserializing the content if we don't have a match.
 61		requestContent := normalizeLineEndings(reqBody)
 62		cassetteContent := normalizeLineEndings(i.Body)
 63		if requestContent == cassetteContent {
 64			return true
 65		}
 66		var content1, content2 any
 67		if err := json.Unmarshal([]byte(requestContent), &content1); err != nil {
 68			return false
 69		}
 70		if err := json.Unmarshal([]byte(cassetteContent), &content2); err != nil {
 71			return false
 72		}
 73		return reflect.DeepEqual(content1, content2)
 74	}
 75}
 76
 77func marshalFunc(in any) ([]byte, error) {
 78	var buff bytes.Buffer
 79	enc := yaml.NewEncoder(&buff)
 80	enc.SetIndent(2)
 81	enc.CompactSeqIndent()
 82	if err := enc.Encode(in); err != nil {
 83		return nil, err
 84	}
 85	return buff.Bytes(), nil
 86}
 87
 88var headersToKeep = map[string]struct{}{
 89	"accept":       {},
 90	"content-type": {},
 91	"user-agent":   {},
 92}
 93
 94func hookRemoveHeaders(i *cassette.Interaction) error {
 95	for k := range i.Request.Headers {
 96		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
 97			delete(i.Request.Headers, k)
 98		}
 99	}
100	for k := range i.Response.Headers {
101		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
102			delete(i.Response.Headers, k)
103		}
104	}
105	return nil
106}
107
108// normalizeLineEndings does not only replace `\r\n` into `\n`,
109// but also replaces `\\r\\n` into `\\n`. That's because we want the content
110// inside JSON string to be replaces as well.
111func normalizeLineEndings[T string | []byte](s T) string {
112	str := string(s)
113	str = strings.ReplaceAll(str, "\r\n", "\n")
114	str = strings.ReplaceAll(str, `\r\n`, `\n`)
115	return str
116}