1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16)
17
18func TestMain(m *testing.M) {
19 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
20
21 exitVal := m.Run()
22 os.Exit(exitVal)
23}
24
25func TestConfig_LoadFromReaders(t *testing.T) {
26 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
27 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
28 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
29
30 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
31
32 require.NoError(t, err)
33 require.NotNil(t, loadedConfig)
34 require.Equal(t, 1, loadedConfig.Providers.Len())
35 pc, _ := loadedConfig.Providers.Get("openai")
36 require.Equal(t, "key2", pc.APIKey)
37 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
38}
39
40func TestConfig_setDefaults(t *testing.T) {
41 cfg := &Config{}
42
43 cfg.setDefaults("/tmp", "")
44
45 require.NotNil(t, cfg.Options)
46 require.NotNil(t, cfg.Options.TUI)
47 require.NotNil(t, cfg.Options.ContextPaths)
48 require.NotNil(t, cfg.Providers)
49 require.NotNil(t, cfg.Models)
50 require.NotNil(t, cfg.LSP)
51 require.NotNil(t, cfg.MCP)
52 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
53 for _, path := range defaultContextPaths {
54 require.Contains(t, cfg.Options.ContextPaths, path)
55 }
56 require.Equal(t, "/tmp", cfg.workingDir)
57}
58
59func TestConfig_configureProviders(t *testing.T) {
60 knownProviders := []catwalk.Provider{
61 {
62 ID: "openai",
63 APIKey: "$OPENAI_API_KEY",
64 APIEndpoint: "https://api.openai.com/v1",
65 Models: []catwalk.Model{{
66 ID: "test-model",
67 }},
68 },
69 }
70
71 cfg := &Config{}
72 cfg.setDefaults("/tmp", "")
73 env := env.NewFromMap(map[string]string{
74 "OPENAI_API_KEY": "test-key",
75 })
76 resolver := NewEnvironmentVariableResolver(env)
77 err := cfg.configureProviders(env, resolver, knownProviders)
78 require.NoError(t, err)
79 require.Equal(t, 1, cfg.Providers.Len())
80
81 // We want to make sure that we keep the configured API key as a placeholder
82 pc, _ := cfg.Providers.Get("openai")
83 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
84}
85
86func TestConfig_configureProvidersWithOverride(t *testing.T) {
87 knownProviders := []catwalk.Provider{
88 {
89 ID: "openai",
90 APIKey: "$OPENAI_API_KEY",
91 APIEndpoint: "https://api.openai.com/v1",
92 Models: []catwalk.Model{{
93 ID: "test-model",
94 }},
95 },
96 }
97
98 cfg := &Config{
99 Providers: csync.NewMap[string, ProviderConfig](),
100 }
101 cfg.Providers.Set("openai", ProviderConfig{
102 APIKey: "xyz",
103 BaseURL: "https://api.openai.com/v2",
104 Models: []catwalk.Model{
105 {
106 ID: "test-model",
107 Name: "Updated",
108 },
109 {
110 ID: "another-model",
111 },
112 },
113 })
114 cfg.setDefaults("/tmp", "")
115
116 env := env.NewFromMap(map[string]string{
117 "OPENAI_API_KEY": "test-key",
118 })
119 resolver := NewEnvironmentVariableResolver(env)
120 err := cfg.configureProviders(env, resolver, knownProviders)
121 require.NoError(t, err)
122 require.Equal(t, 1, cfg.Providers.Len())
123
124 // We want to make sure that we keep the configured API key as a placeholder
125 pc, _ := cfg.Providers.Get("openai")
126 require.Equal(t, "xyz", pc.APIKey)
127 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
128 require.Len(t, pc.Models, 2)
129 require.Equal(t, "Updated", pc.Models[0].Name)
130}
131
132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
133 knownProviders := []catwalk.Provider{
134 {
135 ID: "openai",
136 APIKey: "$OPENAI_API_KEY",
137 APIEndpoint: "https://api.openai.com/v1",
138 Models: []catwalk.Model{{
139 ID: "test-model",
140 }},
141 },
142 }
143
144 cfg := &Config{
145 Providers: csync.NewMapFrom(map[string]ProviderConfig{
146 "custom": {
147 APIKey: "xyz",
148 BaseURL: "https://api.someendpoint.com/v2",
149 Models: []catwalk.Model{
150 {
151 ID: "test-model",
152 },
153 },
154 },
155 }),
156 }
157 cfg.setDefaults("/tmp", "")
158 env := env.NewFromMap(map[string]string{
159 "OPENAI_API_KEY": "test-key",
160 })
161 resolver := NewEnvironmentVariableResolver(env)
162 err := cfg.configureProviders(env, resolver, knownProviders)
163 require.NoError(t, err)
164 // Should be to because of the env variable
165 require.Equal(t, cfg.Providers.Len(), 2)
166
167 // We want to make sure that we keep the configured API key as a placeholder
168 pc, _ := cfg.Providers.Get("custom")
169 require.Equal(t, "xyz", pc.APIKey)
170 // Make sure we set the ID correctly
171 require.Equal(t, "custom", pc.ID)
172 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
173 require.Len(t, pc.Models, 1)
174
175 _, ok := cfg.Providers.Get("openai")
176 require.True(t, ok, "OpenAI provider should still be present")
177}
178
179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
180 knownProviders := []catwalk.Provider{
181 {
182 ID: catwalk.InferenceProviderBedrock,
183 APIKey: "",
184 APIEndpoint: "",
185 Models: []catwalk.Model{{
186 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
187 }},
188 },
189 }
190
191 cfg := &Config{}
192 cfg.setDefaults("/tmp", "")
193 env := env.NewFromMap(map[string]string{
194 "AWS_ACCESS_KEY_ID": "test-key-id",
195 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
196 })
197 resolver := NewEnvironmentVariableResolver(env)
198 err := cfg.configureProviders(env, resolver, knownProviders)
199 require.NoError(t, err)
200 require.Equal(t, cfg.Providers.Len(), 1)
201
202 bedrockProvider, ok := cfg.Providers.Get("bedrock")
203 require.True(t, ok, "Bedrock provider should be present")
204 require.Len(t, bedrockProvider.Models, 1)
205 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
206}
207
208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
209 knownProviders := []catwalk.Provider{
210 {
211 ID: catwalk.InferenceProviderBedrock,
212 APIKey: "",
213 APIEndpoint: "",
214 Models: []catwalk.Model{{
215 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
216 }},
217 },
218 }
219
220 cfg := &Config{}
221 cfg.setDefaults("/tmp", "")
222 env := env.NewFromMap(map[string]string{})
223 resolver := NewEnvironmentVariableResolver(env)
224 err := cfg.configureProviders(env, resolver, knownProviders)
225 require.NoError(t, err)
226 // Provider should not be configured without credentials
227 require.Equal(t, cfg.Providers.Len(), 0)
228}
229
230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
231 knownProviders := []catwalk.Provider{
232 {
233 ID: catwalk.InferenceProviderBedrock,
234 APIKey: "",
235 APIEndpoint: "",
236 Models: []catwalk.Model{{
237 ID: "some-random-model",
238 }},
239 },
240 }
241
242 cfg := &Config{}
243 cfg.setDefaults("/tmp", "")
244 env := env.NewFromMap(map[string]string{
245 "AWS_ACCESS_KEY_ID": "test-key-id",
246 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
247 })
248 resolver := NewEnvironmentVariableResolver(env)
249 err := cfg.configureProviders(env, resolver, knownProviders)
250 require.Error(t, err)
251}
252
253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
254 knownProviders := []catwalk.Provider{
255 {
256 ID: catwalk.InferenceProviderVertexAI,
257 APIKey: "",
258 APIEndpoint: "",
259 Models: []catwalk.Model{{
260 ID: "gemini-pro",
261 }},
262 },
263 }
264
265 cfg := &Config{}
266 cfg.setDefaults("/tmp", "")
267 env := env.NewFromMap(map[string]string{
268 "VERTEXAI_PROJECT": "test-project",
269 "VERTEXAI_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 require.NoError(t, err)
274 require.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 require.True(t, ok, "VertexAI provider should be present")
278 require.Len(t, vertexProvider.Models, 1)
279 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp", "")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 require.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 require.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp", "")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 require.NoError(t, err)
331 // Provider should not be configured without project
332 require.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp", "")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 require.NoError(t, err)
355 require.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 require.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 require.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 require.Len(t, enabled, 1)
401 require.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 require.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 require.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 require.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 require.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
458 cfg := &Config{
459 Options: &Options{
460 DisabledTools: []string{},
461 },
462 }
463
464 cfg.SetupAgents()
465 coderAgent, ok := cfg.Agents[AgentCoder]
466 require.True(t, ok)
467 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
468
469 taskAgent, ok := cfg.Agents[AgentTask]
470 require.True(t, ok)
471 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
472}
473
474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
475 cfg := &Config{
476 Options: &Options{
477 DisabledTools: []string{
478 "edit",
479 "download",
480 "grep",
481 },
482 },
483 }
484
485 cfg.SetupAgents()
486 coderAgent, ok := cfg.Agents[AgentCoder]
487 require.True(t, ok)
488
489 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
490
491 taskAgent, ok := cfg.Agents[AgentTask]
492 require.True(t, ok)
493 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
494}
495
496func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
497 cfg := &Config{
498 Options: &Options{
499 DisabledTools: []string{
500 "glob",
501 "grep",
502 "ls",
503 "sourcegraph",
504 "view",
505 },
506 },
507 }
508
509 cfg.SetupAgents()
510 coderAgent, ok := cfg.Agents[AgentCoder]
511 require.True(t, ok)
512 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools)
513
514 taskAgent, ok := cfg.Agents[AgentTask]
515 require.True(t, ok)
516 assert.Equal(t, []string{}, taskAgent.AllowedTools)
517}
518
519func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
520 knownProviders := []catwalk.Provider{
521 {
522 ID: "openai",
523 APIKey: "$OPENAI_API_KEY",
524 APIEndpoint: "https://api.openai.com/v1",
525 Models: []catwalk.Model{{
526 ID: "test-model",
527 }},
528 },
529 }
530
531 cfg := &Config{
532 Providers: csync.NewMapFrom(map[string]ProviderConfig{
533 "openai": {
534 Disable: true,
535 },
536 }),
537 }
538 cfg.setDefaults("/tmp", "")
539
540 env := env.NewFromMap(map[string]string{
541 "OPENAI_API_KEY": "test-key",
542 })
543 resolver := NewEnvironmentVariableResolver(env)
544 err := cfg.configureProviders(env, resolver, knownProviders)
545 require.NoError(t, err)
546
547 require.Equal(t, cfg.Providers.Len(), 1)
548 prov, exists := cfg.Providers.Get("openai")
549 require.True(t, exists)
550 require.True(t, prov.Disable)
551}
552
553func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
554 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
555 cfg := &Config{
556 Providers: csync.NewMapFrom(map[string]ProviderConfig{
557 "custom": {
558 BaseURL: "https://api.custom.com/v1",
559 Models: []catwalk.Model{{
560 ID: "test-model",
561 }},
562 },
563 "openai": {
564 APIKey: "$MISSING",
565 },
566 }),
567 }
568 cfg.setDefaults("/tmp", "")
569
570 env := env.NewFromMap(map[string]string{})
571 resolver := NewEnvironmentVariableResolver(env)
572 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
573 require.NoError(t, err)
574
575 require.Equal(t, cfg.Providers.Len(), 1)
576 _, exists := cfg.Providers.Get("custom")
577 require.True(t, exists)
578 })
579
580 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
581 cfg := &Config{
582 Providers: csync.NewMapFrom(map[string]ProviderConfig{
583 "custom": {
584 APIKey: "test-key",
585 Models: []catwalk.Model{{
586 ID: "test-model",
587 }},
588 },
589 }),
590 }
591 cfg.setDefaults("/tmp", "")
592
593 env := env.NewFromMap(map[string]string{})
594 resolver := NewEnvironmentVariableResolver(env)
595 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
596 require.NoError(t, err)
597
598 require.Equal(t, cfg.Providers.Len(), 0)
599 _, exists := cfg.Providers.Get("custom")
600 require.False(t, exists)
601 })
602
603 t.Run("custom provider with no models is removed", func(t *testing.T) {
604 cfg := &Config{
605 Providers: csync.NewMapFrom(map[string]ProviderConfig{
606 "custom": {
607 APIKey: "test-key",
608 BaseURL: "https://api.custom.com/v1",
609 Models: []catwalk.Model{},
610 },
611 }),
612 }
613 cfg.setDefaults("/tmp", "")
614
615 env := env.NewFromMap(map[string]string{})
616 resolver := NewEnvironmentVariableResolver(env)
617 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
618 require.NoError(t, err)
619
620 require.Equal(t, cfg.Providers.Len(), 0)
621 _, exists := cfg.Providers.Get("custom")
622 require.False(t, exists)
623 })
624
625 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
626 cfg := &Config{
627 Providers: csync.NewMapFrom(map[string]ProviderConfig{
628 "custom": {
629 APIKey: "test-key",
630 BaseURL: "https://api.custom.com/v1",
631 Type: "unsupported",
632 Models: []catwalk.Model{{
633 ID: "test-model",
634 }},
635 },
636 }),
637 }
638 cfg.setDefaults("/tmp", "")
639
640 env := env.NewFromMap(map[string]string{})
641 resolver := NewEnvironmentVariableResolver(env)
642 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
643 require.NoError(t, err)
644
645 require.Equal(t, cfg.Providers.Len(), 0)
646 _, exists := cfg.Providers.Get("custom")
647 require.False(t, exists)
648 })
649
650 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
651 cfg := &Config{
652 Providers: csync.NewMapFrom(map[string]ProviderConfig{
653 "custom": {
654 APIKey: "test-key",
655 BaseURL: "https://api.custom.com/v1",
656 Type: catwalk.TypeOpenAI,
657 Models: []catwalk.Model{{
658 ID: "test-model",
659 }},
660 },
661 }),
662 }
663 cfg.setDefaults("/tmp", "")
664
665 env := env.NewFromMap(map[string]string{})
666 resolver := NewEnvironmentVariableResolver(env)
667 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
668 require.NoError(t, err)
669
670 require.Equal(t, cfg.Providers.Len(), 1)
671 customProvider, exists := cfg.Providers.Get("custom")
672 require.True(t, exists)
673 require.Equal(t, "custom", customProvider.ID)
674 require.Equal(t, "test-key", customProvider.APIKey)
675 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
676 })
677
678 t.Run("custom anthropic provider is supported", func(t *testing.T) {
679 cfg := &Config{
680 Providers: csync.NewMapFrom(map[string]ProviderConfig{
681 "custom-anthropic": {
682 APIKey: "test-key",
683 BaseURL: "https://api.anthropic.com/v1",
684 Type: catwalk.TypeAnthropic,
685 Models: []catwalk.Model{{
686 ID: "claude-3-sonnet",
687 }},
688 },
689 }),
690 }
691 cfg.setDefaults("/tmp", "")
692
693 env := env.NewFromMap(map[string]string{})
694 resolver := NewEnvironmentVariableResolver(env)
695 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
696 require.NoError(t, err)
697
698 require.Equal(t, cfg.Providers.Len(), 1)
699 customProvider, exists := cfg.Providers.Get("custom-anthropic")
700 require.True(t, exists)
701 require.Equal(t, "custom-anthropic", customProvider.ID)
702 require.Equal(t, "test-key", customProvider.APIKey)
703 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
704 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
705 })
706
707 t.Run("disabled custom provider is removed", func(t *testing.T) {
708 cfg := &Config{
709 Providers: csync.NewMapFrom(map[string]ProviderConfig{
710 "custom": {
711 APIKey: "test-key",
712 BaseURL: "https://api.custom.com/v1",
713 Type: catwalk.TypeOpenAI,
714 Disable: true,
715 Models: []catwalk.Model{{
716 ID: "test-model",
717 }},
718 },
719 }),
720 }
721 cfg.setDefaults("/tmp", "")
722
723 env := env.NewFromMap(map[string]string{})
724 resolver := NewEnvironmentVariableResolver(env)
725 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
726 require.NoError(t, err)
727
728 require.Equal(t, cfg.Providers.Len(), 0)
729 _, exists := cfg.Providers.Get("custom")
730 require.False(t, exists)
731 })
732}
733
734func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
735 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
736 knownProviders := []catwalk.Provider{
737 {
738 ID: catwalk.InferenceProviderVertexAI,
739 APIKey: "",
740 APIEndpoint: "",
741 Models: []catwalk.Model{{
742 ID: "gemini-pro",
743 }},
744 },
745 }
746
747 cfg := &Config{
748 Providers: csync.NewMapFrom(map[string]ProviderConfig{
749 "vertexai": {
750 BaseURL: "custom-url",
751 },
752 }),
753 }
754 cfg.setDefaults("/tmp", "")
755
756 env := env.NewFromMap(map[string]string{
757 "GOOGLE_GENAI_USE_VERTEXAI": "false",
758 })
759 resolver := NewEnvironmentVariableResolver(env)
760 err := cfg.configureProviders(env, resolver, knownProviders)
761 require.NoError(t, err)
762
763 require.Equal(t, cfg.Providers.Len(), 0)
764 _, exists := cfg.Providers.Get("vertexai")
765 require.False(t, exists)
766 })
767
768 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
769 knownProviders := []catwalk.Provider{
770 {
771 ID: catwalk.InferenceProviderBedrock,
772 APIKey: "",
773 APIEndpoint: "",
774 Models: []catwalk.Model{{
775 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
776 }},
777 },
778 }
779
780 cfg := &Config{
781 Providers: csync.NewMapFrom(map[string]ProviderConfig{
782 "bedrock": {
783 BaseURL: "custom-url",
784 },
785 }),
786 }
787 cfg.setDefaults("/tmp", "")
788
789 env := env.NewFromMap(map[string]string{})
790 resolver := NewEnvironmentVariableResolver(env)
791 err := cfg.configureProviders(env, resolver, knownProviders)
792 require.NoError(t, err)
793
794 require.Equal(t, cfg.Providers.Len(), 0)
795 _, exists := cfg.Providers.Get("bedrock")
796 require.False(t, exists)
797 })
798
799 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
800 knownProviders := []catwalk.Provider{
801 {
802 ID: "openai",
803 APIKey: "$MISSING_API_KEY",
804 APIEndpoint: "https://api.openai.com/v1",
805 Models: []catwalk.Model{{
806 ID: "test-model",
807 }},
808 },
809 }
810
811 cfg := &Config{
812 Providers: csync.NewMapFrom(map[string]ProviderConfig{
813 "openai": {
814 BaseURL: "custom-url",
815 },
816 }),
817 }
818 cfg.setDefaults("/tmp", "")
819
820 env := env.NewFromMap(map[string]string{})
821 resolver := NewEnvironmentVariableResolver(env)
822 err := cfg.configureProviders(env, resolver, knownProviders)
823 require.NoError(t, err)
824
825 require.Equal(t, cfg.Providers.Len(), 0)
826 _, exists := cfg.Providers.Get("openai")
827 require.False(t, exists)
828 })
829
830 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
831 knownProviders := []catwalk.Provider{
832 {
833 ID: "openai",
834 APIKey: "$OPENAI_API_KEY",
835 APIEndpoint: "$MISSING_ENDPOINT",
836 Models: []catwalk.Model{{
837 ID: "test-model",
838 }},
839 },
840 }
841
842 cfg := &Config{
843 Providers: csync.NewMapFrom(map[string]ProviderConfig{
844 "openai": {
845 APIKey: "test-key",
846 },
847 }),
848 }
849 cfg.setDefaults("/tmp", "")
850
851 env := env.NewFromMap(map[string]string{
852 "OPENAI_API_KEY": "test-key",
853 })
854 resolver := NewEnvironmentVariableResolver(env)
855 err := cfg.configureProviders(env, resolver, knownProviders)
856 require.NoError(t, err)
857
858 require.Equal(t, cfg.Providers.Len(), 1)
859 _, exists := cfg.Providers.Get("openai")
860 require.True(t, exists)
861 })
862}
863
864func TestConfig_defaultModelSelection(t *testing.T) {
865 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
866 knownProviders := []catwalk.Provider{
867 {
868 ID: "openai",
869 APIKey: "abc",
870 DefaultLargeModelID: "large-model",
871 DefaultSmallModelID: "small-model",
872 Models: []catwalk.Model{
873 {
874 ID: "large-model",
875 DefaultMaxTokens: 1000,
876 },
877 {
878 ID: "small-model",
879 DefaultMaxTokens: 500,
880 },
881 },
882 },
883 }
884
885 cfg := &Config{}
886 cfg.setDefaults("/tmp", "")
887 env := env.NewFromMap(map[string]string{})
888 resolver := NewEnvironmentVariableResolver(env)
889 err := cfg.configureProviders(env, resolver, knownProviders)
890 require.NoError(t, err)
891
892 large, small, err := cfg.defaultModelSelection(knownProviders)
893 require.NoError(t, err)
894 require.Equal(t, "large-model", large.Model)
895 require.Equal(t, "openai", large.Provider)
896 require.Equal(t, int64(1000), large.MaxTokens)
897 require.Equal(t, "small-model", small.Model)
898 require.Equal(t, "openai", small.Provider)
899 require.Equal(t, int64(500), small.MaxTokens)
900 })
901 t.Run("should error if no providers configured", func(t *testing.T) {
902 knownProviders := []catwalk.Provider{
903 {
904 ID: "openai",
905 APIKey: "$MISSING_KEY",
906 DefaultLargeModelID: "large-model",
907 DefaultSmallModelID: "small-model",
908 Models: []catwalk.Model{
909 {
910 ID: "large-model",
911 DefaultMaxTokens: 1000,
912 },
913 {
914 ID: "small-model",
915 DefaultMaxTokens: 500,
916 },
917 },
918 },
919 }
920
921 cfg := &Config{}
922 cfg.setDefaults("/tmp", "")
923 env := env.NewFromMap(map[string]string{})
924 resolver := NewEnvironmentVariableResolver(env)
925 err := cfg.configureProviders(env, resolver, knownProviders)
926 require.NoError(t, err)
927
928 _, _, err = cfg.defaultModelSelection(knownProviders)
929 require.Error(t, err)
930 })
931 t.Run("should error if model is missing", func(t *testing.T) {
932 knownProviders := []catwalk.Provider{
933 {
934 ID: "openai",
935 APIKey: "abc",
936 DefaultLargeModelID: "large-model",
937 DefaultSmallModelID: "small-model",
938 Models: []catwalk.Model{
939 {
940 ID: "not-large-model",
941 DefaultMaxTokens: 1000,
942 },
943 {
944 ID: "small-model",
945 DefaultMaxTokens: 500,
946 },
947 },
948 },
949 }
950
951 cfg := &Config{}
952 cfg.setDefaults("/tmp", "")
953 env := env.NewFromMap(map[string]string{})
954 resolver := NewEnvironmentVariableResolver(env)
955 err := cfg.configureProviders(env, resolver, knownProviders)
956 require.NoError(t, err)
957 _, _, err = cfg.defaultModelSelection(knownProviders)
958 require.Error(t, err)
959 })
960
961 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
962 knownProviders := []catwalk.Provider{
963 {
964 ID: "openai",
965 APIKey: "$MISSING", // will not be included in the config
966 DefaultLargeModelID: "large-model",
967 DefaultSmallModelID: "small-model",
968 Models: []catwalk.Model{
969 {
970 ID: "not-large-model",
971 DefaultMaxTokens: 1000,
972 },
973 {
974 ID: "small-model",
975 DefaultMaxTokens: 500,
976 },
977 },
978 },
979 }
980
981 cfg := &Config{
982 Providers: csync.NewMapFrom(map[string]ProviderConfig{
983 "custom": {
984 APIKey: "test-key",
985 BaseURL: "https://api.custom.com/v1",
986 Models: []catwalk.Model{
987 {
988 ID: "model",
989 DefaultMaxTokens: 600,
990 },
991 },
992 },
993 }),
994 }
995 cfg.setDefaults("/tmp", "")
996 env := env.NewFromMap(map[string]string{})
997 resolver := NewEnvironmentVariableResolver(env)
998 err := cfg.configureProviders(env, resolver, knownProviders)
999 require.NoError(t, err)
1000 large, small, err := cfg.defaultModelSelection(knownProviders)
1001 require.NoError(t, err)
1002 require.Equal(t, "model", large.Model)
1003 require.Equal(t, "custom", large.Provider)
1004 require.Equal(t, int64(600), large.MaxTokens)
1005 require.Equal(t, "model", small.Model)
1006 require.Equal(t, "custom", small.Provider)
1007 require.Equal(t, int64(600), small.MaxTokens)
1008 })
1009
1010 t.Run("should fail if no model configured", func(t *testing.T) {
1011 knownProviders := []catwalk.Provider{
1012 {
1013 ID: "openai",
1014 APIKey: "$MISSING", // will not be included in the config
1015 DefaultLargeModelID: "large-model",
1016 DefaultSmallModelID: "small-model",
1017 Models: []catwalk.Model{
1018 {
1019 ID: "not-large-model",
1020 DefaultMaxTokens: 1000,
1021 },
1022 {
1023 ID: "small-model",
1024 DefaultMaxTokens: 500,
1025 },
1026 },
1027 },
1028 }
1029
1030 cfg := &Config{
1031 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1032 "custom": {
1033 APIKey: "test-key",
1034 BaseURL: "https://api.custom.com/v1",
1035 Models: []catwalk.Model{},
1036 },
1037 }),
1038 }
1039 cfg.setDefaults("/tmp", "")
1040 env := env.NewFromMap(map[string]string{})
1041 resolver := NewEnvironmentVariableResolver(env)
1042 err := cfg.configureProviders(env, resolver, knownProviders)
1043 require.NoError(t, err)
1044 _, _, err = cfg.defaultModelSelection(knownProviders)
1045 require.Error(t, err)
1046 })
1047 t.Run("should use the default provider first", func(t *testing.T) {
1048 knownProviders := []catwalk.Provider{
1049 {
1050 ID: "openai",
1051 APIKey: "set",
1052 DefaultLargeModelID: "large-model",
1053 DefaultSmallModelID: "small-model",
1054 Models: []catwalk.Model{
1055 {
1056 ID: "large-model",
1057 DefaultMaxTokens: 1000,
1058 },
1059 {
1060 ID: "small-model",
1061 DefaultMaxTokens: 500,
1062 },
1063 },
1064 },
1065 }
1066
1067 cfg := &Config{
1068 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1069 "custom": {
1070 APIKey: "test-key",
1071 BaseURL: "https://api.custom.com/v1",
1072 Models: []catwalk.Model{
1073 {
1074 ID: "large-model",
1075 DefaultMaxTokens: 1000,
1076 },
1077 },
1078 },
1079 }),
1080 }
1081 cfg.setDefaults("/tmp", "")
1082 env := env.NewFromMap(map[string]string{})
1083 resolver := NewEnvironmentVariableResolver(env)
1084 err := cfg.configureProviders(env, resolver, knownProviders)
1085 require.NoError(t, err)
1086 large, small, err := cfg.defaultModelSelection(knownProviders)
1087 require.NoError(t, err)
1088 require.Equal(t, "large-model", large.Model)
1089 require.Equal(t, "openai", large.Provider)
1090 require.Equal(t, int64(1000), large.MaxTokens)
1091 require.Equal(t, "small-model", small.Model)
1092 require.Equal(t, "openai", small.Provider)
1093 require.Equal(t, int64(500), small.MaxTokens)
1094 })
1095}
1096
1097func TestConfig_configureSelectedModels(t *testing.T) {
1098 t.Run("should override defaults", func(t *testing.T) {
1099 knownProviders := []catwalk.Provider{
1100 {
1101 ID: "openai",
1102 APIKey: "abc",
1103 DefaultLargeModelID: "large-model",
1104 DefaultSmallModelID: "small-model",
1105 Models: []catwalk.Model{
1106 {
1107 ID: "larger-model",
1108 DefaultMaxTokens: 2000,
1109 },
1110 {
1111 ID: "large-model",
1112 DefaultMaxTokens: 1000,
1113 },
1114 {
1115 ID: "small-model",
1116 DefaultMaxTokens: 500,
1117 },
1118 },
1119 },
1120 }
1121
1122 cfg := &Config{
1123 Models: map[SelectedModelType]SelectedModel{
1124 "large": {
1125 Model: "larger-model",
1126 },
1127 },
1128 }
1129 cfg.setDefaults("/tmp", "")
1130 env := env.NewFromMap(map[string]string{})
1131 resolver := NewEnvironmentVariableResolver(env)
1132 err := cfg.configureProviders(env, resolver, knownProviders)
1133 require.NoError(t, err)
1134
1135 err = cfg.configureSelectedModels(knownProviders)
1136 require.NoError(t, err)
1137 large := cfg.Models[SelectedModelTypeLarge]
1138 small := cfg.Models[SelectedModelTypeSmall]
1139 require.Equal(t, "larger-model", large.Model)
1140 require.Equal(t, "openai", large.Provider)
1141 require.Equal(t, int64(2000), large.MaxTokens)
1142 require.Equal(t, "small-model", small.Model)
1143 require.Equal(t, "openai", small.Provider)
1144 require.Equal(t, int64(500), small.MaxTokens)
1145 })
1146 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1147 knownProviders := []catwalk.Provider{
1148 {
1149 ID: "openai",
1150 APIKey: "abc",
1151 DefaultLargeModelID: "large-model",
1152 DefaultSmallModelID: "small-model",
1153 Models: []catwalk.Model{
1154 {
1155 ID: "large-model",
1156 DefaultMaxTokens: 1000,
1157 },
1158 {
1159 ID: "small-model",
1160 DefaultMaxTokens: 500,
1161 },
1162 },
1163 },
1164 {
1165 ID: "anthropic",
1166 APIKey: "abc",
1167 DefaultLargeModelID: "a-large-model",
1168 DefaultSmallModelID: "a-small-model",
1169 Models: []catwalk.Model{
1170 {
1171 ID: "a-large-model",
1172 DefaultMaxTokens: 1000,
1173 },
1174 {
1175 ID: "a-small-model",
1176 DefaultMaxTokens: 200,
1177 },
1178 },
1179 },
1180 }
1181
1182 cfg := &Config{
1183 Models: map[SelectedModelType]SelectedModel{
1184 "small": {
1185 Model: "a-small-model",
1186 Provider: "anthropic",
1187 MaxTokens: 300,
1188 },
1189 },
1190 }
1191 cfg.setDefaults("/tmp", "")
1192 env := env.NewFromMap(map[string]string{})
1193 resolver := NewEnvironmentVariableResolver(env)
1194 err := cfg.configureProviders(env, resolver, knownProviders)
1195 require.NoError(t, err)
1196
1197 err = cfg.configureSelectedModels(knownProviders)
1198 require.NoError(t, err)
1199 large := cfg.Models[SelectedModelTypeLarge]
1200 small := cfg.Models[SelectedModelTypeSmall]
1201 require.Equal(t, "large-model", large.Model)
1202 require.Equal(t, "openai", large.Provider)
1203 require.Equal(t, int64(1000), large.MaxTokens)
1204 require.Equal(t, "a-small-model", small.Model)
1205 require.Equal(t, "anthropic", small.Provider)
1206 require.Equal(t, int64(300), small.MaxTokens)
1207 })
1208
1209 t.Run("should override the max tokens only", func(t *testing.T) {
1210 knownProviders := []catwalk.Provider{
1211 {
1212 ID: "openai",
1213 APIKey: "abc",
1214 DefaultLargeModelID: "large-model",
1215 DefaultSmallModelID: "small-model",
1216 Models: []catwalk.Model{
1217 {
1218 ID: "large-model",
1219 DefaultMaxTokens: 1000,
1220 },
1221 {
1222 ID: "small-model",
1223 DefaultMaxTokens: 500,
1224 },
1225 },
1226 },
1227 }
1228
1229 cfg := &Config{
1230 Models: map[SelectedModelType]SelectedModel{
1231 "large": {
1232 MaxTokens: 100,
1233 },
1234 },
1235 }
1236 cfg.setDefaults("/tmp", "")
1237 env := env.NewFromMap(map[string]string{})
1238 resolver := NewEnvironmentVariableResolver(env)
1239 err := cfg.configureProviders(env, resolver, knownProviders)
1240 require.NoError(t, err)
1241
1242 err = cfg.configureSelectedModels(knownProviders)
1243 require.NoError(t, err)
1244 large := cfg.Models[SelectedModelTypeLarge]
1245 require.Equal(t, "large-model", large.Model)
1246 require.Equal(t, "openai", large.Provider)
1247 require.Equal(t, int64(100), large.MaxTokens)
1248 })
1249}