useragent_test.go

  1package google
  2
  3import (
  4	"encoding/json"
  5	"net/http"
  6	"net/http/httptest"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"github.com/stretchr/testify/assert"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14func TestUserAgent(t *testing.T) {
 15	t.Parallel()
 16
 17	newUAServer := func() (*httptest.Server, *[]map[string]string) {
 18		var captured []map[string]string
 19		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 20			h := make(map[string]string)
 21			for k, v := range r.Header {
 22				if len(v) > 0 {
 23					h[k] = v[0]
 24				}
 25			}
 26			captured = append(captured, h)
 27
 28			w.Header().Set("Content-Type", "application/json")
 29			_ = json.NewEncoder(w).Encode(map[string]any{
 30				"candidates": []map[string]any{
 31					{
 32						"content": map[string]any{
 33							"role": "model",
 34							"parts": []map[string]any{
 35								{"text": "Hello"},
 36							},
 37						},
 38						"finishReason": "STOP",
 39					},
 40				},
 41				"usageMetadata": map[string]any{
 42					"promptTokenCount":     5,
 43					"candidatesTokenCount": 2,
 44					"totalTokenCount":      7,
 45				},
 46			})
 47		}))
 48		return server, &captured
 49	}
 50
 51	prompt := fantasy.Prompt{
 52		{
 53			Role:    fantasy.MessageRoleUser,
 54			Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}},
 55		},
 56	}
 57
 58	findUA := func(captured *[]map[string]string, want string) bool {
 59		for _, h := range *captured {
 60			if ua, ok := h["User-Agent"]; ok && ua == want {
 61				return true
 62			}
 63		}
 64		return false
 65	}
 66
 67	t.Run("default UA applied", func(t *testing.T) {
 68		t.Parallel()
 69		server, captured := newUAServer()
 70		defer server.Close()
 71
 72		p, err := New(
 73			WithVertex("test-project", "us-central1"),
 74			WithBaseURL(server.URL),
 75			WithSkipAuth(true),
 76		)
 77		require.NoError(t, err)
 78		model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
 79		require.NoError(t, err)
 80		_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
 81
 82		require.NotEmpty(t, *captured)
 83		assert.True(t, findUA(captured, "Charm Fantasy/"+fantasy.Version))
 84	})
 85
 86	t.Run("WithUserAgent wins over default", func(t *testing.T) {
 87		t.Parallel()
 88		server, captured := newUAServer()
 89		defer server.Close()
 90
 91		p, err := New(
 92			WithVertex("test-project", "us-central1"),
 93			WithBaseURL(server.URL),
 94			WithSkipAuth(true),
 95			WithUserAgent("explicit-ua"),
 96		)
 97		require.NoError(t, err)
 98		model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
 99		require.NoError(t, err)
100		_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
101
102		require.NotEmpty(t, *captured)
103		assert.True(t, findUA(captured, "explicit-ua"))
104	})
105
106	t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) {
107		t.Parallel()
108		server, captured := newUAServer()
109		defer server.Close()
110
111		p, err := New(
112			WithVertex("test-project", "us-central1"),
113			WithBaseURL(server.URL),
114			WithSkipAuth(true),
115			WithHeaders(map[string]string{"User-Agent": "custom-from-headers"}),
116		)
117		require.NoError(t, err)
118		model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
119		require.NoError(t, err)
120		_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
121
122		require.NotEmpty(t, *captured)
123		assert.True(t, findUA(captured, "custom-from-headers"))
124	})
125
126	t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) {
127		t.Parallel()
128		server, captured := newUAServer()
129		defer server.Close()
130
131		p, err := New(
132			WithVertex("test-project", "us-central1"),
133			WithBaseURL(server.URL),
134			WithSkipAuth(true),
135			WithHeaders(map[string]string{"User-Agent": "from-headers"}),
136			WithUserAgent("explicit-ua"),
137		)
138		require.NoError(t, err)
139		model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash")
140		require.NoError(t, err)
141		_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
142
143		require.NotEmpty(t, *captured)
144		assert.True(t, findUA(captured, "explicit-ua"))
145	})
146}