1package app
2
3import (
4 "testing"
5
6 "github.com/charmbracelet/catwalk/pkg/catwalk"
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/stretchr/testify/require"
9)
10
11func TestParseModelStr(t *testing.T) {
12 tests := []struct {
13 name string
14 modelStr string
15 expectedFilter string
16 expectedModelID string
17 setupProviders func() map[string]config.ProviderConfig
18 }{
19 {
20 name: "simple model with no slashes",
21 modelStr: "gpt-4o",
22 expectedFilter: "",
23 expectedModelID: "gpt-4o",
24 setupProviders: setupMockProviders,
25 },
26 {
27 name: "valid provider and model",
28 modelStr: "openai/gpt-4o",
29 expectedFilter: "openai",
30 expectedModelID: "gpt-4o",
31 setupProviders: setupMockProviders,
32 },
33 {
34 name: "model with multiple slashes and first part is invalid provider",
35 modelStr: "moonshot/kimi-k2",
36 expectedFilter: "",
37 expectedModelID: "moonshot/kimi-k2",
38 setupProviders: setupMockProviders,
39 },
40 {
41 name: "full path with valid provider and model with slashes",
42 modelStr: "synthetic/moonshot/kimi-k2",
43 expectedFilter: "synthetic",
44 expectedModelID: "moonshot/kimi-k2",
45 setupProviders: setupMockProvidersWithSlashes,
46 },
47 {
48 name: "empty model string",
49 modelStr: "",
50 expectedFilter: "",
51 expectedModelID: "",
52 setupProviders: setupMockProviders,
53 },
54 {
55 name: "model with trailing slash but valid provider",
56 modelStr: "openai/",
57 expectedFilter: "openai",
58 expectedModelID: "",
59 setupProviders: setupMockProviders,
60 },
61 }
62
63 for _, tt := range tests {
64 t.Run(tt.name, func(t *testing.T) {
65 providers := tt.setupProviders()
66 filter, modelID := parseModelStr(providers, tt.modelStr)
67
68 require.Equal(t, tt.expectedFilter, filter, "provider filter mismatch")
69 require.Equal(t, tt.expectedModelID, modelID, "model ID mismatch")
70 })
71 }
72}
73
74func setupMockProviders() map[string]config.ProviderConfig {
75 return map[string]config.ProviderConfig{
76 "openai": {
77 ID: "openai",
78 Name: "OpenAI",
79 Models: []catwalk.Model{{ID: "gpt-4o"}, {ID: "gpt-4o-mini"}},
80 },
81 "anthropic": {
82 ID: "anthropic",
83 Name: "Anthropic",
84 Models: []catwalk.Model{{ID: "claude-3-sonnet"}, {ID: "claude-3-opus"}},
85 },
86 }
87}
88
89func setupMockProvidersWithSlashes() map[string]config.ProviderConfig {
90 return map[string]config.ProviderConfig{
91 "synthetic": {
92 ID: "synthetic",
93 Name: "Synthetic",
94 Models: []catwalk.Model{
95 {ID: "moonshot/kimi-k2"},
96 {ID: "deepseek/deepseek-chat"},
97 },
98 },
99 "openai": {
100 ID: "openai",
101 Name: "OpenAI",
102 Models: []catwalk.Model{{ID: "gpt-4o"}},
103 },
104 }
105}
106
107func TestFindModels(t *testing.T) {
108 tests := []struct {
109 name string
110 modelStr string
111 expectedProvider string
112 expectedModelID string
113 expectError bool
114 errorContains string
115 setupProviders func() map[string]config.ProviderConfig
116 }{
117 {
118 name: "simple model found in one provider",
119 modelStr: "gpt-4o",
120 expectedProvider: "openai",
121 expectedModelID: "gpt-4o",
122 expectError: false,
123 setupProviders: setupMockProviders,
124 },
125 {
126 name: "model with slashes in ID",
127 modelStr: "moonshot/kimi-k2",
128 expectedProvider: "synthetic",
129 expectedModelID: "moonshot/kimi-k2",
130 expectError: false,
131 setupProviders: setupMockProvidersWithSlashes,
132 },
133 {
134 name: "provider and model with slashes in ID",
135 modelStr: "synthetic/moonshot/kimi-k2",
136 expectedProvider: "synthetic",
137 expectedModelID: "moonshot/kimi-k2",
138 expectError: false,
139 setupProviders: setupMockProvidersWithSlashes,
140 },
141 {
142 name: "model not found",
143 modelStr: "nonexistent-model",
144 expectError: true,
145 errorContains: "not found",
146 setupProviders: setupMockProviders,
147 },
148 {
149 name: "invalid provider specified",
150 modelStr: "nonexistent-provider/gpt-4o",
151 expectError: true,
152 errorContains: "provider",
153 setupProviders: setupMockProviders,
154 },
155 {
156 name: "model found in multiple providers without provider filter",
157 modelStr: "shared-model",
158 expectError: true,
159 errorContains: "multiple providers",
160 setupProviders: func() map[string]config.ProviderConfig {
161 return map[string]config.ProviderConfig{
162 "openai": {
163 ID: "openai",
164 Models: []catwalk.Model{{ID: "shared-model"}},
165 },
166 "anthropic": {
167 ID: "anthropic",
168 Models: []catwalk.Model{{ID: "shared-model"}},
169 },
170 }
171 },
172 },
173 {
174 name: "empty model string",
175 modelStr: "",
176 expectError: true,
177 errorContains: "not found",
178 setupProviders: setupMockProviders,
179 },
180 }
181
182 for _, tt := range tests {
183 t.Run(tt.name, func(t *testing.T) {
184 providers := tt.setupProviders()
185
186 // Use findModels with the model as "large" and empty "small".
187 matches, _, err := findModels(providers, tt.modelStr, "")
188 if err != nil {
189 if tt.expectError {
190 require.Contains(t, err.Error(), tt.errorContains)
191 } else {
192 require.NoError(t, err)
193 }
194 return
195 }
196
197 // Validate the matches.
198 match, err := validateMatches(matches, tt.modelStr, "large")
199
200 if tt.expectError {
201 require.Error(t, err)
202 require.Contains(t, err.Error(), tt.errorContains)
203 } else {
204 require.NoError(t, err)
205 require.Equal(t, tt.expectedProvider, match.provider)
206 require.Equal(t, tt.expectedModelID, match.modelID)
207 }
208 })
209 }
210}