1package models
2
3import (
4 "context"
5 "log/slog"
6 "net/http"
7 "testing"
8
9 "shelley.exe.dev/llm"
10)
11
12func TestAll(t *testing.T) {
13 models := All()
14 if len(models) == 0 {
15 t.Fatal("expected at least one model")
16 }
17
18 // Verify all models have required fields
19 for _, m := range models {
20 if m.ID == "" {
21 t.Errorf("model missing ID")
22 }
23 if m.Provider == "" {
24 t.Errorf("model %s missing Provider", m.ID)
25 }
26 if m.Factory == nil {
27 t.Errorf("model %s missing Factory", m.ID)
28 }
29 }
30}
31
32func TestByID(t *testing.T) {
33 tests := []struct {
34 id string
35 wantID string
36 wantNil bool
37 }{
38 {id: "qwen3-coder-fireworks", wantID: "qwen3-coder-fireworks", wantNil: false},
39 {id: "gpt-5.2-codex", wantID: "gpt-5.2-codex", wantNil: false},
40 {id: "claude-sonnet-4.5", wantID: "claude-sonnet-4.5", wantNil: false},
41 {id: "claude-haiku-4.5", wantID: "claude-haiku-4.5", wantNil: false},
42 {id: "claude-opus-4.5", wantID: "claude-opus-4.5", wantNil: false},
43 {id: "nonexistent", wantNil: true},
44 }
45
46 for _, tt := range tests {
47 t.Run(tt.id, func(t *testing.T) {
48 m := ByID(tt.id)
49 if tt.wantNil {
50 if m != nil {
51 t.Errorf("ByID(%q) = %v, want nil", tt.id, m)
52 }
53 } else {
54 if m == nil {
55 t.Fatalf("ByID(%q) = nil, want non-nil", tt.id)
56 }
57 if m.ID != tt.wantID {
58 t.Errorf("ByID(%q).ID = %q, want %q", tt.id, m.ID, tt.wantID)
59 }
60 }
61 })
62 }
63}
64
65func TestDefault(t *testing.T) {
66 d := Default()
67 if d.ID != "claude-opus-4.5" {
68 t.Errorf("Default().ID = %q, want %q", d.ID, "claude-opus-4.5")
69 }
70}
71
72func TestIDs(t *testing.T) {
73 ids := IDs()
74 if len(ids) == 0 {
75 t.Fatal("expected at least one model ID")
76 }
77
78 // Verify all IDs are unique
79 seen := make(map[string]bool)
80 for _, id := range ids {
81 if seen[id] {
82 t.Errorf("duplicate model ID: %s", id)
83 }
84 seen[id] = true
85 }
86}
87
88func TestFactory(t *testing.T) {
89 // Test that we can create services with empty config (should fail for most models)
90 cfg := &Config{}
91
92 // Predictable should work without any config
93 m := ByID("predictable")
94 if m == nil {
95 t.Fatal("predictable model not found")
96 }
97
98 svc, err := m.Factory(cfg, nil)
99 if err != nil {
100 t.Fatalf("predictable Factory() failed: %v", err)
101 }
102 if svc == nil {
103 t.Fatal("predictable Factory() returned nil service")
104 }
105}
106
107func TestManagerGetAvailableModelsOrder(t *testing.T) {
108 // Test that GetAvailableModels returns models in consistent order
109 cfg := &Config{}
110
111 // Create manager - should only have predictable model since no API keys
112 manager, err := NewManager(cfg)
113 if err != nil {
114 t.Fatalf("NewManager failed: %v", err)
115 }
116
117 // Get available models multiple times
118 firstCall := manager.GetAvailableModels()
119 secondCall := manager.GetAvailableModels()
120 thirdCall := manager.GetAvailableModels()
121
122 // Should return at least predictable model
123 if len(firstCall) == 0 {
124 t.Fatal("expected at least one model")
125 }
126
127 // All calls should return identical order
128 if len(firstCall) != len(secondCall) || len(firstCall) != len(thirdCall) {
129 t.Errorf("calls returned different lengths: %d, %d, %d", len(firstCall), len(secondCall), len(thirdCall))
130 }
131
132 for i := range firstCall {
133 if firstCall[i] != secondCall[i] {
134 t.Errorf("call 1 and 2 differ at index %d: %q vs %q", i, firstCall[i], secondCall[i])
135 }
136 if firstCall[i] != thirdCall[i] {
137 t.Errorf("call 1 and 3 differ at index %d: %q vs %q", i, firstCall[i], thirdCall[i])
138 }
139 }
140}
141
142func TestManagerGetAvailableModelsMatchesAllOrder(t *testing.T) {
143 // Test that available models are returned in the same order as All()
144 cfg := &Config{
145 AnthropicAPIKey: "test-key",
146 OpenAIAPIKey: "test-key",
147 GeminiAPIKey: "test-key",
148 FireworksAPIKey: "test-key",
149 }
150
151 manager, err := NewManager(cfg)
152 if err != nil {
153 t.Fatalf("NewManager failed: %v", err)
154 }
155
156 available := manager.GetAvailableModels()
157 all := All()
158
159 // Build expected order from All()
160 var expected []string
161 for _, m := range all {
162 if manager.HasModel(m.ID) {
163 expected = append(expected, m.ID)
164 }
165 }
166
167 // Should match
168 if len(available) != len(expected) {
169 t.Fatalf("available models count %d != expected count %d", len(available), len(expected))
170 }
171
172 for i := range available {
173 if available[i] != expected[i] {
174 t.Errorf("model at index %d: got %q, want %q", i, available[i], expected[i])
175 }
176 }
177}
178
179func TestLoggingService(t *testing.T) {
180 // Create a mock service for testing
181 mockService := &mockLLMService{}
182 logger := slog.Default()
183
184 loggingSvc := &loggingService{
185 service: mockService,
186 logger: logger,
187 modelID: "test-model",
188 provider: ProviderBuiltIn,
189 }
190
191 // Test Do method
192 ctx := context.Background()
193 request := &llm.Request{
194 Messages: []llm.Message{
195 llm.UserStringMessage("Hello"),
196 },
197 }
198
199 response, err := loggingSvc.Do(ctx, request)
200 if err != nil {
201 t.Errorf("Do returned unexpected error: %v", err)
202 }
203
204 if response == nil {
205 t.Error("Do returned nil response")
206 }
207
208 // Test TokenContextWindow
209 window := loggingSvc.TokenContextWindow()
210 if window != mockService.TokenContextWindow() {
211 t.Errorf("TokenContextWindow returned %d, expected %d", window, mockService.TokenContextWindow())
212 }
213
214 // Test MaxImageDimension
215 dimension := loggingSvc.MaxImageDimension()
216 if dimension != mockService.MaxImageDimension() {
217 t.Errorf("MaxImageDimension returned %d, expected %d", dimension, mockService.MaxImageDimension())
218 }
219
220 // Test UseSimplifiedPatch
221 useSimplified := loggingSvc.UseSimplifiedPatch()
222 if useSimplified != mockService.UseSimplifiedPatch() {
223 t.Errorf("UseSimplifiedPatch returned %t, expected %t", useSimplified, mockService.UseSimplifiedPatch())
224 }
225}
226
227// mockLLMService implements llm.Service for testing
228type mockLLMService struct {
229 tokenContextWindow int
230 maxImageDimension int
231 useSimplifiedPatch bool
232}
233
234func (m *mockLLMService) Do(ctx context.Context, request *llm.Request) (*llm.Response, error) {
235 return &llm.Response{
236 Content: llm.TextContent("Hello, world!"),
237 Usage: llm.Usage{
238 InputTokens: 10,
239 OutputTokens: 5,
240 CostUSD: 0.001,
241 },
242 }, nil
243}
244
245func (m *mockLLMService) TokenContextWindow() int {
246 if m.tokenContextWindow == 0 {
247 return 4096
248 }
249 return m.tokenContextWindow
250}
251
252func (m *mockLLMService) MaxImageDimension() int {
253 if m.maxImageDimension == 0 {
254 return 2048
255 }
256 return m.maxImageDimension
257}
258
259func (m *mockLLMService) UseSimplifiedPatch() bool {
260 return m.useSimplifiedPatch
261}
262
263func TestManagerGetService(t *testing.T) {
264 // Test with predictable model (no API keys needed)
265 cfg := &Config{}
266
267 manager, err := NewManager(cfg)
268 if err != nil {
269 t.Fatalf("NewManager failed: %v", err)
270 }
271
272 // Test getting predictable service (should work)
273 svc, err := manager.GetService("predictable")
274 if err != nil {
275 t.Errorf("GetService('predictable') failed: %v", err)
276 }
277 if svc == nil {
278 t.Error("GetService('predictable') returned nil service")
279 }
280
281 // Test getting non-existent service
282 _, err = manager.GetService("non-existent-model")
283 if err == nil {
284 t.Error("GetService('non-existent-model') should have failed but didn't")
285 }
286}
287
288func TestManagerHasModel(t *testing.T) {
289 cfg := &Config{}
290
291 manager, err := NewManager(cfg)
292 if err != nil {
293 t.Fatalf("NewManager failed: %v", err)
294 }
295
296 // Should have predictable model
297 if !manager.HasModel("predictable") {
298 t.Error("HasModel('predictable') should return true")
299 }
300
301 // Should not have models requiring API keys
302 if manager.HasModel("claude-opus-4.5") {
303 t.Error("HasModel('claude-opus-4.5') should return false without API key")
304 }
305
306 // Should not have non-existent model
307 if manager.HasModel("non-existent-model") {
308 t.Error("HasModel('non-existent-model') should return false")
309 }
310}
311
312func TestConfigGetURLMethods(t *testing.T) {
313 // Test getGeminiURL with no gateway
314 cfg := &Config{}
315 if cfg.getGeminiURL() != "" {
316 t.Errorf("getGeminiURL with no gateway should return empty string, got %q", cfg.getGeminiURL())
317 }
318
319 // Test getGeminiURL with gateway
320 cfg.Gateway = "https://gateway.example.com"
321 expected := "https://gateway.example.com/_/gateway/gemini/v1/models/generate"
322 if cfg.getGeminiURL() != expected {
323 t.Errorf("getGeminiURL with gateway should return %q, got %q", expected, cfg.getGeminiURL())
324 }
325
326 // Test other URL methods for completeness
327 if cfg.getAnthropicURL() != "https://gateway.example.com/_/gateway/anthropic/v1/messages" {
328 t.Error("getAnthropicURL did not return expected URL with gateway")
329 }
330
331 if cfg.getOpenAIURL() != "https://gateway.example.com/_/gateway/openai/v1" {
332 t.Error("getOpenAIURL did not return expected URL with gateway")
333 }
334
335 if cfg.getFireworksURL() != "https://gateway.example.com/_/gateway/fireworks/inference/v1" {
336 t.Error("getFireworksURL did not return expected URL with gateway")
337 }
338}
339
340func TestUseSimplifiedPatch(t *testing.T) {
341 // Test with a service that doesn't implement SimplifiedPatcher
342 mockService := &mockLLMService{}
343 logger := slog.Default()
344
345 loggingSvc := &loggingService{
346 service: mockService,
347 logger: logger,
348 modelID: "test-model",
349 provider: ProviderBuiltIn,
350 }
351
352 // Should return false since mockService doesn't implement SimplifiedPatcher
353 result := loggingSvc.UseSimplifiedPatch()
354 if result != false {
355 t.Errorf("UseSimplifiedPatch should return false for non-SimplifiedPatcher, got %t", result)
356 }
357
358 // Test with a service that implements SimplifiedPatcher
359 mockSimplifiedService := &mockSimplifiedLLMService{useSimplified: true}
360 loggingSvc2 := &loggingService{
361 service: mockSimplifiedService,
362 logger: logger,
363 modelID: "test-model-2",
364 provider: ProviderBuiltIn,
365 }
366
367 // Should return true since mockSimplifiedService implements SimplifiedPatcher and returns true
368 result = loggingSvc2.UseSimplifiedPatch()
369 if result != true {
370 t.Errorf("UseSimplifiedPatch should return true for SimplifiedPatcher returning true, got %t", result)
371 }
372}
373
374// mockSimplifiedLLMService implements llm.Service and llm.SimplifiedPatcher for testing
375type mockSimplifiedLLMService struct {
376 mockLLMService
377 useSimplified bool
378}
379
380func (m *mockSimplifiedLLMService) UseSimplifiedPatch() bool {
381 return m.useSimplified
382}
383
384func TestHTTPClientPassedToFactory(t *testing.T) {
385 // Test that HTTP client is passed to factory and used by services
386 cfg := &Config{
387 AnthropicAPIKey: "test-key",
388 }
389
390 // Create a custom HTTP client
391 customClient := &http.Client{}
392
393 // Test that claude factory accepts HTTP client
394 m := ByID("claude-opus-4.5")
395 if m == nil {
396 t.Fatal("claude-opus-4.5 model not found")
397 }
398
399 svc, err := m.Factory(cfg, customClient)
400 if err != nil {
401 t.Fatalf("Factory with custom HTTP client failed: %v", err)
402 }
403 if svc == nil {
404 t.Fatal("Factory returned nil service")
405 }
406}
407
408func TestGetModelSource(t *testing.T) {
409 tests := []struct {
410 name string
411 cfg *Config
412 modelID string
413 want string
414 }{
415 {
416 name: "anthropic with env var only",
417 cfg: &Config{AnthropicAPIKey: "test-key"},
418 modelID: "claude-opus-4.5",
419 want: "$ANTHROPIC_API_KEY",
420 },
421 {
422 name: "anthropic with gateway implicit key",
423 cfg: &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "implicit"},
424 modelID: "claude-opus-4.5",
425 want: "exe.dev gateway",
426 },
427 {
428 name: "anthropic with gateway but explicit key",
429 cfg: &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "actual-key"},
430 modelID: "claude-opus-4.5",
431 want: "$ANTHROPIC_API_KEY",
432 },
433 {
434 name: "fireworks with env var only",
435 cfg: &Config{FireworksAPIKey: "test-key"},
436 modelID: "qwen3-coder-fireworks",
437 want: "$FIREWORKS_API_KEY",
438 },
439 {
440 name: "fireworks with gateway implicit key",
441 cfg: &Config{Gateway: "https://gateway.example.com", FireworksAPIKey: "implicit"},
442 modelID: "qwen3-coder-fireworks",
443 want: "exe.dev gateway",
444 },
445 {
446 name: "openai with env var only",
447 cfg: &Config{OpenAIAPIKey: "test-key"},
448 modelID: "gpt-5.2-codex",
449 want: "$OPENAI_API_KEY",
450 },
451 {
452 name: "gemini with env var only",
453 cfg: &Config{GeminiAPIKey: "test-key"},
454 modelID: "gemini-3-pro",
455 want: "$GEMINI_API_KEY",
456 },
457 {
458 name: "predictable has no source",
459 cfg: &Config{},
460 modelID: "predictable",
461 want: "",
462 },
463 }
464
465 for _, tt := range tests {
466 t.Run(tt.name, func(t *testing.T) {
467 manager, err := NewManager(tt.cfg)
468 if err != nil {
469 t.Fatalf("NewManager failed: %v", err)
470 }
471
472 info := manager.GetModelInfo(tt.modelID)
473 if info == nil {
474 t.Fatalf("GetModelInfo(%q) returned nil", tt.modelID)
475 }
476 if info.Source != tt.want {
477 t.Errorf("GetModelInfo(%q).Source = %q, want %q", tt.modelID, info.Source, tt.want)
478 }
479 })
480 }
481}
482
483func TestGetAvailableModelsUnion(t *testing.T) {
484 // Test that GetAvailableModels returns both built-in and custom models
485 // This test just verifies the union behavior with built-in models only
486 // (testing with custom models requires a database)
487 cfg := &Config{
488 AnthropicAPIKey: "test-key",
489 FireworksAPIKey: "test-key",
490 }
491
492 manager, err := NewManager(cfg)
493 if err != nil {
494 t.Fatalf("NewManager failed: %v", err)
495 }
496
497 models := manager.GetAvailableModels()
498
499 // Should have anthropic models and fireworks models, plus predictable
500 expectedModels := []string{"claude-opus-4.5", "qwen3-coder-fireworks", "glm-4p6-fireworks", "claude-sonnet-4.5", "claude-haiku-4.5", "predictable"}
501 for _, expected := range expectedModels {
502 found := false
503 for _, m := range models {
504 if m == expected {
505 found = true
506 break
507 }
508 }
509 if !found {
510 t.Errorf("Expected model %q not found in available models: %v", expected, models)
511 }
512 }
513}