1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16)
17
18func TestMain(m *testing.M) {
19 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
20
21 exitVal := m.Run()
22 os.Exit(exitVal)
23}
24
25func TestConfig_LoadFromReaders(t *testing.T) {
26 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
27 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
28 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
29
30 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
31
32 require.NoError(t, err)
33 require.NotNil(t, loadedConfig)
34 require.Equal(t, 1, loadedConfig.Providers.Len())
35 pc, _ := loadedConfig.Providers.Get("openai")
36 require.Equal(t, "key2", pc.APIKey)
37 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
38}
39
40func TestConfig_setDefaults(t *testing.T) {
41 cfg := &Config{}
42
43 cfg.setDefaults("/tmp", "")
44
45 require.NotNil(t, cfg.Options)
46 require.NotNil(t, cfg.Options.TUI)
47 require.NotNil(t, cfg.Options.ContextPaths)
48 require.NotNil(t, cfg.Providers)
49 require.NotNil(t, cfg.Models)
50 require.NotNil(t, cfg.LSP)
51 require.NotNil(t, cfg.MCP)
52 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
53 for _, path := range defaultContextPaths {
54 require.Contains(t, cfg.Options.ContextPaths, path)
55 }
56 require.Equal(t, "/tmp", cfg.workingDir)
57}
58
59func TestConfig_configureProviders(t *testing.T) {
60 knownProviders := []catwalk.Provider{
61 {
62 ID: "openai",
63 APIKey: "$OPENAI_API_KEY",
64 APIEndpoint: "https://api.openai.com/v1",
65 Models: []catwalk.Model{{
66 ID: "test-model",
67 }},
68 },
69 }
70
71 cfg := &Config{}
72 cfg.setDefaults("/tmp", "")
73 env := env.NewFromMap(map[string]string{
74 "OPENAI_API_KEY": "test-key",
75 })
76 resolver := NewEnvironmentVariableResolver(env)
77 err := cfg.configureProviders(env, resolver, knownProviders)
78 require.NoError(t, err)
79 require.Equal(t, 1, cfg.Providers.Len())
80
81 // We want to make sure that we keep the configured API key as a placeholder
82 pc, _ := cfg.Providers.Get("openai")
83 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
84}
85
86func TestConfig_configureProvidersWithOverride(t *testing.T) {
87 knownProviders := []catwalk.Provider{
88 {
89 ID: "openai",
90 APIKey: "$OPENAI_API_KEY",
91 APIEndpoint: "https://api.openai.com/v1",
92 Models: []catwalk.Model{{
93 ID: "test-model",
94 }},
95 },
96 }
97
98 cfg := &Config{
99 Providers: csync.NewMap[string, ProviderConfig](),
100 }
101 cfg.Providers.Set("openai", ProviderConfig{
102 APIKey: "xyz",
103 BaseURL: "https://api.openai.com/v2",
104 Models: []catwalk.Model{
105 {
106 ID: "test-model",
107 Name: "Updated",
108 },
109 {
110 ID: "another-model",
111 },
112 },
113 })
114 cfg.setDefaults("/tmp", "")
115
116 env := env.NewFromMap(map[string]string{
117 "OPENAI_API_KEY": "test-key",
118 })
119 resolver := NewEnvironmentVariableResolver(env)
120 err := cfg.configureProviders(env, resolver, knownProviders)
121 require.NoError(t, err)
122 require.Equal(t, 1, cfg.Providers.Len())
123
124 // We want to make sure that we keep the configured API key as a placeholder
125 pc, _ := cfg.Providers.Get("openai")
126 require.Equal(t, "xyz", pc.APIKey)
127 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
128 require.Len(t, pc.Models, 2)
129 require.Equal(t, "Updated", pc.Models[0].Name)
130}
131
132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
133 knownProviders := []catwalk.Provider{
134 {
135 ID: "openai",
136 APIKey: "$OPENAI_API_KEY",
137 APIEndpoint: "https://api.openai.com/v1",
138 Models: []catwalk.Model{{
139 ID: "test-model",
140 }},
141 },
142 }
143
144 cfg := &Config{
145 Providers: csync.NewMapFrom(map[string]ProviderConfig{
146 "custom": {
147 APIKey: "xyz",
148 BaseURL: "https://api.someendpoint.com/v2",
149 Models: []catwalk.Model{
150 {
151 ID: "test-model",
152 },
153 },
154 },
155 }),
156 }
157 cfg.setDefaults("/tmp", "")
158 env := env.NewFromMap(map[string]string{
159 "OPENAI_API_KEY": "test-key",
160 })
161 resolver := NewEnvironmentVariableResolver(env)
162 err := cfg.configureProviders(env, resolver, knownProviders)
163 require.NoError(t, err)
164 // Should be to because of the env variable
165 require.Equal(t, cfg.Providers.Len(), 2)
166
167 // We want to make sure that we keep the configured API key as a placeholder
168 pc, _ := cfg.Providers.Get("custom")
169 require.Equal(t, "xyz", pc.APIKey)
170 // Make sure we set the ID correctly
171 require.Equal(t, "custom", pc.ID)
172 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
173 require.Len(t, pc.Models, 1)
174
175 _, ok := cfg.Providers.Get("openai")
176 require.True(t, ok, "OpenAI provider should still be present")
177}
178
179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
180 knownProviders := []catwalk.Provider{
181 {
182 ID: catwalk.InferenceProviderBedrock,
183 APIKey: "",
184 APIEndpoint: "",
185 Models: []catwalk.Model{{
186 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
187 }},
188 },
189 }
190
191 cfg := &Config{}
192 cfg.setDefaults("/tmp", "")
193 env := env.NewFromMap(map[string]string{
194 "AWS_ACCESS_KEY_ID": "test-key-id",
195 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
196 })
197 resolver := NewEnvironmentVariableResolver(env)
198 err := cfg.configureProviders(env, resolver, knownProviders)
199 require.NoError(t, err)
200 require.Equal(t, cfg.Providers.Len(), 1)
201
202 bedrockProvider, ok := cfg.Providers.Get("bedrock")
203 require.True(t, ok, "Bedrock provider should be present")
204 require.Len(t, bedrockProvider.Models, 1)
205 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
206}
207
208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
209 knownProviders := []catwalk.Provider{
210 {
211 ID: catwalk.InferenceProviderBedrock,
212 APIKey: "",
213 APIEndpoint: "",
214 Models: []catwalk.Model{{
215 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
216 }},
217 },
218 }
219
220 cfg := &Config{}
221 cfg.setDefaults("/tmp", "")
222 env := env.NewFromMap(map[string]string{})
223 resolver := NewEnvironmentVariableResolver(env)
224 err := cfg.configureProviders(env, resolver, knownProviders)
225 require.NoError(t, err)
226 // Provider should not be configured without credentials
227 require.Equal(t, cfg.Providers.Len(), 0)
228}
229
230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
231 knownProviders := []catwalk.Provider{
232 {
233 ID: catwalk.InferenceProviderBedrock,
234 APIKey: "",
235 APIEndpoint: "",
236 Models: []catwalk.Model{{
237 ID: "some-random-model",
238 }},
239 },
240 }
241
242 cfg := &Config{}
243 cfg.setDefaults("/tmp", "")
244 env := env.NewFromMap(map[string]string{
245 "AWS_ACCESS_KEY_ID": "test-key-id",
246 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
247 })
248 resolver := NewEnvironmentVariableResolver(env)
249 err := cfg.configureProviders(env, resolver, knownProviders)
250 require.Error(t, err)
251}
252
253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
254 knownProviders := []catwalk.Provider{
255 {
256 ID: catwalk.InferenceProviderVertexAI,
257 APIKey: "",
258 APIEndpoint: "",
259 Models: []catwalk.Model{{
260 ID: "gemini-pro",
261 }},
262 },
263 }
264
265 cfg := &Config{}
266 cfg.setDefaults("/tmp", "")
267 env := env.NewFromMap(map[string]string{
268 "VERTEXAI_PROJECT": "test-project",
269 "VERTEXAI_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 require.NoError(t, err)
274 require.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 require.True(t, ok, "VertexAI provider should be present")
278 require.Len(t, vertexProvider.Models, 1)
279 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp", "")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 require.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 require.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp", "")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 require.NoError(t, err)
331 // Provider should not be configured without project
332 require.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp", "")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 require.NoError(t, err)
355 require.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 require.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 require.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 require.Len(t, enabled, 1)
401 require.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 require.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 require.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 require.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 require.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
458 cfg := &Config{
459 Options: &Options{
460 DisabledTools: []string{},
461 },
462 }
463
464 cfg.SetupAgents()
465 coderAgent, ok := cfg.Agents[AgentCoder]
466 require.True(t, ok)
467 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
468
469 taskAgent, ok := cfg.Agents[AgentTask]
470 require.True(t, ok)
471 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
472}
473
474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
475 cfg := &Config{
476 Options: &Options{
477 DisabledTools: []string{
478 "edit",
479 "download",
480 "grep",
481 },
482 },
483 }
484
485 cfg.SetupAgents()
486 coderAgent, ok := cfg.Agents[AgentCoder]
487 require.True(t, ok)
488 assert.Equal(t, []string{"agent", "bash", "bash_output", "bash_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
489
490 taskAgent, ok := cfg.Agents[AgentTask]
491 require.True(t, ok)
492 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
493}
494
495func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
496 cfg := &Config{
497 Options: &Options{
498 DisabledTools: []string{
499 "glob",
500 "grep",
501 "ls",
502 "sourcegraph",
503 "view",
504 },
505 },
506 }
507
508 cfg.SetupAgents()
509 coderAgent, ok := cfg.Agents[AgentCoder]
510 require.True(t, ok)
511 assert.Equal(t, []string{"agent", "bash", "bash_output", "bash_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools)
512
513 taskAgent, ok := cfg.Agents[AgentTask]
514 require.True(t, ok)
515 assert.Equal(t, []string{}, taskAgent.AllowedTools)
516}
517
518func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
519 knownProviders := []catwalk.Provider{
520 {
521 ID: "openai",
522 APIKey: "$OPENAI_API_KEY",
523 APIEndpoint: "https://api.openai.com/v1",
524 Models: []catwalk.Model{{
525 ID: "test-model",
526 }},
527 },
528 }
529
530 cfg := &Config{
531 Providers: csync.NewMapFrom(map[string]ProviderConfig{
532 "openai": {
533 Disable: true,
534 },
535 }),
536 }
537 cfg.setDefaults("/tmp", "")
538
539 env := env.NewFromMap(map[string]string{
540 "OPENAI_API_KEY": "test-key",
541 })
542 resolver := NewEnvironmentVariableResolver(env)
543 err := cfg.configureProviders(env, resolver, knownProviders)
544 require.NoError(t, err)
545
546 require.Equal(t, cfg.Providers.Len(), 1)
547 prov, exists := cfg.Providers.Get("openai")
548 require.True(t, exists)
549 require.True(t, prov.Disable)
550}
551
552func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
553 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
554 cfg := &Config{
555 Providers: csync.NewMapFrom(map[string]ProviderConfig{
556 "custom": {
557 BaseURL: "https://api.custom.com/v1",
558 Models: []catwalk.Model{{
559 ID: "test-model",
560 }},
561 },
562 "openai": {
563 APIKey: "$MISSING",
564 },
565 }),
566 }
567 cfg.setDefaults("/tmp", "")
568
569 env := env.NewFromMap(map[string]string{})
570 resolver := NewEnvironmentVariableResolver(env)
571 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
572 require.NoError(t, err)
573
574 require.Equal(t, cfg.Providers.Len(), 1)
575 _, exists := cfg.Providers.Get("custom")
576 require.True(t, exists)
577 })
578
579 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
580 cfg := &Config{
581 Providers: csync.NewMapFrom(map[string]ProviderConfig{
582 "custom": {
583 APIKey: "test-key",
584 Models: []catwalk.Model{{
585 ID: "test-model",
586 }},
587 },
588 }),
589 }
590 cfg.setDefaults("/tmp", "")
591
592 env := env.NewFromMap(map[string]string{})
593 resolver := NewEnvironmentVariableResolver(env)
594 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
595 require.NoError(t, err)
596
597 require.Equal(t, cfg.Providers.Len(), 0)
598 _, exists := cfg.Providers.Get("custom")
599 require.False(t, exists)
600 })
601
602 t.Run("custom provider with no models is removed", func(t *testing.T) {
603 cfg := &Config{
604 Providers: csync.NewMapFrom(map[string]ProviderConfig{
605 "custom": {
606 APIKey: "test-key",
607 BaseURL: "https://api.custom.com/v1",
608 Models: []catwalk.Model{},
609 },
610 }),
611 }
612 cfg.setDefaults("/tmp", "")
613
614 env := env.NewFromMap(map[string]string{})
615 resolver := NewEnvironmentVariableResolver(env)
616 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
617 require.NoError(t, err)
618
619 require.Equal(t, cfg.Providers.Len(), 0)
620 _, exists := cfg.Providers.Get("custom")
621 require.False(t, exists)
622 })
623
624 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
625 cfg := &Config{
626 Providers: csync.NewMapFrom(map[string]ProviderConfig{
627 "custom": {
628 APIKey: "test-key",
629 BaseURL: "https://api.custom.com/v1",
630 Type: "unsupported",
631 Models: []catwalk.Model{{
632 ID: "test-model",
633 }},
634 },
635 }),
636 }
637 cfg.setDefaults("/tmp", "")
638
639 env := env.NewFromMap(map[string]string{})
640 resolver := NewEnvironmentVariableResolver(env)
641 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
642 require.NoError(t, err)
643
644 require.Equal(t, cfg.Providers.Len(), 0)
645 _, exists := cfg.Providers.Get("custom")
646 require.False(t, exists)
647 })
648
649 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
650 cfg := &Config{
651 Providers: csync.NewMapFrom(map[string]ProviderConfig{
652 "custom": {
653 APIKey: "test-key",
654 BaseURL: "https://api.custom.com/v1",
655 Type: catwalk.TypeOpenAI,
656 Models: []catwalk.Model{{
657 ID: "test-model",
658 }},
659 },
660 }),
661 }
662 cfg.setDefaults("/tmp", "")
663
664 env := env.NewFromMap(map[string]string{})
665 resolver := NewEnvironmentVariableResolver(env)
666 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
667 require.NoError(t, err)
668
669 require.Equal(t, cfg.Providers.Len(), 1)
670 customProvider, exists := cfg.Providers.Get("custom")
671 require.True(t, exists)
672 require.Equal(t, "custom", customProvider.ID)
673 require.Equal(t, "test-key", customProvider.APIKey)
674 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
675 })
676
677 t.Run("custom anthropic provider is supported", func(t *testing.T) {
678 cfg := &Config{
679 Providers: csync.NewMapFrom(map[string]ProviderConfig{
680 "custom-anthropic": {
681 APIKey: "test-key",
682 BaseURL: "https://api.anthropic.com/v1",
683 Type: catwalk.TypeAnthropic,
684 Models: []catwalk.Model{{
685 ID: "claude-3-sonnet",
686 }},
687 },
688 }),
689 }
690 cfg.setDefaults("/tmp", "")
691
692 env := env.NewFromMap(map[string]string{})
693 resolver := NewEnvironmentVariableResolver(env)
694 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
695 require.NoError(t, err)
696
697 require.Equal(t, cfg.Providers.Len(), 1)
698 customProvider, exists := cfg.Providers.Get("custom-anthropic")
699 require.True(t, exists)
700 require.Equal(t, "custom-anthropic", customProvider.ID)
701 require.Equal(t, "test-key", customProvider.APIKey)
702 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
703 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
704 })
705
706 t.Run("disabled custom provider is removed", func(t *testing.T) {
707 cfg := &Config{
708 Providers: csync.NewMapFrom(map[string]ProviderConfig{
709 "custom": {
710 APIKey: "test-key",
711 BaseURL: "https://api.custom.com/v1",
712 Type: catwalk.TypeOpenAI,
713 Disable: true,
714 Models: []catwalk.Model{{
715 ID: "test-model",
716 }},
717 },
718 }),
719 }
720 cfg.setDefaults("/tmp", "")
721
722 env := env.NewFromMap(map[string]string{})
723 resolver := NewEnvironmentVariableResolver(env)
724 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
725 require.NoError(t, err)
726
727 require.Equal(t, cfg.Providers.Len(), 0)
728 _, exists := cfg.Providers.Get("custom")
729 require.False(t, exists)
730 })
731}
732
733func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
734 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
735 knownProviders := []catwalk.Provider{
736 {
737 ID: catwalk.InferenceProviderVertexAI,
738 APIKey: "",
739 APIEndpoint: "",
740 Models: []catwalk.Model{{
741 ID: "gemini-pro",
742 }},
743 },
744 }
745
746 cfg := &Config{
747 Providers: csync.NewMapFrom(map[string]ProviderConfig{
748 "vertexai": {
749 BaseURL: "custom-url",
750 },
751 }),
752 }
753 cfg.setDefaults("/tmp", "")
754
755 env := env.NewFromMap(map[string]string{
756 "GOOGLE_GENAI_USE_VERTEXAI": "false",
757 })
758 resolver := NewEnvironmentVariableResolver(env)
759 err := cfg.configureProviders(env, resolver, knownProviders)
760 require.NoError(t, err)
761
762 require.Equal(t, cfg.Providers.Len(), 0)
763 _, exists := cfg.Providers.Get("vertexai")
764 require.False(t, exists)
765 })
766
767 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
768 knownProviders := []catwalk.Provider{
769 {
770 ID: catwalk.InferenceProviderBedrock,
771 APIKey: "",
772 APIEndpoint: "",
773 Models: []catwalk.Model{{
774 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
775 }},
776 },
777 }
778
779 cfg := &Config{
780 Providers: csync.NewMapFrom(map[string]ProviderConfig{
781 "bedrock": {
782 BaseURL: "custom-url",
783 },
784 }),
785 }
786 cfg.setDefaults("/tmp", "")
787
788 env := env.NewFromMap(map[string]string{})
789 resolver := NewEnvironmentVariableResolver(env)
790 err := cfg.configureProviders(env, resolver, knownProviders)
791 require.NoError(t, err)
792
793 require.Equal(t, cfg.Providers.Len(), 0)
794 _, exists := cfg.Providers.Get("bedrock")
795 require.False(t, exists)
796 })
797
798 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
799 knownProviders := []catwalk.Provider{
800 {
801 ID: "openai",
802 APIKey: "$MISSING_API_KEY",
803 APIEndpoint: "https://api.openai.com/v1",
804 Models: []catwalk.Model{{
805 ID: "test-model",
806 }},
807 },
808 }
809
810 cfg := &Config{
811 Providers: csync.NewMapFrom(map[string]ProviderConfig{
812 "openai": {
813 BaseURL: "custom-url",
814 },
815 }),
816 }
817 cfg.setDefaults("/tmp", "")
818
819 env := env.NewFromMap(map[string]string{})
820 resolver := NewEnvironmentVariableResolver(env)
821 err := cfg.configureProviders(env, resolver, knownProviders)
822 require.NoError(t, err)
823
824 require.Equal(t, cfg.Providers.Len(), 0)
825 _, exists := cfg.Providers.Get("openai")
826 require.False(t, exists)
827 })
828
829 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
830 knownProviders := []catwalk.Provider{
831 {
832 ID: "openai",
833 APIKey: "$OPENAI_API_KEY",
834 APIEndpoint: "$MISSING_ENDPOINT",
835 Models: []catwalk.Model{{
836 ID: "test-model",
837 }},
838 },
839 }
840
841 cfg := &Config{
842 Providers: csync.NewMapFrom(map[string]ProviderConfig{
843 "openai": {
844 APIKey: "test-key",
845 },
846 }),
847 }
848 cfg.setDefaults("/tmp", "")
849
850 env := env.NewFromMap(map[string]string{
851 "OPENAI_API_KEY": "test-key",
852 })
853 resolver := NewEnvironmentVariableResolver(env)
854 err := cfg.configureProviders(env, resolver, knownProviders)
855 require.NoError(t, err)
856
857 require.Equal(t, cfg.Providers.Len(), 1)
858 _, exists := cfg.Providers.Get("openai")
859 require.True(t, exists)
860 })
861}
862
863func TestConfig_defaultModelSelection(t *testing.T) {
864 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
865 knownProviders := []catwalk.Provider{
866 {
867 ID: "openai",
868 APIKey: "abc",
869 DefaultLargeModelID: "large-model",
870 DefaultSmallModelID: "small-model",
871 Models: []catwalk.Model{
872 {
873 ID: "large-model",
874 DefaultMaxTokens: 1000,
875 },
876 {
877 ID: "small-model",
878 DefaultMaxTokens: 500,
879 },
880 },
881 },
882 }
883
884 cfg := &Config{}
885 cfg.setDefaults("/tmp", "")
886 env := env.NewFromMap(map[string]string{})
887 resolver := NewEnvironmentVariableResolver(env)
888 err := cfg.configureProviders(env, resolver, knownProviders)
889 require.NoError(t, err)
890
891 large, small, err := cfg.defaultModelSelection(knownProviders)
892 require.NoError(t, err)
893 require.Equal(t, "large-model", large.Model)
894 require.Equal(t, "openai", large.Provider)
895 require.Equal(t, int64(1000), large.MaxTokens)
896 require.Equal(t, "small-model", small.Model)
897 require.Equal(t, "openai", small.Provider)
898 require.Equal(t, int64(500), small.MaxTokens)
899 })
900 t.Run("should error if no providers configured", func(t *testing.T) {
901 knownProviders := []catwalk.Provider{
902 {
903 ID: "openai",
904 APIKey: "$MISSING_KEY",
905 DefaultLargeModelID: "large-model",
906 DefaultSmallModelID: "small-model",
907 Models: []catwalk.Model{
908 {
909 ID: "large-model",
910 DefaultMaxTokens: 1000,
911 },
912 {
913 ID: "small-model",
914 DefaultMaxTokens: 500,
915 },
916 },
917 },
918 }
919
920 cfg := &Config{}
921 cfg.setDefaults("/tmp", "")
922 env := env.NewFromMap(map[string]string{})
923 resolver := NewEnvironmentVariableResolver(env)
924 err := cfg.configureProviders(env, resolver, knownProviders)
925 require.NoError(t, err)
926
927 _, _, err = cfg.defaultModelSelection(knownProviders)
928 require.Error(t, err)
929 })
930 t.Run("should error if model is missing", func(t *testing.T) {
931 knownProviders := []catwalk.Provider{
932 {
933 ID: "openai",
934 APIKey: "abc",
935 DefaultLargeModelID: "large-model",
936 DefaultSmallModelID: "small-model",
937 Models: []catwalk.Model{
938 {
939 ID: "not-large-model",
940 DefaultMaxTokens: 1000,
941 },
942 {
943 ID: "small-model",
944 DefaultMaxTokens: 500,
945 },
946 },
947 },
948 }
949
950 cfg := &Config{}
951 cfg.setDefaults("/tmp", "")
952 env := env.NewFromMap(map[string]string{})
953 resolver := NewEnvironmentVariableResolver(env)
954 err := cfg.configureProviders(env, resolver, knownProviders)
955 require.NoError(t, err)
956 _, _, err = cfg.defaultModelSelection(knownProviders)
957 require.Error(t, err)
958 })
959
960 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
961 knownProviders := []catwalk.Provider{
962 {
963 ID: "openai",
964 APIKey: "$MISSING", // will not be included in the config
965 DefaultLargeModelID: "large-model",
966 DefaultSmallModelID: "small-model",
967 Models: []catwalk.Model{
968 {
969 ID: "not-large-model",
970 DefaultMaxTokens: 1000,
971 },
972 {
973 ID: "small-model",
974 DefaultMaxTokens: 500,
975 },
976 },
977 },
978 }
979
980 cfg := &Config{
981 Providers: csync.NewMapFrom(map[string]ProviderConfig{
982 "custom": {
983 APIKey: "test-key",
984 BaseURL: "https://api.custom.com/v1",
985 Models: []catwalk.Model{
986 {
987 ID: "model",
988 DefaultMaxTokens: 600,
989 },
990 },
991 },
992 }),
993 }
994 cfg.setDefaults("/tmp", "")
995 env := env.NewFromMap(map[string]string{})
996 resolver := NewEnvironmentVariableResolver(env)
997 err := cfg.configureProviders(env, resolver, knownProviders)
998 require.NoError(t, err)
999 large, small, err := cfg.defaultModelSelection(knownProviders)
1000 require.NoError(t, err)
1001 require.Equal(t, "model", large.Model)
1002 require.Equal(t, "custom", large.Provider)
1003 require.Equal(t, int64(600), large.MaxTokens)
1004 require.Equal(t, "model", small.Model)
1005 require.Equal(t, "custom", small.Provider)
1006 require.Equal(t, int64(600), small.MaxTokens)
1007 })
1008
1009 t.Run("should fail if no model configured", func(t *testing.T) {
1010 knownProviders := []catwalk.Provider{
1011 {
1012 ID: "openai",
1013 APIKey: "$MISSING", // will not be included in the config
1014 DefaultLargeModelID: "large-model",
1015 DefaultSmallModelID: "small-model",
1016 Models: []catwalk.Model{
1017 {
1018 ID: "not-large-model",
1019 DefaultMaxTokens: 1000,
1020 },
1021 {
1022 ID: "small-model",
1023 DefaultMaxTokens: 500,
1024 },
1025 },
1026 },
1027 }
1028
1029 cfg := &Config{
1030 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1031 "custom": {
1032 APIKey: "test-key",
1033 BaseURL: "https://api.custom.com/v1",
1034 Models: []catwalk.Model{},
1035 },
1036 }),
1037 }
1038 cfg.setDefaults("/tmp", "")
1039 env := env.NewFromMap(map[string]string{})
1040 resolver := NewEnvironmentVariableResolver(env)
1041 err := cfg.configureProviders(env, resolver, knownProviders)
1042 require.NoError(t, err)
1043 _, _, err = cfg.defaultModelSelection(knownProviders)
1044 require.Error(t, err)
1045 })
1046 t.Run("should use the default provider first", func(t *testing.T) {
1047 knownProviders := []catwalk.Provider{
1048 {
1049 ID: "openai",
1050 APIKey: "set",
1051 DefaultLargeModelID: "large-model",
1052 DefaultSmallModelID: "small-model",
1053 Models: []catwalk.Model{
1054 {
1055 ID: "large-model",
1056 DefaultMaxTokens: 1000,
1057 },
1058 {
1059 ID: "small-model",
1060 DefaultMaxTokens: 500,
1061 },
1062 },
1063 },
1064 }
1065
1066 cfg := &Config{
1067 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1068 "custom": {
1069 APIKey: "test-key",
1070 BaseURL: "https://api.custom.com/v1",
1071 Models: []catwalk.Model{
1072 {
1073 ID: "large-model",
1074 DefaultMaxTokens: 1000,
1075 },
1076 },
1077 },
1078 }),
1079 }
1080 cfg.setDefaults("/tmp", "")
1081 env := env.NewFromMap(map[string]string{})
1082 resolver := NewEnvironmentVariableResolver(env)
1083 err := cfg.configureProviders(env, resolver, knownProviders)
1084 require.NoError(t, err)
1085 large, small, err := cfg.defaultModelSelection(knownProviders)
1086 require.NoError(t, err)
1087 require.Equal(t, "large-model", large.Model)
1088 require.Equal(t, "openai", large.Provider)
1089 require.Equal(t, int64(1000), large.MaxTokens)
1090 require.Equal(t, "small-model", small.Model)
1091 require.Equal(t, "openai", small.Provider)
1092 require.Equal(t, int64(500), small.MaxTokens)
1093 })
1094}
1095
1096func TestConfig_configureSelectedModels(t *testing.T) {
1097 t.Run("should override defaults", func(t *testing.T) {
1098 knownProviders := []catwalk.Provider{
1099 {
1100 ID: "openai",
1101 APIKey: "abc",
1102 DefaultLargeModelID: "large-model",
1103 DefaultSmallModelID: "small-model",
1104 Models: []catwalk.Model{
1105 {
1106 ID: "larger-model",
1107 DefaultMaxTokens: 2000,
1108 },
1109 {
1110 ID: "large-model",
1111 DefaultMaxTokens: 1000,
1112 },
1113 {
1114 ID: "small-model",
1115 DefaultMaxTokens: 500,
1116 },
1117 },
1118 },
1119 }
1120
1121 cfg := &Config{
1122 Models: map[SelectedModelType]SelectedModel{
1123 "large": {
1124 Model: "larger-model",
1125 },
1126 },
1127 }
1128 cfg.setDefaults("/tmp", "")
1129 env := env.NewFromMap(map[string]string{})
1130 resolver := NewEnvironmentVariableResolver(env)
1131 err := cfg.configureProviders(env, resolver, knownProviders)
1132 require.NoError(t, err)
1133
1134 err = cfg.configureSelectedModels(knownProviders)
1135 require.NoError(t, err)
1136 large := cfg.Models[SelectedModelTypeLarge]
1137 small := cfg.Models[SelectedModelTypeSmall]
1138 require.Equal(t, "larger-model", large.Model)
1139 require.Equal(t, "openai", large.Provider)
1140 require.Equal(t, int64(2000), large.MaxTokens)
1141 require.Equal(t, "small-model", small.Model)
1142 require.Equal(t, "openai", small.Provider)
1143 require.Equal(t, int64(500), small.MaxTokens)
1144 })
1145 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1146 knownProviders := []catwalk.Provider{
1147 {
1148 ID: "openai",
1149 APIKey: "abc",
1150 DefaultLargeModelID: "large-model",
1151 DefaultSmallModelID: "small-model",
1152 Models: []catwalk.Model{
1153 {
1154 ID: "large-model",
1155 DefaultMaxTokens: 1000,
1156 },
1157 {
1158 ID: "small-model",
1159 DefaultMaxTokens: 500,
1160 },
1161 },
1162 },
1163 {
1164 ID: "anthropic",
1165 APIKey: "abc",
1166 DefaultLargeModelID: "a-large-model",
1167 DefaultSmallModelID: "a-small-model",
1168 Models: []catwalk.Model{
1169 {
1170 ID: "a-large-model",
1171 DefaultMaxTokens: 1000,
1172 },
1173 {
1174 ID: "a-small-model",
1175 DefaultMaxTokens: 200,
1176 },
1177 },
1178 },
1179 }
1180
1181 cfg := &Config{
1182 Models: map[SelectedModelType]SelectedModel{
1183 "small": {
1184 Model: "a-small-model",
1185 Provider: "anthropic",
1186 MaxTokens: 300,
1187 },
1188 },
1189 }
1190 cfg.setDefaults("/tmp", "")
1191 env := env.NewFromMap(map[string]string{})
1192 resolver := NewEnvironmentVariableResolver(env)
1193 err := cfg.configureProviders(env, resolver, knownProviders)
1194 require.NoError(t, err)
1195
1196 err = cfg.configureSelectedModels(knownProviders)
1197 require.NoError(t, err)
1198 large := cfg.Models[SelectedModelTypeLarge]
1199 small := cfg.Models[SelectedModelTypeSmall]
1200 require.Equal(t, "large-model", large.Model)
1201 require.Equal(t, "openai", large.Provider)
1202 require.Equal(t, int64(1000), large.MaxTokens)
1203 require.Equal(t, "a-small-model", small.Model)
1204 require.Equal(t, "anthropic", small.Provider)
1205 require.Equal(t, int64(300), small.MaxTokens)
1206 })
1207
1208 t.Run("should override the max tokens only", func(t *testing.T) {
1209 knownProviders := []catwalk.Provider{
1210 {
1211 ID: "openai",
1212 APIKey: "abc",
1213 DefaultLargeModelID: "large-model",
1214 DefaultSmallModelID: "small-model",
1215 Models: []catwalk.Model{
1216 {
1217 ID: "large-model",
1218 DefaultMaxTokens: 1000,
1219 },
1220 {
1221 ID: "small-model",
1222 DefaultMaxTokens: 500,
1223 },
1224 },
1225 },
1226 }
1227
1228 cfg := &Config{
1229 Models: map[SelectedModelType]SelectedModel{
1230 "large": {
1231 MaxTokens: 100,
1232 },
1233 },
1234 }
1235 cfg.setDefaults("/tmp", "")
1236 env := env.NewFromMap(map[string]string{})
1237 resolver := NewEnvironmentVariableResolver(env)
1238 err := cfg.configureProviders(env, resolver, knownProviders)
1239 require.NoError(t, err)
1240
1241 err = cfg.configureSelectedModels(knownProviders)
1242 require.NoError(t, err)
1243 large := cfg.Models[SelectedModelTypeLarge]
1244 require.Equal(t, "large-model", large.Model)
1245 require.Equal(t, "openai", large.Provider)
1246 require.Equal(t, int64(100), large.MaxTokens)
1247 })
1248}