1package azure
2
3import (
4 "encoding/json"
5 "net/http"
6 "net/http/httptest"
7 "testing"
8
9 "charm.land/fantasy"
10 "github.com/stretchr/testify/assert"
11 "github.com/stretchr/testify/require"
12)
13
14func TestUserAgent(t *testing.T) {
15 t.Parallel()
16
17 newUAServer := func() (*httptest.Server, *[]map[string]string) {
18 var captured []map[string]string
19 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20 h := make(map[string]string)
21 for k, v := range r.Header {
22 if len(v) > 0 {
23 h[k] = v[0]
24 }
25 }
26 captured = append(captured, h)
27
28 w.Header().Set("Content-Type", "application/json")
29 _ = json.NewEncoder(w).Encode(mockOpenAIResponse())
30 }))
31 return server, &captured
32 }
33
34 prompt := fantasy.Prompt{
35 {
36 Role: fantasy.MessageRoleUser,
37 Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}},
38 },
39 }
40
41 t.Run("default UA applied", func(t *testing.T) {
42 t.Parallel()
43 server, captured := newUAServer()
44 defer server.Close()
45
46 p, err := New(WithAPIKey("k"), WithBaseURL(server.URL))
47 require.NoError(t, err)
48 model, _ := p.LanguageModel(t.Context(), "gpt-4")
49 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
50
51 require.Len(t, *captured, 1)
52 assert.Equal(t, "Charm Fantasy/"+fantasy.Version, (*captured)[0]["User-Agent"])
53 })
54
55 t.Run("WithUserAgent wins over default", func(t *testing.T) {
56 t.Parallel()
57 server, captured := newUAServer()
58 defer server.Close()
59
60 p, err := New(WithAPIKey("k"), WithBaseURL(server.URL), WithUserAgent("explicit-ua"))
61 require.NoError(t, err)
62 model, _ := p.LanguageModel(t.Context(), "gpt-4")
63 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
64
65 require.Len(t, *captured, 1)
66 assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"])
67 })
68
69 t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) {
70 t.Parallel()
71 server, captured := newUAServer()
72 defer server.Close()
73
74 p, err := New(
75 WithAPIKey("k"),
76 WithBaseURL(server.URL),
77 WithHeaders(map[string]string{"User-Agent": "from-headers"}),
78 WithUserAgent("explicit-ua"),
79 )
80 require.NoError(t, err)
81 model, _ := p.LanguageModel(t.Context(), "gpt-4")
82 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
83
84 require.Len(t, *captured, 1)
85 assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"])
86 })
87}
88
89func mockOpenAIResponse() map[string]any {
90 return map[string]any{
91 "id": "chatcmpl-test",
92 "object": "chat.completion",
93 "created": 1711115037,
94 "model": "gpt-4",
95 "choices": []map[string]any{
96 {
97 "index": 0,
98 "message": map[string]any{
99 "role": "assistant",
100 "content": "Hi there",
101 },
102 "finish_reason": "stop",
103 },
104 },
105 "usage": map[string]any{
106 "prompt_tokens": 4,
107 "total_tokens": 6,
108 "completion_tokens": 2,
109 },
110 }
111}