1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/charmbracelet/crush/internal/llm/agent"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/resolver"
17 "github.com/stretchr/testify/assert"
18)
19
20func TestMain(m *testing.M) {
21 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
22
23 exitVal := m.Run()
24 os.Exit(exitVal)
25}
26
27func TestConfig_LoadFromReaders(t *testing.T) {
28 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
29 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
30 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
31
32 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
33
34 assert.NoError(t, err)
35 assert.NotNil(t, loadedConfig)
36 assert.Equal(t, 1, loadedConfig.Providers.Len())
37 pc, _ := loadedConfig.Providers.Get("openai")
38 assert.Equal(t, "key2", pc.APIKey)
39 assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
40}
41
42func TestConfig_setDefaults(t *testing.T) {
43 cfg := &Config{}
44
45 cfg.setDefaults("/tmp")
46
47 assert.NotNil(t, cfg.Options)
48 assert.NotNil(t, cfg.Options.TUI)
49 assert.NotNil(t, cfg.Options.ContextPaths)
50 assert.NotNil(t, cfg.Providers)
51 assert.NotNil(t, cfg.Models)
52 assert.NotNil(t, cfg.LSP)
53 assert.NotNil(t, cfg.MCP)
54 assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
55 for _, path := range defaultContextPaths {
56 assert.Contains(t, cfg.Options.ContextPaths, path)
57 }
58 assert.Equal(t, "/tmp", cfg.workingDir)
59}
60
61func TestConfig_configureProviders(t *testing.T) {
62 knownProviders := []catwalk.Provider{
63 {
64 ID: "openai",
65 APIKey: "$OPENAI_API_KEY",
66 APIEndpoint: "https://api.openai.com/v1",
67 Models: []catwalk.Model{{
68 ID: "test-model",
69 }},
70 },
71 }
72
73 cfg := &Config{}
74 cfg.setDefaults("/tmp")
75 env := env.NewFromMap(map[string]string{
76 "OPENAI_API_KEY": "test-key",
77 })
78 resolver := resolver.NewEnvironmentVariableResolver(env)
79 err := cfg.configureProviders(env, resolver, knownProviders)
80 assert.NoError(t, err)
81 assert.Equal(t, 1, cfg.Providers.Len())
82
83 // We want to make sure that we keep the configured API key as a placeholder
84 pc, _ := cfg.Providers.Get("openai")
85 assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
86}
87
88func TestConfig_configureProvidersWithOverride(t *testing.T) {
89 knownProviders := []catwalk.Provider{
90 {
91 ID: "openai",
92 APIKey: "$OPENAI_API_KEY",
93 APIEndpoint: "https://api.openai.com/v1",
94 Models: []catwalk.Model{{
95 ID: "test-model",
96 }},
97 },
98 }
99
100 cfg := &Config{
101 Providers: csync.NewMap[string, provider.Config](),
102 }
103 cfg.Providers.Set("openai", provider.Config{
104 APIKey: "xyz",
105 BaseURL: "https://api.openai.com/v2",
106 Models: []catwalk.Model{
107 {
108 ID: "test-model",
109 Name: "Updated",
110 },
111 {
112 ID: "another-model",
113 },
114 },
115 })
116 cfg.setDefaults("/tmp")
117
118 env := env.NewFromMap(map[string]string{
119 "OPENAI_API_KEY": "test-key",
120 })
121 resolver := resolver.NewEnvironmentVariableResolver(env)
122 err := cfg.configureProviders(env, resolver, knownProviders)
123 assert.NoError(t, err)
124 assert.Equal(t, 1, cfg.Providers.Len())
125
126 // We want to make sure that we keep the configured API key as a placeholder
127 pc, _ := cfg.Providers.Get("openai")
128 assert.Equal(t, "xyz", pc.APIKey)
129 assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
130 assert.Len(t, pc.Models, 2)
131 assert.Equal(t, "Updated", pc.Models[0].Name)
132}
133
134func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
135 knownProviders := []catwalk.Provider{
136 {
137 ID: "openai",
138 APIKey: "$OPENAI_API_KEY",
139 APIEndpoint: "https://api.openai.com/v1",
140 Models: []catwalk.Model{{
141 ID: "test-model",
142 }},
143 },
144 }
145
146 cfg := &Config{
147 Providers: csync.NewMapFrom(map[string]provider.Config{
148 "custom": {
149 APIKey: "xyz",
150 BaseURL: "https://api.someendpoint.com/v2",
151 Models: []catwalk.Model{
152 {
153 ID: "test-model",
154 },
155 },
156 },
157 }),
158 }
159 cfg.setDefaults("/tmp")
160 env := env.NewFromMap(map[string]string{
161 "OPENAI_API_KEY": "test-key",
162 })
163 resolver := resolver.NewEnvironmentVariableResolver(env)
164 err := cfg.configureProviders(env, resolver, knownProviders)
165 assert.NoError(t, err)
166 // Should be to because of the env variable
167 assert.Equal(t, cfg.Providers.Len(), 2)
168
169 // We want to make sure that we keep the configured API key as a placeholder
170 pc, _ := cfg.Providers.Get("custom")
171 assert.Equal(t, "xyz", pc.APIKey)
172 // Make sure we set the ID correctly
173 assert.Equal(t, "custom", pc.ID)
174 assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
175 assert.Len(t, pc.Models, 1)
176
177 _, ok := cfg.Providers.Get("openai")
178 assert.True(t, ok, "OpenAI provider should still be present")
179}
180
181func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
182 knownProviders := []catwalk.Provider{
183 {
184 ID: catwalk.InferenceProviderBedrock,
185 APIKey: "",
186 APIEndpoint: "",
187 Models: []catwalk.Model{{
188 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
189 }},
190 },
191 }
192
193 cfg := &Config{}
194 cfg.setDefaults("/tmp")
195 env := env.NewFromMap(map[string]string{
196 "AWS_ACCESS_KEY_ID": "test-key-id",
197 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
198 })
199 resolver := resolver.NewEnvironmentVariableResolver(env)
200 err := cfg.configureProviders(env, resolver, knownProviders)
201 assert.NoError(t, err)
202 assert.Equal(t, cfg.Providers.Len(), 1)
203
204 bedrockProvider, ok := cfg.Providers.Get("bedrock")
205 assert.True(t, ok, "Bedrock provider should be present")
206 assert.Len(t, bedrockProvider.Models, 1)
207 assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
208}
209
210func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
211 knownProviders := []catwalk.Provider{
212 {
213 ID: catwalk.InferenceProviderBedrock,
214 APIKey: "",
215 APIEndpoint: "",
216 Models: []catwalk.Model{{
217 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
218 }},
219 },
220 }
221
222 cfg := &Config{}
223 cfg.setDefaults("/tmp")
224 env := env.NewFromMap(map[string]string{})
225 resolver := resolver.NewEnvironmentVariableResolver(env)
226 err := cfg.configureProviders(env, resolver, knownProviders)
227 assert.NoError(t, err)
228 // Provider should not be configured without credentials
229 assert.Equal(t, cfg.Providers.Len(), 0)
230}
231
232func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
233 knownProviders := []catwalk.Provider{
234 {
235 ID: catwalk.InferenceProviderBedrock,
236 APIKey: "",
237 APIEndpoint: "",
238 Models: []catwalk.Model{{
239 ID: "some-random-model",
240 }},
241 },
242 }
243
244 cfg := &Config{}
245 cfg.setDefaults("/tmp")
246 env := env.NewFromMap(map[string]string{
247 "AWS_ACCESS_KEY_ID": "test-key-id",
248 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
249 })
250 resolver := resolver.NewEnvironmentVariableResolver(env)
251 err := cfg.configureProviders(env, resolver, knownProviders)
252 assert.Error(t, err)
253}
254
255func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
256 knownProviders := []catwalk.Provider{
257 {
258 ID: catwalk.InferenceProviderVertexAI,
259 APIKey: "",
260 APIEndpoint: "",
261 Models: []catwalk.Model{{
262 ID: "gemini-pro",
263 }},
264 },
265 }
266
267 cfg := &Config{}
268 cfg.setDefaults("/tmp")
269 env := env.NewFromMap(map[string]string{
270 "GOOGLE_GENAI_USE_VERTEXAI": "true",
271 "GOOGLE_CLOUD_PROJECT": "test-project",
272 "GOOGLE_CLOUD_LOCATION": "us-central1",
273 })
274 resolver := resolver.NewEnvironmentVariableResolver(env)
275 err := cfg.configureProviders(env, resolver, knownProviders)
276 assert.NoError(t, err)
277 assert.Equal(t, cfg.Providers.Len(), 1)
278
279 vertexProvider, ok := cfg.Providers.Get("vertexai")
280 assert.True(t, ok, "VertexAI provider should be present")
281 assert.Len(t, vertexProvider.Models, 1)
282 assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
283 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
284 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
285}
286
287func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
288 knownProviders := []catwalk.Provider{
289 {
290 ID: catwalk.InferenceProviderVertexAI,
291 APIKey: "",
292 APIEndpoint: "",
293 Models: []catwalk.Model{{
294 ID: "gemini-pro",
295 }},
296 },
297 }
298
299 cfg := &Config{}
300 cfg.setDefaults("/tmp")
301 env := env.NewFromMap(map[string]string{
302 "GOOGLE_GENAI_USE_VERTEXAI": "false",
303 "GOOGLE_CLOUD_PROJECT": "test-project",
304 "GOOGLE_CLOUD_LOCATION": "us-central1",
305 })
306 resolver := resolver.NewEnvironmentVariableResolver(env)
307 err := cfg.configureProviders(env, resolver, knownProviders)
308 assert.NoError(t, err)
309 // Provider should not be configured without proper credentials
310 assert.Equal(t, cfg.Providers.Len(), 0)
311}
312
313func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
314 knownProviders := []catwalk.Provider{
315 {
316 ID: catwalk.InferenceProviderVertexAI,
317 APIKey: "",
318 APIEndpoint: "",
319 Models: []catwalk.Model{{
320 ID: "gemini-pro",
321 }},
322 },
323 }
324
325 cfg := &Config{}
326 cfg.setDefaults("/tmp")
327 env := env.NewFromMap(map[string]string{
328 "GOOGLE_GENAI_USE_VERTEXAI": "true",
329 "GOOGLE_CLOUD_LOCATION": "us-central1",
330 })
331 resolver := resolver.NewEnvironmentVariableResolver(env)
332 err := cfg.configureProviders(env, resolver, knownProviders)
333 assert.NoError(t, err)
334 // Provider should not be configured without project
335 assert.Equal(t, cfg.Providers.Len(), 0)
336}
337
338func TestConfig_configureProvidersSetProviderID(t *testing.T) {
339 knownProviders := []catwalk.Provider{
340 {
341 ID: "openai",
342 APIKey: "$OPENAI_API_KEY",
343 APIEndpoint: "https://api.openai.com/v1",
344 Models: []catwalk.Model{{
345 ID: "test-model",
346 }},
347 },
348 }
349
350 cfg := &Config{}
351 cfg.setDefaults("/tmp")
352 env := env.NewFromMap(map[string]string{
353 "OPENAI_API_KEY": "test-key",
354 })
355 resolver := resolver.NewEnvironmentVariableResolver(env)
356 err := cfg.configureProviders(env, resolver, knownProviders)
357 assert.NoError(t, err)
358 assert.Equal(t, cfg.Providers.Len(), 1)
359
360 // Provider ID should be set
361 pc, _ := cfg.Providers.Get("openai")
362 assert.Equal(t, "openai", pc.ID)
363}
364
365func TestConfig_EnabledProviders(t *testing.T) {
366 t.Run("all providers enabled", func(t *testing.T) {
367 cfg := &Config{
368 Providers: csync.NewMapFrom(map[string]provider.Config{
369 "openai": {
370 ID: "openai",
371 APIKey: "key1",
372 Disable: false,
373 },
374 "anthropic": {
375 ID: "anthropic",
376 APIKey: "key2",
377 Disable: false,
378 },
379 }),
380 }
381
382 enabled := cfg.EnabledProviders()
383 assert.Len(t, enabled, 2)
384 })
385
386 t.Run("some providers disabled", func(t *testing.T) {
387 cfg := &Config{
388 Providers: csync.NewMapFrom(map[string]provider.Config{
389 "openai": {
390 ID: "openai",
391 APIKey: "key1",
392 Disable: false,
393 },
394 "anthropic": {
395 ID: "anthropic",
396 APIKey: "key2",
397 Disable: true,
398 },
399 }),
400 }
401
402 enabled := cfg.EnabledProviders()
403 assert.Len(t, enabled, 1)
404 assert.Equal(t, "openai", enabled[0].ID)
405 })
406
407 t.Run("empty providers map", func(t *testing.T) {
408 cfg := &Config{
409 Providers: csync.NewMap[string, provider.Config](),
410 }
411
412 enabled := cfg.EnabledProviders()
413 assert.Len(t, enabled, 0)
414 })
415}
416
417func TestConfig_IsConfigured(t *testing.T) {
418 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
419 cfg := &Config{
420 Providers: csync.NewMapFrom(map[string]provider.Config{
421 "openai": {
422 ID: "openai",
423 APIKey: "key1",
424 Disable: false,
425 },
426 }),
427 }
428
429 assert.True(t, cfg.IsConfigured())
430 })
431
432 t.Run("returns false when no providers are configured", func(t *testing.T) {
433 cfg := &Config{
434 Providers: csync.NewMap[string, provider.Config](),
435 }
436
437 assert.False(t, cfg.IsConfigured())
438 })
439
440 t.Run("returns false when all providers are disabled", func(t *testing.T) {
441 cfg := &Config{
442 Providers: csync.NewMapFrom(map[string]provider.Config{
443 "openai": {
444 ID: "openai",
445 APIKey: "key1",
446 Disable: true,
447 },
448 "anthropic": {
449 ID: "anthropic",
450 APIKey: "key2",
451 Disable: true,
452 },
453 }),
454 }
455
456 assert.False(t, cfg.IsConfigured())
457 })
458}
459
460func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
461 knownProviders := []catwalk.Provider{
462 {
463 ID: "openai",
464 APIKey: "$OPENAI_API_KEY",
465 APIEndpoint: "https://api.openai.com/v1",
466 Models: []catwalk.Model{{
467 ID: "test-model",
468 }},
469 },
470 }
471
472 cfg := &Config{
473 Providers: csync.NewMapFrom(map[string]provider.Config{
474 "openai": {
475 Disable: true,
476 },
477 }),
478 }
479 cfg.setDefaults("/tmp")
480
481 env := env.NewFromMap(map[string]string{
482 "OPENAI_API_KEY": "test-key",
483 })
484 resolver := resolver.NewEnvironmentVariableResolver(env)
485 err := cfg.configureProviders(env, resolver, knownProviders)
486 assert.NoError(t, err)
487
488 // Provider should be removed from config when disabled
489 assert.Equal(t, cfg.Providers.Len(), 0)
490 _, exists := cfg.Providers.Get("openai")
491 assert.False(t, exists)
492}
493
494func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
495 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
496 cfg := &Config{
497 Providers: csync.NewMapFrom(map[string]provider.Config{
498 "custom": {
499 BaseURL: "https://api.custom.com/v1",
500 Models: []catwalk.Model{{
501 ID: "test-model",
502 }},
503 },
504 "openai": {
505 APIKey: "$MISSING",
506 },
507 }),
508 }
509 cfg.setDefaults("/tmp")
510
511 env := env.NewFromMap(map[string]string{})
512 resolver := resolver.NewEnvironmentVariableResolver(env)
513 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
514 assert.NoError(t, err)
515
516 assert.Equal(t, cfg.Providers.Len(), 1)
517 _, exists := cfg.Providers.Get("custom")
518 assert.True(t, exists)
519 })
520
521 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
522 cfg := &Config{
523 Providers: csync.NewMapFrom(map[string]provider.Config{
524 "custom": {
525 APIKey: "test-key",
526 Models: []catwalk.Model{{
527 ID: "test-model",
528 }},
529 },
530 }),
531 }
532 cfg.setDefaults("/tmp")
533
534 env := env.NewFromMap(map[string]string{})
535 resolver := resolver.NewEnvironmentVariableResolver(env)
536 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
537 assert.NoError(t, err)
538
539 assert.Equal(t, cfg.Providers.Len(), 0)
540 _, exists := cfg.Providers.Get("custom")
541 assert.False(t, exists)
542 })
543
544 t.Run("custom provider with no models is removed", func(t *testing.T) {
545 cfg := &Config{
546 Providers: csync.NewMapFrom(map[string]provider.Config{
547 "custom": {
548 APIKey: "test-key",
549 BaseURL: "https://api.custom.com/v1",
550 Models: []catwalk.Model{},
551 },
552 }),
553 }
554 cfg.setDefaults("/tmp")
555
556 env := env.NewFromMap(map[string]string{})
557 resolver := resolver.NewEnvironmentVariableResolver(env)
558 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
559 assert.NoError(t, err)
560
561 assert.Equal(t, cfg.Providers.Len(), 0)
562 _, exists := cfg.Providers.Get("custom")
563 assert.False(t, exists)
564 })
565
566 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
567 cfg := &Config{
568 Providers: csync.NewMapFrom(map[string]provider.Config{
569 "custom": {
570 APIKey: "test-key",
571 BaseURL: "https://api.custom.com/v1",
572 Type: "unsupported",
573 Models: []catwalk.Model{{
574 ID: "test-model",
575 }},
576 },
577 }),
578 }
579 cfg.setDefaults("/tmp")
580
581 env := env.NewFromMap(map[string]string{})
582 resolver := resolver.NewEnvironmentVariableResolver(env)
583 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
584 assert.NoError(t, err)
585
586 assert.Equal(t, cfg.Providers.Len(), 0)
587 _, exists := cfg.Providers.Get("custom")
588 assert.False(t, exists)
589 })
590
591 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
592 cfg := &Config{
593 Providers: csync.NewMapFrom(map[string]provider.Config{
594 "custom": {
595 APIKey: "test-key",
596 BaseURL: "https://api.custom.com/v1",
597 Type: catwalk.TypeOpenAI,
598 Models: []catwalk.Model{{
599 ID: "test-model",
600 }},
601 },
602 }),
603 }
604 cfg.setDefaults("/tmp")
605
606 env := env.NewFromMap(map[string]string{})
607 resolver := resolver.NewEnvironmentVariableResolver(env)
608 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
609 assert.NoError(t, err)
610
611 assert.Equal(t, cfg.Providers.Len(), 1)
612 customProvider, exists := cfg.Providers.Get("custom")
613 assert.True(t, exists)
614 assert.Equal(t, "custom", customProvider.ID)
615 assert.Equal(t, "test-key", customProvider.APIKey)
616 assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
617 })
618
619 t.Run("custom anthropic provider is supported", func(t *testing.T) {
620 cfg := &Config{
621 Providers: csync.NewMapFrom(map[string]provider.Config{
622 "custom-anthropic": {
623 APIKey: "test-key",
624 BaseURL: "https://api.anthropic.com/v1",
625 Type: catwalk.TypeAnthropic,
626 Models: []catwalk.Model{{
627 ID: "claude-3-sonnet",
628 }},
629 },
630 }),
631 }
632 cfg.setDefaults("/tmp")
633
634 env := env.NewFromMap(map[string]string{})
635 resolver := resolver.NewEnvironmentVariableResolver(env)
636 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
637 assert.NoError(t, err)
638
639 assert.Equal(t, cfg.Providers.Len(), 1)
640 customProvider, exists := cfg.Providers.Get("custom-anthropic")
641 assert.True(t, exists)
642 assert.Equal(t, "custom-anthropic", customProvider.ID)
643 assert.Equal(t, "test-key", customProvider.APIKey)
644 assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
645 assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
646 })
647
648 t.Run("disabled custom provider is removed", func(t *testing.T) {
649 cfg := &Config{
650 Providers: csync.NewMapFrom(map[string]provider.Config{
651 "custom": {
652 APIKey: "test-key",
653 BaseURL: "https://api.custom.com/v1",
654 Type: catwalk.TypeOpenAI,
655 Disable: true,
656 Models: []catwalk.Model{{
657 ID: "test-model",
658 }},
659 },
660 }),
661 }
662 cfg.setDefaults("/tmp")
663
664 env := env.NewFromMap(map[string]string{})
665 resolver := resolver.NewEnvironmentVariableResolver(env)
666 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
667 assert.NoError(t, err)
668
669 assert.Equal(t, cfg.Providers.Len(), 0)
670 _, exists := cfg.Providers.Get("custom")
671 assert.False(t, exists)
672 })
673}
674
675func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
676 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
677 knownProviders := []catwalk.Provider{
678 {
679 ID: catwalk.InferenceProviderVertexAI,
680 APIKey: "",
681 APIEndpoint: "",
682 Models: []catwalk.Model{{
683 ID: "gemini-pro",
684 }},
685 },
686 }
687
688 cfg := &Config{
689 Providers: csync.NewMapFrom(map[string]provider.Config{
690 "vertexai": {
691 BaseURL: "custom-url",
692 },
693 }),
694 }
695 cfg.setDefaults("/tmp")
696
697 env := env.NewFromMap(map[string]string{
698 "GOOGLE_GENAI_USE_VERTEXAI": "false",
699 })
700 resolver := resolver.NewEnvironmentVariableResolver(env)
701 err := cfg.configureProviders(env, resolver, knownProviders)
702 assert.NoError(t, err)
703
704 assert.Equal(t, cfg.Providers.Len(), 0)
705 _, exists := cfg.Providers.Get("vertexai")
706 assert.False(t, exists)
707 })
708
709 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
710 knownProviders := []catwalk.Provider{
711 {
712 ID: catwalk.InferenceProviderBedrock,
713 APIKey: "",
714 APIEndpoint: "",
715 Models: []catwalk.Model{{
716 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
717 }},
718 },
719 }
720
721 cfg := &Config{
722 Providers: csync.NewMapFrom(map[string]provider.Config{
723 "bedrock": {
724 BaseURL: "custom-url",
725 },
726 }),
727 }
728 cfg.setDefaults("/tmp")
729
730 env := env.NewFromMap(map[string]string{})
731 resolver := resolver.NewEnvironmentVariableResolver(env)
732 err := cfg.configureProviders(env, resolver, knownProviders)
733 assert.NoError(t, err)
734
735 assert.Equal(t, cfg.Providers.Len(), 0)
736 _, exists := cfg.Providers.Get("bedrock")
737 assert.False(t, exists)
738 })
739
740 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
741 knownProviders := []catwalk.Provider{
742 {
743 ID: "openai",
744 APIKey: "$MISSING_API_KEY",
745 APIEndpoint: "https://api.openai.com/v1",
746 Models: []catwalk.Model{{
747 ID: "test-model",
748 }},
749 },
750 }
751
752 cfg := &Config{
753 Providers: csync.NewMapFrom(map[string]provider.Config{
754 "openai": {
755 BaseURL: "custom-url",
756 },
757 }),
758 }
759 cfg.setDefaults("/tmp")
760
761 env := env.NewFromMap(map[string]string{})
762 resolver := resolver.NewEnvironmentVariableResolver(env)
763 err := cfg.configureProviders(env, resolver, knownProviders)
764 assert.NoError(t, err)
765
766 assert.Equal(t, cfg.Providers.Len(), 0)
767 _, exists := cfg.Providers.Get("openai")
768 assert.False(t, exists)
769 })
770
771 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
772 knownProviders := []catwalk.Provider{
773 {
774 ID: "openai",
775 APIKey: "$OPENAI_API_KEY",
776 APIEndpoint: "$MISSING_ENDPOINT",
777 Models: []catwalk.Model{{
778 ID: "test-model",
779 }},
780 },
781 }
782
783 cfg := &Config{
784 Providers: csync.NewMapFrom(map[string]provider.Config{
785 "openai": {
786 APIKey: "test-key",
787 },
788 }),
789 }
790 cfg.setDefaults("/tmp")
791
792 env := env.NewFromMap(map[string]string{
793 "OPENAI_API_KEY": "test-key",
794 })
795 resolver := resolver.NewEnvironmentVariableResolver(env)
796 err := cfg.configureProviders(env, resolver, knownProviders)
797 assert.NoError(t, err)
798
799 assert.Equal(t, cfg.Providers.Len(), 1)
800 _, exists := cfg.Providers.Get("openai")
801 assert.True(t, exists)
802 })
803}
804
805func TestConfig_defaultModelSelection(t *testing.T) {
806 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
807 knownProviders := []catwalk.Provider{
808 {
809 ID: "openai",
810 APIKey: "abc",
811 DefaultLargeModelID: "large-model",
812 DefaultSmallModelID: "small-model",
813 Models: []catwalk.Model{
814 {
815 ID: "large-model",
816 DefaultMaxTokens: 1000,
817 },
818 {
819 ID: "small-model",
820 DefaultMaxTokens: 500,
821 },
822 },
823 },
824 }
825
826 cfg := &Config{}
827 cfg.setDefaults("/tmp")
828 env := env.NewFromMap(map[string]string{})
829 resolver := resolver.NewEnvironmentVariableResolver(env)
830 err := cfg.configureProviders(env, resolver, knownProviders)
831 assert.NoError(t, err)
832
833 large, small, err := cfg.defaultModelSelection(knownProviders)
834 assert.NoError(t, err)
835 assert.Equal(t, "large-model", large.Model)
836 assert.Equal(t, "openai", large.Provider)
837 assert.Equal(t, int64(1000), large.MaxTokens)
838 assert.Equal(t, "small-model", small.Model)
839 assert.Equal(t, "openai", small.Provider)
840 assert.Equal(t, int64(500), small.MaxTokens)
841 })
842 t.Run("should error if no providers configured", func(t *testing.T) {
843 knownProviders := []catwalk.Provider{
844 {
845 ID: "openai",
846 APIKey: "$MISSING_KEY",
847 DefaultLargeModelID: "large-model",
848 DefaultSmallModelID: "small-model",
849 Models: []catwalk.Model{
850 {
851 ID: "large-model",
852 DefaultMaxTokens: 1000,
853 },
854 {
855 ID: "small-model",
856 DefaultMaxTokens: 500,
857 },
858 },
859 },
860 }
861
862 cfg := &Config{}
863 cfg.setDefaults("/tmp")
864 env := env.NewFromMap(map[string]string{})
865 resolver := resolver.NewEnvironmentVariableResolver(env)
866 err := cfg.configureProviders(env, resolver, knownProviders)
867 assert.NoError(t, err)
868
869 _, _, err = cfg.defaultModelSelection(knownProviders)
870 assert.Error(t, err)
871 })
872 t.Run("should error if model is missing", func(t *testing.T) {
873 knownProviders := []catwalk.Provider{
874 {
875 ID: "openai",
876 APIKey: "abc",
877 DefaultLargeModelID: "large-model",
878 DefaultSmallModelID: "small-model",
879 Models: []catwalk.Model{
880 {
881 ID: "not-large-model",
882 DefaultMaxTokens: 1000,
883 },
884 {
885 ID: "small-model",
886 DefaultMaxTokens: 500,
887 },
888 },
889 },
890 }
891
892 cfg := &Config{}
893 cfg.setDefaults("/tmp")
894 env := env.NewFromMap(map[string]string{})
895 resolver := resolver.NewEnvironmentVariableResolver(env)
896 err := cfg.configureProviders(env, resolver, knownProviders)
897 assert.NoError(t, err)
898 _, _, err = cfg.defaultModelSelection(knownProviders)
899 assert.Error(t, err)
900 })
901
902 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
903 knownProviders := []catwalk.Provider{
904 {
905 ID: "openai",
906 APIKey: "$MISSING", // will not be included in the config
907 DefaultLargeModelID: "large-model",
908 DefaultSmallModelID: "small-model",
909 Models: []catwalk.Model{
910 {
911 ID: "not-large-model",
912 DefaultMaxTokens: 1000,
913 },
914 {
915 ID: "small-model",
916 DefaultMaxTokens: 500,
917 },
918 },
919 },
920 }
921
922 cfg := &Config{
923 Providers: csync.NewMapFrom(map[string]provider.Config{
924 "custom": {
925 APIKey: "test-key",
926 BaseURL: "https://api.custom.com/v1",
927 Models: []catwalk.Model{
928 {
929 ID: "model",
930 DefaultMaxTokens: 600,
931 },
932 },
933 },
934 }),
935 }
936 cfg.setDefaults("/tmp")
937 env := env.NewFromMap(map[string]string{})
938 resolver := resolver.NewEnvironmentVariableResolver(env)
939 err := cfg.configureProviders(env, resolver, knownProviders)
940 assert.NoError(t, err)
941 large, small, err := cfg.defaultModelSelection(knownProviders)
942 assert.NoError(t, err)
943 assert.Equal(t, "model", large.Model)
944 assert.Equal(t, "custom", large.Provider)
945 assert.Equal(t, int64(600), large.MaxTokens)
946 assert.Equal(t, "model", small.Model)
947 assert.Equal(t, "custom", small.Provider)
948 assert.Equal(t, int64(600), small.MaxTokens)
949 })
950
951 t.Run("should fail if no model configured", func(t *testing.T) {
952 knownProviders := []catwalk.Provider{
953 {
954 ID: "openai",
955 APIKey: "$MISSING", // will not be included in the config
956 DefaultLargeModelID: "large-model",
957 DefaultSmallModelID: "small-model",
958 Models: []catwalk.Model{
959 {
960 ID: "not-large-model",
961 DefaultMaxTokens: 1000,
962 },
963 {
964 ID: "small-model",
965 DefaultMaxTokens: 500,
966 },
967 },
968 },
969 }
970
971 cfg := &Config{
972 Providers: csync.NewMapFrom(map[string]provider.Config{
973 "custom": {
974 APIKey: "test-key",
975 BaseURL: "https://api.custom.com/v1",
976 Models: []catwalk.Model{},
977 },
978 }),
979 }
980 cfg.setDefaults("/tmp")
981 env := env.NewFromMap(map[string]string{})
982 resolver := resolver.NewEnvironmentVariableResolver(env)
983 err := cfg.configureProviders(env, resolver, knownProviders)
984 assert.NoError(t, err)
985 _, _, err = cfg.defaultModelSelection(knownProviders)
986 assert.Error(t, err)
987 })
988 t.Run("should use the default provider first", func(t *testing.T) {
989 knownProviders := []catwalk.Provider{
990 {
991 ID: "openai",
992 APIKey: "set",
993 DefaultLargeModelID: "large-model",
994 DefaultSmallModelID: "small-model",
995 Models: []catwalk.Model{
996 {
997 ID: "large-model",
998 DefaultMaxTokens: 1000,
999 },
1000 {
1001 ID: "small-model",
1002 DefaultMaxTokens: 500,
1003 },
1004 },
1005 },
1006 }
1007
1008 cfg := &Config{
1009 Providers: csync.NewMapFrom(map[string]provider.Config{
1010 "custom": {
1011 APIKey: "test-key",
1012 BaseURL: "https://api.custom.com/v1",
1013 Models: []catwalk.Model{
1014 {
1015 ID: "large-model",
1016 DefaultMaxTokens: 1000,
1017 },
1018 },
1019 },
1020 }),
1021 }
1022 cfg.setDefaults("/tmp")
1023 env := env.NewFromMap(map[string]string{})
1024 resolver := resolver.NewEnvironmentVariableResolver(env)
1025 err := cfg.configureProviders(env, resolver, knownProviders)
1026 assert.NoError(t, err)
1027 large, small, err := cfg.defaultModelSelection(knownProviders)
1028 assert.NoError(t, err)
1029 assert.Equal(t, "large-model", large.Model)
1030 assert.Equal(t, "openai", large.Provider)
1031 assert.Equal(t, int64(1000), large.MaxTokens)
1032 assert.Equal(t, "small-model", small.Model)
1033 assert.Equal(t, "openai", small.Provider)
1034 assert.Equal(t, int64(500), small.MaxTokens)
1035 })
1036}
1037
1038func TestConfig_configureSelectedModels(t *testing.T) {
1039 t.Run("should override defaults", func(t *testing.T) {
1040 knownProviders := []catwalk.Provider{
1041 {
1042 ID: "openai",
1043 APIKey: "abc",
1044 DefaultLargeModelID: "large-model",
1045 DefaultSmallModelID: "small-model",
1046 Models: []catwalk.Model{
1047 {
1048 ID: "larger-model",
1049 DefaultMaxTokens: 2000,
1050 },
1051 {
1052 ID: "large-model",
1053 DefaultMaxTokens: 1000,
1054 },
1055 {
1056 ID: "small-model",
1057 DefaultMaxTokens: 500,
1058 },
1059 },
1060 },
1061 }
1062
1063 cfg := &Config{
1064 Models: map[SelectedModelType]agent.Model{
1065 "large": {
1066 Model: "larger-model",
1067 },
1068 },
1069 }
1070 cfg.setDefaults("/tmp")
1071 env := env.NewFromMap(map[string]string{})
1072 resolver := resolver.NewEnvironmentVariableResolver(env)
1073 err := cfg.configureProviders(env, resolver, knownProviders)
1074 assert.NoError(t, err)
1075
1076 err = cfg.configureSelectedModels(knownProviders)
1077 assert.NoError(t, err)
1078 large := cfg.Models[SelectedModelTypeLarge]
1079 small := cfg.Models[SelectedModelTypeSmall]
1080 assert.Equal(t, "larger-model", large.Model)
1081 assert.Equal(t, "openai", large.Provider)
1082 assert.Equal(t, int64(2000), large.MaxTokens)
1083 assert.Equal(t, "small-model", small.Model)
1084 assert.Equal(t, "openai", small.Provider)
1085 assert.Equal(t, int64(500), small.MaxTokens)
1086 })
1087 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1088 knownProviders := []catwalk.Provider{
1089 {
1090 ID: "openai",
1091 APIKey: "abc",
1092 DefaultLargeModelID: "large-model",
1093 DefaultSmallModelID: "small-model",
1094 Models: []catwalk.Model{
1095 {
1096 ID: "large-model",
1097 DefaultMaxTokens: 1000,
1098 },
1099 {
1100 ID: "small-model",
1101 DefaultMaxTokens: 500,
1102 },
1103 },
1104 },
1105 {
1106 ID: "anthropic",
1107 APIKey: "abc",
1108 DefaultLargeModelID: "a-large-model",
1109 DefaultSmallModelID: "a-small-model",
1110 Models: []catwalk.Model{
1111 {
1112 ID: "a-large-model",
1113 DefaultMaxTokens: 1000,
1114 },
1115 {
1116 ID: "a-small-model",
1117 DefaultMaxTokens: 200,
1118 },
1119 },
1120 },
1121 }
1122
1123 cfg := &Config{
1124 Models: map[SelectedModelType]agent.Model{
1125 "small": {
1126 Model: "a-small-model",
1127 Provider: "anthropic",
1128 MaxTokens: 300,
1129 },
1130 },
1131 }
1132 cfg.setDefaults("/tmp")
1133 env := env.NewFromMap(map[string]string{})
1134 resolver := resolver.NewEnvironmentVariableResolver(env)
1135 err := cfg.configureProviders(env, resolver, knownProviders)
1136 assert.NoError(t, err)
1137
1138 err = cfg.configureSelectedModels(knownProviders)
1139 assert.NoError(t, err)
1140 large := cfg.Models[SelectedModelTypeLarge]
1141 small := cfg.Models[SelectedModelTypeSmall]
1142 assert.Equal(t, "large-model", large.Model)
1143 assert.Equal(t, "openai", large.Provider)
1144 assert.Equal(t, int64(1000), large.MaxTokens)
1145 assert.Equal(t, "a-small-model", small.Model)
1146 assert.Equal(t, "anthropic", small.Provider)
1147 assert.Equal(t, int64(300), small.MaxTokens)
1148 })
1149
1150 t.Run("should override the max tokens only", func(t *testing.T) {
1151 knownProviders := []catwalk.Provider{
1152 {
1153 ID: "openai",
1154 APIKey: "abc",
1155 DefaultLargeModelID: "large-model",
1156 DefaultSmallModelID: "small-model",
1157 Models: []catwalk.Model{
1158 {
1159 ID: "large-model",
1160 DefaultMaxTokens: 1000,
1161 },
1162 {
1163 ID: "small-model",
1164 DefaultMaxTokens: 500,
1165 },
1166 },
1167 },
1168 }
1169
1170 cfg := &Config{
1171 Models: map[SelectedModelType]agent.Model{
1172 "large": {
1173 MaxTokens: 100,
1174 },
1175 },
1176 }
1177 cfg.setDefaults("/tmp")
1178 env := env.NewFromMap(map[string]string{})
1179 resolver := resolver.NewEnvironmentVariableResolver(env)
1180 err := cfg.configureProviders(env, resolver, knownProviders)
1181 assert.NoError(t, err)
1182
1183 err = cfg.configureSelectedModels(knownProviders)
1184 assert.NoError(t, err)
1185 large := cfg.Models[SelectedModelTypeLarge]
1186 assert.Equal(t, "large-model", large.Model)
1187 assert.Equal(t, "openai", large.Provider)
1188 assert.Equal(t, int64(100), large.MaxTokens)
1189 })
1190}