1package agent
  2
  3import (
  4	"bytes"
  5	"encoding/json"
  6	"io"
  7	"net/http"
  8	"path/filepath"
  9	"reflect"
 10	"strings"
 11	"testing"
 12
 13	"go.yaml.in/yaml/v4"
 14	"gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16)
 17
 18func newRecorder(t *testing.T) *recorder.Recorder {
 19	cassetteName := filepath.Join("testdata", t.Name())
 20
 21	r, err := recorder.New(
 22		cassetteName,
 23		recorder.WithMode(recorder.ModeRecordOnce),
 24		recorder.WithMatcher(customMatcher(t)),
 25		recorder.WithMarshalFunc(marshalFunc),
 26		recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
 27		recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
 28	)
 29	if err != nil {
 30		t.Fatalf("recorder: failed to create recorder: %v", err)
 31	}
 32
 33	t.Cleanup(func() {
 34		if err := r.Stop(); err != nil {
 35			t.Errorf("recorder: failed to stop recorder: %v", err)
 36		}
 37	})
 38
 39	return r
 40}
 41
 42func customMatcher(t *testing.T) recorder.MatcherFunc {
 43	return func(r *http.Request, i cassette.Request) bool {
 44		if r.Body == nil || r.Body == http.NoBody {
 45			return cassette.DefaultMatcher(r, i)
 46		}
 47		if r.Method != i.Method || r.URL.String() != i.URL {
 48			return false
 49		}
 50
 51		reqBody, err := io.ReadAll(r.Body)
 52		if err != nil {
 53			t.Fatalf("recorder: failed to read request body")
 54		}
 55		r.Body.Close()
 56		r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
 57
 58		// Some providers can sometimes generate JSON requests with keys in
 59		// a different order, which means a direct string comparison will fail.
 60		// Falling back to deserializing the content if we don't have a match.
 61		if string(reqBody) == i.Body { // hot path
 62			return true
 63		}
 64		var content1, content2 any
 65		if err := json.Unmarshal(reqBody, &content1); err != nil {
 66			return false
 67		}
 68		if err := json.Unmarshal([]byte(i.Body), &content2); err != nil {
 69			return false
 70		}
 71		return reflect.DeepEqual(content1, content2)
 72	}
 73}
 74
 75func marshalFunc(in any) ([]byte, error) {
 76	var buff bytes.Buffer
 77	enc := yaml.NewEncoder(&buff)
 78	enc.SetIndent(2)
 79	enc.CompactSeqIndent()
 80	if err := enc.Encode(in); err != nil {
 81		return nil, err
 82	}
 83	return buff.Bytes(), nil
 84}
 85
 86var headersToKeep = map[string]struct{}{
 87	"accept":       {},
 88	"content-type": {},
 89	"user-agent":   {},
 90}
 91
 92func hookRemoveHeaders(i *cassette.Interaction) error {
 93	for k := range i.Request.Headers {
 94		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
 95			delete(i.Request.Headers, k)
 96		}
 97	}
 98	for k := range i.Response.Headers {
 99		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
100			delete(i.Response.Headers, k)
101		}
102	}
103	return nil
104}