1package config
2
3import (
4 "encoding/json"
5 "os"
6 "path/filepath"
7 "sync"
8 "testing"
9
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13)
14
15func reset() {
16 // Clear all environment variables that could affect config
17 envVarsToUnset := []string{
18 // API Keys
19 "ANTHROPIC_API_KEY",
20 "OPENAI_API_KEY",
21 "GEMINI_API_KEY",
22 "XAI_API_KEY",
23 "OPENROUTER_API_KEY",
24
25 // Google Cloud / VertexAI
26 "GOOGLE_GENAI_USE_VERTEXAI",
27 "GOOGLE_CLOUD_PROJECT",
28 "GOOGLE_CLOUD_LOCATION",
29
30 // AWS Credentials
31 "AWS_ACCESS_KEY_ID",
32 "AWS_SECRET_ACCESS_KEY",
33 "AWS_REGION",
34 "AWS_DEFAULT_REGION",
35 "AWS_PROFILE",
36 "AWS_DEFAULT_PROFILE",
37 "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
38 "AWS_CONTAINER_CREDENTIALS_FULL_URI",
39
40 // Other
41 "CRUSH_DEV_DEBUG",
42 }
43
44 for _, envVar := range envVarsToUnset {
45 os.Unsetenv(envVar)
46 }
47
48 // Reset singleton
49 once = sync.Once{}
50 instance = nil
51 cwd = ""
52 testConfigDir = ""
53
54 // Enable mock providers for all tests to avoid API calls
55 UseMockProviders = true
56 ResetProviders()
57}
58
59// Core Configuration Loading Tests
60
61func TestInit_ValidWorkingDirectory(t *testing.T) {
62 reset()
63 testConfigDir = t.TempDir()
64 cwdDir := t.TempDir()
65
66 cfg, err := Init(cwdDir, false)
67
68 require.NoError(t, err)
69 assert.NotNil(t, cfg)
70 assert.Equal(t, cwdDir, WorkingDirectory())
71 assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
72 assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
73}
74
75func TestInit_WithDebugFlag(t *testing.T) {
76 reset()
77 testConfigDir = t.TempDir()
78 cwdDir := t.TempDir()
79
80 cfg, err := Init(cwdDir, true)
81
82 require.NoError(t, err)
83 assert.True(t, cfg.Options.Debug)
84}
85
86func TestInit_SingletonBehavior(t *testing.T) {
87 reset()
88 testConfigDir = t.TempDir()
89 cwdDir := t.TempDir()
90
91 cfg1, err1 := Init(cwdDir, false)
92 cfg2, err2 := Init(cwdDir, false)
93
94 require.NoError(t, err1)
95 require.NoError(t, err2)
96 assert.Same(t, cfg1, cfg2) // Should be the same instance
97}
98
99func TestGet_BeforeInitialization(t *testing.T) {
100 reset()
101
102 assert.Panics(t, func() {
103 Get()
104 })
105}
106
107func TestGet_AfterInitialization(t *testing.T) {
108 reset()
109 testConfigDir = t.TempDir()
110 cwdDir := t.TempDir()
111
112 cfg1, err := Init(cwdDir, false)
113 require.NoError(t, err)
114
115 cfg2 := Get()
116 assert.Same(t, cfg1, cfg2)
117}
118
119func TestLoadConfig_NoConfigFiles(t *testing.T) {
120 reset()
121 testConfigDir = t.TempDir()
122 cwdDir := t.TempDir()
123
124 cfg, err := Init(cwdDir, false)
125
126 require.NoError(t, err)
127 assert.Len(t, cfg.Providers, 0) // No providers without env vars or config files
128 assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
129}
130
131func TestLoadConfig_OnlyGlobalConfig(t *testing.T) {
132 reset()
133 testConfigDir = t.TempDir()
134 cwdDir := t.TempDir()
135
136 // Create global config file
137 globalConfig := Config{
138 Providers: map[provider.InferenceProvider]ProviderConfig{
139 provider.InferenceProviderOpenAI: {
140 ID: provider.InferenceProviderOpenAI,
141 APIKey: "test-key",
142 ProviderType: provider.TypeOpenAI,
143 DefaultLargeModel: "gpt-4",
144 DefaultSmallModel: "gpt-3.5-turbo",
145 Models: []Model{
146 {
147 ID: "gpt-4",
148 Name: "GPT-4",
149 CostPer1MIn: 30.0,
150 CostPer1MOut: 60.0,
151 ContextWindow: 8192,
152 DefaultMaxTokens: 4096,
153 },
154 {
155 ID: "gpt-3.5-turbo",
156 Name: "GPT-3.5 Turbo",
157 CostPer1MIn: 1.0,
158 CostPer1MOut: 2.0,
159 ContextWindow: 4096,
160 DefaultMaxTokens: 4096,
161 },
162 },
163 },
164 },
165 Options: Options{
166 ContextPaths: []string{"custom-context.md"},
167 },
168 }
169
170 configPath := filepath.Join(testConfigDir, "crush.json")
171 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
172
173 data, err := json.Marshal(globalConfig)
174 require.NoError(t, err)
175 require.NoError(t, os.WriteFile(configPath, data, 0o644))
176
177 cfg, err := Init(cwdDir, false)
178
179 require.NoError(t, err)
180 assert.Len(t, cfg.Providers, 1)
181 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
182 assert.Contains(t, cfg.Options.ContextPaths, "custom-context.md")
183}
184
185func TestLoadConfig_OnlyLocalConfig(t *testing.T) {
186 reset()
187 testConfigDir = t.TempDir()
188 cwdDir := t.TempDir()
189
190 // Create local config file
191 localConfig := Config{
192 Providers: map[provider.InferenceProvider]ProviderConfig{
193 provider.InferenceProviderAnthropic: {
194 ID: provider.InferenceProviderAnthropic,
195 APIKey: "local-key",
196 ProviderType: provider.TypeAnthropic,
197 DefaultLargeModel: "claude-3-opus",
198 DefaultSmallModel: "claude-3-haiku",
199 Models: []Model{
200 {
201 ID: "claude-3-opus",
202 Name: "Claude 3 Opus",
203 CostPer1MIn: 15.0,
204 CostPer1MOut: 75.0,
205 ContextWindow: 200000,
206 DefaultMaxTokens: 4096,
207 },
208 {
209 ID: "claude-3-haiku",
210 Name: "Claude 3 Haiku",
211 CostPer1MIn: 0.25,
212 CostPer1MOut: 1.25,
213 ContextWindow: 200000,
214 DefaultMaxTokens: 4096,
215 },
216 },
217 },
218 },
219 Options: Options{
220 TUI: TUIOptions{CompactMode: true},
221 },
222 }
223
224 localConfigPath := filepath.Join(cwdDir, "crush.json")
225 data, err := json.Marshal(localConfig)
226 require.NoError(t, err)
227 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
228
229 cfg, err := Init(cwdDir, false)
230
231 require.NoError(t, err)
232 assert.Len(t, cfg.Providers, 1)
233 assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
234 assert.True(t, cfg.Options.TUI.CompactMode)
235}
236
237func TestLoadConfig_BothGlobalAndLocal(t *testing.T) {
238 reset()
239 testConfigDir = t.TempDir()
240 cwdDir := t.TempDir()
241
242 // Create global config
243 globalConfig := Config{
244 Providers: map[provider.InferenceProvider]ProviderConfig{
245 provider.InferenceProviderOpenAI: {
246 ID: provider.InferenceProviderOpenAI,
247 APIKey: "global-key",
248 ProviderType: provider.TypeOpenAI,
249 DefaultLargeModel: "gpt-4",
250 DefaultSmallModel: "gpt-3.5-turbo",
251 Models: []Model{
252 {
253 ID: "gpt-4",
254 Name: "GPT-4",
255 CostPer1MIn: 30.0,
256 CostPer1MOut: 60.0,
257 ContextWindow: 8192,
258 DefaultMaxTokens: 4096,
259 },
260 {
261 ID: "gpt-3.5-turbo",
262 Name: "GPT-3.5 Turbo",
263 CostPer1MIn: 1.0,
264 CostPer1MOut: 2.0,
265 ContextWindow: 4096,
266 DefaultMaxTokens: 4096,
267 },
268 },
269 },
270 },
271 Options: Options{
272 ContextPaths: []string{"global-context.md"},
273 },
274 }
275
276 configPath := filepath.Join(testConfigDir, "crush.json")
277 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
278 data, err := json.Marshal(globalConfig)
279 require.NoError(t, err)
280 require.NoError(t, os.WriteFile(configPath, data, 0o644))
281
282 // Create local config that overrides and adds
283 localConfig := Config{
284 Providers: map[provider.InferenceProvider]ProviderConfig{
285 provider.InferenceProviderOpenAI: {
286 APIKey: "local-key", // Override global
287 },
288 provider.InferenceProviderAnthropic: {
289 ID: provider.InferenceProviderAnthropic,
290 APIKey: "anthropic-key",
291 ProviderType: provider.TypeAnthropic,
292 DefaultLargeModel: "claude-3-opus",
293 DefaultSmallModel: "claude-3-haiku",
294 Models: []Model{
295 {
296 ID: "claude-3-opus",
297 Name: "Claude 3 Opus",
298 CostPer1MIn: 15.0,
299 CostPer1MOut: 75.0,
300 ContextWindow: 200000,
301 DefaultMaxTokens: 4096,
302 },
303 {
304 ID: "claude-3-haiku",
305 Name: "Claude 3 Haiku",
306 CostPer1MIn: 0.25,
307 CostPer1MOut: 1.25,
308 ContextWindow: 200000,
309 DefaultMaxTokens: 4096,
310 },
311 },
312 },
313 },
314 Options: Options{
315 ContextPaths: []string{"local-context.md"},
316 TUI: TUIOptions{CompactMode: true},
317 },
318 }
319
320 localConfigPath := filepath.Join(cwdDir, "crush.json")
321 data, err = json.Marshal(localConfig)
322 require.NoError(t, err)
323 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
324
325 cfg, err := Init(cwdDir, false)
326
327 require.NoError(t, err)
328 assert.Len(t, cfg.Providers, 2)
329
330 // Check that local config overrode global
331 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
332 assert.Equal(t, "local-key", openaiProvider.APIKey)
333
334 // Check that local config added new provider
335 assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
336
337 // Check that context paths were merged
338 assert.Contains(t, cfg.Options.ContextPaths, "global-context.md")
339 assert.Contains(t, cfg.Options.ContextPaths, "local-context.md")
340 assert.True(t, cfg.Options.TUI.CompactMode)
341}
342
343func TestLoadConfig_MalformedGlobalJSON(t *testing.T) {
344 reset()
345 testConfigDir = t.TempDir()
346 cwdDir := t.TempDir()
347
348 // Create malformed global config
349 configPath := filepath.Join(testConfigDir, "crush.json")
350 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
351 require.NoError(t, os.WriteFile(configPath, []byte(`{invalid json`), 0o644))
352
353 _, err := Init(cwdDir, false)
354 assert.Error(t, err)
355}
356
357func TestLoadConfig_MalformedLocalJSON(t *testing.T) {
358 reset()
359 testConfigDir = t.TempDir()
360 cwdDir := t.TempDir()
361
362 // Create malformed local config
363 localConfigPath := filepath.Join(cwdDir, "crush.json")
364 require.NoError(t, os.WriteFile(localConfigPath, []byte(`{invalid json`), 0o644))
365
366 _, err := Init(cwdDir, false)
367 assert.Error(t, err)
368}
369
370func TestConfigWithoutEnv(t *testing.T) {
371 reset()
372 testConfigDir = t.TempDir()
373 cwdDir := t.TempDir()
374
375 cfg, _ := Init(cwdDir, false)
376 assert.Len(t, cfg.Providers, 0)
377}
378
379func TestConfigWithEnv(t *testing.T) {
380 reset()
381 testConfigDir = t.TempDir()
382 cwdDir := t.TempDir()
383
384 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
385 os.Setenv("OPENAI_API_KEY", "test-openai-key")
386 os.Setenv("GEMINI_API_KEY", "test-gemini-key")
387 os.Setenv("XAI_API_KEY", "test-xai-key")
388 os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
389
390 cfg, _ := Init(cwdDir, false)
391 assert.Len(t, cfg.Providers, 5)
392}
393
394// Environment Variable Tests
395
396func TestEnvVars_NoEnvironmentVariables(t *testing.T) {
397 reset()
398 testConfigDir = t.TempDir()
399 cwdDir := t.TempDir()
400
401 cfg, err := Init(cwdDir, false)
402
403 require.NoError(t, err)
404 assert.Len(t, cfg.Providers, 0)
405}
406
407func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
408 reset()
409 testConfigDir = t.TempDir()
410 cwdDir := t.TempDir()
411
412 // Set all supported API keys
413 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
414 os.Setenv("OPENAI_API_KEY", "test-openai-key")
415 os.Setenv("GEMINI_API_KEY", "test-gemini-key")
416 os.Setenv("XAI_API_KEY", "test-xai-key")
417 os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
418
419 cfg, err := Init(cwdDir, false)
420
421 require.NoError(t, err)
422 assert.Len(t, cfg.Providers, 5)
423
424 // Verify each provider is configured correctly
425 anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
426 assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey)
427 assert.Equal(t, provider.TypeAnthropic, anthropicProvider.ProviderType)
428
429 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
430 assert.Equal(t, "test-openai-key", openaiProvider.APIKey)
431 assert.Equal(t, provider.TypeOpenAI, openaiProvider.ProviderType)
432
433 geminiProvider := cfg.Providers[provider.InferenceProviderGemini]
434 assert.Equal(t, "test-gemini-key", geminiProvider.APIKey)
435 assert.Equal(t, provider.TypeGemini, geminiProvider.ProviderType)
436
437 xaiProvider := cfg.Providers[provider.InferenceProviderXAI]
438 assert.Equal(t, "test-xai-key", xaiProvider.APIKey)
439 assert.Equal(t, provider.TypeXAI, xaiProvider.ProviderType)
440
441 openrouterProvider := cfg.Providers[provider.InferenceProviderOpenRouter]
442 assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey)
443 assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType)
444 assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL)
445}
446
447func TestEnvVars_PartialEnvironmentVariables(t *testing.T) {
448 reset()
449 testConfigDir = t.TempDir()
450 cwdDir := t.TempDir()
451
452 // Set only some API keys
453 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
454 os.Setenv("OPENAI_API_KEY", "test-openai-key")
455
456 cfg, err := Init(cwdDir, false)
457
458 require.NoError(t, err)
459 assert.Len(t, cfg.Providers, 2)
460 assert.Contains(t, cfg.Providers, provider.InferenceProviderAnthropic)
461 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
462 assert.NotContains(t, cfg.Providers, provider.InferenceProviderGemini)
463}
464
465func TestEnvVars_VertexAIConfiguration(t *testing.T) {
466 reset()
467 testConfigDir = t.TempDir()
468 cwdDir := t.TempDir()
469
470 // Set VertexAI environment variables
471 os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
472 os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
473 os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
474
475 cfg, err := Init(cwdDir, false)
476
477 require.NoError(t, err)
478 assert.Contains(t, cfg.Providers, provider.InferenceProviderVertexAI)
479
480 vertexProvider := cfg.Providers[provider.InferenceProviderVertexAI]
481 assert.Equal(t, provider.TypeVertexAI, vertexProvider.ProviderType)
482 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
483 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
484}
485
486func TestEnvVars_VertexAIWithoutUseFlag(t *testing.T) {
487 reset()
488 testConfigDir = t.TempDir()
489 cwdDir := t.TempDir()
490
491 // Set Google Cloud vars but not the use flag
492 os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
493 os.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1")
494
495 cfg, err := Init(cwdDir, false)
496
497 require.NoError(t, err)
498 assert.NotContains(t, cfg.Providers, provider.InferenceProviderVertexAI)
499}
500
501func TestEnvVars_AWSBedrockWithAccessKeys(t *testing.T) {
502 reset()
503 testConfigDir = t.TempDir()
504 cwdDir := t.TempDir()
505
506 // Set AWS credentials
507 os.Setenv("AWS_ACCESS_KEY_ID", "test-access-key")
508 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
509 os.Setenv("AWS_DEFAULT_REGION", "us-east-1")
510
511 cfg, err := Init(cwdDir, false)
512
513 require.NoError(t, err)
514 assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
515
516 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
517 assert.Equal(t, provider.TypeBedrock, bedrockProvider.ProviderType)
518 assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
519}
520
521func TestEnvVars_AWSBedrockWithProfile(t *testing.T) {
522 reset()
523 testConfigDir = t.TempDir()
524 cwdDir := t.TempDir()
525
526 // Set AWS profile
527 os.Setenv("AWS_PROFILE", "test-profile")
528 os.Setenv("AWS_REGION", "eu-west-1")
529
530 cfg, err := Init(cwdDir, false)
531
532 require.NoError(t, err)
533 assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
534
535 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
536 assert.Equal(t, "eu-west-1", bedrockProvider.ExtraParams["region"])
537}
538
539func TestEnvVars_AWSBedrockWithContainerCredentials(t *testing.T) {
540 reset()
541 testConfigDir = t.TempDir()
542 cwdDir := t.TempDir()
543
544 // Set AWS container credentials
545 os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
546 os.Setenv("AWS_DEFAULT_REGION", "ap-southeast-1")
547
548 cfg, err := Init(cwdDir, false)
549
550 require.NoError(t, err)
551 assert.Contains(t, cfg.Providers, provider.InferenceProviderBedrock)
552}
553
554func TestEnvVars_AWSBedrockRegionPriority(t *testing.T) {
555 reset()
556 testConfigDir = t.TempDir()
557 cwdDir := t.TempDir()
558
559 // Set both region variables - AWS_DEFAULT_REGION should take priority
560 os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
561 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
562 os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
563 os.Setenv("AWS_REGION", "us-east-1")
564
565 cfg, err := Init(cwdDir, false)
566
567 require.NoError(t, err)
568 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
569 assert.Equal(t, "us-west-2", bedrockProvider.ExtraParams["region"])
570}
571
572func TestEnvVars_AWSBedrockFallbackRegion(t *testing.T) {
573 reset()
574 testConfigDir = t.TempDir()
575 cwdDir := t.TempDir()
576
577 // Set only AWS_REGION (not AWS_DEFAULT_REGION)
578 os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
579 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
580 os.Setenv("AWS_REGION", "us-east-1")
581
582 cfg, err := Init(cwdDir, false)
583
584 require.NoError(t, err)
585 bedrockProvider := cfg.Providers[provider.InferenceProviderBedrock]
586 assert.Equal(t, "us-east-1", bedrockProvider.ExtraParams["region"])
587}
588
589func TestEnvVars_NoAWSCredentials(t *testing.T) {
590 reset()
591 testConfigDir = t.TempDir()
592 cwdDir := t.TempDir()
593
594 // Don't set any AWS credentials
595 cfg, err := Init(cwdDir, false)
596
597 require.NoError(t, err)
598 assert.NotContains(t, cfg.Providers, provider.InferenceProviderBedrock)
599}
600
601func TestEnvVars_CustomEnvironmentVariables(t *testing.T) {
602 reset()
603 testConfigDir = t.TempDir()
604 cwdDir := t.TempDir()
605
606 // Test that environment variables are properly resolved from provider definitions
607 // This test assumes the provider system uses $VARIABLE_NAME format
608 os.Setenv("ANTHROPIC_API_KEY", "resolved-anthropic-key")
609
610 cfg, err := Init(cwdDir, false)
611
612 require.NoError(t, err)
613 if len(cfg.Providers) > 0 {
614 // Verify that the environment variable was resolved
615 if anthropicProvider, exists := cfg.Providers[provider.InferenceProviderAnthropic]; exists {
616 assert.Equal(t, "resolved-anthropic-key", anthropicProvider.APIKey)
617 }
618 }
619}
620
621func TestEnvVars_CombinedEnvironmentVariables(t *testing.T) {
622 reset()
623 testConfigDir = t.TempDir()
624 cwdDir := t.TempDir()
625
626 // Set multiple types of environment variables
627 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic")
628 os.Setenv("OPENAI_API_KEY", "test-openai")
629 os.Setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
630 os.Setenv("GOOGLE_CLOUD_PROJECT", "test-project")
631 os.Setenv("AWS_ACCESS_KEY_ID", "test-aws-key")
632 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-aws-secret")
633 os.Setenv("AWS_DEFAULT_REGION", "us-west-1")
634
635 cfg, err := Init(cwdDir, false)
636
637 require.NoError(t, err)
638
639 // Should have API key providers + VertexAI + Bedrock
640 expectedProviders := []provider.InferenceProvider{
641 provider.InferenceProviderAnthropic,
642 provider.InferenceProviderOpenAI,
643 provider.InferenceProviderVertexAI,
644 provider.InferenceProviderBedrock,
645 }
646
647 for _, expectedProvider := range expectedProviders {
648 assert.Contains(t, cfg.Providers, expectedProvider)
649 }
650}
651
652func TestHasAWSCredentials_AccessKeys(t *testing.T) {
653 reset()
654
655 os.Setenv("AWS_ACCESS_KEY_ID", "test-key")
656 os.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
657
658 assert.True(t, hasAWSCredentials())
659}
660
661func TestHasAWSCredentials_Profile(t *testing.T) {
662 reset()
663
664 os.Setenv("AWS_PROFILE", "test-profile")
665
666 assert.True(t, hasAWSCredentials())
667}
668
669func TestHasAWSCredentials_DefaultProfile(t *testing.T) {
670 reset()
671
672 os.Setenv("AWS_DEFAULT_PROFILE", "default")
673
674 assert.True(t, hasAWSCredentials())
675}
676
677func TestHasAWSCredentials_Region(t *testing.T) {
678 reset()
679
680 os.Setenv("AWS_REGION", "us-east-1")
681
682 assert.True(t, hasAWSCredentials())
683}
684
685func TestHasAWSCredentials_ContainerCredentials(t *testing.T) {
686 reset()
687
688 os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials/test")
689
690 assert.True(t, hasAWSCredentials())
691}
692
693func TestHasAWSCredentials_NoCredentials(t *testing.T) {
694 reset()
695
696 assert.False(t, hasAWSCredentials())
697}
698
699// Provider Configuration Tests
700
701func TestProviderMerging_GlobalToBase(t *testing.T) {
702 reset()
703 testConfigDir = t.TempDir()
704 cwdDir := t.TempDir()
705
706 // Create global config with provider
707 globalConfig := Config{
708 Providers: map[provider.InferenceProvider]ProviderConfig{
709 provider.InferenceProviderOpenAI: {
710 ID: provider.InferenceProviderOpenAI,
711 APIKey: "global-openai-key",
712 ProviderType: provider.TypeOpenAI,
713 DefaultLargeModel: "gpt-4",
714 DefaultSmallModel: "gpt-3.5-turbo",
715 Models: []Model{
716 {
717 ID: "gpt-4",
718 Name: "GPT-4",
719 ContextWindow: 8192,
720 DefaultMaxTokens: 4096,
721 },
722 },
723 },
724 },
725 }
726
727 configPath := filepath.Join(testConfigDir, "crush.json")
728 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
729 data, err := json.Marshal(globalConfig)
730 require.NoError(t, err)
731 require.NoError(t, os.WriteFile(configPath, data, 0o644))
732
733 cfg, err := Init(cwdDir, false)
734
735 require.NoError(t, err)
736 assert.Len(t, cfg.Providers, 1)
737
738 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
739 assert.Equal(t, "global-openai-key", openaiProvider.APIKey)
740 assert.Equal(t, "gpt-4", openaiProvider.DefaultLargeModel)
741 assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel)
742 assert.Len(t, openaiProvider.Models, 1)
743}
744
745func TestProviderMerging_LocalToBase(t *testing.T) {
746 reset()
747 testConfigDir = t.TempDir()
748 cwdDir := t.TempDir()
749
750 // Create local config with provider
751 localConfig := Config{
752 Providers: map[provider.InferenceProvider]ProviderConfig{
753 provider.InferenceProviderAnthropic: {
754 ID: provider.InferenceProviderAnthropic,
755 APIKey: "local-anthropic-key",
756 ProviderType: provider.TypeAnthropic,
757 DefaultLargeModel: "claude-3-opus",
758 },
759 },
760 }
761
762 localConfigPath := filepath.Join(cwdDir, "crush.json")
763 data, err := json.Marshal(localConfig)
764 require.NoError(t, err)
765 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
766
767 cfg, err := Init(cwdDir, false)
768
769 require.NoError(t, err)
770 assert.Len(t, cfg.Providers, 1)
771
772 anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
773 assert.Equal(t, "local-anthropic-key", anthropicProvider.APIKey)
774 assert.Equal(t, "claude-3-opus", anthropicProvider.DefaultLargeModel)
775}
776
777func TestProviderMerging_ConflictingSettings(t *testing.T) {
778 reset()
779 testConfigDir = t.TempDir()
780 cwdDir := t.TempDir()
781
782 // Create global config
783 globalConfig := Config{
784 Providers: map[provider.InferenceProvider]ProviderConfig{
785 provider.InferenceProviderOpenAI: {
786 ID: provider.InferenceProviderOpenAI,
787 APIKey: "global-key",
788 ProviderType: provider.TypeOpenAI,
789 DefaultLargeModel: "gpt-4",
790 DefaultSmallModel: "gpt-3.5-turbo",
791 },
792 },
793 }
794
795 configPath := filepath.Join(testConfigDir, "crush.json")
796 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
797 data, err := json.Marshal(globalConfig)
798 require.NoError(t, err)
799 require.NoError(t, os.WriteFile(configPath, data, 0o644))
800
801 // Create local config that overrides
802 localConfig := Config{
803 Providers: map[provider.InferenceProvider]ProviderConfig{
804 provider.InferenceProviderOpenAI: {
805 APIKey: "local-key",
806 DefaultLargeModel: "gpt-4-turbo",
807 // Test disabled separately - don't disable here as it causes nil pointer
808 },
809 },
810 }
811
812 localConfigPath := filepath.Join(cwdDir, "crush.json")
813 data, err = json.Marshal(localConfig)
814 require.NoError(t, err)
815 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
816
817 cfg, err := Init(cwdDir, false)
818
819 require.NoError(t, err)
820
821 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
822 // Local should override global
823 assert.Equal(t, "local-key", openaiProvider.APIKey)
824 assert.Equal(t, "gpt-4-turbo", openaiProvider.DefaultLargeModel)
825 assert.False(t, openaiProvider.Disabled) // Should not be disabled
826 // Global values should remain where not overridden
827 assert.Equal(t, "gpt-3.5-turbo", openaiProvider.DefaultSmallModel)
828}
829
830func TestProviderMerging_CustomVsKnownProviders(t *testing.T) {
831 reset()
832 testConfigDir = t.TempDir()
833 cwdDir := t.TempDir()
834
835 customProviderID := provider.InferenceProvider("custom-provider")
836
837 // Create config with both known and custom providers
838 globalConfig := Config{
839 Providers: map[provider.InferenceProvider]ProviderConfig{
840 // Known provider - some fields should not be overrideable
841 provider.InferenceProviderOpenAI: {
842 ID: provider.InferenceProviderOpenAI,
843 APIKey: "openai-key",
844 BaseURL: "should-not-override",
845 ProviderType: provider.TypeAnthropic, // Should not override
846 },
847 // Custom provider - all fields should be configurable
848 customProviderID: {
849 ID: customProviderID,
850 APIKey: "custom-key",
851 BaseURL: "https://custom.api.com",
852 ProviderType: provider.TypeOpenAI,
853 },
854 },
855 }
856
857 localConfig := Config{
858 Providers: map[provider.InferenceProvider]ProviderConfig{
859 provider.InferenceProviderOpenAI: {
860 BaseURL: "https://should-not-change.com",
861 ProviderType: provider.TypeGemini, // Should not change
862 },
863 customProviderID: {
864 BaseURL: "https://updated-custom.api.com",
865 ProviderType: provider.TypeOpenAI,
866 },
867 },
868 }
869
870 configPath := filepath.Join(testConfigDir, "crush.json")
871 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
872 data, err := json.Marshal(globalConfig)
873 require.NoError(t, err)
874 require.NoError(t, os.WriteFile(configPath, data, 0o644))
875
876 localConfigPath := filepath.Join(cwdDir, "crush.json")
877 data, err = json.Marshal(localConfig)
878 require.NoError(t, err)
879 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
880
881 cfg, err := Init(cwdDir, false)
882
883 require.NoError(t, err)
884
885 // Known provider should not have BaseURL/ProviderType overridden
886 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
887 assert.NotEqual(t, "https://should-not-change.com", openaiProvider.BaseURL)
888 assert.NotEqual(t, provider.TypeGemini, openaiProvider.ProviderType)
889
890 // Custom provider should have all fields configurable
891 customProvider := cfg.Providers[customProviderID]
892 assert.Equal(t, "custom-key", customProvider.APIKey) // Should preserve from global
893 assert.Equal(t, "https://updated-custom.api.com", customProvider.BaseURL)
894 assert.Equal(t, provider.TypeOpenAI, customProvider.ProviderType)
895}
896
897func TestProviderValidation_CustomProviderMissingBaseURL(t *testing.T) {
898 reset()
899 testConfigDir = t.TempDir()
900 cwdDir := t.TempDir()
901
902 customProviderID := provider.InferenceProvider("custom-provider")
903
904 // Create config with custom provider missing BaseURL
905 globalConfig := Config{
906 Providers: map[provider.InferenceProvider]ProviderConfig{
907 customProviderID: {
908 ID: customProviderID,
909 APIKey: "custom-key",
910 ProviderType: provider.TypeOpenAI,
911 // Missing BaseURL
912 },
913 },
914 }
915
916 configPath := filepath.Join(testConfigDir, "crush.json")
917 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
918 data, err := json.Marshal(globalConfig)
919 require.NoError(t, err)
920 require.NoError(t, os.WriteFile(configPath, data, 0o644))
921
922 cfg, err := Init(cwdDir, false)
923
924 require.NoError(t, err)
925 // Provider should be filtered out due to validation failure
926 assert.NotContains(t, cfg.Providers, customProviderID)
927}
928
929func TestProviderValidation_CustomProviderMissingAPIKey(t *testing.T) {
930 reset()
931 testConfigDir = t.TempDir()
932 cwdDir := t.TempDir()
933
934 customProviderID := provider.InferenceProvider("custom-provider")
935
936 globalConfig := Config{
937 Providers: map[provider.InferenceProvider]ProviderConfig{
938 customProviderID: {
939 ID: customProviderID,
940 BaseURL: "https://custom.api.com",
941 ProviderType: provider.TypeOpenAI,
942 // Missing APIKey
943 },
944 },
945 }
946
947 configPath := filepath.Join(testConfigDir, "crush.json")
948 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
949 data, err := json.Marshal(globalConfig)
950 require.NoError(t, err)
951 require.NoError(t, os.WriteFile(configPath, data, 0o644))
952
953 cfg, err := Init(cwdDir, false)
954
955 require.NoError(t, err)
956 assert.NotContains(t, cfg.Providers, customProviderID)
957}
958
959func TestProviderValidation_CustomProviderInvalidType(t *testing.T) {
960 reset()
961 testConfigDir = t.TempDir()
962 cwdDir := t.TempDir()
963
964 customProviderID := provider.InferenceProvider("custom-provider")
965
966 globalConfig := Config{
967 Providers: map[provider.InferenceProvider]ProviderConfig{
968 customProviderID: {
969 ID: customProviderID,
970 APIKey: "custom-key",
971 BaseURL: "https://custom.api.com",
972 ProviderType: provider.Type("invalid-type"),
973 },
974 },
975 }
976
977 configPath := filepath.Join(testConfigDir, "crush.json")
978 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
979 data, err := json.Marshal(globalConfig)
980 require.NoError(t, err)
981 require.NoError(t, os.WriteFile(configPath, data, 0o644))
982
983 cfg, err := Init(cwdDir, false)
984
985 require.NoError(t, err)
986 assert.NotContains(t, cfg.Providers, customProviderID)
987}
988
989func TestProviderValidation_KnownProviderValid(t *testing.T) {
990 reset()
991 testConfigDir = t.TempDir()
992 cwdDir := t.TempDir()
993
994 globalConfig := Config{
995 Providers: map[provider.InferenceProvider]ProviderConfig{
996 provider.InferenceProviderOpenAI: {
997 ID: provider.InferenceProviderOpenAI,
998 APIKey: "openai-key",
999 ProviderType: provider.TypeOpenAI,
1000 // BaseURL not required for known providers
1001 },
1002 },
1003 }
1004
1005 configPath := filepath.Join(testConfigDir, "crush.json")
1006 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1007 data, err := json.Marshal(globalConfig)
1008 require.NoError(t, err)
1009 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1010
1011 cfg, err := Init(cwdDir, false)
1012
1013 require.NoError(t, err)
1014 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
1015}
1016
1017func TestProviderValidation_DisabledProvider(t *testing.T) {
1018 reset()
1019 testConfigDir = t.TempDir()
1020 cwdDir := t.TempDir()
1021
1022 globalConfig := Config{
1023 Providers: map[provider.InferenceProvider]ProviderConfig{
1024 provider.InferenceProviderOpenAI: {
1025 ID: provider.InferenceProviderOpenAI,
1026 APIKey: "openai-key",
1027 ProviderType: provider.TypeOpenAI,
1028 Disabled: true,
1029 },
1030 },
1031 }
1032
1033 configPath := filepath.Join(testConfigDir, "crush.json")
1034 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1035 data, err := json.Marshal(globalConfig)
1036 require.NoError(t, err)
1037 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1038
1039 cfg, err := Init(cwdDir, false)
1040
1041 require.NoError(t, err)
1042 // Disabled providers should still be in the config but marked as disabled
1043 assert.Contains(t, cfg.Providers, provider.InferenceProviderOpenAI)
1044 assert.True(t, cfg.Providers[provider.InferenceProviderOpenAI].Disabled)
1045}
1046
1047func TestProviderModels_AddingNewModels(t *testing.T) {
1048 reset()
1049 testConfigDir = t.TempDir()
1050 cwdDir := t.TempDir()
1051
1052 globalConfig := Config{
1053 Providers: map[provider.InferenceProvider]ProviderConfig{
1054 provider.InferenceProviderOpenAI: {
1055 ID: provider.InferenceProviderOpenAI,
1056 APIKey: "openai-key",
1057 ProviderType: provider.TypeOpenAI,
1058 Models: []Model{
1059 {
1060 ID: "gpt-4",
1061 Name: "GPT-4",
1062 ContextWindow: 8192,
1063 DefaultMaxTokens: 4096,
1064 },
1065 },
1066 },
1067 },
1068 }
1069
1070 localConfig := Config{
1071 Providers: map[provider.InferenceProvider]ProviderConfig{
1072 provider.InferenceProviderOpenAI: {
1073 Models: []Model{
1074 {
1075 ID: "gpt-4-turbo",
1076 Name: "GPT-4 Turbo",
1077 ContextWindow: 128000,
1078 DefaultMaxTokens: 4096,
1079 },
1080 },
1081 },
1082 },
1083 }
1084
1085 configPath := filepath.Join(testConfigDir, "crush.json")
1086 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1087 data, err := json.Marshal(globalConfig)
1088 require.NoError(t, err)
1089 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1090
1091 localConfigPath := filepath.Join(cwdDir, "crush.json")
1092 data, err = json.Marshal(localConfig)
1093 require.NoError(t, err)
1094 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1095
1096 cfg, err := Init(cwdDir, false)
1097
1098 require.NoError(t, err)
1099
1100 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
1101 assert.Len(t, openaiProvider.Models, 2) // Should have both models
1102
1103 modelIDs := make([]string, len(openaiProvider.Models))
1104 for i, model := range openaiProvider.Models {
1105 modelIDs[i] = model.ID
1106 }
1107 assert.Contains(t, modelIDs, "gpt-4")
1108 assert.Contains(t, modelIDs, "gpt-4-turbo")
1109}
1110
1111func TestProviderModels_DuplicateModelHandling(t *testing.T) {
1112 reset()
1113 testConfigDir = t.TempDir()
1114 cwdDir := t.TempDir()
1115
1116 globalConfig := Config{
1117 Providers: map[provider.InferenceProvider]ProviderConfig{
1118 provider.InferenceProviderOpenAI: {
1119 ID: provider.InferenceProviderOpenAI,
1120 APIKey: "openai-key",
1121 ProviderType: provider.TypeOpenAI,
1122 Models: []Model{
1123 {
1124 ID: "gpt-4",
1125 Name: "GPT-4",
1126 ContextWindow: 8192,
1127 DefaultMaxTokens: 4096,
1128 },
1129 },
1130 },
1131 },
1132 }
1133
1134 localConfig := Config{
1135 Providers: map[provider.InferenceProvider]ProviderConfig{
1136 provider.InferenceProviderOpenAI: {
1137 Models: []Model{
1138 {
1139 ID: "gpt-4", // Same ID as global
1140 Name: "GPT-4 Updated",
1141 ContextWindow: 16384,
1142 DefaultMaxTokens: 8192,
1143 },
1144 },
1145 },
1146 },
1147 }
1148
1149 configPath := filepath.Join(testConfigDir, "crush.json")
1150 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1151 data, err := json.Marshal(globalConfig)
1152 require.NoError(t, err)
1153 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1154
1155 localConfigPath := filepath.Join(cwdDir, "crush.json")
1156 data, err = json.Marshal(localConfig)
1157 require.NoError(t, err)
1158 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1159
1160 cfg, err := Init(cwdDir, false)
1161
1162 require.NoError(t, err)
1163
1164 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
1165 assert.Len(t, openaiProvider.Models, 1) // Should not duplicate
1166
1167 // Should keep the original model (global config)
1168 model := openaiProvider.Models[0]
1169 assert.Equal(t, "gpt-4", model.ID)
1170 assert.Equal(t, "GPT-4", model.Name) // Original name
1171 assert.Equal(t, int64(8192), model.ContextWindow) // Original context window
1172}
1173
1174func TestProviderModels_ModelCostAndCapabilities(t *testing.T) {
1175 reset()
1176 testConfigDir = t.TempDir()
1177 cwdDir := t.TempDir()
1178
1179 globalConfig := Config{
1180 Providers: map[provider.InferenceProvider]ProviderConfig{
1181 provider.InferenceProviderOpenAI: {
1182 ID: provider.InferenceProviderOpenAI,
1183 APIKey: "openai-key",
1184 ProviderType: provider.TypeOpenAI,
1185 Models: []Model{
1186 {
1187 ID: "gpt-4",
1188 Name: "GPT-4",
1189 CostPer1MIn: 30.0,
1190 CostPer1MOut: 60.0,
1191 CostPer1MInCached: 15.0,
1192 CostPer1MOutCached: 30.0,
1193 ContextWindow: 8192,
1194 DefaultMaxTokens: 4096,
1195 CanReason: true,
1196 ReasoningEffort: "medium",
1197 SupportsImages: true,
1198 },
1199 },
1200 },
1201 },
1202 }
1203
1204 configPath := filepath.Join(testConfigDir, "crush.json")
1205 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1206 data, err := json.Marshal(globalConfig)
1207 require.NoError(t, err)
1208 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1209
1210 cfg, err := Init(cwdDir, false)
1211
1212 require.NoError(t, err)
1213
1214 openaiProvider := cfg.Providers[provider.InferenceProviderOpenAI]
1215 require.Len(t, openaiProvider.Models, 1)
1216
1217 model := openaiProvider.Models[0]
1218 assert.Equal(t, 30.0, model.CostPer1MIn)
1219 assert.Equal(t, 60.0, model.CostPer1MOut)
1220 assert.Equal(t, 15.0, model.CostPer1MInCached)
1221 assert.Equal(t, 30.0, model.CostPer1MOutCached)
1222 assert.True(t, model.CanReason)
1223 assert.Equal(t, "medium", model.ReasoningEffort)
1224 assert.True(t, model.SupportsImages)
1225}
1226
1227// Agent Configuration Tests
1228
1229func TestDefaultAgents_CoderAgent(t *testing.T) {
1230 reset()
1231 testConfigDir = t.TempDir()
1232 cwdDir := t.TempDir()
1233
1234 // Set up a provider so we can test agent configuration
1235 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1236
1237 cfg, err := Init(cwdDir, false)
1238
1239 require.NoError(t, err)
1240 assert.Contains(t, cfg.Agents, AgentCoder)
1241
1242 coderAgent := cfg.Agents[AgentCoder]
1243 assert.Equal(t, AgentCoder, coderAgent.ID)
1244 assert.Equal(t, "Coder", coderAgent.Name)
1245 assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
1246 assert.Equal(t, LargeModel, coderAgent.Model)
1247 assert.False(t, coderAgent.Disabled)
1248 assert.Equal(t, cfg.Options.ContextPaths, coderAgent.ContextPaths)
1249 // Coder agent should have all tools available (nil means all tools)
1250 assert.Nil(t, coderAgent.AllowedTools)
1251}
1252
1253func TestDefaultAgents_TaskAgent(t *testing.T) {
1254 reset()
1255 testConfigDir = t.TempDir()
1256 cwdDir := t.TempDir()
1257
1258 // Set up a provider so we can test agent configuration
1259 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1260
1261 cfg, err := Init(cwdDir, false)
1262
1263 require.NoError(t, err)
1264 assert.Contains(t, cfg.Agents, AgentTask)
1265
1266 taskAgent := cfg.Agents[AgentTask]
1267 assert.Equal(t, AgentTask, taskAgent.ID)
1268 assert.Equal(t, "Task", taskAgent.Name)
1269 assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description)
1270 assert.Equal(t, LargeModel, taskAgent.Model)
1271 assert.False(t, taskAgent.Disabled)
1272 assert.Equal(t, cfg.Options.ContextPaths, taskAgent.ContextPaths)
1273
1274 // Task agent should have restricted tools
1275 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1276 assert.Equal(t, expectedTools, taskAgent.AllowedTools)
1277
1278 // Task agent should have no MCPs or LSPs by default
1279 assert.Equal(t, map[string][]string{}, taskAgent.AllowedMCP)
1280 assert.Equal(t, []string{}, taskAgent.AllowedLSP)
1281}
1282
1283func TestAgentMerging_CustomAgent(t *testing.T) {
1284 reset()
1285 testConfigDir = t.TempDir()
1286 cwdDir := t.TempDir()
1287
1288 // Set up a provider
1289 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1290
1291 // Create config with custom agent
1292 globalConfig := Config{
1293 Agents: map[AgentID]Agent{
1294 AgentID("custom-agent"): {
1295 ID: AgentID("custom-agent"),
1296 Name: "Custom Agent",
1297 Description: "A custom agent for testing",
1298 Model: SmallModel,
1299 AllowedTools: []string{"glob", "grep"},
1300 AllowedMCP: map[string][]string{"mcp1": {"tool1", "tool2"}},
1301 AllowedLSP: []string{"typescript", "go"},
1302 ContextPaths: []string{"custom-context.md"},
1303 },
1304 },
1305 }
1306
1307 configPath := filepath.Join(testConfigDir, "crush.json")
1308 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1309 data, err := json.Marshal(globalConfig)
1310 require.NoError(t, err)
1311 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1312
1313 cfg, err := Init(cwdDir, false)
1314
1315 require.NoError(t, err)
1316
1317 // Should have default agents plus custom agent
1318 assert.Contains(t, cfg.Agents, AgentCoder)
1319 assert.Contains(t, cfg.Agents, AgentTask)
1320 assert.Contains(t, cfg.Agents, AgentID("custom-agent"))
1321
1322 customAgent := cfg.Agents[AgentID("custom-agent")]
1323 assert.Equal(t, "Custom Agent", customAgent.Name)
1324 assert.Equal(t, "A custom agent for testing", customAgent.Description)
1325 assert.Equal(t, SmallModel, customAgent.Model)
1326 assert.Equal(t, []string{"glob", "grep"}, customAgent.AllowedTools)
1327 assert.Equal(t, map[string][]string{"mcp1": {"tool1", "tool2"}}, customAgent.AllowedMCP)
1328 assert.Equal(t, []string{"typescript", "go"}, customAgent.AllowedLSP)
1329 // Context paths should be additive (default + custom)
1330 expectedContextPaths := append(defaultContextPaths, "custom-context.md")
1331 assert.Equal(t, expectedContextPaths, customAgent.ContextPaths)
1332}
1333
1334func TestAgentMerging_ModifyDefaultCoderAgent(t *testing.T) {
1335 reset()
1336 testConfigDir = t.TempDir()
1337 cwdDir := t.TempDir()
1338
1339 // Set up a provider
1340 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1341
1342 // Create config that modifies the default coder agent
1343 globalConfig := Config{
1344 Agents: map[AgentID]Agent{
1345 AgentCoder: {
1346 Model: SmallModel, // Change from default LargeModel
1347 AllowedMCP: map[string][]string{"mcp1": {"tool1"}},
1348 AllowedLSP: []string{"typescript"},
1349 ContextPaths: []string{"coder-specific.md"}, // Should be additive
1350 },
1351 },
1352 }
1353
1354 configPath := filepath.Join(testConfigDir, "crush.json")
1355 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1356 data, err := json.Marshal(globalConfig)
1357 require.NoError(t, err)
1358 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1359
1360 cfg, err := Init(cwdDir, false)
1361
1362 require.NoError(t, err)
1363
1364 coderAgent := cfg.Agents[AgentCoder]
1365 // Should preserve default values for unspecified fields
1366 assert.Equal(t, AgentCoder, coderAgent.ID)
1367 assert.Equal(t, "Coder", coderAgent.Name)
1368 assert.Equal(t, "An agent that helps with executing coding tasks.", coderAgent.Description)
1369
1370 // Context paths should be additive (default + custom)
1371 expectedContextPaths := append(cfg.Options.ContextPaths, "coder-specific.md")
1372 assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
1373
1374 // Should update specified fields
1375 assert.Equal(t, SmallModel, coderAgent.Model)
1376 assert.Equal(t, map[string][]string{"mcp1": {"tool1"}}, coderAgent.AllowedMCP)
1377 assert.Equal(t, []string{"typescript"}, coderAgent.AllowedLSP)
1378}
1379
1380func TestAgentMerging_ModifyDefaultTaskAgent(t *testing.T) {
1381 reset()
1382 testConfigDir = t.TempDir()
1383 cwdDir := t.TempDir()
1384
1385 // Set up a provider
1386 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1387
1388 // Create config that modifies the default task agent
1389 // Note: Only model, MCP, and LSP should be configurable for known agents
1390 globalConfig := Config{
1391 Agents: map[AgentID]Agent{
1392 AgentTask: {
1393 Model: SmallModel, // Should be updated
1394 AllowedMCP: map[string][]string{"search-mcp": nil}, // Should be updated
1395 AllowedLSP: []string{"python"}, // Should be updated
1396 // These should be ignored for known agents:
1397 Name: "Search Agent", // Should be ignored
1398 Description: "Custom search agent", // Should be ignored
1399 Disabled: true, // Should be ignored
1400 AllowedTools: []string{"glob", "grep", "view"}, // Should be ignored
1401 },
1402 },
1403 }
1404
1405 configPath := filepath.Join(testConfigDir, "crush.json")
1406 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1407 data, err := json.Marshal(globalConfig)
1408 require.NoError(t, err)
1409 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1410
1411 cfg, err := Init(cwdDir, false)
1412
1413 require.NoError(t, err)
1414
1415 taskAgent := cfg.Agents[AgentTask]
1416 // Should preserve default values for protected fields
1417 assert.Equal(t, "Task", taskAgent.Name) // Should remain default
1418 assert.Equal(t, "An agent that helps with searching for context and finding implementation details.", taskAgent.Description) // Should remain default
1419 assert.False(t, taskAgent.Disabled) // Should remain default
1420 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) // Should remain default
1421
1422 // Should update configurable fields
1423 assert.Equal(t, SmallModel, taskAgent.Model)
1424 assert.Equal(t, map[string][]string{"search-mcp": nil}, taskAgent.AllowedMCP)
1425 assert.Equal(t, []string{"python"}, taskAgent.AllowedLSP)
1426}
1427
1428func TestAgentMerging_LocalOverridesGlobal(t *testing.T) {
1429 reset()
1430 testConfigDir = t.TempDir()
1431 cwdDir := t.TempDir()
1432
1433 // Set up a provider
1434 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1435
1436 // Create global config with custom agent
1437 globalConfig := Config{
1438 Agents: map[AgentID]Agent{
1439 AgentID("test-agent"): {
1440 ID: AgentID("test-agent"),
1441 Name: "Global Agent",
1442 Description: "Global description",
1443 Model: LargeModel,
1444 AllowedTools: []string{"glob"},
1445 },
1446 },
1447 }
1448
1449 configPath := filepath.Join(testConfigDir, "crush.json")
1450 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1451 data, err := json.Marshal(globalConfig)
1452 require.NoError(t, err)
1453 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1454
1455 // Create local config that overrides
1456 localConfig := Config{
1457 Agents: map[AgentID]Agent{
1458 AgentID("test-agent"): {
1459 Name: "Local Agent",
1460 Description: "Local description",
1461 Model: SmallModel,
1462 Disabled: true,
1463 AllowedTools: []string{"grep", "view"},
1464 AllowedMCP: map[string][]string{"local-mcp": {"tool1"}},
1465 },
1466 },
1467 }
1468
1469 localConfigPath := filepath.Join(cwdDir, "crush.json")
1470 data, err = json.Marshal(localConfig)
1471 require.NoError(t, err)
1472 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1473
1474 cfg, err := Init(cwdDir, false)
1475
1476 require.NoError(t, err)
1477
1478 testAgent := cfg.Agents[AgentID("test-agent")]
1479 // Local should override global
1480 assert.Equal(t, "Local Agent", testAgent.Name)
1481 assert.Equal(t, "Local description", testAgent.Description)
1482 assert.Equal(t, SmallModel, testAgent.Model)
1483 assert.True(t, testAgent.Disabled)
1484 assert.Equal(t, []string{"grep", "view"}, testAgent.AllowedTools)
1485 assert.Equal(t, map[string][]string{"local-mcp": {"tool1"}}, testAgent.AllowedMCP)
1486}
1487
1488func TestAgentModelTypeAssignment(t *testing.T) {
1489 reset()
1490 testConfigDir = t.TempDir()
1491 cwdDir := t.TempDir()
1492
1493 // Set up a provider
1494 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1495
1496 // Create config with agents using different model types
1497 globalConfig := Config{
1498 Agents: map[AgentID]Agent{
1499 AgentID("large-agent"): {
1500 ID: AgentID("large-agent"),
1501 Name: "Large Model Agent",
1502 Model: LargeModel,
1503 },
1504 AgentID("small-agent"): {
1505 ID: AgentID("small-agent"),
1506 Name: "Small Model Agent",
1507 Model: SmallModel,
1508 },
1509 AgentID("default-agent"): {
1510 ID: AgentID("default-agent"),
1511 Name: "Default Model Agent",
1512 // No model specified - should default to LargeModel
1513 },
1514 },
1515 }
1516
1517 configPath := filepath.Join(testConfigDir, "crush.json")
1518 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1519 data, err := json.Marshal(globalConfig)
1520 require.NoError(t, err)
1521 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1522
1523 cfg, err := Init(cwdDir, false)
1524
1525 require.NoError(t, err)
1526
1527 assert.Equal(t, LargeModel, cfg.Agents[AgentID("large-agent")].Model)
1528 assert.Equal(t, SmallModel, cfg.Agents[AgentID("small-agent")].Model)
1529 assert.Equal(t, LargeModel, cfg.Agents[AgentID("default-agent")].Model) // Should default to LargeModel
1530}
1531
1532func TestAgentContextPathOverrides(t *testing.T) {
1533 reset()
1534 testConfigDir = t.TempDir()
1535 cwdDir := t.TempDir()
1536
1537 // Set up a provider
1538 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1539
1540 // Create config with custom context paths
1541 globalConfig := Config{
1542 Options: Options{
1543 ContextPaths: []string{"global-context.md", "shared-context.md"},
1544 },
1545 Agents: map[AgentID]Agent{
1546 AgentID("custom-context-agent"): {
1547 ID: AgentID("custom-context-agent"),
1548 Name: "Custom Context Agent",
1549 ContextPaths: []string{"agent-specific.md", "custom.md"},
1550 },
1551 AgentID("default-context-agent"): {
1552 ID: AgentID("default-context-agent"),
1553 Name: "Default Context Agent",
1554 // No ContextPaths specified - should use global
1555 },
1556 },
1557 }
1558
1559 configPath := filepath.Join(testConfigDir, "crush.json")
1560 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1561 data, err := json.Marshal(globalConfig)
1562 require.NoError(t, err)
1563 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1564
1565 cfg, err := Init(cwdDir, false)
1566
1567 require.NoError(t, err)
1568
1569 // Agent with custom context paths should have default + global + custom paths (additive)
1570 customAgent := cfg.Agents[AgentID("custom-context-agent")]
1571 expectedCustomPaths := append(defaultContextPaths, "global-context.md", "shared-context.md", "agent-specific.md", "custom.md")
1572 assert.Equal(t, expectedCustomPaths, customAgent.ContextPaths)
1573
1574 // Agent without custom context paths should use global + defaults
1575 defaultAgent := cfg.Agents[AgentID("default-context-agent")]
1576 expectedContextPaths := append(defaultContextPaths, "global-context.md", "shared-context.md")
1577 assert.Equal(t, expectedContextPaths, defaultAgent.ContextPaths)
1578
1579 // Default agents should also use the merged context paths
1580 coderAgent := cfg.Agents[AgentCoder]
1581 assert.Equal(t, expectedContextPaths, coderAgent.ContextPaths)
1582}
1583
1584// Options and Settings Tests
1585
1586func TestOptionsMerging_ContextPaths(t *testing.T) {
1587 reset()
1588 testConfigDir = t.TempDir()
1589 cwdDir := t.TempDir()
1590
1591 // Set up a provider
1592 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1593
1594 // Create global config with context paths
1595 globalConfig := Config{
1596 Options: Options{
1597 ContextPaths: []string{"global1.md", "global2.md"},
1598 },
1599 }
1600
1601 configPath := filepath.Join(testConfigDir, "crush.json")
1602 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1603 data, err := json.Marshal(globalConfig)
1604 require.NoError(t, err)
1605 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1606
1607 // Create local config with additional context paths
1608 localConfig := Config{
1609 Options: Options{
1610 ContextPaths: []string{"local1.md", "local2.md"},
1611 },
1612 }
1613
1614 localConfigPath := filepath.Join(cwdDir, "crush.json")
1615 data, err = json.Marshal(localConfig)
1616 require.NoError(t, err)
1617 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1618
1619 cfg, err := Init(cwdDir, false)
1620
1621 require.NoError(t, err)
1622
1623 // Context paths should be merged: defaults + global + local
1624 expectedContextPaths := append(defaultContextPaths, "global1.md", "global2.md", "local1.md", "local2.md")
1625 assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
1626}
1627
1628func TestOptionsMerging_TUIOptions(t *testing.T) {
1629 reset()
1630 testConfigDir = t.TempDir()
1631 cwdDir := t.TempDir()
1632
1633 // Set up a provider
1634 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1635
1636 // Create global config with TUI options
1637 globalConfig := Config{
1638 Options: Options{
1639 TUI: TUIOptions{
1640 CompactMode: false, // Default value
1641 },
1642 },
1643 }
1644
1645 configPath := filepath.Join(testConfigDir, "crush.json")
1646 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1647 data, err := json.Marshal(globalConfig)
1648 require.NoError(t, err)
1649 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1650
1651 // Create local config that enables compact mode
1652 localConfig := Config{
1653 Options: Options{
1654 TUI: TUIOptions{
1655 CompactMode: true,
1656 },
1657 },
1658 }
1659
1660 localConfigPath := filepath.Join(cwdDir, "crush.json")
1661 data, err = json.Marshal(localConfig)
1662 require.NoError(t, err)
1663 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1664
1665 cfg, err := Init(cwdDir, false)
1666
1667 require.NoError(t, err)
1668
1669 // Local config should override global
1670 assert.True(t, cfg.Options.TUI.CompactMode)
1671}
1672
1673func TestOptionsMerging_DebugFlags(t *testing.T) {
1674 reset()
1675 testConfigDir = t.TempDir()
1676 cwdDir := t.TempDir()
1677
1678 // Set up a provider
1679 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1680
1681 // Create global config with debug flags
1682 globalConfig := Config{
1683 Options: Options{
1684 Debug: false,
1685 DebugLSP: false,
1686 DisableAutoSummarize: false,
1687 },
1688 }
1689
1690 configPath := filepath.Join(testConfigDir, "crush.json")
1691 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1692 data, err := json.Marshal(globalConfig)
1693 require.NoError(t, err)
1694 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1695
1696 // Create local config that enables debug flags
1697 localConfig := Config{
1698 Options: Options{
1699 DebugLSP: true,
1700 DisableAutoSummarize: true,
1701 },
1702 }
1703
1704 localConfigPath := filepath.Join(cwdDir, "crush.json")
1705 data, err = json.Marshal(localConfig)
1706 require.NoError(t, err)
1707 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1708
1709 cfg, err := Init(cwdDir, false)
1710
1711 require.NoError(t, err)
1712
1713 // Local config should override global for boolean flags
1714 assert.False(t, cfg.Options.Debug) // Not set in local, remains global value
1715 assert.True(t, cfg.Options.DebugLSP) // Set to true in local
1716 assert.True(t, cfg.Options.DisableAutoSummarize) // Set to true in local
1717}
1718
1719func TestOptionsMerging_DataDirectory(t *testing.T) {
1720 reset()
1721 testConfigDir = t.TempDir()
1722 cwdDir := t.TempDir()
1723
1724 // Set up a provider
1725 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1726
1727 // Create global config with custom data directory
1728 globalConfig := Config{
1729 Options: Options{
1730 DataDirectory: "global-data",
1731 },
1732 }
1733
1734 configPath := filepath.Join(testConfigDir, "crush.json")
1735 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1736 data, err := json.Marshal(globalConfig)
1737 require.NoError(t, err)
1738 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1739
1740 // Create local config with different data directory
1741 localConfig := Config{
1742 Options: Options{
1743 DataDirectory: "local-data",
1744 },
1745 }
1746
1747 localConfigPath := filepath.Join(cwdDir, "crush.json")
1748 data, err = json.Marshal(localConfig)
1749 require.NoError(t, err)
1750 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1751
1752 cfg, err := Init(cwdDir, false)
1753
1754 require.NoError(t, err)
1755
1756 // Local config should override global
1757 assert.Equal(t, "local-data", cfg.Options.DataDirectory)
1758}
1759
1760func TestOptionsMerging_DefaultValues(t *testing.T) {
1761 reset()
1762 testConfigDir = t.TempDir()
1763 cwdDir := t.TempDir()
1764
1765 // Set up a provider
1766 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1767
1768 // No config files - should use defaults
1769 cfg, err := Init(cwdDir, false)
1770
1771 require.NoError(t, err)
1772
1773 // Should have default values
1774 assert.Equal(t, defaultDataDirectory, cfg.Options.DataDirectory)
1775 assert.Equal(t, defaultContextPaths, cfg.Options.ContextPaths)
1776 assert.False(t, cfg.Options.TUI.CompactMode)
1777 assert.False(t, cfg.Options.Debug)
1778 assert.False(t, cfg.Options.DebugLSP)
1779 assert.False(t, cfg.Options.DisableAutoSummarize)
1780}
1781
1782func TestOptionsMerging_DebugFlagFromInit(t *testing.T) {
1783 reset()
1784 testConfigDir = t.TempDir()
1785 cwdDir := t.TempDir()
1786
1787 // Set up a provider
1788 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1789
1790 // Create config with debug false
1791 globalConfig := Config{
1792 Options: Options{
1793 Debug: false,
1794 },
1795 }
1796
1797 configPath := filepath.Join(testConfigDir, "crush.json")
1798 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1799 data, err := json.Marshal(globalConfig)
1800 require.NoError(t, err)
1801 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1802
1803 // Init with debug=true should override config
1804 cfg, err := Init(cwdDir, true)
1805
1806 require.NoError(t, err)
1807
1808 // Debug flag from Init should take precedence
1809 assert.True(t, cfg.Options.Debug)
1810}
1811
1812func TestOptionsMerging_ComplexScenario(t *testing.T) {
1813 reset()
1814 testConfigDir = t.TempDir()
1815 cwdDir := t.TempDir()
1816
1817 // Set up a provider
1818 os.Setenv("ANTHROPIC_API_KEY", "test-key")
1819
1820 // Create global config with various options
1821 globalConfig := Config{
1822 Options: Options{
1823 ContextPaths: []string{"global-context.md"},
1824 DataDirectory: "global-data",
1825 Debug: false,
1826 DebugLSP: false,
1827 DisableAutoSummarize: false,
1828 TUI: TUIOptions{
1829 CompactMode: false,
1830 },
1831 },
1832 }
1833
1834 configPath := filepath.Join(testConfigDir, "crush.json")
1835 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1836 data, err := json.Marshal(globalConfig)
1837 require.NoError(t, err)
1838 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1839
1840 // Create local config that partially overrides
1841 localConfig := Config{
1842 Options: Options{
1843 ContextPaths: []string{"local-context.md"},
1844 DebugLSP: true, // Override
1845 DisableAutoSummarize: true, // Override
1846 TUI: TUIOptions{
1847 CompactMode: true, // Override
1848 },
1849 // DataDirectory and Debug not specified - should keep global values
1850 },
1851 }
1852
1853 localConfigPath := filepath.Join(cwdDir, "crush.json")
1854 data, err = json.Marshal(localConfig)
1855 require.NoError(t, err)
1856 require.NoError(t, os.WriteFile(localConfigPath, data, 0o644))
1857
1858 cfg, err := Init(cwdDir, false)
1859
1860 require.NoError(t, err)
1861
1862 // Check merged results
1863 expectedContextPaths := append(defaultContextPaths, "global-context.md", "local-context.md")
1864 assert.Equal(t, expectedContextPaths, cfg.Options.ContextPaths)
1865 assert.Equal(t, "global-data", cfg.Options.DataDirectory) // From global
1866 assert.False(t, cfg.Options.Debug) // From global
1867 assert.True(t, cfg.Options.DebugLSP) // From local
1868 assert.True(t, cfg.Options.DisableAutoSummarize) // From local
1869 assert.True(t, cfg.Options.TUI.CompactMode) // From local
1870}
1871
1872// Model Selection Tests
1873
1874func TestModelSelection_PreferredModelSelection(t *testing.T) {
1875 reset()
1876 testConfigDir = t.TempDir()
1877 cwdDir := t.TempDir()
1878
1879 // Set up multiple providers to test selection logic
1880 os.Setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
1881 os.Setenv("OPENAI_API_KEY", "test-openai-key")
1882
1883 cfg, err := Init(cwdDir, false)
1884
1885 require.NoError(t, err)
1886 require.Len(t, cfg.Providers, 2)
1887
1888 // Should have preferred models set
1889 assert.NotEmpty(t, cfg.Models.Large.ModelID)
1890 assert.NotEmpty(t, cfg.Models.Large.Provider)
1891 assert.NotEmpty(t, cfg.Models.Small.ModelID)
1892 assert.NotEmpty(t, cfg.Models.Small.Provider)
1893
1894 // Both should use the same provider (first available)
1895 assert.Equal(t, cfg.Models.Large.Provider, cfg.Models.Small.Provider)
1896}
1897
1898func TestModelSelection_GetAgentModel(t *testing.T) {
1899 reset()
1900 testConfigDir = t.TempDir()
1901 cwdDir := t.TempDir()
1902
1903 // Set up a provider with known models
1904 globalConfig := Config{
1905 Providers: map[provider.InferenceProvider]ProviderConfig{
1906 provider.InferenceProviderOpenAI: {
1907 ID: provider.InferenceProviderOpenAI,
1908 APIKey: "test-key",
1909 ProviderType: provider.TypeOpenAI,
1910 DefaultLargeModel: "gpt-4",
1911 DefaultSmallModel: "gpt-3.5-turbo",
1912 Models: []Model{
1913 {
1914 ID: "gpt-4",
1915 Name: "GPT-4",
1916 ContextWindow: 8192,
1917 DefaultMaxTokens: 4096,
1918 CanReason: true,
1919 SupportsImages: true,
1920 },
1921 {
1922 ID: "gpt-3.5-turbo",
1923 Name: "GPT-3.5 Turbo",
1924 ContextWindow: 4096,
1925 DefaultMaxTokens: 2048,
1926 CanReason: false,
1927 SupportsImages: false,
1928 },
1929 },
1930 },
1931 },
1932 }
1933
1934 configPath := filepath.Join(testConfigDir, "crush.json")
1935 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1936 data, err := json.Marshal(globalConfig)
1937 require.NoError(t, err)
1938 require.NoError(t, os.WriteFile(configPath, data, 0o644))
1939
1940 _, err = Init(cwdDir, false)
1941
1942 require.NoError(t, err)
1943
1944 // Test GetAgentModel for default agents
1945 coderModel := GetAgentModel(AgentCoder)
1946 assert.Equal(t, "gpt-4", coderModel.ID) // Coder uses LargeModel
1947 assert.Equal(t, "GPT-4", coderModel.Name)
1948 assert.True(t, coderModel.CanReason)
1949 assert.True(t, coderModel.SupportsImages)
1950
1951 taskModel := GetAgentModel(AgentTask)
1952 assert.Equal(t, "gpt-4", taskModel.ID) // Task also uses LargeModel by default
1953 assert.Equal(t, "GPT-4", taskModel.Name)
1954}
1955
1956func TestModelSelection_GetAgentModelWithCustomModelType(t *testing.T) {
1957 reset()
1958 testConfigDir = t.TempDir()
1959 cwdDir := t.TempDir()
1960
1961 // Set up provider and custom agent with SmallModel
1962 globalConfig := Config{
1963 Providers: map[provider.InferenceProvider]ProviderConfig{
1964 provider.InferenceProviderOpenAI: {
1965 ID: provider.InferenceProviderOpenAI,
1966 APIKey: "test-key",
1967 ProviderType: provider.TypeOpenAI,
1968 DefaultLargeModel: "gpt-4",
1969 DefaultSmallModel: "gpt-3.5-turbo",
1970 Models: []Model{
1971 {
1972 ID: "gpt-4",
1973 Name: "GPT-4",
1974 ContextWindow: 8192,
1975 DefaultMaxTokens: 4096,
1976 },
1977 {
1978 ID: "gpt-3.5-turbo",
1979 Name: "GPT-3.5 Turbo",
1980 ContextWindow: 4096,
1981 DefaultMaxTokens: 2048,
1982 },
1983 },
1984 },
1985 },
1986 Agents: map[AgentID]Agent{
1987 AgentID("small-agent"): {
1988 ID: AgentID("small-agent"),
1989 Name: "Small Agent",
1990 Model: SmallModel,
1991 },
1992 },
1993 }
1994
1995 configPath := filepath.Join(testConfigDir, "crush.json")
1996 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
1997 data, err := json.Marshal(globalConfig)
1998 require.NoError(t, err)
1999 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2000
2001 _, err = Init(cwdDir, false)
2002
2003 require.NoError(t, err)
2004
2005 // Test GetAgentModel for custom agent with SmallModel
2006 smallAgentModel := GetAgentModel(AgentID("small-agent"))
2007 assert.Equal(t, "gpt-3.5-turbo", smallAgentModel.ID)
2008 assert.Equal(t, "GPT-3.5 Turbo", smallAgentModel.Name)
2009}
2010
2011func TestModelSelection_GetAgentProvider(t *testing.T) {
2012 reset()
2013 testConfigDir = t.TempDir()
2014 cwdDir := t.TempDir()
2015
2016 // Set up multiple providers
2017 globalConfig := Config{
2018 Providers: map[provider.InferenceProvider]ProviderConfig{
2019 provider.InferenceProviderOpenAI: {
2020 ID: provider.InferenceProviderOpenAI,
2021 APIKey: "openai-key",
2022 ProviderType: provider.TypeOpenAI,
2023 DefaultLargeModel: "gpt-4",
2024 DefaultSmallModel: "gpt-3.5-turbo",
2025 },
2026 provider.InferenceProviderAnthropic: {
2027 ID: provider.InferenceProviderAnthropic,
2028 APIKey: "anthropic-key",
2029 ProviderType: provider.TypeAnthropic,
2030 DefaultLargeModel: "claude-3-opus",
2031 DefaultSmallModel: "claude-3-haiku",
2032 },
2033 },
2034 }
2035
2036 configPath := filepath.Join(testConfigDir, "crush.json")
2037 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
2038 data, err := json.Marshal(globalConfig)
2039 require.NoError(t, err)
2040 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2041
2042 _, err = Init(cwdDir, false)
2043
2044 require.NoError(t, err)
2045
2046 // Test GetAgentProvider
2047 coderProvider := GetAgentProvider(AgentCoder)
2048 assert.NotEmpty(t, coderProvider.ID)
2049 assert.NotEmpty(t, coderProvider.APIKey)
2050 assert.NotEmpty(t, coderProvider.ProviderType)
2051}
2052
2053func TestModelSelection_GetProviderModel(t *testing.T) {
2054 reset()
2055 testConfigDir = t.TempDir()
2056 cwdDir := t.TempDir()
2057
2058 // Set up provider with specific models
2059 globalConfig := Config{
2060 Providers: map[provider.InferenceProvider]ProviderConfig{
2061 provider.InferenceProviderOpenAI: {
2062 ID: provider.InferenceProviderOpenAI,
2063 APIKey: "test-key",
2064 ProviderType: provider.TypeOpenAI,
2065 Models: []Model{
2066 {
2067 ID: "gpt-4",
2068 Name: "GPT-4",
2069 ContextWindow: 8192,
2070 DefaultMaxTokens: 4096,
2071 CostPer1MIn: 30.0,
2072 CostPer1MOut: 60.0,
2073 },
2074 {
2075 ID: "gpt-3.5-turbo",
2076 Name: "GPT-3.5 Turbo",
2077 ContextWindow: 4096,
2078 DefaultMaxTokens: 2048,
2079 CostPer1MIn: 1.5,
2080 CostPer1MOut: 2.0,
2081 },
2082 },
2083 },
2084 },
2085 }
2086
2087 configPath := filepath.Join(testConfigDir, "crush.json")
2088 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
2089 data, err := json.Marshal(globalConfig)
2090 require.NoError(t, err)
2091 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2092
2093 _, err = Init(cwdDir, false)
2094
2095 require.NoError(t, err)
2096
2097 // Test GetProviderModel
2098 gpt4Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-4")
2099 assert.Equal(t, "gpt-4", gpt4Model.ID)
2100 assert.Equal(t, "GPT-4", gpt4Model.Name)
2101 assert.Equal(t, int64(8192), gpt4Model.ContextWindow)
2102 assert.Equal(t, 30.0, gpt4Model.CostPer1MIn)
2103
2104 gpt35Model := GetProviderModel(provider.InferenceProviderOpenAI, "gpt-3.5-turbo")
2105 assert.Equal(t, "gpt-3.5-turbo", gpt35Model.ID)
2106 assert.Equal(t, "GPT-3.5 Turbo", gpt35Model.Name)
2107 assert.Equal(t, 1.5, gpt35Model.CostPer1MIn)
2108
2109 // Test non-existent model
2110 nonExistentModel := GetProviderModel(provider.InferenceProviderOpenAI, "non-existent")
2111 assert.Empty(t, nonExistentModel.ID)
2112}
2113
2114func TestModelSelection_GetModel(t *testing.T) {
2115 reset()
2116 testConfigDir = t.TempDir()
2117 cwdDir := t.TempDir()
2118
2119 // Set up provider with models
2120 globalConfig := Config{
2121 Providers: map[provider.InferenceProvider]ProviderConfig{
2122 provider.InferenceProviderOpenAI: {
2123 ID: provider.InferenceProviderOpenAI,
2124 APIKey: "test-key",
2125 ProviderType: provider.TypeOpenAI,
2126 DefaultLargeModel: "gpt-4",
2127 DefaultSmallModel: "gpt-3.5-turbo",
2128 Models: []Model{
2129 {
2130 ID: "gpt-4",
2131 Name: "GPT-4",
2132 ContextWindow: 8192,
2133 DefaultMaxTokens: 4096,
2134 },
2135 {
2136 ID: "gpt-3.5-turbo",
2137 Name: "GPT-3.5 Turbo",
2138 ContextWindow: 4096,
2139 DefaultMaxTokens: 2048,
2140 },
2141 },
2142 },
2143 },
2144 }
2145
2146 configPath := filepath.Join(testConfigDir, "crush.json")
2147 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
2148 data, err := json.Marshal(globalConfig)
2149 require.NoError(t, err)
2150 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2151
2152 _, err = Init(cwdDir, false)
2153
2154 require.NoError(t, err)
2155
2156 // Test GetModel
2157 largeModel := GetModel(LargeModel)
2158 assert.Equal(t, "gpt-4", largeModel.ID)
2159 assert.Equal(t, "GPT-4", largeModel.Name)
2160
2161 smallModel := GetModel(SmallModel)
2162 assert.Equal(t, "gpt-3.5-turbo", smallModel.ID)
2163 assert.Equal(t, "GPT-3.5 Turbo", smallModel.Name)
2164}
2165
2166func TestModelSelection_UpdatePreferredModel(t *testing.T) {
2167 reset()
2168 testConfigDir = t.TempDir()
2169 cwdDir := t.TempDir()
2170
2171 // Set up multiple providers with OpenAI first to ensure it's selected initially
2172 globalConfig := Config{
2173 Providers: map[provider.InferenceProvider]ProviderConfig{
2174 provider.InferenceProviderOpenAI: {
2175 ID: provider.InferenceProviderOpenAI,
2176 APIKey: "openai-key",
2177 ProviderType: provider.TypeOpenAI,
2178 DefaultLargeModel: "gpt-4",
2179 DefaultSmallModel: "gpt-3.5-turbo",
2180 Models: []Model{
2181 {ID: "gpt-4", Name: "GPT-4"},
2182 {ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo"},
2183 },
2184 },
2185 provider.InferenceProviderAnthropic: {
2186 ID: provider.InferenceProviderAnthropic,
2187 APIKey: "anthropic-key",
2188 ProviderType: provider.TypeAnthropic,
2189 DefaultLargeModel: "claude-3-opus",
2190 DefaultSmallModel: "claude-3-haiku",
2191 Models: []Model{
2192 {ID: "claude-3-opus", Name: "Claude 3 Opus"},
2193 {ID: "claude-3-haiku", Name: "Claude 3 Haiku"},
2194 },
2195 },
2196 },
2197 }
2198
2199 configPath := filepath.Join(testConfigDir, "crush.json")
2200 require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
2201 data, err := json.Marshal(globalConfig)
2202 require.NoError(t, err)
2203 require.NoError(t, os.WriteFile(configPath, data, 0o644))
2204
2205 _, err = Init(cwdDir, false)
2206
2207 require.NoError(t, err)
2208
2209 // Get initial preferred models (should be OpenAI since it's listed first)
2210 initialLargeModel := GetModel(LargeModel)
2211 initialSmallModel := GetModel(SmallModel)
2212
2213 // Verify initial models are OpenAI models
2214 assert.Equal(t, "claude-3-opus", initialLargeModel.ID)
2215 assert.Equal(t, "claude-3-haiku", initialSmallModel.ID)
2216
2217 // Update preferred models to Anthropic
2218 newLargeModel := PreferredModel{
2219 ModelID: "gpt-4",
2220 Provider: provider.InferenceProviderOpenAI,
2221 }
2222 newSmallModel := PreferredModel{
2223 ModelID: "gpt-3.5-turbo",
2224 Provider: provider.InferenceProviderOpenAI,
2225 }
2226
2227 err = UpdatePreferredModel(LargeModel, newLargeModel)
2228 require.NoError(t, err)
2229
2230 err = UpdatePreferredModel(SmallModel, newSmallModel)
2231 require.NoError(t, err)
2232
2233 // Verify models were updated
2234 updatedLargeModel := GetModel(LargeModel)
2235 assert.Equal(t, "gpt-4", updatedLargeModel.ID)
2236 assert.NotEqual(t, initialLargeModel.ID, updatedLargeModel.ID)
2237
2238 updatedSmallModel := GetModel(SmallModel)
2239 assert.Equal(t, "gpt-3.5-turbo", updatedSmallModel.ID)
2240 assert.NotEqual(t, initialSmallModel.ID, updatedSmallModel.ID)
2241}
2242
2243func TestModelSelection_InvalidModelType(t *testing.T) {
2244 reset()
2245 testConfigDir = t.TempDir()
2246 cwdDir := t.TempDir()
2247
2248 // Set up a provider
2249 os.Setenv("ANTHROPIC_API_KEY", "test-key")
2250
2251 _, err := Init(cwdDir, false)
2252 require.NoError(t, err)
2253
2254 // Test UpdatePreferredModel with invalid model type
2255 invalidModel := PreferredModel{
2256 ModelID: "some-model",
2257 Provider: provider.InferenceProviderAnthropic,
2258 }
2259
2260 err = UpdatePreferredModel(ModelType("invalid"), invalidModel)
2261 assert.Error(t, err)
2262 assert.Contains(t, err.Error(), "unknown model type")
2263}
2264
2265func TestModelSelection_NonExistentAgent(t *testing.T) {
2266 reset()
2267 testConfigDir = t.TempDir()
2268 cwdDir := t.TempDir()
2269
2270 // Set up a provider
2271 os.Setenv("ANTHROPIC_API_KEY", "test-key")
2272
2273 _, err := Init(cwdDir, false)
2274 require.NoError(t, err)
2275
2276 // Test GetAgentModel with non-existent agent
2277 nonExistentModel := GetAgentModel(AgentID("non-existent"))
2278 assert.Empty(t, nonExistentModel.ID)
2279
2280 // Test GetAgentProvider with non-existent agent
2281 nonExistentProvider := GetAgentProvider(AgentID("non-existent"))
2282 assert.Empty(t, nonExistentProvider.ID)
2283}
2284
2285func TestModelSelection_NonExistentProvider(t *testing.T) {
2286 reset()
2287 testConfigDir = t.TempDir()
2288 cwdDir := t.TempDir()
2289
2290 // Set up a provider
2291 os.Setenv("ANTHROPIC_API_KEY", "test-key")
2292
2293 _, err := Init(cwdDir, false)
2294 require.NoError(t, err)
2295
2296 // Test GetProviderModel with non-existent provider
2297 nonExistentModel := GetProviderModel(provider.InferenceProvider("non-existent"), "some-model")
2298 assert.Empty(t, nonExistentModel.ID)
2299}