1package models
2
3import (
4 "context"
5 "log/slog"
6 "net/http"
7 "testing"
8
9 "shelley.exe.dev/llm"
10)
11
12func TestAll(t *testing.T) {
13 models := All()
14 if len(models) == 0 {
15 t.Fatal("expected at least one model")
16 }
17
18 // Verify all models have required fields
19 for _, m := range models {
20 if m.ID == "" {
21 t.Errorf("model missing ID")
22 }
23 if m.Provider == "" {
24 t.Errorf("model %s missing Provider", m.ID)
25 }
26 if m.Factory == nil {
27 t.Errorf("model %s missing Factory", m.ID)
28 }
29 }
30}
31
32func TestByID(t *testing.T) {
33 tests := []struct {
34 id string
35 wantID string
36 wantNil bool
37 }{
38 {id: "qwen3-coder-fireworks", wantID: "qwen3-coder-fireworks", wantNil: false},
39 {id: "gpt-5.2-codex", wantID: "gpt-5.2-codex", wantNil: false},
40 {id: "claude-sonnet-4.5", wantID: "claude-sonnet-4.5", wantNil: false},
41 {id: "claude-haiku-4.5", wantID: "claude-haiku-4.5", wantNil: false},
42 {id: "claude-opus-4.5", wantID: "claude-opus-4.5", wantNil: false},
43 {id: "claude-opus-4.6", wantID: "claude-opus-4.6", wantNil: false},
44 {id: "nonexistent", wantNil: true},
45 }
46
47 for _, tt := range tests {
48 t.Run(tt.id, func(t *testing.T) {
49 m := ByID(tt.id)
50 if tt.wantNil {
51 if m != nil {
52 t.Errorf("ByID(%q) = %v, want nil", tt.id, m)
53 }
54 } else {
55 if m == nil {
56 t.Fatalf("ByID(%q) = nil, want non-nil", tt.id)
57 }
58 if m.ID != tt.wantID {
59 t.Errorf("ByID(%q).ID = %q, want %q", tt.id, m.ID, tt.wantID)
60 }
61 }
62 })
63 }
64}
65
66func TestDefault(t *testing.T) {
67 d := Default()
68 if d.ID != "claude-opus-4.6" {
69 t.Errorf("Default().ID = %q, want %q", d.ID, "claude-opus-4.6")
70 }
71}
72
73func TestIDs(t *testing.T) {
74 ids := IDs()
75 if len(ids) == 0 {
76 t.Fatal("expected at least one model ID")
77 }
78
79 // Verify all IDs are unique
80 seen := make(map[string]bool)
81 for _, id := range ids {
82 if seen[id] {
83 t.Errorf("duplicate model ID: %s", id)
84 }
85 seen[id] = true
86 }
87}
88
89func TestFactory(t *testing.T) {
90 // Test that we can create services with empty config (should fail for most models)
91 cfg := &Config{}
92
93 // Predictable should work without any config
94 m := ByID("predictable")
95 if m == nil {
96 t.Fatal("predictable model not found")
97 }
98
99 svc, err := m.Factory(cfg, nil)
100 if err != nil {
101 t.Fatalf("predictable Factory() failed: %v", err)
102 }
103 if svc == nil {
104 t.Fatal("predictable Factory() returned nil service")
105 }
106}
107
108func TestManagerGetAvailableModelsOrder(t *testing.T) {
109 // Test that GetAvailableModels returns models in consistent order
110 cfg := &Config{}
111
112 // Create manager - should only have predictable model since no API keys
113 manager, err := NewManager(cfg)
114 if err != nil {
115 t.Fatalf("NewManager failed: %v", err)
116 }
117
118 // Get available models multiple times
119 firstCall := manager.GetAvailableModels()
120 secondCall := manager.GetAvailableModels()
121 thirdCall := manager.GetAvailableModels()
122
123 // Should return at least predictable model
124 if len(firstCall) == 0 {
125 t.Fatal("expected at least one model")
126 }
127
128 // All calls should return identical order
129 if len(firstCall) != len(secondCall) || len(firstCall) != len(thirdCall) {
130 t.Errorf("calls returned different lengths: %d, %d, %d", len(firstCall), len(secondCall), len(thirdCall))
131 }
132
133 for i := range firstCall {
134 if firstCall[i] != secondCall[i] {
135 t.Errorf("call 1 and 2 differ at index %d: %q vs %q", i, firstCall[i], secondCall[i])
136 }
137 if firstCall[i] != thirdCall[i] {
138 t.Errorf("call 1 and 3 differ at index %d: %q vs %q", i, firstCall[i], thirdCall[i])
139 }
140 }
141}
142
143func TestManagerGetAvailableModelsMatchesAllOrder(t *testing.T) {
144 // Test that available models are returned in the same order as All()
145 cfg := &Config{
146 AnthropicAPIKey: "test-key",
147 OpenAIAPIKey: "test-key",
148 GeminiAPIKey: "test-key",
149 FireworksAPIKey: "test-key",
150 }
151
152 manager, err := NewManager(cfg)
153 if err != nil {
154 t.Fatalf("NewManager failed: %v", err)
155 }
156
157 available := manager.GetAvailableModels()
158 all := All()
159
160 // Build expected order from All()
161 var expected []string
162 for _, m := range all {
163 if manager.HasModel(m.ID) {
164 expected = append(expected, m.ID)
165 }
166 }
167
168 // Should match
169 if len(available) != len(expected) {
170 t.Fatalf("available models count %d != expected count %d", len(available), len(expected))
171 }
172
173 for i := range available {
174 if available[i] != expected[i] {
175 t.Errorf("model at index %d: got %q, want %q", i, available[i], expected[i])
176 }
177 }
178}
179
180func TestLoggingService(t *testing.T) {
181 // Create a mock service for testing
182 mockService := &mockLLMService{}
183 logger := slog.Default()
184
185 loggingSvc := &loggingService{
186 service: mockService,
187 logger: logger,
188 modelID: "test-model",
189 provider: ProviderBuiltIn,
190 }
191
192 // Test Do method
193 ctx := context.Background()
194 request := &llm.Request{
195 Messages: []llm.Message{
196 llm.UserStringMessage("Hello"),
197 },
198 }
199
200 response, err := loggingSvc.Do(ctx, request)
201 if err != nil {
202 t.Errorf("Do returned unexpected error: %v", err)
203 }
204
205 if response == nil {
206 t.Error("Do returned nil response")
207 }
208
209 // Test TokenContextWindow
210 window := loggingSvc.TokenContextWindow()
211 if window != mockService.TokenContextWindow() {
212 t.Errorf("TokenContextWindow returned %d, expected %d", window, mockService.TokenContextWindow())
213 }
214
215 // Test MaxImageDimension
216 dimension := loggingSvc.MaxImageDimension()
217 if dimension != mockService.MaxImageDimension() {
218 t.Errorf("MaxImageDimension returned %d, expected %d", dimension, mockService.MaxImageDimension())
219 }
220
221 // Test UseSimplifiedPatch
222 useSimplified := loggingSvc.UseSimplifiedPatch()
223 if useSimplified != mockService.UseSimplifiedPatch() {
224 t.Errorf("UseSimplifiedPatch returned %t, expected %t", useSimplified, mockService.UseSimplifiedPatch())
225 }
226}
227
228// mockLLMService implements llm.Service for testing
229type mockLLMService struct {
230 tokenContextWindow int
231 maxImageDimension int
232 useSimplifiedPatch bool
233}
234
235func (m *mockLLMService) Do(ctx context.Context, request *llm.Request) (*llm.Response, error) {
236 return &llm.Response{
237 Content: llm.TextContent("Hello, world!"),
238 Usage: llm.Usage{
239 InputTokens: 10,
240 OutputTokens: 5,
241 CostUSD: 0.001,
242 },
243 }, nil
244}
245
246func (m *mockLLMService) TokenContextWindow() int {
247 if m.tokenContextWindow == 0 {
248 return 4096
249 }
250 return m.tokenContextWindow
251}
252
253func (m *mockLLMService) MaxImageDimension() int {
254 if m.maxImageDimension == 0 {
255 return 2048
256 }
257 return m.maxImageDimension
258}
259
260func (m *mockLLMService) UseSimplifiedPatch() bool {
261 return m.useSimplifiedPatch
262}
263
264func TestManagerGetService(t *testing.T) {
265 // Test with predictable model (no API keys needed)
266 cfg := &Config{}
267
268 manager, err := NewManager(cfg)
269 if err != nil {
270 t.Fatalf("NewManager failed: %v", err)
271 }
272
273 // Test getting predictable service (should work)
274 svc, err := manager.GetService("predictable")
275 if err != nil {
276 t.Errorf("GetService('predictable') failed: %v", err)
277 }
278 if svc == nil {
279 t.Error("GetService('predictable') returned nil service")
280 }
281
282 // Test getting non-existent service
283 _, err = manager.GetService("non-existent-model")
284 if err == nil {
285 t.Error("GetService('non-existent-model') should have failed but didn't")
286 }
287}
288
289func TestManagerHasModel(t *testing.T) {
290 cfg := &Config{}
291
292 manager, err := NewManager(cfg)
293 if err != nil {
294 t.Fatalf("NewManager failed: %v", err)
295 }
296
297 // Should have predictable model
298 if !manager.HasModel("predictable") {
299 t.Error("HasModel('predictable') should return true")
300 }
301
302 // Should not have models requiring API keys
303 if manager.HasModel("claude-opus-4.6") {
304 t.Error("HasModel('claude-opus-4.6') should return false without API key")
305 }
306
307 // Should not have non-existent model
308 if manager.HasModel("non-existent-model") {
309 t.Error("HasModel('non-existent-model') should return false")
310 }
311}
312
313func TestConfigGetURLMethods(t *testing.T) {
314 // Test getGeminiURL with no gateway
315 cfg := &Config{}
316 if cfg.getGeminiURL() != "" {
317 t.Errorf("getGeminiURL with no gateway should return empty string, got %q", cfg.getGeminiURL())
318 }
319
320 // Test getGeminiURL with gateway
321 cfg.Gateway = "https://gateway.example.com"
322 expected := "https://gateway.example.com/_/gateway/gemini/v1/models/generate"
323 if cfg.getGeminiURL() != expected {
324 t.Errorf("getGeminiURL with gateway should return %q, got %q", expected, cfg.getGeminiURL())
325 }
326
327 // Test other URL methods for completeness
328 if cfg.getAnthropicURL() != "https://gateway.example.com/_/gateway/anthropic/v1/messages" {
329 t.Error("getAnthropicURL did not return expected URL with gateway")
330 }
331
332 if cfg.getOpenAIURL() != "https://gateway.example.com/_/gateway/openai/v1" {
333 t.Error("getOpenAIURL did not return expected URL with gateway")
334 }
335
336 if cfg.getFireworksURL() != "https://gateway.example.com/_/gateway/fireworks/inference/v1" {
337 t.Error("getFireworksURL did not return expected URL with gateway")
338 }
339}
340
341func TestUseSimplifiedPatch(t *testing.T) {
342 // Test with a service that doesn't implement SimplifiedPatcher
343 mockService := &mockLLMService{}
344 logger := slog.Default()
345
346 loggingSvc := &loggingService{
347 service: mockService,
348 logger: logger,
349 modelID: "test-model",
350 provider: ProviderBuiltIn,
351 }
352
353 // Should return false since mockService doesn't implement SimplifiedPatcher
354 result := loggingSvc.UseSimplifiedPatch()
355 if result != false {
356 t.Errorf("UseSimplifiedPatch should return false for non-SimplifiedPatcher, got %t", result)
357 }
358
359 // Test with a service that implements SimplifiedPatcher
360 mockSimplifiedService := &mockSimplifiedLLMService{useSimplified: true}
361 loggingSvc2 := &loggingService{
362 service: mockSimplifiedService,
363 logger: logger,
364 modelID: "test-model-2",
365 provider: ProviderBuiltIn,
366 }
367
368 // Should return true since mockSimplifiedService implements SimplifiedPatcher and returns true
369 result = loggingSvc2.UseSimplifiedPatch()
370 if result != true {
371 t.Errorf("UseSimplifiedPatch should return true for SimplifiedPatcher returning true, got %t", result)
372 }
373}
374
375// mockSimplifiedLLMService implements llm.Service and llm.SimplifiedPatcher for testing
376type mockSimplifiedLLMService struct {
377 mockLLMService
378 useSimplified bool
379}
380
381func (m *mockSimplifiedLLMService) UseSimplifiedPatch() bool {
382 return m.useSimplified
383}
384
385func TestHTTPClientPassedToFactory(t *testing.T) {
386 // Test that HTTP client is passed to factory and used by services
387 cfg := &Config{
388 AnthropicAPIKey: "test-key",
389 }
390
391 // Create a custom HTTP client
392 customClient := &http.Client{}
393
394 // Test that claude factory accepts HTTP client
395 m := ByID("claude-opus-4.5")
396 if m == nil {
397 t.Fatal("claude-opus-4.5 model not found")
398 }
399
400 svc, err := m.Factory(cfg, customClient)
401 if err != nil {
402 t.Fatalf("Factory with custom HTTP client failed: %v", err)
403 }
404 if svc == nil {
405 t.Fatal("Factory returned nil service")
406 }
407}
408
409func TestGetModelSource(t *testing.T) {
410 tests := []struct {
411 name string
412 cfg *Config
413 modelID string
414 want string
415 }{
416 {
417 name: "anthropic with env var only",
418 cfg: &Config{AnthropicAPIKey: "test-key"},
419 modelID: "claude-opus-4.6",
420 want: "$ANTHROPIC_API_KEY",
421 },
422 {
423 name: "anthropic with gateway implicit key",
424 cfg: &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "implicit"},
425 modelID: "claude-opus-4.6",
426 want: "exe.dev gateway",
427 },
428 {
429 name: "anthropic with gateway but explicit key",
430 cfg: &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "actual-key"},
431 modelID: "claude-opus-4.6",
432 want: "$ANTHROPIC_API_KEY",
433 },
434 {
435 name: "fireworks with env var only",
436 cfg: &Config{FireworksAPIKey: "test-key"},
437 modelID: "qwen3-coder-fireworks",
438 want: "$FIREWORKS_API_KEY",
439 },
440 {
441 name: "fireworks with gateway implicit key",
442 cfg: &Config{Gateway: "https://gateway.example.com", FireworksAPIKey: "implicit"},
443 modelID: "qwen3-coder-fireworks",
444 want: "exe.dev gateway",
445 },
446 {
447 name: "openai with env var only",
448 cfg: &Config{OpenAIAPIKey: "test-key"},
449 modelID: "gpt-5.2-codex",
450 want: "$OPENAI_API_KEY",
451 },
452 {
453 name: "gemini with env var only",
454 cfg: &Config{GeminiAPIKey: "test-key"},
455 modelID: "gemini-3-pro",
456 want: "$GEMINI_API_KEY",
457 },
458 {
459 name: "predictable has no source",
460 cfg: &Config{},
461 modelID: "predictable",
462 want: "",
463 },
464 }
465
466 for _, tt := range tests {
467 t.Run(tt.name, func(t *testing.T) {
468 manager, err := NewManager(tt.cfg)
469 if err != nil {
470 t.Fatalf("NewManager failed: %v", err)
471 }
472
473 info := manager.GetModelInfo(tt.modelID)
474 if info == nil {
475 t.Fatalf("GetModelInfo(%q) returned nil", tt.modelID)
476 }
477 if info.Source != tt.want {
478 t.Errorf("GetModelInfo(%q).Source = %q, want %q", tt.modelID, info.Source, tt.want)
479 }
480 })
481 }
482}
483
484func TestGetAvailableModelsUnion(t *testing.T) {
485 // Test that GetAvailableModels returns both built-in and custom models
486 // This test just verifies the union behavior with built-in models only
487 // (testing with custom models requires a database)
488 cfg := &Config{
489 AnthropicAPIKey: "test-key",
490 FireworksAPIKey: "test-key",
491 }
492
493 manager, err := NewManager(cfg)
494 if err != nil {
495 t.Fatalf("NewManager failed: %v", err)
496 }
497
498 models := manager.GetAvailableModels()
499
500 // Should have anthropic models and fireworks models, plus predictable
501 expectedModels := []string{"claude-opus-4.6", "claude-opus-4.5", "qwen3-coder-fireworks", "glm-4p6-fireworks", "claude-sonnet-4.5", "claude-haiku-4.5", "predictable"}
502 for _, expected := range expectedModels {
503 found := false
504 for _, m := range models {
505 if m == expected {
506 found = true
507 break
508 }
509 }
510 if !found {
511 t.Errorf("Expected model %q not found in available models: %v", expected, models)
512 }
513 }
514}