1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/assert"
15)
16
17func TestMain(m *testing.M) {
18 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
19
20 exitVal := m.Run()
21 os.Exit(exitVal)
22}
23
24func TestConfig_LoadFromReaders(t *testing.T) {
25 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
26 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
27 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
28
29 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
30
31 assert.NoError(t, err)
32 assert.NotNil(t, loadedConfig)
33 assert.Equal(t, 1, loadedConfig.Providers.Len())
34 pc, _ := loadedConfig.Providers.Get("openai")
35 assert.Equal(t, "key2", pc.APIKey)
36 assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
37}
38
39func TestConfig_setDefaults(t *testing.T) {
40 cfg := &Config{}
41
42 cfg.setDefaults("/tmp")
43
44 assert.NotNil(t, cfg.Options)
45 assert.NotNil(t, cfg.Options.TUI)
46 assert.NotNil(t, cfg.Options.ContextPaths)
47 assert.NotNil(t, cfg.Providers)
48 assert.NotNil(t, cfg.Models)
49 assert.NotNil(t, cfg.LSP)
50 assert.NotNil(t, cfg.MCP)
51 assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
52 for _, path := range defaultContextPaths {
53 assert.Contains(t, cfg.Options.ContextPaths, path)
54 }
55 assert.Equal(t, "/tmp", cfg.workingDir)
56}
57
58func TestConfig_configureProviders(t *testing.T) {
59 knownProviders := []catwalk.Provider{
60 {
61 ID: "openai",
62 APIKey: "$OPENAI_API_KEY",
63 APIEndpoint: "https://api.openai.com/v1",
64 Models: []catwalk.Model{{
65 ID: "test-model",
66 }},
67 },
68 }
69
70 cfg := &Config{}
71 cfg.setDefaults("/tmp")
72 env := env.NewFromMap(map[string]string{
73 "OPENAI_API_KEY": "test-key",
74 })
75 resolver := NewEnvironmentVariableResolver(env)
76 err := cfg.configureProviders(env, resolver, knownProviders)
77 assert.NoError(t, err)
78 assert.Equal(t, 1, cfg.Providers.Len())
79
80 // We want to make sure that we keep the configured API key as a placeholder
81 pc, _ := cfg.Providers.Get("openai")
82 assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
83}
84
85func TestConfig_configureProvidersWithOverride(t *testing.T) {
86 knownProviders := []catwalk.Provider{
87 {
88 ID: "openai",
89 APIKey: "$OPENAI_API_KEY",
90 APIEndpoint: "https://api.openai.com/v1",
91 Models: []catwalk.Model{{
92 ID: "test-model",
93 }},
94 },
95 }
96
97 cfg := &Config{
98 Providers: csync.NewMap[string, ProviderConfig](),
99 }
100 cfg.Providers.Set("openai", ProviderConfig{
101 APIKey: "xyz",
102 BaseURL: "https://api.openai.com/v2",
103 Models: []catwalk.Model{
104 {
105 ID: "test-model",
106 Name: "Updated",
107 },
108 {
109 ID: "another-model",
110 },
111 },
112 })
113 cfg.setDefaults("/tmp")
114
115 env := env.NewFromMap(map[string]string{
116 "OPENAI_API_KEY": "test-key",
117 })
118 resolver := NewEnvironmentVariableResolver(env)
119 err := cfg.configureProviders(env, resolver, knownProviders)
120 assert.NoError(t, err)
121 assert.Equal(t, 1, cfg.Providers.Len())
122
123 // We want to make sure that we keep the configured API key as a placeholder
124 pc, _ := cfg.Providers.Get("openai")
125 assert.Equal(t, "xyz", pc.APIKey)
126 assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
127 assert.Len(t, pc.Models, 2)
128 assert.Equal(t, "Updated", pc.Models[0].Name)
129}
130
131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
132 knownProviders := []catwalk.Provider{
133 {
134 ID: "openai",
135 APIKey: "$OPENAI_API_KEY",
136 APIEndpoint: "https://api.openai.com/v1",
137 Models: []catwalk.Model{{
138 ID: "test-model",
139 }},
140 },
141 }
142
143 cfg := &Config{
144 Providers: csync.NewMapFrom(map[string]ProviderConfig{
145 "custom": {
146 APIKey: "xyz",
147 BaseURL: "https://api.someendpoint.com/v2",
148 Models: []catwalk.Model{
149 {
150 ID: "test-model",
151 },
152 },
153 },
154 }),
155 }
156 cfg.setDefaults("/tmp")
157 env := env.NewFromMap(map[string]string{
158 "OPENAI_API_KEY": "test-key",
159 })
160 resolver := NewEnvironmentVariableResolver(env)
161 err := cfg.configureProviders(env, resolver, knownProviders)
162 assert.NoError(t, err)
163 // Should be to because of the env variable
164 assert.Equal(t, cfg.Providers.Len(), 2)
165
166 // We want to make sure that we keep the configured API key as a placeholder
167 pc, _ := cfg.Providers.Get("custom")
168 assert.Equal(t, "xyz", pc.APIKey)
169 // Make sure we set the ID correctly
170 assert.Equal(t, "custom", pc.ID)
171 assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
172 assert.Len(t, pc.Models, 1)
173
174 _, ok := cfg.Providers.Get("openai")
175 assert.True(t, ok, "OpenAI provider should still be present")
176}
177
178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
179 knownProviders := []catwalk.Provider{
180 {
181 ID: catwalk.InferenceProviderBedrock,
182 APIKey: "",
183 APIEndpoint: "",
184 Models: []catwalk.Model{{
185 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
186 }},
187 },
188 }
189
190 cfg := &Config{}
191 cfg.setDefaults("/tmp")
192 env := env.NewFromMap(map[string]string{
193 "AWS_ACCESS_KEY_ID": "test-key-id",
194 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
195 })
196 resolver := NewEnvironmentVariableResolver(env)
197 err := cfg.configureProviders(env, resolver, knownProviders)
198 assert.NoError(t, err)
199 assert.Equal(t, cfg.Providers.Len(), 1)
200
201 bedrockProvider, ok := cfg.Providers.Get("bedrock")
202 assert.True(t, ok, "Bedrock provider should be present")
203 assert.Len(t, bedrockProvider.Models, 1)
204 assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
205}
206
207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
208 knownProviders := []catwalk.Provider{
209 {
210 ID: catwalk.InferenceProviderBedrock,
211 APIKey: "",
212 APIEndpoint: "",
213 Models: []catwalk.Model{{
214 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
215 }},
216 },
217 }
218
219 cfg := &Config{}
220 cfg.setDefaults("/tmp")
221 env := env.NewFromMap(map[string]string{})
222 resolver := NewEnvironmentVariableResolver(env)
223 err := cfg.configureProviders(env, resolver, knownProviders)
224 assert.NoError(t, err)
225 // Provider should not be configured without credentials
226 assert.Equal(t, cfg.Providers.Len(), 0)
227}
228
229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
230 knownProviders := []catwalk.Provider{
231 {
232 ID: catwalk.InferenceProviderBedrock,
233 APIKey: "",
234 APIEndpoint: "",
235 Models: []catwalk.Model{{
236 ID: "some-random-model",
237 }},
238 },
239 }
240
241 cfg := &Config{}
242 cfg.setDefaults("/tmp")
243 env := env.NewFromMap(map[string]string{
244 "AWS_ACCESS_KEY_ID": "test-key-id",
245 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
246 })
247 resolver := NewEnvironmentVariableResolver(env)
248 err := cfg.configureProviders(env, resolver, knownProviders)
249 assert.Error(t, err)
250}
251
252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
253 knownProviders := []catwalk.Provider{
254 {
255 ID: catwalk.InferenceProviderVertexAI,
256 APIKey: "",
257 APIEndpoint: "",
258 Models: []catwalk.Model{{
259 ID: "gemini-pro",
260 }},
261 },
262 }
263
264 cfg := &Config{}
265 cfg.setDefaults("/tmp")
266 env := env.NewFromMap(map[string]string{
267 "GOOGLE_GENAI_USE_VERTEXAI": "true",
268 "GOOGLE_CLOUD_PROJECT": "test-project",
269 "GOOGLE_CLOUD_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 assert.NoError(t, err)
274 assert.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 assert.True(t, ok, "VertexAI provider should be present")
278 assert.Len(t, vertexProvider.Models, 1)
279 assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 assert.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 assert.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 assert.NoError(t, err)
331 // Provider should not be configured without project
332 assert.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 assert.NoError(t, err)
355 assert.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 assert.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 assert.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 assert.Len(t, enabled, 1)
401 assert.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 assert.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 assert.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 assert.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 assert.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
458 knownProviders := []catwalk.Provider{
459 {
460 ID: "openai",
461 APIKey: "$OPENAI_API_KEY",
462 APIEndpoint: "https://api.openai.com/v1",
463 Models: []catwalk.Model{{
464 ID: "test-model",
465 }},
466 },
467 }
468
469 cfg := &Config{
470 Providers: csync.NewMapFrom(map[string]ProviderConfig{
471 "openai": {
472 Disable: true,
473 },
474 }),
475 }
476 cfg.setDefaults("/tmp")
477
478 env := env.NewFromMap(map[string]string{
479 "OPENAI_API_KEY": "test-key",
480 })
481 resolver := NewEnvironmentVariableResolver(env)
482 err := cfg.configureProviders(env, resolver, knownProviders)
483 assert.NoError(t, err)
484
485 // Provider should be removed from config when disabled
486 assert.Equal(t, cfg.Providers.Len(), 0)
487 _, exists := cfg.Providers.Get("openai")
488 assert.False(t, exists)
489}
490
491func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
492 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
493 cfg := &Config{
494 Providers: csync.NewMapFrom(map[string]ProviderConfig{
495 "custom": {
496 BaseURL: "https://api.custom.com/v1",
497 Models: []catwalk.Model{{
498 ID: "test-model",
499 }},
500 },
501 "openai": {
502 APIKey: "$MISSING",
503 },
504 }),
505 }
506 cfg.setDefaults("/tmp")
507
508 env := env.NewFromMap(map[string]string{})
509 resolver := NewEnvironmentVariableResolver(env)
510 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
511 assert.NoError(t, err)
512
513 assert.Equal(t, cfg.Providers.Len(), 1)
514 _, exists := cfg.Providers.Get("custom")
515 assert.True(t, exists)
516 })
517
518 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
519 cfg := &Config{
520 Providers: csync.NewMapFrom(map[string]ProviderConfig{
521 "custom": {
522 APIKey: "test-key",
523 Models: []catwalk.Model{{
524 ID: "test-model",
525 }},
526 },
527 }),
528 }
529 cfg.setDefaults("/tmp")
530
531 env := env.NewFromMap(map[string]string{})
532 resolver := NewEnvironmentVariableResolver(env)
533 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
534 assert.NoError(t, err)
535
536 assert.Equal(t, cfg.Providers.Len(), 0)
537 _, exists := cfg.Providers.Get("custom")
538 assert.False(t, exists)
539 })
540
541 t.Run("custom provider with no models is removed", func(t *testing.T) {
542 cfg := &Config{
543 Providers: csync.NewMapFrom(map[string]ProviderConfig{
544 "custom": {
545 APIKey: "test-key",
546 BaseURL: "https://api.custom.com/v1",
547 Models: []catwalk.Model{},
548 },
549 }),
550 }
551 cfg.setDefaults("/tmp")
552
553 env := env.NewFromMap(map[string]string{})
554 resolver := NewEnvironmentVariableResolver(env)
555 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
556 assert.NoError(t, err)
557
558 assert.Equal(t, cfg.Providers.Len(), 0)
559 _, exists := cfg.Providers.Get("custom")
560 assert.False(t, exists)
561 })
562
563 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
564 cfg := &Config{
565 Providers: csync.NewMapFrom(map[string]ProviderConfig{
566 "custom": {
567 APIKey: "test-key",
568 BaseURL: "https://api.custom.com/v1",
569 Type: "unsupported",
570 Models: []catwalk.Model{{
571 ID: "test-model",
572 }},
573 },
574 }),
575 }
576 cfg.setDefaults("/tmp")
577
578 env := env.NewFromMap(map[string]string{})
579 resolver := NewEnvironmentVariableResolver(env)
580 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
581 assert.NoError(t, err)
582
583 assert.Equal(t, cfg.Providers.Len(), 0)
584 _, exists := cfg.Providers.Get("custom")
585 assert.False(t, exists)
586 })
587
588 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
589 cfg := &Config{
590 Providers: csync.NewMapFrom(map[string]ProviderConfig{
591 "custom": {
592 APIKey: "test-key",
593 BaseURL: "https://api.custom.com/v1",
594 Type: catwalk.TypeOpenAI,
595 Models: []catwalk.Model{{
596 ID: "test-model",
597 }},
598 },
599 }),
600 }
601 cfg.setDefaults("/tmp")
602
603 env := env.NewFromMap(map[string]string{})
604 resolver := NewEnvironmentVariableResolver(env)
605 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
606 assert.NoError(t, err)
607
608 assert.Equal(t, cfg.Providers.Len(), 1)
609 customProvider, exists := cfg.Providers.Get("custom")
610 assert.True(t, exists)
611 assert.Equal(t, "custom", customProvider.ID)
612 assert.Equal(t, "test-key", customProvider.APIKey)
613 assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
614 })
615
616 t.Run("custom anthropic provider is supported", func(t *testing.T) {
617 cfg := &Config{
618 Providers: csync.NewMapFrom(map[string]ProviderConfig{
619 "custom-anthropic": {
620 APIKey: "test-key",
621 BaseURL: "https://api.anthropic.com/v1",
622 Type: catwalk.TypeAnthropic,
623 Models: []catwalk.Model{{
624 ID: "claude-3-sonnet",
625 }},
626 },
627 }),
628 }
629 cfg.setDefaults("/tmp")
630
631 env := env.NewFromMap(map[string]string{})
632 resolver := NewEnvironmentVariableResolver(env)
633 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
634 assert.NoError(t, err)
635
636 assert.Equal(t, cfg.Providers.Len(), 1)
637 customProvider, exists := cfg.Providers.Get("custom-anthropic")
638 assert.True(t, exists)
639 assert.Equal(t, "custom-anthropic", customProvider.ID)
640 assert.Equal(t, "test-key", customProvider.APIKey)
641 assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
642 assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
643 })
644
645 t.Run("disabled custom provider is removed", func(t *testing.T) {
646 cfg := &Config{
647 Providers: csync.NewMapFrom(map[string]ProviderConfig{
648 "custom": {
649 APIKey: "test-key",
650 BaseURL: "https://api.custom.com/v1",
651 Type: catwalk.TypeOpenAI,
652 Disable: true,
653 Models: []catwalk.Model{{
654 ID: "test-model",
655 }},
656 },
657 }),
658 }
659 cfg.setDefaults("/tmp")
660
661 env := env.NewFromMap(map[string]string{})
662 resolver := NewEnvironmentVariableResolver(env)
663 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
664 assert.NoError(t, err)
665
666 assert.Equal(t, cfg.Providers.Len(), 0)
667 _, exists := cfg.Providers.Get("custom")
668 assert.False(t, exists)
669 })
670}
671
672func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
673 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
674 knownProviders := []catwalk.Provider{
675 {
676 ID: catwalk.InferenceProviderVertexAI,
677 APIKey: "",
678 APIEndpoint: "",
679 Models: []catwalk.Model{{
680 ID: "gemini-pro",
681 }},
682 },
683 }
684
685 cfg := &Config{
686 Providers: csync.NewMapFrom(map[string]ProviderConfig{
687 "vertexai": {
688 BaseURL: "custom-url",
689 },
690 }),
691 }
692 cfg.setDefaults("/tmp")
693
694 env := env.NewFromMap(map[string]string{
695 "GOOGLE_GENAI_USE_VERTEXAI": "false",
696 })
697 resolver := NewEnvironmentVariableResolver(env)
698 err := cfg.configureProviders(env, resolver, knownProviders)
699 assert.NoError(t, err)
700
701 assert.Equal(t, cfg.Providers.Len(), 0)
702 _, exists := cfg.Providers.Get("vertexai")
703 assert.False(t, exists)
704 })
705
706 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
707 knownProviders := []catwalk.Provider{
708 {
709 ID: catwalk.InferenceProviderBedrock,
710 APIKey: "",
711 APIEndpoint: "",
712 Models: []catwalk.Model{{
713 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
714 }},
715 },
716 }
717
718 cfg := &Config{
719 Providers: csync.NewMapFrom(map[string]ProviderConfig{
720 "bedrock": {
721 BaseURL: "custom-url",
722 },
723 }),
724 }
725 cfg.setDefaults("/tmp")
726
727 env := env.NewFromMap(map[string]string{})
728 resolver := NewEnvironmentVariableResolver(env)
729 err := cfg.configureProviders(env, resolver, knownProviders)
730 assert.NoError(t, err)
731
732 assert.Equal(t, cfg.Providers.Len(), 0)
733 _, exists := cfg.Providers.Get("bedrock")
734 assert.False(t, exists)
735 })
736
737 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
738 knownProviders := []catwalk.Provider{
739 {
740 ID: "openai",
741 APIKey: "$MISSING_API_KEY",
742 APIEndpoint: "https://api.openai.com/v1",
743 Models: []catwalk.Model{{
744 ID: "test-model",
745 }},
746 },
747 }
748
749 cfg := &Config{
750 Providers: csync.NewMapFrom(map[string]ProviderConfig{
751 "openai": {
752 BaseURL: "custom-url",
753 },
754 }),
755 }
756 cfg.setDefaults("/tmp")
757
758 env := env.NewFromMap(map[string]string{})
759 resolver := NewEnvironmentVariableResolver(env)
760 err := cfg.configureProviders(env, resolver, knownProviders)
761 assert.NoError(t, err)
762
763 assert.Equal(t, cfg.Providers.Len(), 0)
764 _, exists := cfg.Providers.Get("openai")
765 assert.False(t, exists)
766 })
767
768 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
769 knownProviders := []catwalk.Provider{
770 {
771 ID: "openai",
772 APIKey: "$OPENAI_API_KEY",
773 APIEndpoint: "$MISSING_ENDPOINT",
774 Models: []catwalk.Model{{
775 ID: "test-model",
776 }},
777 },
778 }
779
780 cfg := &Config{
781 Providers: csync.NewMapFrom(map[string]ProviderConfig{
782 "openai": {
783 APIKey: "test-key",
784 },
785 }),
786 }
787 cfg.setDefaults("/tmp")
788
789 env := env.NewFromMap(map[string]string{
790 "OPENAI_API_KEY": "test-key",
791 })
792 resolver := NewEnvironmentVariableResolver(env)
793 err := cfg.configureProviders(env, resolver, knownProviders)
794 assert.NoError(t, err)
795
796 assert.Equal(t, cfg.Providers.Len(), 1)
797 _, exists := cfg.Providers.Get("openai")
798 assert.True(t, exists)
799 })
800}
801
802func TestConfig_defaultModelSelection(t *testing.T) {
803 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
804 knownProviders := []catwalk.Provider{
805 {
806 ID: "openai",
807 APIKey: "abc",
808 DefaultLargeModelID: "large-model",
809 DefaultSmallModelID: "small-model",
810 Models: []catwalk.Model{
811 {
812 ID: "large-model",
813 DefaultMaxTokens: 1000,
814 },
815 {
816 ID: "small-model",
817 DefaultMaxTokens: 500,
818 },
819 },
820 },
821 }
822
823 cfg := &Config{}
824 cfg.setDefaults("/tmp")
825 env := env.NewFromMap(map[string]string{})
826 resolver := NewEnvironmentVariableResolver(env)
827 err := cfg.configureProviders(env, resolver, knownProviders)
828 assert.NoError(t, err)
829
830 large, small, err := cfg.defaultModelSelection(knownProviders)
831 assert.NoError(t, err)
832 assert.Equal(t, "large-model", large.Model)
833 assert.Equal(t, "openai", large.Provider)
834 assert.Equal(t, int64(1000), large.MaxTokens)
835 assert.Equal(t, "small-model", small.Model)
836 assert.Equal(t, "openai", small.Provider)
837 assert.Equal(t, int64(500), small.MaxTokens)
838 })
839 t.Run("should error if no providers configured", func(t *testing.T) {
840 knownProviders := []catwalk.Provider{
841 {
842 ID: "openai",
843 APIKey: "$MISSING_KEY",
844 DefaultLargeModelID: "large-model",
845 DefaultSmallModelID: "small-model",
846 Models: []catwalk.Model{
847 {
848 ID: "large-model",
849 DefaultMaxTokens: 1000,
850 },
851 {
852 ID: "small-model",
853 DefaultMaxTokens: 500,
854 },
855 },
856 },
857 }
858
859 cfg := &Config{}
860 cfg.setDefaults("/tmp")
861 env := env.NewFromMap(map[string]string{})
862 resolver := NewEnvironmentVariableResolver(env)
863 err := cfg.configureProviders(env, resolver, knownProviders)
864 assert.NoError(t, err)
865
866 _, _, err = cfg.defaultModelSelection(knownProviders)
867 assert.Error(t, err)
868 })
869 t.Run("should error if model is missing", func(t *testing.T) {
870 knownProviders := []catwalk.Provider{
871 {
872 ID: "openai",
873 APIKey: "abc",
874 DefaultLargeModelID: "large-model",
875 DefaultSmallModelID: "small-model",
876 Models: []catwalk.Model{
877 {
878 ID: "not-large-model",
879 DefaultMaxTokens: 1000,
880 },
881 {
882 ID: "small-model",
883 DefaultMaxTokens: 500,
884 },
885 },
886 },
887 }
888
889 cfg := &Config{}
890 cfg.setDefaults("/tmp")
891 env := env.NewFromMap(map[string]string{})
892 resolver := NewEnvironmentVariableResolver(env)
893 err := cfg.configureProviders(env, resolver, knownProviders)
894 assert.NoError(t, err)
895 _, _, err = cfg.defaultModelSelection(knownProviders)
896 assert.Error(t, err)
897 })
898
899 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
900 knownProviders := []catwalk.Provider{
901 {
902 ID: "openai",
903 APIKey: "$MISSING", // will not be included in the config
904 DefaultLargeModelID: "large-model",
905 DefaultSmallModelID: "small-model",
906 Models: []catwalk.Model{
907 {
908 ID: "not-large-model",
909 DefaultMaxTokens: 1000,
910 },
911 {
912 ID: "small-model",
913 DefaultMaxTokens: 500,
914 },
915 },
916 },
917 }
918
919 cfg := &Config{
920 Providers: csync.NewMapFrom(map[string]ProviderConfig{
921 "custom": {
922 APIKey: "test-key",
923 BaseURL: "https://api.custom.com/v1",
924 Models: []catwalk.Model{
925 {
926 ID: "model",
927 DefaultMaxTokens: 600,
928 },
929 },
930 },
931 }),
932 }
933 cfg.setDefaults("/tmp")
934 env := env.NewFromMap(map[string]string{})
935 resolver := NewEnvironmentVariableResolver(env)
936 err := cfg.configureProviders(env, resolver, knownProviders)
937 assert.NoError(t, err)
938 large, small, err := cfg.defaultModelSelection(knownProviders)
939 assert.NoError(t, err)
940 assert.Equal(t, "model", large.Model)
941 assert.Equal(t, "custom", large.Provider)
942 assert.Equal(t, int64(600), large.MaxTokens)
943 assert.Equal(t, "model", small.Model)
944 assert.Equal(t, "custom", small.Provider)
945 assert.Equal(t, int64(600), small.MaxTokens)
946 })
947
948 t.Run("should fail if no model configured", func(t *testing.T) {
949 knownProviders := []catwalk.Provider{
950 {
951 ID: "openai",
952 APIKey: "$MISSING", // will not be included in the config
953 DefaultLargeModelID: "large-model",
954 DefaultSmallModelID: "small-model",
955 Models: []catwalk.Model{
956 {
957 ID: "not-large-model",
958 DefaultMaxTokens: 1000,
959 },
960 {
961 ID: "small-model",
962 DefaultMaxTokens: 500,
963 },
964 },
965 },
966 }
967
968 cfg := &Config{
969 Providers: csync.NewMapFrom(map[string]ProviderConfig{
970 "custom": {
971 APIKey: "test-key",
972 BaseURL: "https://api.custom.com/v1",
973 Models: []catwalk.Model{},
974 },
975 }),
976 }
977 cfg.setDefaults("/tmp")
978 env := env.NewFromMap(map[string]string{})
979 resolver := NewEnvironmentVariableResolver(env)
980 err := cfg.configureProviders(env, resolver, knownProviders)
981 assert.NoError(t, err)
982 _, _, err = cfg.defaultModelSelection(knownProviders)
983 assert.Error(t, err)
984 })
985 t.Run("should use the default provider first", func(t *testing.T) {
986 knownProviders := []catwalk.Provider{
987 {
988 ID: "openai",
989 APIKey: "set",
990 DefaultLargeModelID: "large-model",
991 DefaultSmallModelID: "small-model",
992 Models: []catwalk.Model{
993 {
994 ID: "large-model",
995 DefaultMaxTokens: 1000,
996 },
997 {
998 ID: "small-model",
999 DefaultMaxTokens: 500,
1000 },
1001 },
1002 },
1003 }
1004
1005 cfg := &Config{
1006 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1007 "custom": {
1008 APIKey: "test-key",
1009 BaseURL: "https://api.custom.com/v1",
1010 Models: []catwalk.Model{
1011 {
1012 ID: "large-model",
1013 DefaultMaxTokens: 1000,
1014 },
1015 },
1016 },
1017 }),
1018 }
1019 cfg.setDefaults("/tmp")
1020 env := env.NewFromMap(map[string]string{})
1021 resolver := NewEnvironmentVariableResolver(env)
1022 err := cfg.configureProviders(env, resolver, knownProviders)
1023 assert.NoError(t, err)
1024 large, small, err := cfg.defaultModelSelection(knownProviders)
1025 assert.NoError(t, err)
1026 assert.Equal(t, "large-model", large.Model)
1027 assert.Equal(t, "openai", large.Provider)
1028 assert.Equal(t, int64(1000), large.MaxTokens)
1029 assert.Equal(t, "small-model", small.Model)
1030 assert.Equal(t, "openai", small.Provider)
1031 assert.Equal(t, int64(500), small.MaxTokens)
1032 })
1033}
1034
1035func TestConfig_configureSelectedModels(t *testing.T) {
1036 t.Run("should override defaults", func(t *testing.T) {
1037 knownProviders := []catwalk.Provider{
1038 {
1039 ID: "openai",
1040 APIKey: "abc",
1041 DefaultLargeModelID: "large-model",
1042 DefaultSmallModelID: "small-model",
1043 Models: []catwalk.Model{
1044 {
1045 ID: "larger-model",
1046 DefaultMaxTokens: 2000,
1047 },
1048 {
1049 ID: "large-model",
1050 DefaultMaxTokens: 1000,
1051 },
1052 {
1053 ID: "small-model",
1054 DefaultMaxTokens: 500,
1055 },
1056 },
1057 },
1058 }
1059
1060 cfg := &Config{
1061 Models: map[SelectedModelType]SelectedModel{
1062 "large": {
1063 Model: "larger-model",
1064 },
1065 },
1066 }
1067 cfg.setDefaults("/tmp")
1068 env := env.NewFromMap(map[string]string{})
1069 resolver := NewEnvironmentVariableResolver(env)
1070 err := cfg.configureProviders(env, resolver, knownProviders)
1071 assert.NoError(t, err)
1072
1073 err = cfg.configureSelectedModels(knownProviders)
1074 assert.NoError(t, err)
1075 large := cfg.Models[SelectedModelTypeLarge]
1076 small := cfg.Models[SelectedModelTypeSmall]
1077 assert.Equal(t, "larger-model", large.Model)
1078 assert.Equal(t, "openai", large.Provider)
1079 assert.Equal(t, int64(2000), large.MaxTokens)
1080 assert.Equal(t, "small-model", small.Model)
1081 assert.Equal(t, "openai", small.Provider)
1082 assert.Equal(t, int64(500), small.MaxTokens)
1083 })
1084 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1085 knownProviders := []catwalk.Provider{
1086 {
1087 ID: "openai",
1088 APIKey: "abc",
1089 DefaultLargeModelID: "large-model",
1090 DefaultSmallModelID: "small-model",
1091 Models: []catwalk.Model{
1092 {
1093 ID: "large-model",
1094 DefaultMaxTokens: 1000,
1095 },
1096 {
1097 ID: "small-model",
1098 DefaultMaxTokens: 500,
1099 },
1100 },
1101 },
1102 {
1103 ID: "anthropic",
1104 APIKey: "abc",
1105 DefaultLargeModelID: "a-large-model",
1106 DefaultSmallModelID: "a-small-model",
1107 Models: []catwalk.Model{
1108 {
1109 ID: "a-large-model",
1110 DefaultMaxTokens: 1000,
1111 },
1112 {
1113 ID: "a-small-model",
1114 DefaultMaxTokens: 200,
1115 },
1116 },
1117 },
1118 }
1119
1120 cfg := &Config{
1121 Models: map[SelectedModelType]SelectedModel{
1122 "small": {
1123 Model: "a-small-model",
1124 Provider: "anthropic",
1125 MaxTokens: 300,
1126 },
1127 },
1128 }
1129 cfg.setDefaults("/tmp")
1130 env := env.NewFromMap(map[string]string{})
1131 resolver := NewEnvironmentVariableResolver(env)
1132 err := cfg.configureProviders(env, resolver, knownProviders)
1133 assert.NoError(t, err)
1134
1135 err = cfg.configureSelectedModels(knownProviders)
1136 assert.NoError(t, err)
1137 large := cfg.Models[SelectedModelTypeLarge]
1138 small := cfg.Models[SelectedModelTypeSmall]
1139 assert.Equal(t, "large-model", large.Model)
1140 assert.Equal(t, "openai", large.Provider)
1141 assert.Equal(t, int64(1000), large.MaxTokens)
1142 assert.Equal(t, "a-small-model", small.Model)
1143 assert.Equal(t, "anthropic", small.Provider)
1144 assert.Equal(t, int64(300), small.MaxTokens)
1145 })
1146
1147 t.Run("should override the max tokens only", func(t *testing.T) {
1148 knownProviders := []catwalk.Provider{
1149 {
1150 ID: "openai",
1151 APIKey: "abc",
1152 DefaultLargeModelID: "large-model",
1153 DefaultSmallModelID: "small-model",
1154 Models: []catwalk.Model{
1155 {
1156 ID: "large-model",
1157 DefaultMaxTokens: 1000,
1158 },
1159 {
1160 ID: "small-model",
1161 DefaultMaxTokens: 500,
1162 },
1163 },
1164 },
1165 }
1166
1167 cfg := &Config{
1168 Models: map[SelectedModelType]SelectedModel{
1169 "large": {
1170 MaxTokens: 100,
1171 },
1172 },
1173 }
1174 cfg.setDefaults("/tmp")
1175 env := env.NewFromMap(map[string]string{})
1176 resolver := NewEnvironmentVariableResolver(env)
1177 err := cfg.configureProviders(env, resolver, knownProviders)
1178 assert.NoError(t, err)
1179
1180 err = cfg.configureSelectedModels(knownProviders)
1181 assert.NoError(t, err)
1182 large := cfg.Models[SelectedModelTypeLarge]
1183 assert.Equal(t, "large-model", large.Model)
1184 assert.Equal(t, "openai", large.Provider)
1185 assert.Equal(t, int64(100), large.MaxTokens)
1186 })
1187}