1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16)
17
18func TestMain(m *testing.M) {
19 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
20
21 exitVal := m.Run()
22 os.Exit(exitVal)
23}
24
25func TestConfig_LoadFromReaders(t *testing.T) {
26 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
27 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
28 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
29
30 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
31
32 require.NoError(t, err)
33 require.NotNil(t, loadedConfig)
34 require.Equal(t, 1, loadedConfig.Providers.Len())
35 pc, _ := loadedConfig.Providers.Get("openai")
36 require.Equal(t, "key2", pc.APIKey)
37 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
38}
39
40func TestConfig_setDefaults(t *testing.T) {
41 cfg := &Config{}
42
43 cfg.setDefaults("/tmp", "")
44
45 require.NotNil(t, cfg.Options)
46 require.NotNil(t, cfg.Options.TUI)
47 require.NotNil(t, cfg.Options.ContextPaths)
48 require.NotNil(t, cfg.Providers)
49 require.NotNil(t, cfg.Models)
50 require.NotNil(t, cfg.LSP)
51 require.NotNil(t, cfg.MCP)
52 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
53 for _, path := range defaultContextPaths {
54 require.Contains(t, cfg.Options.ContextPaths, path)
55 }
56 require.Equal(t, "/tmp", cfg.workingDir)
57}
58
59func TestConfig_configureProviders(t *testing.T) {
60 knownProviders := []catwalk.Provider{
61 {
62 ID: "openai",
63 APIKey: "$OPENAI_API_KEY",
64 APIEndpoint: "https://api.openai.com/v1",
65 Models: []catwalk.Model{{
66 ID: "test-model",
67 }},
68 },
69 }
70
71 cfg := &Config{}
72 cfg.setDefaults("/tmp", "")
73 env := env.NewFromMap(map[string]string{
74 "OPENAI_API_KEY": "test-key",
75 })
76 resolver := NewEnvironmentVariableResolver(env)
77 err := cfg.configureProviders(env, resolver, knownProviders)
78 require.NoError(t, err)
79 require.Equal(t, 1, cfg.Providers.Len())
80
81 // We want to make sure that we keep the configured API key as a placeholder
82 pc, _ := cfg.Providers.Get("openai")
83 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
84}
85
86func TestConfig_configureProvidersWithOverride(t *testing.T) {
87 knownProviders := []catwalk.Provider{
88 {
89 ID: "openai",
90 APIKey: "$OPENAI_API_KEY",
91 APIEndpoint: "https://api.openai.com/v1",
92 Models: []catwalk.Model{{
93 ID: "test-model",
94 }},
95 },
96 }
97
98 cfg := &Config{
99 Providers: csync.NewMap[string, ProviderConfig](),
100 }
101 cfg.Providers.Set("openai", ProviderConfig{
102 APIKey: "xyz",
103 BaseURL: "https://api.openai.com/v2",
104 Models: []catwalk.Model{
105 {
106 ID: "test-model",
107 Name: "Updated",
108 },
109 {
110 ID: "another-model",
111 },
112 },
113 })
114 cfg.setDefaults("/tmp", "")
115
116 env := env.NewFromMap(map[string]string{
117 "OPENAI_API_KEY": "test-key",
118 })
119 resolver := NewEnvironmentVariableResolver(env)
120 err := cfg.configureProviders(env, resolver, knownProviders)
121 require.NoError(t, err)
122 require.Equal(t, 1, cfg.Providers.Len())
123
124 // We want to make sure that we keep the configured API key as a placeholder
125 pc, _ := cfg.Providers.Get("openai")
126 require.Equal(t, "xyz", pc.APIKey)
127 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
128 require.Len(t, pc.Models, 2)
129 require.Equal(t, "Updated", pc.Models[0].Name)
130}
131
132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
133 knownProviders := []catwalk.Provider{
134 {
135 ID: "openai",
136 APIKey: "$OPENAI_API_KEY",
137 APIEndpoint: "https://api.openai.com/v1",
138 Models: []catwalk.Model{{
139 ID: "test-model",
140 }},
141 },
142 }
143
144 cfg := &Config{
145 Providers: csync.NewMapFrom(map[string]ProviderConfig{
146 "custom": {
147 APIKey: "xyz",
148 BaseURL: "https://api.someendpoint.com/v2",
149 Models: []catwalk.Model{
150 {
151 ID: "test-model",
152 },
153 },
154 },
155 }),
156 }
157 cfg.setDefaults("/tmp", "")
158 env := env.NewFromMap(map[string]string{
159 "OPENAI_API_KEY": "test-key",
160 })
161 resolver := NewEnvironmentVariableResolver(env)
162 err := cfg.configureProviders(env, resolver, knownProviders)
163 require.NoError(t, err)
164 // Should be to because of the env variable
165 require.Equal(t, cfg.Providers.Len(), 2)
166
167 // We want to make sure that we keep the configured API key as a placeholder
168 pc, _ := cfg.Providers.Get("custom")
169 require.Equal(t, "xyz", pc.APIKey)
170 // Make sure we set the ID correctly
171 require.Equal(t, "custom", pc.ID)
172 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
173 require.Len(t, pc.Models, 1)
174
175 _, ok := cfg.Providers.Get("openai")
176 require.True(t, ok, "OpenAI provider should still be present")
177}
178
179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
180 knownProviders := []catwalk.Provider{
181 {
182 ID: catwalk.InferenceProviderBedrock,
183 APIKey: "",
184 APIEndpoint: "",
185 Models: []catwalk.Model{{
186 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
187 }},
188 },
189 }
190
191 cfg := &Config{}
192 cfg.setDefaults("/tmp", "")
193 env := env.NewFromMap(map[string]string{
194 "AWS_ACCESS_KEY_ID": "test-key-id",
195 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
196 })
197 resolver := NewEnvironmentVariableResolver(env)
198 err := cfg.configureProviders(env, resolver, knownProviders)
199 require.NoError(t, err)
200 require.Equal(t, cfg.Providers.Len(), 1)
201
202 bedrockProvider, ok := cfg.Providers.Get("bedrock")
203 require.True(t, ok, "Bedrock provider should be present")
204 require.Len(t, bedrockProvider.Models, 1)
205 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
206}
207
208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
209 knownProviders := []catwalk.Provider{
210 {
211 ID: catwalk.InferenceProviderBedrock,
212 APIKey: "",
213 APIEndpoint: "",
214 Models: []catwalk.Model{{
215 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
216 }},
217 },
218 }
219
220 cfg := &Config{}
221 cfg.setDefaults("/tmp", "")
222 env := env.NewFromMap(map[string]string{})
223 resolver := NewEnvironmentVariableResolver(env)
224 err := cfg.configureProviders(env, resolver, knownProviders)
225 require.NoError(t, err)
226 // Provider should not be configured without credentials
227 require.Equal(t, cfg.Providers.Len(), 0)
228}
229
230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
231 knownProviders := []catwalk.Provider{
232 {
233 ID: catwalk.InferenceProviderBedrock,
234 APIKey: "",
235 APIEndpoint: "",
236 Models: []catwalk.Model{{
237 ID: "some-random-model",
238 }},
239 },
240 }
241
242 cfg := &Config{}
243 cfg.setDefaults("/tmp", "")
244 env := env.NewFromMap(map[string]string{
245 "AWS_ACCESS_KEY_ID": "test-key-id",
246 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
247 })
248 resolver := NewEnvironmentVariableResolver(env)
249 err := cfg.configureProviders(env, resolver, knownProviders)
250 require.Error(t, err)
251}
252
253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
254 knownProviders := []catwalk.Provider{
255 {
256 ID: catwalk.InferenceProviderVertexAI,
257 APIKey: "",
258 APIEndpoint: "",
259 Models: []catwalk.Model{{
260 ID: "gemini-pro",
261 }},
262 },
263 }
264
265 cfg := &Config{}
266 cfg.setDefaults("/tmp", "")
267 env := env.NewFromMap(map[string]string{
268 "VERTEXAI_PROJECT": "test-project",
269 "VERTEXAI_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 require.NoError(t, err)
274 require.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 require.True(t, ok, "VertexAI provider should be present")
278 require.Len(t, vertexProvider.Models, 1)
279 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp", "")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 require.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 require.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp", "")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 require.NoError(t, err)
331 // Provider should not be configured without project
332 require.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp", "")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 require.NoError(t, err)
355 require.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 require.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 require.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 require.Len(t, enabled, 1)
401 require.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 require.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 require.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 require.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 require.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
458 cfg := &Config{
459 Options: &Options{
460 DisabledTools: []string{},
461 },
462 }
463
464 cfg.SetupAgents()
465 coderAgent, ok := cfg.Agents["coder"]
466 require.True(t, ok)
467 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
468
469 taskAgent, ok := cfg.Agents["task"]
470 require.True(t, ok)
471 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
472}
473
474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
475 cfg := &Config{
476 Options: &Options{
477 DisabledTools: []string{
478 "edit",
479 "download",
480 "grep",
481 },
482 },
483 }
484
485 cfg.SetupAgents()
486 coderAgent, ok := cfg.Agents["coder"]
487 require.True(t, ok)
488 assert.Equal(t, []string{"bash", "multiedit", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
489
490 taskAgent, ok := cfg.Agents["task"]
491 require.True(t, ok)
492 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
493}
494
495func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
496 knownProviders := []catwalk.Provider{
497 {
498 ID: "openai",
499 APIKey: "$OPENAI_API_KEY",
500 APIEndpoint: "https://api.openai.com/v1",
501 Models: []catwalk.Model{{
502 ID: "test-model",
503 }},
504 },
505 }
506
507 cfg := &Config{
508 Providers: csync.NewMapFrom(map[string]ProviderConfig{
509 "openai": {
510 Disable: true,
511 },
512 }),
513 }
514 cfg.setDefaults("/tmp", "")
515
516 env := env.NewFromMap(map[string]string{
517 "OPENAI_API_KEY": "test-key",
518 })
519 resolver := NewEnvironmentVariableResolver(env)
520 err := cfg.configureProviders(env, resolver, knownProviders)
521 require.NoError(t, err)
522
523 // Provider should be removed from config when disabled
524 require.Equal(t, cfg.Providers.Len(), 0)
525 _, exists := cfg.Providers.Get("openai")
526 require.False(t, exists)
527}
528
529func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
530 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
531 cfg := &Config{
532 Providers: csync.NewMapFrom(map[string]ProviderConfig{
533 "custom": {
534 BaseURL: "https://api.custom.com/v1",
535 Models: []catwalk.Model{{
536 ID: "test-model",
537 }},
538 },
539 "openai": {
540 APIKey: "$MISSING",
541 },
542 }),
543 }
544 cfg.setDefaults("/tmp", "")
545
546 env := env.NewFromMap(map[string]string{})
547 resolver := NewEnvironmentVariableResolver(env)
548 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
549 require.NoError(t, err)
550
551 require.Equal(t, cfg.Providers.Len(), 1)
552 _, exists := cfg.Providers.Get("custom")
553 require.True(t, exists)
554 })
555
556 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
557 cfg := &Config{
558 Providers: csync.NewMapFrom(map[string]ProviderConfig{
559 "custom": {
560 APIKey: "test-key",
561 Models: []catwalk.Model{{
562 ID: "test-model",
563 }},
564 },
565 }),
566 }
567 cfg.setDefaults("/tmp", "")
568
569 env := env.NewFromMap(map[string]string{})
570 resolver := NewEnvironmentVariableResolver(env)
571 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
572 require.NoError(t, err)
573
574 require.Equal(t, cfg.Providers.Len(), 0)
575 _, exists := cfg.Providers.Get("custom")
576 require.False(t, exists)
577 })
578
579 t.Run("custom provider with no models is removed", func(t *testing.T) {
580 cfg := &Config{
581 Providers: csync.NewMapFrom(map[string]ProviderConfig{
582 "custom": {
583 APIKey: "test-key",
584 BaseURL: "https://api.custom.com/v1",
585 Models: []catwalk.Model{},
586 },
587 }),
588 }
589 cfg.setDefaults("/tmp", "")
590
591 env := env.NewFromMap(map[string]string{})
592 resolver := NewEnvironmentVariableResolver(env)
593 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
594 require.NoError(t, err)
595
596 require.Equal(t, cfg.Providers.Len(), 0)
597 _, exists := cfg.Providers.Get("custom")
598 require.False(t, exists)
599 })
600
601 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
602 cfg := &Config{
603 Providers: csync.NewMapFrom(map[string]ProviderConfig{
604 "custom": {
605 APIKey: "test-key",
606 BaseURL: "https://api.custom.com/v1",
607 Type: "unsupported",
608 Models: []catwalk.Model{{
609 ID: "test-model",
610 }},
611 },
612 }),
613 }
614 cfg.setDefaults("/tmp", "")
615
616 env := env.NewFromMap(map[string]string{})
617 resolver := NewEnvironmentVariableResolver(env)
618 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
619 require.NoError(t, err)
620
621 require.Equal(t, cfg.Providers.Len(), 0)
622 _, exists := cfg.Providers.Get("custom")
623 require.False(t, exists)
624 })
625
626 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
627 cfg := &Config{
628 Providers: csync.NewMapFrom(map[string]ProviderConfig{
629 "custom": {
630 APIKey: "test-key",
631 BaseURL: "https://api.custom.com/v1",
632 Type: catwalk.TypeOpenAI,
633 Models: []catwalk.Model{{
634 ID: "test-model",
635 }},
636 },
637 }),
638 }
639 cfg.setDefaults("/tmp", "")
640
641 env := env.NewFromMap(map[string]string{})
642 resolver := NewEnvironmentVariableResolver(env)
643 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
644 require.NoError(t, err)
645
646 require.Equal(t, cfg.Providers.Len(), 1)
647 customProvider, exists := cfg.Providers.Get("custom")
648 require.True(t, exists)
649 require.Equal(t, "custom", customProvider.ID)
650 require.Equal(t, "test-key", customProvider.APIKey)
651 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
652 })
653
654 t.Run("custom anthropic provider is supported", func(t *testing.T) {
655 cfg := &Config{
656 Providers: csync.NewMapFrom(map[string]ProviderConfig{
657 "custom-anthropic": {
658 APIKey: "test-key",
659 BaseURL: "https://api.anthropic.com/v1",
660 Type: catwalk.TypeAnthropic,
661 Models: []catwalk.Model{{
662 ID: "claude-3-sonnet",
663 }},
664 },
665 }),
666 }
667 cfg.setDefaults("/tmp", "")
668
669 env := env.NewFromMap(map[string]string{})
670 resolver := NewEnvironmentVariableResolver(env)
671 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
672 require.NoError(t, err)
673
674 require.Equal(t, cfg.Providers.Len(), 1)
675 customProvider, exists := cfg.Providers.Get("custom-anthropic")
676 require.True(t, exists)
677 require.Equal(t, "custom-anthropic", customProvider.ID)
678 require.Equal(t, "test-key", customProvider.APIKey)
679 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
680 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
681 })
682
683 t.Run("disabled custom provider is removed", func(t *testing.T) {
684 cfg := &Config{
685 Providers: csync.NewMapFrom(map[string]ProviderConfig{
686 "custom": {
687 APIKey: "test-key",
688 BaseURL: "https://api.custom.com/v1",
689 Type: catwalk.TypeOpenAI,
690 Disable: true,
691 Models: []catwalk.Model{{
692 ID: "test-model",
693 }},
694 },
695 }),
696 }
697 cfg.setDefaults("/tmp", "")
698
699 env := env.NewFromMap(map[string]string{})
700 resolver := NewEnvironmentVariableResolver(env)
701 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
702 require.NoError(t, err)
703
704 require.Equal(t, cfg.Providers.Len(), 0)
705 _, exists := cfg.Providers.Get("custom")
706 require.False(t, exists)
707 })
708}
709
710func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
711 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
712 knownProviders := []catwalk.Provider{
713 {
714 ID: catwalk.InferenceProviderVertexAI,
715 APIKey: "",
716 APIEndpoint: "",
717 Models: []catwalk.Model{{
718 ID: "gemini-pro",
719 }},
720 },
721 }
722
723 cfg := &Config{
724 Providers: csync.NewMapFrom(map[string]ProviderConfig{
725 "vertexai": {
726 BaseURL: "custom-url",
727 },
728 }),
729 }
730 cfg.setDefaults("/tmp", "")
731
732 env := env.NewFromMap(map[string]string{
733 "GOOGLE_GENAI_USE_VERTEXAI": "false",
734 })
735 resolver := NewEnvironmentVariableResolver(env)
736 err := cfg.configureProviders(env, resolver, knownProviders)
737 require.NoError(t, err)
738
739 require.Equal(t, cfg.Providers.Len(), 0)
740 _, exists := cfg.Providers.Get("vertexai")
741 require.False(t, exists)
742 })
743
744 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
745 knownProviders := []catwalk.Provider{
746 {
747 ID: catwalk.InferenceProviderBedrock,
748 APIKey: "",
749 APIEndpoint: "",
750 Models: []catwalk.Model{{
751 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
752 }},
753 },
754 }
755
756 cfg := &Config{
757 Providers: csync.NewMapFrom(map[string]ProviderConfig{
758 "bedrock": {
759 BaseURL: "custom-url",
760 },
761 }),
762 }
763 cfg.setDefaults("/tmp", "")
764
765 env := env.NewFromMap(map[string]string{})
766 resolver := NewEnvironmentVariableResolver(env)
767 err := cfg.configureProviders(env, resolver, knownProviders)
768 require.NoError(t, err)
769
770 require.Equal(t, cfg.Providers.Len(), 0)
771 _, exists := cfg.Providers.Get("bedrock")
772 require.False(t, exists)
773 })
774
775 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
776 knownProviders := []catwalk.Provider{
777 {
778 ID: "openai",
779 APIKey: "$MISSING_API_KEY",
780 APIEndpoint: "https://api.openai.com/v1",
781 Models: []catwalk.Model{{
782 ID: "test-model",
783 }},
784 },
785 }
786
787 cfg := &Config{
788 Providers: csync.NewMapFrom(map[string]ProviderConfig{
789 "openai": {
790 BaseURL: "custom-url",
791 },
792 }),
793 }
794 cfg.setDefaults("/tmp", "")
795
796 env := env.NewFromMap(map[string]string{})
797 resolver := NewEnvironmentVariableResolver(env)
798 err := cfg.configureProviders(env, resolver, knownProviders)
799 require.NoError(t, err)
800
801 require.Equal(t, cfg.Providers.Len(), 0)
802 _, exists := cfg.Providers.Get("openai")
803 require.False(t, exists)
804 })
805
806 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
807 knownProviders := []catwalk.Provider{
808 {
809 ID: "openai",
810 APIKey: "$OPENAI_API_KEY",
811 APIEndpoint: "$MISSING_ENDPOINT",
812 Models: []catwalk.Model{{
813 ID: "test-model",
814 }},
815 },
816 }
817
818 cfg := &Config{
819 Providers: csync.NewMapFrom(map[string]ProviderConfig{
820 "openai": {
821 APIKey: "test-key",
822 },
823 }),
824 }
825 cfg.setDefaults("/tmp", "")
826
827 env := env.NewFromMap(map[string]string{
828 "OPENAI_API_KEY": "test-key",
829 })
830 resolver := NewEnvironmentVariableResolver(env)
831 err := cfg.configureProviders(env, resolver, knownProviders)
832 require.NoError(t, err)
833
834 require.Equal(t, cfg.Providers.Len(), 1)
835 _, exists := cfg.Providers.Get("openai")
836 require.True(t, exists)
837 })
838}
839
840func TestConfig_defaultModelSelection(t *testing.T) {
841 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
842 knownProviders := []catwalk.Provider{
843 {
844 ID: "openai",
845 APIKey: "abc",
846 DefaultLargeModelID: "large-model",
847 DefaultSmallModelID: "small-model",
848 Models: []catwalk.Model{
849 {
850 ID: "large-model",
851 DefaultMaxTokens: 1000,
852 },
853 {
854 ID: "small-model",
855 DefaultMaxTokens: 500,
856 },
857 },
858 },
859 }
860
861 cfg := &Config{}
862 cfg.setDefaults("/tmp", "")
863 env := env.NewFromMap(map[string]string{})
864 resolver := NewEnvironmentVariableResolver(env)
865 err := cfg.configureProviders(env, resolver, knownProviders)
866 require.NoError(t, err)
867
868 large, small, err := cfg.defaultModelSelection(knownProviders)
869 require.NoError(t, err)
870 require.Equal(t, "large-model", large.Model)
871 require.Equal(t, "openai", large.Provider)
872 require.Equal(t, int64(1000), large.MaxTokens)
873 require.Equal(t, "small-model", small.Model)
874 require.Equal(t, "openai", small.Provider)
875 require.Equal(t, int64(500), small.MaxTokens)
876 })
877 t.Run("should error if no providers configured", func(t *testing.T) {
878 knownProviders := []catwalk.Provider{
879 {
880 ID: "openai",
881 APIKey: "$MISSING_KEY",
882 DefaultLargeModelID: "large-model",
883 DefaultSmallModelID: "small-model",
884 Models: []catwalk.Model{
885 {
886 ID: "large-model",
887 DefaultMaxTokens: 1000,
888 },
889 {
890 ID: "small-model",
891 DefaultMaxTokens: 500,
892 },
893 },
894 },
895 }
896
897 cfg := &Config{}
898 cfg.setDefaults("/tmp", "")
899 env := env.NewFromMap(map[string]string{})
900 resolver := NewEnvironmentVariableResolver(env)
901 err := cfg.configureProviders(env, resolver, knownProviders)
902 require.NoError(t, err)
903
904 _, _, err = cfg.defaultModelSelection(knownProviders)
905 require.Error(t, err)
906 })
907 t.Run("should error if model is missing", func(t *testing.T) {
908 knownProviders := []catwalk.Provider{
909 {
910 ID: "openai",
911 APIKey: "abc",
912 DefaultLargeModelID: "large-model",
913 DefaultSmallModelID: "small-model",
914 Models: []catwalk.Model{
915 {
916 ID: "not-large-model",
917 DefaultMaxTokens: 1000,
918 },
919 {
920 ID: "small-model",
921 DefaultMaxTokens: 500,
922 },
923 },
924 },
925 }
926
927 cfg := &Config{}
928 cfg.setDefaults("/tmp", "")
929 env := env.NewFromMap(map[string]string{})
930 resolver := NewEnvironmentVariableResolver(env)
931 err := cfg.configureProviders(env, resolver, knownProviders)
932 require.NoError(t, err)
933 _, _, err = cfg.defaultModelSelection(knownProviders)
934 require.Error(t, err)
935 })
936
937 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
938 knownProviders := []catwalk.Provider{
939 {
940 ID: "openai",
941 APIKey: "$MISSING", // will not be included in the config
942 DefaultLargeModelID: "large-model",
943 DefaultSmallModelID: "small-model",
944 Models: []catwalk.Model{
945 {
946 ID: "not-large-model",
947 DefaultMaxTokens: 1000,
948 },
949 {
950 ID: "small-model",
951 DefaultMaxTokens: 500,
952 },
953 },
954 },
955 }
956
957 cfg := &Config{
958 Providers: csync.NewMapFrom(map[string]ProviderConfig{
959 "custom": {
960 APIKey: "test-key",
961 BaseURL: "https://api.custom.com/v1",
962 Models: []catwalk.Model{
963 {
964 ID: "model",
965 DefaultMaxTokens: 600,
966 },
967 },
968 },
969 }),
970 }
971 cfg.setDefaults("/tmp", "")
972 env := env.NewFromMap(map[string]string{})
973 resolver := NewEnvironmentVariableResolver(env)
974 err := cfg.configureProviders(env, resolver, knownProviders)
975 require.NoError(t, err)
976 large, small, err := cfg.defaultModelSelection(knownProviders)
977 require.NoError(t, err)
978 require.Equal(t, "model", large.Model)
979 require.Equal(t, "custom", large.Provider)
980 require.Equal(t, int64(600), large.MaxTokens)
981 require.Equal(t, "model", small.Model)
982 require.Equal(t, "custom", small.Provider)
983 require.Equal(t, int64(600), small.MaxTokens)
984 })
985
986 t.Run("should fail if no model configured", func(t *testing.T) {
987 knownProviders := []catwalk.Provider{
988 {
989 ID: "openai",
990 APIKey: "$MISSING", // will not be included in the config
991 DefaultLargeModelID: "large-model",
992 DefaultSmallModelID: "small-model",
993 Models: []catwalk.Model{
994 {
995 ID: "not-large-model",
996 DefaultMaxTokens: 1000,
997 },
998 {
999 ID: "small-model",
1000 DefaultMaxTokens: 500,
1001 },
1002 },
1003 },
1004 }
1005
1006 cfg := &Config{
1007 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1008 "custom": {
1009 APIKey: "test-key",
1010 BaseURL: "https://api.custom.com/v1",
1011 Models: []catwalk.Model{},
1012 },
1013 }),
1014 }
1015 cfg.setDefaults("/tmp", "")
1016 env := env.NewFromMap(map[string]string{})
1017 resolver := NewEnvironmentVariableResolver(env)
1018 err := cfg.configureProviders(env, resolver, knownProviders)
1019 require.NoError(t, err)
1020 _, _, err = cfg.defaultModelSelection(knownProviders)
1021 require.Error(t, err)
1022 })
1023 t.Run("should use the default provider first", func(t *testing.T) {
1024 knownProviders := []catwalk.Provider{
1025 {
1026 ID: "openai",
1027 APIKey: "set",
1028 DefaultLargeModelID: "large-model",
1029 DefaultSmallModelID: "small-model",
1030 Models: []catwalk.Model{
1031 {
1032 ID: "large-model",
1033 DefaultMaxTokens: 1000,
1034 },
1035 {
1036 ID: "small-model",
1037 DefaultMaxTokens: 500,
1038 },
1039 },
1040 },
1041 }
1042
1043 cfg := &Config{
1044 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1045 "custom": {
1046 APIKey: "test-key",
1047 BaseURL: "https://api.custom.com/v1",
1048 Models: []catwalk.Model{
1049 {
1050 ID: "large-model",
1051 DefaultMaxTokens: 1000,
1052 },
1053 },
1054 },
1055 }),
1056 }
1057 cfg.setDefaults("/tmp", "")
1058 env := env.NewFromMap(map[string]string{})
1059 resolver := NewEnvironmentVariableResolver(env)
1060 err := cfg.configureProviders(env, resolver, knownProviders)
1061 require.NoError(t, err)
1062 large, small, err := cfg.defaultModelSelection(knownProviders)
1063 require.NoError(t, err)
1064 require.Equal(t, "large-model", large.Model)
1065 require.Equal(t, "openai", large.Provider)
1066 require.Equal(t, int64(1000), large.MaxTokens)
1067 require.Equal(t, "small-model", small.Model)
1068 require.Equal(t, "openai", small.Provider)
1069 require.Equal(t, int64(500), small.MaxTokens)
1070 })
1071}
1072
1073func TestConfig_configureSelectedModels(t *testing.T) {
1074 t.Run("should override defaults", func(t *testing.T) {
1075 knownProviders := []catwalk.Provider{
1076 {
1077 ID: "openai",
1078 APIKey: "abc",
1079 DefaultLargeModelID: "large-model",
1080 DefaultSmallModelID: "small-model",
1081 Models: []catwalk.Model{
1082 {
1083 ID: "larger-model",
1084 DefaultMaxTokens: 2000,
1085 },
1086 {
1087 ID: "large-model",
1088 DefaultMaxTokens: 1000,
1089 },
1090 {
1091 ID: "small-model",
1092 DefaultMaxTokens: 500,
1093 },
1094 },
1095 },
1096 }
1097
1098 cfg := &Config{
1099 Models: map[SelectedModelType]SelectedModel{
1100 "large": {
1101 Model: "larger-model",
1102 },
1103 },
1104 }
1105 cfg.setDefaults("/tmp", "")
1106 env := env.NewFromMap(map[string]string{})
1107 resolver := NewEnvironmentVariableResolver(env)
1108 err := cfg.configureProviders(env, resolver, knownProviders)
1109 require.NoError(t, err)
1110
1111 err = cfg.configureSelectedModels(knownProviders)
1112 require.NoError(t, err)
1113 large := cfg.Models[SelectedModelTypeLarge]
1114 small := cfg.Models[SelectedModelTypeSmall]
1115 require.Equal(t, "larger-model", large.Model)
1116 require.Equal(t, "openai", large.Provider)
1117 require.Equal(t, int64(2000), large.MaxTokens)
1118 require.Equal(t, "small-model", small.Model)
1119 require.Equal(t, "openai", small.Provider)
1120 require.Equal(t, int64(500), small.MaxTokens)
1121 })
1122 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1123 knownProviders := []catwalk.Provider{
1124 {
1125 ID: "openai",
1126 APIKey: "abc",
1127 DefaultLargeModelID: "large-model",
1128 DefaultSmallModelID: "small-model",
1129 Models: []catwalk.Model{
1130 {
1131 ID: "large-model",
1132 DefaultMaxTokens: 1000,
1133 },
1134 {
1135 ID: "small-model",
1136 DefaultMaxTokens: 500,
1137 },
1138 },
1139 },
1140 {
1141 ID: "anthropic",
1142 APIKey: "abc",
1143 DefaultLargeModelID: "a-large-model",
1144 DefaultSmallModelID: "a-small-model",
1145 Models: []catwalk.Model{
1146 {
1147 ID: "a-large-model",
1148 DefaultMaxTokens: 1000,
1149 },
1150 {
1151 ID: "a-small-model",
1152 DefaultMaxTokens: 200,
1153 },
1154 },
1155 },
1156 }
1157
1158 cfg := &Config{
1159 Models: map[SelectedModelType]SelectedModel{
1160 "small": {
1161 Model: "a-small-model",
1162 Provider: "anthropic",
1163 MaxTokens: 300,
1164 },
1165 },
1166 }
1167 cfg.setDefaults("/tmp", "")
1168 env := env.NewFromMap(map[string]string{})
1169 resolver := NewEnvironmentVariableResolver(env)
1170 err := cfg.configureProviders(env, resolver, knownProviders)
1171 require.NoError(t, err)
1172
1173 err = cfg.configureSelectedModels(knownProviders)
1174 require.NoError(t, err)
1175 large := cfg.Models[SelectedModelTypeLarge]
1176 small := cfg.Models[SelectedModelTypeSmall]
1177 require.Equal(t, "large-model", large.Model)
1178 require.Equal(t, "openai", large.Provider)
1179 require.Equal(t, int64(1000), large.MaxTokens)
1180 require.Equal(t, "a-small-model", small.Model)
1181 require.Equal(t, "anthropic", small.Provider)
1182 require.Equal(t, int64(300), small.MaxTokens)
1183 })
1184
1185 t.Run("should override the max tokens only", func(t *testing.T) {
1186 knownProviders := []catwalk.Provider{
1187 {
1188 ID: "openai",
1189 APIKey: "abc",
1190 DefaultLargeModelID: "large-model",
1191 DefaultSmallModelID: "small-model",
1192 Models: []catwalk.Model{
1193 {
1194 ID: "large-model",
1195 DefaultMaxTokens: 1000,
1196 },
1197 {
1198 ID: "small-model",
1199 DefaultMaxTokens: 500,
1200 },
1201 },
1202 },
1203 }
1204
1205 cfg := &Config{
1206 Models: map[SelectedModelType]SelectedModel{
1207 "large": {
1208 MaxTokens: 100,
1209 },
1210 },
1211 }
1212 cfg.setDefaults("/tmp", "")
1213 env := env.NewFromMap(map[string]string{})
1214 resolver := NewEnvironmentVariableResolver(env)
1215 err := cfg.configureProviders(env, resolver, knownProviders)
1216 require.NoError(t, err)
1217
1218 err = cfg.configureSelectedModels(knownProviders)
1219 require.NoError(t, err)
1220 large := cfg.Models[SelectedModelTypeLarge]
1221 require.Equal(t, "large-model", large.Model)
1222 require.Equal(t, "openai", large.Provider)
1223 require.Equal(t, int64(100), large.MaxTokens)
1224 })
1225}