1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "git.secluded.site/crush/internal/csync"
12 "git.secluded.site/crush/internal/env"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16)
17
18func TestMain(m *testing.M) {
19 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
20
21 exitVal := m.Run()
22 os.Exit(exitVal)
23}
24
25func TestConfig_LoadFromReaders(t *testing.T) {
26 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
27 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
28 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
29
30 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
31
32 require.NoError(t, err)
33 require.NotNil(t, loadedConfig)
34 require.Equal(t, 1, loadedConfig.Providers.Len())
35 pc, _ := loadedConfig.Providers.Get("openai")
36 require.Equal(t, "key2", pc.APIKey)
37 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
38}
39
40func TestConfig_setDefaults(t *testing.T) {
41 cfg := &Config{}
42
43 cfg.setDefaults("/tmp", "")
44
45 require.NotNil(t, cfg.Options)
46 require.NotNil(t, cfg.Options.TUI)
47 require.NotNil(t, cfg.Options.ContextPaths)
48 require.NotNil(t, cfg.Providers)
49 require.NotNil(t, cfg.Models)
50 require.NotNil(t, cfg.LSP)
51 require.NotNil(t, cfg.MCP)
52 require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
53 require.Equal(t, "AGENTS.md", cfg.Options.InitializeAs)
54 for _, path := range defaultContextPaths {
55 require.Contains(t, cfg.Options.ContextPaths, path)
56 }
57 require.Equal(t, "/tmp", cfg.workingDir)
58}
59
60func TestConfig_configureProviders(t *testing.T) {
61 knownProviders := []catwalk.Provider{
62 {
63 ID: "openai",
64 APIKey: "$OPENAI_API_KEY",
65 APIEndpoint: "https://api.openai.com/v1",
66 Models: []catwalk.Model{{
67 ID: "test-model",
68 }},
69 },
70 }
71
72 cfg := &Config{}
73 cfg.setDefaults("/tmp", "")
74 env := env.NewFromMap(map[string]string{
75 "OPENAI_API_KEY": "test-key",
76 })
77 resolver := NewEnvironmentVariableResolver(env)
78 err := cfg.configureProviders(env, resolver, knownProviders)
79 require.NoError(t, err)
80 require.Equal(t, 1, cfg.Providers.Len())
81
82 // We want to make sure that we keep the configured API key as a placeholder
83 pc, _ := cfg.Providers.Get("openai")
84 require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
85}
86
87func TestConfig_configureProvidersWithOverride(t *testing.T) {
88 knownProviders := []catwalk.Provider{
89 {
90 ID: "openai",
91 APIKey: "$OPENAI_API_KEY",
92 APIEndpoint: "https://api.openai.com/v1",
93 Models: []catwalk.Model{{
94 ID: "test-model",
95 }},
96 },
97 }
98
99 cfg := &Config{
100 Providers: csync.NewMap[string, ProviderConfig](),
101 }
102 cfg.Providers.Set("openai", ProviderConfig{
103 APIKey: "xyz",
104 BaseURL: "https://api.openai.com/v2",
105 Models: []catwalk.Model{
106 {
107 ID: "test-model",
108 Name: "Updated",
109 },
110 {
111 ID: "another-model",
112 },
113 },
114 })
115 cfg.setDefaults("/tmp", "")
116
117 env := env.NewFromMap(map[string]string{
118 "OPENAI_API_KEY": "test-key",
119 })
120 resolver := NewEnvironmentVariableResolver(env)
121 err := cfg.configureProviders(env, resolver, knownProviders)
122 require.NoError(t, err)
123 require.Equal(t, 1, cfg.Providers.Len())
124
125 // We want to make sure that we keep the configured API key as a placeholder
126 pc, _ := cfg.Providers.Get("openai")
127 require.Equal(t, "xyz", pc.APIKey)
128 require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
129 require.Len(t, pc.Models, 2)
130 require.Equal(t, "Updated", pc.Models[0].Name)
131}
132
133func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
134 knownProviders := []catwalk.Provider{
135 {
136 ID: "openai",
137 APIKey: "$OPENAI_API_KEY",
138 APIEndpoint: "https://api.openai.com/v1",
139 Models: []catwalk.Model{{
140 ID: "test-model",
141 }},
142 },
143 }
144
145 cfg := &Config{
146 Providers: csync.NewMapFrom(map[string]ProviderConfig{
147 "custom": {
148 APIKey: "xyz",
149 BaseURL: "https://api.someendpoint.com/v2",
150 Models: []catwalk.Model{
151 {
152 ID: "test-model",
153 },
154 },
155 },
156 }),
157 }
158 cfg.setDefaults("/tmp", "")
159 env := env.NewFromMap(map[string]string{
160 "OPENAI_API_KEY": "test-key",
161 })
162 resolver := NewEnvironmentVariableResolver(env)
163 err := cfg.configureProviders(env, resolver, knownProviders)
164 require.NoError(t, err)
165 // Should be to because of the env variable
166 require.Equal(t, cfg.Providers.Len(), 2)
167
168 // We want to make sure that we keep the configured API key as a placeholder
169 pc, _ := cfg.Providers.Get("custom")
170 require.Equal(t, "xyz", pc.APIKey)
171 // Make sure we set the ID correctly
172 require.Equal(t, "custom", pc.ID)
173 require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
174 require.Len(t, pc.Models, 1)
175
176 _, ok := cfg.Providers.Get("openai")
177 require.True(t, ok, "OpenAI provider should still be present")
178}
179
180func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
181 knownProviders := []catwalk.Provider{
182 {
183 ID: catwalk.InferenceProviderBedrock,
184 APIKey: "",
185 APIEndpoint: "",
186 Models: []catwalk.Model{{
187 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
188 }},
189 },
190 }
191
192 cfg := &Config{}
193 cfg.setDefaults("/tmp", "")
194 env := env.NewFromMap(map[string]string{
195 "AWS_ACCESS_KEY_ID": "test-key-id",
196 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
197 })
198 resolver := NewEnvironmentVariableResolver(env)
199 err := cfg.configureProviders(env, resolver, knownProviders)
200 require.NoError(t, err)
201 require.Equal(t, cfg.Providers.Len(), 1)
202
203 bedrockProvider, ok := cfg.Providers.Get("bedrock")
204 require.True(t, ok, "Bedrock provider should be present")
205 require.Len(t, bedrockProvider.Models, 1)
206 require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
207}
208
209func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
210 knownProviders := []catwalk.Provider{
211 {
212 ID: catwalk.InferenceProviderBedrock,
213 APIKey: "",
214 APIEndpoint: "",
215 Models: []catwalk.Model{{
216 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
217 }},
218 },
219 }
220
221 cfg := &Config{}
222 cfg.setDefaults("/tmp", "")
223 env := env.NewFromMap(map[string]string{})
224 resolver := NewEnvironmentVariableResolver(env)
225 err := cfg.configureProviders(env, resolver, knownProviders)
226 require.NoError(t, err)
227 // Provider should not be configured without credentials
228 require.Equal(t, cfg.Providers.Len(), 0)
229}
230
231func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
232 knownProviders := []catwalk.Provider{
233 {
234 ID: catwalk.InferenceProviderBedrock,
235 APIKey: "",
236 APIEndpoint: "",
237 Models: []catwalk.Model{{
238 ID: "some-random-model",
239 }},
240 },
241 }
242
243 cfg := &Config{}
244 cfg.setDefaults("/tmp", "")
245 env := env.NewFromMap(map[string]string{
246 "AWS_ACCESS_KEY_ID": "test-key-id",
247 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
248 })
249 resolver := NewEnvironmentVariableResolver(env)
250 err := cfg.configureProviders(env, resolver, knownProviders)
251 require.Error(t, err)
252}
253
254func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
255 knownProviders := []catwalk.Provider{
256 {
257 ID: catwalk.InferenceProviderVertexAI,
258 APIKey: "",
259 APIEndpoint: "",
260 Models: []catwalk.Model{{
261 ID: "gemini-pro",
262 }},
263 },
264 }
265
266 cfg := &Config{}
267 cfg.setDefaults("/tmp", "")
268 env := env.NewFromMap(map[string]string{
269 "VERTEXAI_PROJECT": "test-project",
270 "VERTEXAI_LOCATION": "us-central1",
271 })
272 resolver := NewEnvironmentVariableResolver(env)
273 err := cfg.configureProviders(env, resolver, knownProviders)
274 require.NoError(t, err)
275 require.Equal(t, cfg.Providers.Len(), 1)
276
277 vertexProvider, ok := cfg.Providers.Get("vertexai")
278 require.True(t, ok, "VertexAI provider should be present")
279 require.Len(t, vertexProvider.Models, 1)
280 require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
281 require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
282 require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
283}
284
285func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
286 knownProviders := []catwalk.Provider{
287 {
288 ID: catwalk.InferenceProviderVertexAI,
289 APIKey: "",
290 APIEndpoint: "",
291 Models: []catwalk.Model{{
292 ID: "gemini-pro",
293 }},
294 },
295 }
296
297 cfg := &Config{}
298 cfg.setDefaults("/tmp", "")
299 env := env.NewFromMap(map[string]string{
300 "GOOGLE_GENAI_USE_VERTEXAI": "false",
301 "GOOGLE_CLOUD_PROJECT": "test-project",
302 "GOOGLE_CLOUD_LOCATION": "us-central1",
303 })
304 resolver := NewEnvironmentVariableResolver(env)
305 err := cfg.configureProviders(env, resolver, knownProviders)
306 require.NoError(t, err)
307 // Provider should not be configured without proper credentials
308 require.Equal(t, cfg.Providers.Len(), 0)
309}
310
311func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
312 knownProviders := []catwalk.Provider{
313 {
314 ID: catwalk.InferenceProviderVertexAI,
315 APIKey: "",
316 APIEndpoint: "",
317 Models: []catwalk.Model{{
318 ID: "gemini-pro",
319 }},
320 },
321 }
322
323 cfg := &Config{}
324 cfg.setDefaults("/tmp", "")
325 env := env.NewFromMap(map[string]string{
326 "GOOGLE_GENAI_USE_VERTEXAI": "true",
327 "GOOGLE_CLOUD_LOCATION": "us-central1",
328 })
329 resolver := NewEnvironmentVariableResolver(env)
330 err := cfg.configureProviders(env, resolver, knownProviders)
331 require.NoError(t, err)
332 // Provider should not be configured without project
333 require.Equal(t, cfg.Providers.Len(), 0)
334}
335
336func TestConfig_configureProvidersSetProviderID(t *testing.T) {
337 knownProviders := []catwalk.Provider{
338 {
339 ID: "openai",
340 APIKey: "$OPENAI_API_KEY",
341 APIEndpoint: "https://api.openai.com/v1",
342 Models: []catwalk.Model{{
343 ID: "test-model",
344 }},
345 },
346 }
347
348 cfg := &Config{}
349 cfg.setDefaults("/tmp", "")
350 env := env.NewFromMap(map[string]string{
351 "OPENAI_API_KEY": "test-key",
352 })
353 resolver := NewEnvironmentVariableResolver(env)
354 err := cfg.configureProviders(env, resolver, knownProviders)
355 require.NoError(t, err)
356 require.Equal(t, cfg.Providers.Len(), 1)
357
358 // Provider ID should be set
359 pc, _ := cfg.Providers.Get("openai")
360 require.Equal(t, "openai", pc.ID)
361}
362
363func TestConfig_EnabledProviders(t *testing.T) {
364 t.Run("all providers enabled", func(t *testing.T) {
365 cfg := &Config{
366 Providers: csync.NewMapFrom(map[string]ProviderConfig{
367 "openai": {
368 ID: "openai",
369 APIKey: "key1",
370 Disable: false,
371 },
372 "anthropic": {
373 ID: "anthropic",
374 APIKey: "key2",
375 Disable: false,
376 },
377 }),
378 }
379
380 enabled := cfg.EnabledProviders()
381 require.Len(t, enabled, 2)
382 })
383
384 t.Run("some providers disabled", func(t *testing.T) {
385 cfg := &Config{
386 Providers: csync.NewMapFrom(map[string]ProviderConfig{
387 "openai": {
388 ID: "openai",
389 APIKey: "key1",
390 Disable: false,
391 },
392 "anthropic": {
393 ID: "anthropic",
394 APIKey: "key2",
395 Disable: true,
396 },
397 }),
398 }
399
400 enabled := cfg.EnabledProviders()
401 require.Len(t, enabled, 1)
402 require.Equal(t, "openai", enabled[0].ID)
403 })
404
405 t.Run("empty providers map", func(t *testing.T) {
406 cfg := &Config{
407 Providers: csync.NewMap[string, ProviderConfig](),
408 }
409
410 enabled := cfg.EnabledProviders()
411 require.Len(t, enabled, 0)
412 })
413}
414
415func TestConfig_IsConfigured(t *testing.T) {
416 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
417 cfg := &Config{
418 Providers: csync.NewMapFrom(map[string]ProviderConfig{
419 "openai": {
420 ID: "openai",
421 APIKey: "key1",
422 Disable: false,
423 },
424 }),
425 }
426
427 require.True(t, cfg.IsConfigured())
428 })
429
430 t.Run("returns false when no providers are configured", func(t *testing.T) {
431 cfg := &Config{
432 Providers: csync.NewMap[string, ProviderConfig](),
433 }
434
435 require.False(t, cfg.IsConfigured())
436 })
437
438 t.Run("returns false when all providers are disabled", func(t *testing.T) {
439 cfg := &Config{
440 Providers: csync.NewMapFrom(map[string]ProviderConfig{
441 "openai": {
442 ID: "openai",
443 APIKey: "key1",
444 Disable: true,
445 },
446 "anthropic": {
447 ID: "anthropic",
448 APIKey: "key2",
449 Disable: true,
450 },
451 }),
452 }
453
454 require.False(t, cfg.IsConfigured())
455 })
456}
457
458func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
459 cfg := &Config{
460 Options: &Options{
461 DisabledTools: []string{},
462 },
463 }
464
465 cfg.SetupAgents()
466 coderAgent, ok := cfg.Agents[AgentCoder]
467 require.True(t, ok)
468 assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
469
470 taskAgent, ok := cfg.Agents[AgentTask]
471 require.True(t, ok)
472 assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
473}
474
475func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
476 cfg := &Config{
477 Options: &Options{
478 DisabledTools: []string{
479 "edit",
480 "download",
481 "grep",
482 },
483 },
484 }
485
486 cfg.SetupAgents()
487 coderAgent, ok := cfg.Agents[AgentCoder]
488 require.True(t, ok)
489
490 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
491
492 taskAgent, ok := cfg.Agents[AgentTask]
493 require.True(t, ok)
494 assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
495}
496
497func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
498 cfg := &Config{
499 Options: &Options{
500 DisabledTools: []string{
501 "glob",
502 "grep",
503 "ls",
504 "sourcegraph",
505 "view",
506 },
507 },
508 }
509
510 cfg.SetupAgents()
511 coderAgent, ok := cfg.Agents[AgentCoder]
512 require.True(t, ok)
513 assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools)
514
515 taskAgent, ok := cfg.Agents[AgentTask]
516 require.True(t, ok)
517 assert.Equal(t, []string{}, taskAgent.AllowedTools)
518}
519
520func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
521 knownProviders := []catwalk.Provider{
522 {
523 ID: "openai",
524 APIKey: "$OPENAI_API_KEY",
525 APIEndpoint: "https://api.openai.com/v1",
526 Models: []catwalk.Model{{
527 ID: "test-model",
528 }},
529 },
530 }
531
532 cfg := &Config{
533 Providers: csync.NewMapFrom(map[string]ProviderConfig{
534 "openai": {
535 Disable: true,
536 },
537 }),
538 }
539 cfg.setDefaults("/tmp", "")
540
541 env := env.NewFromMap(map[string]string{
542 "OPENAI_API_KEY": "test-key",
543 })
544 resolver := NewEnvironmentVariableResolver(env)
545 err := cfg.configureProviders(env, resolver, knownProviders)
546 require.NoError(t, err)
547
548 require.Equal(t, cfg.Providers.Len(), 1)
549 prov, exists := cfg.Providers.Get("openai")
550 require.True(t, exists)
551 require.True(t, prov.Disable)
552}
553
554func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
555 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
556 cfg := &Config{
557 Providers: csync.NewMapFrom(map[string]ProviderConfig{
558 "custom": {
559 BaseURL: "https://api.custom.com/v1",
560 Models: []catwalk.Model{{
561 ID: "test-model",
562 }},
563 },
564 "openai": {
565 APIKey: "$MISSING",
566 },
567 }),
568 }
569 cfg.setDefaults("/tmp", "")
570
571 env := env.NewFromMap(map[string]string{})
572 resolver := NewEnvironmentVariableResolver(env)
573 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
574 require.NoError(t, err)
575
576 require.Equal(t, cfg.Providers.Len(), 1)
577 _, exists := cfg.Providers.Get("custom")
578 require.True(t, exists)
579 })
580
581 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
582 cfg := &Config{
583 Providers: csync.NewMapFrom(map[string]ProviderConfig{
584 "custom": {
585 APIKey: "test-key",
586 Models: []catwalk.Model{{
587 ID: "test-model",
588 }},
589 },
590 }),
591 }
592 cfg.setDefaults("/tmp", "")
593
594 env := env.NewFromMap(map[string]string{})
595 resolver := NewEnvironmentVariableResolver(env)
596 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
597 require.NoError(t, err)
598
599 require.Equal(t, cfg.Providers.Len(), 0)
600 _, exists := cfg.Providers.Get("custom")
601 require.False(t, exists)
602 })
603
604 t.Run("custom provider with no models is removed", func(t *testing.T) {
605 cfg := &Config{
606 Providers: csync.NewMapFrom(map[string]ProviderConfig{
607 "custom": {
608 APIKey: "test-key",
609 BaseURL: "https://api.custom.com/v1",
610 Models: []catwalk.Model{},
611 },
612 }),
613 }
614 cfg.setDefaults("/tmp", "")
615
616 env := env.NewFromMap(map[string]string{})
617 resolver := NewEnvironmentVariableResolver(env)
618 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
619 require.NoError(t, err)
620
621 require.Equal(t, cfg.Providers.Len(), 0)
622 _, exists := cfg.Providers.Get("custom")
623 require.False(t, exists)
624 })
625
626 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
627 cfg := &Config{
628 Providers: csync.NewMapFrom(map[string]ProviderConfig{
629 "custom": {
630 APIKey: "test-key",
631 BaseURL: "https://api.custom.com/v1",
632 Type: "unsupported",
633 Models: []catwalk.Model{{
634 ID: "test-model",
635 }},
636 },
637 }),
638 }
639 cfg.setDefaults("/tmp", "")
640
641 env := env.NewFromMap(map[string]string{})
642 resolver := NewEnvironmentVariableResolver(env)
643 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
644 require.NoError(t, err)
645
646 require.Equal(t, cfg.Providers.Len(), 0)
647 _, exists := cfg.Providers.Get("custom")
648 require.False(t, exists)
649 })
650
651 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
652 cfg := &Config{
653 Providers: csync.NewMapFrom(map[string]ProviderConfig{
654 "custom": {
655 APIKey: "test-key",
656 BaseURL: "https://api.custom.com/v1",
657 Type: catwalk.TypeOpenAI,
658 Models: []catwalk.Model{{
659 ID: "test-model",
660 }},
661 },
662 }),
663 }
664 cfg.setDefaults("/tmp", "")
665
666 env := env.NewFromMap(map[string]string{})
667 resolver := NewEnvironmentVariableResolver(env)
668 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
669 require.NoError(t, err)
670
671 require.Equal(t, cfg.Providers.Len(), 1)
672 customProvider, exists := cfg.Providers.Get("custom")
673 require.True(t, exists)
674 require.Equal(t, "custom", customProvider.ID)
675 require.Equal(t, "test-key", customProvider.APIKey)
676 require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
677 })
678
679 t.Run("custom anthropic provider is supported", func(t *testing.T) {
680 cfg := &Config{
681 Providers: csync.NewMapFrom(map[string]ProviderConfig{
682 "custom-anthropic": {
683 APIKey: "test-key",
684 BaseURL: "https://api.anthropic.com/v1",
685 Type: catwalk.TypeAnthropic,
686 Models: []catwalk.Model{{
687 ID: "claude-3-sonnet",
688 }},
689 },
690 }),
691 }
692 cfg.setDefaults("/tmp", "")
693
694 env := env.NewFromMap(map[string]string{})
695 resolver := NewEnvironmentVariableResolver(env)
696 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
697 require.NoError(t, err)
698
699 require.Equal(t, cfg.Providers.Len(), 1)
700 customProvider, exists := cfg.Providers.Get("custom-anthropic")
701 require.True(t, exists)
702 require.Equal(t, "custom-anthropic", customProvider.ID)
703 require.Equal(t, "test-key", customProvider.APIKey)
704 require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
705 require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
706 })
707
708 t.Run("disabled custom provider is removed", func(t *testing.T) {
709 cfg := &Config{
710 Providers: csync.NewMapFrom(map[string]ProviderConfig{
711 "custom": {
712 APIKey: "test-key",
713 BaseURL: "https://api.custom.com/v1",
714 Type: catwalk.TypeOpenAI,
715 Disable: true,
716 Models: []catwalk.Model{{
717 ID: "test-model",
718 }},
719 },
720 }),
721 }
722 cfg.setDefaults("/tmp", "")
723
724 env := env.NewFromMap(map[string]string{})
725 resolver := NewEnvironmentVariableResolver(env)
726 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
727 require.NoError(t, err)
728
729 require.Equal(t, cfg.Providers.Len(), 0)
730 _, exists := cfg.Providers.Get("custom")
731 require.False(t, exists)
732 })
733}
734
735func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
736 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
737 knownProviders := []catwalk.Provider{
738 {
739 ID: catwalk.InferenceProviderVertexAI,
740 APIKey: "",
741 APIEndpoint: "",
742 Models: []catwalk.Model{{
743 ID: "gemini-pro",
744 }},
745 },
746 }
747
748 cfg := &Config{
749 Providers: csync.NewMapFrom(map[string]ProviderConfig{
750 "vertexai": {
751 BaseURL: "custom-url",
752 },
753 }),
754 }
755 cfg.setDefaults("/tmp", "")
756
757 env := env.NewFromMap(map[string]string{
758 "GOOGLE_GENAI_USE_VERTEXAI": "false",
759 })
760 resolver := NewEnvironmentVariableResolver(env)
761 err := cfg.configureProviders(env, resolver, knownProviders)
762 require.NoError(t, err)
763
764 require.Equal(t, cfg.Providers.Len(), 0)
765 _, exists := cfg.Providers.Get("vertexai")
766 require.False(t, exists)
767 })
768
769 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
770 knownProviders := []catwalk.Provider{
771 {
772 ID: catwalk.InferenceProviderBedrock,
773 APIKey: "",
774 APIEndpoint: "",
775 Models: []catwalk.Model{{
776 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
777 }},
778 },
779 }
780
781 cfg := &Config{
782 Providers: csync.NewMapFrom(map[string]ProviderConfig{
783 "bedrock": {
784 BaseURL: "custom-url",
785 },
786 }),
787 }
788 cfg.setDefaults("/tmp", "")
789
790 env := env.NewFromMap(map[string]string{})
791 resolver := NewEnvironmentVariableResolver(env)
792 err := cfg.configureProviders(env, resolver, knownProviders)
793 require.NoError(t, err)
794
795 require.Equal(t, cfg.Providers.Len(), 0)
796 _, exists := cfg.Providers.Get("bedrock")
797 require.False(t, exists)
798 })
799
800 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
801 knownProviders := []catwalk.Provider{
802 {
803 ID: "openai",
804 APIKey: "$MISSING_API_KEY",
805 APIEndpoint: "https://api.openai.com/v1",
806 Models: []catwalk.Model{{
807 ID: "test-model",
808 }},
809 },
810 }
811
812 cfg := &Config{
813 Providers: csync.NewMapFrom(map[string]ProviderConfig{
814 "openai": {
815 BaseURL: "custom-url",
816 },
817 }),
818 }
819 cfg.setDefaults("/tmp", "")
820
821 env := env.NewFromMap(map[string]string{})
822 resolver := NewEnvironmentVariableResolver(env)
823 err := cfg.configureProviders(env, resolver, knownProviders)
824 require.NoError(t, err)
825
826 require.Equal(t, cfg.Providers.Len(), 0)
827 _, exists := cfg.Providers.Get("openai")
828 require.False(t, exists)
829 })
830
831 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
832 knownProviders := []catwalk.Provider{
833 {
834 ID: "openai",
835 APIKey: "$OPENAI_API_KEY",
836 APIEndpoint: "$MISSING_ENDPOINT",
837 Models: []catwalk.Model{{
838 ID: "test-model",
839 }},
840 },
841 }
842
843 cfg := &Config{
844 Providers: csync.NewMapFrom(map[string]ProviderConfig{
845 "openai": {
846 APIKey: "test-key",
847 },
848 }),
849 }
850 cfg.setDefaults("/tmp", "")
851
852 env := env.NewFromMap(map[string]string{
853 "OPENAI_API_KEY": "test-key",
854 })
855 resolver := NewEnvironmentVariableResolver(env)
856 err := cfg.configureProviders(env, resolver, knownProviders)
857 require.NoError(t, err)
858
859 require.Equal(t, cfg.Providers.Len(), 1)
860 _, exists := cfg.Providers.Get("openai")
861 require.True(t, exists)
862 })
863}
864
865func TestConfig_defaultModelSelection(t *testing.T) {
866 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
867 knownProviders := []catwalk.Provider{
868 {
869 ID: "openai",
870 APIKey: "abc",
871 DefaultLargeModelID: "large-model",
872 DefaultSmallModelID: "small-model",
873 Models: []catwalk.Model{
874 {
875 ID: "large-model",
876 DefaultMaxTokens: 1000,
877 },
878 {
879 ID: "small-model",
880 DefaultMaxTokens: 500,
881 },
882 },
883 },
884 }
885
886 cfg := &Config{}
887 cfg.setDefaults("/tmp", "")
888 env := env.NewFromMap(map[string]string{})
889 resolver := NewEnvironmentVariableResolver(env)
890 err := cfg.configureProviders(env, resolver, knownProviders)
891 require.NoError(t, err)
892
893 large, small, err := cfg.defaultModelSelection(knownProviders)
894 require.NoError(t, err)
895 require.Equal(t, "large-model", large.Model)
896 require.Equal(t, "openai", large.Provider)
897 require.Equal(t, int64(1000), large.MaxTokens)
898 require.Equal(t, "small-model", small.Model)
899 require.Equal(t, "openai", small.Provider)
900 require.Equal(t, int64(500), small.MaxTokens)
901 })
902 t.Run("should error if no providers configured", func(t *testing.T) {
903 knownProviders := []catwalk.Provider{
904 {
905 ID: "openai",
906 APIKey: "$MISSING_KEY",
907 DefaultLargeModelID: "large-model",
908 DefaultSmallModelID: "small-model",
909 Models: []catwalk.Model{
910 {
911 ID: "large-model",
912 DefaultMaxTokens: 1000,
913 },
914 {
915 ID: "small-model",
916 DefaultMaxTokens: 500,
917 },
918 },
919 },
920 }
921
922 cfg := &Config{}
923 cfg.setDefaults("/tmp", "")
924 env := env.NewFromMap(map[string]string{})
925 resolver := NewEnvironmentVariableResolver(env)
926 err := cfg.configureProviders(env, resolver, knownProviders)
927 require.NoError(t, err)
928
929 _, _, err = cfg.defaultModelSelection(knownProviders)
930 require.Error(t, err)
931 })
932 t.Run("should error if model is missing", func(t *testing.T) {
933 knownProviders := []catwalk.Provider{
934 {
935 ID: "openai",
936 APIKey: "abc",
937 DefaultLargeModelID: "large-model",
938 DefaultSmallModelID: "small-model",
939 Models: []catwalk.Model{
940 {
941 ID: "not-large-model",
942 DefaultMaxTokens: 1000,
943 },
944 {
945 ID: "small-model",
946 DefaultMaxTokens: 500,
947 },
948 },
949 },
950 }
951
952 cfg := &Config{}
953 cfg.setDefaults("/tmp", "")
954 env := env.NewFromMap(map[string]string{})
955 resolver := NewEnvironmentVariableResolver(env)
956 err := cfg.configureProviders(env, resolver, knownProviders)
957 require.NoError(t, err)
958 _, _, err = cfg.defaultModelSelection(knownProviders)
959 require.Error(t, err)
960 })
961
962 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
963 knownProviders := []catwalk.Provider{
964 {
965 ID: "openai",
966 APIKey: "$MISSING", // will not be included in the config
967 DefaultLargeModelID: "large-model",
968 DefaultSmallModelID: "small-model",
969 Models: []catwalk.Model{
970 {
971 ID: "not-large-model",
972 DefaultMaxTokens: 1000,
973 },
974 {
975 ID: "small-model",
976 DefaultMaxTokens: 500,
977 },
978 },
979 },
980 }
981
982 cfg := &Config{
983 Providers: csync.NewMapFrom(map[string]ProviderConfig{
984 "custom": {
985 APIKey: "test-key",
986 BaseURL: "https://api.custom.com/v1",
987 Models: []catwalk.Model{
988 {
989 ID: "model",
990 DefaultMaxTokens: 600,
991 },
992 },
993 },
994 }),
995 }
996 cfg.setDefaults("/tmp", "")
997 env := env.NewFromMap(map[string]string{})
998 resolver := NewEnvironmentVariableResolver(env)
999 err := cfg.configureProviders(env, resolver, knownProviders)
1000 require.NoError(t, err)
1001 large, small, err := cfg.defaultModelSelection(knownProviders)
1002 require.NoError(t, err)
1003 require.Equal(t, "model", large.Model)
1004 require.Equal(t, "custom", large.Provider)
1005 require.Equal(t, int64(600), large.MaxTokens)
1006 require.Equal(t, "model", small.Model)
1007 require.Equal(t, "custom", small.Provider)
1008 require.Equal(t, int64(600), small.MaxTokens)
1009 })
1010
1011 t.Run("should fail if no model configured", func(t *testing.T) {
1012 knownProviders := []catwalk.Provider{
1013 {
1014 ID: "openai",
1015 APIKey: "$MISSING", // will not be included in the config
1016 DefaultLargeModelID: "large-model",
1017 DefaultSmallModelID: "small-model",
1018 Models: []catwalk.Model{
1019 {
1020 ID: "not-large-model",
1021 DefaultMaxTokens: 1000,
1022 },
1023 {
1024 ID: "small-model",
1025 DefaultMaxTokens: 500,
1026 },
1027 },
1028 },
1029 }
1030
1031 cfg := &Config{
1032 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1033 "custom": {
1034 APIKey: "test-key",
1035 BaseURL: "https://api.custom.com/v1",
1036 Models: []catwalk.Model{},
1037 },
1038 }),
1039 }
1040 cfg.setDefaults("/tmp", "")
1041 env := env.NewFromMap(map[string]string{})
1042 resolver := NewEnvironmentVariableResolver(env)
1043 err := cfg.configureProviders(env, resolver, knownProviders)
1044 require.NoError(t, err)
1045 _, _, err = cfg.defaultModelSelection(knownProviders)
1046 require.Error(t, err)
1047 })
1048 t.Run("should use the default provider first", func(t *testing.T) {
1049 knownProviders := []catwalk.Provider{
1050 {
1051 ID: "openai",
1052 APIKey: "set",
1053 DefaultLargeModelID: "large-model",
1054 DefaultSmallModelID: "small-model",
1055 Models: []catwalk.Model{
1056 {
1057 ID: "large-model",
1058 DefaultMaxTokens: 1000,
1059 },
1060 {
1061 ID: "small-model",
1062 DefaultMaxTokens: 500,
1063 },
1064 },
1065 },
1066 }
1067
1068 cfg := &Config{
1069 Providers: csync.NewMapFrom(map[string]ProviderConfig{
1070 "custom": {
1071 APIKey: "test-key",
1072 BaseURL: "https://api.custom.com/v1",
1073 Models: []catwalk.Model{
1074 {
1075 ID: "large-model",
1076 DefaultMaxTokens: 1000,
1077 },
1078 },
1079 },
1080 }),
1081 }
1082 cfg.setDefaults("/tmp", "")
1083 env := env.NewFromMap(map[string]string{})
1084 resolver := NewEnvironmentVariableResolver(env)
1085 err := cfg.configureProviders(env, resolver, knownProviders)
1086 require.NoError(t, err)
1087 large, small, err := cfg.defaultModelSelection(knownProviders)
1088 require.NoError(t, err)
1089 require.Equal(t, "large-model", large.Model)
1090 require.Equal(t, "openai", large.Provider)
1091 require.Equal(t, int64(1000), large.MaxTokens)
1092 require.Equal(t, "small-model", small.Model)
1093 require.Equal(t, "openai", small.Provider)
1094 require.Equal(t, int64(500), small.MaxTokens)
1095 })
1096}
1097
1098func TestConfig_configureSelectedModels(t *testing.T) {
1099 t.Run("should override defaults", func(t *testing.T) {
1100 knownProviders := []catwalk.Provider{
1101 {
1102 ID: "openai",
1103 APIKey: "abc",
1104 DefaultLargeModelID: "large-model",
1105 DefaultSmallModelID: "small-model",
1106 Models: []catwalk.Model{
1107 {
1108 ID: "larger-model",
1109 DefaultMaxTokens: 2000,
1110 },
1111 {
1112 ID: "large-model",
1113 DefaultMaxTokens: 1000,
1114 },
1115 {
1116 ID: "small-model",
1117 DefaultMaxTokens: 500,
1118 },
1119 },
1120 },
1121 }
1122
1123 cfg := &Config{
1124 Models: map[SelectedModelType]SelectedModel{
1125 "large": {
1126 Model: "larger-model",
1127 },
1128 },
1129 }
1130 cfg.setDefaults("/tmp", "")
1131 env := env.NewFromMap(map[string]string{})
1132 resolver := NewEnvironmentVariableResolver(env)
1133 err := cfg.configureProviders(env, resolver, knownProviders)
1134 require.NoError(t, err)
1135
1136 err = cfg.configureSelectedModels(knownProviders)
1137 require.NoError(t, err)
1138 large := cfg.Models[SelectedModelTypeLarge]
1139 small := cfg.Models[SelectedModelTypeSmall]
1140 require.Equal(t, "larger-model", large.Model)
1141 require.Equal(t, "openai", large.Provider)
1142 require.Equal(t, int64(2000), large.MaxTokens)
1143 require.Equal(t, "small-model", small.Model)
1144 require.Equal(t, "openai", small.Provider)
1145 require.Equal(t, int64(500), small.MaxTokens)
1146 })
1147 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1148 knownProviders := []catwalk.Provider{
1149 {
1150 ID: "openai",
1151 APIKey: "abc",
1152 DefaultLargeModelID: "large-model",
1153 DefaultSmallModelID: "small-model",
1154 Models: []catwalk.Model{
1155 {
1156 ID: "large-model",
1157 DefaultMaxTokens: 1000,
1158 },
1159 {
1160 ID: "small-model",
1161 DefaultMaxTokens: 500,
1162 },
1163 },
1164 },
1165 {
1166 ID: "anthropic",
1167 APIKey: "abc",
1168 DefaultLargeModelID: "a-large-model",
1169 DefaultSmallModelID: "a-small-model",
1170 Models: []catwalk.Model{
1171 {
1172 ID: "a-large-model",
1173 DefaultMaxTokens: 1000,
1174 },
1175 {
1176 ID: "a-small-model",
1177 DefaultMaxTokens: 200,
1178 },
1179 },
1180 },
1181 }
1182
1183 cfg := &Config{
1184 Models: map[SelectedModelType]SelectedModel{
1185 "small": {
1186 Model: "a-small-model",
1187 Provider: "anthropic",
1188 MaxTokens: 300,
1189 },
1190 },
1191 }
1192 cfg.setDefaults("/tmp", "")
1193 env := env.NewFromMap(map[string]string{})
1194 resolver := NewEnvironmentVariableResolver(env)
1195 err := cfg.configureProviders(env, resolver, knownProviders)
1196 require.NoError(t, err)
1197
1198 err = cfg.configureSelectedModels(knownProviders)
1199 require.NoError(t, err)
1200 large := cfg.Models[SelectedModelTypeLarge]
1201 small := cfg.Models[SelectedModelTypeSmall]
1202 require.Equal(t, "large-model", large.Model)
1203 require.Equal(t, "openai", large.Provider)
1204 require.Equal(t, int64(1000), large.MaxTokens)
1205 require.Equal(t, "a-small-model", small.Model)
1206 require.Equal(t, "anthropic", small.Provider)
1207 require.Equal(t, int64(300), small.MaxTokens)
1208 })
1209
1210 t.Run("should override the max tokens only", func(t *testing.T) {
1211 knownProviders := []catwalk.Provider{
1212 {
1213 ID: "openai",
1214 APIKey: "abc",
1215 DefaultLargeModelID: "large-model",
1216 DefaultSmallModelID: "small-model",
1217 Models: []catwalk.Model{
1218 {
1219 ID: "large-model",
1220 DefaultMaxTokens: 1000,
1221 },
1222 {
1223 ID: "small-model",
1224 DefaultMaxTokens: 500,
1225 },
1226 },
1227 },
1228 }
1229
1230 cfg := &Config{
1231 Models: map[SelectedModelType]SelectedModel{
1232 "large": {
1233 MaxTokens: 100,
1234 },
1235 },
1236 }
1237 cfg.setDefaults("/tmp", "")
1238 env := env.NewFromMap(map[string]string{})
1239 resolver := NewEnvironmentVariableResolver(env)
1240 err := cfg.configureProviders(env, resolver, knownProviders)
1241 require.NoError(t, err)
1242
1243 err = cfg.configureSelectedModels(knownProviders)
1244 require.NoError(t, err)
1245 large := cfg.Models[SelectedModelTypeLarge]
1246 require.Equal(t, "large-model", large.Model)
1247 require.Equal(t, "openai", large.Provider)
1248 require.Equal(t, int64(100), large.MaxTokens)
1249 })
1250}