agent_useragent_test.go

 1package fantasy
 2
 3import (
 4	"context"
 5	"testing"
 6
 7	"github.com/stretchr/testify/assert"
 8	"github.com/stretchr/testify/require"
 9)
10
11func TestAgent_WithUserAgent_PropagatesOnGenerate(t *testing.T) {
12	t.Parallel()
13
14	var capturedCall Call
15	model := &mockLanguageModel{
16		generateFunc: func(_ context.Context, call Call) (*Response, error) {
17			capturedCall = call
18			return &Response{
19				Content:      []Content{TextContent{Text: "ok"}},
20				FinishReason: FinishReasonStop,
21			}, nil
22		},
23	}
24
25	agent := NewAgent(model, WithUserAgent("MyApp/2.0"))
26	_, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"})
27	require.NoError(t, err)
28	assert.Equal(t, "MyApp/2.0", capturedCall.UserAgent)
29}
30
31func TestAgent_WithUserAgent_PropagatesOnStream(t *testing.T) {
32	t.Parallel()
33
34	var capturedCall Call
35	model := &mockLanguageModel{
36		streamFunc: func(_ context.Context, call Call) (StreamResponse, error) {
37			capturedCall = call
38			return func(yield func(StreamPart) bool) {
39				yield(StreamPart{
40					Type:         StreamPartTypeFinish,
41					FinishReason: FinishReasonStop,
42				})
43			}, nil
44		},
45	}
46
47	agent := NewAgent(model, WithUserAgent("StreamApp/1.0"))
48	_, err := agent.Stream(context.Background(), AgentStreamCall{Prompt: "hi"})
49	require.NoError(t, err)
50	assert.Equal(t, "StreamApp/1.0", capturedCall.UserAgent)
51}
52
53func TestAgent_NoUA_OmitsCallLevelFields(t *testing.T) {
54	t.Parallel()
55
56	var capturedCall Call
57	model := &mockLanguageModel{
58		generateFunc: func(_ context.Context, call Call) (*Response, error) {
59			capturedCall = call
60			return &Response{
61				Content:      []Content{TextContent{Text: "ok"}},
62				FinishReason: FinishReasonStop,
63			}, nil
64		},
65	}
66
67	agent := NewAgent(model)
68	_, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"})
69	require.NoError(t, err)
70	assert.Empty(t, capturedCall.UserAgent)
71}