models_test.go

  1package models
  2
  3import (
  4	"context"
  5	"log/slog"
  6	"net/http"
  7	"testing"
  8
  9	"shelley.exe.dev/llm"
 10)
 11
 12func TestAll(t *testing.T) {
 13	models := All()
 14	if len(models) == 0 {
 15		t.Fatal("expected at least one model")
 16	}
 17
 18	// Verify all models have required fields
 19	for _, m := range models {
 20		if m.ID == "" {
 21			t.Errorf("model missing ID")
 22		}
 23		if m.Provider == "" {
 24			t.Errorf("model %s missing Provider", m.ID)
 25		}
 26		if m.Factory == nil {
 27			t.Errorf("model %s missing Factory", m.ID)
 28		}
 29	}
 30}
 31
 32func TestByID(t *testing.T) {
 33	tests := []struct {
 34		id      string
 35		wantID  string
 36		wantNil bool
 37	}{
 38		{id: "qwen3-coder-fireworks", wantID: "qwen3-coder-fireworks", wantNil: false},
 39		{id: "gpt-5.2-codex", wantID: "gpt-5.2-codex", wantNil: false},
 40		{id: "claude-sonnet-4.5", wantID: "claude-sonnet-4.5", wantNil: false},
 41		{id: "claude-haiku-4.5", wantID: "claude-haiku-4.5", wantNil: false},
 42		{id: "claude-opus-4.5", wantID: "claude-opus-4.5", wantNil: false},
 43		{id: "claude-opus-4.6", wantID: "claude-opus-4.6", wantNil: false},
 44		{id: "nonexistent", wantNil: true},
 45	}
 46
 47	for _, tt := range tests {
 48		t.Run(tt.id, func(t *testing.T) {
 49			m := ByID(tt.id)
 50			if tt.wantNil {
 51				if m != nil {
 52					t.Errorf("ByID(%q) = %v, want nil", tt.id, m)
 53				}
 54			} else {
 55				if m == nil {
 56					t.Fatalf("ByID(%q) = nil, want non-nil", tt.id)
 57				}
 58				if m.ID != tt.wantID {
 59					t.Errorf("ByID(%q).ID = %q, want %q", tt.id, m.ID, tt.wantID)
 60				}
 61			}
 62		})
 63	}
 64}
 65
 66func TestDefault(t *testing.T) {
 67	d := Default()
 68	if d.ID != "claude-opus-4.6" {
 69		t.Errorf("Default().ID = %q, want %q", d.ID, "claude-opus-4.6")
 70	}
 71}
 72
 73func TestIDs(t *testing.T) {
 74	ids := IDs()
 75	if len(ids) == 0 {
 76		t.Fatal("expected at least one model ID")
 77	}
 78
 79	// Verify all IDs are unique
 80	seen := make(map[string]bool)
 81	for _, id := range ids {
 82		if seen[id] {
 83			t.Errorf("duplicate model ID: %s", id)
 84		}
 85		seen[id] = true
 86	}
 87}
 88
 89func TestFactory(t *testing.T) {
 90	// Test that we can create services with empty config (should fail for most models)
 91	cfg := &Config{}
 92
 93	// Predictable should work without any config
 94	m := ByID("predictable")
 95	if m == nil {
 96		t.Fatal("predictable model not found")
 97	}
 98
 99	svc, err := m.Factory(cfg, nil)
100	if err != nil {
101		t.Fatalf("predictable Factory() failed: %v", err)
102	}
103	if svc == nil {
104		t.Fatal("predictable Factory() returned nil service")
105	}
106}
107
108func TestManagerGetAvailableModelsOrder(t *testing.T) {
109	// Test that GetAvailableModels returns models in consistent order
110	cfg := &Config{}
111
112	// Create manager - should only have predictable model since no API keys
113	manager, err := NewManager(cfg)
114	if err != nil {
115		t.Fatalf("NewManager failed: %v", err)
116	}
117
118	// Get available models multiple times
119	firstCall := manager.GetAvailableModels()
120	secondCall := manager.GetAvailableModels()
121	thirdCall := manager.GetAvailableModels()
122
123	// Should return at least predictable model
124	if len(firstCall) == 0 {
125		t.Fatal("expected at least one model")
126	}
127
128	// All calls should return identical order
129	if len(firstCall) != len(secondCall) || len(firstCall) != len(thirdCall) {
130		t.Errorf("calls returned different lengths: %d, %d, %d", len(firstCall), len(secondCall), len(thirdCall))
131	}
132
133	for i := range firstCall {
134		if firstCall[i] != secondCall[i] {
135			t.Errorf("call 1 and 2 differ at index %d: %q vs %q", i, firstCall[i], secondCall[i])
136		}
137		if firstCall[i] != thirdCall[i] {
138			t.Errorf("call 1 and 3 differ at index %d: %q vs %q", i, firstCall[i], thirdCall[i])
139		}
140	}
141}
142
143func TestManagerGetAvailableModelsMatchesAllOrder(t *testing.T) {
144	// Test that available models are returned in the same order as All()
145	cfg := &Config{
146		AnthropicAPIKey: "test-key",
147		OpenAIAPIKey:    "test-key",
148		GeminiAPIKey:    "test-key",
149		FireworksAPIKey: "test-key",
150	}
151
152	manager, err := NewManager(cfg)
153	if err != nil {
154		t.Fatalf("NewManager failed: %v", err)
155	}
156
157	available := manager.GetAvailableModels()
158	all := All()
159
160	// Build expected order from All()
161	var expected []string
162	for _, m := range all {
163		if manager.HasModel(m.ID) {
164			expected = append(expected, m.ID)
165		}
166	}
167
168	// Should match
169	if len(available) != len(expected) {
170		t.Fatalf("available models count %d != expected count %d", len(available), len(expected))
171	}
172
173	for i := range available {
174		if available[i] != expected[i] {
175			t.Errorf("model at index %d: got %q, want %q", i, available[i], expected[i])
176		}
177	}
178}
179
180func TestLoggingService(t *testing.T) {
181	// Create a mock service for testing
182	mockService := &mockLLMService{}
183	logger := slog.Default()
184
185	loggingSvc := &loggingService{
186		service:  mockService,
187		logger:   logger,
188		modelID:  "test-model",
189		provider: ProviderBuiltIn,
190	}
191
192	// Test Do method
193	ctx := context.Background()
194	request := &llm.Request{
195		Messages: []llm.Message{
196			llm.UserStringMessage("Hello"),
197		},
198	}
199
200	response, err := loggingSvc.Do(ctx, request)
201	if err != nil {
202		t.Errorf("Do returned unexpected error: %v", err)
203	}
204
205	if response == nil {
206		t.Error("Do returned nil response")
207	}
208
209	// Test TokenContextWindow
210	window := loggingSvc.TokenContextWindow()
211	if window != mockService.TokenContextWindow() {
212		t.Errorf("TokenContextWindow returned %d, expected %d", window, mockService.TokenContextWindow())
213	}
214
215	// Test MaxImageDimension
216	dimension := loggingSvc.MaxImageDimension()
217	if dimension != mockService.MaxImageDimension() {
218		t.Errorf("MaxImageDimension returned %d, expected %d", dimension, mockService.MaxImageDimension())
219	}
220
221	// Test UseSimplifiedPatch
222	useSimplified := loggingSvc.UseSimplifiedPatch()
223	if useSimplified != mockService.UseSimplifiedPatch() {
224		t.Errorf("UseSimplifiedPatch returned %t, expected %t", useSimplified, mockService.UseSimplifiedPatch())
225	}
226}
227
228// mockLLMService implements llm.Service for testing
229type mockLLMService struct {
230	tokenContextWindow int
231	maxImageDimension  int
232	useSimplifiedPatch bool
233}
234
235func (m *mockLLMService) Do(ctx context.Context, request *llm.Request) (*llm.Response, error) {
236	return &llm.Response{
237		Content: llm.TextContent("Hello, world!"),
238		Usage: llm.Usage{
239			InputTokens:  10,
240			OutputTokens: 5,
241			CostUSD:      0.001,
242		},
243	}, nil
244}
245
246func (m *mockLLMService) TokenContextWindow() int {
247	if m.tokenContextWindow == 0 {
248		return 4096
249	}
250	return m.tokenContextWindow
251}
252
253func (m *mockLLMService) MaxImageDimension() int {
254	if m.maxImageDimension == 0 {
255		return 2048
256	}
257	return m.maxImageDimension
258}
259
260func (m *mockLLMService) UseSimplifiedPatch() bool {
261	return m.useSimplifiedPatch
262}
263
264func TestManagerGetService(t *testing.T) {
265	// Test with predictable model (no API keys needed)
266	cfg := &Config{}
267
268	manager, err := NewManager(cfg)
269	if err != nil {
270		t.Fatalf("NewManager failed: %v", err)
271	}
272
273	// Test getting predictable service (should work)
274	svc, err := manager.GetService("predictable")
275	if err != nil {
276		t.Errorf("GetService('predictable') failed: %v", err)
277	}
278	if svc == nil {
279		t.Error("GetService('predictable') returned nil service")
280	}
281
282	// Test getting non-existent service
283	_, err = manager.GetService("non-existent-model")
284	if err == nil {
285		t.Error("GetService('non-existent-model') should have failed but didn't")
286	}
287}
288
289func TestManagerHasModel(t *testing.T) {
290	cfg := &Config{}
291
292	manager, err := NewManager(cfg)
293	if err != nil {
294		t.Fatalf("NewManager failed: %v", err)
295	}
296
297	// Should have predictable model
298	if !manager.HasModel("predictable") {
299		t.Error("HasModel('predictable') should return true")
300	}
301
302	// Should not have models requiring API keys
303	if manager.HasModel("claude-opus-4.6") {
304		t.Error("HasModel('claude-opus-4.6') should return false without API key")
305	}
306
307	// Should not have non-existent model
308	if manager.HasModel("non-existent-model") {
309		t.Error("HasModel('non-existent-model') should return false")
310	}
311}
312
313func TestConfigGetURLMethods(t *testing.T) {
314	// Test getGeminiURL with no gateway
315	cfg := &Config{}
316	if cfg.getGeminiURL() != "" {
317		t.Errorf("getGeminiURL with no gateway should return empty string, got %q", cfg.getGeminiURL())
318	}
319
320	// Test getGeminiURL with gateway
321	cfg.Gateway = "https://gateway.example.com"
322	expected := "https://gateway.example.com/_/gateway/gemini/v1/models/generate"
323	if cfg.getGeminiURL() != expected {
324		t.Errorf("getGeminiURL with gateway should return %q, got %q", expected, cfg.getGeminiURL())
325	}
326
327	// Test other URL methods for completeness
328	if cfg.getAnthropicURL() != "https://gateway.example.com/_/gateway/anthropic/v1/messages" {
329		t.Error("getAnthropicURL did not return expected URL with gateway")
330	}
331
332	if cfg.getOpenAIURL() != "https://gateway.example.com/_/gateway/openai/v1" {
333		t.Error("getOpenAIURL did not return expected URL with gateway")
334	}
335
336	if cfg.getFireworksURL() != "https://gateway.example.com/_/gateway/fireworks/inference/v1" {
337		t.Error("getFireworksURL did not return expected URL with gateway")
338	}
339}
340
341func TestUseSimplifiedPatch(t *testing.T) {
342	// Test with a service that doesn't implement SimplifiedPatcher
343	mockService := &mockLLMService{}
344	logger := slog.Default()
345
346	loggingSvc := &loggingService{
347		service:  mockService,
348		logger:   logger,
349		modelID:  "test-model",
350		provider: ProviderBuiltIn,
351	}
352
353	// Should return false since mockService doesn't implement SimplifiedPatcher
354	result := loggingSvc.UseSimplifiedPatch()
355	if result != false {
356		t.Errorf("UseSimplifiedPatch should return false for non-SimplifiedPatcher, got %t", result)
357	}
358
359	// Test with a service that implements SimplifiedPatcher
360	mockSimplifiedService := &mockSimplifiedLLMService{useSimplified: true}
361	loggingSvc2 := &loggingService{
362		service:  mockSimplifiedService,
363		logger:   logger,
364		modelID:  "test-model-2",
365		provider: ProviderBuiltIn,
366	}
367
368	// Should return true since mockSimplifiedService implements SimplifiedPatcher and returns true
369	result = loggingSvc2.UseSimplifiedPatch()
370	if result != true {
371		t.Errorf("UseSimplifiedPatch should return true for SimplifiedPatcher returning true, got %t", result)
372	}
373}
374
375// mockSimplifiedLLMService implements llm.Service and llm.SimplifiedPatcher for testing
376type mockSimplifiedLLMService struct {
377	mockLLMService
378	useSimplified bool
379}
380
381func (m *mockSimplifiedLLMService) UseSimplifiedPatch() bool {
382	return m.useSimplified
383}
384
385func TestHTTPClientPassedToFactory(t *testing.T) {
386	// Test that HTTP client is passed to factory and used by services
387	cfg := &Config{
388		AnthropicAPIKey: "test-key",
389	}
390
391	// Create a custom HTTP client
392	customClient := &http.Client{}
393
394	// Test that claude factory accepts HTTP client
395	m := ByID("claude-opus-4.5")
396	if m == nil {
397		t.Fatal("claude-opus-4.5 model not found")
398	}
399
400	svc, err := m.Factory(cfg, customClient)
401	if err != nil {
402		t.Fatalf("Factory with custom HTTP client failed: %v", err)
403	}
404	if svc == nil {
405		t.Fatal("Factory returned nil service")
406	}
407}
408
409func TestGetModelSource(t *testing.T) {
410	tests := []struct {
411		name    string
412		cfg     *Config
413		modelID string
414		want    string
415	}{
416		{
417			name:    "anthropic with env var only",
418			cfg:     &Config{AnthropicAPIKey: "test-key"},
419			modelID: "claude-opus-4.6",
420			want:    "$ANTHROPIC_API_KEY",
421		},
422		{
423			name:    "anthropic with gateway implicit key",
424			cfg:     &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "implicit"},
425			modelID: "claude-opus-4.6",
426			want:    "exe.dev gateway",
427		},
428		{
429			name:    "anthropic with gateway but explicit key",
430			cfg:     &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "actual-key"},
431			modelID: "claude-opus-4.6",
432			want:    "$ANTHROPIC_API_KEY",
433		},
434		{
435			name:    "fireworks with env var only",
436			cfg:     &Config{FireworksAPIKey: "test-key"},
437			modelID: "qwen3-coder-fireworks",
438			want:    "$FIREWORKS_API_KEY",
439		},
440		{
441			name:    "fireworks with gateway implicit key",
442			cfg:     &Config{Gateway: "https://gateway.example.com", FireworksAPIKey: "implicit"},
443			modelID: "qwen3-coder-fireworks",
444			want:    "exe.dev gateway",
445		},
446		{
447			name:    "openai with env var only",
448			cfg:     &Config{OpenAIAPIKey: "test-key"},
449			modelID: "gpt-5.2-codex",
450			want:    "$OPENAI_API_KEY",
451		},
452		{
453			name:    "gemini with env var only",
454			cfg:     &Config{GeminiAPIKey: "test-key"},
455			modelID: "gemini-3-pro",
456			want:    "$GEMINI_API_KEY",
457		},
458		{
459			name:    "predictable has no source",
460			cfg:     &Config{},
461			modelID: "predictable",
462			want:    "",
463		},
464	}
465
466	for _, tt := range tests {
467		t.Run(tt.name, func(t *testing.T) {
468			manager, err := NewManager(tt.cfg)
469			if err != nil {
470				t.Fatalf("NewManager failed: %v", err)
471			}
472
473			info := manager.GetModelInfo(tt.modelID)
474			if info == nil {
475				t.Fatalf("GetModelInfo(%q) returned nil", tt.modelID)
476			}
477			if info.Source != tt.want {
478				t.Errorf("GetModelInfo(%q).Source = %q, want %q", tt.modelID, info.Source, tt.want)
479			}
480		})
481	}
482}
483
484func TestGetAvailableModelsUnion(t *testing.T) {
485	// Test that GetAvailableModels returns both built-in and custom models
486	// This test just verifies the union behavior with built-in models only
487	// (testing with custom models requires a database)
488	cfg := &Config{
489		AnthropicAPIKey: "test-key",
490		FireworksAPIKey: "test-key",
491	}
492
493	manager, err := NewManager(cfg)
494	if err != nil {
495		t.Fatalf("NewManager failed: %v", err)
496	}
497
498	models := manager.GetAvailableModels()
499
500	// Should have anthropic models and fireworks models, plus predictable
501	expectedModels := []string{"claude-opus-4.6", "claude-opus-4.5", "qwen3-coder-fireworks", "glm-4p6-fireworks", "claude-sonnet-4.5", "claude-haiku-4.5", "predictable"}
502	for _, expected := range expectedModels {
503		found := false
504		for _, m := range models {
505			if m == expected {
506				found = true
507				break
508			}
509		}
510		if !found {
511			t.Errorf("Expected model %q not found in available models: %v", expected, models)
512		}
513	}
514}