1package agent
2
3import (
4 "testing"
5
6 "github.com/stretchr/testify/require"
7)
8
9func TestDetectModelFamily(t *testing.T) {
10 t.Parallel()
11
12 tests := []struct {
13 name string
14 model string
15 expected ModelFamily
16 }{
17 // Anthropic models
18 {"claude-3-5-sonnet", "claude-3-5-sonnet-20241022", ModelFamilyAnthropic},
19 {"claude-3-opus", "claude-3-opus-20240229", ModelFamilyAnthropic},
20 {"claude-2", "claude-2.1", ModelFamilyAnthropic},
21 {"claude-instant", "claude-instant-1.2", ModelFamilyAnthropic},
22
23 // OpenAI models
24 {"gpt-4", "gpt-4-turbo", ModelFamilyOpenAI},
25 {"gpt-4o", "gpt-4o", ModelFamilyOpenAI},
26 {"gpt-3.5-turbo", "gpt-3.5-turbo", ModelFamilyOpenAI},
27 {"o1-preview", "o1-preview", ModelFamilyOpenAI},
28 {"o1-mini", "o1-mini", ModelFamilyOpenAI},
29 {"chatgpt", "chatgpt-4o-latest", ModelFamilyOpenAI},
30
31 // Google models
32 {"gemini-pro", "gemini-pro", ModelFamilyGoogle},
33 {"gemini-1.5-pro", "gemini-1.5-pro-latest", ModelFamilyGoogle},
34 {"gemini-1.5-flash", "gemini-1.5-flash-002", ModelFamilyGoogle},
35
36 // Default/unknown models
37 {"llama", "llama-3-70b", ModelFamilyDefault},
38 {"mistral", "mistral-large", ModelFamilyDefault},
39 {"unknown", "some-unknown-model", ModelFamilyDefault},
40 }
41
42 for _, tt := range tests {
43 t.Run(tt.name, func(t *testing.T) {
44 t.Parallel()
45 result := DetectModelFamily(tt.model)
46 require.Equal(t, tt.expected, result, "model: %s", tt.model)
47 })
48 }
49}