1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "testing"
9
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/charmbracelet/crush/internal/env"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func TestMain(m *testing.M) {
18 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
19
20 exitVal := m.Run()
21 os.Exit(exitVal)
22}
23
24func TestConfig_LoadFromBytes(t *testing.T) {
25 data1 := []byte(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
26 data2 := []byte(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
27 data3 := []byte(`{"providers": {"openai": {}}}`)
28
29 loadedConfig, err := loadFromBytes([][]byte{data1, data2, data3})
30
31 require.NoError(t, err)
32 require.NotNil(t, loadedConfig)
33 require.Equal(t, 1, loadedConfig.Providers.Len())
34 pc, _ := loadedConfig.Providers.Get("openai")
35 require.Equal(t, "key2", pc.APIKey)
36 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
37}
38
39func TestConfig_setDefaults(t *testing.T) {
40 cfg := &Config{}
41
42 cfg.setDefaults("/tmp", "")
43
44 require.NotNil(t, cfg.Options)
45 require.NotNil(t, cfg.Options.TUI)
46 require.NotNil(t, cfg.Options.ContextPaths)
47 require.NotNil(t, cfg.Providers)
48 require.NotNil(t, cfg.Models)
49 require.NotNil(t, cfg.LSP)
50 require.NotNil(t, cfg.MCP)
51 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
52 require.Equal(t, "AGENTS.md", cfg.Options.InitializeAs)
53 for _, path := range defaultContextPaths {
54 require.Contains(t, cfg.Options.ContextPaths, path)
55 }
56 require.Equal(t, "/tmp", cfg.workingDir)
57}
58
59func TestConfig_configureProviders(t *testing.T) {
60 knownProviders := []catwalk.Provider{
61 {
62 ID: "openai",
63 APIKey: "$OPENAI_API_KEY",
64 APIEndpoint: "https://api.openai.com/v1",
65 Models: []catwalk.Model{{
66 ID: "test-model",
67 }},
68 },
69 }
70
71 cfg := &Config{}
72 cfg.setDefaults("/tmp", "")
73 env := env.NewFromMap(map[string]string{
74 "OPENAI_API_KEY": "test-key",
75 })
76 resolver := NewEnvironmentVariableResolver(env)
77 err := cfg.configureProviders(env, resolver, knownProviders)
78 require.NoError(t, err)
79 require.Equal(t, 1, cfg.Providers.Len())
80
81 // We want to make sure that we keep the configured API key as a placeholder
82 pc, _ := cfg.Providers.Get("openai")
83 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
84}
85
86func TestConfig_configureProvidersWithOverride(t *testing.T) {
87 knownProviders := []catwalk.Provider{
88 {
89 ID: "openai",
90 APIKey: "$OPENAI_API_KEY",
91 APIEndpoint: "https://api.openai.com/v1",
92 Models: []catwalk.Model{{
93 ID: "test-model",
94 }},
95 },
96 }
97
98 cfg := &Config{
99 Providers: csync.NewMap[string, ProviderConfig](),
100 }
101 cfg.Providers.Set("openai", ProviderConfig{
102 APIKey: "xyz",
103 BaseURL: "https://api.openai.com/v2",
104 Models: []catwalk.Model{
105 {
106 ID: "test-model",
107 Name: "Updated",
108 },
109 {
110 ID: "another-model",
111 },
112 },
113 })
114 cfg.setDefaults("/tmp", "")
115
116 env := env.NewFromMap(map[string]string{
117 "OPENAI_API_KEY": "test-key",
118 })
119 resolver := NewEnvironmentVariableResolver(env)
120 err := cfg.configureProviders(env, resolver, knownProviders)
121 require.NoError(t, err)
122 require.Equal(t, 1, cfg.Providers.Len())
123
124 // We want to make sure that we keep the configured API key as a placeholder
125 pc, _ := cfg.Providers.Get("openai")
126 require.Equal(t, "xyz", pc.APIKey)
127 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
128 require.Len(t, pc.Models, 2)
129 require.Equal(t, "Updated", pc.Models[0].Name)
130}
131
132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
133 knownProviders := []catwalk.Provider{
134 {
135 ID: "openai",
136 APIKey: "$OPENAI_API_KEY",
137 APIEndpoint: "https://api.openai.com/v1",
138 Models: []catwalk.Model{{
139 ID: "test-model",
140 }},
141 },
142 }
143
144 cfg := &Config{
145 Providers: csync.NewMapFrom(map[string]ProviderConfig{
146 "custom": {
147 APIKey: "xyz",
148 BaseURL: "https://api.someendpoint.com/v2",
149 Models: []catwalk.Model{
150 {
151 ID: "test-model",
152 },
153 },
154 },
155 }),
156 }
157 cfg.setDefaults("/tmp", "")
158 env := env.NewFromMap(map[string]string{
159 "OPENAI_API_KEY": "test-key",
160 })
161 resolver := NewEnvironmentVariableResolver(env)
162 err := cfg.configureProviders(env, resolver, knownProviders)
163 require.NoError(t, err)
164 // Should be to because of the env variable
165 require.Equal(t, cfg.Providers.Len(), 2)
166
167 // We want to make sure that we keep the configured API key as a placeholder
168 pc, _ := cfg.Providers.Get("custom")
169 require.Equal(t, "xyz", pc.APIKey)
170 // Make sure we set the ID correctly
171 require.Equal(t, "custom", pc.ID)
172 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
173 require.Len(t, pc.Models, 1)
174
175 _, ok := cfg.Providers.Get("openai")
176 require.True(t, ok, "OpenAI provider should still be present")
177}
178
179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
180 knownProviders := []catwalk.Provider{
181 {
182 ID: catwalk.InferenceProviderBedrock,
183 APIKey: "",
184 APIEndpoint: "",
185 Models: []catwalk.Model{{
186 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
187 }},
188 },
189 }
190
191 cfg := &Config{}
192 cfg.setDefaults("/tmp", "")
193 env := env.NewFromMap(map[string]string{
194 "AWS_ACCESS_KEY_ID": "test-key-id",
195 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
196 })
197 resolver := NewEnvironmentVariableResolver(env)
198 err := cfg.configureProviders(env, resolver, knownProviders)
199 require.NoError(t, err)
200 require.Equal(t, cfg.Providers.Len(), 1)
201
202 bedrockProvider, ok := cfg.Providers.Get("bedrock")
203 require.True(t, ok, "Bedrock provider should be present")
204 require.Len(t, bedrockProvider.Models, 1)
205 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
206}
207
208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
209 knownProviders := []catwalk.Provider{
210 {
211 ID: catwalk.InferenceProviderBedrock,
212 APIKey: "",
213 APIEndpoint: "",
214 Models: []catwalk.Model{{
215 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
216 }},
217 },
218 }
219
220 cfg := &Config{}
221 cfg.setDefaults("/tmp", "")
222 env := env.NewFromMap(map[string]string{})
223 resolver := NewEnvironmentVariableResolver(env)
224 err := cfg.configureProviders(env, resolver, knownProviders)
225 require.NoError(t, err)
226 // Provider should not be configured without credentials
227 require.Equal(t, cfg.Providers.Len(), 0)
228}
229
230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
231 knownProviders := []catwalk.Provider{
232 {
233 ID: catwalk.InferenceProviderBedrock,
234 APIKey: "",
235 APIEndpoint: "",
236 Models: []catwalk.Model{{
237 ID: "some-random-model",
238 }},
239 },
240 }
241
242 cfg := &Config{}
243 cfg.setDefaults("/tmp", "")
244 env := env.NewFromMap(map[string]string{
245 "AWS_ACCESS_KEY_ID": "test-key-id",
246 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
247 })
248 resolver := NewEnvironmentVariableResolver(env)
249 err := cfg.configureProviders(env, resolver, knownProviders)
250 require.Error(t, err)
251}
252
253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
254 knownProviders := []catwalk.Provider{
255 {
256 ID: catwalk.InferenceProviderVertexAI,
257 APIKey: "",
258 APIEndpoint: "",
259 Models: []catwalk.Model{{
260 ID: "gemini-pro",
261 }},
262 },
263 }
264
265 cfg := &Config{}
266 cfg.setDefaults("/tmp", "")
267 env := env.NewFromMap(map[string]string{
268 "VERTEXAI_PROJECT": "test-project",
269 "VERTEXAI_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 require.NoError(t, err)
274 require.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 require.True(t, ok, "VertexAI provider should be present")
278 require.Len(t, vertexProvider.Models, 1)
279 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp", "")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 require.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 require.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp", "")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 require.NoError(t, err)
331 // Provider should not be configured without project
332 require.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp", "")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 require.NoError(t, err)
355 require.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 require.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 require.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 require.Len(t, enabled, 1)
401 require.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 require.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 require.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 require.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 require.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
458 cfg := &Config{
459 Options: &Options{
460 DisabledTools: []string{},
461 },
462 }
463
464 cfg.SetupAgents()
465 coderAgent, ok := cfg.Agents[AgentCoder]
466 require.True(t, ok)
467 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
468
469 taskAgent, ok := cfg.Agents[AgentTask]
470 require.True(t, ok)
471 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
472}
473
474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
475 cfg := &Config{
476 Options: &Options{
477 DisabledTools: []string{
478 "edit",
479 "download",
480 "grep",
481 },
482 },
483 }
484
485 cfg.SetupAgents()
486 coderAgent, ok := cfg.Agents[AgentCoder]
487 require.True(t, ok)
488
489 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write"}, coderAgent.AllowedTools)
490
491 taskAgent, ok := cfg.Agents[AgentTask]
492 require.True(t, ok)
493 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
494}
495
496func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
497 cfg := &Config{
498 Options: &Options{
499 DisabledTools: []string{
500 "glob",
501 "grep",
502 "ls",
503 "sourcegraph",
504 "view",
505 },
506 },
507 }
508
509 cfg.SetupAgents()
510 coderAgent, ok := cfg.Agents[AgentCoder]
511 require.True(t, ok)
512 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "todos", "write"}, coderAgent.AllowedTools)
513
514 taskAgent, ok := cfg.Agents[AgentTask]
515 require.True(t, ok)
516 assert.Equal(t, []string{}, taskAgent.AllowedTools)
517}
518
519func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
520 knownProviders := []catwalk.Provider{
521 {
522 ID: "openai",
523 APIKey: "$OPENAI_API_KEY",
524 APIEndpoint: "https://api.openai.com/v1",
525 Models: []catwalk.Model{{
526 ID: "test-model",
527 }},
528 },
529 }
530
531 cfg := &Config{
532 Providers: csync.NewMapFrom(map[string]ProviderConfig{
533 "openai": {
534 Disable: true,
535 },
536 }),
537 }
538 cfg.setDefaults("/tmp", "")
539
540 env := env.NewFromMap(map[string]string{
541 "OPENAI_API_KEY": "test-key",
542 })
543 resolver := NewEnvironmentVariableResolver(env)
544 err := cfg.configureProviders(env, resolver, knownProviders)
545 require.NoError(t, err)
546
547 require.Equal(t, cfg.Providers.Len(), 1)
548 prov, exists := cfg.Providers.Get("openai")
549 require.True(t, exists)
550 require.True(t, prov.Disable)
551}
552
553func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
554 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
555 cfg := &Config{
556 Providers: csync.NewMapFrom(map[string]ProviderConfig{
557 "custom": {
558 BaseURL: "https://api.custom.com/v1",
559 Models: []catwalk.Model{{
560 ID: "test-model",
561 }},
562 },
563 "openai": {
564 APIKey: "$MISSING",
565 },
566 }),
567 }
568 cfg.setDefaults("/tmp", "")
569
570 env := env.NewFromMap(map[string]string{})
571 resolver := NewEnvironmentVariableResolver(env)
572 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
573 require.NoError(t, err)
574
575 require.Equal(t, cfg.Providers.Len(), 1)
576 _, exists := cfg.Providers.Get("custom")
577 require.True(t, exists)
578 })
579
580 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
581 cfg := &Config{
582 Providers: csync.NewMapFrom(map[string]ProviderConfig{
583 "custom": {
584 APIKey: "test-key",
585 Models: []catwalk.Model{{
586 ID: "test-model",
587 }},
588 },
589 }),
590 }
591 cfg.setDefaults("/tmp", "")
592
593 env := env.NewFromMap(map[string]string{})
594 resolver := NewEnvironmentVariableResolver(env)
595 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
596 require.NoError(t, err)
597
598 require.Equal(t, cfg.Providers.Len(), 0)
599 _, exists := cfg.Providers.Get("custom")
600 require.False(t, exists)
601 })
602
603 t.Run("custom provider with no models is removed", func(t *testing.T) {
604 cfg := &Config{
605 Providers: csync.NewMapFrom(map[string]ProviderConfig{
606 "custom": {
607 APIKey: "test-key",
608 BaseURL: "https://api.custom.com/v1",
609 Models: []catwalk.Model{},
610 },
611 }),
612 }
613 cfg.setDefaults("/tmp", "")
614
615 env := env.NewFromMap(map[string]string{})
616 resolver := NewEnvironmentVariableResolver(env)
617 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
618 require.NoError(t, err)
619
620 require.Equal(t, cfg.Providers.Len(), 0)
621 _, exists := cfg.Providers.Get("custom")
622 require.False(t, exists)
623 })
624
625 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
626 cfg := &Config{
627 Providers: csync.NewMapFrom(map[string]ProviderConfig{
628 "custom": {
629 APIKey: "test-key",
630 BaseURL: "https://api.custom.com/v1",
631 Type: "unsupported",
632 Models: []catwalk.Model{{
633 ID: "test-model",
634 }},
635 },
636 }),
637 }
638 cfg.setDefaults("/tmp", "")
639
640 env := env.NewFromMap(map[string]string{})
641 resolver := NewEnvironmentVariableResolver(env)
642 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
643 require.NoError(t, err)
644
645 require.Equal(t, cfg.Providers.Len(), 0)
646 _, exists := cfg.Providers.Get("custom")
647 require.False(t, exists)
648 })
649
650 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
651 cfg := &Config{
652 Providers: csync.NewMapFrom(map[string]ProviderConfig{
653 "custom": {
654 APIKey: "test-key",
655 BaseURL: "https://api.custom.com/v1",
656 Type: catwalk.TypeOpenAI,
657 Models: []catwalk.Model{{
658 ID: "test-model",
659 }},
660 },
661 }),
662 }
663 cfg.setDefaults("/tmp", "")
664
665 env := env.NewFromMap(map[string]string{})
666 resolver := NewEnvironmentVariableResolver(env)
667 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
668 require.NoError(t, err)
669
670 require.Equal(t, cfg.Providers.Len(), 1)
671 customProvider, exists := cfg.Providers.Get("custom")
672 require.True(t, exists)
673 require.Equal(t, "custom", customProvider.ID)
674 require.Equal(t, "test-key", customProvider.APIKey)
675 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
676 })
677
678 t.Run("custom anthropic provider is supported", func(t *testing.T) {
679 cfg := &Config{
680 Providers: csync.NewMapFrom(map[string]ProviderConfig{
681 "custom-anthropic": {
682 APIKey: "test-key",
683 BaseURL: "https://api.anthropic.com/v1",
684 Type: catwalk.TypeAnthropic,
685 Models: []catwalk.Model{{
686 ID: "claude-3-sonnet",
687 }},
688 },
689 }),
690 }
691 cfg.setDefaults("/tmp", "")
692
693 env := env.NewFromMap(map[string]string{})
694 resolver := NewEnvironmentVariableResolver(env)
695 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
696 require.NoError(t, err)
697
698 require.Equal(t, cfg.Providers.Len(), 1)
699 customProvider, exists := cfg.Providers.Get("custom-anthropic")
700 require.True(t, exists)
701 require.Equal(t, "custom-anthropic", customProvider.ID)
702 require.Equal(t, "test-key", customProvider.APIKey)
703 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
704 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
705 })
706
707 t.Run("disabled custom provider is removed", func(t *testing.T) {
708 cfg := &Config{
709 Providers: csync.NewMapFrom(map[string]ProviderConfig{
710 "custom": {
711 APIKey: "test-key",
712 BaseURL: "https://api.custom.com/v1",
713 Type: catwalk.TypeOpenAI,
714 Disable: true,
715 Models: []catwalk.Model{{
716 ID: "test-model",
717 }},
718 },
719 }),
720 }
721 cfg.setDefaults("/tmp", "")
722
723 env := env.NewFromMap(map[string]string{})
724 resolver := NewEnvironmentVariableResolver(env)
725 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
726 require.NoError(t, err)
727
728 require.Equal(t, cfg.Providers.Len(), 0)
729 _, exists := cfg.Providers.Get("custom")
730 require.False(t, exists)
731 })
732}
733
734func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
735 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
736 knownProviders := []catwalk.Provider{
737 {
738 ID: catwalk.InferenceProviderVertexAI,
739 APIKey: "",
740 APIEndpoint: "",
741 Models: []catwalk.Model{{
742 ID: "gemini-pro",
743 }},
744 },
745 }
746
747 cfg := &Config{
748 Providers: csync.NewMapFrom(map[string]ProviderConfig{
749 "vertexai": {
750 BaseURL: "custom-url",
751 },
752 }),
753 }
754 cfg.setDefaults("/tmp", "")
755
756 env := env.NewFromMap(map[string]string{
757 "GOOGLE_GENAI_USE_VERTEXAI": "false",
758 })
759 resolver := NewEnvironmentVariableResolver(env)
760 err := cfg.configureProviders(env, resolver, knownProviders)
761 require.NoError(t, err)
762
763 require.Equal(t, cfg.Providers.Len(), 0)
764 _, exists := cfg.Providers.Get("vertexai")
765 require.False(t, exists)
766 })
767
768 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
769 knownProviders := []catwalk.Provider{
770 {
771 ID: catwalk.InferenceProviderBedrock,
772 APIKey: "",
773 APIEndpoint: "",
774 Models: []catwalk.Model{{
775 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
776 }},
777 },
778 }
779
780 cfg := &Config{
781 Providers: csync.NewMapFrom(map[string]ProviderConfig{
782 "bedrock": {
783 BaseURL: "custom-url",
784 },
785 }),
786 }
787 cfg.setDefaults("/tmp", "")
788
789 env := env.NewFromMap(map[string]string{})
790 resolver := NewEnvironmentVariableResolver(env)
791 err := cfg.configureProviders(env, resolver, knownProviders)
792 require.NoError(t, err)
793
794 require.Equal(t, cfg.Providers.Len(), 0)
795 _, exists := cfg.Providers.Get("bedrock")
796 require.False(t, exists)
797 })
798
799 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
800 knownProviders := []catwalk.Provider{
801 {
802 ID: "openai",
803 APIKey: "$MISSING_API_KEY",
804 APIEndpoint: "https://api.openai.com/v1",
805 Models: []catwalk.Model{{
806 ID: "test-model",
807 }},
808 },
809 }
810
811 cfg := &Config{
812 Providers: csync.NewMapFrom(map[string]ProviderConfig{
813 "openai": {
814 BaseURL: "custom-url",
815 },
816 }),
817 }
818 cfg.setDefaults("/tmp", "")
819
820 env := env.NewFromMap(map[string]string{})
821 resolver := NewEnvironmentVariableResolver(env)
822 err := cfg.configureProviders(env, resolver, knownProviders)
823 require.NoError(t, err)
824
825 require.Equal(t, cfg.Providers.Len(), 0)
826 _, exists := cfg.Providers.Get("openai")
827 require.False(t, exists)
828 })
829
830 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
831 knownProviders := []catwalk.Provider{
832 {
833 ID: "openai",
834 APIKey: "$OPENAI_API_KEY",
835 APIEndpoint: "$MISSING_ENDPOINT",
836 Models: []catwalk.Model{{
837 ID: "test-model",
838 }},
839 },
840 }
841
842 cfg := &Config{
843 Providers: csync.NewMapFrom(map[string]ProviderConfig{
844 "openai": {
845 APIKey: "test-key",
846 },
847 }),
848 }
849 cfg.setDefaults("/tmp", "")
850
851 env := env.NewFromMap(map[string]string{
852 "OPENAI_API_KEY": "test-key",
853 })
854 resolver := NewEnvironmentVariableResolver(env)
855 err := cfg.configureProviders(env, resolver, knownProviders)
856 require.NoError(t, err)
857
858 require.Equal(t, cfg.Providers.Len(), 1)
859 _, exists := cfg.Providers.Get("openai")
860 require.True(t, exists)
861 })
862}
863
864func TestConfig_defaultModelSelection(t *testing.T) {
865 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
866 knownProviders := []catwalk.Provider{
867 {
868 ID: "openai",
869 APIKey: "abc",
870 DefaultLargeModelID: "large-model",
871 DefaultSmallModelID: "small-model",
872 Models: []catwalk.Model{
873 {
874 ID: "large-model",
875 DefaultMaxTokens: 1000,
876 },
877 {
878 ID: "small-model",
879 DefaultMaxTokens: 500,
880 },
881 },
882 },
883 }
884
885 cfg := &Config{}
886 cfg.setDefaults("/tmp", "")
887 env := env.NewFromMap(map[string]string{})
888 resolver := NewEnvironmentVariableResolver(env)
889 err := cfg.configureProviders(env, resolver, knownProviders)
890 require.NoError(t, err)
891
892 large, small, err := cfg.defaultModelSelection(knownProviders)
893 require.NoError(t, err)
894 require.Equal(t, "large-model", large.Model)
895 require.Equal(t, "openai", large.Provider)
896 require.Equal(t, int64(1000), large.MaxTokens)
897 require.Equal(t, "small-model", small.Model)
898 require.Equal(t, "openai", small.Provider)
899 require.Equal(t, int64(500), small.MaxTokens)
900 })
901 t.Run("should error if no providers configured", func(t *testing.T) {
902 knownProviders := []catwalk.Provider{
903 {
904 ID: "openai",
905 APIKey: "$MISSING_KEY",
906 DefaultLargeModelID: "large-model",
907 DefaultSmallModelID: "small-model",
908 Models: []catwalk.Model{
909 {
910 ID: "large-model",
911 DefaultMaxTokens: 1000,
912 },
913 {
914 ID: "small-model",
915 DefaultMaxTokens: 500,
916 },
917 },
918 },
919 }
920
921 cfg := &Config{}
922 cfg.setDefaults("/tmp", "")
923 env := env.NewFromMap(map[string]string{})
924 resolver := NewEnvironmentVariableResolver(env)
925 err := cfg.configureProviders(env, resolver, knownProviders)
926 require.NoError(t, err)
927
928 _, _, err = cfg.defaultModelSelection(knownProviders)
929 require.Error(t, err)
930 })
931 t.Run("should error if model is missing", func(t *testing.T) {
932 knownProviders := []catwalk.Provider{
933 {
934 ID: "openai",
935 APIKey: "abc",
936 DefaultLargeModelID: "large-model",
937 DefaultSmallModelID: "small-model",
938 Models: []catwalk.Model{
939 {
940 ID: "not-large-model",
941 DefaultMaxTokens: 1000,
942 },
943 {
944 ID: "small-model",
945 DefaultMaxTokens: 500,
946 },
947 },
948 },
949 }
950
951 cfg := &Config{}
952 cfg.setDefaults("/tmp", "")
953 env := env.NewFromMap(map[string]string{})
954 resolver := NewEnvironmentVariableResolver(env)
955 err := cfg.configureProviders(env, resolver, knownProviders)
956 require.NoError(t, err)
957 _, _, err = cfg.defaultModelSelection(knownProviders)
958 require.Error(t, err)
959 })
960
961 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
962 knownProviders := []catwalk.Provider{
963 {
964 ID: "openai",
965 APIKey: "$MISSING", // will not be included in the config
966 DefaultLargeModelID: "large-model",
967 DefaultSmallModelID: "small-model",
968 Models: []catwalk.Model{
969 {
970 ID: "not-large-model",
971 DefaultMaxTokens: 1000,
972 },
973 {
974 ID: "small-model",
975 DefaultMaxTokens: 500,
976 },
977 },
978 },
979 }
980
981 cfg := &Config{
982 Providers: csync.NewMapFrom(map[string]ProviderConfig{
983 "custom": {
984 APIKey: "test-key",
985 BaseURL: "https://api.custom.com/v1",
986 Models: []catwalk.Model{
987 {
988 ID: "model",
989 DefaultMaxTokens: 600,
990 },
991 },
992 },
993 }),
994 }
995 cfg.setDefaults("/tmp", "")
996 env := env.NewFromMap(map[string]string{})
997 resolver := NewEnvironmentVariableResolver(env)
998 err := cfg.configureProviders(env, resolver, knownProviders)
999 require.NoError(t, err)
1000 large, small, err := cfg.defaultModelSelection(knownProviders)
1001 require.NoError(t, err)
1002 require.Equal(t, "model", large.Model)
1003 require.Equal(t, "custom", large.Provider)
1004 require.Equal(t, int64(600), large.MaxTokens)
1005 require.Equal(t, "model", small.Model)
1006 require.Equal(t, "custom", small.Provider)
1007 require.Equal(t, int64(600), small.MaxTokens)
1008 })
1009
1010 t.Run("should fail if no model configured", func(t *testing.T) {
1011 knownProviders := []catwalk.Provider{
1012 {
1013 ID: "openai",
1014 APIKey: "$MISSING", // will not be included in the config
1015 DefaultLargeModelID: "large-model",
1016 DefaultSmallModelID: "small-model",
1017 Models: []catwalk.Model{
1018 {
1019 ID: "not-large-model",
1020 DefaultMaxTokens: 1000,
1021 },
1022 {
1023 ID: "small-model",
1024 DefaultMaxTokens: 500,
1025 },
1026 },
1027 },
1028 }
1029
1030 cfg := &Config{
1031 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1032 "custom": {
1033 APIKey: "test-key",
1034 BaseURL: "https://api.custom.com/v1",
1035 Models: []catwalk.Model{},
1036 },
1037 }),
1038 }
1039 cfg.setDefaults("/tmp", "")
1040 env := env.NewFromMap(map[string]string{})
1041 resolver := NewEnvironmentVariableResolver(env)
1042 err := cfg.configureProviders(env, resolver, knownProviders)
1043 require.NoError(t, err)
1044 _, _, err = cfg.defaultModelSelection(knownProviders)
1045 require.Error(t, err)
1046 })
1047 t.Run("should use the default provider first", func(t *testing.T) {
1048 knownProviders := []catwalk.Provider{
1049 {
1050 ID: "openai",
1051 APIKey: "set",
1052 DefaultLargeModelID: "large-model",
1053 DefaultSmallModelID: "small-model",
1054 Models: []catwalk.Model{
1055 {
1056 ID: "large-model",
1057 DefaultMaxTokens: 1000,
1058 },
1059 {
1060 ID: "small-model",
1061 DefaultMaxTokens: 500,
1062 },
1063 },
1064 },
1065 }
1066
1067 cfg := &Config{
1068 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1069 "custom": {
1070 APIKey: "test-key",
1071 BaseURL: "https://api.custom.com/v1",
1072 Models: []catwalk.Model{
1073 {
1074 ID: "large-model",
1075 DefaultMaxTokens: 1000,
1076 },
1077 },
1078 },
1079 }),
1080 }
1081 cfg.setDefaults("/tmp", "")
1082 env := env.NewFromMap(map[string]string{})
1083 resolver := NewEnvironmentVariableResolver(env)
1084 err := cfg.configureProviders(env, resolver, knownProviders)
1085 require.NoError(t, err)
1086 large, small, err := cfg.defaultModelSelection(knownProviders)
1087 require.NoError(t, err)
1088 require.Equal(t, "large-model", large.Model)
1089 require.Equal(t, "openai", large.Provider)
1090 require.Equal(t, int64(1000), large.MaxTokens)
1091 require.Equal(t, "small-model", small.Model)
1092 require.Equal(t, "openai", small.Provider)
1093 require.Equal(t, int64(500), small.MaxTokens)
1094 })
1095}
1096
1097func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
1098 t.Run("when enabled, ignores all default providers and requires full specification", func(t *testing.T) {
1099 knownProviders := []catwalk.Provider{
1100 {
1101 ID: "openai",
1102 APIKey: "$OPENAI_API_KEY",
1103 APIEndpoint: "https://api.openai.com/v1",
1104 Models: []catwalk.Model{{
1105 ID: "gpt-4",
1106 }},
1107 },
1108 }
1109
1110 // User references openai but doesn't fully specify it (no base_url, no
1111 // models). This should be rejected because disable_default_providers
1112 // treats all providers as custom.
1113 cfg := &Config{
1114 Options: &Options{
1115 DisableDefaultProviders: true,
1116 },
1117 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1118 "openai": {
1119 APIKey: "$OPENAI_API_KEY",
1120 },
1121 }),
1122 }
1123 cfg.setDefaults("/tmp", "")
1124
1125 env := env.NewFromMap(map[string]string{
1126 "OPENAI_API_KEY": "test-key",
1127 })
1128 resolver := NewEnvironmentVariableResolver(env)
1129 err := cfg.configureProviders(env, resolver, knownProviders)
1130 require.NoError(t, err)
1131
1132 // openai should NOT be present because it lacks base_url and models.
1133 require.Equal(t, 0, cfg.Providers.Len())
1134 _, exists := cfg.Providers.Get("openai")
1135 require.False(t, exists, "openai should not be present without full specification")
1136 })
1137
1138 t.Run("when enabled, fully specified providers work", func(t *testing.T) {
1139 knownProviders := []catwalk.Provider{
1140 {
1141 ID: "openai",
1142 APIKey: "$OPENAI_API_KEY",
1143 APIEndpoint: "https://api.openai.com/v1",
1144 Models: []catwalk.Model{{
1145 ID: "gpt-4",
1146 }},
1147 },
1148 }
1149
1150 // User fully specifies their provider.
1151 cfg := &Config{
1152 Options: &Options{
1153 DisableDefaultProviders: true,
1154 },
1155 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1156 "my-llm": {
1157 APIKey: "$MY_API_KEY",
1158 BaseURL: "https://my-llm.example.com/v1",
1159 Models: []catwalk.Model{{
1160 ID: "my-model",
1161 }},
1162 },
1163 }),
1164 }
1165 cfg.setDefaults("/tmp", "")
1166
1167 env := env.NewFromMap(map[string]string{
1168 "MY_API_KEY": "test-key",
1169 "OPENAI_API_KEY": "test-key",
1170 })
1171 resolver := NewEnvironmentVariableResolver(env)
1172 err := cfg.configureProviders(env, resolver, knownProviders)
1173 require.NoError(t, err)
1174
1175 // Only fully specified provider should be present.
1176 require.Equal(t, 1, cfg.Providers.Len())
1177 provider, exists := cfg.Providers.Get("my-llm")
1178 require.True(t, exists, "my-llm should be present")
1179 require.Equal(t, "https://my-llm.example.com/v1", provider.BaseURL)
1180 require.Len(t, provider.Models, 1)
1181
1182 // Default openai should NOT be present.
1183 _, exists = cfg.Providers.Get("openai")
1184 require.False(t, exists, "openai should not be present")
1185 })
1186
1187 t.Run("when disabled, includes all known providers with valid credentials", func(t *testing.T) {
1188 knownProviders := []catwalk.Provider{
1189 {
1190 ID: "openai",
1191 APIKey: "$OPENAI_API_KEY",
1192 APIEndpoint: "https://api.openai.com/v1",
1193 Models: []catwalk.Model{{
1194 ID: "gpt-4",
1195 }},
1196 },
1197 {
1198 ID: "anthropic",
1199 APIKey: "$ANTHROPIC_API_KEY",
1200 APIEndpoint: "https://api.anthropic.com/v1",
1201 Models: []catwalk.Model{{
1202 ID: "claude-3",
1203 }},
1204 },
1205 }
1206
1207 // User only configures openai, both API keys are available, but option
1208 // is disabled.
1209 cfg := &Config{
1210 Options: &Options{
1211 DisableDefaultProviders: false,
1212 },
1213 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1214 "openai": {
1215 APIKey: "$OPENAI_API_KEY",
1216 },
1217 }),
1218 }
1219 cfg.setDefaults("/tmp", "")
1220
1221 env := env.NewFromMap(map[string]string{
1222 "OPENAI_API_KEY": "test-key",
1223 "ANTHROPIC_API_KEY": "test-key",
1224 })
1225 resolver := NewEnvironmentVariableResolver(env)
1226 err := cfg.configureProviders(env, resolver, knownProviders)
1227 require.NoError(t, err)
1228
1229 // Both providers should be present.
1230 require.Equal(t, 2, cfg.Providers.Len())
1231 _, exists := cfg.Providers.Get("openai")
1232 require.True(t, exists, "openai should be present")
1233 _, exists = cfg.Providers.Get("anthropic")
1234 require.True(t, exists, "anthropic should be present")
1235 })
1236
1237 t.Run("when enabled, provider missing models is rejected", func(t *testing.T) {
1238 cfg := &Config{
1239 Options: &Options{
1240 DisableDefaultProviders: true,
1241 },
1242 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1243 "my-llm": {
1244 APIKey: "test-key",
1245 BaseURL: "https://my-llm.example.com/v1",
1246 Models: []catwalk.Model{}, // No models.
1247 },
1248 }),
1249 }
1250 cfg.setDefaults("/tmp", "")
1251
1252 env := env.NewFromMap(map[string]string{})
1253 resolver := NewEnvironmentVariableResolver(env)
1254 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
1255 require.NoError(t, err)
1256
1257 // Provider should be rejected for missing models.
1258 require.Equal(t, 0, cfg.Providers.Len())
1259 })
1260
1261 t.Run("when enabled, provider missing base_url is rejected", func(t *testing.T) {
1262 cfg := &Config{
1263 Options: &Options{
1264 DisableDefaultProviders: true,
1265 },
1266 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1267 "my-llm": {
1268 APIKey: "test-key",
1269 Models: []catwalk.Model{{ID: "model"}},
1270 // No BaseURL.
1271 },
1272 }),
1273 }
1274 cfg.setDefaults("/tmp", "")
1275
1276 env := env.NewFromMap(map[string]string{})
1277 resolver := NewEnvironmentVariableResolver(env)
1278 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
1279 require.NoError(t, err)
1280
1281 // Provider should be rejected for missing base_url.
1282 require.Equal(t, 0, cfg.Providers.Len())
1283 })
1284}
1285
1286func TestConfig_setDefaultsDisableDefaultProvidersEnvVar(t *testing.T) {
1287 t.Run("sets option from environment variable", func(t *testing.T) {
1288 t.Setenv("CRUSH_DISABLE_DEFAULT_PROVIDERS", "true")
1289
1290 cfg := &Config{}
1291 cfg.setDefaults("/tmp", "")
1292
1293 require.True(t, cfg.Options.DisableDefaultProviders)
1294 })
1295
1296 t.Run("does not override when env var is not set", func(t *testing.T) {
1297 cfg := &Config{
1298 Options: &Options{
1299 DisableDefaultProviders: true,
1300 },
1301 }
1302 cfg.setDefaults("/tmp", "")
1303
1304 require.True(t, cfg.Options.DisableDefaultProviders)
1305 })
1306}
1307
1308func TestConfig_configureSelectedModels(t *testing.T) {
1309 t.Run("should override defaults", func(t *testing.T) {
1310 knownProviders := []catwalk.Provider{
1311 {
1312 ID: "openai",
1313 APIKey: "abc",
1314 DefaultLargeModelID: "large-model",
1315 DefaultSmallModelID: "small-model",
1316 Models: []catwalk.Model{
1317 {
1318 ID: "larger-model",
1319 DefaultMaxTokens: 2000,
1320 },
1321 {
1322 ID: "large-model",
1323 DefaultMaxTokens: 1000,
1324 },
1325 {
1326 ID: "small-model",
1327 DefaultMaxTokens: 500,
1328 },
1329 },
1330 },
1331 }
1332
1333 cfg := &Config{
1334 Models: map[SelectedModelType]SelectedModel{
1335 "large": {
1336 Model: "larger-model",
1337 },
1338 },
1339 }
1340 cfg.setDefaults("/tmp", "")
1341 env := env.NewFromMap(map[string]string{})
1342 resolver := NewEnvironmentVariableResolver(env)
1343 err := cfg.configureProviders(env, resolver, knownProviders)
1344 require.NoError(t, err)
1345
1346 err = cfg.configureSelectedModels(knownProviders)
1347 require.NoError(t, err)
1348 large := cfg.Models[SelectedModelTypeLarge]
1349 small := cfg.Models[SelectedModelTypeSmall]
1350 require.Equal(t, "larger-model", large.Model)
1351 require.Equal(t, "openai", large.Provider)
1352 require.Equal(t, int64(2000), large.MaxTokens)
1353 require.Equal(t, "small-model", small.Model)
1354 require.Equal(t, "openai", small.Provider)
1355 require.Equal(t, int64(500), small.MaxTokens)
1356 })
1357 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1358 knownProviders := []catwalk.Provider{
1359 {
1360 ID: "openai",
1361 APIKey: "abc",
1362 DefaultLargeModelID: "large-model",
1363 DefaultSmallModelID: "small-model",
1364 Models: []catwalk.Model{
1365 {
1366 ID: "large-model",
1367 DefaultMaxTokens: 1000,
1368 },
1369 {
1370 ID: "small-model",
1371 DefaultMaxTokens: 500,
1372 },
1373 },
1374 },
1375 {
1376 ID: "anthropic",
1377 APIKey: "abc",
1378 DefaultLargeModelID: "a-large-model",
1379 DefaultSmallModelID: "a-small-model",
1380 Models: []catwalk.Model{
1381 {
1382 ID: "a-large-model",
1383 DefaultMaxTokens: 1000,
1384 },
1385 {
1386 ID: "a-small-model",
1387 DefaultMaxTokens: 200,
1388 },
1389 },
1390 },
1391 }
1392
1393 cfg := &Config{
1394 Models: map[SelectedModelType]SelectedModel{
1395 "small": {
1396 Model: "a-small-model",
1397 Provider: "anthropic",
1398 MaxTokens: 300,
1399 },
1400 },
1401 }
1402 cfg.setDefaults("/tmp", "")
1403 env := env.NewFromMap(map[string]string{})
1404 resolver := NewEnvironmentVariableResolver(env)
1405 err := cfg.configureProviders(env, resolver, knownProviders)
1406 require.NoError(t, err)
1407
1408 err = cfg.configureSelectedModels(knownProviders)
1409 require.NoError(t, err)
1410 large := cfg.Models[SelectedModelTypeLarge]
1411 small := cfg.Models[SelectedModelTypeSmall]
1412 require.Equal(t, "large-model", large.Model)
1413 require.Equal(t, "openai", large.Provider)
1414 require.Equal(t, int64(1000), large.MaxTokens)
1415 require.Equal(t, "a-small-model", small.Model)
1416 require.Equal(t, "anthropic", small.Provider)
1417 require.Equal(t, int64(300), small.MaxTokens)
1418 })
1419
1420 t.Run("should override the max tokens only", func(t *testing.T) {
1421 knownProviders := []catwalk.Provider{
1422 {
1423 ID: "openai",
1424 APIKey: "abc",
1425 DefaultLargeModelID: "large-model",
1426 DefaultSmallModelID: "small-model",
1427 Models: []catwalk.Model{
1428 {
1429 ID: "large-model",
1430 DefaultMaxTokens: 1000,
1431 },
1432 {
1433 ID: "small-model",
1434 DefaultMaxTokens: 500,
1435 },
1436 },
1437 },
1438 }
1439
1440 cfg := &Config{
1441 Models: map[SelectedModelType]SelectedModel{
1442 "large": {
1443 MaxTokens: 100,
1444 },
1445 },
1446 }
1447 cfg.setDefaults("/tmp", "")
1448 env := env.NewFromMap(map[string]string{})
1449 resolver := NewEnvironmentVariableResolver(env)
1450 err := cfg.configureProviders(env, resolver, knownProviders)
1451 require.NoError(t, err)
1452
1453 err = cfg.configureSelectedModels(knownProviders)
1454 require.NoError(t, err)
1455 large := cfg.Models[SelectedModelTypeLarge]
1456 require.Equal(t, "large-model", large.Model)
1457 require.Equal(t, "openai", large.Provider)
1458 require.Equal(t, int64(100), large.MaxTokens)
1459 })
1460}