1package google
2
3import (
4 "encoding/json"
5 "net/http"
6 "net/http/httptest"
7 "testing"
8
9 "charm.land/fantasy"
10 "github.com/stretchr/testify/assert"
11 "github.com/stretchr/testify/require"
12)
13
14func TestUserAgent(t *testing.T) {
15 t.Parallel()
16
17 newUAServer := func() (*httptest.Server, *[]map[string]string) {
18 var captured []map[string]string
19 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20 h := make(map[string]string)
21 for k, v := range r.Header {
22 if len(v) > 0 {
23 h[k] = v[0]
24 }
25 }
26 captured = append(captured, h)
27
28 w.Header().Set("Content-Type", "application/json")
29 _ = json.NewEncoder(w).Encode(map[string]any{
30 "candidates": []map[string]any{
31 {
32 "content": map[string]any{
33 "role": "model",
34 "parts": []map[string]any{
35 {"text": "Hello"},
36 },
37 },
38 "finishReason": "STOP",
39 },
40 },
41 "usageMetadata": map[string]any{
42 "promptTokenCount": 5,
43 "candidatesTokenCount": 2,
44 "totalTokenCount": 7,
45 },
46 })
47 }))
48 return server, &captured
49 }
50
51 prompt := fantasy.Prompt{
52 {
53 Role: fantasy.MessageRoleUser,
54 Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}},
55 },
56 }
57
58 findUA := func(captured *[]map[string]string, want string) bool {
59 for _, h := range *captured {
60 if ua, ok := h["User-Agent"]; ok && ua == want {
61 return true
62 }
63 }
64 return false
65 }
66
67 t.Run("default UA applied", func(t *testing.T) {
68 t.Parallel()
69 server, captured := newUAServer()
70 defer server.Close()
71
72 p, err := New(
73 WithVertex("test-project", "us-central1"),
74 WithBaseURL(server.URL),
75 WithSkipAuth(true),
76 )
77 require.NoError(t, err)
78 model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
79 require.NoError(t, err)
80 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
81
82 require.NotEmpty(t, *captured)
83 assert.True(t, findUA(captured, "Charm Fantasy/"+fantasy.Version))
84 })
85
86 t.Run("WithUserAgent wins over default", func(t *testing.T) {
87 t.Parallel()
88 server, captured := newUAServer()
89 defer server.Close()
90
91 p, err := New(
92 WithVertex("test-project", "us-central1"),
93 WithBaseURL(server.URL),
94 WithSkipAuth(true),
95 WithUserAgent("explicit-ua"),
96 )
97 require.NoError(t, err)
98 model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
99 require.NoError(t, err)
100 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
101
102 require.NotEmpty(t, *captured)
103 assert.True(t, findUA(captured, "explicit-ua"))
104 })
105
106 t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) {
107 t.Parallel()
108 server, captured := newUAServer()
109 defer server.Close()
110
111 p, err := New(
112 WithVertex("test-project", "us-central1"),
113 WithBaseURL(server.URL),
114 WithSkipAuth(true),
115 WithHeaders(map[string]string{"User-Agent": "custom-from-headers"}),
116 )
117 require.NoError(t, err)
118 model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
119 require.NoError(t, err)
120 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
121
122 require.NotEmpty(t, *captured)
123 assert.True(t, findUA(captured, "custom-from-headers"))
124 })
125
126 t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) {
127 t.Parallel()
128 server, captured := newUAServer()
129 defer server.Close()
130
131 p, err := New(
132 WithVertex("test-project", "us-central1"),
133 WithBaseURL(server.URL),
134 WithSkipAuth(true),
135 WithHeaders(map[string]string{"User-Agent": "from-headers"}),
136 WithUserAgent("explicit-ua"),
137 )
138 require.NoError(t, err)
139 model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
140 require.NoError(t, err)
141 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
142
143 require.NotEmpty(t, *captured)
144 assert.True(t, findUA(captured, "explicit-ua"))
145 })
146}