1package config
2
3import (
4 "encoding/json"
5 "os"
6 "path/filepath"
7 "sync"
8 "testing"
9
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13)
14
15func reset() {
16 // Clear all environment variables that could affect config
17 envVarsToUnset := []string{
18 // API Keys
19 "ANTHROPIC_API_KEY",
20 "OPENAI_API_KEY",
21 "GEMINI_API_KEY",
22 "XAI_API_KEY",
23 "OPENROUTER_API_KEY",
24
25 // Google Cloud / VertexAI
26 "GOOGLE_GENAI_USE_VERTEXAI",
27 "GOOGLE_CLOUD_PROJECT",
28 "GOOGLE_CLOUD_LOCATION",
29
30 // AWS Credentials
31 "AWS_ACCESS_KEY_ID",
32 "AWS_SECRET_ACCESS_KEY",
33 "AWS_REGION",
34 "AWS_DEFAULT_REGION",
35 "AWS_PROFILE",
36 "AWS_DEFAULT_PROFILE",
37 "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
38 "AWS_CONTAINER_CREDENTIALS_FULL_URI",
39
40 // Other
41 "CRUSH_DEV_DEBUG",
42 }
43
44 for _, envVar := range envVarsToUnset {
45 os.Unsetenv(envVar)
46 }
47
48 // Reset singleton
49 once = sync.Once{}
50 instance = nil
51 cwd = ""
52 testConfigDir = ""
53}
54
55// Core Configuration Loading Tests
56
57func TestInit_ValidWorkingDirectory(t *testing.T) {
58 reset()
59 testConfigDir = t.TempDir()
60 cwdDir := t.TempDir()
61
62 cfg, err := Init(cwdDir, false)
63
64 require.NoError(t, err)
65 assert.NotNil(t, cfg)
66 assert.Equal(t, cwdDir, WorkingDirectory())
67 assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
68 assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
69}
70
71func TestInit_WithDebugFlag(t *testing.T) {
72 reset()
73 testConfigDir = t.TempDir()
74 cwdDir := t.TempDir()
75
76 cfg, err := Init(cwdDir, true)
77
78 require.NoError(t, err)
79 assert.True(t, cfg.Options.Debug)
80}
81
82func TestInit_SingletonBehavior(t *testing.T) {
83 reset()
84 testConfigDir = t.TempDir()
85 cwdDir := t.TempDir()
86
87 cfg1, err1 := Init(cwdDir, false)
88 cfg2, err2 := Init(cwdDir, false)
89
90 require.NoError(t, err1)
91 require.NoError(t, err2)
92 assert.Same(t, cfg1, cfg2) // Should be the same instance
93}
94
95func TestGet_BeforeInitialization(t *testing.T) {
96 reset()
97
98 assert.Panics(t, func() {
99 Get()
100 })
101}
102
103func TestGet_AfterInitialization(t *testing.T) {
104 reset()
105 testConfigDir = t.TempDir()
106 cwdDir := t.TempDir()
107
108 cfg1, err := Init(cwdDir, false)
109 require.NoError(t, err)
110
111 cfg2 := Get()
112 assert.Same(t, cfg1, cfg2)
113}
114
115func TestLoadConfig_NoConfigFiles(t *testing.T) {
116 reset()
117 testConfigDir = t.TempDir()
118 cwdDir := t.TempDir()
119
120 cfg, err := Init(cwdDir, false)
121
122 require.NoError(t, err)
123 assert.Len(t, cfg.Providers, 0) // No providers without env vars or config files
124 assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
125}
126
127func TestLoadConfig_OnlyGlobalConfig(t *testing.T) {
128 reset()
129 testConfigDir = t.TempDir()
130 cwdDir := t.TempDir()
131
132 // Create global config file
133 globalConfig := Config{
134 Providers: map[provider.InferenceProvider]ProviderConfig{
135 provider.InferenceProviderOpenAI: {
136 ID: provider.InferenceProviderOpenAI,
137 APIKey: "test-key",
138 ProviderType: provider.TypeOpenAI,
139 },
140 },
141 Options: Options{
142 ContextPaths: []string{"custom-context.md"},
143 },
144 }
145
146 configPath := filepath.Join(testConfigDir, "crush.json")
147 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
148
149 data, err := json.Marshal(globalConfig)
150 require.NoError(t, err)
151 require.NoError(t, os.WriteFile(configPath, data, 0o644))
152
153 cfg, err := Init(cwdDir, false)
154
155 require.NoError(t, err)
156 assert.Len(t, cfg.Providers, 1)
157 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
158 assert.Contains(t, cfg.Options.ContextPaths, "custom-context.md")
159}
160
161func TestLoadConfig_OnlyLocalConfig(t *testing.T) {
162 reset()
163 testConfigDir = t.TempDir()
164 cwdDir := t.TempDir()
165
166 // Create local config file
167 localConfig := Config{
168 Providers: map[provider.InferenceProvider]ProviderConfig{
169 provider.InferenceProviderAnthropic: {
170 ID: provider.InferenceProviderAnthropic,
171 APIKey: "local-key",
172 ProviderType: provider.TypeAnthropic,
173 },
174 },
175 Options: Options{
176 TUI: TUIOptions{CompactMode: true},
177 },
178 }
179
180 localConfigPath := filepath.Join(cwdDir, "crush.json")
181 data, err := json.Marshal(localConfig)
182 require.NoError(t, err)
183 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
184
185 cfg, err := Init(cwdDir, false)
186
187 require.NoError(t, err)
188 assert.Len(t, cfg.Providers, 1)
189 assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
190 assert.True(t, cfg.Options.TUI.CompactMode)
191}
192
193func TestLoadConfig_BothGlobalAndLocal(t *testing.T) {
194 reset()
195 testConfigDir = t.TempDir()
196 cwdDir := t.TempDir()
197
198 // Create global config
199 globalConfig := Config{
200 Providers: map[provider.InferenceProvider]ProviderConfig{
201 provider.InferenceProviderOpenAI: {
202 ID: provider.InferenceProviderOpenAI,
203 APIKey: "global-key",
204 ProviderType: provider.TypeOpenAI,
205 },
206 },
207 Options: Options{
208 ContextPaths: []string{"global-context.md"},
209 },
210 }
211
212 configPath := filepath.Join(testConfigDir, "crush.json")
213 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
214 data, err := json.Marshal(globalConfig)
215 require.NoError(t, err)
216 require.NoError(t, os.WriteFile(configPath, data, 0o644))
217
218 // Create local config that overrides and adds
219 localConfig := Config{
220 Providers: map[provider.InferenceProvider]ProviderConfig{
221 provider.InferenceProviderOpenAI: {
222 APIKey: "local-key", // Override global
223 },
224 provider.InferenceProviderAnthropic: {
225 ID: provider.InferenceProviderAnthropic,
226 APIKey: "anthropic-key",
227 ProviderType: provider.TypeAnthropic,
228 },
229 },
230 Options: Options{
231 ContextPaths: []string{"local-context.md"},
232 TUI: TUIOptions{CompactMode: true},
233 },
234 }
235
236 localConfigPath := filepath.Join(cwdDir, "crush.json")
237 data, err = json.Marshal(localConfig)
238 require.NoError(t, err)
239 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
240
241 cfg, err := Init(cwdDir, false)
242
243 require.NoError(t, err)
244 assert.Len(t, cfg.Providers, 2)
245
246 // Check that local config overrode global
247 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
248 assert.Equal(t, "local-key", openaiProvider.APIKey)
249
250 // Check that local config added new provider
251 assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
252
253 // Check that context paths were merged
254 assert.Contains(t, cfg.Options.ContextPaths, "global-context.md")
255 assert.Contains(t, cfg.Options.ContextPaths, "local-context.md")
256 assert.True(t, cfg.Options.TUI.CompactMode)
257}
258
259func TestLoadConfig_MalformedGlobalJSON(t *testing.T) {
260 reset()
261 testConfigDir = t.TempDir()
262 cwdDir := t.TempDir()
263
264 // Create malformed global config
265 configPath := filepath.Join(testConfigDir, "crush.json")
266 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
267 require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644))
268
269 _, err := Init(cwdDir, false)
270 assert.Error(t, err)
271}
272
273func TestLoadConfig_MalformedLocalJSON(t *testing.T) {
274 reset()
275 testConfigDir = t.TempDir()
276 cwdDir := t.TempDir()
277
278 // Create malformed local config
279 localConfigPath := filepath.Join(cwdDir, "crush.json")
280 require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644))
281
282 _, err := Init(cwdDir, false)
283 assert.Error(t, err)
284}
285
286func TestConfigWithoutEnv(t *testing.T) {
287 reset()
288 testConfigDir = t.TempDir()
289 cwdDir := t.TempDir()
290
291 cfg, _ := Init(cwdDir, false)
292 assert.Len(t, cfg.Providers, 0)
293}
294
295func TestConfigWithEnv(t *testing.T) {
296 reset()
297 testConfigDir = t.TempDir()
298 cwdDir := t.TempDir()
299
300 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
301 os.Setenv("OPENAI_API_KEY", "test-openai-key")
302 os.Setenv("GEMINI_API_KEY", "test-gemini-key")
303 os.Setenv("XAI_API_KEY", "test-xai-key")
304 os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
305
306 cfg, _ := Init(cwdDir, false)
307 assert.Len(t, cfg.Providers, 5)
308}
309
310// Environment Variable Tests
311
312func TestEnvVars_NoEnvironmentVariables(t *testing.T) {
313 reset()
314 testConfigDir = t.TempDir()
315 cwdDir := t.TempDir()
316
317 cfg, err := Init(cwdDir, false)
318
319 require.NoError(t, err)
320 assert.Len(t, cfg.Providers, 0)
321}
322
323func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
324 reset()
325 testConfigDir = t.TempDir()
326 cwdDir := t.TempDir()
327
328 // Set all supported API keys
329 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
330 os.Setenv("OPENAI_API_KEY", "test-openai-key")
331 os.Setenv("GEMINI_API_KEY", "test-gemini-key")
332 os.Setenv("XAI_API_KEY", "test-xai-key")
333 os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
334
335 cfg, err := Init(cwdDir, false)
336
337 require.NoError(t, err)
338 assert.Len(t, cfg.Providers, 5)
339
340 // Verify each provider is configured correctly
341 anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
342 assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey)
343 assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType)
344
345 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
346 assert.Equal(t, "test-openai-key", openaiProvider.APIKey)
347 assert.Equal(t, provider.TypeOpenAI, openaiProvider.ProviderType)
348
349 geminiProvider := cfg.Providers[provider.InferenceProviderGemini]
350 assert.Equal(t, "test-gemini-key", geminiProvider.APIKey)
351 assert.Equal(t, provider.TypeGemini, geminiProvider.ProviderType)
352
353 xaiProvider := cfg.Providers[provider.InferenceProviderXAI]
354 assert.Equal(t, "test-xai-key", xaiProvider.APIKey)
355 assert.Equal(t, provider.TypeXAI, xaiProvider.ProviderType)
356
357 openrouterProvider := cfg.Providers[provider.InferenceProviderOpenRouter]
358 assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey)
359 assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType)
360 assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL)
361}
362
363func TestEnvVars_PartialEnvironmentVariables(t *testing.T) {
364 reset()
365 testConfigDir = t.TempDir()
366 cwdDir := t.TempDir()
367
368 // Set only some API keys
369 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
370 os.Setenv("OPENAI_API_KEY", "test-openai-key")
371
372 cfg, err := Init(cwdDir, false)
373
374 require.NoError(t, err)
375 assert.Len(t, cfg.Providers, 2)
376 assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
377 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
378 assert.NotContains(t, cfg.Providers, provider.InferenceProviderGemini)
379}
380
381func TestEnvVars_VertexAIConfiguration(t *testing.T) {
382 reset()
383 testConfigDir = t.TempDir()
384 cwdDir := t.TempDir()
385
386 // Set VertexAI environment variables
387 os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
388 os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
389 os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
390
391 cfg, err := Init(cwdDir, false)
392
393 require.NoError(t, err)
394 assert.Contains(t, cfg.Providers, provider.InferenceProviderVertexAI)
395
396 vertexProvider := cfg.Providers[provider.InferenceProviderVertexAI]
397 assert.Equal(t, provider.TypeVertexAI, vertexProvider.ProviderType)
398 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
399 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
400}
401
402func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) {
403 reset()
404 testConfigDir = t.TempDir()
405 cwdDir := t.TempDir()
406
407 // Set Google Cloud vars but not the use flag
408 os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
409 os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
410
411 cfg, err := Init(cwdDir, false)
412
413 require.NoError(t, err)
414 assert.NotContains(t, cfg.Providers, provider.InferenceProviderVertexAI)
415}
416
417func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) {
418 reset()
419 testConfigDir = t.TempDir()
420 cwdDir := t.TempDir()
421
422 // Set AWS credentials
423 os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key")
424 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
425 os.Setenv("AWS_DEFAULT_REGION", "us-east-1")
426
427 cfg, err := Init(cwdDir, false)
428
429 require.NoError(t, err)
430 assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
431
432 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
433 assert.Equal(t, provider.TypeBedrock, bedrockProvider.ProviderType)
434 assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
435}
436
437func TestEnvVars_AWSBedrockWithProfile(t *testing.T) {
438 reset()
439 testConfigDir = t.TempDir()
440 cwdDir := t.TempDir()
441
442 // Set AWS profile
443 os.Setenv("AWS_PROFILE", "test-profile")
444 os.Setenv("AWS_REGION", "eu-west-1")
445
446 cfg, err := Init(cwdDir, false)
447
448 require.NoError(t, err)
449 assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
450
451 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
452 assert.Equal(t, "eu-west-1", bedrockProvider.ExtraParams["region"])
453}
454
455func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) {
456 reset()
457 testConfigDir = t.TempDir()
458 cwdDir := t.TempDir()
459
460 // Set AWS container credentials
461 os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
462 os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1")
463
464 cfg, err := Init(cwdDir, false)
465
466 require.NoError(t, err)
467 assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
468}
469
470func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) {
471 reset()
472 testConfigDir = t.TempDir()
473 cwdDir := t.TempDir()
474
475 // Set both region variables - AWS_DEFAULT_REGION should take priority
476 os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
477 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
478 os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
479 os.Setenv("AWS_REGION", "us-east-1")
480
481 cfg, err := Init(cwdDir, false)
482
483 require.NoError(t, err)
484 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
485 assert.Equal(t, "us-west-2", bedrockProvider.ExtraParams["region"])
486}
487
488func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) {
489 reset()
490 testConfigDir = t.TempDir()
491 cwdDir := t.TempDir()
492
493 // Set only AWS_REGION (not AWS_DEFAULT_REGION)
494 os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
495 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
496 os.Setenv("AWS_REGION", "us-east-1")
497
498 cfg, err := Init(cwdDir, false)
499
500 require.NoError(t, err)
501 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
502 assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
503}
504
505func TestEnvVars_NoAWSCredentials(t *testing.T) {
506 reset()
507 testConfigDir = t.TempDir()
508 cwdDir := t.TempDir()
509
510 // Don't set any AWS credentials
511 cfg, err := Init(cwdDir, false)
512
513 require.NoError(t, err)
514 assert.NotContains(t, cfg.Providers, provider.InferenceProviderBedrock)
515}
516
517func TestEnvVars_CustomEnvironmentVariables(t *testing.T) {
518 reset()
519 testConfigDir = t.TempDir()
520 cwdDir := t.TempDir()
521
522 // Test that environment variables are properly resolved from provider definitions
523 // This test assumes the provider system uses $VARIABLE_NAME format
524 os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key")
525
526 cfg, err := Init(cwdDir, false)
527
528 require.NoError(t, err)
529 if len(cfg.Providers) > 0 {
530 // Verify that the environment variable was resolved
531 if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists {
532 assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey)
533 }
534 }
535}
536
537func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) {
538 reset()
539 testConfigDir = t.TempDir()
540 cwdDir := t.TempDir()
541
542 // Set multiple types of environment variables
543 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic")
544 os.Setenv("OPENAI_API_KEY", "test-openai")
545 os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
546 os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
547 os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key")
548 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret")
549 os.Setenv("AWS_DEFAULT_REGION", "us-west-1")
550
551 cfg, err := Init(cwdDir, false)
552
553 require.NoError(t, err)
554
555 // Should have API key providers + VertexAI + Bedrock
556 expectedProviders := []provider.InferenceProvider{
557 provider.InferenceProviderAnthropic,
558 provider.InferenceProviderOpenAI,
559 provider.InferenceProviderVertexAI,
560 provider.InferenceProviderBedrock,
561 }
562
563 for _, expectedProvider := range expectedProviders {
564 assert.Contains(t, cfg.Providers, expectedProvider)
565 }
566}
567
568func TestHasAWSCredentials_AccessKeys(t *testing.T) {
569 reset()
570
571 os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
572 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
573
574 assert.True(t, hasAWSCredentials())
575}
576
577func TestHasAWSCredentials_Profile(t *testing.T) {
578 reset()
579
580 os.Setenv("AWS_PROFILE", "test-profile")
581
582 assert.True(t, hasAWSCredentials())
583}
584
585func TestHasAWSCredentials_DefaultProfile(t *testing.T) {
586 reset()
587
588 os.Setenv("AWS_DEFAULT_PROFILE", "default")
589
590 assert.True(t, hasAWSCredentials())
591}
592
593func TestHasAWSCredentials_Region(t *testing.T) {
594 reset()
595
596 os.Setenv("AWS_REGION", "us-east-1")
597
598 assert.True(t, hasAWSCredentials())
599}
600
601func TestHasAWSCredentials_ContainerCredentials(t *testing.T) {
602 reset()
603
604 os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
605
606 assert.True(t, hasAWSCredentials())
607}
608
609func TestHasAWSCredentials_NoCredentials(t *testing.T) {
610 reset()
611
612 assert.False(t, hasAWSCredentials())
613}
614
615// Provider Configuration Tests
616
617func TestProviderMerging_GlobalToBase(t *testing.T) {
618 reset()
619 testConfigDir = t.TempDir()
620 cwdDir := t.TempDir()
621
622 // Create global config with provider
623 globalConfig := Config{
624 Providers: map[provider.InferenceProvider]ProviderConfig{
625 provider.InferenceProviderOpenAI: {
626 ID: provider.InferenceProviderOpenAI,
627 APIKey: "global-openai-key",
628 ProviderType: provider.TypeOpenAI,
629 DefaultLargeModel: "gpt-4",
630 DefaultSmallModel: "gpt-3.5-turbo",
631 Models: []Model{
632 {
633 ID: "gpt-4",
634 Name: "GPT-4",
635 ContextWindow: 8192,
636 DefaultMaxTokens: 4096,
637 },
638 },
639 },
640 },
641 }
642
643 configPath := filepath.Join(testConfigDir, "crush.json")
644 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
645 data, err := json.Marshal(globalConfig)
646 require.NoError(t, err)
647 require.NoError(t, os.WriteFile(configPath, data, 0o644))
648
649 cfg, err := Init(cwdDir, false)
650
651 require.NoError(t, err)
652 assert.Len(t, cfg.Providers, 1)
653
654 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
655 assert.Equal(t, "global-openai-key", openaiProvider.APIKey)
656 assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel)
657 assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel)
658 assert.Len(t, openaiProvider.Models, 1)
659}
660
661func TestProviderMerging_LocalToBase(t *testing.T) {
662 reset()
663 testConfigDir = t.TempDir()
664 cwdDir := t.TempDir()
665
666 // Create local config with provider
667 localConfig := Config{
668 Providers: map[provider.InferenceProvider]ProviderConfig{
669 provider.InferenceProviderAnthropic: {
670 ID: provider.InferenceProviderAnthropic,
671 APIKey: "local-anthropic-key",
672 ProviderType: provider.TypeAnthropic,
673 DefaultLargeModel: "claude-3-opus",
674 },
675 },
676 }
677
678 localConfigPath := filepath.Join(cwdDir, "crush.json")
679 data, err := json.Marshal(localConfig)
680 require.NoError(t, err)
681 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
682
683 cfg, err := Init(cwdDir, false)
684
685 require.NoError(t, err)
686 assert.Len(t, cfg.Providers, 1)
687
688 anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
689 assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey)
690 assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel)
691}
692
693func TestProviderMerging_ConflictingSettings(t *testing.T) {
694 reset()
695 testConfigDir = t.TempDir()
696 cwdDir := t.TempDir()
697
698 // Create global config
699 globalConfig := Config{
700 Providers: map[provider.InferenceProvider]ProviderConfig{
701 provider.InferenceProviderOpenAI: {
702 ID: provider.InferenceProviderOpenAI,
703 APIKey: "global-key",
704 ProviderType: provider.TypeOpenAI,
705 DefaultLargeModel: "gpt-4",
706 DefaultSmallModel: "gpt-3.5-turbo",
707 },
708 },
709 }
710
711 configPath := filepath.Join(testConfigDir, "crush.json")
712 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
713 data, err := json.Marshal(globalConfig)
714 require.NoError(t, err)
715 require.NoError(t, os.WriteFile(configPath, data, 0o644))
716
717 // Create local config that overrides
718 localConfig := Config{
719 Providers: map[provider.InferenceProvider]ProviderConfig{
720 provider.InferenceProviderOpenAI: {
721 APIKey: "local-key",
722 DefaultLargeModel: "gpt-4-turbo",
723 // Test disabled separately - don't disable here as it causes nil pointer
724 },
725 },
726 }
727
728 localConfigPath := filepath.Join(cwdDir, "crush.json")
729 data, err = json.Marshal(localConfig)
730 require.NoError(t, err)
731 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
732
733 cfg, err := Init(cwdDir, false)
734
735 require.NoError(t, err)
736
737 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
738 // Local should override global
739 assert.Equal(t, "local-key", openaiProvider.APIKey)
740 assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel)
741 assert.False(t, openaiProvider.Disabled) // Should not be disabled
742 // Global values should remain where not overridden
743 assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel)
744}
745
746func TestProviderMerging_CustomVsKnownProviders(t *testing.T) {
747 reset()
748 testConfigDir = t.TempDir()
749 cwdDir := t.TempDir()
750
751 customProviderID := provider.InferenceProvider("custom-provider")
752
753 // Create config with both known and custom providers
754 globalConfig := Config{
755 Providers: map[provider.InferenceProvider]ProviderConfig{
756 // Known provider - some fields should not be overrideable
757 provider.InferenceProviderOpenAI: {
758 ID: provider.InferenceProviderOpenAI,
759 APIKey: "openai-key",
760 BaseURL: "should-not-override",
761 ProviderType: provider.TypeAnthropic, // Should not override
762 },
763 // Custom provider - all fields should be configurable
764 customProviderID: {
765 ID: customProviderID,
766 APIKey: "custom-key",
767 BaseURL: "https://custom.api.com",
768 ProviderType: provider.TypeOpenAI,
769 },
770 },
771 }
772
773 localConfig := Config{
774 Providers: map[provider.InferenceProvider]ProviderConfig{
775 provider.InferenceProviderOpenAI: {
776 BaseURL: "https://should-not-change.com",
777 ProviderType: provider.TypeGemini, // Should not change
778 },
779 customProviderID: {
780 BaseURL: "https://updated-custom.api.com",
781 ProviderType: provider.TypeOpenAI,
782 },
783 },
784 }
785
786 configPath := filepath.Join(testConfigDir, "crush.json")
787 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
788 data, err := json.Marshal(globalConfig)
789 require.NoError(t, err)
790 require.NoError(t, os.WriteFile(configPath, data, 0o644))
791
792 localConfigPath := filepath.Join(cwdDir, "crush.json")
793 data, err = json.Marshal(localConfig)
794 require.NoError(t, err)
795 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
796
797 cfg, err := Init(cwdDir, false)
798
799 require.NoError(t, err)
800
801 // Known provider should not have BaseURL/ProviderType overridden
802 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
803 assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL)
804 assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType)
805
806 // Custom provider should have all fields configurable
807 customProvider := cfg.Providers[customProviderID]
808 assert.Equal(t, "custom-key", customProvider.APIKey) // Should preserve from global
809 assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL)
810 assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType)
811}
812
813func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) {
814 reset()
815 testConfigDir = t.TempDir()
816 cwdDir := t.TempDir()
817
818 customProviderID := provider.InferenceProvider("custom-provider")
819
820 // Create config with custom provider missing BaseURL
821 globalConfig := Config{
822 Providers: map[provider.InferenceProvider]ProviderConfig{
823 customProviderID: {
824 ID: customProviderID,
825 APIKey: "custom-key",
826 ProviderType: provider.TypeOpenAI,
827 // Missing BaseURL
828 },
829 },
830 }
831
832 configPath := filepath.Join(testConfigDir, "crush.json")
833 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
834 data, err := json.Marshal(globalConfig)
835 require.NoError(t, err)
836 require.NoError(t, os.WriteFile(configPath, data, 0o644))
837
838 cfg, err := Init(cwdDir, false)
839
840 require.NoError(t, err)
841 // Provider should be filtered out due to validation failure
842 assert.NotContains(t, cfg.Providers, customProviderID)
843}
844
845func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) {
846 reset()
847 testConfigDir = t.TempDir()
848 cwdDir := t.TempDir()
849
850 customProviderID := provider.InferenceProvider("custom-provider")
851
852 globalConfig := Config{
853 Providers: map[provider.InferenceProvider]ProviderConfig{
854 customProviderID: {
855 ID: customProviderID,
856 BaseURL: "https://custom.api.com",
857 ProviderType: provider.TypeOpenAI,
858 // Missing APIKey
859 },
860 },
861 }
862
863 configPath := filepath.Join(testConfigDir, "crush.json")
864 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
865 data, err := json.Marshal(globalConfig)
866 require.NoError(t, err)
867 require.NoError(t, os.WriteFile(configPath, data, 0o644))
868
869 cfg, err := Init(cwdDir, false)
870
871 require.NoError(t, err)
872 assert.NotContains(t, cfg.Providers, customProviderID)
873}
874
875func TestProviderValidation_CustomProviderInvalidType(t *testing.T) {
876 reset()
877 testConfigDir = t.TempDir()
878 cwdDir := t.TempDir()
879
880 customProviderID := provider.InferenceProvider("custom-provider")
881
882 globalConfig := Config{
883 Providers: map[provider.InferenceProvider]ProviderConfig{
884 customProviderID: {
885 ID: customProviderID,
886 APIKey: "custom-key",
887 BaseURL: "https://custom.api.com",
888 ProviderType: provider.Type("invalid-type"),
889 },
890 },
891 }
892
893 configPath := filepath.Join(testConfigDir, "crush.json")
894 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
895 data, err := json.Marshal(globalConfig)
896 require.NoError(t, err)
897 require.NoError(t, os.WriteFile(configPath, data, 0o644))
898
899 cfg, err := Init(cwdDir, false)
900
901 require.NoError(t, err)
902 assert.NotContains(t, cfg.Providers, customProviderID)
903}
904
905func TestProviderValidation_KnownProviderValid(t *testing.T) {
906 reset()
907 testConfigDir = t.TempDir()
908 cwdDir := t.TempDir()
909
910 globalConfig := Config{
911 Providers: map[provider.InferenceProvider]ProviderConfig{
912 provider.InferenceProviderOpenAI: {
913 ID: provider.InferenceProviderOpenAI,
914 APIKey: "openai-key",
915 ProviderType: provider.TypeOpenAI,
916 // BaseURL not required for known providers
917 },
918 },
919 }
920
921 configPath := filepath.Join(testConfigDir, "crush.json")
922 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
923 data, err := json.Marshal(globalConfig)
924 require.NoError(t, err)
925 require.NoError(t, os.WriteFile(configPath, data, 0o644))
926
927 cfg, err := Init(cwdDir, false)
928
929 require.NoError(t, err)
930 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
931}
932
933func TestProviderValidation_DisabledProvider(t *testing.T) {
934 reset()
935 testConfigDir = t.TempDir()
936 cwdDir := t.TempDir()
937
938 globalConfig := Config{
939 Providers: map[provider.InferenceProvider]ProviderConfig{
940 provider.InferenceProviderOpenAI: {
941 ID: provider.InferenceProviderOpenAI,
942 APIKey: "openai-key",
943 ProviderType: provider.TypeOpenAI,
944 Disabled: true,
945 },
946 },
947 }
948
949 configPath := filepath.Join(testConfigDir, "crush.json")
950 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
951 data, err := json.Marshal(globalConfig)
952 require.NoError(t, err)
953 require.NoError(t, os.WriteFile(configPath, data, 0o644))
954
955 cfg, err := Init(cwdDir, false)
956
957 require.NoError(t, err)
958 // Disabled providers should still be in the config but marked as disabled
959 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
960 assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled)
961}
962
963func TestProviderModels_AddingNewModels(t *testing.T) {
964 reset()
965 testConfigDir = t.TempDir()
966 cwdDir := t.TempDir()
967
968 globalConfig := Config{
969 Providers: map[provider.InferenceProvider]ProviderConfig{
970 provider.InferenceProviderOpenAI: {
971 ID: provider.InferenceProviderOpenAI,
972 APIKey: "openai-key",
973 ProviderType: provider.TypeOpenAI,
974 Models: []Model{
975 {
976 ID: "gpt-4",
977 Name: "GPT-4",
978 ContextWindow: 8192,
979 DefaultMaxTokens: 4096,
980 },
981 },
982 },
983 },
984 }
985
986 localConfig := Config{
987 Providers: map[provider.InferenceProvider]ProviderConfig{
988 provider.InferenceProviderOpenAI: {
989 Models: []Model{
990 {
991 ID: "gpt-4-turbo",
992 Name: "GPT-4 Turbo",
993 ContextWindow: 128000,
994 DefaultMaxTokens: 4096,
995 },
996 },
997 },
998 },
999 }
1000
1001 configPath := filepath.Join(testConfigDir, "crush.json")
1002 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1003 data, err := json.Marshal(globalConfig)
1004 require.NoError(t, err)
1005 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1006
1007 localConfigPath := filepath.Join(cwdDir, "crush.json")
1008 data, err = json.Marshal(localConfig)
1009 require.NoError(t, err)
1010 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1011
1012 cfg, err := Init(cwdDir, false)
1013
1014 require.NoError(t, err)
1015
1016 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
1017 assert.Len(t, openaiProvider.Models, 2) // Should have both models
1018
1019 modelIDs := make([]string, len(openaiProvider.Models))
1020 for i, model := range openaiProvider.Models {
1021 modelIDs[i] = model.ID
1022 }
1023 assert.Contains(t, modelIDs, "gpt-4")
1024 assert.Contains(t, modelIDs, "gpt-4-turbo")
1025}
1026
1027func TestProviderModels_DuplicateModelHandling(t *testing.T) {
1028 reset()
1029 testConfigDir = t.TempDir()
1030 cwdDir := t.TempDir()
1031
1032 globalConfig := Config{
1033 Providers: map[provider.InferenceProvider]ProviderConfig{
1034 provider.InferenceProviderOpenAI: {
1035 ID: provider.InferenceProviderOpenAI,
1036 APIKey: "openai-key",
1037 ProviderType: provider.TypeOpenAI,
1038 Models: []Model{
1039 {
1040 ID: "gpt-4",
1041 Name: "GPT-4",
1042 ContextWindow: 8192,
1043 DefaultMaxTokens: 4096,
1044 },
1045 },
1046 },
1047 },
1048 }
1049
1050 localConfig := Config{
1051 Providers: map[provider.InferenceProvider]ProviderConfig{
1052 provider.InferenceProviderOpenAI: {
1053 Models: []Model{
1054 {
1055 ID: "gpt-4", // Same ID as global
1056 Name: "GPT-4 Updated",
1057 ContextWindow: 16384,
1058 DefaultMaxTokens: 8192,
1059 },
1060 },
1061 },
1062 },
1063 }
1064
1065 configPath := filepath.Join(testConfigDir, "crush.json")
1066 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1067 data, err := json.Marshal(globalConfig)
1068 require.NoError(t, err)
1069 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1070
1071 localConfigPath := filepath.Join(cwdDir, "crush.json")
1072 data, err = json.Marshal(localConfig)
1073 require.NoError(t, err)
1074 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1075
1076 cfg, err := Init(cwdDir, false)
1077
1078 require.NoError(t, err)
1079
1080 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
1081 assert.Len(t, openaiProvider.Models, 1) // Should not duplicate
1082
1083 // Should keep the original model (global config)
1084 model := openaiProvider.Models[0]
1085 assert.Equal(t, "gpt-4", model.ID)
1086 assert.Equal(t, "GPT-4", model.Name) // Original name
1087 assert.Equal(t, int64(8192), model.ContextWindow) // Original context window
1088}
1089
1090func TestProviderModels_ModelCostAndCapabilities(t *testing.T) {
1091 reset()
1092 testConfigDir = t.TempDir()
1093 cwdDir := t.TempDir()
1094
1095 globalConfig := Config{
1096 Providers: map[provider.InferenceProvider]ProviderConfig{
1097 provider.InferenceProviderOpenAI: {
1098 ID: provider.InferenceProviderOpenAI,
1099 APIKey: "openai-key",
1100 ProviderType: provider.TypeOpenAI,
1101 Models: []Model{
1102 {
1103 ID: "gpt-4",
1104 Name: "GPT-4",
1105 CostPer1MIn: 30.0,
1106 CostPer1MOut: 60.0,
1107 CostPer1MInCached: 15.0,
1108 CostPer1MOutCached: 30.0,
1109 ContextWindow: 8192,
1110 DefaultMaxTokens: 4096,
1111 CanReason: true,
1112 ReasoningEffort: "medium",
1113 SupportsImages: true,
1114 },
1115 },
1116 },
1117 },
1118 }
1119
1120 configPath := filepath.Join(testConfigDir, "crush.json")
1121 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1122 data, err := json.Marshal(globalConfig)
1123 require.NoError(t, err)
1124 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1125
1126 cfg, err := Init(cwdDir, false)
1127
1128 require.NoError(t, err)
1129
1130 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
1131 require.Len(t, openaiProvider.Models, 1)
1132
1133 model := openaiProvider.Models[0]
1134 assert.Equal(t, 30.0, model.CostPer1MIn)
1135 assert.Equal(t, 60.0, model.CostPer1MOut)
1136 assert.Equal(t, 15.0, model.CostPer1MInCached)
1137 assert.Equal(t, 30.0, model.CostPer1MOutCached)
1138 assert.True(t, model.CanReason)
1139 assert.Equal(t, "medium", model.ReasoningEffort)
1140 assert.True(t, model.SupportsImages)
1141}
1142
1143// Agent Configuration Tests
1144
1145func TestDefaultAgents_CoderAgent(t *testing.T) {
1146 reset()
1147 testConfigDir = t.TempDir()
1148 cwdDir := t.TempDir()
1149
1150 // Set up a provider so we can test agent configuration
1151 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1152
1153 cfg, err := Init(cwdDir, false)
1154
1155 require.NoError(t, err)
1156 assert.Contains(t, cfg.Agents, AgentCoder)
1157
1158 coderAgent := cfg.Agents[AgentCoder]
1159 assert.Equal(t, AgentCoder, coderAgent.ID)
1160 assert.Equal(t, "Coder", coderAgent.Name)
1161 assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
1162 assert.Equal(t, LargeModel, coderAgent.Model)
1163 assert.False(t, coderAgent.Disabled)
1164 assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths)
1165 // Coder agent should have all tools available (nil means all tools)
1166 assert.Nil(t, coderAgent.AllowedTools)
1167}
1168
1169func TestDefaultAgents_TaskAgent(t *testing.T) {
1170 reset()
1171 testConfigDir = t.TempDir()
1172 cwdDir := t.TempDir()
1173
1174 // Set up a provider so we can test agent configuration
1175 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1176
1177 cfg, err := Init(cwdDir, false)
1178
1179 require.NoError(t, err)
1180 assert.Contains(t, cfg.Agents, AgentTask)
1181
1182 taskAgent := cfg.Agents[AgentTask]
1183 assert.Equal(t, AgentTask, taskAgent.ID)
1184 assert.Equal(t, "Task", taskAgent.Name)
1185 assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description)
1186 assert.Equal(t, LargeModel, taskAgent.Model)
1187 assert.False(t, taskAgent.Disabled)
1188 assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths)
1189
1190 // Task agent should have restricted tools
1191 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1192 assert.Equal(t, expectedTools, taskAgent.AllowedTools)
1193
1194 // Task agent should have no MCPs or LSPs by default
1195 assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP)
1196 assert.Equal(t, []string{}, taskAgent.AllowedLSP)
1197}
1198
1199func TestAgentMerging_CustomAgent(t *testing.T) {
1200 reset()
1201 testConfigDir = t.TempDir()
1202 cwdDir := t.TempDir()
1203
1204 // Set up a provider
1205 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1206
1207 // Create config with custom agent
1208 globalConfig := Config{
1209 Agents: map[AgentID]Agent{
1210 AgentID("custom-agent"): {
1211 ID: AgentID("custom-agent"),
1212 Name: "Custom Agent",
1213 Description: "A custom agent for testing",
1214 Model: SmallModel,
1215 AllowedTools: []string{"glob", "grep"},
1216 AllowedMCP: map[string][]string{"mcp1": {"tool1", "tool2"}},
1217 AllowedLSP: []string{"typescript", "go"},
1218 ContextPaths: []string{"custom-context.md"},
1219 },
1220 },
1221 }
1222
1223 configPath := filepath.Join(testConfigDir, "crush.json")
1224 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1225 data, err := json.Marshal(globalConfig)
1226 require.NoError(t, err)
1227 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1228
1229 cfg, err := Init(cwdDir, false)
1230
1231 require.NoError(t, err)
1232
1233 // Should have default agents plus custom agent
1234 assert.Contains(t, cfg.Agents, AgentCoder)
1235 assert.Contains(t, cfg.Agents, AgentTask)
1236 assert.Contains(t, cfg.Agents, AgentID("custom-agent"))
1237
1238 customAgent := cfg.Agents[AgentID("custom-agent")]
1239 assert.Equal(t, "Custom Agent", customAgent.Name)
1240 assert.Equal(t, "A custom agent for testing", customAgent.Description)
1241 assert.Equal(t, SmallModel, customAgent.Model)
1242 assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools)
1243 assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP)
1244 assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP)
1245 // Context paths should be additive (default + custom)
1246 expectedContextPaths := append(defaultContextPaths, "custom-context.md")
1247 assert.Equal(t, expectedContextPaths, customAgent.ContextPaths)
1248}
1249
1250func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) {
1251 reset()
1252 testConfigDir = t.TempDir()
1253 cwdDir := t.TempDir()
1254
1255 // Set up a provider
1256 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1257
1258 // Create config that modifies the default coder agent
1259 globalConfig := Config{
1260 Agents: map[AgentID]Agent{
1261 AgentCoder: {
1262 Model: SmallModel, // Change from default LargeModel
1263 AllowedMCP: map[string][]string{"mcp1": {"tool1"}},
1264 AllowedLSP: []string{"typescript"},
1265 ContextPaths: []string{"coder-specific.md"}, // Should be additive
1266 },
1267 },
1268 }
1269
1270 configPath := filepath.Join(testConfigDir, "crush.json")
1271 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1272 data, err := json.Marshal(globalConfig)
1273 require.NoError(t, err)
1274 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1275
1276 cfg, err := Init(cwdDir, false)
1277
1278 require.NoError(t, err)
1279
1280 coderAgent := cfg.Agents[AgentCoder]
1281 // Should preserve default values for unspecified fields
1282 assert.Equal(t, AgentCoder, coderAgent.ID)
1283 assert.Equal(t, "Coder", coderAgent.Name)
1284 assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
1285
1286 // Context paths should be additive (default + custom)
1287 expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md")
1288 assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
1289
1290 // Should update specified fields
1291 assert.Equal(t, SmallModel, coderAgent.Model)
1292 assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP)
1293 assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP)
1294}
1295
1296func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) {
1297 reset()
1298 testConfigDir = t.TempDir()
1299 cwdDir := t.TempDir()
1300
1301 // Set up a provider
1302 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1303
1304 // Create config that modifies the default task agent
1305 // Note: Only model, MCP, and LSP should be configurable for known agents
1306 globalConfig := Config{
1307 Agents: map[AgentID]Agent{
1308 AgentTask: {
1309 Model: SmallModel, // Should be updated
1310 AllowedMCP: map[string][]string{"search-mcp": nil}, // Should be updated
1311 AllowedLSP: []string{"python"}, // Should be updated
1312 // These should be ignored for known agents:
1313 Name: "Search Agent", // Should be ignored
1314 Description: "Custom search agent", // Should be ignored
1315 Disabled: true, // Should be ignored
1316 AllowedTools: []string{"glob", "grep", "view"}, // Should be ignored
1317 },
1318 },
1319 }
1320
1321 configPath := filepath.Join(testConfigDir, "crush.json")
1322 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1323 data, err := json.Marshal(globalConfig)
1324 require.NoError(t, err)
1325 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1326
1327 cfg, err := Init(cwdDir, false)
1328
1329 require.NoError(t, err)
1330
1331 taskAgent := cfg.Agents[AgentTask]
1332 // Should preserve default values for protected fields
1333 assert.Equal(t, "Task", taskAgent.Name) // Should remain default
1334 assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) // Should remain default
1335 assert.False(t, taskAgent.Disabled) // Should remain default
1336 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) // Should remain default
1337
1338 // Should update configurable fields
1339 assert.Equal(t, SmallModel, taskAgent.Model)
1340 assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP)
1341 assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP)
1342}
1343
1344func TestAgentMerging_LocalOverridesGlobal(t *testing.T) {
1345 reset()
1346 testConfigDir = t.TempDir()
1347 cwdDir := t.TempDir()
1348
1349 // Set up a provider
1350 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1351
1352 // Create global config with custom agent
1353 globalConfig := Config{
1354 Agents: map[AgentID]Agent{
1355 AgentID("test-agent"): {
1356 ID: AgentID("test-agent"),
1357 Name: "Global Agent",
1358 Description: "Global description",
1359 Model: LargeModel,
1360 AllowedTools: []string{"glob"},
1361 },
1362 },
1363 }
1364
1365 configPath := filepath.Join(testConfigDir, "crush.json")
1366 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1367 data, err := json.Marshal(globalConfig)
1368 require.NoError(t, err)
1369 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1370
1371 // Create local config that overrides
1372 localConfig := Config{
1373 Agents: map[AgentID]Agent{
1374 AgentID("test-agent"): {
1375 Name: "Local Agent",
1376 Description: "Local description",
1377 Model: SmallModel,
1378 Disabled: true,
1379 AllowedTools: []string{"grep", "view"},
1380 AllowedMCP: map[string][]string{"local-mcp": {"tool1"}},
1381 },
1382 },
1383 }
1384
1385 localConfigPath := filepath.Join(cwdDir, "crush.json")
1386 data, err = json.Marshal(localConfig)
1387 require.NoError(t, err)
1388 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1389
1390 cfg, err := Init(cwdDir, false)
1391
1392 require.NoError(t, err)
1393
1394 testAgent := cfg.Agents[AgentID("test-agent")]
1395 // Local should override global
1396 assert.Equal(t, "Local Agent", testAgent.Name)
1397 assert.Equal(t, "Local description", testAgent.Description)
1398 assert.Equal(t, SmallModel, testAgent.Model)
1399 assert.True(t, testAgent.Disabled)
1400 assert.Equal(t, []string{"grep", "view"}, testAgent.AllowedTools)
1401 assert.Equal(t, map[string][]string{"local-mcp": {"tool1"}}, testAgent.AllowedMCP)
1402}
1403
1404func TestAgentModelTypeAssignment(t *testing.T) {
1405 reset()
1406 testConfigDir = t.TempDir()
1407 cwdDir := t.TempDir()
1408
1409 // Set up a provider
1410 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1411
1412 // Create config with agents using different model types
1413 globalConfig := Config{
1414 Agents: map[AgentID]Agent{
1415 AgentID("large-agent"): {
1416 ID: AgentID("large-agent"),
1417 Name: "Large Model Agent",
1418 Model: LargeModel,
1419 },
1420 AgentID("small-agent"): {
1421 ID: AgentID("small-agent"),
1422 Name: "Small Model Agent",
1423 Model: SmallModel,
1424 },
1425 AgentID("default-agent"): {
1426 ID: AgentID("default-agent"),
1427 Name: "Default Model Agent",
1428 // No model specified - should default to LargeModel
1429 },
1430 },
1431 }
1432
1433 configPath := filepath.Join(testConfigDir, "crush.json")
1434 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1435 data, err := json.Marshal(globalConfig)
1436 require.NoError(t, err)
1437 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1438
1439 cfg, err := Init(cwdDir, false)
1440
1441 require.NoError(t, err)
1442
1443 assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model)
1444 assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model)
1445 assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) // Should default to LargeModel
1446}
1447
1448func TestAgentContextPathOverrides(t *testing.T) {
1449 reset()
1450 testConfigDir = t.TempDir()
1451 cwdDir := t.TempDir()
1452
1453 // Set up a provider
1454 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1455
1456 // Create config with custom context paths
1457 globalConfig := Config{
1458 Options: Options{
1459 ContextPaths: []string{"global-context.md", "shared-context.md"},
1460 },
1461 Agents: map[AgentID]Agent{
1462 AgentID("custom-context-agent"): {
1463 ID: AgentID("custom-context-agent"),
1464 Name: "Custom Context Agent",
1465 ContextPaths: []string{"agent-specific.md", "custom.md"},
1466 },
1467 AgentID("default-context-agent"): {
1468 ID: AgentID("default-context-agent"),
1469 Name: "Default Context Agent",
1470 // No ContextPaths specified - should use global
1471 },
1472 },
1473 }
1474
1475 configPath := filepath.Join(testConfigDir, "crush.json")
1476 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1477 data, err := json.Marshal(globalConfig)
1478 require.NoError(t, err)
1479 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1480
1481 cfg, err := Init(cwdDir, false)
1482
1483 require.NoError(t, err)
1484
1485 // Agent with custom context paths should have default + global + custom paths (additive)
1486 customAgent := cfg.Agents[AgentID("custom-context-agent")]
1487 expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md")
1488 assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths)
1489
1490 // Agent without custom context paths should use global + defaults
1491 defaultAgent := cfg.Agents[AgentID("default-context-agent")]
1492 expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md")
1493 assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths)
1494
1495 // Default agents should also use the merged context paths
1496 coderAgent := cfg.Agents[AgentCoder]
1497 assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
1498}
1499
1500// Options and Settings Tests
1501
1502func TestOptionsMerging_ContextPaths(t *testing.T) {
1503 reset()
1504 testConfigDir = t.TempDir()
1505 cwdDir := t.TempDir()
1506
1507 // Set up a provider
1508 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1509
1510 // Create global config with context paths
1511 globalConfig := Config{
1512 Options: Options{
1513 ContextPaths: []string{"global1.md", "global2.md"},
1514 },
1515 }
1516
1517 configPath := filepath.Join(testConfigDir, "crush.json")
1518 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1519 data, err := json.Marshal(globalConfig)
1520 require.NoError(t, err)
1521 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1522
1523 // Create local config with additional context paths
1524 localConfig := Config{
1525 Options: Options{
1526 ContextPaths: []string{"local1.md", "local2.md"},
1527 },
1528 }
1529
1530 localConfigPath := filepath.Join(cwdDir, "crush.json")
1531 data, err = json.Marshal(localConfig)
1532 require.NoError(t, err)
1533 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1534
1535 cfg, err := Init(cwdDir, false)
1536
1537 require.NoError(t, err)
1538
1539 // Context paths should be merged: defaults + global + local
1540 expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md")
1541 assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
1542}
1543
1544func TestOptionsMerging_TUIOptions(t *testing.T) {
1545 reset()
1546 testConfigDir = t.TempDir()
1547 cwdDir := t.TempDir()
1548
1549 // Set up a provider
1550 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1551
1552 // Create global config with TUI options
1553 globalConfig := Config{
1554 Options: Options{
1555 TUI: TUIOptions{
1556 CompactMode: false, // Default value
1557 },
1558 },
1559 }
1560
1561 configPath := filepath.Join(testConfigDir, "crush.json")
1562 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1563 data, err := json.Marshal(globalConfig)
1564 require.NoError(t, err)
1565 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1566
1567 // Create local config that enables compact mode
1568 localConfig := Config{
1569 Options: Options{
1570 TUI: TUIOptions{
1571 CompactMode: true,
1572 },
1573 },
1574 }
1575
1576 localConfigPath := filepath.Join(cwdDir, "crush.json")
1577 data, err = json.Marshal(localConfig)
1578 require.NoError(t, err)
1579 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1580
1581 cfg, err := Init(cwdDir, false)
1582
1583 require.NoError(t, err)
1584
1585 // Local config should override global
1586 assert.True(t, cfg.Options.TUI.CompactMode)
1587}
1588
1589func TestOptionsMerging_DebugFlags(t *testing.T) {
1590 reset()
1591 testConfigDir = t.TempDir()
1592 cwdDir := t.TempDir()
1593
1594 // Set up a provider
1595 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1596
1597 // Create global config with debug flags
1598 globalConfig := Config{
1599 Options: Options{
1600 Debug: false,
1601 DebugLSP: false,
1602 DisableAutoSummarize: false,
1603 },
1604 }
1605
1606 configPath := filepath.Join(testConfigDir, "crush.json")
1607 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1608 data, err := json.Marshal(globalConfig)
1609 require.NoError(t, err)
1610 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1611
1612 // Create local config that enables debug flags
1613 localConfig := Config{
1614 Options: Options{
1615 DebugLSP: true,
1616 DisableAutoSummarize: true,
1617 },
1618 }
1619
1620 localConfigPath := filepath.Join(cwdDir, "crush.json")
1621 data, err = json.Marshal(localConfig)
1622 require.NoError(t, err)
1623 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1624
1625 cfg, err := Init(cwdDir, false)
1626
1627 require.NoError(t, err)
1628
1629 // Local config should override global for boolean flags
1630 assert.False(t, cfg.Options.Debug) // Not set in local, remains global value
1631 assert.True(t, cfg.Options.DebugLSP) // Set to true in local
1632 assert.True(t, cfg.Options.DisableAutoSummarize) // Set to true in local
1633}
1634
1635func TestOptionsMerging_DataDirectory(t *testing.T) {
1636 reset()
1637 testConfigDir = t.TempDir()
1638 cwdDir := t.TempDir()
1639
1640 // Set up a provider
1641 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1642
1643 // Create global config with custom data directory
1644 globalConfig := Config{
1645 Options: Options{
1646 DataDirectory: "global-data",
1647 },
1648 }
1649
1650 configPath := filepath.Join(testConfigDir, "crush.json")
1651 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1652 data, err := json.Marshal(globalConfig)
1653 require.NoError(t, err)
1654 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1655
1656 // Create local config with different data directory
1657 localConfig := Config{
1658 Options: Options{
1659 DataDirectory: "local-data",
1660 },
1661 }
1662
1663 localConfigPath := filepath.Join(cwdDir, "crush.json")
1664 data, err = json.Marshal(localConfig)
1665 require.NoError(t, err)
1666 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1667
1668 cfg, err := Init(cwdDir, false)
1669
1670 require.NoError(t, err)
1671
1672 // Local config should override global
1673 assert.Equal(t, "local-data", cfg.Options.DataDirectory)
1674}
1675
1676func TestOptionsMerging_DefaultValues(t *testing.T) {
1677 reset()
1678 testConfigDir = t.TempDir()
1679 cwdDir := t.TempDir()
1680
1681 // Set up a provider
1682 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1683
1684 // No config files - should use defaults
1685 cfg, err := Init(cwdDir, false)
1686
1687 require.NoError(t, err)
1688
1689 // Should have default values
1690 assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
1691 assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
1692 assert.False(t, cfg.Options.TUI.CompactMode)
1693 assert.False(t, cfg.Options.Debug)
1694 assert.False(t, cfg.Options.DebugLSP)
1695 assert.False(t, cfg.Options.DisableAutoSummarize)
1696}
1697
1698func TestOptionsMerging_DebugFlagFromInit(t *testing.T) {
1699 reset()
1700 testConfigDir = t.TempDir()
1701 cwdDir := t.TempDir()
1702
1703 // Set up a provider
1704 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1705
1706 // Create config with debug false
1707 globalConfig := Config{
1708 Options: Options{
1709 Debug: false,
1710 },
1711 }
1712
1713 configPath := filepath.Join(testConfigDir, "crush.json")
1714 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1715 data, err := json.Marshal(globalConfig)
1716 require.NoError(t, err)
1717 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1718
1719 // Init with debug=true should override config
1720 cfg, err := Init(cwdDir, true)
1721
1722 require.NoError(t, err)
1723
1724 // Debug flag from Init should take precedence
1725 assert.True(t, cfg.Options.Debug)
1726}
1727
1728func TestOptionsMerging_ComplexScenario(t *testing.T) {
1729 reset()
1730 testConfigDir = t.TempDir()
1731 cwdDir := t.TempDir()
1732
1733 // Set up a provider
1734 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1735
1736 // Create global config with various options
1737 globalConfig := Config{
1738 Options: Options{
1739 ContextPaths: []string{"global-context.md"},
1740 DataDirectory: "global-data",
1741 Debug: false,
1742 DebugLSP: false,
1743 DisableAutoSummarize: false,
1744 TUI: TUIOptions{
1745 CompactMode: false,
1746 },
1747 },
1748 }
1749
1750 configPath := filepath.Join(testConfigDir, "crush.json")
1751 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1752 data, err := json.Marshal(globalConfig)
1753 require.NoError(t, err)
1754 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1755
1756 // Create local config that partially overrides
1757 localConfig := Config{
1758 Options: Options{
1759 ContextPaths: []string{"local-context.md"},
1760 DebugLSP: true, // Override
1761 DisableAutoSummarize: true, // Override
1762 TUI: TUIOptions{
1763 CompactMode: true, // Override
1764 },
1765 // DataDirectory and Debug not specified - should keep global values
1766 },
1767 }
1768
1769 localConfigPath := filepath.Join(cwdDir, "crush.json")
1770 data, err = json.Marshal(localConfig)
1771 require.NoError(t, err)
1772 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1773
1774 cfg, err := Init(cwdDir, false)
1775
1776 require.NoError(t, err)
1777
1778 // Check merged results
1779 expectedContextPaths := append(defaultContextPaths, "global-context.md", "local-context.md")
1780 assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
1781 assert.Equal(t, "global-data", cfg.Options.DataDirectory) // From global
1782 assert.False(t, cfg.Options.Debug) // From global
1783 assert.True(t, cfg.Options.DebugLSP) // From local
1784 assert.True(t, cfg.Options.DisableAutoSummarize) // From local
1785 assert.True(t, cfg.Options.TUI.CompactMode) // From local
1786}
1787
1788// Model Selection Tests
1789
1790func TestModelSelection_PreferredModelSelection(t *testing.T) {
1791 reset()
1792 testConfigDir = t.TempDir()
1793 cwdDir := t.TempDir()
1794
1795 // Set up multiple providers to test selection logic
1796 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
1797 os.Setenv("OPENAI_API_KEY", "test-openai-key")
1798
1799 cfg, err := Init(cwdDir, false)
1800
1801 require.NoError(t, err)
1802 require.Len(t, cfg.Providers, 2)
1803
1804 // Should have preferred models set
1805 assert.NotEmpty(t, cfg.Models.Large.ModelID)
1806 assert.NotEmpty(t, cfg.Models.Large.Provider)
1807 assert.NotEmpty(t, cfg.Models.Small.ModelID)
1808 assert.NotEmpty(t, cfg.Models.Small.Provider)
1809
1810 // Both should use the same provider (first available)
1811 assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider)
1812}
1813
1814func TestModelSelection_GetAgentModel(t *testing.T) {
1815 reset()
1816 testConfigDir = t.TempDir()
1817 cwdDir := t.TempDir()
1818
1819 // Set up a provider with known models
1820 globalConfig := Config{
1821 Providers: map[provider.InferenceProvider]ProviderConfig{
1822 provider.InferenceProviderOpenAI: {
1823 ID: provider.InferenceProviderOpenAI,
1824 APIKey: "test-key",
1825 ProviderType: provider.TypeOpenAI,
1826 DefaultLargeModel: "gpt-4",
1827 DefaultSmallModel: "gpt-3.5-turbo",
1828 Models: []Model{
1829 {
1830 ID: "gpt-4",
1831 Name: "GPT-4",
1832 ContextWindow: 8192,
1833 DefaultMaxTokens: 4096,
1834 CanReason: true,
1835 SupportsImages: true,
1836 },
1837 {
1838 ID: "gpt-3.5-turbo",
1839 Name: "GPT-3.5 Turbo",
1840 ContextWindow: 4096,
1841 DefaultMaxTokens: 2048,
1842 CanReason: false,
1843 SupportsImages: false,
1844 },
1845 },
1846 },
1847 },
1848 }
1849
1850 configPath := filepath.Join(testConfigDir, "crush.json")
1851 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1852 data, err := json.Marshal(globalConfig)
1853 require.NoError(t, err)
1854 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1855
1856 _, err = Init(cwdDir, false)
1857
1858 require.NoError(t, err)
1859
1860 // Test GetAgentModel for default agents
1861 coderModel := GetAgentModel(AgentCoder)
1862 assert.Equal(t, "gpt-4", coderModel.ID) // Coder uses LargeModel
1863 assert.Equal(t, "GPT-4", coderModel.Name)
1864 assert.True(t, coderModel.CanReason)
1865 assert.True(t, coderModel.SupportsImages)
1866
1867 taskModel := GetAgentModel(AgentTask)
1868 assert.Equal(t, "gpt-4", taskModel.ID) // Task also uses LargeModel by default
1869 assert.Equal(t, "GPT-4", taskModel.Name)
1870}
1871
1872func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) {
1873 reset()
1874 testConfigDir = t.TempDir()
1875 cwdDir := t.TempDir()
1876
1877 // Set up provider and custom agent with SmallModel
1878 globalConfig := Config{
1879 Providers: map[provider.InferenceProvider]ProviderConfig{
1880 provider.InferenceProviderOpenAI: {
1881 ID: provider.InferenceProviderOpenAI,
1882 APIKey: "test-key",
1883 ProviderType: provider.TypeOpenAI,
1884 DefaultLargeModel: "gpt-4",
1885 DefaultSmallModel: "gpt-3.5-turbo",
1886 Models: []Model{
1887 {
1888 ID: "gpt-4",
1889 Name: "GPT-4",
1890 ContextWindow: 8192,
1891 DefaultMaxTokens: 4096,
1892 },
1893 {
1894 ID: "gpt-3.5-turbo",
1895 Name: "GPT-3.5 Turbo",
1896 ContextWindow: 4096,
1897 DefaultMaxTokens: 2048,
1898 },
1899 },
1900 },
1901 },
1902 Agents: map[AgentID]Agent{
1903 AgentID("small-agent"): {
1904 ID: AgentID("small-agent"),
1905 Name: "Small Agent",
1906 Model: SmallModel,
1907 },
1908 },
1909 }
1910
1911 configPath := filepath.Join(testConfigDir, "crush.json")
1912 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1913 data, err := json.Marshal(globalConfig)
1914 require.NoError(t, err)
1915 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1916
1917 _, err = Init(cwdDir, false)
1918
1919 require.NoError(t, err)
1920
1921 // Test GetAgentModel for custom agent with SmallModel
1922 smallAgentModel := GetAgentModel(AgentID("small-agent"))
1923 assert.Equal(t, "gpt-3.5-turbo", smallAgentModel.ID)
1924 assert.Equal(t, "GPT-3.5 Turbo", smallAgentModel.Name)
1925}
1926
1927func TestModelSelection_GetAgentProvider(t *testing.T) {
1928 reset()
1929 testConfigDir = t.TempDir()
1930 cwdDir := t.TempDir()
1931
1932 // Set up multiple providers
1933 globalConfig := Config{
1934 Providers: map[provider.InferenceProvider]ProviderConfig{
1935 provider.InferenceProviderOpenAI: {
1936 ID: provider.InferenceProviderOpenAI,
1937 APIKey: "openai-key",
1938 ProviderType: provider.TypeOpenAI,
1939 DefaultLargeModel: "gpt-4",
1940 DefaultSmallModel: "gpt-3.5-turbo",
1941 },
1942 provider.InferenceProviderAnthropic: {
1943 ID: provider.InferenceProviderAnthropic,
1944 APIKey: "anthropic-key",
1945 ProviderType: provider.TypeAnthropic,
1946 DefaultLargeModel: "claude-3-opus",
1947 DefaultSmallModel: "claude-3-haiku",
1948 },
1949 },
1950 }
1951
1952 configPath := filepath.Join(testConfigDir, "crush.json")
1953 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1954 data, err := json.Marshal(globalConfig)
1955 require.NoError(t, err)
1956 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1957
1958 _, err = Init(cwdDir, false)
1959
1960 require.NoError(t, err)
1961
1962 // Test GetAgentProvider
1963 coderProvider := GetAgentProvider(AgentCoder)
1964 assert.NotEmpty(t, coderProvider.ID)
1965 assert.NotEmpty(t, coderProvider.APIKey)
1966 assert.NotEmpty(t, coderProvider.ProviderType)
1967}
1968
1969func TestModelSelection_GetProviderModel(t *testing.T) {
1970 reset()
1971 testConfigDir = t.TempDir()
1972 cwdDir := t.TempDir()
1973
1974 // Set up provider with specific models
1975 globalConfig := Config{
1976 Providers: map[provider.InferenceProvider]ProviderConfig{
1977 provider.InferenceProviderOpenAI: {
1978 ID: provider.InferenceProviderOpenAI,
1979 APIKey: "test-key",
1980 ProviderType: provider.TypeOpenAI,
1981 Models: []Model{
1982 {
1983 ID: "gpt-4",
1984 Name: "GPT-4",
1985 ContextWindow: 8192,
1986 DefaultMaxTokens: 4096,
1987 CostPer1MIn: 30.0,
1988 CostPer1MOut: 60.0,
1989 },
1990 {
1991 ID: "gpt-3.5-turbo",
1992 Name: "GPT-3.5 Turbo",
1993 ContextWindow: 4096,
1994 DefaultMaxTokens: 2048,
1995 CostPer1MIn: 1.5,
1996 CostPer1MOut: 2.0,
1997 },
1998 },
1999 },
2000 },
2001 }
2002
2003 configPath := filepath.Join(testConfigDir, "crush.json")
2004 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
2005 data, err := json.Marshal(globalConfig)
2006 require.NoError(t, err)
2007 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2008
2009 _, err = Init(cwdDir, false)
2010
2011 require.NoError(t, err)
2012
2013 // Test GetProviderModel
2014 gpt4Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-4")
2015 assert.Equal(t, "gpt-4", gpt4Model.ID)
2016 assert.Equal(t, "GPT-4", gpt4Model.Name)
2017 assert.Equal(t, int64(8192), gpt4Model.ContextWindow)
2018 assert.Equal(t, 30.0, gpt4Model.CostPer1MIn)
2019
2020 gpt35Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-3.5-turbo")
2021 assert.Equal(t, "gpt-3.5-turbo", gpt35Model.ID)
2022 assert.Equal(t, "GPT-3.5 Turbo", gpt35Model.Name)
2023 assert.Equal(t, 1.5, gpt35Model.CostPer1MIn)
2024
2025 // Test non-existent model
2026 nonExistentModel := GetProviderModel(provider.InferenceProviderOpenAI, "non-existent")
2027 assert.Empty(t, nonExistentModel.ID)
2028}
2029
2030func TestModelSelection_GetModel(t *testing.T) {
2031 reset()
2032 testConfigDir = t.TempDir()
2033 cwdDir := t.TempDir()
2034
2035 // Set up provider with models
2036 globalConfig := Config{
2037 Providers: map[provider.InferenceProvider]ProviderConfig{
2038 provider.InferenceProviderOpenAI: {
2039 ID: provider.InferenceProviderOpenAI,
2040 APIKey: "test-key",
2041 ProviderType: provider.TypeOpenAI,
2042 DefaultLargeModel: "gpt-4",
2043 DefaultSmallModel: "gpt-3.5-turbo",
2044 Models: []Model{
2045 {
2046 ID: "gpt-4",
2047 Name: "GPT-4",
2048 ContextWindow: 8192,
2049 DefaultMaxTokens: 4096,
2050 },
2051 {
2052 ID: "gpt-3.5-turbo",
2053 Name: "GPT-3.5 Turbo",
2054 ContextWindow: 4096,
2055 DefaultMaxTokens: 2048,
2056 },
2057 },
2058 },
2059 },
2060 }
2061
2062 configPath := filepath.Join(testConfigDir, "crush.json")
2063 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
2064 data, err := json.Marshal(globalConfig)
2065 require.NoError(t, err)
2066 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2067
2068 _, err = Init(cwdDir, false)
2069
2070 require.NoError(t, err)
2071
2072 // Test GetModel
2073 largeModel := GetModel(LargeModel)
2074 assert.Equal(t, "gpt-4", largeModel.ID)
2075 assert.Equal(t, "GPT-4", largeModel.Name)
2076
2077 smallModel := GetModel(SmallModel)
2078 assert.Equal(t, "gpt-3.5-turbo", smallModel.ID)
2079 assert.Equal(t, "GPT-3.5 Turbo", smallModel.Name)
2080}
2081
2082func TestModelSelection_UpdatePreferredModel(t *testing.T) {
2083 reset()
2084 testConfigDir = t.TempDir()
2085 cwdDir := t.TempDir()
2086
2087 // Set up multiple providers with OpenAI first to ensure it's selected initially
2088 globalConfig := Config{
2089 Providers: map[provider.InferenceProvider]ProviderConfig{
2090 provider.InferenceProviderOpenAI: {
2091 ID: provider.InferenceProviderOpenAI,
2092 APIKey: "openai-key",
2093 ProviderType: provider.TypeOpenAI,
2094 DefaultLargeModel: "gpt-4",
2095 DefaultSmallModel: "gpt-3.5-turbo",
2096 Models: []Model{
2097 {ID: "gpt-4", Name: "GPT-4"},
2098 {ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo"},
2099 },
2100 },
2101 provider.InferenceProviderAnthropic: {
2102 ID: provider.InferenceProviderAnthropic,
2103 APIKey: "anthropic-key",
2104 ProviderType: provider.TypeAnthropic,
2105 DefaultLargeModel: "claude-3-opus",
2106 DefaultSmallModel: "claude-3-haiku",
2107 Models: []Model{
2108 {ID: "claude-3-opus", Name: "Claude 3 Opus"},
2109 {ID: "claude-3-haiku", Name: "Claude 3 Haiku"},
2110 },
2111 },
2112 },
2113 }
2114
2115 configPath := filepath.Join(testConfigDir, "crush.json")
2116 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
2117 data, err := json.Marshal(globalConfig)
2118 require.NoError(t, err)
2119 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2120
2121 _, err = Init(cwdDir, false)
2122
2123 require.NoError(t, err)
2124
2125 // Get initial preferred models (should be OpenAI since it's listed first)
2126 initialLargeModel := GetModel(LargeModel)
2127 initialSmallModel := GetModel(SmallModel)
2128
2129 // Verify initial models are OpenAI models
2130 assert.Equal(t, "claude-3-opus", initialLargeModel.ID)
2131 assert.Equal(t, "claude-3-haiku", initialSmallModel.ID)
2132
2133 // Update preferred models to Anthropic
2134 newLargeModel := PreferredModel{
2135 ModelID: "gpt-4",
2136 Provider: provider.InferenceProviderOpenAI,
2137 }
2138 newSmallModel := PreferredModel{
2139 ModelID: "gpt-3.5-turbo",
2140 Provider: provider.InferenceProviderOpenAI,
2141 }
2142
2143 err = UpdatePreferredModel(LargeModel, newLargeModel)
2144 require.NoError(t, err)
2145
2146 err = UpdatePreferredModel(SmallModel, newSmallModel)
2147 require.NoError(t, err)
2148
2149 // Verify models were updated
2150 updatedLargeModel := GetModel(LargeModel)
2151 assert.Equal(t, "gpt-4", updatedLargeModel.ID)
2152 assert.NotEqual(t, initialLargeModel.ID, updatedLargeModel.ID)
2153
2154 updatedSmallModel := GetModel(SmallModel)
2155 assert.Equal(t, "gpt-3.5-turbo", updatedSmallModel.ID)
2156 assert.NotEqual(t, initialSmallModel.ID, updatedSmallModel.ID)
2157}
2158
2159func TestModelSelection_InvalidModelType(t *testing.T) {
2160 reset()
2161 testConfigDir = t.TempDir()
2162 cwdDir := t.TempDir()
2163
2164 // Set up a provider
2165 os.Setenv("ANTHROPIC_API_KEY", "test-key")
2166
2167 _, err := Init(cwdDir, false)
2168 require.NoError(t, err)
2169
2170 // Test UpdatePreferredModel with invalid model type
2171 invalidModel := PreferredModel{
2172 ModelID: "some-model",
2173 Provider: provider.InferenceProviderAnthropic,
2174 }
2175
2176 err = UpdatePreferredModel(ModelType("invalid"), invalidModel)
2177 assert.Error(t, err)
2178 assert.Contains(t, err.Error(), "unknown model type")
2179}
2180
2181func TestModelSelection_NonExistentAgent(t *testing.T) {
2182 reset()
2183 testConfigDir = t.TempDir()
2184 cwdDir := t.TempDir()
2185
2186 // Set up a provider
2187 os.Setenv("ANTHROPIC_API_KEY", "test-key")
2188
2189 _, err := Init(cwdDir, false)
2190 require.NoError(t, err)
2191
2192 // Test GetAgentModel with non-existent agent
2193 nonExistentModel := GetAgentModel(AgentID("non-existent"))
2194 assert.Empty(t, nonExistentModel.ID)
2195
2196 // Test GetAgentProvider with non-existent agent
2197 nonExistentProvider := GetAgentProvider(AgentID("non-existent"))
2198 assert.Empty(t, nonExistentProvider.ID)
2199}
2200
2201func TestModelSelection_NonExistentProvider(t *testing.T) {
2202 reset()
2203 testConfigDir = t.TempDir()
2204 cwdDir := t.TempDir()
2205
2206 // Set up a provider
2207 os.Setenv("ANTHROPIC_API_KEY", "test-key")
2208
2209 _, err := Init(cwdDir, false)
2210 require.NoError(t, err)
2211
2212 // Test GetProviderModel with non-existent provider
2213 nonExistentModel := GetProviderModel(provider.InferenceProvider("non-existent"), "some-model")
2214 assert.Empty(t, nonExistentModel.ID)
2215}