1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "testing"
9
10 "charm.land/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/charmbracelet/crush/internal/env"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func TestMain(m *testing.M) {
18 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
19
20 exitVal := m.Run()
21 os.Exit(exitVal)
22}
23
24func TestConfig_LoadFromBytes(t *testing.T) {
25 data1 := []byte(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
26 data2 := []byte(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
27 data3 := []byte(`{"providers": {"openai": {}}}`)
28
29 loadedConfig, err := loadFromBytes([][]byte{data1, data2, data3})
30
31 require.NoError(t, err)
32 require.NotNil(t, loadedConfig)
33 require.Equal(t, 1, loadedConfig.Providers.Len())
34 pc, _ := loadedConfig.Providers.Get("openai")
35 require.Equal(t, "key2", pc.APIKey)
36 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
37}
38
39// testStore wraps a Config in a minimal ConfigStore for testing.
40func testStore(cfg *Config) *ConfigStore {
41 return &ConfigStore{config: cfg}
42}
43
44func TestConfig_setDefaults(t *testing.T) {
45 cfg := &Config{}
46
47 cfg.setDefaults("/tmp", "")
48
49 require.NotNil(t, cfg.Options)
50 require.NotNil(t, cfg.Options.TUI)
51 require.NotNil(t, cfg.Options.ContextPaths)
52 require.NotNil(t, cfg.Providers)
53 require.NotNil(t, cfg.Models)
54 require.NotNil(t, cfg.LSP)
55 require.NotNil(t, cfg.MCP)
56 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
57 require.Equal(t, "AGENTS.md", cfg.Options.InitializeAs)
58 for _, path := range defaultContextPaths {
59 require.Contains(t, cfg.Options.ContextPaths, path)
60 }
61}
62
63func TestConfig_configureProviders(t *testing.T) {
64 knownProviders := []catwalk.Provider{
65 {
66 ID: "openai",
67 APIKey: "$OPENAI_API_KEY",
68 APIEndpoint: "https://api.openai.com/v1",
69 Models: []catwalk.Model{{
70 ID: "test-model",
71 }},
72 },
73 }
74
75 cfg := &Config{}
76 cfg.setDefaults("/tmp", "")
77 env := env.NewFromMap(map[string]string{
78 "OPENAI_API_KEY": "test-key",
79 })
80 resolver := NewEnvironmentVariableResolver(env)
81 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
82 require.NoError(t, err)
83 require.Equal(t, 1, cfg.Providers.Len())
84
85 // We want to make sure that we keep the configured API key as a placeholder
86 pc, _ := cfg.Providers.Get("openai")
87 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
88}
89
90func TestConfig_configureProvidersWithOverride(t *testing.T) {
91 knownProviders := []catwalk.Provider{
92 {
93 ID: "openai",
94 APIKey: "$OPENAI_API_KEY",
95 APIEndpoint: "https://api.openai.com/v1",
96 Models: []catwalk.Model{{
97 ID: "test-model",
98 }},
99 },
100 }
101
102 cfg := &Config{
103 Providers: csync.NewMap[string, ProviderConfig](),
104 }
105 cfg.Providers.Set("openai", ProviderConfig{
106 APIKey: "xyz",
107 BaseURL: "https://api.openai.com/v2",
108 Models: []catwalk.Model{
109 {
110 ID: "test-model",
111 Name: "Updated",
112 },
113 {
114 ID: "another-model",
115 },
116 },
117 })
118 cfg.setDefaults("/tmp", "")
119
120 env := env.NewFromMap(map[string]string{
121 "OPENAI_API_KEY": "test-key",
122 })
123 resolver := NewEnvironmentVariableResolver(env)
124 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
125 require.NoError(t, err)
126 require.Equal(t, 1, cfg.Providers.Len())
127
128 // We want to make sure that we keep the configured API key as a placeholder
129 pc, _ := cfg.Providers.Get("openai")
130 require.Equal(t, "xyz", pc.APIKey)
131 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
132 require.Len(t, pc.Models, 2)
133 require.Equal(t, "Updated", pc.Models[0].Name)
134}
135
136func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
137 knownProviders := []catwalk.Provider{
138 {
139 ID: "openai",
140 APIKey: "$OPENAI_API_KEY",
141 APIEndpoint: "https://api.openai.com/v1",
142 Models: []catwalk.Model{{
143 ID: "test-model",
144 }},
145 },
146 }
147
148 cfg := &Config{
149 Providers: csync.NewMapFrom(map[string]ProviderConfig{
150 "custom": {
151 APIKey: "xyz",
152 BaseURL: "https://api.someendpoint.com/v2",
153 Models: []catwalk.Model{
154 {
155 ID: "test-model",
156 },
157 },
158 },
159 }),
160 }
161 cfg.setDefaults("/tmp", "")
162 env := env.NewFromMap(map[string]string{
163 "OPENAI_API_KEY": "test-key",
164 })
165 resolver := NewEnvironmentVariableResolver(env)
166 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
167 require.NoError(t, err)
168 // Should be to because of the env variable
169 require.Equal(t, cfg.Providers.Len(), 2)
170
171 // We want to make sure that we keep the configured API key as a placeholder
172 pc, _ := cfg.Providers.Get("custom")
173 require.Equal(t, "xyz", pc.APIKey)
174 // Make sure we set the ID correctly
175 require.Equal(t, "custom", pc.ID)
176 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
177 require.Len(t, pc.Models, 1)
178
179 _, ok := cfg.Providers.Get("openai")
180 require.True(t, ok, "OpenAI provider should still be present")
181}
182
183func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
184 knownProviders := []catwalk.Provider{
185 {
186 ID: catwalk.InferenceProviderBedrock,
187 APIKey: "",
188 APIEndpoint: "",
189 Models: []catwalk.Model{{
190 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
191 }},
192 },
193 }
194
195 cfg := &Config{}
196 cfg.setDefaults("/tmp", "")
197 env := env.NewFromMap(map[string]string{
198 "AWS_ACCESS_KEY_ID": "test-key-id",
199 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
200 })
201 resolver := NewEnvironmentVariableResolver(env)
202 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
203 require.NoError(t, err)
204 require.Equal(t, cfg.Providers.Len(), 1)
205
206 bedrockProvider, ok := cfg.Providers.Get("bedrock")
207 require.True(t, ok, "Bedrock provider should be present")
208 require.Len(t, bedrockProvider.Models, 1)
209 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
210}
211
212func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
213 knownProviders := []catwalk.Provider{
214 {
215 ID: catwalk.InferenceProviderBedrock,
216 APIKey: "",
217 APIEndpoint: "",
218 Models: []catwalk.Model{{
219 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
220 }},
221 },
222 }
223
224 cfg := &Config{}
225 cfg.setDefaults("/tmp", "")
226 env := env.NewFromMap(map[string]string{})
227 resolver := NewEnvironmentVariableResolver(env)
228 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
229 require.NoError(t, err)
230 // Provider should not be configured without credentials
231 require.Equal(t, cfg.Providers.Len(), 0)
232}
233
234func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
235 knownProviders := []catwalk.Provider{
236 {
237 ID: catwalk.InferenceProviderBedrock,
238 APIKey: "",
239 APIEndpoint: "",
240 Models: []catwalk.Model{{
241 ID: "some-random-model",
242 }},
243 },
244 }
245
246 cfg := &Config{}
247 cfg.setDefaults("/tmp", "")
248 env := env.NewFromMap(map[string]string{
249 "AWS_ACCESS_KEY_ID": "test-key-id",
250 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
251 })
252 resolver := NewEnvironmentVariableResolver(env)
253 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
254 require.Error(t, err)
255}
256
257func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
258 knownProviders := []catwalk.Provider{
259 {
260 ID: catwalk.InferenceProviderVertexAI,
261 APIKey: "",
262 APIEndpoint: "",
263 Models: []catwalk.Model{{
264 ID: "gemini-pro",
265 }},
266 },
267 }
268
269 cfg := &Config{}
270 cfg.setDefaults("/tmp", "")
271 env := env.NewFromMap(map[string]string{
272 "VERTEXAI_PROJECT": "test-project",
273 "VERTEXAI_LOCATION": "us-central1",
274 })
275 resolver := NewEnvironmentVariableResolver(env)
276 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
277 require.NoError(t, err)
278 require.Equal(t, cfg.Providers.Len(), 1)
279
280 vertexProvider, ok := cfg.Providers.Get("vertexai")
281 require.True(t, ok, "VertexAI provider should be present")
282 require.Len(t, vertexProvider.Models, 1)
283 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
284 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
285 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
286}
287
288func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
289 knownProviders := []catwalk.Provider{
290 {
291 ID: catwalk.InferenceProviderVertexAI,
292 APIKey: "",
293 APIEndpoint: "",
294 Models: []catwalk.Model{{
295 ID: "gemini-pro",
296 }},
297 },
298 }
299
300 cfg := &Config{}
301 cfg.setDefaults("/tmp", "")
302 env := env.NewFromMap(map[string]string{
303 "GOOGLE_GENAI_USE_VERTEXAI": "false",
304 "GOOGLE_CLOUD_PROJECT": "test-project",
305 "GOOGLE_CLOUD_LOCATION": "us-central1",
306 })
307 resolver := NewEnvironmentVariableResolver(env)
308 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
309 require.NoError(t, err)
310 // Provider should not be configured without proper credentials
311 require.Equal(t, cfg.Providers.Len(), 0)
312}
313
314func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
315 knownProviders := []catwalk.Provider{
316 {
317 ID: catwalk.InferenceProviderVertexAI,
318 APIKey: "",
319 APIEndpoint: "",
320 Models: []catwalk.Model{{
321 ID: "gemini-pro",
322 }},
323 },
324 }
325
326 cfg := &Config{}
327 cfg.setDefaults("/tmp", "")
328 env := env.NewFromMap(map[string]string{
329 "GOOGLE_GENAI_USE_VERTEXAI": "true",
330 "GOOGLE_CLOUD_LOCATION": "us-central1",
331 })
332 resolver := NewEnvironmentVariableResolver(env)
333 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
334 require.NoError(t, err)
335 // Provider should not be configured without project
336 require.Equal(t, cfg.Providers.Len(), 0)
337}
338
339func TestConfig_configureProvidersSetProviderID(t *testing.T) {
340 knownProviders := []catwalk.Provider{
341 {
342 ID: "openai",
343 APIKey: "$OPENAI_API_KEY",
344 APIEndpoint: "https://api.openai.com/v1",
345 Models: []catwalk.Model{{
346 ID: "test-model",
347 }},
348 },
349 }
350
351 cfg := &Config{}
352 cfg.setDefaults("/tmp", "")
353 env := env.NewFromMap(map[string]string{
354 "OPENAI_API_KEY": "test-key",
355 })
356 resolver := NewEnvironmentVariableResolver(env)
357 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
358 require.NoError(t, err)
359 require.Equal(t, cfg.Providers.Len(), 1)
360
361 // Provider ID should be set
362 pc, _ := cfg.Providers.Get("openai")
363 require.Equal(t, "openai", pc.ID)
364}
365
366func TestConfig_EnabledProviders(t *testing.T) {
367 t.Run("all providers enabled", func(t *testing.T) {
368 cfg := &Config{
369 Providers: csync.NewMapFrom(map[string]ProviderConfig{
370 "openai": {
371 ID: "openai",
372 APIKey: "key1",
373 Disable: false,
374 },
375 "anthropic": {
376 ID: "anthropic",
377 APIKey: "key2",
378 Disable: false,
379 },
380 }),
381 }
382
383 enabled := cfg.EnabledProviders()
384 require.Len(t, enabled, 2)
385 })
386
387 t.Run("some providers disabled", func(t *testing.T) {
388 cfg := &Config{
389 Providers: csync.NewMapFrom(map[string]ProviderConfig{
390 "openai": {
391 ID: "openai",
392 APIKey: "key1",
393 Disable: false,
394 },
395 "anthropic": {
396 ID: "anthropic",
397 APIKey: "key2",
398 Disable: true,
399 },
400 }),
401 }
402
403 enabled := cfg.EnabledProviders()
404 require.Len(t, enabled, 1)
405 require.Equal(t, "openai", enabled[0].ID)
406 })
407
408 t.Run("empty providers map", func(t *testing.T) {
409 cfg := &Config{
410 Providers: csync.NewMap[string, ProviderConfig](),
411 }
412
413 enabled := cfg.EnabledProviders()
414 require.Len(t, enabled, 0)
415 })
416}
417
418func TestConfig_IsConfigured(t *testing.T) {
419 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
420 cfg := &Config{
421 Providers: csync.NewMapFrom(map[string]ProviderConfig{
422 "openai": {
423 ID: "openai",
424 APIKey: "key1",
425 Disable: false,
426 },
427 }),
428 }
429
430 require.True(t, cfg.IsConfigured())
431 })
432
433 t.Run("returns false when no providers are configured", func(t *testing.T) {
434 cfg := &Config{
435 Providers: csync.NewMap[string, ProviderConfig](),
436 }
437
438 require.False(t, cfg.IsConfigured())
439 })
440
441 t.Run("returns false when all providers are disabled", func(t *testing.T) {
442 cfg := &Config{
443 Providers: csync.NewMapFrom(map[string]ProviderConfig{
444 "openai": {
445 ID: "openai",
446 APIKey: "key1",
447 Disable: true,
448 },
449 "anthropic": {
450 ID: "anthropic",
451 APIKey: "key2",
452 Disable: true,
453 },
454 }),
455 }
456
457 require.False(t, cfg.IsConfigured())
458 })
459}
460
461func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
462 cfg := &Config{
463 Options: &Options{
464 DisabledTools: []string{},
465 },
466 }
467
468 cfg.SetupAgents()
469 coderAgent, ok := cfg.Agents[AgentCoder]
470 require.True(t, ok)
471 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
472
473 taskAgent, ok := cfg.Agents[AgentTask]
474 require.True(t, ok)
475 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
476}
477
478func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
479 cfg := &Config{
480 Options: &Options{
481 DisabledTools: []string{
482 "edit",
483 "download",
484 "grep",
485 },
486 },
487 }
488
489 cfg.SetupAgents()
490 coderAgent, ok := cfg.Agents[AgentCoder]
491 require.True(t, ok)
492
493 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
494
495 taskAgent, ok := cfg.Agents[AgentTask]
496 require.True(t, ok)
497 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
498}
499
500func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
501 cfg := &Config{
502 Options: &Options{
503 DisabledTools: []string{
504 "glob",
505 "grep",
506 "ls",
507 "sourcegraph",
508 "view",
509 },
510 },
511 }
512
513 cfg.SetupAgents()
514 coderAgent, ok := cfg.Agents[AgentCoder]
515 require.True(t, ok)
516 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "todos", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
517
518 taskAgent, ok := cfg.Agents[AgentTask]
519 require.True(t, ok)
520 assert.Len(t, taskAgent.AllowedTools, 0)
521}
522
523func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
524 knownProviders := []catwalk.Provider{
525 {
526 ID: "openai",
527 APIKey: "$OPENAI_API_KEY",
528 APIEndpoint: "https://api.openai.com/v1",
529 Models: []catwalk.Model{{
530 ID: "test-model",
531 }},
532 },
533 }
534
535 cfg := &Config{
536 Providers: csync.NewMapFrom(map[string]ProviderConfig{
537 "openai": {
538 Disable: true,
539 },
540 }),
541 }
542 cfg.setDefaults("/tmp", "")
543
544 env := env.NewFromMap(map[string]string{
545 "OPENAI_API_KEY": "test-key",
546 })
547 resolver := NewEnvironmentVariableResolver(env)
548 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
549 require.NoError(t, err)
550
551 require.Equal(t, cfg.Providers.Len(), 1)
552 prov, exists := cfg.Providers.Get("openai")
553 require.True(t, exists)
554 require.True(t, prov.Disable)
555}
556
557func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
558 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
559 cfg := &Config{
560 Providers: csync.NewMapFrom(map[string]ProviderConfig{
561 "custom": {
562 BaseURL: "https://api.custom.com/v1",
563 Models: []catwalk.Model{{
564 ID: "test-model",
565 }},
566 },
567 "openai": {
568 APIKey: "$MISSING",
569 },
570 }),
571 }
572 cfg.setDefaults("/tmp", "")
573
574 env := env.NewFromMap(map[string]string{})
575 resolver := NewEnvironmentVariableResolver(env)
576 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
577 require.NoError(t, err)
578
579 require.Equal(t, cfg.Providers.Len(), 1)
580 _, exists := cfg.Providers.Get("custom")
581 require.True(t, exists)
582 })
583
584 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
585 cfg := &Config{
586 Providers: csync.NewMapFrom(map[string]ProviderConfig{
587 "custom": {
588 APIKey: "test-key",
589 Models: []catwalk.Model{{
590 ID: "test-model",
591 }},
592 },
593 }),
594 }
595 cfg.setDefaults("/tmp", "")
596
597 env := env.NewFromMap(map[string]string{})
598 resolver := NewEnvironmentVariableResolver(env)
599 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
600 require.NoError(t, err)
601
602 require.Equal(t, cfg.Providers.Len(), 0)
603 _, exists := cfg.Providers.Get("custom")
604 require.False(t, exists)
605 })
606
607 t.Run("custom provider with no models is removed", func(t *testing.T) {
608 cfg := &Config{
609 Providers: csync.NewMapFrom(map[string]ProviderConfig{
610 "custom": {
611 APIKey: "test-key",
612 BaseURL: "https://api.custom.com/v1",
613 Models: []catwalk.Model{},
614 },
615 }),
616 }
617 cfg.setDefaults("/tmp", "")
618
619 env := env.NewFromMap(map[string]string{})
620 resolver := NewEnvironmentVariableResolver(env)
621 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
622 require.NoError(t, err)
623
624 require.Equal(t, cfg.Providers.Len(), 0)
625 _, exists := cfg.Providers.Get("custom")
626 require.False(t, exists)
627 })
628
629 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
630 cfg := &Config{
631 Providers: csync.NewMapFrom(map[string]ProviderConfig{
632 "custom": {
633 APIKey: "test-key",
634 BaseURL: "https://api.custom.com/v1",
635 Type: "unsupported",
636 Models: []catwalk.Model{{
637 ID: "test-model",
638 }},
639 },
640 }),
641 }
642 cfg.setDefaults("/tmp", "")
643
644 env := env.NewFromMap(map[string]string{})
645 resolver := NewEnvironmentVariableResolver(env)
646 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
647 require.NoError(t, err)
648
649 require.Equal(t, cfg.Providers.Len(), 0)
650 _, exists := cfg.Providers.Get("custom")
651 require.False(t, exists)
652 })
653
654 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
655 cfg := &Config{
656 Providers: csync.NewMapFrom(map[string]ProviderConfig{
657 "custom": {
658 APIKey: "test-key",
659 BaseURL: "https://api.custom.com/v1",
660 Type: catwalk.TypeOpenAI,
661 Models: []catwalk.Model{{
662 ID: "test-model",
663 }},
664 },
665 }),
666 }
667 cfg.setDefaults("/tmp", "")
668
669 env := env.NewFromMap(map[string]string{})
670 resolver := NewEnvironmentVariableResolver(env)
671 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
672 require.NoError(t, err)
673
674 require.Equal(t, cfg.Providers.Len(), 1)
675 customProvider, exists := cfg.Providers.Get("custom")
676 require.True(t, exists)
677 require.Equal(t, "custom", customProvider.ID)
678 require.Equal(t, "test-key", customProvider.APIKey)
679 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
680 })
681
682 t.Run("custom anthropic provider is supported", func(t *testing.T) {
683 cfg := &Config{
684 Providers: csync.NewMapFrom(map[string]ProviderConfig{
685 "custom-anthropic": {
686 APIKey: "test-key",
687 BaseURL: "https://api.anthropic.com/v1",
688 Type: catwalk.TypeAnthropic,
689 Models: []catwalk.Model{{
690 ID: "claude-3-sonnet",
691 }},
692 },
693 }),
694 }
695 cfg.setDefaults("/tmp", "")
696
697 env := env.NewFromMap(map[string]string{})
698 resolver := NewEnvironmentVariableResolver(env)
699 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
700 require.NoError(t, err)
701
702 require.Equal(t, cfg.Providers.Len(), 1)
703 customProvider, exists := cfg.Providers.Get("custom-anthropic")
704 require.True(t, exists)
705 require.Equal(t, "custom-anthropic", customProvider.ID)
706 require.Equal(t, "test-key", customProvider.APIKey)
707 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
708 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
709 })
710
711 t.Run("disabled custom provider is removed", func(t *testing.T) {
712 cfg := &Config{
713 Providers: csync.NewMapFrom(map[string]ProviderConfig{
714 "custom": {
715 APIKey: "test-key",
716 BaseURL: "https://api.custom.com/v1",
717 Type: catwalk.TypeOpenAI,
718 Disable: true,
719 Models: []catwalk.Model{{
720 ID: "test-model",
721 }},
722 },
723 }),
724 }
725 cfg.setDefaults("/tmp", "")
726
727 env := env.NewFromMap(map[string]string{})
728 resolver := NewEnvironmentVariableResolver(env)
729 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
730 require.NoError(t, err)
731
732 require.Equal(t, cfg.Providers.Len(), 0)
733 _, exists := cfg.Providers.Get("custom")
734 require.False(t, exists)
735 })
736}
737
738func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
739 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
740 knownProviders := []catwalk.Provider{
741 {
742 ID: catwalk.InferenceProviderVertexAI,
743 APIKey: "",
744 APIEndpoint: "",
745 Models: []catwalk.Model{{
746 ID: "gemini-pro",
747 }},
748 },
749 }
750
751 cfg := &Config{
752 Providers: csync.NewMapFrom(map[string]ProviderConfig{
753 "vertexai": {
754 BaseURL: "custom-url",
755 },
756 }),
757 }
758 cfg.setDefaults("/tmp", "")
759
760 env := env.NewFromMap(map[string]string{
761 "GOOGLE_GENAI_USE_VERTEXAI": "false",
762 })
763 resolver := NewEnvironmentVariableResolver(env)
764 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
765 require.NoError(t, err)
766
767 require.Equal(t, cfg.Providers.Len(), 0)
768 _, exists := cfg.Providers.Get("vertexai")
769 require.False(t, exists)
770 })
771
772 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
773 knownProviders := []catwalk.Provider{
774 {
775 ID: catwalk.InferenceProviderBedrock,
776 APIKey: "",
777 APIEndpoint: "",
778 Models: []catwalk.Model{{
779 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
780 }},
781 },
782 }
783
784 cfg := &Config{
785 Providers: csync.NewMapFrom(map[string]ProviderConfig{
786 "bedrock": {
787 BaseURL: "custom-url",
788 },
789 }),
790 }
791 cfg.setDefaults("/tmp", "")
792
793 env := env.NewFromMap(map[string]string{})
794 resolver := NewEnvironmentVariableResolver(env)
795 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
796 require.NoError(t, err)
797
798 require.Equal(t, cfg.Providers.Len(), 0)
799 _, exists := cfg.Providers.Get("bedrock")
800 require.False(t, exists)
801 })
802
803 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
804 knownProviders := []catwalk.Provider{
805 {
806 ID: "openai",
807 APIKey: "$MISSING_API_KEY",
808 APIEndpoint: "https://api.openai.com/v1",
809 Models: []catwalk.Model{{
810 ID: "test-model",
811 }},
812 },
813 }
814
815 cfg := &Config{
816 Providers: csync.NewMapFrom(map[string]ProviderConfig{
817 "openai": {
818 BaseURL: "custom-url",
819 },
820 }),
821 }
822 cfg.setDefaults("/tmp", "")
823
824 env := env.NewFromMap(map[string]string{})
825 resolver := NewEnvironmentVariableResolver(env)
826 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
827 require.NoError(t, err)
828
829 require.Equal(t, cfg.Providers.Len(), 0)
830 _, exists := cfg.Providers.Get("openai")
831 require.False(t, exists)
832 })
833
834 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
835 knownProviders := []catwalk.Provider{
836 {
837 ID: "openai",
838 APIKey: "$OPENAI_API_KEY",
839 APIEndpoint: "$MISSING_ENDPOINT",
840 Models: []catwalk.Model{{
841 ID: "test-model",
842 }},
843 },
844 }
845
846 cfg := &Config{
847 Providers: csync.NewMapFrom(map[string]ProviderConfig{
848 "openai": {
849 APIKey: "test-key",
850 },
851 }),
852 }
853 cfg.setDefaults("/tmp", "")
854
855 env := env.NewFromMap(map[string]string{
856 "OPENAI_API_KEY": "test-key",
857 })
858 resolver := NewEnvironmentVariableResolver(env)
859 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
860 require.NoError(t, err)
861
862 require.Equal(t, cfg.Providers.Len(), 1)
863 _, exists := cfg.Providers.Get("openai")
864 require.True(t, exists)
865 })
866}
867
868func TestConfig_defaultModelSelection(t *testing.T) {
869 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
870 knownProviders := []catwalk.Provider{
871 {
872 ID: "openai",
873 APIKey: "abc",
874 DefaultLargeModelID: "large-model",
875 DefaultSmallModelID: "small-model",
876 Models: []catwalk.Model{
877 {
878 ID: "large-model",
879 DefaultMaxTokens: 1000,
880 },
881 {
882 ID: "small-model",
883 DefaultMaxTokens: 500,
884 },
885 },
886 },
887 }
888
889 cfg := &Config{}
890 cfg.setDefaults("/tmp", "")
891 env := env.NewFromMap(map[string]string{})
892 resolver := NewEnvironmentVariableResolver(env)
893 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
894 require.NoError(t, err)
895
896 large, small, err := cfg.defaultModelSelection(knownProviders)
897 require.NoError(t, err)
898 require.Equal(t, "large-model", large.Model)
899 require.Equal(t, "openai", large.Provider)
900 require.Equal(t, int64(1000), large.MaxTokens)
901 require.Equal(t, "small-model", small.Model)
902 require.Equal(t, "openai", small.Provider)
903 require.Equal(t, int64(500), small.MaxTokens)
904 })
905 t.Run("should error if no providers configured", func(t *testing.T) {
906 knownProviders := []catwalk.Provider{
907 {
908 ID: "openai",
909 APIKey: "$MISSING_KEY",
910 DefaultLargeModelID: "large-model",
911 DefaultSmallModelID: "small-model",
912 Models: []catwalk.Model{
913 {
914 ID: "large-model",
915 DefaultMaxTokens: 1000,
916 },
917 {
918 ID: "small-model",
919 DefaultMaxTokens: 500,
920 },
921 },
922 },
923 }
924
925 cfg := &Config{}
926 cfg.setDefaults("/tmp", "")
927 env := env.NewFromMap(map[string]string{})
928 resolver := NewEnvironmentVariableResolver(env)
929 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
930 require.NoError(t, err)
931
932 _, _, err = cfg.defaultModelSelection(knownProviders)
933 require.Error(t, err)
934 })
935 t.Run("should error if model is missing", func(t *testing.T) {
936 knownProviders := []catwalk.Provider{
937 {
938 ID: "openai",
939 APIKey: "abc",
940 DefaultLargeModelID: "large-model",
941 DefaultSmallModelID: "small-model",
942 Models: []catwalk.Model{
943 {
944 ID: "not-large-model",
945 DefaultMaxTokens: 1000,
946 },
947 {
948 ID: "small-model",
949 DefaultMaxTokens: 500,
950 },
951 },
952 },
953 }
954
955 cfg := &Config{}
956 cfg.setDefaults("/tmp", "")
957 env := env.NewFromMap(map[string]string{})
958 resolver := NewEnvironmentVariableResolver(env)
959 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
960 require.NoError(t, err)
961 _, _, err = cfg.defaultModelSelection(knownProviders)
962 require.Error(t, err)
963 })
964
965 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
966 knownProviders := []catwalk.Provider{
967 {
968 ID: "openai",
969 APIKey: "$MISSING", // will not be included in the config
970 DefaultLargeModelID: "large-model",
971 DefaultSmallModelID: "small-model",
972 Models: []catwalk.Model{
973 {
974 ID: "not-large-model",
975 DefaultMaxTokens: 1000,
976 },
977 {
978 ID: "small-model",
979 DefaultMaxTokens: 500,
980 },
981 },
982 },
983 }
984
985 cfg := &Config{
986 Providers: csync.NewMapFrom(map[string]ProviderConfig{
987 "custom": {
988 APIKey: "test-key",
989 BaseURL: "https://api.custom.com/v1",
990 Models: []catwalk.Model{
991 {
992 ID: "model",
993 DefaultMaxTokens: 600,
994 },
995 },
996 },
997 }),
998 }
999 cfg.setDefaults("/tmp", "")
1000 env := env.NewFromMap(map[string]string{})
1001 resolver := NewEnvironmentVariableResolver(env)
1002 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1003 require.NoError(t, err)
1004 large, small, err := cfg.defaultModelSelection(knownProviders)
1005 require.NoError(t, err)
1006 require.Equal(t, "model", large.Model)
1007 require.Equal(t, "custom", large.Provider)
1008 require.Equal(t, int64(600), large.MaxTokens)
1009 require.Equal(t, "model", small.Model)
1010 require.Equal(t, "custom", small.Provider)
1011 require.Equal(t, int64(600), small.MaxTokens)
1012 })
1013
1014 t.Run("should fail if no model configured", func(t *testing.T) {
1015 knownProviders := []catwalk.Provider{
1016 {
1017 ID: "openai",
1018 APIKey: "$MISSING", // will not be included in the config
1019 DefaultLargeModelID: "large-model",
1020 DefaultSmallModelID: "small-model",
1021 Models: []catwalk.Model{
1022 {
1023 ID: "not-large-model",
1024 DefaultMaxTokens: 1000,
1025 },
1026 {
1027 ID: "small-model",
1028 DefaultMaxTokens: 500,
1029 },
1030 },
1031 },
1032 }
1033
1034 cfg := &Config{
1035 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1036 "custom": {
1037 APIKey: "test-key",
1038 BaseURL: "https://api.custom.com/v1",
1039 Models: []catwalk.Model{},
1040 },
1041 }),
1042 }
1043 cfg.setDefaults("/tmp", "")
1044 env := env.NewFromMap(map[string]string{})
1045 resolver := NewEnvironmentVariableResolver(env)
1046 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1047 require.NoError(t, err)
1048 _, _, err = cfg.defaultModelSelection(knownProviders)
1049 require.Error(t, err)
1050 })
1051 t.Run("should use the default provider first", func(t *testing.T) {
1052 knownProviders := []catwalk.Provider{
1053 {
1054 ID: "openai",
1055 APIKey: "set",
1056 DefaultLargeModelID: "large-model",
1057 DefaultSmallModelID: "small-model",
1058 Models: []catwalk.Model{
1059 {
1060 ID: "large-model",
1061 DefaultMaxTokens: 1000,
1062 },
1063 {
1064 ID: "small-model",
1065 DefaultMaxTokens: 500,
1066 },
1067 },
1068 },
1069 }
1070
1071 cfg := &Config{
1072 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1073 "custom": {
1074 APIKey: "test-key",
1075 BaseURL: "https://api.custom.com/v1",
1076 Models: []catwalk.Model{
1077 {
1078 ID: "large-model",
1079 DefaultMaxTokens: 1000,
1080 },
1081 },
1082 },
1083 }),
1084 }
1085 cfg.setDefaults("/tmp", "")
1086 env := env.NewFromMap(map[string]string{})
1087 resolver := NewEnvironmentVariableResolver(env)
1088 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1089 require.NoError(t, err)
1090 large, small, err := cfg.defaultModelSelection(knownProviders)
1091 require.NoError(t, err)
1092 require.Equal(t, "large-model", large.Model)
1093 require.Equal(t, "openai", large.Provider)
1094 require.Equal(t, int64(1000), large.MaxTokens)
1095 require.Equal(t, "small-model", small.Model)
1096 require.Equal(t, "openai", small.Provider)
1097 require.Equal(t, int64(500), small.MaxTokens)
1098 })
1099}
1100
1101func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
1102 t.Run("when enabled, ignores all default providers and requires full specification", func(t *testing.T) {
1103 knownProviders := []catwalk.Provider{
1104 {
1105 ID: "openai",
1106 APIKey: "$OPENAI_API_KEY",
1107 APIEndpoint: "https://api.openai.com/v1",
1108 Models: []catwalk.Model{{
1109 ID: "gpt-4",
1110 }},
1111 },
1112 }
1113
1114 // User references openai but doesn't fully specify it (no base_url, no
1115 // models). This should be rejected because disable_default_providers
1116 // treats all providers as custom.
1117 cfg := &Config{
1118 Options: &Options{
1119 DisableDefaultProviders: true,
1120 },
1121 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1122 "openai": {
1123 APIKey: "$OPENAI_API_KEY",
1124 },
1125 }),
1126 }
1127 cfg.setDefaults("/tmp", "")
1128
1129 env := env.NewFromMap(map[string]string{
1130 "OPENAI_API_KEY": "test-key",
1131 })
1132 resolver := NewEnvironmentVariableResolver(env)
1133 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1134 require.ErrorContains(t, err, "no custom providers")
1135
1136 // openai should NOT be present because it lacks base_url and models.
1137 require.Equal(t, 0, cfg.Providers.Len())
1138 _, exists := cfg.Providers.Get("openai")
1139 require.False(t, exists, "openai should not be present without full specification")
1140 })
1141
1142 t.Run("when enabled, fully specified providers work", func(t *testing.T) {
1143 knownProviders := []catwalk.Provider{
1144 {
1145 ID: "openai",
1146 APIKey: "$OPENAI_API_KEY",
1147 APIEndpoint: "https://api.openai.com/v1",
1148 Models: []catwalk.Model{{
1149 ID: "gpt-4",
1150 }},
1151 },
1152 }
1153
1154 // User fully specifies their provider.
1155 cfg := &Config{
1156 Options: &Options{
1157 DisableDefaultProviders: true,
1158 },
1159 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1160 "my-llm": {
1161 APIKey: "$MY_API_KEY",
1162 BaseURL: "https://my-llm.example.com/v1",
1163 Models: []catwalk.Model{{
1164 ID: "my-model",
1165 }},
1166 },
1167 }),
1168 }
1169 cfg.setDefaults("/tmp", "")
1170
1171 env := env.NewFromMap(map[string]string{
1172 "MY_API_KEY": "test-key",
1173 "OPENAI_API_KEY": "test-key",
1174 })
1175 resolver := NewEnvironmentVariableResolver(env)
1176 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1177 require.NoError(t, err)
1178
1179 // Only fully specified provider should be present.
1180 require.Equal(t, 1, cfg.Providers.Len())
1181 provider, exists := cfg.Providers.Get("my-llm")
1182 require.True(t, exists, "my-llm should be present")
1183 require.Equal(t, "https://my-llm.example.com/v1", provider.BaseURL)
1184 require.Len(t, provider.Models, 1)
1185
1186 // Default openai should NOT be present.
1187 _, exists = cfg.Providers.Get("openai")
1188 require.False(t, exists, "openai should not be present")
1189 })
1190
1191 t.Run("when disabled, includes all known providers with valid credentials", func(t *testing.T) {
1192 knownProviders := []catwalk.Provider{
1193 {
1194 ID: "openai",
1195 APIKey: "$OPENAI_API_KEY",
1196 APIEndpoint: "https://api.openai.com/v1",
1197 Models: []catwalk.Model{{
1198 ID: "gpt-4",
1199 }},
1200 },
1201 {
1202 ID: "anthropic",
1203 APIKey: "$ANTHROPIC_API_KEY",
1204 APIEndpoint: "https://api.anthropic.com/v1",
1205 Models: []catwalk.Model{{
1206 ID: "claude-3",
1207 }},
1208 },
1209 }
1210
1211 // User only configures openai, both API keys are available, but option
1212 // is disabled.
1213 cfg := &Config{
1214 Options: &Options{
1215 DisableDefaultProviders: false,
1216 },
1217 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1218 "openai": {
1219 APIKey: "$OPENAI_API_KEY",
1220 },
1221 }),
1222 }
1223 cfg.setDefaults("/tmp", "")
1224
1225 env := env.NewFromMap(map[string]string{
1226 "OPENAI_API_KEY": "test-key",
1227 "ANTHROPIC_API_KEY": "test-key",
1228 })
1229 resolver := NewEnvironmentVariableResolver(env)
1230 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1231 require.NoError(t, err)
1232
1233 // Both providers should be present.
1234 require.Equal(t, 2, cfg.Providers.Len())
1235 _, exists := cfg.Providers.Get("openai")
1236 require.True(t, exists, "openai should be present")
1237 _, exists = cfg.Providers.Get("anthropic")
1238 require.True(t, exists, "anthropic should be present")
1239 })
1240
1241 t.Run("when enabled, provider missing models is rejected", func(t *testing.T) {
1242 cfg := &Config{
1243 Options: &Options{
1244 DisableDefaultProviders: true,
1245 },
1246 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1247 "my-llm": {
1248 APIKey: "test-key",
1249 BaseURL: "https://my-llm.example.com/v1",
1250 Models: []catwalk.Model{}, // No models.
1251 },
1252 }),
1253 }
1254 cfg.setDefaults("/tmp", "")
1255
1256 env := env.NewFromMap(map[string]string{})
1257 resolver := NewEnvironmentVariableResolver(env)
1258 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
1259 require.ErrorContains(t, err, "no custom providers")
1260
1261 // Provider should be rejected for missing models.
1262 require.Equal(t, 0, cfg.Providers.Len())
1263 })
1264
1265 t.Run("when enabled, provider missing base_url is rejected", func(t *testing.T) {
1266 cfg := &Config{
1267 Options: &Options{
1268 DisableDefaultProviders: true,
1269 },
1270 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1271 "my-llm": {
1272 APIKey: "test-key",
1273 Models: []catwalk.Model{{ID: "model"}},
1274 // No BaseURL.
1275 },
1276 }),
1277 }
1278 cfg.setDefaults("/tmp", "")
1279
1280 env := env.NewFromMap(map[string]string{})
1281 resolver := NewEnvironmentVariableResolver(env)
1282 err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
1283 require.ErrorContains(t, err, "no custom providers")
1284
1285 // Provider should be rejected for missing base_url.
1286 require.Equal(t, 0, cfg.Providers.Len())
1287 })
1288}
1289
1290func TestConfig_setDefaultsDisableDefaultProvidersEnvVar(t *testing.T) {
1291 t.Run("sets option from environment variable", func(t *testing.T) {
1292 t.Setenv("CRUSH_DISABLE_DEFAULT_PROVIDERS", "true")
1293
1294 cfg := &Config{}
1295 cfg.setDefaults("/tmp", "")
1296
1297 require.True(t, cfg.Options.DisableDefaultProviders)
1298 })
1299
1300 t.Run("does not override when env var is not set", func(t *testing.T) {
1301 cfg := &Config{
1302 Options: &Options{
1303 DisableDefaultProviders: true,
1304 },
1305 }
1306 cfg.setDefaults("/tmp", "")
1307
1308 require.True(t, cfg.Options.DisableDefaultProviders)
1309 })
1310}
1311
1312func TestConfig_configureSelectedModels(t *testing.T) {
1313 t.Run("should override defaults", func(t *testing.T) {
1314 knownProviders := []catwalk.Provider{
1315 {
1316 ID: "openai",
1317 APIKey: "abc",
1318 DefaultLargeModelID: "large-model",
1319 DefaultSmallModelID: "small-model",
1320 Models: []catwalk.Model{
1321 {
1322 ID: "larger-model",
1323 DefaultMaxTokens: 2000,
1324 },
1325 {
1326 ID: "large-model",
1327 DefaultMaxTokens: 1000,
1328 },
1329 {
1330 ID: "small-model",
1331 DefaultMaxTokens: 500,
1332 },
1333 },
1334 },
1335 }
1336
1337 cfg := &Config{
1338 Models: map[SelectedModelType]SelectedModel{
1339 "large": {
1340 Model: "larger-model",
1341 },
1342 },
1343 }
1344 cfg.setDefaults("/tmp", "")
1345 env := env.NewFromMap(map[string]string{})
1346 resolver := NewEnvironmentVariableResolver(env)
1347 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1348 require.NoError(t, err)
1349
1350 err = configureSelectedModels(testStore(cfg), knownProviders)
1351 require.NoError(t, err)
1352 large := cfg.Models[SelectedModelTypeLarge]
1353 small := cfg.Models[SelectedModelTypeSmall]
1354 require.Equal(t, "larger-model", large.Model)
1355 require.Equal(t, "openai", large.Provider)
1356 require.Equal(t, int64(2000), large.MaxTokens)
1357 require.Equal(t, "small-model", small.Model)
1358 require.Equal(t, "openai", small.Provider)
1359 require.Equal(t, int64(500), small.MaxTokens)
1360 })
1361 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1362 knownProviders := []catwalk.Provider{
1363 {
1364 ID: "openai",
1365 APIKey: "abc",
1366 DefaultLargeModelID: "large-model",
1367 DefaultSmallModelID: "small-model",
1368 Models: []catwalk.Model{
1369 {
1370 ID: "large-model",
1371 DefaultMaxTokens: 1000,
1372 },
1373 {
1374 ID: "small-model",
1375 DefaultMaxTokens: 500,
1376 },
1377 },
1378 },
1379 {
1380 ID: "anthropic",
1381 APIKey: "abc",
1382 DefaultLargeModelID: "a-large-model",
1383 DefaultSmallModelID: "a-small-model",
1384 Models: []catwalk.Model{
1385 {
1386 ID: "a-large-model",
1387 DefaultMaxTokens: 1000,
1388 },
1389 {
1390 ID: "a-small-model",
1391 DefaultMaxTokens: 200,
1392 },
1393 },
1394 },
1395 }
1396
1397 cfg := &Config{
1398 Models: map[SelectedModelType]SelectedModel{
1399 "small": {
1400 Model: "a-small-model",
1401 Provider: "anthropic",
1402 MaxTokens: 300,
1403 },
1404 },
1405 }
1406 cfg.setDefaults("/tmp", "")
1407 env := env.NewFromMap(map[string]string{})
1408 resolver := NewEnvironmentVariableResolver(env)
1409 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1410 require.NoError(t, err)
1411
1412 err = configureSelectedModels(testStore(cfg), knownProviders)
1413 require.NoError(t, err)
1414 large := cfg.Models[SelectedModelTypeLarge]
1415 small := cfg.Models[SelectedModelTypeSmall]
1416 require.Equal(t, "large-model", large.Model)
1417 require.Equal(t, "openai", large.Provider)
1418 require.Equal(t, int64(1000), large.MaxTokens)
1419 require.Equal(t, "a-small-model", small.Model)
1420 require.Equal(t, "anthropic", small.Provider)
1421 require.Equal(t, int64(300), small.MaxTokens)
1422 })
1423
1424 t.Run("should override the max tokens only", func(t *testing.T) {
1425 knownProviders := []catwalk.Provider{
1426 {
1427 ID: "openai",
1428 APIKey: "abc",
1429 DefaultLargeModelID: "large-model",
1430 DefaultSmallModelID: "small-model",
1431 Models: []catwalk.Model{
1432 {
1433 ID: "large-model",
1434 DefaultMaxTokens: 1000,
1435 },
1436 {
1437 ID: "small-model",
1438 DefaultMaxTokens: 500,
1439 },
1440 },
1441 },
1442 }
1443
1444 cfg := &Config{
1445 Models: map[SelectedModelType]SelectedModel{
1446 "large": {
1447 MaxTokens: 100,
1448 },
1449 },
1450 }
1451 cfg.setDefaults("/tmp", "")
1452 env := env.NewFromMap(map[string]string{})
1453 resolver := NewEnvironmentVariableResolver(env)
1454 err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
1455 require.NoError(t, err)
1456
1457 err = configureSelectedModels(testStore(cfg), knownProviders)
1458 require.NoError(t, err)
1459 large := cfg.Models[SelectedModelTypeLarge]
1460 require.Equal(t, "large-model", large.Model)
1461 require.Equal(t, "openai", large.Provider)
1462 require.Equal(t, int64(100), large.MaxTokens)
1463 })
1464}