1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func TestMain(m *testing.M) {
18 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
19
20 exitVal := m.Run()
21 os.Exit(exitVal)
22}
23
24func TestConfig_LoadFromReaders(t *testing.T) {
25 data1 := strings.NewReader(`{"providers": {"openai": {"api_key=key1", "base_url": "https://api.openai.com/v1"}}}`)
26 data2 := strings.NewReader(`{"providers": {"openai": {"api_key=key2", "base_url": "https://api.openai.com/v2"}}}`)
27 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
28
29 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
30
31 require.NoError(t, err)
32 require.NotNil(t, loadedConfig)
33 require.Equal(t, 1, loadedConfig.Providers.Len())
34 pc, _ := loadedConfig.Providers.Get("openai")
35 require.Equal(t, "key2", pc.APIKey)
36 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
37}
38
39func TestConfig_setDefaults(t *testing.T) {
40 cfg := &Config{}
41
42 cfg.setDefaults("/tmp", "")
43
44 require.NotNil(t, cfg.Options)
45 require.NotNil(t, cfg.Options.TUI)
46 require.NotNil(t, cfg.Options.ContextPaths)
47 require.NotNil(t, cfg.Providers)
48 require.NotNil(t, cfg.Models)
49 require.NotNil(t, cfg.LSP)
50 require.NotNil(t, cfg.MCP)
51 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
52 for _, path := range defaultContextPaths {
53 require.Contains(t, cfg.Options.ContextPaths, path)
54 }
55 require.Equal(t, "/tmp", cfg.workingDir)
56}
57
58func TestConfig_configureProviders(t *testing.T) {
59 knownProviders := []catwalk.Provider{
60 {
61 ID: "openai",
62 APIKey: "$OPENAI_API_KEY",
63 APIEndpoint: "https://api.openai.com/v1",
64 Models: []catwalk.Model{{
65 ID: "test-model",
66 }},
67 },
68 }
69
70 cfg := &Config{}
71 cfg.setDefaults("/tmp", "")
72 env := environ([]string{
73 "OPENAI_API_KEY=test-key",
74 })
75 resolver := NewEnvironmentVariableResolver(env)
76 err := cfg.configureProviders(env, resolver, knownProviders)
77 require.NoError(t, err)
78 require.Equal(t, 1, cfg.Providers.Len())
79
80 // We want to make sure that we keep the configured API key as a placeholder
81 pc, _ := cfg.Providers.Get("openai")
82 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
83}
84
85func TestConfig_configureProvidersWithOverride(t *testing.T) {
86 knownProviders := []catwalk.Provider{
87 {
88 ID: "openai",
89 APIKey: "$OPENAI_API_KEY",
90 APIEndpoint: "https://api.openai.com/v1",
91 Models: []catwalk.Model{{
92 ID: "test-model",
93 }},
94 },
95 }
96
97 cfg := &Config{
98 Providers: csync.NewMap[string, ProviderConfig](),
99 }
100 cfg.Providers.Set("openai", ProviderConfig{
101 APIKey: "xyz",
102 BaseURL: "https://api.openai.com/v2",
103 Models: []catwalk.Model{
104 {
105 ID: "test-model",
106 Name: "Updated",
107 },
108 {
109 ID: "another-model",
110 },
111 },
112 })
113 cfg.setDefaults("/tmp", "")
114
115 env := environ([]string{
116 "OPENAI_API_KEY=test-key",
117 })
118 resolver := NewEnvironmentVariableResolver(env)
119 err := cfg.configureProviders(env, resolver, knownProviders)
120 require.NoError(t, err)
121 require.Equal(t, 1, cfg.Providers.Len())
122
123 // We want to make sure that we keep the configured API key as a placeholder
124 pc, _ := cfg.Providers.Get("openai")
125 require.Equal(t, "xyz", pc.APIKey)
126 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
127 require.Len(t, pc.Models, 2)
128 require.Equal(t, "Updated", pc.Models[0].Name)
129}
130
131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
132 knownProviders := []catwalk.Provider{
133 {
134 ID: "openai",
135 APIKey: "$OPENAI_API_KEY",
136 APIEndpoint: "https://api.openai.com/v1",
137 Models: []catwalk.Model{{
138 ID: "test-model",
139 }},
140 },
141 }
142
143 cfg := &Config{
144 Providers: csync.NewMapFrom(map[string]ProviderConfig{
145 "custom": {
146 APIKey: "xyz",
147 BaseURL: "https://api.someendpoint.com/v2",
148 Models: []catwalk.Model{
149 {
150 ID: "test-model",
151 },
152 },
153 },
154 }),
155 }
156 cfg.setDefaults("/tmp", "")
157 env := environ([]string{
158 "OPENAI_API_KEY=test-key",
159 })
160 resolver := NewEnvironmentVariableResolver(env)
161 err := cfg.configureProviders(env, resolver, knownProviders)
162 require.NoError(t, err)
163 // Should be to because of the env variable
164 require.Equal(t, cfg.Providers.Len(), 2)
165
166 // We want to make sure that we keep the configured API key as a placeholder
167 pc, _ := cfg.Providers.Get("custom")
168 require.Equal(t, "xyz", pc.APIKey)
169 // Make sure we set the ID correctly
170 require.Equal(t, "custom", pc.ID)
171 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
172 require.Len(t, pc.Models, 1)
173
174 _, ok := cfg.Providers.Get("openai")
175 require.True(t, ok, "OpenAI provider should still be present")
176}
177
178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
179 knownProviders := []catwalk.Provider{
180 {
181 ID: catwalk.InferenceProviderBedrock,
182 APIKey: "",
183 APIEndpoint: "",
184 Models: []catwalk.Model{{
185 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
186 }},
187 },
188 }
189
190 cfg := &Config{}
191 cfg.setDefaults("/tmp", "")
192 env := environ([]string{
193 "AWS_ACCESS_KEY_ID=test-key-id",
194 "AWS_SECRET_ACCESS_KEY=test-secret-key",
195 })
196 resolver := NewEnvironmentVariableResolver(env)
197 err := cfg.configureProviders(env, resolver, knownProviders)
198 require.NoError(t, err)
199 require.Equal(t, cfg.Providers.Len(), 1)
200
201 bedrockProvider, ok := cfg.Providers.Get("bedrock")
202 require.True(t, ok, "Bedrock provider should be present")
203 require.Len(t, bedrockProvider.Models, 1)
204 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
205}
206
207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
208 knownProviders := []catwalk.Provider{
209 {
210 ID: catwalk.InferenceProviderBedrock,
211 APIKey: "",
212 APIEndpoint: "",
213 Models: []catwalk.Model{{
214 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
215 }},
216 },
217 }
218
219 cfg := &Config{}
220 cfg.setDefaults("/tmp", "")
221 env := environ([]string{})
222 resolver := NewEnvironmentVariableResolver(env)
223 err := cfg.configureProviders(env, resolver, knownProviders)
224 require.NoError(t, err)
225 // Provider should not be configured without credentials
226 require.Equal(t, cfg.Providers.Len(), 0)
227}
228
229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
230 knownProviders := []catwalk.Provider{
231 {
232 ID: catwalk.InferenceProviderBedrock,
233 APIKey: "",
234 APIEndpoint: "",
235 Models: []catwalk.Model{{
236 ID: "some-random-model",
237 }},
238 },
239 }
240
241 cfg := &Config{}
242 cfg.setDefaults("/tmp", "")
243 env := environ([]string{
244 "AWS_ACCESS_KEY_ID=test-key-id",
245 "AWS_SECRET_ACCESS_KEY=test-secret-key",
246 })
247 resolver := NewEnvironmentVariableResolver(env)
248 err := cfg.configureProviders(env, resolver, knownProviders)
249 require.Error(t, err)
250}
251
252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
253 knownProviders := []catwalk.Provider{
254 {
255 ID: catwalk.InferenceProviderVertexAI,
256 APIKey: "",
257 APIEndpoint: "",
258 Models: []catwalk.Model{{
259 ID: "gemini-pro",
260 }},
261 },
262 }
263
264 cfg := &Config{}
265 cfg.setDefaults("/tmp", "")
266 env := environ([]string{
267 "VERTEXAI_PROJECT=test-project",
268 "VERTEXAI_LOCATION=us-central1",
269 })
270 resolver := NewEnvironmentVariableResolver(env)
271 err := cfg.configureProviders(env, resolver, knownProviders)
272 require.NoError(t, err)
273 require.Equal(t, cfg.Providers.Len(), 1)
274
275 vertexProvider, ok := cfg.Providers.Get("vertexai")
276 require.True(t, ok, "VertexAI provider should be present")
277 require.Len(t, vertexProvider.Models, 1)
278 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
279 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
280 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
281}
282
283func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
284 knownProviders := []catwalk.Provider{
285 {
286 ID: catwalk.InferenceProviderVertexAI,
287 APIKey: "",
288 APIEndpoint: "",
289 Models: []catwalk.Model{{
290 ID: "gemini-pro",
291 }},
292 },
293 }
294
295 cfg := &Config{}
296 cfg.setDefaults("/tmp", "")
297 env := environ([]string{
298 "GOOGLE_GENAI_USE_VERTEXAI=false",
299 "GOOGLE_CLOUD_PROJECT=test-project",
300 "GOOGLE_CLOUD_LOCATION=us-central1",
301 })
302 resolver := NewEnvironmentVariableResolver(env)
303 err := cfg.configureProviders(env, resolver, knownProviders)
304 require.NoError(t, err)
305 // Provider should not be configured without proper credentials
306 require.Equal(t, cfg.Providers.Len(), 0)
307}
308
309func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
310 knownProviders := []catwalk.Provider{
311 {
312 ID: catwalk.InferenceProviderVertexAI,
313 APIKey: "",
314 APIEndpoint: "",
315 Models: []catwalk.Model{{
316 ID: "gemini-pro",
317 }},
318 },
319 }
320
321 cfg := &Config{}
322 cfg.setDefaults("/tmp", "")
323 env := environ([]string{
324 "GOOGLE_GENAI_USE_VERTEXAI=true",
325 "GOOGLE_CLOUD_LOCATION=us-central1",
326 })
327 resolver := NewEnvironmentVariableResolver(env)
328 err := cfg.configureProviders(env, resolver, knownProviders)
329 require.NoError(t, err)
330 // Provider should not be configured without project
331 require.Equal(t, cfg.Providers.Len(), 0)
332}
333
334func TestConfig_configureProvidersSetProviderID(t *testing.T) {
335 knownProviders := []catwalk.Provider{
336 {
337 ID: "openai",
338 APIKey: "$OPENAI_API_KEY",
339 APIEndpoint: "https://api.openai.com/v1",
340 Models: []catwalk.Model{{
341 ID: "test-model",
342 }},
343 },
344 }
345
346 cfg := &Config{}
347 cfg.setDefaults("/tmp", "")
348 env := environ([]string{
349 "OPENAI_API_KEY=test-key",
350 })
351 resolver := NewEnvironmentVariableResolver(env)
352 err := cfg.configureProviders(env, resolver, knownProviders)
353 require.NoError(t, err)
354 require.Equal(t, cfg.Providers.Len(), 1)
355
356 // Provider ID should be set
357 pc, _ := cfg.Providers.Get("openai")
358 require.Equal(t, "openai", pc.ID)
359}
360
361func TestConfig_EnabledProviders(t *testing.T) {
362 t.Run("all providers enabled", func(t *testing.T) {
363 cfg := &Config{
364 Providers: csync.NewMapFrom(map[string]ProviderConfig{
365 "openai": {
366 ID: "openai",
367 APIKey: "key1",
368 Disable: false,
369 },
370 "anthropic": {
371 ID: "anthropic",
372 APIKey: "key2",
373 Disable: false,
374 },
375 }),
376 }
377
378 enabled := cfg.EnabledProviders()
379 require.Len(t, enabled, 2)
380 })
381
382 t.Run("some providers disabled", func(t *testing.T) {
383 cfg := &Config{
384 Providers: csync.NewMapFrom(map[string]ProviderConfig{
385 "openai": {
386 ID: "openai",
387 APIKey: "key1",
388 Disable: false,
389 },
390 "anthropic": {
391 ID: "anthropic",
392 APIKey: "key2",
393 Disable: true,
394 },
395 }),
396 }
397
398 enabled := cfg.EnabledProviders()
399 require.Len(t, enabled, 1)
400 require.Equal(t, "openai", enabled[0].ID)
401 })
402
403 t.Run("empty providers map", func(t *testing.T) {
404 cfg := &Config{
405 Providers: csync.NewMap[string, ProviderConfig](),
406 }
407
408 enabled := cfg.EnabledProviders()
409 require.Len(t, enabled, 0)
410 })
411}
412
413func TestConfig_IsConfigured(t *testing.T) {
414 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
415 cfg := &Config{
416 Providers: csync.NewMapFrom(map[string]ProviderConfig{
417 "openai": {
418 ID: "openai",
419 APIKey: "key1",
420 Disable: false,
421 },
422 }),
423 }
424
425 require.True(t, cfg.IsConfigured())
426 })
427
428 t.Run("returns false when no providers are configured", func(t *testing.T) {
429 cfg := &Config{
430 Providers: csync.NewMap[string, ProviderConfig](),
431 }
432
433 require.False(t, cfg.IsConfigured())
434 })
435
436 t.Run("returns false when all providers are disabled", func(t *testing.T) {
437 cfg := &Config{
438 Providers: csync.NewMapFrom(map[string]ProviderConfig{
439 "openai": {
440 ID: "openai",
441 APIKey: "key1",
442 Disable: true,
443 },
444 "anthropic": {
445 ID: "anthropic",
446 APIKey: "key2",
447 Disable: true,
448 },
449 }),
450 }
451
452 require.False(t, cfg.IsConfigured())
453 })
454}
455
456func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
457 cfg := &Config{
458 Options: &Options{
459 DisabledTools: []string{},
460 },
461 }
462
463 cfg.SetupAgents()
464 coderAgent, ok := cfg.Agents["coder"]
465 require.True(t, ok)
466 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
467
468 taskAgent, ok := cfg.Agents["task"]
469 require.True(t, ok)
470 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
471}
472
473func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
474 cfg := &Config{
475 Options: &Options{
476 DisabledTools: []string{
477 "edit",
478 "download",
479 "grep",
480 },
481 },
482 }
483
484 cfg.SetupAgents()
485 coderAgent, ok := cfg.Agents["coder"]
486 require.True(t, ok)
487 assert.Equal(t, []string{"agent", "bash", "multiedit", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
488
489 taskAgent, ok := cfg.Agents["task"]
490 require.True(t, ok)
491 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
492}
493
494func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
495 cfg := &Config{
496 Options: &Options{
497 DisabledTools: []string{
498 "glob",
499 "grep",
500 "ls",
501 "sourcegraph",
502 "view",
503 },
504 },
505 }
506
507 cfg.SetupAgents()
508 coderAgent, ok := cfg.Agents["coder"]
509 require.True(t, ok)
510 assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "fetch", "write"}, coderAgent.AllowedTools)
511
512 taskAgent, ok := cfg.Agents["task"]
513 require.True(t, ok)
514 assert.Equal(t, []string{}, taskAgent.AllowedTools)
515}
516
517func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
518 knownProviders := []catwalk.Provider{
519 {
520 ID: "openai",
521 APIKey: "$OPENAI_API_KEY",
522 APIEndpoint: "https://api.openai.com/v1",
523 Models: []catwalk.Model{{
524 ID: "test-model",
525 }},
526 },
527 }
528
529 cfg := &Config{
530 Providers: csync.NewMapFrom(map[string]ProviderConfig{
531 "openai": {
532 Disable: true,
533 },
534 }),
535 }
536 cfg.setDefaults("/tmp", "")
537
538 env := environ([]string{
539 "OPENAI_API_KEY=test-key",
540 })
541 resolver := NewEnvironmentVariableResolver(env)
542 err := cfg.configureProviders(env, resolver, knownProviders)
543 require.NoError(t, err)
544
545 require.Equal(t, cfg.Providers.Len(), 1)
546 prov, exists := cfg.Providers.Get("openai")
547 require.True(t, exists)
548 require.True(t, prov.Disable)
549}
550
551func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
552 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
553 cfg := &Config{
554 Providers: csync.NewMapFrom(map[string]ProviderConfig{
555 "custom": {
556 BaseURL: "https://api.custom.com/v1",
557 Models: []catwalk.Model{{
558 ID: "test-model",
559 }},
560 },
561 "openai": {
562 APIKey: "$MISSING",
563 },
564 }),
565 }
566 cfg.setDefaults("/tmp", "")
567
568 env := environ([]string{})
569 resolver := NewEnvironmentVariableResolver(env)
570 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
571 require.NoError(t, err)
572
573 require.Equal(t, cfg.Providers.Len(), 1)
574 _, exists := cfg.Providers.Get("custom")
575 require.True(t, exists)
576 })
577
578 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
579 cfg := &Config{
580 Providers: csync.NewMapFrom(map[string]ProviderConfig{
581 "custom": {
582 APIKey: "test-key",
583 Models: []catwalk.Model{{
584 ID: "test-model",
585 }},
586 },
587 }),
588 }
589 cfg.setDefaults("/tmp", "")
590
591 env := environ([]string{})
592 resolver := NewEnvironmentVariableResolver(env)
593 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
594 require.NoError(t, err)
595
596 require.Equal(t, cfg.Providers.Len(), 0)
597 _, exists := cfg.Providers.Get("custom")
598 require.False(t, exists)
599 })
600
601 t.Run("custom provider with no models is removed", func(t *testing.T) {
602 cfg := &Config{
603 Providers: csync.NewMapFrom(map[string]ProviderConfig{
604 "custom": {
605 APIKey: "test-key",
606 BaseURL: "https://api.custom.com/v1",
607 Models: []catwalk.Model{},
608 },
609 }),
610 }
611 cfg.setDefaults("/tmp", "")
612
613 env := environ([]string{})
614 resolver := NewEnvironmentVariableResolver(env)
615 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
616 require.NoError(t, err)
617
618 require.Equal(t, cfg.Providers.Len(), 0)
619 _, exists := cfg.Providers.Get("custom")
620 require.False(t, exists)
621 })
622
623 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
624 cfg := &Config{
625 Providers: csync.NewMapFrom(map[string]ProviderConfig{
626 "custom": {
627 APIKey: "test-key",
628 BaseURL: "https://api.custom.com/v1",
629 Type: "unsupported",
630 Models: []catwalk.Model{{
631 ID: "test-model",
632 }},
633 },
634 }),
635 }
636 cfg.setDefaults("/tmp", "")
637
638 env := environ([]string{})
639 resolver := NewEnvironmentVariableResolver(env)
640 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
641 require.NoError(t, err)
642
643 require.Equal(t, cfg.Providers.Len(), 0)
644 _, exists := cfg.Providers.Get("custom")
645 require.False(t, exists)
646 })
647
648 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
649 cfg := &Config{
650 Providers: csync.NewMapFrom(map[string]ProviderConfig{
651 "custom": {
652 APIKey: "test-key",
653 BaseURL: "https://api.custom.com/v1",
654 Type: catwalk.TypeOpenAI,
655 Models: []catwalk.Model{{
656 ID: "test-model",
657 }},
658 },
659 }),
660 }
661 cfg.setDefaults("/tmp", "")
662
663 env := environ([]string{})
664 resolver := NewEnvironmentVariableResolver(env)
665 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
666 require.NoError(t, err)
667
668 require.Equal(t, cfg.Providers.Len(), 1)
669 customProvider, exists := cfg.Providers.Get("custom")
670 require.True(t, exists)
671 require.Equal(t, "custom", customProvider.ID)
672 require.Equal(t, "test-key", customProvider.APIKey)
673 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
674 })
675
676 t.Run("custom anthropic provider is supported", func(t *testing.T) {
677 cfg := &Config{
678 Providers: csync.NewMapFrom(map[string]ProviderConfig{
679 "custom-anthropic": {
680 APIKey: "test-key",
681 BaseURL: "https://api.anthropic.com/v1",
682 Type: catwalk.TypeAnthropic,
683 Models: []catwalk.Model{{
684 ID: "claude-3-sonnet",
685 }},
686 },
687 }),
688 }
689 cfg.setDefaults("/tmp", "")
690
691 env := environ([]string{})
692 resolver := NewEnvironmentVariableResolver(env)
693 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
694 require.NoError(t, err)
695
696 require.Equal(t, cfg.Providers.Len(), 1)
697 customProvider, exists := cfg.Providers.Get("custom-anthropic")
698 require.True(t, exists)
699 require.Equal(t, "custom-anthropic", customProvider.ID)
700 require.Equal(t, "test-key", customProvider.APIKey)
701 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
702 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
703 })
704
705 t.Run("disabled custom provider is removed", func(t *testing.T) {
706 cfg := &Config{
707 Providers: csync.NewMapFrom(map[string]ProviderConfig{
708 "custom": {
709 APIKey: "test-key",
710 BaseURL: "https://api.custom.com/v1",
711 Type: catwalk.TypeOpenAI,
712 Disable: true,
713 Models: []catwalk.Model{{
714 ID: "test-model",
715 }},
716 },
717 }),
718 }
719 cfg.setDefaults("/tmp", "")
720
721 env := environ([]string{})
722 resolver := NewEnvironmentVariableResolver(env)
723 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
724 require.NoError(t, err)
725
726 require.Equal(t, cfg.Providers.Len(), 0)
727 _, exists := cfg.Providers.Get("custom")
728 require.False(t, exists)
729 })
730}
731
732func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
733 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
734 knownProviders := []catwalk.Provider{
735 {
736 ID: catwalk.InferenceProviderVertexAI,
737 APIKey: "",
738 APIEndpoint: "",
739 Models: []catwalk.Model{{
740 ID: "gemini-pro",
741 }},
742 },
743 }
744
745 cfg := &Config{
746 Providers: csync.NewMapFrom(map[string]ProviderConfig{
747 "vertexai": {
748 BaseURL: "custom-url",
749 },
750 }),
751 }
752 cfg.setDefaults("/tmp", "")
753
754 env := environ([]string{
755 "GOOGLE_GENAI_USE_VERTEXAI=false",
756 })
757 resolver := NewEnvironmentVariableResolver(env)
758 err := cfg.configureProviders(env, resolver, knownProviders)
759 require.NoError(t, err)
760
761 require.Equal(t, cfg.Providers.Len(), 0)
762 _, exists := cfg.Providers.Get("vertexai")
763 require.False(t, exists)
764 })
765
766 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
767 knownProviders := []catwalk.Provider{
768 {
769 ID: catwalk.InferenceProviderBedrock,
770 APIKey: "",
771 APIEndpoint: "",
772 Models: []catwalk.Model{{
773 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
774 }},
775 },
776 }
777
778 cfg := &Config{
779 Providers: csync.NewMapFrom(map[string]ProviderConfig{
780 "bedrock": {
781 BaseURL: "custom-url",
782 },
783 }),
784 }
785 cfg.setDefaults("/tmp", "")
786
787 env := environ([]string{})
788 resolver := NewEnvironmentVariableResolver(env)
789 err := cfg.configureProviders(env, resolver, knownProviders)
790 require.NoError(t, err)
791
792 require.Equal(t, cfg.Providers.Len(), 0)
793 _, exists := cfg.Providers.Get("bedrock")
794 require.False(t, exists)
795 })
796
797 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
798 knownProviders := []catwalk.Provider{
799 {
800 ID: "openai",
801 APIKey: "$MISSING_API_KEY",
802 APIEndpoint: "https://api.openai.com/v1",
803 Models: []catwalk.Model{{
804 ID: "test-model",
805 }},
806 },
807 }
808
809 cfg := &Config{
810 Providers: csync.NewMapFrom(map[string]ProviderConfig{
811 "openai": {
812 BaseURL: "custom-url",
813 },
814 }),
815 }
816 cfg.setDefaults("/tmp", "")
817
818 env := environ([]string{})
819 resolver := NewEnvironmentVariableResolver(env)
820 err := cfg.configureProviders(env, resolver, knownProviders)
821 require.NoError(t, err)
822
823 require.Equal(t, cfg.Providers.Len(), 0)
824 _, exists := cfg.Providers.Get("openai")
825 require.False(t, exists)
826 })
827
828 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
829 knownProviders := []catwalk.Provider{
830 {
831 ID: "openai",
832 APIKey: "$OPENAI_API_KEY",
833 APIEndpoint: "$MISSING_ENDPOINT",
834 Models: []catwalk.Model{{
835 ID: "test-model",
836 }},
837 },
838 }
839
840 cfg := &Config{
841 Providers: csync.NewMapFrom(map[string]ProviderConfig{
842 "openai": {
843 APIKey: "test-key",
844 },
845 }),
846 }
847 cfg.setDefaults("/tmp", "")
848
849 env := environ([]string{
850 "OPENAI_API_KEY=test-key",
851 })
852 resolver := NewEnvironmentVariableResolver(env)
853 err := cfg.configureProviders(env, resolver, knownProviders)
854 require.NoError(t, err)
855
856 require.Equal(t, cfg.Providers.Len(), 1)
857 _, exists := cfg.Providers.Get("openai")
858 require.True(t, exists)
859 })
860}
861
862func TestConfig_defaultModelSelection(t *testing.T) {
863 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
864 knownProviders := []catwalk.Provider{
865 {
866 ID: "openai",
867 APIKey: "abc",
868 DefaultLargeModelID: "large-model",
869 DefaultSmallModelID: "small-model",
870 Models: []catwalk.Model{
871 {
872 ID: "large-model",
873 DefaultMaxTokens: 1000,
874 },
875 {
876 ID: "small-model",
877 DefaultMaxTokens: 500,
878 },
879 },
880 },
881 }
882
883 cfg := &Config{}
884 cfg.setDefaults("/tmp", "")
885 env := environ([]string{})
886 resolver := NewEnvironmentVariableResolver(env)
887 err := cfg.configureProviders(env, resolver, knownProviders)
888 require.NoError(t, err)
889
890 large, small, err := cfg.defaultModelSelection(knownProviders)
891 require.NoError(t, err)
892 require.Equal(t, "large-model", large.Model)
893 require.Equal(t, "openai", large.Provider)
894 require.Equal(t, int64(1000), large.MaxTokens)
895 require.Equal(t, "small-model", small.Model)
896 require.Equal(t, "openai", small.Provider)
897 require.Equal(t, int64(500), small.MaxTokens)
898 })
899 t.Run("should error if no providers configured", func(t *testing.T) {
900 knownProviders := []catwalk.Provider{
901 {
902 ID: "openai",
903 APIKey: "$MISSING_KEY",
904 DefaultLargeModelID: "large-model",
905 DefaultSmallModelID: "small-model",
906 Models: []catwalk.Model{
907 {
908 ID: "large-model",
909 DefaultMaxTokens: 1000,
910 },
911 {
912 ID: "small-model",
913 DefaultMaxTokens: 500,
914 },
915 },
916 },
917 }
918
919 cfg := &Config{}
920 cfg.setDefaults("/tmp", "")
921 env := environ([]string{})
922 resolver := NewEnvironmentVariableResolver(env)
923 err := cfg.configureProviders(env, resolver, knownProviders)
924 require.NoError(t, err)
925
926 _, _, err = cfg.defaultModelSelection(knownProviders)
927 require.Error(t, err)
928 })
929 t.Run("should error if model is missing", func(t *testing.T) {
930 knownProviders := []catwalk.Provider{
931 {
932 ID: "openai",
933 APIKey: "abc",
934 DefaultLargeModelID: "large-model",
935 DefaultSmallModelID: "small-model",
936 Models: []catwalk.Model{
937 {
938 ID: "not-large-model",
939 DefaultMaxTokens: 1000,
940 },
941 {
942 ID: "small-model",
943 DefaultMaxTokens: 500,
944 },
945 },
946 },
947 }
948
949 cfg := &Config{}
950 cfg.setDefaults("/tmp", "")
951 env := environ([]string{})
952 resolver := NewEnvironmentVariableResolver(env)
953 err := cfg.configureProviders(env, resolver, knownProviders)
954 require.NoError(t, err)
955 _, _, err = cfg.defaultModelSelection(knownProviders)
956 require.Error(t, err)
957 })
958
959 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
960 knownProviders := []catwalk.Provider{
961 {
962 ID: "openai",
963 APIKey: "$MISSING", // will not be included in the config
964 DefaultLargeModelID: "large-model",
965 DefaultSmallModelID: "small-model",
966 Models: []catwalk.Model{
967 {
968 ID: "not-large-model",
969 DefaultMaxTokens: 1000,
970 },
971 {
972 ID: "small-model",
973 DefaultMaxTokens: 500,
974 },
975 },
976 },
977 }
978
979 cfg := &Config{
980 Providers: csync.NewMapFrom(map[string]ProviderConfig{
981 "custom": {
982 APIKey: "test-key",
983 BaseURL: "https://api.custom.com/v1",
984 Models: []catwalk.Model{
985 {
986 ID: "model",
987 DefaultMaxTokens: 600,
988 },
989 },
990 },
991 }),
992 }
993 cfg.setDefaults("/tmp", "")
994 env := environ([]string{})
995 resolver := NewEnvironmentVariableResolver(env)
996 err := cfg.configureProviders(env, resolver, knownProviders)
997 require.NoError(t, err)
998 large, small, err := cfg.defaultModelSelection(knownProviders)
999 require.NoError(t, err)
1000 require.Equal(t, "model", large.Model)
1001 require.Equal(t, "custom", large.Provider)
1002 require.Equal(t, int64(600), large.MaxTokens)
1003 require.Equal(t, "model", small.Model)
1004 require.Equal(t, "custom", small.Provider)
1005 require.Equal(t, int64(600), small.MaxTokens)
1006 })
1007
1008 t.Run("should fail if no model configured", func(t *testing.T) {
1009 knownProviders := []catwalk.Provider{
1010 {
1011 ID: "openai",
1012 APIKey: "$MISSING", // will not be included in the config
1013 DefaultLargeModelID: "large-model",
1014 DefaultSmallModelID: "small-model",
1015 Models: []catwalk.Model{
1016 {
1017 ID: "not-large-model",
1018 DefaultMaxTokens: 1000,
1019 },
1020 {
1021 ID: "small-model",
1022 DefaultMaxTokens: 500,
1023 },
1024 },
1025 },
1026 }
1027
1028 cfg := &Config{
1029 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1030 "custom": {
1031 APIKey: "test-key",
1032 BaseURL: "https://api.custom.com/v1",
1033 Models: []catwalk.Model{},
1034 },
1035 }),
1036 }
1037 cfg.setDefaults("/tmp", "")
1038 env := environ([]string{})
1039 resolver := NewEnvironmentVariableResolver(env)
1040 err := cfg.configureProviders(env, resolver, knownProviders)
1041 require.NoError(t, err)
1042 _, _, err = cfg.defaultModelSelection(knownProviders)
1043 require.Error(t, err)
1044 })
1045 t.Run("should use the default provider first", func(t *testing.T) {
1046 knownProviders := []catwalk.Provider{
1047 {
1048 ID: "openai",
1049 APIKey: "set",
1050 DefaultLargeModelID: "large-model",
1051 DefaultSmallModelID: "small-model",
1052 Models: []catwalk.Model{
1053 {
1054 ID: "large-model",
1055 DefaultMaxTokens: 1000,
1056 },
1057 {
1058 ID: "small-model",
1059 DefaultMaxTokens: 500,
1060 },
1061 },
1062 },
1063 }
1064
1065 cfg := &Config{
1066 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1067 "custom": {
1068 APIKey: "test-key",
1069 BaseURL: "https://api.custom.com/v1",
1070 Models: []catwalk.Model{
1071 {
1072 ID: "large-model",
1073 DefaultMaxTokens: 1000,
1074 },
1075 },
1076 },
1077 }),
1078 }
1079 cfg.setDefaults("/tmp", "")
1080 env := environ([]string{})
1081 resolver := NewEnvironmentVariableResolver(env)
1082 err := cfg.configureProviders(env, resolver, knownProviders)
1083 require.NoError(t, err)
1084 large, small, err := cfg.defaultModelSelection(knownProviders)
1085 require.NoError(t, err)
1086 require.Equal(t, "large-model", large.Model)
1087 require.Equal(t, "openai", large.Provider)
1088 require.Equal(t, int64(1000), large.MaxTokens)
1089 require.Equal(t, "small-model", small.Model)
1090 require.Equal(t, "openai", small.Provider)
1091 require.Equal(t, int64(500), small.MaxTokens)
1092 })
1093}
1094
1095func TestConfig_configureSelectedModels(t *testing.T) {
1096 t.Run("should override defaults", func(t *testing.T) {
1097 knownProviders := []catwalk.Provider{
1098 {
1099 ID: "openai",
1100 APIKey: "abc",
1101 DefaultLargeModelID: "large-model",
1102 DefaultSmallModelID: "small-model",
1103 Models: []catwalk.Model{
1104 {
1105 ID: "larger-model",
1106 DefaultMaxTokens: 2000,
1107 },
1108 {
1109 ID: "large-model",
1110 DefaultMaxTokens: 1000,
1111 },
1112 {
1113 ID: "small-model",
1114 DefaultMaxTokens: 500,
1115 },
1116 },
1117 },
1118 }
1119
1120 cfg := &Config{
1121 Models: map[SelectedModelType]SelectedModel{
1122 "large": {
1123 Model: "larger-model",
1124 },
1125 },
1126 }
1127 cfg.setDefaults("/tmp", "")
1128 env := environ([]string{})
1129 resolver := NewEnvironmentVariableResolver(env)
1130 err := cfg.configureProviders(env, resolver, knownProviders)
1131 require.NoError(t, err)
1132
1133 err = cfg.configureSelectedModels(knownProviders)
1134 require.NoError(t, err)
1135 large := cfg.Models[SelectedModelTypeLarge]
1136 small := cfg.Models[SelectedModelTypeSmall]
1137 require.Equal(t, "larger-model", large.Model)
1138 require.Equal(t, "openai", large.Provider)
1139 require.Equal(t, int64(2000), large.MaxTokens)
1140 require.Equal(t, "small-model", small.Model)
1141 require.Equal(t, "openai", small.Provider)
1142 require.Equal(t, int64(500), small.MaxTokens)
1143 })
1144 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1145 knownProviders := []catwalk.Provider{
1146 {
1147 ID: "openai",
1148 APIKey: "abc",
1149 DefaultLargeModelID: "large-model",
1150 DefaultSmallModelID: "small-model",
1151 Models: []catwalk.Model{
1152 {
1153 ID: "large-model",
1154 DefaultMaxTokens: 1000,
1155 },
1156 {
1157 ID: "small-model",
1158 DefaultMaxTokens: 500,
1159 },
1160 },
1161 },
1162 {
1163 ID: "anthropic",
1164 APIKey: "abc",
1165 DefaultLargeModelID: "a-large-model",
1166 DefaultSmallModelID: "a-small-model",
1167 Models: []catwalk.Model{
1168 {
1169 ID: "a-large-model",
1170 DefaultMaxTokens: 1000,
1171 },
1172 {
1173 ID: "a-small-model",
1174 DefaultMaxTokens: 200,
1175 },
1176 },
1177 },
1178 }
1179
1180 cfg := &Config{
1181 Models: map[SelectedModelType]SelectedModel{
1182 "small": {
1183 Model: "a-small-model",
1184 Provider: "anthropic",
1185 MaxTokens: 300,
1186 },
1187 },
1188 }
1189 cfg.setDefaults("/tmp", "")
1190 env := environ([]string{})
1191 resolver := NewEnvironmentVariableResolver(env)
1192 err := cfg.configureProviders(env, resolver, knownProviders)
1193 require.NoError(t, err)
1194
1195 err = cfg.configureSelectedModels(knownProviders)
1196 require.NoError(t, err)
1197 large := cfg.Models[SelectedModelTypeLarge]
1198 small := cfg.Models[SelectedModelTypeSmall]
1199 require.Equal(t, "large-model", large.Model)
1200 require.Equal(t, "openai", large.Provider)
1201 require.Equal(t, int64(1000), large.MaxTokens)
1202 require.Equal(t, "a-small-model", small.Model)
1203 require.Equal(t, "anthropic", small.Provider)
1204 require.Equal(t, int64(300), small.MaxTokens)
1205 })
1206
1207 t.Run("should override the max tokens only", func(t *testing.T) {
1208 knownProviders := []catwalk.Provider{
1209 {
1210 ID: "openai",
1211 APIKey: "abc",
1212 DefaultLargeModelID: "large-model",
1213 DefaultSmallModelID: "small-model",
1214 Models: []catwalk.Model{
1215 {
1216 ID: "large-model",
1217 DefaultMaxTokens: 1000,
1218 },
1219 {
1220 ID: "small-model",
1221 DefaultMaxTokens: 500,
1222 },
1223 },
1224 },
1225 }
1226
1227 cfg := &Config{
1228 Models: map[SelectedModelType]SelectedModel{
1229 "large": {
1230 MaxTokens: 100,
1231 },
1232 },
1233 }
1234 cfg.setDefaults("/tmp", "")
1235 env := environ([]string{})
1236 resolver := NewEnvironmentVariableResolver(env)
1237 err := cfg.configureProviders(env, resolver, knownProviders)
1238 require.NoError(t, err)
1239
1240 err = cfg.configureSelectedModels(knownProviders)
1241 require.NoError(t, err)
1242 large := cfg.Models[SelectedModelTypeLarge]
1243 require.Equal(t, "large-model", large.Model)
1244 require.Equal(t, "openai", large.Provider)
1245 require.Equal(t, int64(100), large.MaxTokens)
1246 })
1247}