useragent_test.go

  1package azure
  2
  3import (
  4	"encoding/json"
  5	"net/http"
  6	"net/http/httptest"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"github.com/stretchr/testify/assert"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14func TestUserAgent(t *testing.T) {
 15	t.Parallel()
 16
 17	newUAServer := func() (*httptest.Server, *[]map[string]string) {
 18		var captured []map[string]string
 19		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 20			h := make(map[string]string)
 21			for k, v := range r.Header {
 22				if len(v) > 0 {
 23					h[k] = v[0]
 24				}
 25			}
 26			captured = append(captured, h)
 27
 28			w.Header().Set("Content-Type", "application/json")
 29			_ = json.NewEncoder(w).Encode(mockOpenAIResponse())
 30		}))
 31		return server, &captured
 32	}
 33
 34	prompt := fantasy.Prompt{
 35		{
 36			Role:    fantasy.MessageRoleUser,
 37			Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}},
 38		},
 39	}
 40
 41	t.Run("default UA applied", func(t *testing.T) {
 42		t.Parallel()
 43		server, captured := newUAServer()
 44		defer server.Close()
 45
 46		p, err := New(WithAPIKey("k"), WithBaseURL(server.URL))
 47		require.NoError(t, err)
 48		model, _ := p.LanguageModel(t.Context(), "gpt-4")
 49		_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
 50
 51		require.Len(t, *captured, 1)
 52		assert.Equal(t, "Charm Fantasy/"+fantasy.Version, (*captured)[0]["User-Agent"])
 53	})
 54
 55	t.Run("WithUserAgent wins over default", func(t *testing.T) {
 56		t.Parallel()
 57		server, captured := newUAServer()
 58		defer server.Close()
 59
 60		p, err := New(WithAPIKey("k"), WithBaseURL(server.URL), WithUserAgent("explicit-ua"))
 61		require.NoError(t, err)
 62		model, _ := p.LanguageModel(t.Context(), "gpt-4")
 63		_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
 64
 65		require.Len(t, *captured, 1)
 66		assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"])
 67	})
 68
 69	t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) {
 70		t.Parallel()
 71		server, captured := newUAServer()
 72		defer server.Close()
 73
 74		p, err := New(
 75			WithAPIKey("k"),
 76			WithBaseURL(server.URL),
 77			WithHeaders(map[string]string{"User-Agent": "from-headers"}),
 78			WithUserAgent("explicit-ua"),
 79		)
 80		require.NoError(t, err)
 81		model, _ := p.LanguageModel(t.Context(), "gpt-4")
 82		_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})
 83
 84		require.Len(t, *captured, 1)
 85		assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"])
 86	})
 87}
 88
 89func mockOpenAIResponse() map[string]any {
 90	return map[string]any{
 91		"id":      "chatcmpl-test",
 92		"object":  "chat.completion",
 93		"created": 1711115037,
 94		"model":   "gpt-4",
 95		"choices": []map[string]any{
 96			{
 97				"index": 0,
 98				"message": map[string]any{
 99					"role":    "assistant",
100					"content": "Hi there",
101				},
102				"finish_reason": "stop",
103			},
104		},
105		"usage": map[string]any{
106			"prompt_tokens":     4,
107			"total_tokens":      6,
108			"completion_tokens": 2,
109		},
110	}
111}