models_test.go

  1package models
  2
  3import (
  4	"context"
  5	"log/slog"
  6	"net/http"
  7	"testing"
  8
  9	"shelley.exe.dev/llm"
 10)
 11
 12func TestAll(t *testing.T) {
 13	models := All()
 14	if len(models) == 0 {
 15		t.Fatal("expected at least one model")
 16	}
 17
 18	// Verify all models have required fields
 19	for _, m := range models {
 20		if m.ID == "" {
 21			t.Errorf("model missing ID")
 22		}
 23		if m.Provider == "" {
 24			t.Errorf("model %s missing Provider", m.ID)
 25		}
 26		if m.Factory == nil {
 27			t.Errorf("model %s missing Factory", m.ID)
 28		}
 29	}
 30}
 31
 32func TestByID(t *testing.T) {
 33	tests := []struct {
 34		id      string
 35		wantID  string
 36		wantNil bool
 37	}{
 38		{id: "qwen3-coder-fireworks", wantID: "qwen3-coder-fireworks", wantNil: false},
 39		{id: "gpt-5.2-codex", wantID: "gpt-5.2-codex", wantNil: false},
 40		{id: "claude-sonnet-4.5", wantID: "claude-sonnet-4.5", wantNil: false},
 41		{id: "claude-haiku-4.5", wantID: "claude-haiku-4.5", wantNil: false},
 42		{id: "claude-opus-4.5", wantID: "claude-opus-4.5", wantNil: false},
 43		{id: "nonexistent", wantNil: true},
 44	}
 45
 46	for _, tt := range tests {
 47		t.Run(tt.id, func(t *testing.T) {
 48			m := ByID(tt.id)
 49			if tt.wantNil {
 50				if m != nil {
 51					t.Errorf("ByID(%q) = %v, want nil", tt.id, m)
 52				}
 53			} else {
 54				if m == nil {
 55					t.Fatalf("ByID(%q) = nil, want non-nil", tt.id)
 56				}
 57				if m.ID != tt.wantID {
 58					t.Errorf("ByID(%q).ID = %q, want %q", tt.id, m.ID, tt.wantID)
 59				}
 60			}
 61		})
 62	}
 63}
 64
 65func TestDefault(t *testing.T) {
 66	d := Default()
 67	if d.ID != "claude-opus-4.5" {
 68		t.Errorf("Default().ID = %q, want %q", d.ID, "claude-opus-4.5")
 69	}
 70}
 71
 72func TestIDs(t *testing.T) {
 73	ids := IDs()
 74	if len(ids) == 0 {
 75		t.Fatal("expected at least one model ID")
 76	}
 77
 78	// Verify all IDs are unique
 79	seen := make(map[string]bool)
 80	for _, id := range ids {
 81		if seen[id] {
 82			t.Errorf("duplicate model ID: %s", id)
 83		}
 84		seen[id] = true
 85	}
 86}
 87
 88func TestFactory(t *testing.T) {
 89	// Test that we can create services with empty config (should fail for most models)
 90	cfg := &Config{}
 91
 92	// Predictable should work without any config
 93	m := ByID("predictable")
 94	if m == nil {
 95		t.Fatal("predictable model not found")
 96	}
 97
 98	svc, err := m.Factory(cfg, nil)
 99	if err != nil {
100		t.Fatalf("predictable Factory() failed: %v", err)
101	}
102	if svc == nil {
103		t.Fatal("predictable Factory() returned nil service")
104	}
105}
106
107func TestManagerGetAvailableModelsOrder(t *testing.T) {
108	// Test that GetAvailableModels returns models in consistent order
109	cfg := &Config{}
110
111	// Create manager - should only have predictable model since no API keys
112	manager, err := NewManager(cfg)
113	if err != nil {
114		t.Fatalf("NewManager failed: %v", err)
115	}
116
117	// Get available models multiple times
118	firstCall := manager.GetAvailableModels()
119	secondCall := manager.GetAvailableModels()
120	thirdCall := manager.GetAvailableModels()
121
122	// Should return at least predictable model
123	if len(firstCall) == 0 {
124		t.Fatal("expected at least one model")
125	}
126
127	// All calls should return identical order
128	if len(firstCall) != len(secondCall) || len(firstCall) != len(thirdCall) {
129		t.Errorf("calls returned different lengths: %d, %d, %d", len(firstCall), len(secondCall), len(thirdCall))
130	}
131
132	for i := range firstCall {
133		if firstCall[i] != secondCall[i] {
134			t.Errorf("call 1 and 2 differ at index %d: %q vs %q", i, firstCall[i], secondCall[i])
135		}
136		if firstCall[i] != thirdCall[i] {
137			t.Errorf("call 1 and 3 differ at index %d: %q vs %q", i, firstCall[i], thirdCall[i])
138		}
139	}
140}
141
142func TestManagerGetAvailableModelsMatchesAllOrder(t *testing.T) {
143	// Test that available models are returned in the same order as All()
144	cfg := &Config{
145		AnthropicAPIKey: "test-key",
146		OpenAIAPIKey:    "test-key",
147		GeminiAPIKey:    "test-key",
148		FireworksAPIKey: "test-key",
149	}
150
151	manager, err := NewManager(cfg)
152	if err != nil {
153		t.Fatalf("NewManager failed: %v", err)
154	}
155
156	available := manager.GetAvailableModels()
157	all := All()
158
159	// Build expected order from All()
160	var expected []string
161	for _, m := range all {
162		if manager.HasModel(m.ID) {
163			expected = append(expected, m.ID)
164		}
165	}
166
167	// Should match
168	if len(available) != len(expected) {
169		t.Fatalf("available models count %d != expected count %d", len(available), len(expected))
170	}
171
172	for i := range available {
173		if available[i] != expected[i] {
174			t.Errorf("model at index %d: got %q, want %q", i, available[i], expected[i])
175		}
176	}
177}
178
179func TestLoggingService(t *testing.T) {
180	// Create a mock service for testing
181	mockService := &mockLLMService{}
182	logger := slog.Default()
183
184	loggingSvc := &loggingService{
185		service:  mockService,
186		logger:   logger,
187		modelID:  "test-model",
188		provider: ProviderBuiltIn,
189	}
190
191	// Test Do method
192	ctx := context.Background()
193	request := &llm.Request{
194		Messages: []llm.Message{
195			llm.UserStringMessage("Hello"),
196		},
197	}
198
199	response, err := loggingSvc.Do(ctx, request)
200	if err != nil {
201		t.Errorf("Do returned unexpected error: %v", err)
202	}
203
204	if response == nil {
205		t.Error("Do returned nil response")
206	}
207
208	// Test TokenContextWindow
209	window := loggingSvc.TokenContextWindow()
210	if window != mockService.TokenContextWindow() {
211		t.Errorf("TokenContextWindow returned %d, expected %d", window, mockService.TokenContextWindow())
212	}
213
214	// Test MaxImageDimension
215	dimension := loggingSvc.MaxImageDimension()
216	if dimension != mockService.MaxImageDimension() {
217		t.Errorf("MaxImageDimension returned %d, expected %d", dimension, mockService.MaxImageDimension())
218	}
219
220	// Test UseSimplifiedPatch
221	useSimplified := loggingSvc.UseSimplifiedPatch()
222	if useSimplified != mockService.UseSimplifiedPatch() {
223		t.Errorf("UseSimplifiedPatch returned %t, expected %t", useSimplified, mockService.UseSimplifiedPatch())
224	}
225}
226
227// mockLLMService implements llm.Service for testing
228type mockLLMService struct {
229	tokenContextWindow int
230	maxImageDimension  int
231	useSimplifiedPatch bool
232}
233
234func (m *mockLLMService) Do(ctx context.Context, request *llm.Request) (*llm.Response, error) {
235	return &llm.Response{
236		Content: llm.TextContent("Hello, world!"),
237		Usage: llm.Usage{
238			InputTokens:  10,
239			OutputTokens: 5,
240			CostUSD:      0.001,
241		},
242	}, nil
243}
244
245func (m *mockLLMService) TokenContextWindow() int {
246	if m.tokenContextWindow == 0 {
247		return 4096
248	}
249	return m.tokenContextWindow
250}
251
252func (m *mockLLMService) MaxImageDimension() int {
253	if m.maxImageDimension == 0 {
254		return 2048
255	}
256	return m.maxImageDimension
257}
258
259func (m *mockLLMService) UseSimplifiedPatch() bool {
260	return m.useSimplifiedPatch
261}
262
263func TestManagerGetService(t *testing.T) {
264	// Test with predictable model (no API keys needed)
265	cfg := &Config{}
266
267	manager, err := NewManager(cfg)
268	if err != nil {
269		t.Fatalf("NewManager failed: %v", err)
270	}
271
272	// Test getting predictable service (should work)
273	svc, err := manager.GetService("predictable")
274	if err != nil {
275		t.Errorf("GetService('predictable') failed: %v", err)
276	}
277	if svc == nil {
278		t.Error("GetService('predictable') returned nil service")
279	}
280
281	// Test getting non-existent service
282	_, err = manager.GetService("non-existent-model")
283	if err == nil {
284		t.Error("GetService('non-existent-model') should have failed but didn't")
285	}
286}
287
288func TestManagerHasModel(t *testing.T) {
289	cfg := &Config{}
290
291	manager, err := NewManager(cfg)
292	if err != nil {
293		t.Fatalf("NewManager failed: %v", err)
294	}
295
296	// Should have predictable model
297	if !manager.HasModel("predictable") {
298		t.Error("HasModel('predictable') should return true")
299	}
300
301	// Should not have models requiring API keys
302	if manager.HasModel("claude-opus-4.5") {
303		t.Error("HasModel('claude-opus-4.5') should return false without API key")
304	}
305
306	// Should not have non-existent model
307	if manager.HasModel("non-existent-model") {
308		t.Error("HasModel('non-existent-model') should return false")
309	}
310}
311
312func TestConfigGetURLMethods(t *testing.T) {
313	// Test getGeminiURL with no gateway
314	cfg := &Config{}
315	if cfg.getGeminiURL() != "" {
316		t.Errorf("getGeminiURL with no gateway should return empty string, got %q", cfg.getGeminiURL())
317	}
318
319	// Test getGeminiURL with gateway
320	cfg.Gateway = "https://gateway.example.com"
321	expected := "https://gateway.example.com/_/gateway/gemini/v1/models/generate"
322	if cfg.getGeminiURL() != expected {
323		t.Errorf("getGeminiURL with gateway should return %q, got %q", expected, cfg.getGeminiURL())
324	}
325
326	// Test other URL methods for completeness
327	if cfg.getAnthropicURL() != "https://gateway.example.com/_/gateway/anthropic/v1/messages" {
328		t.Error("getAnthropicURL did not return expected URL with gateway")
329	}
330
331	if cfg.getOpenAIURL() != "https://gateway.example.com/_/gateway/openai/v1" {
332		t.Error("getOpenAIURL did not return expected URL with gateway")
333	}
334
335	if cfg.getFireworksURL() != "https://gateway.example.com/_/gateway/fireworks/inference/v1" {
336		t.Error("getFireworksURL did not return expected URL with gateway")
337	}
338}
339
340func TestUseSimplifiedPatch(t *testing.T) {
341	// Test with a service that doesn't implement SimplifiedPatcher
342	mockService := &mockLLMService{}
343	logger := slog.Default()
344
345	loggingSvc := &loggingService{
346		service:  mockService,
347		logger:   logger,
348		modelID:  "test-model",
349		provider: ProviderBuiltIn,
350	}
351
352	// Should return false since mockService doesn't implement SimplifiedPatcher
353	result := loggingSvc.UseSimplifiedPatch()
354	if result != false {
355		t.Errorf("UseSimplifiedPatch should return false for non-SimplifiedPatcher, got %t", result)
356	}
357
358	// Test with a service that implements SimplifiedPatcher
359	mockSimplifiedService := &mockSimplifiedLLMService{useSimplified: true}
360	loggingSvc2 := &loggingService{
361		service:  mockSimplifiedService,
362		logger:   logger,
363		modelID:  "test-model-2",
364		provider: ProviderBuiltIn,
365	}
366
367	// Should return true since mockSimplifiedService implements SimplifiedPatcher and returns true
368	result = loggingSvc2.UseSimplifiedPatch()
369	if result != true {
370		t.Errorf("UseSimplifiedPatch should return true for SimplifiedPatcher returning true, got %t", result)
371	}
372}
373
374// mockSimplifiedLLMService implements llm.Service and llm.SimplifiedPatcher for testing
375type mockSimplifiedLLMService struct {
376	mockLLMService
377	useSimplified bool
378}
379
380func (m *mockSimplifiedLLMService) UseSimplifiedPatch() bool {
381	return m.useSimplified
382}
383
384func TestHTTPClientPassedToFactory(t *testing.T) {
385	// Test that HTTP client is passed to factory and used by services
386	cfg := &Config{
387		AnthropicAPIKey: "test-key",
388	}
389
390	// Create a custom HTTP client
391	customClient := &http.Client{}
392
393	// Test that claude factory accepts HTTP client
394	m := ByID("claude-opus-4.5")
395	if m == nil {
396		t.Fatal("claude-opus-4.5 model not found")
397	}
398
399	svc, err := m.Factory(cfg, customClient)
400	if err != nil {
401		t.Fatalf("Factory with custom HTTP client failed: %v", err)
402	}
403	if svc == nil {
404		t.Fatal("Factory returned nil service")
405	}
406}
407
408func TestGetModelSource(t *testing.T) {
409	tests := []struct {
410		name    string
411		cfg     *Config
412		modelID string
413		want    string
414	}{
415		{
416			name:    "anthropic with env var only",
417			cfg:     &Config{AnthropicAPIKey: "test-key"},
418			modelID: "claude-opus-4.5",
419			want:    "$ANTHROPIC_API_KEY",
420		},
421		{
422			name:    "anthropic with gateway implicit key",
423			cfg:     &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "implicit"},
424			modelID: "claude-opus-4.5",
425			want:    "exe.dev gateway",
426		},
427		{
428			name:    "anthropic with gateway but explicit key",
429			cfg:     &Config{Gateway: "https://gateway.example.com", AnthropicAPIKey: "actual-key"},
430			modelID: "claude-opus-4.5",
431			want:    "$ANTHROPIC_API_KEY",
432		},
433		{
434			name:    "fireworks with env var only",
435			cfg:     &Config{FireworksAPIKey: "test-key"},
436			modelID: "qwen3-coder-fireworks",
437			want:    "$FIREWORKS_API_KEY",
438		},
439		{
440			name:    "fireworks with gateway implicit key",
441			cfg:     &Config{Gateway: "https://gateway.example.com", FireworksAPIKey: "implicit"},
442			modelID: "qwen3-coder-fireworks",
443			want:    "exe.dev gateway",
444		},
445		{
446			name:    "openai with env var only",
447			cfg:     &Config{OpenAIAPIKey: "test-key"},
448			modelID: "gpt-5.2-codex",
449			want:    "$OPENAI_API_KEY",
450		},
451		{
452			name:    "gemini with env var only",
453			cfg:     &Config{GeminiAPIKey: "test-key"},
454			modelID: "gemini-3-pro",
455			want:    "$GEMINI_API_KEY",
456		},
457		{
458			name:    "predictable has no source",
459			cfg:     &Config{},
460			modelID: "predictable",
461			want:    "",
462		},
463	}
464
465	for _, tt := range tests {
466		t.Run(tt.name, func(t *testing.T) {
467			manager, err := NewManager(tt.cfg)
468			if err != nil {
469				t.Fatalf("NewManager failed: %v", err)
470			}
471
472			info := manager.GetModelInfo(tt.modelID)
473			if info == nil {
474				t.Fatalf("GetModelInfo(%q) returned nil", tt.modelID)
475			}
476			if info.Source != tt.want {
477				t.Errorf("GetModelInfo(%q).Source = %q, want %q", tt.modelID, info.Source, tt.want)
478			}
479		})
480	}
481}
482
483func TestGetAvailableModelsUnion(t *testing.T) {
484	// Test that GetAvailableModels returns both built-in and custom models
485	// This test just verifies the union behavior with built-in models only
486	// (testing with custom models requires a database)
487	cfg := &Config{
488		AnthropicAPIKey: "test-key",
489		FireworksAPIKey: "test-key",
490	}
491
492	manager, err := NewManager(cfg)
493	if err != nil {
494		t.Fatalf("NewManager failed: %v", err)
495	}
496
497	models := manager.GetAvailableModels()
498
499	// Should have anthropic models and fireworks models, plus predictable
500	expectedModels := []string{"claude-opus-4.5", "qwen3-coder-fireworks", "glm-4p6-fireworks", "claude-sonnet-4.5", "claude-haiku-4.5", "predictable"}
501	for _, expected := range expectedModels {
502		found := false
503		for _, m := range models {
504			if m == expected {
505				found = true
506				break
507			}
508		}
509		if !found {
510			t.Errorf("Expected model %q not found in available models: %v", expected, models)
511		}
512	}
513}