1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16)
17
18func TestMain(m *testing.M) {
19 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
20
21 exitVal := m.Run()
22 os.Exit(exitVal)
23}
24
25func TestConfig_LoadFromReaders(t *testing.T) {
26 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
27 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
28 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
29
30 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
31
32 require.NoError(t, err)
33 require.NotNil(t, loadedConfig)
34 require.Equal(t, 1, loadedConfig.Providers.Len())
35 pc, _ := loadedConfig.Providers.Get("openai")
36 require.Equal(t, "key2", pc.APIKey)
37 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
38}
39
40func TestConfig_setDefaults(t *testing.T) {
41 cfg := &Config{}
42
43 cfg.setDefaults("/tmp", "")
44
45 require.NotNil(t, cfg.Options)
46 require.NotNil(t, cfg.Options.TUI)
47 require.NotNil(t, cfg.Options.ContextPaths)
48 require.NotNil(t, cfg.Providers)
49 require.NotNil(t, cfg.Models)
50 require.NotNil(t, cfg.LSP)
51 require.NotNil(t, cfg.MCP)
52 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
53 for _, path := range defaultContextPaths {
54 require.Contains(t, cfg.Options.ContextPaths, path)
55 }
56 require.Equal(t, "/tmp", cfg.workingDir)
57}
58
59func TestConfig_configureProviders(t *testing.T) {
60 knownProviders := []catwalk.Provider{
61 {
62 ID: "openai",
63 APIKey: "$OPENAI_API_KEY",
64 APIEndpoint: "https://api.openai.com/v1",
65 Models: []catwalk.Model{{
66 ID: "test-model",
67 }},
68 },
69 }
70
71 cfg := &Config{}
72 cfg.setDefaults("/tmp", "")
73 env := env.NewFromMap(map[string]string{
74 "OPENAI_API_KEY": "test-key",
75 })
76 resolver := NewEnvironmentVariableResolver(env)
77 err := cfg.configureProviders(env, resolver, knownProviders)
78 require.NoError(t, err)
79 require.Equal(t, 1, cfg.Providers.Len())
80
81 // We want to make sure that we keep the configured API key as a placeholder
82 pc, _ := cfg.Providers.Get("openai")
83 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
84}
85
86func TestConfig_configureProvidersWithOverride(t *testing.T) {
87 knownProviders := []catwalk.Provider{
88 {
89 ID: "openai",
90 APIKey: "$OPENAI_API_KEY",
91 APIEndpoint: "https://api.openai.com/v1",
92 Models: []catwalk.Model{{
93 ID: "test-model",
94 }},
95 },
96 }
97
98 cfg := &Config{
99 Providers: csync.NewMap[string, ProviderConfig](),
100 }
101 cfg.Providers.Set("openai", ProviderConfig{
102 APIKey: "xyz",
103 BaseURL: "https://api.openai.com/v2",
104 Models: []catwalk.Model{
105 {
106 ID: "test-model",
107 Name: "Updated",
108 },
109 {
110 ID: "another-model",
111 },
112 },
113 })
114 cfg.setDefaults("/tmp", "")
115
116 env := env.NewFromMap(map[string]string{
117 "OPENAI_API_KEY": "test-key",
118 })
119 resolver := NewEnvironmentVariableResolver(env)
120 err := cfg.configureProviders(env, resolver, knownProviders)
121 require.NoError(t, err)
122 require.Equal(t, 1, cfg.Providers.Len())
123
124 // We want to make sure that we keep the configured API key as a placeholder
125 pc, _ := cfg.Providers.Get("openai")
126 require.Equal(t, "xyz", pc.APIKey)
127 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
128 require.Len(t, pc.Models, 2)
129 require.Equal(t, "Updated", pc.Models[0].Name)
130}
131
132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
133 knownProviders := []catwalk.Provider{
134 {
135 ID: "openai",
136 APIKey: "$OPENAI_API_KEY",
137 APIEndpoint: "https://api.openai.com/v1",
138 Models: []catwalk.Model{{
139 ID: "test-model",
140 }},
141 },
142 }
143
144 cfg := &Config{
145 Providers: csync.NewMapFrom(map[string]ProviderConfig{
146 "custom": {
147 APIKey: "xyz",
148 BaseURL: "https://api.someendpoint.com/v2",
149 Models: []catwalk.Model{
150 {
151 ID: "test-model",
152 },
153 },
154 },
155 }),
156 }
157 cfg.setDefaults("/tmp", "")
158 env := env.NewFromMap(map[string]string{
159 "OPENAI_API_KEY": "test-key",
160 })
161 resolver := NewEnvironmentVariableResolver(env)
162 err := cfg.configureProviders(env, resolver, knownProviders)
163 require.NoError(t, err)
164 // Should be to because of the env variable
165 require.Equal(t, cfg.Providers.Len(), 2)
166
167 // We want to make sure that we keep the configured API key as a placeholder
168 pc, _ := cfg.Providers.Get("custom")
169 require.Equal(t, "xyz", pc.APIKey)
170 // Make sure we set the ID correctly
171 require.Equal(t, "custom", pc.ID)
172 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
173 require.Len(t, pc.Models, 1)
174
175 _, ok := cfg.Providers.Get("openai")
176 require.True(t, ok, "OpenAI provider should still be present")
177}
178
179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
180 knownProviders := []catwalk.Provider{
181 {
182 ID: catwalk.InferenceProviderBedrock,
183 APIKey: "",
184 APIEndpoint: "",
185 Models: []catwalk.Model{{
186 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
187 }},
188 },
189 }
190
191 cfg := &Config{}
192 cfg.setDefaults("/tmp", "")
193 env := env.NewFromMap(map[string]string{
194 "AWS_ACCESS_KEY_ID": "test-key-id",
195 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
196 })
197 resolver := NewEnvironmentVariableResolver(env)
198 err := cfg.configureProviders(env, resolver, knownProviders)
199 require.NoError(t, err)
200 require.Equal(t, cfg.Providers.Len(), 1)
201
202 bedrockProvider, ok := cfg.Providers.Get("bedrock")
203 require.True(t, ok, "Bedrock provider should be present")
204 require.Len(t, bedrockProvider.Models, 1)
205 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
206}
207
208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
209 knownProviders := []catwalk.Provider{
210 {
211 ID: catwalk.InferenceProviderBedrock,
212 APIKey: "",
213 APIEndpoint: "",
214 Models: []catwalk.Model{{
215 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
216 }},
217 },
218 }
219
220 cfg := &Config{}
221 cfg.setDefaults("/tmp", "")
222 env := env.NewFromMap(map[string]string{})
223 resolver := NewEnvironmentVariableResolver(env)
224 err := cfg.configureProviders(env, resolver, knownProviders)
225 require.NoError(t, err)
226 // Provider should not be configured without credentials
227 require.Equal(t, cfg.Providers.Len(), 0)
228}
229
230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
231 knownProviders := []catwalk.Provider{
232 {
233 ID: catwalk.InferenceProviderBedrock,
234 APIKey: "",
235 APIEndpoint: "",
236 Models: []catwalk.Model{{
237 ID: "some-random-model",
238 }},
239 },
240 }
241
242 cfg := &Config{}
243 cfg.setDefaults("/tmp", "")
244 env := env.NewFromMap(map[string]string{
245 "AWS_ACCESS_KEY_ID": "test-key-id",
246 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
247 })
248 resolver := NewEnvironmentVariableResolver(env)
249 err := cfg.configureProviders(env, resolver, knownProviders)
250 require.Error(t, err)
251}
252
253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
254 knownProviders := []catwalk.Provider{
255 {
256 ID: catwalk.InferenceProviderVertexAI,
257 APIKey: "",
258 APIEndpoint: "",
259 Models: []catwalk.Model{{
260 ID: "gemini-pro",
261 }},
262 },
263 }
264
265 cfg := &Config{}
266 cfg.setDefaults("/tmp", "")
267 env := env.NewFromMap(map[string]string{
268 "VERTEXAI_PROJECT": "test-project",
269 "VERTEXAI_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 require.NoError(t, err)
274 require.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 require.True(t, ok, "VertexAI provider should be present")
278 require.Len(t, vertexProvider.Models, 1)
279 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp", "")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 require.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 require.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp", "")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 require.NoError(t, err)
331 // Provider should not be configured without project
332 require.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp", "")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 require.NoError(t, err)
355 require.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 require.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 require.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 require.Len(t, enabled, 1)
401 require.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 require.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 require.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 require.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 require.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
458 cfg := &Config{
459 Options: &Options{
460 DisabledTools: []string{},
461 },
462 }
463
464 cfg.SetupAgents()
465 coderAgent, ok := cfg.Agents[AgentCoder]
466 require.True(t, ok)
467 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
468
469 taskAgent, ok := cfg.Agents[AgentTask]
470 require.True(t, ok)
471 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
472
473 fetchAgent, ok := cfg.Agents[AgentFetch]
474 require.True(t, ok)
475 assert.Equal(t, "fetch", fetchAgent.ID)
476 assert.Equal(t, SelectedModelTypeSmall, fetchAgent.Model)
477 assert.Equal(t, []string{"web_fetch", "grep", "glob", "view"}, fetchAgent.AllowedTools)
478 assert.Empty(t, fetchAgent.ContextPaths)
479}
480
481func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
482 cfg := &Config{
483 Options: &Options{
484 DisabledTools: []string{
485 "edit",
486 "download",
487 "grep",
488 },
489 },
490 }
491
492 cfg.SetupAgents()
493 coderAgent, ok := cfg.Agents[AgentCoder]
494 require.True(t, ok)
495 assert.Equal(t, []string{"agent", "bash", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
496
497 taskAgent, ok := cfg.Agents[AgentTask]
498 require.True(t, ok)
499 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
500}
501
502func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
503 cfg := &Config{
504 Options: &Options{
505 DisabledTools: []string{
506 "glob",
507 "grep",
508 "ls",
509 "sourcegraph",
510 "view",
511 },
512 },
513 }
514
515 cfg.SetupAgents()
516 coderAgent, ok := cfg.Agents[AgentCoder]
517 require.True(t, ok)
518 assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "write"}, coderAgent.AllowedTools)
519
520 taskAgent, ok := cfg.Agents[AgentTask]
521 require.True(t, ok)
522 assert.Equal(t, []string{}, taskAgent.AllowedTools)
523}
524
525func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
526 knownProviders := []catwalk.Provider{
527 {
528 ID: "openai",
529 APIKey: "$OPENAI_API_KEY",
530 APIEndpoint: "https://api.openai.com/v1",
531 Models: []catwalk.Model{{
532 ID: "test-model",
533 }},
534 },
535 }
536
537 cfg := &Config{
538 Providers: csync.NewMapFrom(map[string]ProviderConfig{
539 "openai": {
540 Disable: true,
541 },
542 }),
543 }
544 cfg.setDefaults("/tmp", "")
545
546 env := env.NewFromMap(map[string]string{
547 "OPENAI_API_KEY": "test-key",
548 })
549 resolver := NewEnvironmentVariableResolver(env)
550 err := cfg.configureProviders(env, resolver, knownProviders)
551 require.NoError(t, err)
552
553 require.Equal(t, cfg.Providers.Len(), 1)
554 prov, exists := cfg.Providers.Get("openai")
555 require.True(t, exists)
556 require.True(t, prov.Disable)
557}
558
559func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
560 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
561 cfg := &Config{
562 Providers: csync.NewMapFrom(map[string]ProviderConfig{
563 "custom": {
564 BaseURL: "https://api.custom.com/v1",
565 Models: []catwalk.Model{{
566 ID: "test-model",
567 }},
568 },
569 "openai": {
570 APIKey: "$MISSING",
571 },
572 }),
573 }
574 cfg.setDefaults("/tmp", "")
575
576 env := env.NewFromMap(map[string]string{})
577 resolver := NewEnvironmentVariableResolver(env)
578 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
579 require.NoError(t, err)
580
581 require.Equal(t, cfg.Providers.Len(), 1)
582 _, exists := cfg.Providers.Get("custom")
583 require.True(t, exists)
584 })
585
586 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
587 cfg := &Config{
588 Providers: csync.NewMapFrom(map[string]ProviderConfig{
589 "custom": {
590 APIKey: "test-key",
591 Models: []catwalk.Model{{
592 ID: "test-model",
593 }},
594 },
595 }),
596 }
597 cfg.setDefaults("/tmp", "")
598
599 env := env.NewFromMap(map[string]string{})
600 resolver := NewEnvironmentVariableResolver(env)
601 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
602 require.NoError(t, err)
603
604 require.Equal(t, cfg.Providers.Len(), 0)
605 _, exists := cfg.Providers.Get("custom")
606 require.False(t, exists)
607 })
608
609 t.Run("custom provider with no models is removed", func(t *testing.T) {
610 cfg := &Config{
611 Providers: csync.NewMapFrom(map[string]ProviderConfig{
612 "custom": {
613 APIKey: "test-key",
614 BaseURL: "https://api.custom.com/v1",
615 Models: []catwalk.Model{},
616 },
617 }),
618 }
619 cfg.setDefaults("/tmp", "")
620
621 env := env.NewFromMap(map[string]string{})
622 resolver := NewEnvironmentVariableResolver(env)
623 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
624 require.NoError(t, err)
625
626 require.Equal(t, cfg.Providers.Len(), 0)
627 _, exists := cfg.Providers.Get("custom")
628 require.False(t, exists)
629 })
630
631 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
632 cfg := &Config{
633 Providers: csync.NewMapFrom(map[string]ProviderConfig{
634 "custom": {
635 APIKey: "test-key",
636 BaseURL: "https://api.custom.com/v1",
637 Type: "unsupported",
638 Models: []catwalk.Model{{
639 ID: "test-model",
640 }},
641 },
642 }),
643 }
644 cfg.setDefaults("/tmp", "")
645
646 env := env.NewFromMap(map[string]string{})
647 resolver := NewEnvironmentVariableResolver(env)
648 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
649 require.NoError(t, err)
650
651 require.Equal(t, cfg.Providers.Len(), 0)
652 _, exists := cfg.Providers.Get("custom")
653 require.False(t, exists)
654 })
655
656 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
657 cfg := &Config{
658 Providers: csync.NewMapFrom(map[string]ProviderConfig{
659 "custom": {
660 APIKey: "test-key",
661 BaseURL: "https://api.custom.com/v1",
662 Type: catwalk.TypeOpenAI,
663 Models: []catwalk.Model{{
664 ID: "test-model",
665 }},
666 },
667 }),
668 }
669 cfg.setDefaults("/tmp", "")
670
671 env := env.NewFromMap(map[string]string{})
672 resolver := NewEnvironmentVariableResolver(env)
673 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
674 require.NoError(t, err)
675
676 require.Equal(t, cfg.Providers.Len(), 1)
677 customProvider, exists := cfg.Providers.Get("custom")
678 require.True(t, exists)
679 require.Equal(t, "custom", customProvider.ID)
680 require.Equal(t, "test-key", customProvider.APIKey)
681 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
682 })
683
684 t.Run("custom anthropic provider is supported", func(t *testing.T) {
685 cfg := &Config{
686 Providers: csync.NewMapFrom(map[string]ProviderConfig{
687 "custom-anthropic": {
688 APIKey: "test-key",
689 BaseURL: "https://api.anthropic.com/v1",
690 Type: catwalk.TypeAnthropic,
691 Models: []catwalk.Model{{
692 ID: "claude-3-sonnet",
693 }},
694 },
695 }),
696 }
697 cfg.setDefaults("/tmp", "")
698
699 env := env.NewFromMap(map[string]string{})
700 resolver := NewEnvironmentVariableResolver(env)
701 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
702 require.NoError(t, err)
703
704 require.Equal(t, cfg.Providers.Len(), 1)
705 customProvider, exists := cfg.Providers.Get("custom-anthropic")
706 require.True(t, exists)
707 require.Equal(t, "custom-anthropic", customProvider.ID)
708 require.Equal(t, "test-key", customProvider.APIKey)
709 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
710 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
711 })
712
713 t.Run("disabled custom provider is removed", func(t *testing.T) {
714 cfg := &Config{
715 Providers: csync.NewMapFrom(map[string]ProviderConfig{
716 "custom": {
717 APIKey: "test-key",
718 BaseURL: "https://api.custom.com/v1",
719 Type: catwalk.TypeOpenAI,
720 Disable: true,
721 Models: []catwalk.Model{{
722 ID: "test-model",
723 }},
724 },
725 }),
726 }
727 cfg.setDefaults("/tmp", "")
728
729 env := env.NewFromMap(map[string]string{})
730 resolver := NewEnvironmentVariableResolver(env)
731 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
732 require.NoError(t, err)
733
734 require.Equal(t, cfg.Providers.Len(), 0)
735 _, exists := cfg.Providers.Get("custom")
736 require.False(t, exists)
737 })
738}
739
740func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
741 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
742 knownProviders := []catwalk.Provider{
743 {
744 ID: catwalk.InferenceProviderVertexAI,
745 APIKey: "",
746 APIEndpoint: "",
747 Models: []catwalk.Model{{
748 ID: "gemini-pro",
749 }},
750 },
751 }
752
753 cfg := &Config{
754 Providers: csync.NewMapFrom(map[string]ProviderConfig{
755 "vertexai": {
756 BaseURL: "custom-url",
757 },
758 }),
759 }
760 cfg.setDefaults("/tmp", "")
761
762 env := env.NewFromMap(map[string]string{
763 "GOOGLE_GENAI_USE_VERTEXAI": "false",
764 })
765 resolver := NewEnvironmentVariableResolver(env)
766 err := cfg.configureProviders(env, resolver, knownProviders)
767 require.NoError(t, err)
768
769 require.Equal(t, cfg.Providers.Len(), 0)
770 _, exists := cfg.Providers.Get("vertexai")
771 require.False(t, exists)
772 })
773
774 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
775 knownProviders := []catwalk.Provider{
776 {
777 ID: catwalk.InferenceProviderBedrock,
778 APIKey: "",
779 APIEndpoint: "",
780 Models: []catwalk.Model{{
781 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
782 }},
783 },
784 }
785
786 cfg := &Config{
787 Providers: csync.NewMapFrom(map[string]ProviderConfig{
788 "bedrock": {
789 BaseURL: "custom-url",
790 },
791 }),
792 }
793 cfg.setDefaults("/tmp", "")
794
795 env := env.NewFromMap(map[string]string{})
796 resolver := NewEnvironmentVariableResolver(env)
797 err := cfg.configureProviders(env, resolver, knownProviders)
798 require.NoError(t, err)
799
800 require.Equal(t, cfg.Providers.Len(), 0)
801 _, exists := cfg.Providers.Get("bedrock")
802 require.False(t, exists)
803 })
804
805 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
806 knownProviders := []catwalk.Provider{
807 {
808 ID: "openai",
809 APIKey: "$MISSING_API_KEY",
810 APIEndpoint: "https://api.openai.com/v1",
811 Models: []catwalk.Model{{
812 ID: "test-model",
813 }},
814 },
815 }
816
817 cfg := &Config{
818 Providers: csync.NewMapFrom(map[string]ProviderConfig{
819 "openai": {
820 BaseURL: "custom-url",
821 },
822 }),
823 }
824 cfg.setDefaults("/tmp", "")
825
826 env := env.NewFromMap(map[string]string{})
827 resolver := NewEnvironmentVariableResolver(env)
828 err := cfg.configureProviders(env, resolver, knownProviders)
829 require.NoError(t, err)
830
831 require.Equal(t, cfg.Providers.Len(), 0)
832 _, exists := cfg.Providers.Get("openai")
833 require.False(t, exists)
834 })
835
836 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
837 knownProviders := []catwalk.Provider{
838 {
839 ID: "openai",
840 APIKey: "$OPENAI_API_KEY",
841 APIEndpoint: "$MISSING_ENDPOINT",
842 Models: []catwalk.Model{{
843 ID: "test-model",
844 }},
845 },
846 }
847
848 cfg := &Config{
849 Providers: csync.NewMapFrom(map[string]ProviderConfig{
850 "openai": {
851 APIKey: "test-key",
852 },
853 }),
854 }
855 cfg.setDefaults("/tmp", "")
856
857 env := env.NewFromMap(map[string]string{
858 "OPENAI_API_KEY": "test-key",
859 })
860 resolver := NewEnvironmentVariableResolver(env)
861 err := cfg.configureProviders(env, resolver, knownProviders)
862 require.NoError(t, err)
863
864 require.Equal(t, cfg.Providers.Len(), 1)
865 _, exists := cfg.Providers.Get("openai")
866 require.True(t, exists)
867 })
868}
869
870func TestConfig_defaultModelSelection(t *testing.T) {
871 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
872 knownProviders := []catwalk.Provider{
873 {
874 ID: "openai",
875 APIKey: "abc",
876 DefaultLargeModelID: "large-model",
877 DefaultSmallModelID: "small-model",
878 Models: []catwalk.Model{
879 {
880 ID: "large-model",
881 DefaultMaxTokens: 1000,
882 },
883 {
884 ID: "small-model",
885 DefaultMaxTokens: 500,
886 },
887 },
888 },
889 }
890
891 cfg := &Config{}
892 cfg.setDefaults("/tmp", "")
893 env := env.NewFromMap(map[string]string{})
894 resolver := NewEnvironmentVariableResolver(env)
895 err := cfg.configureProviders(env, resolver, knownProviders)
896 require.NoError(t, err)
897
898 large, small, err := cfg.defaultModelSelection(knownProviders)
899 require.NoError(t, err)
900 require.Equal(t, "large-model", large.Model)
901 require.Equal(t, "openai", large.Provider)
902 require.Equal(t, int64(1000), large.MaxTokens)
903 require.Equal(t, "small-model", small.Model)
904 require.Equal(t, "openai", small.Provider)
905 require.Equal(t, int64(500), small.MaxTokens)
906 })
907 t.Run("should error if no providers configured", func(t *testing.T) {
908 knownProviders := []catwalk.Provider{
909 {
910 ID: "openai",
911 APIKey: "$MISSING_KEY",
912 DefaultLargeModelID: "large-model",
913 DefaultSmallModelID: "small-model",
914 Models: []catwalk.Model{
915 {
916 ID: "large-model",
917 DefaultMaxTokens: 1000,
918 },
919 {
920 ID: "small-model",
921 DefaultMaxTokens: 500,
922 },
923 },
924 },
925 }
926
927 cfg := &Config{}
928 cfg.setDefaults("/tmp", "")
929 env := env.NewFromMap(map[string]string{})
930 resolver := NewEnvironmentVariableResolver(env)
931 err := cfg.configureProviders(env, resolver, knownProviders)
932 require.NoError(t, err)
933
934 _, _, err = cfg.defaultModelSelection(knownProviders)
935 require.Error(t, err)
936 })
937 t.Run("should error if model is missing", func(t *testing.T) {
938 knownProviders := []catwalk.Provider{
939 {
940 ID: "openai",
941 APIKey: "abc",
942 DefaultLargeModelID: "large-model",
943 DefaultSmallModelID: "small-model",
944 Models: []catwalk.Model{
945 {
946 ID: "not-large-model",
947 DefaultMaxTokens: 1000,
948 },
949 {
950 ID: "small-model",
951 DefaultMaxTokens: 500,
952 },
953 },
954 },
955 }
956
957 cfg := &Config{}
958 cfg.setDefaults("/tmp", "")
959 env := env.NewFromMap(map[string]string{})
960 resolver := NewEnvironmentVariableResolver(env)
961 err := cfg.configureProviders(env, resolver, knownProviders)
962 require.NoError(t, err)
963 _, _, err = cfg.defaultModelSelection(knownProviders)
964 require.Error(t, err)
965 })
966
967 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
968 knownProviders := []catwalk.Provider{
969 {
970 ID: "openai",
971 APIKey: "$MISSING", // will not be included in the config
972 DefaultLargeModelID: "large-model",
973 DefaultSmallModelID: "small-model",
974 Models: []catwalk.Model{
975 {
976 ID: "not-large-model",
977 DefaultMaxTokens: 1000,
978 },
979 {
980 ID: "small-model",
981 DefaultMaxTokens: 500,
982 },
983 },
984 },
985 }
986
987 cfg := &Config{
988 Providers: csync.NewMapFrom(map[string]ProviderConfig{
989 "custom": {
990 APIKey: "test-key",
991 BaseURL: "https://api.custom.com/v1",
992 Models: []catwalk.Model{
993 {
994 ID: "model",
995 DefaultMaxTokens: 600,
996 },
997 },
998 },
999 }),
1000 }
1001 cfg.setDefaults("/tmp", "")
1002 env := env.NewFromMap(map[string]string{})
1003 resolver := NewEnvironmentVariableResolver(env)
1004 err := cfg.configureProviders(env, resolver, knownProviders)
1005 require.NoError(t, err)
1006 large, small, err := cfg.defaultModelSelection(knownProviders)
1007 require.NoError(t, err)
1008 require.Equal(t, "model", large.Model)
1009 require.Equal(t, "custom", large.Provider)
1010 require.Equal(t, int64(600), large.MaxTokens)
1011 require.Equal(t, "model", small.Model)
1012 require.Equal(t, "custom", small.Provider)
1013 require.Equal(t, int64(600), small.MaxTokens)
1014 })
1015
1016 t.Run("should fail if no model configured", func(t *testing.T) {
1017 knownProviders := []catwalk.Provider{
1018 {
1019 ID: "openai",
1020 APIKey: "$MISSING", // will not be included in the config
1021 DefaultLargeModelID: "large-model",
1022 DefaultSmallModelID: "small-model",
1023 Models: []catwalk.Model{
1024 {
1025 ID: "not-large-model",
1026 DefaultMaxTokens: 1000,
1027 },
1028 {
1029 ID: "small-model",
1030 DefaultMaxTokens: 500,
1031 },
1032 },
1033 },
1034 }
1035
1036 cfg := &Config{
1037 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1038 "custom": {
1039 APIKey: "test-key",
1040 BaseURL: "https://api.custom.com/v1",
1041 Models: []catwalk.Model{},
1042 },
1043 }),
1044 }
1045 cfg.setDefaults("/tmp", "")
1046 env := env.NewFromMap(map[string]string{})
1047 resolver := NewEnvironmentVariableResolver(env)
1048 err := cfg.configureProviders(env, resolver, knownProviders)
1049 require.NoError(t, err)
1050 _, _, err = cfg.defaultModelSelection(knownProviders)
1051 require.Error(t, err)
1052 })
1053 t.Run("should use the default provider first", func(t *testing.T) {
1054 knownProviders := []catwalk.Provider{
1055 {
1056 ID: "openai",
1057 APIKey: "set",
1058 DefaultLargeModelID: "large-model",
1059 DefaultSmallModelID: "small-model",
1060 Models: []catwalk.Model{
1061 {
1062 ID: "large-model",
1063 DefaultMaxTokens: 1000,
1064 },
1065 {
1066 ID: "small-model",
1067 DefaultMaxTokens: 500,
1068 },
1069 },
1070 },
1071 }
1072
1073 cfg := &Config{
1074 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1075 "custom": {
1076 APIKey: "test-key",
1077 BaseURL: "https://api.custom.com/v1",
1078 Models: []catwalk.Model{
1079 {
1080 ID: "large-model",
1081 DefaultMaxTokens: 1000,
1082 },
1083 },
1084 },
1085 }),
1086 }
1087 cfg.setDefaults("/tmp", "")
1088 env := env.NewFromMap(map[string]string{})
1089 resolver := NewEnvironmentVariableResolver(env)
1090 err := cfg.configureProviders(env, resolver, knownProviders)
1091 require.NoError(t, err)
1092 large, small, err := cfg.defaultModelSelection(knownProviders)
1093 require.NoError(t, err)
1094 require.Equal(t, "large-model", large.Model)
1095 require.Equal(t, "openai", large.Provider)
1096 require.Equal(t, int64(1000), large.MaxTokens)
1097 require.Equal(t, "small-model", small.Model)
1098 require.Equal(t, "openai", small.Provider)
1099 require.Equal(t, int64(500), small.MaxTokens)
1100 })
1101}
1102
1103func TestConfig_configureSelectedModels(t *testing.T) {
1104 t.Run("should override defaults", func(t *testing.T) {
1105 knownProviders := []catwalk.Provider{
1106 {
1107 ID: "openai",
1108 APIKey: "abc",
1109 DefaultLargeModelID: "large-model",
1110 DefaultSmallModelID: "small-model",
1111 Models: []catwalk.Model{
1112 {
1113 ID: "larger-model",
1114 DefaultMaxTokens: 2000,
1115 },
1116 {
1117 ID: "large-model",
1118 DefaultMaxTokens: 1000,
1119 },
1120 {
1121 ID: "small-model",
1122 DefaultMaxTokens: 500,
1123 },
1124 },
1125 },
1126 }
1127
1128 cfg := &Config{
1129 Models: map[SelectedModelType]SelectedModel{
1130 "large": {
1131 Model: "larger-model",
1132 },
1133 },
1134 }
1135 cfg.setDefaults("/tmp", "")
1136 env := env.NewFromMap(map[string]string{})
1137 resolver := NewEnvironmentVariableResolver(env)
1138 err := cfg.configureProviders(env, resolver, knownProviders)
1139 require.NoError(t, err)
1140
1141 err = cfg.configureSelectedModels(knownProviders)
1142 require.NoError(t, err)
1143 large := cfg.Models[SelectedModelTypeLarge]
1144 small := cfg.Models[SelectedModelTypeSmall]
1145 require.Equal(t, "larger-model", large.Model)
1146 require.Equal(t, "openai", large.Provider)
1147 require.Equal(t, int64(2000), large.MaxTokens)
1148 require.Equal(t, "small-model", small.Model)
1149 require.Equal(t, "openai", small.Provider)
1150 require.Equal(t, int64(500), small.MaxTokens)
1151 })
1152 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1153 knownProviders := []catwalk.Provider{
1154 {
1155 ID: "openai",
1156 APIKey: "abc",
1157 DefaultLargeModelID: "large-model",
1158 DefaultSmallModelID: "small-model",
1159 Models: []catwalk.Model{
1160 {
1161 ID: "large-model",
1162 DefaultMaxTokens: 1000,
1163 },
1164 {
1165 ID: "small-model",
1166 DefaultMaxTokens: 500,
1167 },
1168 },
1169 },
1170 {
1171 ID: "anthropic",
1172 APIKey: "abc",
1173 DefaultLargeModelID: "a-large-model",
1174 DefaultSmallModelID: "a-small-model",
1175 Models: []catwalk.Model{
1176 {
1177 ID: "a-large-model",
1178 DefaultMaxTokens: 1000,
1179 },
1180 {
1181 ID: "a-small-model",
1182 DefaultMaxTokens: 200,
1183 },
1184 },
1185 },
1186 }
1187
1188 cfg := &Config{
1189 Models: map[SelectedModelType]SelectedModel{
1190 "small": {
1191 Model: "a-small-model",
1192 Provider: "anthropic",
1193 MaxTokens: 300,
1194 },
1195 },
1196 }
1197 cfg.setDefaults("/tmp", "")
1198 env := env.NewFromMap(map[string]string{})
1199 resolver := NewEnvironmentVariableResolver(env)
1200 err := cfg.configureProviders(env, resolver, knownProviders)
1201 require.NoError(t, err)
1202
1203 err = cfg.configureSelectedModels(knownProviders)
1204 require.NoError(t, err)
1205 large := cfg.Models[SelectedModelTypeLarge]
1206 small := cfg.Models[SelectedModelTypeSmall]
1207 require.Equal(t, "large-model", large.Model)
1208 require.Equal(t, "openai", large.Provider)
1209 require.Equal(t, int64(1000), large.MaxTokens)
1210 require.Equal(t, "a-small-model", small.Model)
1211 require.Equal(t, "anthropic", small.Provider)
1212 require.Equal(t, int64(300), small.MaxTokens)
1213 })
1214
1215 t.Run("should override the max tokens only", func(t *testing.T) {
1216 knownProviders := []catwalk.Provider{
1217 {
1218 ID: "openai",
1219 APIKey: "abc",
1220 DefaultLargeModelID: "large-model",
1221 DefaultSmallModelID: "small-model",
1222 Models: []catwalk.Model{
1223 {
1224 ID: "large-model",
1225 DefaultMaxTokens: 1000,
1226 },
1227 {
1228 ID: "small-model",
1229 DefaultMaxTokens: 500,
1230 },
1231 },
1232 },
1233 }
1234
1235 cfg := &Config{
1236 Models: map[SelectedModelType]SelectedModel{
1237 "large": {
1238 MaxTokens: 100,
1239 },
1240 },
1241 }
1242 cfg.setDefaults("/tmp", "")
1243 env := env.NewFromMap(map[string]string{})
1244 resolver := NewEnvironmentVariableResolver(env)
1245 err := cfg.configureProviders(env, resolver, knownProviders)
1246 require.NoError(t, err)
1247
1248 err = cfg.configureSelectedModels(knownProviders)
1249 require.NoError(t, err)
1250 large := cfg.Models[SelectedModelTypeLarge]
1251 require.Equal(t, "large-model", large.Model)
1252 require.Equal(t, "openai", large.Provider)
1253 require.Equal(t, int64(100), large.MaxTokens)
1254 })
1255}