1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/crush/internal/env"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/stretchr/testify/assert"
14)
15
16func TestMain(m *testing.M) {
17 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
18
19 exitVal := m.Run()
20 os.Exit(exitVal)
21}
22
23func TestConfig_LoadFromReaders(t *testing.T) {
24 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
25 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
26 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
27
28 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
29
30 assert.NoError(t, err)
31 assert.NotNil(t, loadedConfig)
32 assert.Len(t, loadedConfig.Providers, 1)
33 assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey)
34 assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL)
35}
36
37func TestConfig_setDefaults(t *testing.T) {
38 cfg := &Config{}
39
40 cfg.setDefaults("/tmp")
41
42 assert.NotNil(t, cfg.Options)
43 assert.NotNil(t, cfg.Options.TUI)
44 assert.NotNil(t, cfg.Options.ContextPaths)
45 assert.NotNil(t, cfg.Providers)
46 assert.NotNil(t, cfg.Models)
47 assert.NotNil(t, cfg.LSP)
48 assert.NotNil(t, cfg.MCP)
49 assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
50 for _, path := range defaultContextPaths {
51 assert.Contains(t, cfg.Options.ContextPaths, path)
52 }
53 assert.Equal(t, "/tmp", cfg.workingDir)
54}
55
56func TestConfig_configureProviders(t *testing.T) {
57 knownProviders := []provider.Provider{
58 {
59 ID: "openai",
60 APIKey: "$OPENAI_API_KEY",
61 APIEndpoint: "https://api.openai.com/v1",
62 Models: []provider.Model{{
63 ID: "test-model",
64 }},
65 },
66 }
67
68 cfg := &Config{}
69 cfg.setDefaults("/tmp")
70 env := env.NewFromMap(map[string]string{
71 "OPENAI_API_KEY": "test-key",
72 })
73 resolver := NewEnvironmentVariableResolver(env)
74 err := cfg.configureProviders(env, resolver, knownProviders)
75 assert.NoError(t, err)
76 assert.Len(t, cfg.Providers, 1)
77
78 // We want to make sure that we keep the configured API key as a placeholder
79 assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey)
80}
81
82func TestConfig_configureProvidersWithOverride(t *testing.T) {
83 knownProviders := []provider.Provider{
84 {
85 ID: "openai",
86 APIKey: "$OPENAI_API_KEY",
87 APIEndpoint: "https://api.openai.com/v1",
88 Models: []provider.Model{{
89 ID: "test-model",
90 }},
91 },
92 }
93
94 cfg := &Config{
95 Providers: map[string]ProviderConfig{
96 "openai": {
97 APIKey: "xyz",
98 BaseURL: "https://api.openai.com/v2",
99 Models: []provider.Model{
100 {
101 ID: "test-model",
102 Model: "Updated",
103 },
104 {
105 ID: "another-model",
106 },
107 },
108 },
109 },
110 }
111 cfg.setDefaults("/tmp")
112
113 env := env.NewFromMap(map[string]string{
114 "OPENAI_API_KEY": "test-key",
115 })
116 resolver := NewEnvironmentVariableResolver(env)
117 err := cfg.configureProviders(env, resolver, knownProviders)
118 assert.NoError(t, err)
119 assert.Len(t, cfg.Providers, 1)
120
121 // We want to make sure that we keep the configured API key as a placeholder
122 assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
123 assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
124 assert.Len(t, cfg.Providers["openai"].Models, 2)
125 assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model)
126}
127
128func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
129 knownProviders := []provider.Provider{
130 {
131 ID: "openai",
132 APIKey: "$OPENAI_API_KEY",
133 APIEndpoint: "https://api.openai.com/v1",
134 Models: []provider.Model{{
135 ID: "test-model",
136 }},
137 },
138 }
139
140 cfg := &Config{
141 Providers: map[string]ProviderConfig{
142 "custom": {
143 APIKey: "xyz",
144 BaseURL: "https://api.someendpoint.com/v2",
145 Models: []provider.Model{
146 {
147 ID: "test-model",
148 },
149 },
150 },
151 },
152 }
153 cfg.setDefaults("/tmp")
154 env := env.NewFromMap(map[string]string{
155 "OPENAI_API_KEY": "test-key",
156 })
157 resolver := NewEnvironmentVariableResolver(env)
158 err := cfg.configureProviders(env, resolver, knownProviders)
159 assert.NoError(t, err)
160 // Should be to because of the env variable
161 assert.Len(t, cfg.Providers, 2)
162
163 // We want to make sure that we keep the configured API key as a placeholder
164 assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey)
165 // Make sure we set the ID correctly
166 assert.Equal(t, "custom", cfg.Providers["custom"].ID)
167 assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL)
168 assert.Len(t, cfg.Providers["custom"].Models, 1)
169
170 _, ok := cfg.Providers["openai"]
171 assert.True(t, ok, "OpenAI provider should still be present")
172}
173
174func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
175 knownProviders := []provider.Provider{
176 {
177 ID: provider.InferenceProviderBedrock,
178 APIKey: "",
179 APIEndpoint: "",
180 Models: []provider.Model{{
181 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
182 }},
183 },
184 }
185
186 cfg := &Config{}
187 cfg.setDefaults("/tmp")
188 env := env.NewFromMap(map[string]string{
189 "AWS_ACCESS_KEY_ID": "test-key-id",
190 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
191 })
192 resolver := NewEnvironmentVariableResolver(env)
193 err := cfg.configureProviders(env, resolver, knownProviders)
194 assert.NoError(t, err)
195 assert.Len(t, cfg.Providers, 1)
196
197 bedrockProvider, ok := cfg.Providers["bedrock"]
198 assert.True(t, ok, "Bedrock provider should be present")
199 assert.Len(t, bedrockProvider.Models, 1)
200 assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
201}
202
203func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
204 knownProviders := []provider.Provider{
205 {
206 ID: provider.InferenceProviderBedrock,
207 APIKey: "",
208 APIEndpoint: "",
209 Models: []provider.Model{{
210 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
211 }},
212 },
213 }
214
215 cfg := &Config{}
216 cfg.setDefaults("/tmp")
217 env := env.NewFromMap(map[string]string{})
218 resolver := NewEnvironmentVariableResolver(env)
219 err := cfg.configureProviders(env, resolver, knownProviders)
220 assert.NoError(t, err)
221 // Provider should not be configured without credentials
222 assert.Len(t, cfg.Providers, 0)
223}
224
225func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
226 knownProviders := []provider.Provider{
227 {
228 ID: provider.InferenceProviderBedrock,
229 APIKey: "",
230 APIEndpoint: "",
231 Models: []provider.Model{{
232 ID: "some-random-model",
233 }},
234 },
235 }
236
237 cfg := &Config{}
238 cfg.setDefaults("/tmp")
239 env := env.NewFromMap(map[string]string{
240 "AWS_ACCESS_KEY_ID": "test-key-id",
241 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
242 })
243 resolver := NewEnvironmentVariableResolver(env)
244 err := cfg.configureProviders(env, resolver, knownProviders)
245 assert.Error(t, err)
246}
247
248func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
249 knownProviders := []provider.Provider{
250 {
251 ID: provider.InferenceProviderVertexAI,
252 APIKey: "",
253 APIEndpoint: "",
254 Models: []provider.Model{{
255 ID: "gemini-pro",
256 }},
257 },
258 }
259
260 cfg := &Config{}
261 cfg.setDefaults("/tmp")
262 env := env.NewFromMap(map[string]string{
263 "GOOGLE_GENAI_USE_VERTEXAI": "true",
264 "GOOGLE_CLOUD_PROJECT": "test-project",
265 "GOOGLE_CLOUD_LOCATION": "us-central1",
266 })
267 resolver := NewEnvironmentVariableResolver(env)
268 err := cfg.configureProviders(env, resolver, knownProviders)
269 assert.NoError(t, err)
270 assert.Len(t, cfg.Providers, 1)
271
272 vertexProvider, ok := cfg.Providers["vertexai"]
273 assert.True(t, ok, "VertexAI provider should be present")
274 assert.Len(t, vertexProvider.Models, 1)
275 assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
276 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
277 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
278}
279
280func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
281 knownProviders := []provider.Provider{
282 {
283 ID: provider.InferenceProviderVertexAI,
284 APIKey: "",
285 APIEndpoint: "",
286 Models: []provider.Model{{
287 ID: "gemini-pro",
288 }},
289 },
290 }
291
292 cfg := &Config{}
293 cfg.setDefaults("/tmp")
294 env := env.NewFromMap(map[string]string{
295 "GOOGLE_GENAI_USE_VERTEXAI": "false",
296 "GOOGLE_CLOUD_PROJECT": "test-project",
297 "GOOGLE_CLOUD_LOCATION": "us-central1",
298 })
299 resolver := NewEnvironmentVariableResolver(env)
300 err := cfg.configureProviders(env, resolver, knownProviders)
301 assert.NoError(t, err)
302 // Provider should not be configured without proper credentials
303 assert.Len(t, cfg.Providers, 0)
304}
305
306func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
307 knownProviders := []provider.Provider{
308 {
309 ID: provider.InferenceProviderVertexAI,
310 APIKey: "",
311 APIEndpoint: "",
312 Models: []provider.Model{{
313 ID: "gemini-pro",
314 }},
315 },
316 }
317
318 cfg := &Config{}
319 cfg.setDefaults("/tmp")
320 env := env.NewFromMap(map[string]string{
321 "GOOGLE_GENAI_USE_VERTEXAI": "true",
322 "GOOGLE_CLOUD_LOCATION": "us-central1",
323 })
324 resolver := NewEnvironmentVariableResolver(env)
325 err := cfg.configureProviders(env, resolver, knownProviders)
326 assert.NoError(t, err)
327 // Provider should not be configured without project
328 assert.Len(t, cfg.Providers, 0)
329}
330
331func TestConfig_configureProvidersSetProviderID(t *testing.T) {
332 knownProviders := []provider.Provider{
333 {
334 ID: "openai",
335 APIKey: "$OPENAI_API_KEY",
336 APIEndpoint: "https://api.openai.com/v1",
337 Models: []provider.Model{{
338 ID: "test-model",
339 }},
340 },
341 }
342
343 cfg := &Config{}
344 cfg.setDefaults("/tmp")
345 env := env.NewFromMap(map[string]string{
346 "OPENAI_API_KEY": "test-key",
347 })
348 resolver := NewEnvironmentVariableResolver(env)
349 err := cfg.configureProviders(env, resolver, knownProviders)
350 assert.NoError(t, err)
351 assert.Len(t, cfg.Providers, 1)
352
353 // Provider ID should be set
354 assert.Equal(t, "openai", cfg.Providers["openai"].ID)
355}
356
357func TestConfig_EnabledProviders(t *testing.T) {
358 t.Run("all providers enabled", func(t *testing.T) {
359 cfg := &Config{
360 Providers: map[string]ProviderConfig{
361 "openai": {
362 ID: "openai",
363 APIKey: "key1",
364 Disable: false,
365 },
366 "anthropic": {
367 ID: "anthropic",
368 APIKey: "key2",
369 Disable: false,
370 },
371 },
372 }
373
374 enabled := cfg.EnabledProviders()
375 assert.Len(t, enabled, 2)
376 })
377
378 t.Run("some providers disabled", func(t *testing.T) {
379 cfg := &Config{
380 Providers: map[string]ProviderConfig{
381 "openai": {
382 ID: "openai",
383 APIKey: "key1",
384 Disable: false,
385 },
386 "anthropic": {
387 ID: "anthropic",
388 APIKey: "key2",
389 Disable: true,
390 },
391 },
392 }
393
394 enabled := cfg.EnabledProviders()
395 assert.Len(t, enabled, 1)
396 assert.Equal(t, "openai", enabled[0].ID)
397 })
398
399 t.Run("empty providers map", func(t *testing.T) {
400 cfg := &Config{
401 Providers: map[string]ProviderConfig{},
402 }
403
404 enabled := cfg.EnabledProviders()
405 assert.Len(t, enabled, 0)
406 })
407}
408
409func TestConfig_IsConfigured(t *testing.T) {
410 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
411 cfg := &Config{
412 Providers: map[string]ProviderConfig{
413 "openai": {
414 ID: "openai",
415 APIKey: "key1",
416 Disable: false,
417 },
418 },
419 }
420
421 assert.True(t, cfg.IsConfigured())
422 })
423
424 t.Run("returns false when no providers are configured", func(t *testing.T) {
425 cfg := &Config{
426 Providers: map[string]ProviderConfig{},
427 }
428
429 assert.False(t, cfg.IsConfigured())
430 })
431
432 t.Run("returns false when all providers are disabled", func(t *testing.T) {
433 cfg := &Config{
434 Providers: map[string]ProviderConfig{
435 "openai": {
436 ID: "openai",
437 APIKey: "key1",
438 Disable: true,
439 },
440 "anthropic": {
441 ID: "anthropic",
442 APIKey: "key2",
443 Disable: true,
444 },
445 },
446 }
447
448 assert.False(t, cfg.IsConfigured())
449 })
450}
451
452func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
453 knownProviders := []provider.Provider{
454 {
455 ID: "openai",
456 APIKey: "$OPENAI_API_KEY",
457 APIEndpoint: "https://api.openai.com/v1",
458 Models: []provider.Model{{
459 ID: "test-model",
460 }},
461 },
462 }
463
464 cfg := &Config{
465 Providers: map[string]ProviderConfig{
466 "openai": {
467 Disable: true,
468 },
469 },
470 }
471 cfg.setDefaults("/tmp")
472
473 env := env.NewFromMap(map[string]string{
474 "OPENAI_API_KEY": "test-key",
475 })
476 resolver := NewEnvironmentVariableResolver(env)
477 err := cfg.configureProviders(env, resolver, knownProviders)
478 assert.NoError(t, err)
479
480 // Provider should be removed from config when disabled
481 assert.Len(t, cfg.Providers, 0)
482 _, exists := cfg.Providers["openai"]
483 assert.False(t, exists)
484}
485
486func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
487 t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
488 cfg := &Config{
489 Providers: map[string]ProviderConfig{
490 "custom": {
491 BaseURL: "https://api.custom.com/v1",
492 Models: []provider.Model{{
493 ID: "test-model",
494 }},
495 },
496 "openai": {
497 APIKey: "$MISSING",
498 },
499 },
500 }
501 cfg.setDefaults("/tmp")
502
503 env := env.NewFromMap(map[string]string{})
504 resolver := NewEnvironmentVariableResolver(env)
505 err := cfg.configureProviders(env, resolver, []provider.Provider{})
506 assert.NoError(t, err)
507
508 assert.Len(t, cfg.Providers, 1)
509 _, exists := cfg.Providers["custom"]
510 assert.True(t, exists)
511 })
512
513 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
514 cfg := &Config{
515 Providers: map[string]ProviderConfig{
516 "custom": {
517 APIKey: "test-key",
518 Models: []provider.Model{{
519 ID: "test-model",
520 }},
521 },
522 },
523 }
524 cfg.setDefaults("/tmp")
525
526 env := env.NewFromMap(map[string]string{})
527 resolver := NewEnvironmentVariableResolver(env)
528 err := cfg.configureProviders(env, resolver, []provider.Provider{})
529 assert.NoError(t, err)
530
531 assert.Len(t, cfg.Providers, 0)
532 _, exists := cfg.Providers["custom"]
533 assert.False(t, exists)
534 })
535
536 t.Run("custom provider with no models is removed", func(t *testing.T) {
537 cfg := &Config{
538 Providers: map[string]ProviderConfig{
539 "custom": {
540 APIKey: "test-key",
541 BaseURL: "https://api.custom.com/v1",
542 Models: []provider.Model{},
543 },
544 },
545 }
546 cfg.setDefaults("/tmp")
547
548 env := env.NewFromMap(map[string]string{})
549 resolver := NewEnvironmentVariableResolver(env)
550 err := cfg.configureProviders(env, resolver, []provider.Provider{})
551 assert.NoError(t, err)
552
553 assert.Len(t, cfg.Providers, 0)
554 _, exists := cfg.Providers["custom"]
555 assert.False(t, exists)
556 })
557
558 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
559 cfg := &Config{
560 Providers: map[string]ProviderConfig{
561 "custom": {
562 APIKey: "test-key",
563 BaseURL: "https://api.custom.com/v1",
564 Type: "unsupported",
565 Models: []provider.Model{{
566 ID: "test-model",
567 }},
568 },
569 },
570 }
571 cfg.setDefaults("/tmp")
572
573 env := env.NewFromMap(map[string]string{})
574 resolver := NewEnvironmentVariableResolver(env)
575 err := cfg.configureProviders(env, resolver, []provider.Provider{})
576 assert.NoError(t, err)
577
578 assert.Len(t, cfg.Providers, 0)
579 _, exists := cfg.Providers["custom"]
580 assert.False(t, exists)
581 })
582
583 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
584 cfg := &Config{
585 Providers: map[string]ProviderConfig{
586 "custom": {
587 APIKey: "test-key",
588 BaseURL: "https://api.custom.com/v1",
589 Type: provider.TypeOpenAI,
590 Models: []provider.Model{{
591 ID: "test-model",
592 }},
593 },
594 },
595 }
596 cfg.setDefaults("/tmp")
597
598 env := env.NewFromMap(map[string]string{})
599 resolver := NewEnvironmentVariableResolver(env)
600 err := cfg.configureProviders(env, resolver, []provider.Provider{})
601 assert.NoError(t, err)
602
603 assert.Len(t, cfg.Providers, 1)
604 customProvider, exists := cfg.Providers["custom"]
605 assert.True(t, exists)
606 assert.Equal(t, "custom", customProvider.ID)
607 assert.Equal(t, "test-key", customProvider.APIKey)
608 assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
609 })
610
611 t.Run("disabled custom provider is removed", func(t *testing.T) {
612 cfg := &Config{
613 Providers: map[string]ProviderConfig{
614 "custom": {
615 APIKey: "test-key",
616 BaseURL: "https://api.custom.com/v1",
617 Type: provider.TypeOpenAI,
618 Disable: true,
619 Models: []provider.Model{{
620 ID: "test-model",
621 }},
622 },
623 },
624 }
625 cfg.setDefaults("/tmp")
626
627 env := env.NewFromMap(map[string]string{})
628 resolver := NewEnvironmentVariableResolver(env)
629 err := cfg.configureProviders(env, resolver, []provider.Provider{})
630 assert.NoError(t, err)
631
632 assert.Len(t, cfg.Providers, 0)
633 _, exists := cfg.Providers["custom"]
634 assert.False(t, exists)
635 })
636}
637
638func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
639 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
640 knownProviders := []provider.Provider{
641 {
642 ID: provider.InferenceProviderVertexAI,
643 APIKey: "",
644 APIEndpoint: "",
645 Models: []provider.Model{{
646 ID: "gemini-pro",
647 }},
648 },
649 }
650
651 cfg := &Config{
652 Providers: map[string]ProviderConfig{
653 "vertexai": {
654 BaseURL: "custom-url",
655 },
656 },
657 }
658 cfg.setDefaults("/tmp")
659
660 env := env.NewFromMap(map[string]string{
661 "GOOGLE_GENAI_USE_VERTEXAI": "false",
662 })
663 resolver := NewEnvironmentVariableResolver(env)
664 err := cfg.configureProviders(env, resolver, knownProviders)
665 assert.NoError(t, err)
666
667 assert.Len(t, cfg.Providers, 0)
668 _, exists := cfg.Providers["vertexai"]
669 assert.False(t, exists)
670 })
671
672 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
673 knownProviders := []provider.Provider{
674 {
675 ID: provider.InferenceProviderBedrock,
676 APIKey: "",
677 APIEndpoint: "",
678 Models: []provider.Model{{
679 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
680 }},
681 },
682 }
683
684 cfg := &Config{
685 Providers: map[string]ProviderConfig{
686 "bedrock": {
687 BaseURL: "custom-url",
688 },
689 },
690 }
691 cfg.setDefaults("/tmp")
692
693 env := env.NewFromMap(map[string]string{})
694 resolver := NewEnvironmentVariableResolver(env)
695 err := cfg.configureProviders(env, resolver, knownProviders)
696 assert.NoError(t, err)
697
698 assert.Len(t, cfg.Providers, 0)
699 _, exists := cfg.Providers["bedrock"]
700 assert.False(t, exists)
701 })
702
703 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
704 knownProviders := []provider.Provider{
705 {
706 ID: "openai",
707 APIKey: "$MISSING_API_KEY",
708 APIEndpoint: "https://api.openai.com/v1",
709 Models: []provider.Model{{
710 ID: "test-model",
711 }},
712 },
713 }
714
715 cfg := &Config{
716 Providers: map[string]ProviderConfig{
717 "openai": {
718 BaseURL: "custom-url",
719 },
720 },
721 }
722 cfg.setDefaults("/tmp")
723
724 env := env.NewFromMap(map[string]string{})
725 resolver := NewEnvironmentVariableResolver(env)
726 err := cfg.configureProviders(env, resolver, knownProviders)
727 assert.NoError(t, err)
728
729 assert.Len(t, cfg.Providers, 0)
730 _, exists := cfg.Providers["openai"]
731 assert.False(t, exists)
732 })
733
734 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
735 knownProviders := []provider.Provider{
736 {
737 ID: "openai",
738 APIKey: "$OPENAI_API_KEY",
739 APIEndpoint: "$MISSING_ENDPOINT",
740 Models: []provider.Model{{
741 ID: "test-model",
742 }},
743 },
744 }
745
746 cfg := &Config{
747 Providers: map[string]ProviderConfig{
748 "openai": {
749 APIKey: "test-key",
750 },
751 },
752 }
753 cfg.setDefaults("/tmp")
754
755 env := env.NewFromMap(map[string]string{
756 "OPENAI_API_KEY": "test-key",
757 })
758 resolver := NewEnvironmentVariableResolver(env)
759 err := cfg.configureProviders(env, resolver, knownProviders)
760 assert.NoError(t, err)
761
762 assert.Len(t, cfg.Providers, 1)
763 _, exists := cfg.Providers["openai"]
764 assert.True(t, exists)
765 })
766}
767
768func TestConfig_defaultModelSelection(t *testing.T) {
769 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
770 knownProviders := []provider.Provider{
771 {
772 ID: "openai",
773 APIKey: "abc",
774 DefaultLargeModelID: "large-model",
775 DefaultSmallModelID: "small-model",
776 Models: []provider.Model{
777 {
778 ID: "large-model",
779 DefaultMaxTokens: 1000,
780 },
781 {
782 ID: "small-model",
783 DefaultMaxTokens: 500,
784 },
785 },
786 },
787 }
788
789 cfg := &Config{}
790 cfg.setDefaults("/tmp")
791 env := env.NewFromMap(map[string]string{})
792 resolver := NewEnvironmentVariableResolver(env)
793 err := cfg.configureProviders(env, resolver, knownProviders)
794 assert.NoError(t, err)
795
796 large, small, err := cfg.defaultModelSelection(knownProviders)
797 assert.NoError(t, err)
798 assert.Equal(t, "large-model", large.Model)
799 assert.Equal(t, "openai", large.Provider)
800 assert.Equal(t, int64(1000), large.MaxTokens)
801 assert.Equal(t, "small-model", small.Model)
802 assert.Equal(t, "openai", small.Provider)
803 assert.Equal(t, int64(500), small.MaxTokens)
804 })
805 t.Run("should error if no providers configured", func(t *testing.T) {
806 knownProviders := []provider.Provider{
807 {
808 ID: "openai",
809 APIKey: "$MISSING_KEY",
810 DefaultLargeModelID: "large-model",
811 DefaultSmallModelID: "small-model",
812 Models: []provider.Model{
813 {
814 ID: "large-model",
815 DefaultMaxTokens: 1000,
816 },
817 {
818 ID: "small-model",
819 DefaultMaxTokens: 500,
820 },
821 },
822 },
823 }
824
825 cfg := &Config{}
826 cfg.setDefaults("/tmp")
827 env := env.NewFromMap(map[string]string{})
828 resolver := NewEnvironmentVariableResolver(env)
829 err := cfg.configureProviders(env, resolver, knownProviders)
830 assert.NoError(t, err)
831
832 _, _, err = cfg.defaultModelSelection(knownProviders)
833 assert.Error(t, err)
834 })
835 t.Run("should error if model is missing", func(t *testing.T) {
836 knownProviders := []provider.Provider{
837 {
838 ID: "openai",
839 APIKey: "abc",
840 DefaultLargeModelID: "large-model",
841 DefaultSmallModelID: "small-model",
842 Models: []provider.Model{
843 {
844 ID: "not-large-model",
845 DefaultMaxTokens: 1000,
846 },
847 {
848 ID: "small-model",
849 DefaultMaxTokens: 500,
850 },
851 },
852 },
853 }
854
855 cfg := &Config{}
856 cfg.setDefaults("/tmp")
857 env := env.NewFromMap(map[string]string{})
858 resolver := NewEnvironmentVariableResolver(env)
859 err := cfg.configureProviders(env, resolver, knownProviders)
860 assert.NoError(t, err)
861 _, _, err = cfg.defaultModelSelection(knownProviders)
862 assert.Error(t, err)
863 })
864
865 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
866 knownProviders := []provider.Provider{
867 {
868 ID: "openai",
869 APIKey: "$MISSING", // will not be included in the config
870 DefaultLargeModelID: "large-model",
871 DefaultSmallModelID: "small-model",
872 Models: []provider.Model{
873 {
874 ID: "not-large-model",
875 DefaultMaxTokens: 1000,
876 },
877 {
878 ID: "small-model",
879 DefaultMaxTokens: 500,
880 },
881 },
882 },
883 }
884
885 cfg := &Config{
886 Providers: map[string]ProviderConfig{
887 "custom": {
888 APIKey: "test-key",
889 BaseURL: "https://api.custom.com/v1",
890 Models: []provider.Model{
891 {
892 ID: "model",
893 DefaultMaxTokens: 600,
894 },
895 },
896 },
897 },
898 }
899 cfg.setDefaults("/tmp")
900 env := env.NewFromMap(map[string]string{})
901 resolver := NewEnvironmentVariableResolver(env)
902 err := cfg.configureProviders(env, resolver, knownProviders)
903 assert.NoError(t, err)
904 large, small, err := cfg.defaultModelSelection(knownProviders)
905 assert.NoError(t, err)
906 assert.Equal(t, "model", large.Model)
907 assert.Equal(t, "custom", large.Provider)
908 assert.Equal(t, int64(600), large.MaxTokens)
909 assert.Equal(t, "model", small.Model)
910 assert.Equal(t, "custom", small.Provider)
911 assert.Equal(t, int64(600), small.MaxTokens)
912 })
913
914 t.Run("should fail if no model configured", func(t *testing.T) {
915 knownProviders := []provider.Provider{
916 {
917 ID: "openai",
918 APIKey: "$MISSING", // will not be included in the config
919 DefaultLargeModelID: "large-model",
920 DefaultSmallModelID: "small-model",
921 Models: []provider.Model{
922 {
923 ID: "not-large-model",
924 DefaultMaxTokens: 1000,
925 },
926 {
927 ID: "small-model",
928 DefaultMaxTokens: 500,
929 },
930 },
931 },
932 }
933
934 cfg := &Config{
935 Providers: map[string]ProviderConfig{
936 "custom": {
937 APIKey: "test-key",
938 BaseURL: "https://api.custom.com/v1",
939 Models: []provider.Model{},
940 },
941 },
942 }
943 cfg.setDefaults("/tmp")
944 env := env.NewFromMap(map[string]string{})
945 resolver := NewEnvironmentVariableResolver(env)
946 err := cfg.configureProviders(env, resolver, knownProviders)
947 assert.NoError(t, err)
948 _, _, err = cfg.defaultModelSelection(knownProviders)
949 assert.Error(t, err)
950 })
951 t.Run("should use the default provider first", func(t *testing.T) {
952 knownProviders := []provider.Provider{
953 {
954 ID: "openai",
955 APIKey: "set",
956 DefaultLargeModelID: "large-model",
957 DefaultSmallModelID: "small-model",
958 Models: []provider.Model{
959 {
960 ID: "large-model",
961 DefaultMaxTokens: 1000,
962 },
963 {
964 ID: "small-model",
965 DefaultMaxTokens: 500,
966 },
967 },
968 },
969 }
970
971 cfg := &Config{
972 Providers: map[string]ProviderConfig{
973 "custom": {
974 APIKey: "test-key",
975 BaseURL: "https://api.custom.com/v1",
976 Models: []provider.Model{
977 {
978 ID: "large-model",
979 DefaultMaxTokens: 1000,
980 },
981 },
982 },
983 },
984 }
985 cfg.setDefaults("/tmp")
986 env := env.NewFromMap(map[string]string{})
987 resolver := NewEnvironmentVariableResolver(env)
988 err := cfg.configureProviders(env, resolver, knownProviders)
989 assert.NoError(t, err)
990 large, small, err := cfg.defaultModelSelection(knownProviders)
991 assert.NoError(t, err)
992 assert.Equal(t, "large-model", large.Model)
993 assert.Equal(t, "openai", large.Provider)
994 assert.Equal(t, int64(1000), large.MaxTokens)
995 assert.Equal(t, "small-model", small.Model)
996 assert.Equal(t, "openai", small.Provider)
997 assert.Equal(t, int64(500), small.MaxTokens)
998 })
999}
1000
1001func TestConfig_configureSelectedModels(t *testing.T) {
1002 t.Run("should override defaults", func(t *testing.T) {
1003 knownProviders := []provider.Provider{
1004 {
1005 ID: "openai",
1006 APIKey: "abc",
1007 DefaultLargeModelID: "large-model",
1008 DefaultSmallModelID: "small-model",
1009 Models: []provider.Model{
1010 {
1011 ID: "larger-model",
1012 DefaultMaxTokens: 2000,
1013 },
1014 {
1015 ID: "large-model",
1016 DefaultMaxTokens: 1000,
1017 },
1018 {
1019 ID: "small-model",
1020 DefaultMaxTokens: 500,
1021 },
1022 },
1023 },
1024 }
1025
1026 cfg := &Config{
1027 Models: map[SelectedModelType]SelectedModel{
1028 "large": {
1029 Model: "larger-model",
1030 },
1031 },
1032 }
1033 cfg.setDefaults("/tmp")
1034 env := env.NewFromMap(map[string]string{})
1035 resolver := NewEnvironmentVariableResolver(env)
1036 err := cfg.configureProviders(env, resolver, knownProviders)
1037 assert.NoError(t, err)
1038
1039 err = cfg.configureSelectedModels(knownProviders)
1040 assert.NoError(t, err)
1041 large := cfg.Models[SelectedModelTypeLarge]
1042 small := cfg.Models[SelectedModelTypeSmall]
1043 assert.Equal(t, "larger-model", large.Model)
1044 assert.Equal(t, "openai", large.Provider)
1045 assert.Equal(t, int64(2000), large.MaxTokens)
1046 assert.Equal(t, "small-model", small.Model)
1047 assert.Equal(t, "openai", small.Provider)
1048 assert.Equal(t, int64(500), small.MaxTokens)
1049 })
1050 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1051 knownProviders := []provider.Provider{
1052 {
1053 ID: "openai",
1054 APIKey: "abc",
1055 DefaultLargeModelID: "large-model",
1056 DefaultSmallModelID: "small-model",
1057 Models: []provider.Model{
1058 {
1059 ID: "large-model",
1060 DefaultMaxTokens: 1000,
1061 },
1062 {
1063 ID: "small-model",
1064 DefaultMaxTokens: 500,
1065 },
1066 },
1067 },
1068 {
1069 ID: "anthropic",
1070 APIKey: "abc",
1071 DefaultLargeModelID: "a-large-model",
1072 DefaultSmallModelID: "a-small-model",
1073 Models: []provider.Model{
1074 {
1075 ID: "a-large-model",
1076 DefaultMaxTokens: 1000,
1077 },
1078 {
1079 ID: "a-small-model",
1080 DefaultMaxTokens: 200,
1081 },
1082 },
1083 },
1084 }
1085
1086 cfg := &Config{
1087 Models: map[SelectedModelType]SelectedModel{
1088 "small": {
1089 Model: "a-small-model",
1090 Provider: "anthropic",
1091 MaxTokens: 300,
1092 },
1093 },
1094 }
1095 cfg.setDefaults("/tmp")
1096 env := env.NewFromMap(map[string]string{})
1097 resolver := NewEnvironmentVariableResolver(env)
1098 err := cfg.configureProviders(env, resolver, knownProviders)
1099 assert.NoError(t, err)
1100
1101 err = cfg.configureSelectedModels(knownProviders)
1102 assert.NoError(t, err)
1103 large := cfg.Models[SelectedModelTypeLarge]
1104 small := cfg.Models[SelectedModelTypeSmall]
1105 assert.Equal(t, "large-model", large.Model)
1106 assert.Equal(t, "openai", large.Provider)
1107 assert.Equal(t, int64(1000), large.MaxTokens)
1108 assert.Equal(t, "a-small-model", small.Model)
1109 assert.Equal(t, "anthropic", small.Provider)
1110 assert.Equal(t, int64(300), small.MaxTokens)
1111 })
1112
1113 t.Run("should override the max tokens only", func(t *testing.T) {
1114 knownProviders := []provider.Provider{
1115 {
1116 ID: "openai",
1117 APIKey: "abc",
1118 DefaultLargeModelID: "large-model",
1119 DefaultSmallModelID: "small-model",
1120 Models: []provider.Model{
1121 {
1122 ID: "large-model",
1123 DefaultMaxTokens: 1000,
1124 },
1125 {
1126 ID: "small-model",
1127 DefaultMaxTokens: 500,
1128 },
1129 },
1130 },
1131 }
1132
1133 cfg := &Config{
1134 Models: map[SelectedModelType]SelectedModel{
1135 "large": {
1136 MaxTokens: 100,
1137 },
1138 },
1139 }
1140 cfg.setDefaults("/tmp")
1141 env := env.NewFromMap(map[string]string{})
1142 resolver := NewEnvironmentVariableResolver(env)
1143 err := cfg.configureProviders(env, resolver, knownProviders)
1144 assert.NoError(t, err)
1145
1146 err = cfg.configureSelectedModels(knownProviders)
1147 assert.NoError(t, err)
1148 large := cfg.Models[SelectedModelTypeLarge]
1149 assert.Equal(t, "large-model", large.Model)
1150 assert.Equal(t, "openai", large.Provider)
1151 assert.Equal(t, int64(100), large.MaxTokens)
1152 })
1153}