1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/require"
15)
16
17func TestMain(m *testing.M) {
18 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
19
20 exitVal := m.Run()
21 os.Exit(exitVal)
22}
23
24func TestConfig_LoadFromReaders(t *testing.T) {
25 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
26 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
27 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
28
29 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
30
31 require.NoError(t, err)
32 require.NotNil(t, loadedConfig)
33 require.Equal(t, 1, loadedConfig.Providers.Len())
34 pc, _ := loadedConfig.Providers.Get("openai")
35 require.Equal(t, "key2", pc.APIKey)
36 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
37}
38
39func TestConfig_setDefaults(t *testing.T) {
40 cfg := &Config{}
41
42 cfg.setDefaults("/tmp")
43
44 require.NotNil(t, cfg.Options)
45 require.NotNil(t, cfg.Options.TUI)
46 require.NotNil(t, cfg.Options.ContextPaths)
47 require.NotNil(t, cfg.Providers)
48 require.NotNil(t, cfg.Models)
49 require.NotNil(t, cfg.LSP)
50 require.NotNil(t, cfg.MCP)
51 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
52 for _, path := range defaultContextPaths {
53 require.Contains(t, cfg.Options.ContextPaths, path)
54 }
55 require.Equal(t, "/tmp", cfg.workingDir)
56}
57
58func TestConfig_configureProviders(t *testing.T) {
59 knownProviders := []catwalk.Provider{
60 {
61 ID: "openai",
62 APIKey: "$OPENAI_API_KEY",
63 APIEndpoint: "https://api.openai.com/v1",
64 Models: []catwalk.Model{{
65 ID: "test-model",
66 }},
67 },
68 }
69
70 cfg := &Config{}
71 cfg.setDefaults("/tmp")
72 env := env.NewFromMap(map[string]string{
73 "OPENAI_API_KEY": "test-key",
74 })
75 resolver := NewEnvironmentVariableResolver(env)
76 err := cfg.configureProviders(env, resolver, knownProviders)
77 require.NoError(t, err)
78 require.Equal(t, 1, cfg.Providers.Len())
79
80 // We want to make sure that we keep the configured API key as a placeholder
81 pc, _ := cfg.Providers.Get("openai")
82 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
83}
84
85func TestConfig_configureProvidersWithOverride(t *testing.T) {
86 knownProviders := []catwalk.Provider{
87 {
88 ID: "openai",
89 APIKey: "$OPENAI_API_KEY",
90 APIEndpoint: "https://api.openai.com/v1",
91 Models: []catwalk.Model{{
92 ID: "test-model",
93 }},
94 },
95 }
96
97 cfg := &Config{
98 Providers: csync.NewMap[string, ProviderConfig](),
99 }
100 cfg.Providers.Set("openai", ProviderConfig{
101 APIKey: "xyz",
102 BaseURL: "https://api.openai.com/v2",
103 Models: []catwalk.Model{
104 {
105 ID: "test-model",
106 Name: "Updated",
107 },
108 {
109 ID: "another-model",
110 },
111 },
112 })
113 cfg.setDefaults("/tmp")
114
115 env := env.NewFromMap(map[string]string{
116 "OPENAI_API_KEY": "test-key",
117 })
118 resolver := NewEnvironmentVariableResolver(env)
119 err := cfg.configureProviders(env, resolver, knownProviders)
120 require.NoError(t, err)
121 require.Equal(t, 1, cfg.Providers.Len())
122
123 // We want to make sure that we keep the configured API key as a placeholder
124 pc, _ := cfg.Providers.Get("openai")
125 require.Equal(t, "xyz", pc.APIKey)
126 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
127 require.Len(t, pc.Models, 2)
128 require.Equal(t, "Updated", pc.Models[0].Name)
129}
130
131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
132 knownProviders := []catwalk.Provider{
133 {
134 ID: "openai",
135 APIKey: "$OPENAI_API_KEY",
136 APIEndpoint: "https://api.openai.com/v1",
137 Models: []catwalk.Model{{
138 ID: "test-model",
139 }},
140 },
141 }
142
143 cfg := &Config{
144 Providers: csync.NewMapFrom(map[string]ProviderConfig{
145 "custom": {
146 APIKey: "xyz",
147 BaseURL: "https://api.someendpoint.com/v2",
148 Models: []catwalk.Model{
149 {
150 ID: "test-model",
151 },
152 },
153 },
154 }),
155 }
156 cfg.setDefaults("/tmp")
157 env := env.NewFromMap(map[string]string{
158 "OPENAI_API_KEY": "test-key",
159 })
160 resolver := NewEnvironmentVariableResolver(env)
161 err := cfg.configureProviders(env, resolver, knownProviders)
162 require.NoError(t, err)
163 // Should be to because of the env variable
164 require.Equal(t, cfg.Providers.Len(), 2)
165
166 // We want to make sure that we keep the configured API key as a placeholder
167 pc, _ := cfg.Providers.Get("custom")
168 require.Equal(t, "xyz", pc.APIKey)
169 // Make sure we set the ID correctly
170 require.Equal(t, "custom", pc.ID)
171 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
172 require.Len(t, pc.Models, 1)
173
174 _, ok := cfg.Providers.Get("openai")
175 require.True(t, ok, "OpenAI provider should still be present")
176}
177
178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
179 knownProviders := []catwalk.Provider{
180 {
181 ID: catwalk.InferenceProviderBedrock,
182 APIKey: "",
183 APIEndpoint: "",
184 Models: []catwalk.Model{{
185 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
186 }},
187 },
188 }
189
190 cfg := &Config{}
191 cfg.setDefaults("/tmp")
192 env := env.NewFromMap(map[string]string{
193 "AWS_ACCESS_KEY_ID": "test-key-id",
194 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
195 })
196 resolver := NewEnvironmentVariableResolver(env)
197 err := cfg.configureProviders(env, resolver, knownProviders)
198 require.NoError(t, err)
199 require.Equal(t, cfg.Providers.Len(), 1)
200
201 bedrockProvider, ok := cfg.Providers.Get("bedrock")
202 require.True(t, ok, "Bedrock provider should be present")
203 require.Len(t, bedrockProvider.Models, 1)
204 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
205}
206
207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
208 knownProviders := []catwalk.Provider{
209 {
210 ID: catwalk.InferenceProviderBedrock,
211 APIKey: "",
212 APIEndpoint: "",
213 Models: []catwalk.Model{{
214 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
215 }},
216 },
217 }
218
219 cfg := &Config{}
220 cfg.setDefaults("/tmp")
221 env := env.NewFromMap(map[string]string{})
222 resolver := NewEnvironmentVariableResolver(env)
223 err := cfg.configureProviders(env, resolver, knownProviders)
224 require.NoError(t, err)
225 // Provider should not be configured without credentials
226 require.Equal(t, cfg.Providers.Len(), 0)
227}
228
229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
230 knownProviders := []catwalk.Provider{
231 {
232 ID: catwalk.InferenceProviderBedrock,
233 APIKey: "",
234 APIEndpoint: "",
235 Models: []catwalk.Model{{
236 ID: "some-random-model",
237 }},
238 },
239 }
240
241 cfg := &Config{}
242 cfg.setDefaults("/tmp")
243 env := env.NewFromMap(map[string]string{
244 "AWS_ACCESS_KEY_ID": "test-key-id",
245 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
246 })
247 resolver := NewEnvironmentVariableResolver(env)
248 err := cfg.configureProviders(env, resolver, knownProviders)
249 require.Error(t, err)
250}
251
252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
253 knownProviders := []catwalk.Provider{
254 {
255 ID: catwalk.InferenceProviderVertexAI,
256 APIKey: "",
257 APIEndpoint: "",
258 Models: []catwalk.Model{{
259 ID: "gemini-pro",
260 }},
261 },
262 }
263
264 cfg := &Config{}
265 cfg.setDefaults("/tmp")
266 env := env.NewFromMap(map[string]string{
267 "VERTEXAI_PROJECT": "test-project",
268 "VERTEXAI_LOCATION": "us-central1",
269 })
270 resolver := NewEnvironmentVariableResolver(env)
271 err := cfg.configureProviders(env, resolver, knownProviders)
272 require.NoError(t, err)
273 require.Equal(t, cfg.Providers.Len(), 1)
274
275 vertexProvider, ok := cfg.Providers.Get("vertexai")
276 require.True(t, ok, "VertexAI provider should be present")
277 require.Len(t, vertexProvider.Models, 1)
278 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
279 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
280 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
281}
282
283func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
284 knownProviders := []catwalk.Provider{
285 {
286 ID: catwalk.InferenceProviderVertexAI,
287 APIKey: "",
288 APIEndpoint: "",
289 Models: []catwalk.Model{{
290 ID: "gemini-pro",
291 }},
292 },
293 }
294
295 cfg := &Config{}
296 cfg.setDefaults("/tmp")
297 env := env.NewFromMap(map[string]string{
298 "GOOGLE_GENAI_USE_VERTEXAI": "false",
299 "GOOGLE_CLOUD_PROJECT": "test-project",
300 "GOOGLE_CLOUD_LOCATION": "us-central1",
301 })
302 resolver := NewEnvironmentVariableResolver(env)
303 err := cfg.configureProviders(env, resolver, knownProviders)
304 require.NoError(t, err)
305 // Provider should not be configured without proper credentials
306 require.Equal(t, cfg.Providers.Len(), 0)
307}
308
309func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
310 knownProviders := []catwalk.Provider{
311 {
312 ID: catwalk.InferenceProviderVertexAI,
313 APIKey: "",
314 APIEndpoint: "",
315 Models: []catwalk.Model{{
316 ID: "gemini-pro",
317 }},
318 },
319 }
320
321 cfg := &Config{}
322 cfg.setDefaults("/tmp")
323 env := env.NewFromMap(map[string]string{
324 "GOOGLE_GENAI_USE_VERTEXAI": "true",
325 "GOOGLE_CLOUD_LOCATION": "us-central1",
326 })
327 resolver := NewEnvironmentVariableResolver(env)
328 err := cfg.configureProviders(env, resolver, knownProviders)
329 require.NoError(t, err)
330 // Provider should not be configured without project
331 require.Equal(t, cfg.Providers.Len(), 0)
332}
333
334func TestConfig_configureProvidersSetProviderID(t *testing.T) {
335 knownProviders := []catwalk.Provider{
336 {
337 ID: "openai",
338 APIKey: "$OPENAI_API_KEY",
339 APIEndpoint: "https://api.openai.com/v1",
340 Models: []catwalk.Model{{
341 ID: "test-model",
342 }},
343 },
344 }
345
346 cfg := &Config{}
347 cfg.setDefaults("/tmp")
348 env := env.NewFromMap(map[string]string{
349 "OPENAI_API_KEY": "test-key",
350 })
351 resolver := NewEnvironmentVariableResolver(env)
352 err := cfg.configureProviders(env, resolver, knownProviders)
353 require.NoError(t, err)
354 require.Equal(t, cfg.Providers.Len(), 1)
355
356 // Provider ID should be set
357 pc, _ := cfg.Providers.Get("openai")
358 require.Equal(t, "openai", pc.ID)
359}
360
361func TestConfig_EnabledProviders(t *testing.T) {
362 t.Run("all providers enabled", func(t *testing.T) {
363 cfg := &Config{
364 Providers: csync.NewMapFrom(map[string]ProviderConfig{
365 "openai": {
366 ID: "openai",
367 APIKey: "key1",
368 Disable: false,
369 },
370 "anthropic": {
371 ID: "anthropic",
372 APIKey: "key2",
373 Disable: false,
374 },
375 }),
376 }
377
378 enabled := cfg.EnabledProviders()
379 require.Len(t, enabled, 2)
380 })
381
382 t.Run("some providers disabled", func(t *testing.T) {
383 cfg := &Config{
384 Providers: csync.NewMapFrom(map[string]ProviderConfig{
385 "openai": {
386 ID: "openai",
387 APIKey: "key1",
388 Disable: false,
389 },
390 "anthropic": {
391 ID: "anthropic",
392 APIKey: "key2",
393 Disable: true,
394 },
395 }),
396 }
397
398 enabled := cfg.EnabledProviders()
399 require.Len(t, enabled, 1)
400 require.Equal(t, "openai", enabled[0].ID)
401 })
402
403 t.Run("empty providers map", func(t *testing.T) {
404 cfg := &Config{
405 Providers: csync.NewMap[string, ProviderConfig](),
406 }
407
408 enabled := cfg.EnabledProviders()
409 require.Len(t, enabled, 0)
410 })
411}
412
413func TestConfig_IsConfigured(t *testing.T) {
414 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
415 cfg := &Config{
416 Providers: csync.NewMapFrom(map[string]ProviderConfig{
417 "openai": {
418 ID: "openai",
419 APIKey: "key1",
420 Disable: false,
421 },
422 }),
423 }
424
425 require.True(t, cfg.IsConfigured())
426 })
427
428 t.Run("returns false when no providers are configured", func(t *testing.T) {
429 cfg := &Config{
430 Providers: csync.NewMap[string, ProviderConfig](),
431 }
432
433 require.False(t, cfg.IsConfigured())
434 })
435
436 t.Run("returns false when all providers are disabled", func(t *testing.T) {
437 cfg := &Config{
438 Providers: csync.NewMapFrom(map[string]ProviderConfig{
439 "openai": {
440 ID: "openai",
441 APIKey: "key1",
442 Disable: true,
443 },
444 "anthropic": {
445 ID: "anthropic",
446 APIKey: "key2",
447 Disable: true,
448 },
449 }),
450 }
451
452 require.False(t, cfg.IsConfigured())
453 })
454}
455
456func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
457 knownProviders := []catwalk.Provider{
458 {
459 ID: "openai",
460 APIKey: "$OPENAI_API_KEY",
461 APIEndpoint: "https://api.openai.com/v1",
462 Models: []catwalk.Model{{
463 ID: "test-model",
464 }},
465 },
466 }
467
468 cfg := &Config{
469 Providers: csync.NewMapFrom(map[string]ProviderConfig{
470 "openai": {
471 Disable: true,
472 },
473 }),
474 }
475 cfg.setDefaults("/tmp")
476
477 env := env.NewFromMap(map[string]string{
478 "OPENAI_API_KEY": "test-key",
479 })
480 resolver := NewEnvironmentVariableResolver(env)
481 err := cfg.configureProviders(env, resolver, knownProviders)
482 require.NoError(t, err)
483
484 // Provider should be removed from config when disabled
485 require.Equal(t, cfg.Providers.Len(), 0)
486 _, exists := cfg.Providers.Get("openai")
487 require.False(t, exists)
488}
489
490func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
491 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
492 cfg := &Config{
493 Providers: csync.NewMapFrom(map[string]ProviderConfig{
494 "custom": {
495 BaseURL: "https://api.custom.com/v1",
496 Models: []catwalk.Model{{
497 ID: "test-model",
498 }},
499 },
500 "openai": {
501 APIKey: "$MISSING",
502 },
503 }),
504 }
505 cfg.setDefaults("/tmp")
506
507 env := env.NewFromMap(map[string]string{})
508 resolver := NewEnvironmentVariableResolver(env)
509 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
510 require.NoError(t, err)
511
512 require.Equal(t, cfg.Providers.Len(), 1)
513 _, exists := cfg.Providers.Get("custom")
514 require.True(t, exists)
515 })
516
517 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
518 cfg := &Config{
519 Providers: csync.NewMapFrom(map[string]ProviderConfig{
520 "custom": {
521 APIKey: "test-key",
522 Models: []catwalk.Model{{
523 ID: "test-model",
524 }},
525 },
526 }),
527 }
528 cfg.setDefaults("/tmp")
529
530 env := env.NewFromMap(map[string]string{})
531 resolver := NewEnvironmentVariableResolver(env)
532 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
533 require.NoError(t, err)
534
535 require.Equal(t, cfg.Providers.Len(), 0)
536 _, exists := cfg.Providers.Get("custom")
537 require.False(t, exists)
538 })
539
540 t.Run("custom provider with no models is removed", func(t *testing.T) {
541 cfg := &Config{
542 Providers: csync.NewMapFrom(map[string]ProviderConfig{
543 "custom": {
544 APIKey: "test-key",
545 BaseURL: "https://api.custom.com/v1",
546 Models: []catwalk.Model{},
547 },
548 }),
549 }
550 cfg.setDefaults("/tmp")
551
552 env := env.NewFromMap(map[string]string{})
553 resolver := NewEnvironmentVariableResolver(env)
554 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
555 require.NoError(t, err)
556
557 require.Equal(t, cfg.Providers.Len(), 0)
558 _, exists := cfg.Providers.Get("custom")
559 require.False(t, exists)
560 })
561
562 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
563 cfg := &Config{
564 Providers: csync.NewMapFrom(map[string]ProviderConfig{
565 "custom": {
566 APIKey: "test-key",
567 BaseURL: "https://api.custom.com/v1",
568 Type: "unsupported",
569 Models: []catwalk.Model{{
570 ID: "test-model",
571 }},
572 },
573 }),
574 }
575 cfg.setDefaults("/tmp")
576
577 env := env.NewFromMap(map[string]string{})
578 resolver := NewEnvironmentVariableResolver(env)
579 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
580 require.NoError(t, err)
581
582 require.Equal(t, cfg.Providers.Len(), 0)
583 _, exists := cfg.Providers.Get("custom")
584 require.False(t, exists)
585 })
586
587 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
588 cfg := &Config{
589 Providers: csync.NewMapFrom(map[string]ProviderConfig{
590 "custom": {
591 APIKey: "test-key",
592 BaseURL: "https://api.custom.com/v1",
593 Type: catwalk.TypeOpenAI,
594 Models: []catwalk.Model{{
595 ID: "test-model",
596 }},
597 },
598 }),
599 }
600 cfg.setDefaults("/tmp")
601
602 env := env.NewFromMap(map[string]string{})
603 resolver := NewEnvironmentVariableResolver(env)
604 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
605 require.NoError(t, err)
606
607 require.Equal(t, cfg.Providers.Len(), 1)
608 customProvider, exists := cfg.Providers.Get("custom")
609 require.True(t, exists)
610 require.Equal(t, "custom", customProvider.ID)
611 require.Equal(t, "test-key", customProvider.APIKey)
612 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
613 })
614
615 t.Run("custom anthropic provider is supported", func(t *testing.T) {
616 cfg := &Config{
617 Providers: csync.NewMapFrom(map[string]ProviderConfig{
618 "custom-anthropic": {
619 APIKey: "test-key",
620 BaseURL: "https://api.anthropic.com/v1",
621 Type: catwalk.TypeAnthropic,
622 Models: []catwalk.Model{{
623 ID: "claude-3-sonnet",
624 }},
625 },
626 }),
627 }
628 cfg.setDefaults("/tmp")
629
630 env := env.NewFromMap(map[string]string{})
631 resolver := NewEnvironmentVariableResolver(env)
632 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
633 require.NoError(t, err)
634
635 require.Equal(t, cfg.Providers.Len(), 1)
636 customProvider, exists := cfg.Providers.Get("custom-anthropic")
637 require.True(t, exists)
638 require.Equal(t, "custom-anthropic", customProvider.ID)
639 require.Equal(t, "test-key", customProvider.APIKey)
640 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
641 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
642 })
643
644 t.Run("disabled custom provider is removed", func(t *testing.T) {
645 cfg := &Config{
646 Providers: csync.NewMapFrom(map[string]ProviderConfig{
647 "custom": {
648 APIKey: "test-key",
649 BaseURL: "https://api.custom.com/v1",
650 Type: catwalk.TypeOpenAI,
651 Disable: true,
652 Models: []catwalk.Model{{
653 ID: "test-model",
654 }},
655 },
656 }),
657 }
658 cfg.setDefaults("/tmp")
659
660 env := env.NewFromMap(map[string]string{})
661 resolver := NewEnvironmentVariableResolver(env)
662 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
663 require.NoError(t, err)
664
665 require.Equal(t, cfg.Providers.Len(), 0)
666 _, exists := cfg.Providers.Get("custom")
667 require.False(t, exists)
668 })
669}
670
671func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
672 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
673 knownProviders := []catwalk.Provider{
674 {
675 ID: catwalk.InferenceProviderVertexAI,
676 APIKey: "",
677 APIEndpoint: "",
678 Models: []catwalk.Model{{
679 ID: "gemini-pro",
680 }},
681 },
682 }
683
684 cfg := &Config{
685 Providers: csync.NewMapFrom(map[string]ProviderConfig{
686 "vertexai": {
687 BaseURL: "custom-url",
688 },
689 }),
690 }
691 cfg.setDefaults("/tmp")
692
693 env := env.NewFromMap(map[string]string{
694 "GOOGLE_GENAI_USE_VERTEXAI": "false",
695 })
696 resolver := NewEnvironmentVariableResolver(env)
697 err := cfg.configureProviders(env, resolver, knownProviders)
698 require.NoError(t, err)
699
700 require.Equal(t, cfg.Providers.Len(), 0)
701 _, exists := cfg.Providers.Get("vertexai")
702 require.False(t, exists)
703 })
704
705 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
706 knownProviders := []catwalk.Provider{
707 {
708 ID: catwalk.InferenceProviderBedrock,
709 APIKey: "",
710 APIEndpoint: "",
711 Models: []catwalk.Model{{
712 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
713 }},
714 },
715 }
716
717 cfg := &Config{
718 Providers: csync.NewMapFrom(map[string]ProviderConfig{
719 "bedrock": {
720 BaseURL: "custom-url",
721 },
722 }),
723 }
724 cfg.setDefaults("/tmp")
725
726 env := env.NewFromMap(map[string]string{})
727 resolver := NewEnvironmentVariableResolver(env)
728 err := cfg.configureProviders(env, resolver, knownProviders)
729 require.NoError(t, err)
730
731 require.Equal(t, cfg.Providers.Len(), 0)
732 _, exists := cfg.Providers.Get("bedrock")
733 require.False(t, exists)
734 })
735
736 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
737 knownProviders := []catwalk.Provider{
738 {
739 ID: "openai",
740 APIKey: "$MISSING_API_KEY",
741 APIEndpoint: "https://api.openai.com/v1",
742 Models: []catwalk.Model{{
743 ID: "test-model",
744 }},
745 },
746 }
747
748 cfg := &Config{
749 Providers: csync.NewMapFrom(map[string]ProviderConfig{
750 "openai": {
751 BaseURL: "custom-url",
752 },
753 }),
754 }
755 cfg.setDefaults("/tmp")
756
757 env := env.NewFromMap(map[string]string{})
758 resolver := NewEnvironmentVariableResolver(env)
759 err := cfg.configureProviders(env, resolver, knownProviders)
760 require.NoError(t, err)
761
762 require.Equal(t, cfg.Providers.Len(), 0)
763 _, exists := cfg.Providers.Get("openai")
764 require.False(t, exists)
765 })
766
767 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
768 knownProviders := []catwalk.Provider{
769 {
770 ID: "openai",
771 APIKey: "$OPENAI_API_KEY",
772 APIEndpoint: "$MISSING_ENDPOINT",
773 Models: []catwalk.Model{{
774 ID: "test-model",
775 }},
776 },
777 }
778
779 cfg := &Config{
780 Providers: csync.NewMapFrom(map[string]ProviderConfig{
781 "openai": {
782 APIKey: "test-key",
783 },
784 }),
785 }
786 cfg.setDefaults("/tmp")
787
788 env := env.NewFromMap(map[string]string{
789 "OPENAI_API_KEY": "test-key",
790 })
791 resolver := NewEnvironmentVariableResolver(env)
792 err := cfg.configureProviders(env, resolver, knownProviders)
793 require.NoError(t, err)
794
795 require.Equal(t, cfg.Providers.Len(), 1)
796 _, exists := cfg.Providers.Get("openai")
797 require.True(t, exists)
798 })
799}
800
801func TestConfig_defaultModelSelection(t *testing.T) {
802 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
803 knownProviders := []catwalk.Provider{
804 {
805 ID: "openai",
806 APIKey: "abc",
807 DefaultLargeModelID: "large-model",
808 DefaultSmallModelID: "small-model",
809 Models: []catwalk.Model{
810 {
811 ID: "large-model",
812 DefaultMaxTokens: 1000,
813 },
814 {
815 ID: "small-model",
816 DefaultMaxTokens: 500,
817 },
818 },
819 },
820 }
821
822 cfg := &Config{}
823 cfg.setDefaults("/tmp")
824 env := env.NewFromMap(map[string]string{})
825 resolver := NewEnvironmentVariableResolver(env)
826 err := cfg.configureProviders(env, resolver, knownProviders)
827 require.NoError(t, err)
828
829 large, small, err := cfg.defaultModelSelection(knownProviders)
830 require.NoError(t, err)
831 require.Equal(t, "large-model", large.Model)
832 require.Equal(t, "openai", large.Provider)
833 require.Equal(t, int64(1000), large.MaxTokens)
834 require.Equal(t, "small-model", small.Model)
835 require.Equal(t, "openai", small.Provider)
836 require.Equal(t, int64(500), small.MaxTokens)
837 })
838 t.Run("should error if no providers configured", func(t *testing.T) {
839 knownProviders := []catwalk.Provider{
840 {
841 ID: "openai",
842 APIKey: "$MISSING_KEY",
843 DefaultLargeModelID: "large-model",
844 DefaultSmallModelID: "small-model",
845 Models: []catwalk.Model{
846 {
847 ID: "large-model",
848 DefaultMaxTokens: 1000,
849 },
850 {
851 ID: "small-model",
852 DefaultMaxTokens: 500,
853 },
854 },
855 },
856 }
857
858 cfg := &Config{}
859 cfg.setDefaults("/tmp")
860 env := env.NewFromMap(map[string]string{})
861 resolver := NewEnvironmentVariableResolver(env)
862 err := cfg.configureProviders(env, resolver, knownProviders)
863 require.NoError(t, err)
864
865 _, _, err = cfg.defaultModelSelection(knownProviders)
866 require.Error(t, err)
867 })
868 t.Run("should error if model is missing", func(t *testing.T) {
869 knownProviders := []catwalk.Provider{
870 {
871 ID: "openai",
872 APIKey: "abc",
873 DefaultLargeModelID: "large-model",
874 DefaultSmallModelID: "small-model",
875 Models: []catwalk.Model{
876 {
877 ID: "not-large-model",
878 DefaultMaxTokens: 1000,
879 },
880 {
881 ID: "small-model",
882 DefaultMaxTokens: 500,
883 },
884 },
885 },
886 }
887
888 cfg := &Config{}
889 cfg.setDefaults("/tmp")
890 env := env.NewFromMap(map[string]string{})
891 resolver := NewEnvironmentVariableResolver(env)
892 err := cfg.configureProviders(env, resolver, knownProviders)
893 require.NoError(t, err)
894 _, _, err = cfg.defaultModelSelection(knownProviders)
895 require.Error(t, err)
896 })
897
898 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
899 knownProviders := []catwalk.Provider{
900 {
901 ID: "openai",
902 APIKey: "$MISSING", // will not be included in the config
903 DefaultLargeModelID: "large-model",
904 DefaultSmallModelID: "small-model",
905 Models: []catwalk.Model{
906 {
907 ID: "not-large-model",
908 DefaultMaxTokens: 1000,
909 },
910 {
911 ID: "small-model",
912 DefaultMaxTokens: 500,
913 },
914 },
915 },
916 }
917
918 cfg := &Config{
919 Providers: csync.NewMapFrom(map[string]ProviderConfig{
920 "custom": {
921 APIKey: "test-key",
922 BaseURL: "https://api.custom.com/v1",
923 Models: []catwalk.Model{
924 {
925 ID: "model",
926 DefaultMaxTokens: 600,
927 },
928 },
929 },
930 }),
931 }
932 cfg.setDefaults("/tmp")
933 env := env.NewFromMap(map[string]string{})
934 resolver := NewEnvironmentVariableResolver(env)
935 err := cfg.configureProviders(env, resolver, knownProviders)
936 require.NoError(t, err)
937 large, small, err := cfg.defaultModelSelection(knownProviders)
938 require.NoError(t, err)
939 require.Equal(t, "model", large.Model)
940 require.Equal(t, "custom", large.Provider)
941 require.Equal(t, int64(600), large.MaxTokens)
942 require.Equal(t, "model", small.Model)
943 require.Equal(t, "custom", small.Provider)
944 require.Equal(t, int64(600), small.MaxTokens)
945 })
946
947 t.Run("should fail if no model configured", func(t *testing.T) {
948 knownProviders := []catwalk.Provider{
949 {
950 ID: "openai",
951 APIKey: "$MISSING", // will not be included in the config
952 DefaultLargeModelID: "large-model",
953 DefaultSmallModelID: "small-model",
954 Models: []catwalk.Model{
955 {
956 ID: "not-large-model",
957 DefaultMaxTokens: 1000,
958 },
959 {
960 ID: "small-model",
961 DefaultMaxTokens: 500,
962 },
963 },
964 },
965 }
966
967 cfg := &Config{
968 Providers: csync.NewMapFrom(map[string]ProviderConfig{
969 "custom": {
970 APIKey: "test-key",
971 BaseURL: "https://api.custom.com/v1",
972 Models: []catwalk.Model{},
973 },
974 }),
975 }
976 cfg.setDefaults("/tmp")
977 env := env.NewFromMap(map[string]string{})
978 resolver := NewEnvironmentVariableResolver(env)
979 err := cfg.configureProviders(env, resolver, knownProviders)
980 require.NoError(t, err)
981 _, _, err = cfg.defaultModelSelection(knownProviders)
982 require.Error(t, err)
983 })
984 t.Run("should use the default provider first", func(t *testing.T) {
985 knownProviders := []catwalk.Provider{
986 {
987 ID: "openai",
988 APIKey: "set",
989 DefaultLargeModelID: "large-model",
990 DefaultSmallModelID: "small-model",
991 Models: []catwalk.Model{
992 {
993 ID: "large-model",
994 DefaultMaxTokens: 1000,
995 },
996 {
997 ID: "small-model",
998 DefaultMaxTokens: 500,
999 },
1000 },
1001 },
1002 }
1003
1004 cfg := &Config{
1005 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1006 "custom": {
1007 APIKey: "test-key",
1008 BaseURL: "https://api.custom.com/v1",
1009 Models: []catwalk.Model{
1010 {
1011 ID: "large-model",
1012 DefaultMaxTokens: 1000,
1013 },
1014 },
1015 },
1016 }),
1017 }
1018 cfg.setDefaults("/tmp")
1019 env := env.NewFromMap(map[string]string{})
1020 resolver := NewEnvironmentVariableResolver(env)
1021 err := cfg.configureProviders(env, resolver, knownProviders)
1022 require.NoError(t, err)
1023 large, small, err := cfg.defaultModelSelection(knownProviders)
1024 require.NoError(t, err)
1025 require.Equal(t, "large-model", large.Model)
1026 require.Equal(t, "openai", large.Provider)
1027 require.Equal(t, int64(1000), large.MaxTokens)
1028 require.Equal(t, "small-model", small.Model)
1029 require.Equal(t, "openai", small.Provider)
1030 require.Equal(t, int64(500), small.MaxTokens)
1031 })
1032}
1033
1034func TestConfig_configureSelectedModels(t *testing.T) {
1035 t.Run("should override defaults", func(t *testing.T) {
1036 knownProviders := []catwalk.Provider{
1037 {
1038 ID: "openai",
1039 APIKey: "abc",
1040 DefaultLargeModelID: "large-model",
1041 DefaultSmallModelID: "small-model",
1042 Models: []catwalk.Model{
1043 {
1044 ID: "larger-model",
1045 DefaultMaxTokens: 2000,
1046 },
1047 {
1048 ID: "large-model",
1049 DefaultMaxTokens: 1000,
1050 },
1051 {
1052 ID: "small-model",
1053 DefaultMaxTokens: 500,
1054 },
1055 },
1056 },
1057 }
1058
1059 cfg := &Config{
1060 Models: map[SelectedModelType]SelectedModel{
1061 "large": {
1062 Model: "larger-model",
1063 },
1064 },
1065 }
1066 cfg.setDefaults("/tmp")
1067 env := env.NewFromMap(map[string]string{})
1068 resolver := NewEnvironmentVariableResolver(env)
1069 err := cfg.configureProviders(env, resolver, knownProviders)
1070 require.NoError(t, err)
1071
1072 err = cfg.configureSelectedModels(knownProviders)
1073 require.NoError(t, err)
1074 large := cfg.Models[SelectedModelTypeLarge]
1075 small := cfg.Models[SelectedModelTypeSmall]
1076 require.Equal(t, "larger-model", large.Model)
1077 require.Equal(t, "openai", large.Provider)
1078 require.Equal(t, int64(2000), large.MaxTokens)
1079 require.Equal(t, "small-model", small.Model)
1080 require.Equal(t, "openai", small.Provider)
1081 require.Equal(t, int64(500), small.MaxTokens)
1082 })
1083 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1084 knownProviders := []catwalk.Provider{
1085 {
1086 ID: "openai",
1087 APIKey: "abc",
1088 DefaultLargeModelID: "large-model",
1089 DefaultSmallModelID: "small-model",
1090 Models: []catwalk.Model{
1091 {
1092 ID: "large-model",
1093 DefaultMaxTokens: 1000,
1094 },
1095 {
1096 ID: "small-model",
1097 DefaultMaxTokens: 500,
1098 },
1099 },
1100 },
1101 {
1102 ID: "anthropic",
1103 APIKey: "abc",
1104 DefaultLargeModelID: "a-large-model",
1105 DefaultSmallModelID: "a-small-model",
1106 Models: []catwalk.Model{
1107 {
1108 ID: "a-large-model",
1109 DefaultMaxTokens: 1000,
1110 },
1111 {
1112 ID: "a-small-model",
1113 DefaultMaxTokens: 200,
1114 },
1115 },
1116 },
1117 }
1118
1119 cfg := &Config{
1120 Models: map[SelectedModelType]SelectedModel{
1121 "small": {
1122 Model: "a-small-model",
1123 Provider: "anthropic",
1124 MaxTokens: 300,
1125 },
1126 },
1127 }
1128 cfg.setDefaults("/tmp")
1129 env := env.NewFromMap(map[string]string{})
1130 resolver := NewEnvironmentVariableResolver(env)
1131 err := cfg.configureProviders(env, resolver, knownProviders)
1132 require.NoError(t, err)
1133
1134 err = cfg.configureSelectedModels(knownProviders)
1135 require.NoError(t, err)
1136 large := cfg.Models[SelectedModelTypeLarge]
1137 small := cfg.Models[SelectedModelTypeSmall]
1138 require.Equal(t, "large-model", large.Model)
1139 require.Equal(t, "openai", large.Provider)
1140 require.Equal(t, int64(1000), large.MaxTokens)
1141 require.Equal(t, "a-small-model", small.Model)
1142 require.Equal(t, "anthropic", small.Provider)
1143 require.Equal(t, int64(300), small.MaxTokens)
1144 })
1145
1146 t.Run("should override the max tokens only", func(t *testing.T) {
1147 knownProviders := []catwalk.Provider{
1148 {
1149 ID: "openai",
1150 APIKey: "abc",
1151 DefaultLargeModelID: "large-model",
1152 DefaultSmallModelID: "small-model",
1153 Models: []catwalk.Model{
1154 {
1155 ID: "large-model",
1156 DefaultMaxTokens: 1000,
1157 },
1158 {
1159 ID: "small-model",
1160 DefaultMaxTokens: 500,
1161 },
1162 },
1163 },
1164 }
1165
1166 cfg := &Config{
1167 Models: map[SelectedModelType]SelectedModel{
1168 "large": {
1169 MaxTokens: 100,
1170 },
1171 },
1172 }
1173 cfg.setDefaults("/tmp")
1174 env := env.NewFromMap(map[string]string{})
1175 resolver := NewEnvironmentVariableResolver(env)
1176 err := cfg.configureProviders(env, resolver, knownProviders)
1177 require.NoError(t, err)
1178
1179 err = cfg.configureSelectedModels(knownProviders)
1180 require.NoError(t, err)
1181 large := cfg.Models[SelectedModelTypeLarge]
1182 require.Equal(t, "large-model", large.Model)
1183 require.Equal(t, "openai", large.Provider)
1184 require.Equal(t, int64(100), large.MaxTokens)
1185 })
1186}