model_family_test.go

 1package agent
 2
 3import (
 4	"testing"
 5
 6	"github.com/stretchr/testify/require"
 7)
 8
 9func TestDetectModelFamily(t *testing.T) {
10	t.Parallel()
11
12	tests := []struct {
13		name     string
14		model    string
15		expected ModelFamily
16	}{
17		// Anthropic models
18		{"claude-3-5-sonnet", "claude-3-5-sonnet-20241022", ModelFamilyAnthropic},
19		{"claude-3-opus", "claude-3-opus-20240229", ModelFamilyAnthropic},
20		{"claude-2", "claude-2.1", ModelFamilyAnthropic},
21		{"claude-instant", "claude-instant-1.2", ModelFamilyAnthropic},
22
23		// OpenAI models
24		{"gpt-4", "gpt-4-turbo", ModelFamilyOpenAI},
25		{"gpt-4o", "gpt-4o", ModelFamilyOpenAI},
26		{"gpt-3.5-turbo", "gpt-3.5-turbo", ModelFamilyOpenAI},
27		{"o1-preview", "o1-preview", ModelFamilyOpenAI},
28		{"o1-mini", "o1-mini", ModelFamilyOpenAI},
29		{"chatgpt", "chatgpt-4o-latest", ModelFamilyOpenAI},
30
31		// Google models
32		{"gemini-pro", "gemini-pro", ModelFamilyGoogle},
33		{"gemini-1.5-pro", "gemini-1.5-pro-latest", ModelFamilyGoogle},
34		{"gemini-1.5-flash", "gemini-1.5-flash-002", ModelFamilyGoogle},
35
36		// Default/unknown models
37		{"llama", "llama-3-70b", ModelFamilyDefault},
38		{"mistral", "mistral-large", ModelFamilyDefault},
39		{"unknown", "some-unknown-model", ModelFamilyDefault},
40	}
41
42	for _, tt := range tests {
43		t.Run(tt.name, func(t *testing.T) {
44			t.Parallel()
45			result := DetectModelFamily(tt.model)
46			require.Equal(t, tt.expected, result, "model: %s", tt.model)
47		})
48	}
49}