1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/stretchr/testify/assert"
15)
16
17func TestMain(m *testing.M) {
18 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
19
20 exitVal := m.Run()
21 os.Exit(exitVal)
22}
23
24func TestConfig_LoadFromReaders(t *testing.T) {
25 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
26 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
27 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
28
29 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
30
31 assert.NoError(t, err)
32 assert.NotNil(t, loadedConfig)
33 assert.Equal(t, 1, loadedConfig.Providers.Len())
34 pc, _ := loadedConfig.Providers.Get("openai")
35 assert.Equal(t, "key2", pc.APIKey)
36 assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
37}
38
39func TestConfig_setDefaults(t *testing.T) {
40 cfg := &Config{}
41
42 cfg.setDefaults("/tmp")
43
44 assert.NotNil(t, cfg.Options)
45 assert.NotNil(t, cfg.Options.TUI)
46 assert.NotNil(t, cfg.Options.ContextPaths)
47 assert.NotNil(t, cfg.Providers)
48 assert.NotNil(t, cfg.Models)
49 assert.NotNil(t, cfg.LSP)
50 assert.NotNil(t, cfg.MCP)
51 assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
52 for _, path := range defaultContextPaths {
53 assert.Contains(t, cfg.Options.ContextPaths, path)
54 }
55 assert.Equal(t, "/tmp", cfg.workingDir)
56}
57
58func TestConfig_configureProviders(t *testing.T) {
59 knownProviders := []catwalk.Provider{
60 {
61 ID: "openai",
62 APIKey: "$OPENAI_API_KEY",
63 APIEndpoint: "https://api.openai.com/v1",
64 Models: []catwalk.Model{{
65 ID: "test-model",
66 }},
67 },
68 }
69
70 cfg := &Config{}
71 cfg.setDefaults("/tmp")
72 env := env.NewFromMap(map[string]string{
73 "OPENAI_API_KEY": "test-key",
74 })
75 resolver := NewEnvironmentVariableResolver(env)
76 err := cfg.configureProviders(env, resolver, knownProviders)
77 assert.NoError(t, err)
78 assert.Equal(t, 1, cfg.Providers.Len())
79
80 // We want to make sure that we keep the configured API key as a placeholder
81 pc, _ := cfg.Providers.Get("openai")
82 assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
83}
84
85func TestConfig_configureProvidersWithOverride(t *testing.T) {
86 knownProviders := []catwalk.Provider{
87 {
88 ID: "openai",
89 APIKey: "$OPENAI_API_KEY",
90 APIEndpoint: "https://api.openai.com/v1",
91 Models: []catwalk.Model{{
92 ID: "test-model",
93 }},
94 },
95 }
96
97 cfg := &Config{
98 Providers: csync.NewMap[string, ProviderConfig](),
99 }
100 cfg.Providers.Set("openai", ProviderConfig{
101 APIKey: "xyz",
102 BaseURL: "https://api.openai.com/v2",
103 Models: []catwalk.Model{
104 {
105 ID: "test-model",
106 Name: "Updated",
107 },
108 {
109 ID: "another-model",
110 },
111 },
112 })
113 cfg.setDefaults("/tmp")
114
115 env := env.NewFromMap(map[string]string{
116 "OPENAI_API_KEY": "test-key",
117 })
118 resolver := NewEnvironmentVariableResolver(env)
119 err := cfg.configureProviders(env, resolver, knownProviders)
120 assert.NoError(t, err)
121 assert.Equal(t, 1, cfg.Providers.Len())
122
123 // We want to make sure that we keep the configured API key as a placeholder
124 pc, _ := cfg.Providers.Get("openai")
125 assert.Equal(t, "xyz", pc.APIKey)
126 assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
127 assert.Len(t, pc.Models, 2)
128 assert.Equal(t, "Updated", pc.Models[0].Name)
129}
130
131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
132 knownProviders := []catwalk.Provider{
133 {
134 ID: "openai",
135 APIKey: "$OPENAI_API_KEY",
136 APIEndpoint: "https://api.openai.com/v1",
137 Models: []catwalk.Model{{
138 ID: "test-model",
139 }},
140 },
141 }
142
143 cfg := &Config{
144 Providers: csync.NewMapFrom(map[string]ProviderConfig{
145 "custom": {
146 APIKey: "xyz",
147 BaseURL: "https://api.someendpoint.com/v2",
148 Models: []catwalk.Model{
149 {
150 ID: "test-model",
151 },
152 },
153 },
154 }),
155 }
156 cfg.setDefaults("/tmp")
157 env := env.NewFromMap(map[string]string{
158 "OPENAI_API_KEY": "test-key",
159 })
160 resolver := NewEnvironmentVariableResolver(env)
161 err := cfg.configureProviders(env, resolver, knownProviders)
162 assert.NoError(t, err)
163 // Should be to because of the env variable
164 assert.Equal(t, cfg.Providers.Len(), 2)
165
166 // We want to make sure that we keep the configured API key as a placeholder
167 pc, _ := cfg.Providers.Get("custom")
168 assert.Equal(t, "xyz", pc.APIKey)
169 // Make sure we set the ID correctly
170 assert.Equal(t, "custom", pc.ID)
171 assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
172 assert.Len(t, pc.Models, 1)
173
174 _, ok := cfg.Providers.Get("openai")
175 assert.True(t, ok, "OpenAI provider should still be present")
176}
177
178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
179 knownProviders := []catwalk.Provider{
180 {
181 ID: catwalk.InferenceProviderBedrock,
182 APIKey: "",
183 APIEndpoint: "",
184 Models: []catwalk.Model{{
185 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
186 }},
187 },
188 }
189
190 cfg := &Config{}
191 cfg.setDefaults("/tmp")
192 env := env.NewFromMap(map[string]string{
193 "AWS_ACCESS_KEY_ID": "test-key-id",
194 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
195 })
196 resolver := NewEnvironmentVariableResolver(env)
197 err := cfg.configureProviders(env, resolver, knownProviders)
198 assert.NoError(t, err)
199 assert.Equal(t, cfg.Providers.Len(), 1)
200
201 bedrockProvider, ok := cfg.Providers.Get("bedrock")
202 assert.True(t, ok, "Bedrock provider should be present")
203 assert.Len(t, bedrockProvider.Models, 1)
204 assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
205}
206
207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
208 knownProviders := []catwalk.Provider{
209 {
210 ID: catwalk.InferenceProviderBedrock,
211 APIKey: "",
212 APIEndpoint: "",
213 Models: []catwalk.Model{{
214 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
215 }},
216 },
217 }
218
219 cfg := &Config{}
220 cfg.setDefaults("/tmp")
221 env := env.NewFromMap(map[string]string{})
222 resolver := NewEnvironmentVariableResolver(env)
223 err := cfg.configureProviders(env, resolver, knownProviders)
224 assert.NoError(t, err)
225 // Provider should not be configured without credentials
226 assert.Equal(t, cfg.Providers.Len(), 0)
227}
228
229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
230 knownProviders := []catwalk.Provider{
231 {
232 ID: catwalk.InferenceProviderBedrock,
233 APIKey: "",
234 APIEndpoint: "",
235 Models: []catwalk.Model{{
236 ID: "some-random-model",
237 }},
238 },
239 }
240
241 cfg := &Config{}
242 cfg.setDefaults("/tmp")
243 env := env.NewFromMap(map[string]string{
244 "AWS_ACCESS_KEY_ID": "test-key-id",
245 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
246 })
247 resolver := NewEnvironmentVariableResolver(env)
248 err := cfg.configureProviders(env, resolver, knownProviders)
249 assert.Error(t, err)
250}
251
252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
253 knownProviders := []catwalk.Provider{
254 {
255 ID: catwalk.InferenceProviderVertexAI,
256 APIKey: "",
257 APIEndpoint: "",
258 Models: []catwalk.Model{{
259 ID: "gemini-pro",
260 }},
261 },
262 }
263
264 cfg := &Config{}
265 cfg.setDefaults("/tmp")
266 env := env.NewFromMap(map[string]string{
267 "GOOGLE_GENAI_USE_VERTEXAI": "true",
268 "GOOGLE_CLOUD_PROJECT": "test-project",
269 "GOOGLE_CLOUD_LOCATION": "us-central1",
270 })
271 resolver := NewEnvironmentVariableResolver(env)
272 err := cfg.configureProviders(env, resolver, knownProviders)
273 assert.NoError(t, err)
274 assert.Equal(t, cfg.Providers.Len(), 1)
275
276 vertexProvider, ok := cfg.Providers.Get("vertexai")
277 assert.True(t, ok, "VertexAI provider should be present")
278 assert.Len(t, vertexProvider.Models, 1)
279 assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
280 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
281 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
282}
283
284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
285 knownProviders := []catwalk.Provider{
286 {
287 ID: catwalk.InferenceProviderVertexAI,
288 APIKey: "",
289 APIEndpoint: "",
290 Models: []catwalk.Model{{
291 ID: "gemini-pro",
292 }},
293 },
294 }
295
296 cfg := &Config{}
297 cfg.setDefaults("/tmp")
298 env := env.NewFromMap(map[string]string{
299 "GOOGLE_GENAI_USE_VERTEXAI": "false",
300 "GOOGLE_CLOUD_PROJECT": "test-project",
301 "GOOGLE_CLOUD_LOCATION": "us-central1",
302 })
303 resolver := NewEnvironmentVariableResolver(env)
304 err := cfg.configureProviders(env, resolver, knownProviders)
305 assert.NoError(t, err)
306 // Provider should not be configured without proper credentials
307 assert.Equal(t, cfg.Providers.Len(), 0)
308}
309
310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
311 knownProviders := []catwalk.Provider{
312 {
313 ID: catwalk.InferenceProviderVertexAI,
314 APIKey: "",
315 APIEndpoint: "",
316 Models: []catwalk.Model{{
317 ID: "gemini-pro",
318 }},
319 },
320 }
321
322 cfg := &Config{}
323 cfg.setDefaults("/tmp")
324 env := env.NewFromMap(map[string]string{
325 "GOOGLE_GENAI_USE_VERTEXAI": "true",
326 "GOOGLE_CLOUD_LOCATION": "us-central1",
327 })
328 resolver := NewEnvironmentVariableResolver(env)
329 err := cfg.configureProviders(env, resolver, knownProviders)
330 assert.NoError(t, err)
331 // Provider should not be configured without project
332 assert.Equal(t, cfg.Providers.Len(), 0)
333}
334
335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
336 knownProviders := []catwalk.Provider{
337 {
338 ID: "openai",
339 APIKey: "$OPENAI_API_KEY",
340 APIEndpoint: "https://api.openai.com/v1",
341 Models: []catwalk.Model{{
342 ID: "test-model",
343 }},
344 },
345 }
346
347 cfg := &Config{}
348 cfg.setDefaults("/tmp")
349 env := env.NewFromMap(map[string]string{
350 "OPENAI_API_KEY": "test-key",
351 })
352 resolver := NewEnvironmentVariableResolver(env)
353 err := cfg.configureProviders(env, resolver, knownProviders)
354 assert.NoError(t, err)
355 assert.Equal(t, cfg.Providers.Len(), 1)
356
357 // Provider ID should be set
358 pc, _ := cfg.Providers.Get("openai")
359 assert.Equal(t, "openai", pc.ID)
360}
361
362func TestConfig_EnabledProviders(t *testing.T) {
363 t.Run("all providers enabled", func(t *testing.T) {
364 cfg := &Config{
365 Providers: csync.NewMapFrom(map[string]ProviderConfig{
366 "openai": {
367 ID: "openai",
368 APIKey: "key1",
369 Disable: false,
370 },
371 "anthropic": {
372 ID: "anthropic",
373 APIKey: "key2",
374 Disable: false,
375 },
376 }),
377 }
378
379 enabled := cfg.EnabledProviders()
380 assert.Len(t, enabled, 2)
381 })
382
383 t.Run("some providers disabled", func(t *testing.T) {
384 cfg := &Config{
385 Providers: csync.NewMapFrom(map[string]ProviderConfig{
386 "openai": {
387 ID: "openai",
388 APIKey: "key1",
389 Disable: false,
390 },
391 "anthropic": {
392 ID: "anthropic",
393 APIKey: "key2",
394 Disable: true,
395 },
396 }),
397 }
398
399 enabled := cfg.EnabledProviders()
400 assert.Len(t, enabled, 1)
401 assert.Equal(t, "openai", enabled[0].ID)
402 })
403
404 t.Run("empty providers map", func(t *testing.T) {
405 cfg := &Config{
406 Providers: csync.NewMap[string, ProviderConfig](),
407 }
408
409 enabled := cfg.EnabledProviders()
410 assert.Len(t, enabled, 0)
411 })
412}
413
414func TestConfig_IsConfigured(t *testing.T) {
415 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
416 cfg := &Config{
417 Providers: csync.NewMapFrom(map[string]ProviderConfig{
418 "openai": {
419 ID: "openai",
420 APIKey: "key1",
421 Disable: false,
422 },
423 }),
424 }
425
426 assert.True(t, cfg.IsConfigured())
427 })
428
429 t.Run("returns false when no providers are configured", func(t *testing.T) {
430 cfg := &Config{
431 Providers: csync.NewMap[string, ProviderConfig](),
432 }
433
434 assert.False(t, cfg.IsConfigured())
435 })
436
437 t.Run("returns false when all providers are disabled", func(t *testing.T) {
438 cfg := &Config{
439 Providers: csync.NewMapFrom(map[string]ProviderConfig{
440 "openai": {
441 ID: "openai",
442 APIKey: "key1",
443 Disable: true,
444 },
445 "anthropic": {
446 ID: "anthropic",
447 APIKey: "key2",
448 Disable: true,
449 },
450 }),
451 }
452
453 assert.False(t, cfg.IsConfigured())
454 })
455}
456
457func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
458 knownProviders := []catwalk.Provider{
459 {
460 ID: "openai",
461 APIKey: "$OPENAI_API_KEY",
462 APIEndpoint: "https://api.openai.com/v1",
463 Models: []catwalk.Model{{
464 ID: "test-model",
465 }},
466 },
467 }
468
469 cfg := &Config{
470 Providers: csync.NewMapFrom(map[string]ProviderConfig{
471 "openai": {
472 Disable: true,
473 },
474 }),
475 }
476 cfg.setDefaults("/tmp")
477
478 env := env.NewFromMap(map[string]string{
479 "OPENAI_API_KEY": "test-key",
480 })
481 resolver := NewEnvironmentVariableResolver(env)
482 err := cfg.configureProviders(env, resolver, knownProviders)
483 assert.NoError(t, err)
484
485 // Provider should be removed from config when disabled
486 assert.Equal(t, cfg.Providers.Len(), 0)
487 _, exists := cfg.Providers.Get("openai")
488 assert.False(t, exists)
489}
490
491func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
492 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
493 cfg := &Config{
494 Providers: csync.NewMapFrom(map[string]ProviderConfig{
495 "custom": {
496 BaseURL: "https://api.custom.com/v1",
497 Models: []catwalk.Model{{
498 ID: "test-model",
499 }},
500 },
501 "openai": {
502 APIKey: "$MISSING",
503 },
504 }),
505 }
506 cfg.setDefaults("/tmp")
507
508 env := env.NewFromMap(map[string]string{})
509 resolver := NewEnvironmentVariableResolver(env)
510 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
511 assert.NoError(t, err)
512
513 assert.Equal(t, cfg.Providers.Len(), 1)
514 _, exists := cfg.Providers.Get("custom")
515 assert.True(t, exists)
516 })
517
518 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
519 cfg := &Config{
520 Providers: csync.NewMapFrom(map[string]ProviderConfig{
521 "custom": {
522 APIKey: "test-key",
523 Models: []catwalk.Model{{
524 ID: "test-model",
525 }},
526 },
527 }),
528 }
529 cfg.setDefaults("/tmp")
530
531 env := env.NewFromMap(map[string]string{})
532 resolver := NewEnvironmentVariableResolver(env)
533 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
534 assert.NoError(t, err)
535
536 assert.Equal(t, cfg.Providers.Len(), 0)
537 _, exists := cfg.Providers.Get("custom")
538 assert.False(t, exists)
539 })
540
541 t.Run("custom provider with no models is removed", func(t *testing.T) {
542 cfg := &Config{
543 Providers: csync.NewMapFrom(map[string]ProviderConfig{
544 "custom": {
545 APIKey: "test-key",
546 BaseURL: "https://api.custom.com/v1",
547 Models: []catwalk.Model{},
548 },
549 }),
550 }
551 cfg.setDefaults("/tmp")
552
553 env := env.NewFromMap(map[string]string{})
554 resolver := NewEnvironmentVariableResolver(env)
555 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
556 assert.NoError(t, err)
557
558 assert.Equal(t, cfg.Providers.Len(), 0)
559 _, exists := cfg.Providers.Get("custom")
560 assert.False(t, exists)
561 })
562
563 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
564 cfg := &Config{
565 Providers: csync.NewMapFrom(map[string]ProviderConfig{
566 "custom": {
567 APIKey: "test-key",
568 BaseURL: "https://api.custom.com/v1",
569 Type: "unsupported",
570 Models: []catwalk.Model{{
571 ID: "test-model",
572 }},
573 },
574 }),
575 }
576 cfg.setDefaults("/tmp")
577
578 env := env.NewFromMap(map[string]string{})
579 resolver := NewEnvironmentVariableResolver(env)
580 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
581 assert.NoError(t, err)
582
583 assert.Equal(t, cfg.Providers.Len(), 0)
584 _, exists := cfg.Providers.Get("custom")
585 assert.False(t, exists)
586 })
587
588 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
589 cfg := &Config{
590 Providers: csync.NewMapFrom(map[string]ProviderConfig{
591 "custom": {
592 APIKey: "test-key",
593 BaseURL: "https://api.custom.com/v1",
594 Type: catwalk.TypeOpenAI,
595 Models: []catwalk.Model{{
596 ID: "test-model",
597 }},
598 },
599 }),
600 }
601 cfg.setDefaults("/tmp")
602
603 env := env.NewFromMap(map[string]string{})
604 resolver := NewEnvironmentVariableResolver(env)
605 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
606 assert.NoError(t, err)
607
608 assert.Equal(t, cfg.Providers.Len(), 1)
609 customProvider, exists := cfg.Providers.Get("custom")
610 assert.True(t, exists)
611 assert.Equal(t, "custom", customProvider.ID)
612 assert.Equal(t, "test-key", customProvider.APIKey)
613 assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
614 })
615
616 t.Run("disabled custom provider is removed", func(t *testing.T) {
617 cfg := &Config{
618 Providers: csync.NewMapFrom(map[string]ProviderConfig{
619 "custom": {
620 APIKey: "test-key",
621 BaseURL: "https://api.custom.com/v1",
622 Type: catwalk.TypeOpenAI,
623 Disable: true,
624 Models: []catwalk.Model{{
625 ID: "test-model",
626 }},
627 },
628 }),
629 }
630 cfg.setDefaults("/tmp")
631
632 env := env.NewFromMap(map[string]string{})
633 resolver := NewEnvironmentVariableResolver(env)
634 err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
635 assert.NoError(t, err)
636
637 assert.Equal(t, cfg.Providers.Len(), 0)
638 _, exists := cfg.Providers.Get("custom")
639 assert.False(t, exists)
640 })
641}
642
643func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
644 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
645 knownProviders := []catwalk.Provider{
646 {
647 ID: catwalk.InferenceProviderVertexAI,
648 APIKey: "",
649 APIEndpoint: "",
650 Models: []catwalk.Model{{
651 ID: "gemini-pro",
652 }},
653 },
654 }
655
656 cfg := &Config{
657 Providers: csync.NewMapFrom(map[string]ProviderConfig{
658 "vertexai": {
659 BaseURL: "custom-url",
660 },
661 }),
662 }
663 cfg.setDefaults("/tmp")
664
665 env := env.NewFromMap(map[string]string{
666 "GOOGLE_GENAI_USE_VERTEXAI": "false",
667 })
668 resolver := NewEnvironmentVariableResolver(env)
669 err := cfg.configureProviders(env, resolver, knownProviders)
670 assert.NoError(t, err)
671
672 assert.Equal(t, cfg.Providers.Len(), 0)
673 _, exists := cfg.Providers.Get("vertexai")
674 assert.False(t, exists)
675 })
676
677 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
678 knownProviders := []catwalk.Provider{
679 {
680 ID: catwalk.InferenceProviderBedrock,
681 APIKey: "",
682 APIEndpoint: "",
683 Models: []catwalk.Model{{
684 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
685 }},
686 },
687 }
688
689 cfg := &Config{
690 Providers: csync.NewMapFrom(map[string]ProviderConfig{
691 "bedrock": {
692 BaseURL: "custom-url",
693 },
694 }),
695 }
696 cfg.setDefaults("/tmp")
697
698 env := env.NewFromMap(map[string]string{})
699 resolver := NewEnvironmentVariableResolver(env)
700 err := cfg.configureProviders(env, resolver, knownProviders)
701 assert.NoError(t, err)
702
703 assert.Equal(t, cfg.Providers.Len(), 0)
704 _, exists := cfg.Providers.Get("bedrock")
705 assert.False(t, exists)
706 })
707
708 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
709 knownProviders := []catwalk.Provider{
710 {
711 ID: "openai",
712 APIKey: "$MISSING_API_KEY",
713 APIEndpoint: "https://api.openai.com/v1",
714 Models: []catwalk.Model{{
715 ID: "test-model",
716 }},
717 },
718 }
719
720 cfg := &Config{
721 Providers: csync.NewMapFrom(map[string]ProviderConfig{
722 "openai": {
723 BaseURL: "custom-url",
724 },
725 }),
726 }
727 cfg.setDefaults("/tmp")
728
729 env := env.NewFromMap(map[string]string{})
730 resolver := NewEnvironmentVariableResolver(env)
731 err := cfg.configureProviders(env, resolver, knownProviders)
732 assert.NoError(t, err)
733
734 assert.Equal(t, cfg.Providers.Len(), 0)
735 _, exists := cfg.Providers.Get("openai")
736 assert.False(t, exists)
737 })
738
739 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
740 knownProviders := []catwalk.Provider{
741 {
742 ID: "openai",
743 APIKey: "$OPENAI_API_KEY",
744 APIEndpoint: "$MISSING_ENDPOINT",
745 Models: []catwalk.Model{{
746 ID: "test-model",
747 }},
748 },
749 }
750
751 cfg := &Config{
752 Providers: csync.NewMapFrom(map[string]ProviderConfig{
753 "openai": {
754 APIKey: "test-key",
755 },
756 }),
757 }
758 cfg.setDefaults("/tmp")
759
760 env := env.NewFromMap(map[string]string{
761 "OPENAI_API_KEY": "test-key",
762 })
763 resolver := NewEnvironmentVariableResolver(env)
764 err := cfg.configureProviders(env, resolver, knownProviders)
765 assert.NoError(t, err)
766
767 assert.Equal(t, cfg.Providers.Len(), 1)
768 _, exists := cfg.Providers.Get("openai")
769 assert.True(t, exists)
770 })
771}
772
773func TestConfig_defaultModelSelection(t *testing.T) {
774 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
775 knownProviders := []catwalk.Provider{
776 {
777 ID: "openai",
778 APIKey: "abc",
779 DefaultLargeModelID: "large-model",
780 DefaultSmallModelID: "small-model",
781 Models: []catwalk.Model{
782 {
783 ID: "large-model",
784 DefaultMaxTokens: 1000,
785 },
786 {
787 ID: "small-model",
788 DefaultMaxTokens: 500,
789 },
790 },
791 },
792 }
793
794 cfg := &Config{}
795 cfg.setDefaults("/tmp")
796 env := env.NewFromMap(map[string]string{})
797 resolver := NewEnvironmentVariableResolver(env)
798 err := cfg.configureProviders(env, resolver, knownProviders)
799 assert.NoError(t, err)
800
801 large, small, err := cfg.defaultModelSelection(knownProviders)
802 assert.NoError(t, err)
803 assert.Equal(t, "large-model", large.Model)
804 assert.Equal(t, "openai", large.Provider)
805 assert.Equal(t, int64(1000), large.MaxTokens)
806 assert.Equal(t, "small-model", small.Model)
807 assert.Equal(t, "openai", small.Provider)
808 assert.Equal(t, int64(500), small.MaxTokens)
809 })
810 t.Run("should error if no providers configured", func(t *testing.T) {
811 knownProviders := []catwalk.Provider{
812 {
813 ID: "openai",
814 APIKey: "$MISSING_KEY",
815 DefaultLargeModelID: "large-model",
816 DefaultSmallModelID: "small-model",
817 Models: []catwalk.Model{
818 {
819 ID: "large-model",
820 DefaultMaxTokens: 1000,
821 },
822 {
823 ID: "small-model",
824 DefaultMaxTokens: 500,
825 },
826 },
827 },
828 }
829
830 cfg := &Config{}
831 cfg.setDefaults("/tmp")
832 env := env.NewFromMap(map[string]string{})
833 resolver := NewEnvironmentVariableResolver(env)
834 err := cfg.configureProviders(env, resolver, knownProviders)
835 assert.NoError(t, err)
836
837 _, _, err = cfg.defaultModelSelection(knownProviders)
838 assert.Error(t, err)
839 })
840 t.Run("should error if model is missing", func(t *testing.T) {
841 knownProviders := []catwalk.Provider{
842 {
843 ID: "openai",
844 APIKey: "abc",
845 DefaultLargeModelID: "large-model",
846 DefaultSmallModelID: "small-model",
847 Models: []catwalk.Model{
848 {
849 ID: "not-large-model",
850 DefaultMaxTokens: 1000,
851 },
852 {
853 ID: "small-model",
854 DefaultMaxTokens: 500,
855 },
856 },
857 },
858 }
859
860 cfg := &Config{}
861 cfg.setDefaults("/tmp")
862 env := env.NewFromMap(map[string]string{})
863 resolver := NewEnvironmentVariableResolver(env)
864 err := cfg.configureProviders(env, resolver, knownProviders)
865 assert.NoError(t, err)
866 _, _, err = cfg.defaultModelSelection(knownProviders)
867 assert.Error(t, err)
868 })
869
870 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
871 knownProviders := []catwalk.Provider{
872 {
873 ID: "openai",
874 APIKey: "$MISSING", // will not be included in the config
875 DefaultLargeModelID: "large-model",
876 DefaultSmallModelID: "small-model",
877 Models: []catwalk.Model{
878 {
879 ID: "not-large-model",
880 DefaultMaxTokens: 1000,
881 },
882 {
883 ID: "small-model",
884 DefaultMaxTokens: 500,
885 },
886 },
887 },
888 }
889
890 cfg := &Config{
891 Providers: csync.NewMapFrom(map[string]ProviderConfig{
892 "custom": {
893 APIKey: "test-key",
894 BaseURL: "https://api.custom.com/v1",
895 Models: []catwalk.Model{
896 {
897 ID: "model",
898 DefaultMaxTokens: 600,
899 },
900 },
901 },
902 }),
903 }
904 cfg.setDefaults("/tmp")
905 env := env.NewFromMap(map[string]string{})
906 resolver := NewEnvironmentVariableResolver(env)
907 err := cfg.configureProviders(env, resolver, knownProviders)
908 assert.NoError(t, err)
909 large, small, err := cfg.defaultModelSelection(knownProviders)
910 assert.NoError(t, err)
911 assert.Equal(t, "model", large.Model)
912 assert.Equal(t, "custom", large.Provider)
913 assert.Equal(t, int64(600), large.MaxTokens)
914 assert.Equal(t, "model", small.Model)
915 assert.Equal(t, "custom", small.Provider)
916 assert.Equal(t, int64(600), small.MaxTokens)
917 })
918
919 t.Run("should fail if no model configured", func(t *testing.T) {
920 knownProviders := []catwalk.Provider{
921 {
922 ID: "openai",
923 APIKey: "$MISSING", // will not be included in the config
924 DefaultLargeModelID: "large-model",
925 DefaultSmallModelID: "small-model",
926 Models: []catwalk.Model{
927 {
928 ID: "not-large-model",
929 DefaultMaxTokens: 1000,
930 },
931 {
932 ID: "small-model",
933 DefaultMaxTokens: 500,
934 },
935 },
936 },
937 }
938
939 cfg := &Config{
940 Providers: csync.NewMapFrom(map[string]ProviderConfig{
941 "custom": {
942 APIKey: "test-key",
943 BaseURL: "https://api.custom.com/v1",
944 Models: []catwalk.Model{},
945 },
946 }),
947 }
948 cfg.setDefaults("/tmp")
949 env := env.NewFromMap(map[string]string{})
950 resolver := NewEnvironmentVariableResolver(env)
951 err := cfg.configureProviders(env, resolver, knownProviders)
952 assert.NoError(t, err)
953 _, _, err = cfg.defaultModelSelection(knownProviders)
954 assert.Error(t, err)
955 })
956 t.Run("should use the default provider first", func(t *testing.T) {
957 knownProviders := []catwalk.Provider{
958 {
959 ID: "openai",
960 APIKey: "set",
961 DefaultLargeModelID: "large-model",
962 DefaultSmallModelID: "small-model",
963 Models: []catwalk.Model{
964 {
965 ID: "large-model",
966 DefaultMaxTokens: 1000,
967 },
968 {
969 ID: "small-model",
970 DefaultMaxTokens: 500,
971 },
972 },
973 },
974 }
975
976 cfg := &Config{
977 Providers: csync.NewMapFrom(map[string]ProviderConfig{
978 "custom": {
979 APIKey: "test-key",
980 BaseURL: "https://api.custom.com/v1",
981 Models: []catwalk.Model{
982 {
983 ID: "large-model",
984 DefaultMaxTokens: 1000,
985 },
986 },
987 },
988 }),
989 }
990 cfg.setDefaults("/tmp")
991 env := env.NewFromMap(map[string]string{})
992 resolver := NewEnvironmentVariableResolver(env)
993 err := cfg.configureProviders(env, resolver, knownProviders)
994 assert.NoError(t, err)
995 large, small, err := cfg.defaultModelSelection(knownProviders)
996 assert.NoError(t, err)
997 assert.Equal(t, "large-model", large.Model)
998 assert.Equal(t, "openai", large.Provider)
999 assert.Equal(t, int64(1000), large.MaxTokens)
1000 assert.Equal(t, "small-model", small.Model)
1001 assert.Equal(t, "openai", small.Provider)
1002 assert.Equal(t, int64(500), small.MaxTokens)
1003 })
1004}
1005
1006func TestConfig_configureSelectedModels(t *testing.T) {
1007 t.Run("should override defaults", func(t *testing.T) {
1008 knownProviders := []catwalk.Provider{
1009 {
1010 ID: "openai",
1011 APIKey: "abc",
1012 DefaultLargeModelID: "large-model",
1013 DefaultSmallModelID: "small-model",
1014 Models: []catwalk.Model{
1015 {
1016 ID: "larger-model",
1017 DefaultMaxTokens: 2000,
1018 },
1019 {
1020 ID: "large-model",
1021 DefaultMaxTokens: 1000,
1022 },
1023 {
1024 ID: "small-model",
1025 DefaultMaxTokens: 500,
1026 },
1027 },
1028 },
1029 }
1030
1031 cfg := &Config{
1032 Models: map[SelectedModelType]SelectedModel{
1033 "large": {
1034 Model: "larger-model",
1035 },
1036 },
1037 }
1038 cfg.setDefaults("/tmp")
1039 env := env.NewFromMap(map[string]string{})
1040 resolver := NewEnvironmentVariableResolver(env)
1041 err := cfg.configureProviders(env, resolver, knownProviders)
1042 assert.NoError(t, err)
1043
1044 err = cfg.configureSelectedModels(knownProviders)
1045 assert.NoError(t, err)
1046 large := cfg.Models[SelectedModelTypeLarge]
1047 small := cfg.Models[SelectedModelTypeSmall]
1048 assert.Equal(t, "larger-model", large.Model)
1049 assert.Equal(t, "openai", large.Provider)
1050 assert.Equal(t, int64(2000), large.MaxTokens)
1051 assert.Equal(t, "small-model", small.Model)
1052 assert.Equal(t, "openai", small.Provider)
1053 assert.Equal(t, int64(500), small.MaxTokens)
1054 })
1055 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1056 knownProviders := []catwalk.Provider{
1057 {
1058 ID: "openai",
1059 APIKey: "abc",
1060 DefaultLargeModelID: "large-model",
1061 DefaultSmallModelID: "small-model",
1062 Models: []catwalk.Model{
1063 {
1064 ID: "large-model",
1065 DefaultMaxTokens: 1000,
1066 },
1067 {
1068 ID: "small-model",
1069 DefaultMaxTokens: 500,
1070 },
1071 },
1072 },
1073 {
1074 ID: "anthropic",
1075 APIKey: "abc",
1076 DefaultLargeModelID: "a-large-model",
1077 DefaultSmallModelID: "a-small-model",
1078 Models: []catwalk.Model{
1079 {
1080 ID: "a-large-model",
1081 DefaultMaxTokens: 1000,
1082 },
1083 {
1084 ID: "a-small-model",
1085 DefaultMaxTokens: 200,
1086 },
1087 },
1088 },
1089 }
1090
1091 cfg := &Config{
1092 Models: map[SelectedModelType]SelectedModel{
1093 "small": {
1094 Model: "a-small-model",
1095 Provider: "anthropic",
1096 MaxTokens: 300,
1097 },
1098 },
1099 }
1100 cfg.setDefaults("/tmp")
1101 env := env.NewFromMap(map[string]string{})
1102 resolver := NewEnvironmentVariableResolver(env)
1103 err := cfg.configureProviders(env, resolver, knownProviders)
1104 assert.NoError(t, err)
1105
1106 err = cfg.configureSelectedModels(knownProviders)
1107 assert.NoError(t, err)
1108 large := cfg.Models[SelectedModelTypeLarge]
1109 small := cfg.Models[SelectedModelTypeSmall]
1110 assert.Equal(t, "large-model", large.Model)
1111 assert.Equal(t, "openai", large.Provider)
1112 assert.Equal(t, int64(1000), large.MaxTokens)
1113 assert.Equal(t, "a-small-model", small.Model)
1114 assert.Equal(t, "anthropic", small.Provider)
1115 assert.Equal(t, int64(300), small.MaxTokens)
1116 })
1117
1118 t.Run("should override the max tokens only", func(t *testing.T) {
1119 knownProviders := []catwalk.Provider{
1120 {
1121 ID: "openai",
1122 APIKey: "abc",
1123 DefaultLargeModelID: "large-model",
1124 DefaultSmallModelID: "small-model",
1125 Models: []catwalk.Model{
1126 {
1127 ID: "large-model",
1128 DefaultMaxTokens: 1000,
1129 },
1130 {
1131 ID: "small-model",
1132 DefaultMaxTokens: 500,
1133 },
1134 },
1135 },
1136 }
1137
1138 cfg := &Config{
1139 Models: map[SelectedModelType]SelectedModel{
1140 "large": {
1141 MaxTokens: 100,
1142 },
1143 },
1144 }
1145 cfg.setDefaults("/tmp")
1146 env := env.NewFromMap(map[string]string{})
1147 resolver := NewEnvironmentVariableResolver(env)
1148 err := cfg.configureProviders(env, resolver, knownProviders)
1149 assert.NoError(t, err)
1150
1151 err = cfg.configureSelectedModels(knownProviders)
1152 assert.NoError(t, err)
1153 large := cfg.Models[SelectedModelTypeLarge]
1154 assert.Equal(t, "large-model", large.Model)
1155 assert.Equal(t, "openai", large.Provider)
1156 assert.Equal(t, int64(100), large.MaxTokens)
1157 })
1158}