1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16)
17
18func TestMain(m *testing.M) {
19 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
20
21 exitVal := m.Run()
22 os.Exit(exitVal)
23}
24
25func TestConfig_LoadFromReaders(t *testing.T) {
26 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
27 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
28 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
29
30 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
31
32 require.NoError(t, err)
33 require.NotNil(t, loadedConfig)
34 require.Equal(t, 1, loadedConfig.Providers.Len())
35 pc, _ := loadedConfig.Providers.Get("openai")
36 require.Equal(t, "key2", pc.APIKey)
37 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
38}
39
40func TestConfig_setDefaults(t *testing.T) {
41 cfg := &Config{}
42
43 cfg.setDefaults("/tmp", "")
44
45 require.NotNil(t, cfg.Options)
46 require.NotNil(t, cfg.Options.TUI)
47 require.NotNil(t, cfg.Options.ContextPaths)
48 require.NotNil(t, cfg.Providers)
49 require.NotNil(t, cfg.Models)
50 require.NotNil(t, cfg.LSP)
51 require.NotNil(t, cfg.MCP)
52 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
53 for _, path := range defaultContextPaths {
54 require.Contains(t, cfg.Options.ContextPaths, path)
55 }
56 require.Equal(t, "/tmp", cfg.workingDir)
57}
58
59func TestConfig_configureProviders(t *testing.T) {
60 knownProviders := []catwalk.Provider{
61 {
62 ID: "openai",
63 APIKey: "$OPENAI_API_KEY",
64 APIEndpoint: "https://api.openai.com/v1",
65 Models: []catwalk.Model{{
66 ID: "test-model",
67 }},
68 },
69 }
70
71 cfg := &Config{}
72 cfg.setDefaults("/tmp", "")
73 env := env.NewFromMap(map[string]string{
74 "OPENAI_API_KEY": "test-key",
75 })
76 resolver := NewEnvironmentVariableResolver(env)
77 err := cfg.configureProviders(env, resolver, knownProviders)
78 require.NoError(t, err)
79 require.Equal(t, 1, cfg.Providers.Len())
80
81 // We want to make sure that we keep the configured API key as a placeholder
82 pc, _ := cfg.Providers.Get("openai")
83 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
84}
85
86func TestConfig_configureProvidersWithOverride(t *testing.T) {
87 knownProviders := []catwalk.Provider{
88 {
89 ID: "openai",
90 APIKey: "$OPENAI_API_KEY",
91 APIEndpoint: "https://api.openai.com/v1",
92 Models: []catwalk.Model{{
93 ID: "test-model",
94 }},
95 },
96 }
97
98 cfg := &Config{
99 Providers: csync.NewMap[string, ProviderConfig](),
100 }
101 cfg.Providers.Set("openai", ProviderConfig{
102 APIKey: "xyz",
103 BaseURL: "https://api.openai.com/v2",
104 Models: []catwalk.Model{
105 {
106 ID: "test-model",
107 Name: "Updated",
108 },
109 {
110 ID: "another-model",
111 },
112 },
113 })
114 cfg.setDefaults("/tmp", "")
115
116 env := env.NewFromMap(map[string]string{
117 "OPENAI_API_KEY": "test-key",
118 })
119 resolver := NewEnvironmentVariableResolver(env)
120 err := cfg.configureProviders(env, resolver, knownProviders)
121 require.NoError(t, err)
122 require.Equal(t, 1, cfg.Providers.Len())
123
124 // We want to make sure that we keep the configured API key as a placeholder
125 pc, _ := cfg.Providers.Get("openai")
126 require.Equal(t, "xyz", pc.APIKey)
127 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
128 require.Len(t, pc.Models, 2)
129 require.Equal(t, "Updated", pc.Models[0].Name)
130}
131
132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
133 knownProviders := []catwalk.Provider{
134 {
135 ID: "openai",
136 APIKey: "$OPENAI_API_KEY",
137 APIEndpoint: "https://api.openai.com/v1",
138 Models: []catwalk.Model{{
139 ID: "test-model",
140 }},
141 },
142 }
143
144 cfg := &Config{
145 Providers: csync.NewMapFrom(map[string]ProviderConfig{
146 "custom": {
147 APIKey: "xyz",
148 BaseURL: "https://api.someendpoint.com/v2",
149 Models: []catwalk.Model{
150 {
151 ID: "test-model",
152 },
153 },
154 },
155 }),
156 }
157 cfg.setDefaults("/tmp", "")
158 env := env.NewFromMap(map[string]string{
159 "OPENAI_API_KEY": "test-key",
160 })
161 resolver := NewEnvironmentVariableResolver(env)
162 err := cfg.configureProviders(env, resolver, knownProviders)
163 require.NoError(t, err)
164 // Should be to because of the env variable
165 require.Equal(t, cfg.Providers.Len(), 2)
166
167 // We want to make sure that we keep the configured API key as a placeholder
168 pc, _ := cfg.Providers.Get("custom")
169 require.Equal(t, "xyz", pc.APIKey)
170 // Make sure we set the ID correctly
171 require.Equal(t, "custom", pc.ID)
172 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
173 require.Len(t, pc.Models, 1)
174
175 _, ok := cfg.Providers.Get("openai")
176 require.True(t, ok, "OpenAI provider should still be present")
177}
178
179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
180 knownProviders := []catwalk.Provider{
181 {
182 ID: catwalk.InferenceProviderBedrock,
183 APIKey: "",
184 APIEndpoint: "",
185 Models: []catwalk.Model{{
186 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
187 }},
188 },
189 }
190
191 cfg := &Config{}
192 cfg.setDefaults("/tmp", "")
193 env := env.NewFromMap(map[string]string{
194 "AWS_ACCESS_KEY_ID": "test-key-id",
195 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
196 })
197 resolver := NewEnvironmentVariableResolver(env)
198 err := cfg.configureProviders(env, resolver, knownProviders)
199 require.NoError(t, err)
200 require.Equal(t, cfg.Providers.Len(), 1)
201
202 bedrockProvider, ok := cfg.Providers.Get("bedrock")
203 require.True(t, ok, "Bedrock provider should be present")
204 require.Len(t, bedrockProvider.Models, 1)
205 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
206}
207
208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
209 knownProviders := []catwalk.Provider{
210 {
211 ID: catwalk.InferenceProviderBedrock,
212 APIKey: "",
213 APIEndpoint: "",
214 Models: []catwalk.Model{{
215 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
216 }},
217 },
218 }
219
220 cfg := &Config{}
221 cfg.setDefaults("/tmp", "")
222 env := env.NewFromMap(map[string]string{})
223 resolver := NewEnvironmentVariableResolver(env)
224 err := cfg.configureProviders(env, resolver, knownProviders)
225 require.NoError(t, err)
226 // Provider should not be configured without credentials
227 require.Equal(t, cfg.Providers.Len(), 0)
228}
229
230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
231 knownProviders := []catwalk.Provider{
232 {
233 ID: catwalk.InferenceProviderBedrock,
234 APIKey: "",
235 APIEndpoint: "",
236 Models: []catwalk.Model{{
237 ID: "some-random-model",
238 }},
239 },
240 }
241
242 cfg := &Config{}
243 cfg.setDefaults("/tmp", "")
244 env := env.NewFromMap(map[string]string{
245 "AWS_ACCESS_KEY_ID": "test-key-id",
246 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
247 })
248 resolver := NewEnvironmentVariableResolver(env)
249 err := cfg.configureProviders(env, resolver, knownProviders)
250 require.Error(t, err)
251}
252
253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
254 knownProviders := []catwalk.Provider{
255 {
256 ID: catwalk.InferenceProviderVertexAI,
257 APIKey: "",
258 APIEndpoint: "",
259 Models: []catwalk.Model{{
260 ID: "gemini-pro",
261 }},
262 },
263 }
264
265 cfg := &Config{}
266 cfg.setDefaults("/tmp", "")
267 env := env.NewFromMap(map[string]string{
268 "VERTEXAI_PROJECT": "test-project",
269 "VERTEXAI_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 require.NoError(t, err)
274 require.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 require.True(t, ok, "VertexAI provider should be present")
278 require.Len(t, vertexProvider.Models, 1)
279 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp", "")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 require.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 require.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp", "")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 require.NoError(t, err)
331 // Provider should not be configured without project
332 require.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp", "")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 require.NoError(t, err)
355 require.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 require.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 require.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 require.Len(t, enabled, 1)
401 require.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 require.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 require.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 require.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 require.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
458 cfg := &Config{
459 Options: &Options{
460 DisabledTools: []string{},
461 },
462 }
463
464 cfg.SetupAgents()
465 coderAgent, ok := cfg.Agents[AgentCoder]
466 require.True(t, ok)
467 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
468
469 taskAgent, ok := cfg.Agents[AgentTask]
470 require.True(t, ok)
471 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
472}
473
474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
475 cfg := &Config{
476 Options: &Options{
477 DisabledTools: []string{
478 "edit",
479 "download",
480 "grep",
481 },
482 },
483 }
484
485 cfg.SetupAgents()
486 coderAgent, ok := cfg.Agents[AgentCoder]
487 require.True(t, ok)
488 assert.Equal(t, []string{
489 "agent",
490 "bash",
491 "multiedit",
492 "lsp_diagnostics",
493 "lsp_references",
494 "fetch",
495 "glob",
496 "ls",
497 "sourcegraph",
498 "view",
499 "write",
500 "read_mcp_resource",
501 "list_mcp_resources",
502 }, coderAgent.AllowedTools)
503
504 taskAgent, ok := cfg.Agents[AgentTask]
505 require.True(t, ok)
506 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
507}
508
509func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
510 cfg := &Config{
511 Options: &Options{
512 DisabledTools: []string{
513 "glob",
514 "grep",
515 "ls",
516 "sourcegraph",
517 "view",
518 "read_mcp_resource",
519 "list_mcp_resources",
520 },
521 },
522 }
523
524 cfg.SetupAgents()
525 coderAgent, ok := cfg.Agents[AgentCoder]
526 require.True(t, ok)
527 assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "write"}, coderAgent.AllowedTools)
528
529 taskAgent, ok := cfg.Agents[AgentTask]
530 require.True(t, ok)
531 assert.Equal(t, []string{}, taskAgent.AllowedTools)
532}
533
534func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
535 knownProviders := []catwalk.Provider{
536 {
537 ID: "openai",
538 APIKey: "$OPENAI_API_KEY",
539 APIEndpoint: "https://api.openai.com/v1",
540 Models: []catwalk.Model{{
541 ID: "test-model",
542 }},
543 },
544 }
545
546 cfg := &Config{
547 Providers: csync.NewMapFrom(map[string]ProviderConfig{
548 "openai": {
549 Disable: true,
550 },
551 }),
552 }
553 cfg.setDefaults("/tmp", "")
554
555 env := env.NewFromMap(map[string]string{
556 "OPENAI_API_KEY": "test-key",
557 })
558 resolver := NewEnvironmentVariableResolver(env)
559 err := cfg.configureProviders(env, resolver, knownProviders)
560 require.NoError(t, err)
561
562 require.Equal(t, cfg.Providers.Len(), 1)
563 prov, exists := cfg.Providers.Get("openai")
564 require.True(t, exists)
565 require.True(t, prov.Disable)
566}
567
568func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
569 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
570 cfg := &Config{
571 Providers: csync.NewMapFrom(map[string]ProviderConfig{
572 "custom": {
573 BaseURL: "https://api.custom.com/v1",
574 Models: []catwalk.Model{{
575 ID: "test-model",
576 }},
577 },
578 "openai": {
579 APIKey: "$MISSING",
580 },
581 }),
582 }
583 cfg.setDefaults("/tmp", "")
584
585 env := env.NewFromMap(map[string]string{})
586 resolver := NewEnvironmentVariableResolver(env)
587 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
588 require.NoError(t, err)
589
590 require.Equal(t, cfg.Providers.Len(), 1)
591 _, exists := cfg.Providers.Get("custom")
592 require.True(t, exists)
593 })
594
595 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
596 cfg := &Config{
597 Providers: csync.NewMapFrom(map[string]ProviderConfig{
598 "custom": {
599 APIKey: "test-key",
600 Models: []catwalk.Model{{
601 ID: "test-model",
602 }},
603 },
604 }),
605 }
606 cfg.setDefaults("/tmp", "")
607
608 env := env.NewFromMap(map[string]string{})
609 resolver := NewEnvironmentVariableResolver(env)
610 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
611 require.NoError(t, err)
612
613 require.Equal(t, cfg.Providers.Len(), 0)
614 _, exists := cfg.Providers.Get("custom")
615 require.False(t, exists)
616 })
617
618 t.Run("custom provider with no models is removed", func(t *testing.T) {
619 cfg := &Config{
620 Providers: csync.NewMapFrom(map[string]ProviderConfig{
621 "custom": {
622 APIKey: "test-key",
623 BaseURL: "https://api.custom.com/v1",
624 Models: []catwalk.Model{},
625 },
626 }),
627 }
628 cfg.setDefaults("/tmp", "")
629
630 env := env.NewFromMap(map[string]string{})
631 resolver := NewEnvironmentVariableResolver(env)
632 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
633 require.NoError(t, err)
634
635 require.Equal(t, cfg.Providers.Len(), 0)
636 _, exists := cfg.Providers.Get("custom")
637 require.False(t, exists)
638 })
639
640 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
641 cfg := &Config{
642 Providers: csync.NewMapFrom(map[string]ProviderConfig{
643 "custom": {
644 APIKey: "test-key",
645 BaseURL: "https://api.custom.com/v1",
646 Type: "unsupported",
647 Models: []catwalk.Model{{
648 ID: "test-model",
649 }},
650 },
651 }),
652 }
653 cfg.setDefaults("/tmp", "")
654
655 env := env.NewFromMap(map[string]string{})
656 resolver := NewEnvironmentVariableResolver(env)
657 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
658 require.NoError(t, err)
659
660 require.Equal(t, cfg.Providers.Len(), 0)
661 _, exists := cfg.Providers.Get("custom")
662 require.False(t, exists)
663 })
664
665 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
666 cfg := &Config{
667 Providers: csync.NewMapFrom(map[string]ProviderConfig{
668 "custom": {
669 APIKey: "test-key",
670 BaseURL: "https://api.custom.com/v1",
671 Type: catwalk.TypeOpenAI,
672 Models: []catwalk.Model{{
673 ID: "test-model",
674 }},
675 },
676 }),
677 }
678 cfg.setDefaults("/tmp", "")
679
680 env := env.NewFromMap(map[string]string{})
681 resolver := NewEnvironmentVariableResolver(env)
682 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
683 require.NoError(t, err)
684
685 require.Equal(t, cfg.Providers.Len(), 1)
686 customProvider, exists := cfg.Providers.Get("custom")
687 require.True(t, exists)
688 require.Equal(t, "custom", customProvider.ID)
689 require.Equal(t, "test-key", customProvider.APIKey)
690 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
691 })
692
693 t.Run("custom anthropic provider is supported", func(t *testing.T) {
694 cfg := &Config{
695 Providers: csync.NewMapFrom(map[string]ProviderConfig{
696 "custom-anthropic": {
697 APIKey: "test-key",
698 BaseURL: "https://api.anthropic.com/v1",
699 Type: catwalk.TypeAnthropic,
700 Models: []catwalk.Model{{
701 ID: "claude-3-sonnet",
702 }},
703 },
704 }),
705 }
706 cfg.setDefaults("/tmp", "")
707
708 env := env.NewFromMap(map[string]string{})
709 resolver := NewEnvironmentVariableResolver(env)
710 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
711 require.NoError(t, err)
712
713 require.Equal(t, cfg.Providers.Len(), 1)
714 customProvider, exists := cfg.Providers.Get("custom-anthropic")
715 require.True(t, exists)
716 require.Equal(t, "custom-anthropic", customProvider.ID)
717 require.Equal(t, "test-key", customProvider.APIKey)
718 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
719 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
720 })
721
722 t.Run("disabled custom provider is removed", func(t *testing.T) {
723 cfg := &Config{
724 Providers: csync.NewMapFrom(map[string]ProviderConfig{
725 "custom": {
726 APIKey: "test-key",
727 BaseURL: "https://api.custom.com/v1",
728 Type: catwalk.TypeOpenAI,
729 Disable: true,
730 Models: []catwalk.Model{{
731 ID: "test-model",
732 }},
733 },
734 }),
735 }
736 cfg.setDefaults("/tmp", "")
737
738 env := env.NewFromMap(map[string]string{})
739 resolver := NewEnvironmentVariableResolver(env)
740 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
741 require.NoError(t, err)
742
743 require.Equal(t, cfg.Providers.Len(), 0)
744 _, exists := cfg.Providers.Get("custom")
745 require.False(t, exists)
746 })
747}
748
749func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
750 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
751 knownProviders := []catwalk.Provider{
752 {
753 ID: catwalk.InferenceProviderVertexAI,
754 APIKey: "",
755 APIEndpoint: "",
756 Models: []catwalk.Model{{
757 ID: "gemini-pro",
758 }},
759 },
760 }
761
762 cfg := &Config{
763 Providers: csync.NewMapFrom(map[string]ProviderConfig{
764 "vertexai": {
765 BaseURL: "custom-url",
766 },
767 }),
768 }
769 cfg.setDefaults("/tmp", "")
770
771 env := env.NewFromMap(map[string]string{
772 "GOOGLE_GENAI_USE_VERTEXAI": "false",
773 })
774 resolver := NewEnvironmentVariableResolver(env)
775 err := cfg.configureProviders(env, resolver, knownProviders)
776 require.NoError(t, err)
777
778 require.Equal(t, cfg.Providers.Len(), 0)
779 _, exists := cfg.Providers.Get("vertexai")
780 require.False(t, exists)
781 })
782
783 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
784 knownProviders := []catwalk.Provider{
785 {
786 ID: catwalk.InferenceProviderBedrock,
787 APIKey: "",
788 APIEndpoint: "",
789 Models: []catwalk.Model{{
790 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
791 }},
792 },
793 }
794
795 cfg := &Config{
796 Providers: csync.NewMapFrom(map[string]ProviderConfig{
797 "bedrock": {
798 BaseURL: "custom-url",
799 },
800 }),
801 }
802 cfg.setDefaults("/tmp", "")
803
804 env := env.NewFromMap(map[string]string{})
805 resolver := NewEnvironmentVariableResolver(env)
806 err := cfg.configureProviders(env, resolver, knownProviders)
807 require.NoError(t, err)
808
809 require.Equal(t, cfg.Providers.Len(), 0)
810 _, exists := cfg.Providers.Get("bedrock")
811 require.False(t, exists)
812 })
813
814 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
815 knownProviders := []catwalk.Provider{
816 {
817 ID: "openai",
818 APIKey: "$MISSING_API_KEY",
819 APIEndpoint: "https://api.openai.com/v1",
820 Models: []catwalk.Model{{
821 ID: "test-model",
822 }},
823 },
824 }
825
826 cfg := &Config{
827 Providers: csync.NewMapFrom(map[string]ProviderConfig{
828 "openai": {
829 BaseURL: "custom-url",
830 },
831 }),
832 }
833 cfg.setDefaults("/tmp", "")
834
835 env := env.NewFromMap(map[string]string{})
836 resolver := NewEnvironmentVariableResolver(env)
837 err := cfg.configureProviders(env, resolver, knownProviders)
838 require.NoError(t, err)
839
840 require.Equal(t, cfg.Providers.Len(), 0)
841 _, exists := cfg.Providers.Get("openai")
842 require.False(t, exists)
843 })
844
845 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
846 knownProviders := []catwalk.Provider{
847 {
848 ID: "openai",
849 APIKey: "$OPENAI_API_KEY",
850 APIEndpoint: "$MISSING_ENDPOINT",
851 Models: []catwalk.Model{{
852 ID: "test-model",
853 }},
854 },
855 }
856
857 cfg := &Config{
858 Providers: csync.NewMapFrom(map[string]ProviderConfig{
859 "openai": {
860 APIKey: "test-key",
861 },
862 }),
863 }
864 cfg.setDefaults("/tmp", "")
865
866 env := env.NewFromMap(map[string]string{
867 "OPENAI_API_KEY": "test-key",
868 })
869 resolver := NewEnvironmentVariableResolver(env)
870 err := cfg.configureProviders(env, resolver, knownProviders)
871 require.NoError(t, err)
872
873 require.Equal(t, cfg.Providers.Len(), 1)
874 _, exists := cfg.Providers.Get("openai")
875 require.True(t, exists)
876 })
877}
878
879func TestConfig_defaultModelSelection(t *testing.T) {
880 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
881 knownProviders := []catwalk.Provider{
882 {
883 ID: "openai",
884 APIKey: "abc",
885 DefaultLargeModelID: "large-model",
886 DefaultSmallModelID: "small-model",
887 Models: []catwalk.Model{
888 {
889 ID: "large-model",
890 DefaultMaxTokens: 1000,
891 },
892 {
893 ID: "small-model",
894 DefaultMaxTokens: 500,
895 },
896 },
897 },
898 }
899
900 cfg := &Config{}
901 cfg.setDefaults("/tmp", "")
902 env := env.NewFromMap(map[string]string{})
903 resolver := NewEnvironmentVariableResolver(env)
904 err := cfg.configureProviders(env, resolver, knownProviders)
905 require.NoError(t, err)
906
907 large, small, err := cfg.defaultModelSelection(knownProviders)
908 require.NoError(t, err)
909 require.Equal(t, "large-model", large.Model)
910 require.Equal(t, "openai", large.Provider)
911 require.Equal(t, int64(1000), large.MaxTokens)
912 require.Equal(t, "small-model", small.Model)
913 require.Equal(t, "openai", small.Provider)
914 require.Equal(t, int64(500), small.MaxTokens)
915 })
916 t.Run("should error if no providers configured", func(t *testing.T) {
917 knownProviders := []catwalk.Provider{
918 {
919 ID: "openai",
920 APIKey: "$MISSING_KEY",
921 DefaultLargeModelID: "large-model",
922 DefaultSmallModelID: "small-model",
923 Models: []catwalk.Model{
924 {
925 ID: "large-model",
926 DefaultMaxTokens: 1000,
927 },
928 {
929 ID: "small-model",
930 DefaultMaxTokens: 500,
931 },
932 },
933 },
934 }
935
936 cfg := &Config{}
937 cfg.setDefaults("/tmp", "")
938 env := env.NewFromMap(map[string]string{})
939 resolver := NewEnvironmentVariableResolver(env)
940 err := cfg.configureProviders(env, resolver, knownProviders)
941 require.NoError(t, err)
942
943 _, _, err = cfg.defaultModelSelection(knownProviders)
944 require.Error(t, err)
945 })
946 t.Run("should error if model is missing", func(t *testing.T) {
947 knownProviders := []catwalk.Provider{
948 {
949 ID: "openai",
950 APIKey: "abc",
951 DefaultLargeModelID: "large-model",
952 DefaultSmallModelID: "small-model",
953 Models: []catwalk.Model{
954 {
955 ID: "not-large-model",
956 DefaultMaxTokens: 1000,
957 },
958 {
959 ID: "small-model",
960 DefaultMaxTokens: 500,
961 },
962 },
963 },
964 }
965
966 cfg := &Config{}
967 cfg.setDefaults("/tmp", "")
968 env := env.NewFromMap(map[string]string{})
969 resolver := NewEnvironmentVariableResolver(env)
970 err := cfg.configureProviders(env, resolver, knownProviders)
971 require.NoError(t, err)
972 _, _, err = cfg.defaultModelSelection(knownProviders)
973 require.Error(t, err)
974 })
975
976 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
977 knownProviders := []catwalk.Provider{
978 {
979 ID: "openai",
980 APIKey: "$MISSING", // will not be included in the config
981 DefaultLargeModelID: "large-model",
982 DefaultSmallModelID: "small-model",
983 Models: []catwalk.Model{
984 {
985 ID: "not-large-model",
986 DefaultMaxTokens: 1000,
987 },
988 {
989 ID: "small-model",
990 DefaultMaxTokens: 500,
991 },
992 },
993 },
994 }
995
996 cfg := &Config{
997 Providers: csync.NewMapFrom(map[string]ProviderConfig{
998 "custom": {
999 APIKey: "test-key",
1000 BaseURL: "https://api.custom.com/v1",
1001 Models: []catwalk.Model{
1002 {
1003 ID: "model",
1004 DefaultMaxTokens: 600,
1005 },
1006 },
1007 },
1008 }),
1009 }
1010 cfg.setDefaults("/tmp", "")
1011 env := env.NewFromMap(map[string]string{})
1012 resolver := NewEnvironmentVariableResolver(env)
1013 err := cfg.configureProviders(env, resolver, knownProviders)
1014 require.NoError(t, err)
1015 large, small, err := cfg.defaultModelSelection(knownProviders)
1016 require.NoError(t, err)
1017 require.Equal(t, "model", large.Model)
1018 require.Equal(t, "custom", large.Provider)
1019 require.Equal(t, int64(600), large.MaxTokens)
1020 require.Equal(t, "model", small.Model)
1021 require.Equal(t, "custom", small.Provider)
1022 require.Equal(t, int64(600), small.MaxTokens)
1023 })
1024
1025 t.Run("should fail if no model configured", func(t *testing.T) {
1026 knownProviders := []catwalk.Provider{
1027 {
1028 ID: "openai",
1029 APIKey: "$MISSING", // will not be included in the config
1030 DefaultLargeModelID: "large-model",
1031 DefaultSmallModelID: "small-model",
1032 Models: []catwalk.Model{
1033 {
1034 ID: "not-large-model",
1035 DefaultMaxTokens: 1000,
1036 },
1037 {
1038 ID: "small-model",
1039 DefaultMaxTokens: 500,
1040 },
1041 },
1042 },
1043 }
1044
1045 cfg := &Config{
1046 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1047 "custom": {
1048 APIKey: "test-key",
1049 BaseURL: "https://api.custom.com/v1",
1050 Models: []catwalk.Model{},
1051 },
1052 }),
1053 }
1054 cfg.setDefaults("/tmp", "")
1055 env := env.NewFromMap(map[string]string{})
1056 resolver := NewEnvironmentVariableResolver(env)
1057 err := cfg.configureProviders(env, resolver, knownProviders)
1058 require.NoError(t, err)
1059 _, _, err = cfg.defaultModelSelection(knownProviders)
1060 require.Error(t, err)
1061 })
1062 t.Run("should use the default provider first", func(t *testing.T) {
1063 knownProviders := []catwalk.Provider{
1064 {
1065 ID: "openai",
1066 APIKey: "set",
1067 DefaultLargeModelID: "large-model",
1068 DefaultSmallModelID: "small-model",
1069 Models: []catwalk.Model{
1070 {
1071 ID: "large-model",
1072 DefaultMaxTokens: 1000,
1073 },
1074 {
1075 ID: "small-model",
1076 DefaultMaxTokens: 500,
1077 },
1078 },
1079 },
1080 }
1081
1082 cfg := &Config{
1083 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1084 "custom": {
1085 APIKey: "test-key",
1086 BaseURL: "https://api.custom.com/v1",
1087 Models: []catwalk.Model{
1088 {
1089 ID: "large-model",
1090 DefaultMaxTokens: 1000,
1091 },
1092 },
1093 },
1094 }),
1095 }
1096 cfg.setDefaults("/tmp", "")
1097 env := env.NewFromMap(map[string]string{})
1098 resolver := NewEnvironmentVariableResolver(env)
1099 err := cfg.configureProviders(env, resolver, knownProviders)
1100 require.NoError(t, err)
1101 large, small, err := cfg.defaultModelSelection(knownProviders)
1102 require.NoError(t, err)
1103 require.Equal(t, "large-model", large.Model)
1104 require.Equal(t, "openai", large.Provider)
1105 require.Equal(t, int64(1000), large.MaxTokens)
1106 require.Equal(t, "small-model", small.Model)
1107 require.Equal(t, "openai", small.Provider)
1108 require.Equal(t, int64(500), small.MaxTokens)
1109 })
1110}
1111
1112func TestConfig_configureSelectedModels(t *testing.T) {
1113 t.Run("should override defaults", func(t *testing.T) {
1114 knownProviders := []catwalk.Provider{
1115 {
1116 ID: "openai",
1117 APIKey: "abc",
1118 DefaultLargeModelID: "large-model",
1119 DefaultSmallModelID: "small-model",
1120 Models: []catwalk.Model{
1121 {
1122 ID: "larger-model",
1123 DefaultMaxTokens: 2000,
1124 },
1125 {
1126 ID: "large-model",
1127 DefaultMaxTokens: 1000,
1128 },
1129 {
1130 ID: "small-model",
1131 DefaultMaxTokens: 500,
1132 },
1133 },
1134 },
1135 }
1136
1137 cfg := &Config{
1138 Models: map[SelectedModelType]SelectedModel{
1139 "large": {
1140 Model: "larger-model",
1141 },
1142 },
1143 }
1144 cfg.setDefaults("/tmp", "")
1145 env := env.NewFromMap(map[string]string{})
1146 resolver := NewEnvironmentVariableResolver(env)
1147 err := cfg.configureProviders(env, resolver, knownProviders)
1148 require.NoError(t, err)
1149
1150 err = cfg.configureSelectedModels(knownProviders)
1151 require.NoError(t, err)
1152 large := cfg.Models[SelectedModelTypeLarge]
1153 small := cfg.Models[SelectedModelTypeSmall]
1154 require.Equal(t, "larger-model", large.Model)
1155 require.Equal(t, "openai", large.Provider)
1156 require.Equal(t, int64(2000), large.MaxTokens)
1157 require.Equal(t, "small-model", small.Model)
1158 require.Equal(t, "openai", small.Provider)
1159 require.Equal(t, int64(500), small.MaxTokens)
1160 })
1161 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1162 knownProviders := []catwalk.Provider{
1163 {
1164 ID: "openai",
1165 APIKey: "abc",
1166 DefaultLargeModelID: "large-model",
1167 DefaultSmallModelID: "small-model",
1168 Models: []catwalk.Model{
1169 {
1170 ID: "large-model",
1171 DefaultMaxTokens: 1000,
1172 },
1173 {
1174 ID: "small-model",
1175 DefaultMaxTokens: 500,
1176 },
1177 },
1178 },
1179 {
1180 ID: "anthropic",
1181 APIKey: "abc",
1182 DefaultLargeModelID: "a-large-model",
1183 DefaultSmallModelID: "a-small-model",
1184 Models: []catwalk.Model{
1185 {
1186 ID: "a-large-model",
1187 DefaultMaxTokens: 1000,
1188 },
1189 {
1190 ID: "a-small-model",
1191 DefaultMaxTokens: 200,
1192 },
1193 },
1194 },
1195 }
1196
1197 cfg := &Config{
1198 Models: map[SelectedModelType]SelectedModel{
1199 "small": {
1200 Model: "a-small-model",
1201 Provider: "anthropic",
1202 MaxTokens: 300,
1203 },
1204 },
1205 }
1206 cfg.setDefaults("/tmp", "")
1207 env := env.NewFromMap(map[string]string{})
1208 resolver := NewEnvironmentVariableResolver(env)
1209 err := cfg.configureProviders(env, resolver, knownProviders)
1210 require.NoError(t, err)
1211
1212 err = cfg.configureSelectedModels(knownProviders)
1213 require.NoError(t, err)
1214 large := cfg.Models[SelectedModelTypeLarge]
1215 small := cfg.Models[SelectedModelTypeSmall]
1216 require.Equal(t, "large-model", large.Model)
1217 require.Equal(t, "openai", large.Provider)
1218 require.Equal(t, int64(1000), large.MaxTokens)
1219 require.Equal(t, "a-small-model", small.Model)
1220 require.Equal(t, "anthropic", small.Provider)
1221 require.Equal(t, int64(300), small.MaxTokens)
1222 })
1223
1224 t.Run("should override the max tokens only", func(t *testing.T) {
1225 knownProviders := []catwalk.Provider{
1226 {
1227 ID: "openai",
1228 APIKey: "abc",
1229 DefaultLargeModelID: "large-model",
1230 DefaultSmallModelID: "small-model",
1231 Models: []catwalk.Model{
1232 {
1233 ID: "large-model",
1234 DefaultMaxTokens: 1000,
1235 },
1236 {
1237 ID: "small-model",
1238 DefaultMaxTokens: 500,
1239 },
1240 },
1241 },
1242 }
1243
1244 cfg := &Config{
1245 Models: map[SelectedModelType]SelectedModel{
1246 "large": {
1247 MaxTokens: 100,
1248 },
1249 },
1250 }
1251 cfg.setDefaults("/tmp", "")
1252 env := env.NewFromMap(map[string]string{})
1253 resolver := NewEnvironmentVariableResolver(env)
1254 err := cfg.configureProviders(env, resolver, knownProviders)
1255 require.NoError(t, err)
1256
1257 err = cfg.configureSelectedModels(knownProviders)
1258 require.NoError(t, err)
1259 large := cfg.Models[SelectedModelTypeLarge]
1260 require.Equal(t, "large-model", large.Model)
1261 require.Equal(t, "openai", large.Provider)
1262 require.Equal(t, int64(100), large.MaxTokens)
1263 })
1264}