1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "testing"
9
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/charmbracelet/crush/internal/env"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func TestMain(m *testing.M) {
18 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
19
20 exitVal := m.Run()
21 os.Exit(exitVal)
22}
23
24func TestConfig_setDefaults(t *testing.T) {
25 cfg := &Config{}
26
27 cfg.setDefaults("/tmp", "")
28
29 require.NotNil(t, cfg.Options)
30 require.NotNil(t, cfg.Options.TUI)
31 require.NotNil(t, cfg.Options.ContextPaths)
32 require.NotNil(t, cfg.Providers)
33 require.NotNil(t, cfg.Models)
34 require.NotNil(t, cfg.LSP)
35 require.NotNil(t, cfg.MCP)
36 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
37 for _, path := range defaultContextPaths {
38 require.Contains(t, cfg.Options.ContextPaths, path)
39 }
40 require.Equal(t, "/tmp", cfg.workingDir)
41}
42
43func TestConfig_configureProviders(t *testing.T) {
44 knownProviders := []catwalk.Provider{
45 {
46 ID: "openai",
47 APIKey: "$OPENAI_API_KEY",
48 APIEndpoint: "https://api.openai.com/v1",
49 Models: []catwalk.Model{{
50 ID: "test-model",
51 }},
52 },
53 }
54
55 cfg := &Config{}
56 cfg.setDefaults("/tmp", "")
57 env := env.NewFromMap(map[string]string{
58 "OPENAI_API_KEY": "test-key",
59 })
60 resolver := NewEnvironmentVariableResolver(env)
61 err := cfg.configureProviders(env, resolver, knownProviders)
62 require.NoError(t, err)
63 require.Equal(t, 1, cfg.Providers.Len())
64
65 // We want to make sure that we keep the configured API key as a placeholder
66 pc, _ := cfg.Providers.Get("openai")
67 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
68}
69
70func TestConfig_configureProvidersWithOverride(t *testing.T) {
71 knownProviders := []catwalk.Provider{
72 {
73 ID: "openai",
74 APIKey: "$OPENAI_API_KEY",
75 APIEndpoint: "https://api.openai.com/v1",
76 Models: []catwalk.Model{{
77 ID: "test-model",
78 }},
79 },
80 }
81
82 cfg := &Config{
83 Providers: csync.NewMap[string, ProviderConfig](),
84 }
85 cfg.Providers.Set("openai", ProviderConfig{
86 APIKey: "xyz",
87 BaseURL: "https://api.openai.com/v2",
88 Models: []catwalk.Model{
89 {
90 ID: "test-model",
91 Name: "Updated",
92 },
93 {
94 ID: "another-model",
95 },
96 },
97 })
98 cfg.setDefaults("/tmp", "")
99
100 env := env.NewFromMap(map[string]string{
101 "OPENAI_API_KEY": "test-key",
102 })
103 resolver := NewEnvironmentVariableResolver(env)
104 err := cfg.configureProviders(env, resolver, knownProviders)
105 require.NoError(t, err)
106 require.Equal(t, 1, cfg.Providers.Len())
107
108 // We want to make sure that we keep the configured API key as a placeholder
109 pc, _ := cfg.Providers.Get("openai")
110 require.Equal(t, "xyz", pc.APIKey)
111 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
112 require.Len(t, pc.Models, 2)
113 require.Equal(t, "Updated", pc.Models[0].Name)
114}
115
116func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
117 knownProviders := []catwalk.Provider{
118 {
119 ID: "openai",
120 APIKey: "$OPENAI_API_KEY",
121 APIEndpoint: "https://api.openai.com/v1",
122 Models: []catwalk.Model{{
123 ID: "test-model",
124 }},
125 },
126 }
127
128 cfg := &Config{
129 Providers: csync.NewMapFrom(map[string]ProviderConfig{
130 "custom": {
131 APIKey: "xyz",
132 BaseURL: "https://api.someendpoint.com/v2",
133 Models: []catwalk.Model{
134 {
135 ID: "test-model",
136 },
137 },
138 },
139 }),
140 }
141 cfg.setDefaults("/tmp", "")
142 env := env.NewFromMap(map[string]string{
143 "OPENAI_API_KEY": "test-key",
144 })
145 resolver := NewEnvironmentVariableResolver(env)
146 err := cfg.configureProviders(env, resolver, knownProviders)
147 require.NoError(t, err)
148 // Should be to because of the env variable
149 require.Equal(t, cfg.Providers.Len(), 2)
150
151 // We want to make sure that we keep the configured API key as a placeholder
152 pc, _ := cfg.Providers.Get("custom")
153 require.Equal(t, "xyz", pc.APIKey)
154 // Make sure we set the ID correctly
155 require.Equal(t, "custom", pc.ID)
156 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
157 require.Len(t, pc.Models, 1)
158
159 _, ok := cfg.Providers.Get("openai")
160 require.True(t, ok, "OpenAI provider should still be present")
161}
162
163func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
164 knownProviders := []catwalk.Provider{
165 {
166 ID: catwalk.InferenceProviderBedrock,
167 APIKey: "",
168 APIEndpoint: "",
169 Models: []catwalk.Model{{
170 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
171 }},
172 },
173 }
174
175 cfg := &Config{}
176 cfg.setDefaults("/tmp", "")
177 env := env.NewFromMap(map[string]string{
178 "AWS_ACCESS_KEY_ID": "test-key-id",
179 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
180 })
181 resolver := NewEnvironmentVariableResolver(env)
182 err := cfg.configureProviders(env, resolver, knownProviders)
183 require.NoError(t, err)
184 require.Equal(t, cfg.Providers.Len(), 1)
185
186 bedrockProvider, ok := cfg.Providers.Get("bedrock")
187 require.True(t, ok, "Bedrock provider should be present")
188 require.Len(t, bedrockProvider.Models, 1)
189 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
190}
191
192func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
193 knownProviders := []catwalk.Provider{
194 {
195 ID: catwalk.InferenceProviderBedrock,
196 APIKey: "",
197 APIEndpoint: "",
198 Models: []catwalk.Model{{
199 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
200 }},
201 },
202 }
203
204 cfg := &Config{}
205 cfg.setDefaults("/tmp", "")
206 env := env.NewFromMap(map[string]string{})
207 resolver := NewEnvironmentVariableResolver(env)
208 err := cfg.configureProviders(env, resolver, knownProviders)
209 require.NoError(t, err)
210 // Provider should not be configured without credentials
211 require.Equal(t, cfg.Providers.Len(), 0)
212}
213
214func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
215 knownProviders := []catwalk.Provider{
216 {
217 ID: catwalk.InferenceProviderBedrock,
218 APIKey: "",
219 APIEndpoint: "",
220 Models: []catwalk.Model{{
221 ID: "some-random-model",
222 }},
223 },
224 }
225
226 cfg := &Config{}
227 cfg.setDefaults("/tmp", "")
228 env := env.NewFromMap(map[string]string{
229 "AWS_ACCESS_KEY_ID": "test-key-id",
230 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
231 })
232 resolver := NewEnvironmentVariableResolver(env)
233 err := cfg.configureProviders(env, resolver, knownProviders)
234 require.Error(t, err)
235}
236
237func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
238 knownProviders := []catwalk.Provider{
239 {
240 ID: catwalk.InferenceProviderVertexAI,
241 APIKey: "",
242 APIEndpoint: "",
243 Models: []catwalk.Model{{
244 ID: "gemini-pro",
245 }},
246 },
247 }
248
249 cfg := &Config{}
250 cfg.setDefaults("/tmp", "")
251 env := env.NewFromMap(map[string]string{
252 "VERTEXAI_PROJECT": "test-project",
253 "VERTEXAI_LOCATION": "us-central1",
254 })
255 resolver := NewEnvironmentVariableResolver(env)
256 err := cfg.configureProviders(env, resolver, knownProviders)
257 require.NoError(t, err)
258 require.Equal(t, cfg.Providers.Len(), 1)
259
260 vertexProvider, ok := cfg.Providers.Get("vertexai")
261 require.True(t, ok, "VertexAI provider should be present")
262 require.Len(t, vertexProvider.Models, 1)
263 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
264 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
265 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
266}
267
268func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
269 knownProviders := []catwalk.Provider{
270 {
271 ID: catwalk.InferenceProviderVertexAI,
272 APIKey: "",
273 APIEndpoint: "",
274 Models: []catwalk.Model{{
275 ID: "gemini-pro",
276 }},
277 },
278 }
279
280 cfg := &Config{}
281 cfg.setDefaults("/tmp", "")
282 env := env.NewFromMap(map[string]string{
283 "GOOGLE_GENAI_USE_VERTEXAI": "false",
284 "GOOGLE_CLOUD_PROJECT": "test-project",
285 "GOOGLE_CLOUD_LOCATION": "us-central1",
286 })
287 resolver := NewEnvironmentVariableResolver(env)
288 err := cfg.configureProviders(env, resolver, knownProviders)
289 require.NoError(t, err)
290 // Provider should not be configured without proper credentials
291 require.Equal(t, cfg.Providers.Len(), 0)
292}
293
294func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
295 knownProviders := []catwalk.Provider{
296 {
297 ID: catwalk.InferenceProviderVertexAI,
298 APIKey: "",
299 APIEndpoint: "",
300 Models: []catwalk.Model{{
301 ID: "gemini-pro",
302 }},
303 },
304 }
305
306 cfg := &Config{}
307 cfg.setDefaults("/tmp", "")
308 env := env.NewFromMap(map[string]string{
309 "GOOGLE_GENAI_USE_VERTEXAI": "true",
310 "GOOGLE_CLOUD_LOCATION": "us-central1",
311 })
312 resolver := NewEnvironmentVariableResolver(env)
313 err := cfg.configureProviders(env, resolver, knownProviders)
314 require.NoError(t, err)
315 // Provider should not be configured without project
316 require.Equal(t, cfg.Providers.Len(), 0)
317}
318
319func TestConfig_configureProvidersSetProviderID(t *testing.T) {
320 knownProviders := []catwalk.Provider{
321 {
322 ID: "openai",
323 APIKey: "$OPENAI_API_KEY",
324 APIEndpoint: "https://api.openai.com/v1",
325 Models: []catwalk.Model{{
326 ID: "test-model",
327 }},
328 },
329 }
330
331 cfg := &Config{}
332 cfg.setDefaults("/tmp", "")
333 env := env.NewFromMap(map[string]string{
334 "OPENAI_API_KEY": "test-key",
335 })
336 resolver := NewEnvironmentVariableResolver(env)
337 err := cfg.configureProviders(env, resolver, knownProviders)
338 require.NoError(t, err)
339 require.Equal(t, cfg.Providers.Len(), 1)
340
341 // Provider ID should be set
342 pc, _ := cfg.Providers.Get("openai")
343 require.Equal(t, "openai", pc.ID)
344}
345
346func TestConfig_EnabledProviders(t *testing.T) {
347 t.Run("all providers enabled", func(t *testing.T) {
348 cfg := &Config{
349 Providers: csync.NewMapFrom(map[string]ProviderConfig{
350 "openai": {
351 ID: "openai",
352 APIKey: "key1",
353 Disable: false,
354 },
355 "anthropic": {
356 ID: "anthropic",
357 APIKey: "key2",
358 Disable: false,
359 },
360 }),
361 }
362
363 enabled := cfg.EnabledProviders()
364 require.Len(t, enabled, 2)
365 })
366
367 t.Run("some providers disabled", func(t *testing.T) {
368 cfg := &Config{
369 Providers: csync.NewMapFrom(map[string]ProviderConfig{
370 "openai": {
371 ID: "openai",
372 APIKey: "key1",
373 Disable: false,
374 },
375 "anthropic": {
376 ID: "anthropic",
377 APIKey: "key2",
378 Disable: true,
379 },
380 }),
381 }
382
383 enabled := cfg.EnabledProviders()
384 require.Len(t, enabled, 1)
385 require.Equal(t, "openai", enabled[0].ID)
386 })
387
388 t.Run("empty providers map", func(t *testing.T) {
389 cfg := &Config{
390 Providers: csync.NewMap[string, ProviderConfig](),
391 }
392
393 enabled := cfg.EnabledProviders()
394 require.Len(t, enabled, 0)
395 })
396}
397
398func TestConfig_IsConfigured(t *testing.T) {
399 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
400 cfg := &Config{
401 Providers: csync.NewMapFrom(map[string]ProviderConfig{
402 "openai": {
403 ID: "openai",
404 APIKey: "key1",
405 Disable: false,
406 },
407 }),
408 }
409
410 require.True(t, cfg.IsConfigured())
411 })
412
413 t.Run("returns false when no providers are configured", func(t *testing.T) {
414 cfg := &Config{
415 Providers: csync.NewMap[string, ProviderConfig](),
416 }
417
418 require.False(t, cfg.IsConfigured())
419 })
420
421 t.Run("returns false when all providers are disabled", func(t *testing.T) {
422 cfg := &Config{
423 Providers: csync.NewMapFrom(map[string]ProviderConfig{
424 "openai": {
425 ID: "openai",
426 APIKey: "key1",
427 Disable: true,
428 },
429 "anthropic": {
430 ID: "anthropic",
431 APIKey: "key2",
432 Disable: true,
433 },
434 }),
435 }
436
437 require.False(t, cfg.IsConfigured())
438 })
439}
440
441func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
442 cfg := &Config{
443 Options: &Options{
444 DisabledTools: []string{},
445 },
446 }
447
448 cfg.SetupAgents()
449 coderAgent, ok := cfg.Agents["coder"]
450 require.True(t, ok)
451 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
452
453 taskAgent, ok := cfg.Agents["task"]
454 require.True(t, ok)
455 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
456}
457
458func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
459 cfg := &Config{
460 Options: &Options{
461 DisabledTools: []string{
462 "edit",
463 "download",
464 "grep",
465 },
466 },
467 }
468
469 cfg.SetupAgents()
470 coderAgent, ok := cfg.Agents["coder"]
471 require.True(t, ok)
472 assert.Equal(t, []string{"agent", "bash", "multiedit", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
473
474 taskAgent, ok := cfg.Agents["task"]
475 require.True(t, ok)
476 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
477}
478
479func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
480 cfg := &Config{
481 Options: &Options{
482 DisabledTools: []string{
483 "glob",
484 "grep",
485 "ls",
486 "sourcegraph",
487 "view",
488 },
489 },
490 }
491
492 cfg.SetupAgents()
493 coderAgent, ok := cfg.Agents["coder"]
494 require.True(t, ok)
495 assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "fetch", "write"}, coderAgent.AllowedTools)
496
497 taskAgent, ok := cfg.Agents["task"]
498 require.True(t, ok)
499 assert.Equal(t, []string{}, taskAgent.AllowedTools)
500}
501
502func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
503 knownProviders := []catwalk.Provider{
504 {
505 ID: "openai",
506 APIKey: "$OPENAI_API_KEY",
507 APIEndpoint: "https://api.openai.com/v1",
508 Models: []catwalk.Model{{
509 ID: "test-model",
510 }},
511 },
512 }
513
514 cfg := &Config{
515 Providers: csync.NewMapFrom(map[string]ProviderConfig{
516 "openai": {
517 Disable: true,
518 },
519 }),
520 }
521 cfg.setDefaults("/tmp", "")
522
523 env := env.NewFromMap(map[string]string{
524 "OPENAI_API_KEY": "test-key",
525 })
526 resolver := NewEnvironmentVariableResolver(env)
527 err := cfg.configureProviders(env, resolver, knownProviders)
528 require.NoError(t, err)
529
530 require.Equal(t, cfg.Providers.Len(), 1)
531 prov, exists := cfg.Providers.Get("openai")
532 require.True(t, exists)
533 require.True(t, prov.Disable)
534}
535
536func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
537 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
538 cfg := &Config{
539 Providers: csync.NewMapFrom(map[string]ProviderConfig{
540 "custom": {
541 BaseURL: "https://api.custom.com/v1",
542 Models: []catwalk.Model{{
543 ID: "test-model",
544 }},
545 },
546 "openai": {
547 APIKey: "$MISSING",
548 },
549 }),
550 }
551 cfg.setDefaults("/tmp", "")
552
553 env := env.NewFromMap(map[string]string{})
554 resolver := NewEnvironmentVariableResolver(env)
555 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
556 require.NoError(t, err)
557
558 require.Equal(t, cfg.Providers.Len(), 1)
559 _, exists := cfg.Providers.Get("custom")
560 require.True(t, exists)
561 })
562
563 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
564 cfg := &Config{
565 Providers: csync.NewMapFrom(map[string]ProviderConfig{
566 "custom": {
567 APIKey: "test-key",
568 Models: []catwalk.Model{{
569 ID: "test-model",
570 }},
571 },
572 }),
573 }
574 cfg.setDefaults("/tmp", "")
575
576 env := env.NewFromMap(map[string]string{})
577 resolver := NewEnvironmentVariableResolver(env)
578 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
579 require.NoError(t, err)
580
581 require.Equal(t, cfg.Providers.Len(), 0)
582 _, exists := cfg.Providers.Get("custom")
583 require.False(t, exists)
584 })
585
586 t.Run("custom provider with no models is removed", func(t *testing.T) {
587 cfg := &Config{
588 Providers: csync.NewMapFrom(map[string]ProviderConfig{
589 "custom": {
590 APIKey: "test-key",
591 BaseURL: "https://api.custom.com/v1",
592 Models: []catwalk.Model{},
593 },
594 }),
595 }
596 cfg.setDefaults("/tmp", "")
597
598 env := env.NewFromMap(map[string]string{})
599 resolver := NewEnvironmentVariableResolver(env)
600 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
601 require.NoError(t, err)
602
603 require.Equal(t, cfg.Providers.Len(), 0)
604 _, exists := cfg.Providers.Get("custom")
605 require.False(t, exists)
606 })
607
608 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
609 cfg := &Config{
610 Providers: csync.NewMapFrom(map[string]ProviderConfig{
611 "custom": {
612 APIKey: "test-key",
613 BaseURL: "https://api.custom.com/v1",
614 Type: "unsupported",
615 Models: []catwalk.Model{{
616 ID: "test-model",
617 }},
618 },
619 }),
620 }
621 cfg.setDefaults("/tmp", "")
622
623 env := env.NewFromMap(map[string]string{})
624 resolver := NewEnvironmentVariableResolver(env)
625 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
626 require.NoError(t, err)
627
628 require.Equal(t, cfg.Providers.Len(), 0)
629 _, exists := cfg.Providers.Get("custom")
630 require.False(t, exists)
631 })
632
633 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
634 cfg := &Config{
635 Providers: csync.NewMapFrom(map[string]ProviderConfig{
636 "custom": {
637 APIKey: "test-key",
638 BaseURL: "https://api.custom.com/v1",
639 Type: catwalk.TypeOpenAI,
640 Models: []catwalk.Model{{
641 ID: "test-model",
642 }},
643 },
644 }),
645 }
646 cfg.setDefaults("/tmp", "")
647
648 env := env.NewFromMap(map[string]string{})
649 resolver := NewEnvironmentVariableResolver(env)
650 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
651 require.NoError(t, err)
652
653 require.Equal(t, cfg.Providers.Len(), 1)
654 customProvider, exists := cfg.Providers.Get("custom")
655 require.True(t, exists)
656 require.Equal(t, "custom", customProvider.ID)
657 require.Equal(t, "test-key", customProvider.APIKey)
658 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
659 })
660
661 t.Run("custom anthropic provider is supported", func(t *testing.T) {
662 cfg := &Config{
663 Providers: csync.NewMapFrom(map[string]ProviderConfig{
664 "custom-anthropic": {
665 APIKey: "test-key",
666 BaseURL: "https://api.anthropic.com/v1",
667 Type: catwalk.TypeAnthropic,
668 Models: []catwalk.Model{{
669 ID: "claude-3-sonnet",
670 }},
671 },
672 }),
673 }
674 cfg.setDefaults("/tmp", "")
675
676 env := env.NewFromMap(map[string]string{})
677 resolver := NewEnvironmentVariableResolver(env)
678 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
679 require.NoError(t, err)
680
681 require.Equal(t, cfg.Providers.Len(), 1)
682 customProvider, exists := cfg.Providers.Get("custom-anthropic")
683 require.True(t, exists)
684 require.Equal(t, "custom-anthropic", customProvider.ID)
685 require.Equal(t, "test-key", customProvider.APIKey)
686 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
687 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
688 })
689
690 t.Run("disabled custom provider is removed", func(t *testing.T) {
691 cfg := &Config{
692 Providers: csync.NewMapFrom(map[string]ProviderConfig{
693 "custom": {
694 APIKey: "test-key",
695 BaseURL: "https://api.custom.com/v1",
696 Type: catwalk.TypeOpenAI,
697 Disable: true,
698 Models: []catwalk.Model{{
699 ID: "test-model",
700 }},
701 },
702 }),
703 }
704 cfg.setDefaults("/tmp", "")
705
706 env := env.NewFromMap(map[string]string{})
707 resolver := NewEnvironmentVariableResolver(env)
708 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
709 require.NoError(t, err)
710
711 require.Equal(t, cfg.Providers.Len(), 0)
712 _, exists := cfg.Providers.Get("custom")
713 require.False(t, exists)
714 })
715}
716
717func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
718 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
719 knownProviders := []catwalk.Provider{
720 {
721 ID: catwalk.InferenceProviderVertexAI,
722 APIKey: "",
723 APIEndpoint: "",
724 Models: []catwalk.Model{{
725 ID: "gemini-pro",
726 }},
727 },
728 }
729
730 cfg := &Config{
731 Providers: csync.NewMapFrom(map[string]ProviderConfig{
732 "vertexai": {
733 BaseURL: "custom-url",
734 },
735 }),
736 }
737 cfg.setDefaults("/tmp", "")
738
739 env := env.NewFromMap(map[string]string{
740 "GOOGLE_GENAI_USE_VERTEXAI": "false",
741 })
742 resolver := NewEnvironmentVariableResolver(env)
743 err := cfg.configureProviders(env, resolver, knownProviders)
744 require.NoError(t, err)
745
746 require.Equal(t, cfg.Providers.Len(), 0)
747 _, exists := cfg.Providers.Get("vertexai")
748 require.False(t, exists)
749 })
750
751 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
752 knownProviders := []catwalk.Provider{
753 {
754 ID: catwalk.InferenceProviderBedrock,
755 APIKey: "",
756 APIEndpoint: "",
757 Models: []catwalk.Model{{
758 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
759 }},
760 },
761 }
762
763 cfg := &Config{
764 Providers: csync.NewMapFrom(map[string]ProviderConfig{
765 "bedrock": {
766 BaseURL: "custom-url",
767 },
768 }),
769 }
770 cfg.setDefaults("/tmp", "")
771
772 env := env.NewFromMap(map[string]string{})
773 resolver := NewEnvironmentVariableResolver(env)
774 err := cfg.configureProviders(env, resolver, knownProviders)
775 require.NoError(t, err)
776
777 require.Equal(t, cfg.Providers.Len(), 0)
778 _, exists := cfg.Providers.Get("bedrock")
779 require.False(t, exists)
780 })
781
782 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
783 knownProviders := []catwalk.Provider{
784 {
785 ID: "openai",
786 APIKey: "$MISSING_API_KEY",
787 APIEndpoint: "https://api.openai.com/v1",
788 Models: []catwalk.Model{{
789 ID: "test-model",
790 }},
791 },
792 }
793
794 cfg := &Config{
795 Providers: csync.NewMapFrom(map[string]ProviderConfig{
796 "openai": {
797 BaseURL: "custom-url",
798 },
799 }),
800 }
801 cfg.setDefaults("/tmp", "")
802
803 env := env.NewFromMap(map[string]string{})
804 resolver := NewEnvironmentVariableResolver(env)
805 err := cfg.configureProviders(env, resolver, knownProviders)
806 require.NoError(t, err)
807
808 require.Equal(t, cfg.Providers.Len(), 0)
809 _, exists := cfg.Providers.Get("openai")
810 require.False(t, exists)
811 })
812
813 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
814 knownProviders := []catwalk.Provider{
815 {
816 ID: "openai",
817 APIKey: "$OPENAI_API_KEY",
818 APIEndpoint: "$MISSING_ENDPOINT",
819 Models: []catwalk.Model{{
820 ID: "test-model",
821 }},
822 },
823 }
824
825 cfg := &Config{
826 Providers: csync.NewMapFrom(map[string]ProviderConfig{
827 "openai": {
828 APIKey: "test-key",
829 },
830 }),
831 }
832 cfg.setDefaults("/tmp", "")
833
834 env := env.NewFromMap(map[string]string{
835 "OPENAI_API_KEY": "test-key",
836 })
837 resolver := NewEnvironmentVariableResolver(env)
838 err := cfg.configureProviders(env, resolver, knownProviders)
839 require.NoError(t, err)
840
841 require.Equal(t, cfg.Providers.Len(), 1)
842 _, exists := cfg.Providers.Get("openai")
843 require.True(t, exists)
844 })
845}
846
847func TestConfig_defaultModelSelection(t *testing.T) {
848 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
849 knownProviders := []catwalk.Provider{
850 {
851 ID: "openai",
852 APIKey: "abc",
853 DefaultLargeModelID: "large-model",
854 DefaultSmallModelID: "small-model",
855 Models: []catwalk.Model{
856 {
857 ID: "large-model",
858 DefaultMaxTokens: 1000,
859 },
860 {
861 ID: "small-model",
862 DefaultMaxTokens: 500,
863 },
864 },
865 },
866 }
867
868 cfg := &Config{}
869 cfg.setDefaults("/tmp", "")
870 env := env.NewFromMap(map[string]string{})
871 resolver := NewEnvironmentVariableResolver(env)
872 err := cfg.configureProviders(env, resolver, knownProviders)
873 require.NoError(t, err)
874
875 large, small, err := cfg.defaultModelSelection(knownProviders)
876 require.NoError(t, err)
877 require.Equal(t, "large-model", large.Model)
878 require.Equal(t, "openai", large.Provider)
879 require.Equal(t, int64(1000), large.MaxTokens)
880 require.Equal(t, "small-model", small.Model)
881 require.Equal(t, "openai", small.Provider)
882 require.Equal(t, int64(500), small.MaxTokens)
883 })
884 t.Run("should error if no providers configured", func(t *testing.T) {
885 knownProviders := []catwalk.Provider{
886 {
887 ID: "openai",
888 APIKey: "$MISSING_KEY",
889 DefaultLargeModelID: "large-model",
890 DefaultSmallModelID: "small-model",
891 Models: []catwalk.Model{
892 {
893 ID: "large-model",
894 DefaultMaxTokens: 1000,
895 },
896 {
897 ID: "small-model",
898 DefaultMaxTokens: 500,
899 },
900 },
901 },
902 }
903
904 cfg := &Config{}
905 cfg.setDefaults("/tmp", "")
906 env := env.NewFromMap(map[string]string{})
907 resolver := NewEnvironmentVariableResolver(env)
908 err := cfg.configureProviders(env, resolver, knownProviders)
909 require.NoError(t, err)
910
911 _, _, err = cfg.defaultModelSelection(knownProviders)
912 require.Error(t, err)
913 })
914 t.Run("should error if model is missing", func(t *testing.T) {
915 knownProviders := []catwalk.Provider{
916 {
917 ID: "openai",
918 APIKey: "abc",
919 DefaultLargeModelID: "large-model",
920 DefaultSmallModelID: "small-model",
921 Models: []catwalk.Model{
922 {
923 ID: "not-large-model",
924 DefaultMaxTokens: 1000,
925 },
926 {
927 ID: "small-model",
928 DefaultMaxTokens: 500,
929 },
930 },
931 },
932 }
933
934 cfg := &Config{}
935 cfg.setDefaults("/tmp", "")
936 env := env.NewFromMap(map[string]string{})
937 resolver := NewEnvironmentVariableResolver(env)
938 err := cfg.configureProviders(env, resolver, knownProviders)
939 require.NoError(t, err)
940 _, _, err = cfg.defaultModelSelection(knownProviders)
941 require.Error(t, err)
942 })
943
944 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
945 knownProviders := []catwalk.Provider{
946 {
947 ID: "openai",
948 APIKey: "$MISSING", // will not be included in the config
949 DefaultLargeModelID: "large-model",
950 DefaultSmallModelID: "small-model",
951 Models: []catwalk.Model{
952 {
953 ID: "not-large-model",
954 DefaultMaxTokens: 1000,
955 },
956 {
957 ID: "small-model",
958 DefaultMaxTokens: 500,
959 },
960 },
961 },
962 }
963
964 cfg := &Config{
965 Providers: csync.NewMapFrom(map[string]ProviderConfig{
966 "custom": {
967 APIKey: "test-key",
968 BaseURL: "https://api.custom.com/v1",
969 Models: []catwalk.Model{
970 {
971 ID: "model",
972 DefaultMaxTokens: 600,
973 },
974 },
975 },
976 }),
977 }
978 cfg.setDefaults("/tmp", "")
979 env := env.NewFromMap(map[string]string{})
980 resolver := NewEnvironmentVariableResolver(env)
981 err := cfg.configureProviders(env, resolver, knownProviders)
982 require.NoError(t, err)
983 large, small, err := cfg.defaultModelSelection(knownProviders)
984 require.NoError(t, err)
985 require.Equal(t, "model", large.Model)
986 require.Equal(t, "custom", large.Provider)
987 require.Equal(t, int64(600), large.MaxTokens)
988 require.Equal(t, "model", small.Model)
989 require.Equal(t, "custom", small.Provider)
990 require.Equal(t, int64(600), small.MaxTokens)
991 })
992
993 t.Run("should fail if no model configured", func(t *testing.T) {
994 knownProviders := []catwalk.Provider{
995 {
996 ID: "openai",
997 APIKey: "$MISSING", // will not be included in the config
998 DefaultLargeModelID: "large-model",
999 DefaultSmallModelID: "small-model",
1000 Models: []catwalk.Model{
1001 {
1002 ID: "not-large-model",
1003 DefaultMaxTokens: 1000,
1004 },
1005 {
1006 ID: "small-model",
1007 DefaultMaxTokens: 500,
1008 },
1009 },
1010 },
1011 }
1012
1013 cfg := &Config{
1014 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1015 "custom": {
1016 APIKey: "test-key",
1017 BaseURL: "https://api.custom.com/v1",
1018 Models: []catwalk.Model{},
1019 },
1020 }),
1021 }
1022 cfg.setDefaults("/tmp", "")
1023 env := env.NewFromMap(map[string]string{})
1024 resolver := NewEnvironmentVariableResolver(env)
1025 err := cfg.configureProviders(env, resolver, knownProviders)
1026 require.NoError(t, err)
1027 _, _, err = cfg.defaultModelSelection(knownProviders)
1028 require.Error(t, err)
1029 })
1030 t.Run("should use the default provider first", func(t *testing.T) {
1031 knownProviders := []catwalk.Provider{
1032 {
1033 ID: "openai",
1034 APIKey: "set",
1035 DefaultLargeModelID: "large-model",
1036 DefaultSmallModelID: "small-model",
1037 Models: []catwalk.Model{
1038 {
1039 ID: "large-model",
1040 DefaultMaxTokens: 1000,
1041 },
1042 {
1043 ID: "small-model",
1044 DefaultMaxTokens: 500,
1045 },
1046 },
1047 },
1048 }
1049
1050 cfg := &Config{
1051 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1052 "custom": {
1053 APIKey: "test-key",
1054 BaseURL: "https://api.custom.com/v1",
1055 Models: []catwalk.Model{
1056 {
1057 ID: "large-model",
1058 DefaultMaxTokens: 1000,
1059 },
1060 },
1061 },
1062 }),
1063 }
1064 cfg.setDefaults("/tmp", "")
1065 env := env.NewFromMap(map[string]string{})
1066 resolver := NewEnvironmentVariableResolver(env)
1067 err := cfg.configureProviders(env, resolver, knownProviders)
1068 require.NoError(t, err)
1069 large, small, err := cfg.defaultModelSelection(knownProviders)
1070 require.NoError(t, err)
1071 require.Equal(t, "large-model", large.Model)
1072 require.Equal(t, "openai", large.Provider)
1073 require.Equal(t, int64(1000), large.MaxTokens)
1074 require.Equal(t, "small-model", small.Model)
1075 require.Equal(t, "openai", small.Provider)
1076 require.Equal(t, int64(500), small.MaxTokens)
1077 })
1078}
1079
1080func TestConfig_configureSelectedModels(t *testing.T) {
1081 t.Run("should override defaults", func(t *testing.T) {
1082 knownProviders := []catwalk.Provider{
1083 {
1084 ID: "openai",
1085 APIKey: "abc",
1086 DefaultLargeModelID: "large-model",
1087 DefaultSmallModelID: "small-model",
1088 Models: []catwalk.Model{
1089 {
1090 ID: "larger-model",
1091 DefaultMaxTokens: 2000,
1092 },
1093 {
1094 ID: "large-model",
1095 DefaultMaxTokens: 1000,
1096 },
1097 {
1098 ID: "small-model",
1099 DefaultMaxTokens: 500,
1100 },
1101 },
1102 },
1103 }
1104
1105 cfg := &Config{
1106 Models: map[SelectedModelType]SelectedModel{
1107 "large": {
1108 Model: "larger-model",
1109 },
1110 },
1111 }
1112 cfg.setDefaults("/tmp", "")
1113 env := env.NewFromMap(map[string]string{})
1114 resolver := NewEnvironmentVariableResolver(env)
1115 err := cfg.configureProviders(env, resolver, knownProviders)
1116 require.NoError(t, err)
1117
1118 err = cfg.configureSelectedModels(knownProviders)
1119 require.NoError(t, err)
1120 large := cfg.Models[SelectedModelTypeLarge]
1121 small := cfg.Models[SelectedModelTypeSmall]
1122 require.Equal(t, "larger-model", large.Model)
1123 require.Equal(t, "openai", large.Provider)
1124 require.Equal(t, int64(2000), large.MaxTokens)
1125 require.Equal(t, "small-model", small.Model)
1126 require.Equal(t, "openai", small.Provider)
1127 require.Equal(t, int64(500), small.MaxTokens)
1128 })
1129 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1130 knownProviders := []catwalk.Provider{
1131 {
1132 ID: "openai",
1133 APIKey: "abc",
1134 DefaultLargeModelID: "large-model",
1135 DefaultSmallModelID: "small-model",
1136 Models: []catwalk.Model{
1137 {
1138 ID: "large-model",
1139 DefaultMaxTokens: 1000,
1140 },
1141 {
1142 ID: "small-model",
1143 DefaultMaxTokens: 500,
1144 },
1145 },
1146 },
1147 {
1148 ID: "anthropic",
1149 APIKey: "abc",
1150 DefaultLargeModelID: "a-large-model",
1151 DefaultSmallModelID: "a-small-model",
1152 Models: []catwalk.Model{
1153 {
1154 ID: "a-large-model",
1155 DefaultMaxTokens: 1000,
1156 },
1157 {
1158 ID: "a-small-model",
1159 DefaultMaxTokens: 200,
1160 },
1161 },
1162 },
1163 }
1164
1165 cfg := &Config{
1166 Models: map[SelectedModelType]SelectedModel{
1167 "small": {
1168 Model: "a-small-model",
1169 Provider: "anthropic",
1170 MaxTokens: 300,
1171 },
1172 },
1173 }
1174 cfg.setDefaults("/tmp", "")
1175 env := env.NewFromMap(map[string]string{})
1176 resolver := NewEnvironmentVariableResolver(env)
1177 err := cfg.configureProviders(env, resolver, knownProviders)
1178 require.NoError(t, err)
1179
1180 err = cfg.configureSelectedModels(knownProviders)
1181 require.NoError(t, err)
1182 large := cfg.Models[SelectedModelTypeLarge]
1183 small := cfg.Models[SelectedModelTypeSmall]
1184 require.Equal(t, "large-model", large.Model)
1185 require.Equal(t, "openai", large.Provider)
1186 require.Equal(t, int64(1000), large.MaxTokens)
1187 require.Equal(t, "a-small-model", small.Model)
1188 require.Equal(t, "anthropic", small.Provider)
1189 require.Equal(t, int64(300), small.MaxTokens)
1190 })
1191
1192 t.Run("should override the max tokens only", func(t *testing.T) {
1193 knownProviders := []catwalk.Provider{
1194 {
1195 ID: "openai",
1196 APIKey: "abc",
1197 DefaultLargeModelID: "large-model",
1198 DefaultSmallModelID: "small-model",
1199 Models: []catwalk.Model{
1200 {
1201 ID: "large-model",
1202 DefaultMaxTokens: 1000,
1203 },
1204 {
1205 ID: "small-model",
1206 DefaultMaxTokens: 500,
1207 },
1208 },
1209 },
1210 }
1211
1212 cfg := &Config{
1213 Models: map[SelectedModelType]SelectedModel{
1214 "large": {
1215 MaxTokens: 100,
1216 },
1217 },
1218 }
1219 cfg.setDefaults("/tmp", "")
1220 env := env.NewFromMap(map[string]string{})
1221 resolver := NewEnvironmentVariableResolver(env)
1222 err := cfg.configureProviders(env, resolver, knownProviders)
1223 require.NoError(t, err)
1224
1225 err = cfg.configureSelectedModels(knownProviders)
1226 require.NoError(t, err)
1227 large := cfg.Models[SelectedModelTypeLarge]
1228 require.Equal(t, "large-model", large.Model)
1229 require.Equal(t, "openai", large.Provider)
1230 require.Equal(t, int64(100), large.MaxTokens)
1231 })
1232}