1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "strings"
8 "testing"
9
10 "github.com/charmbracelet/crush/internal/env"
11 "github.com/charmbracelet/crush/internal/fur/provider"
12 "github.com/stretchr/testify/assert"
13)
14
15func TestMain(m *testing.M) {
16 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
17
18 exitVal := m.Run()
19 os.Exit(exitVal)
20}
21
22func TestConfig_LoadFromReaders(t *testing.T) {
23 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
24 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
25 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
26
27 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
28
29 assert.NoError(t, err)
30 assert.NotNil(t, loadedConfig)
31 assert.Len(t, loadedConfig.Providers, 1)
32 assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey)
33 assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL)
34}
35
36func TestConfig_setDefaults(t *testing.T) {
37 cfg := &Config{}
38
39 cfg.setDefaults("/tmp")
40
41 assert.NotNil(t, cfg.Options)
42 assert.NotNil(t, cfg.Options.TUI)
43 assert.NotNil(t, cfg.Options.ContextPaths)
44 assert.NotNil(t, cfg.Providers)
45 assert.NotNil(t, cfg.Models)
46 assert.NotNil(t, cfg.LSP)
47 assert.NotNil(t, cfg.MCP)
48 assert.Equal(t, "/tmp/.crush", cfg.Options.DataDirectory)
49 for _, path := range defaultContextPaths {
50 assert.Contains(t, cfg.Options.ContextPaths, path)
51 }
52 assert.Equal(t, "/tmp", cfg.workingDir)
53}
54
55func TestConfig_configureProviders(t *testing.T) {
56 knownProviders := []provider.Provider{
57 {
58 ID: "openai",
59 APIKey: "$OPENAI_API_KEY",
60 APIEndpoint: "https://api.openai.com/v1",
61 Models: []provider.Model{{
62 ID: "test-model",
63 }},
64 },
65 }
66
67 cfg := &Config{}
68 cfg.setDefaults("/tmp")
69 env := env.NewFromMap(map[string]string{
70 "OPENAI_API_KEY": "test-key",
71 })
72 resolver := NewEnvironmentVariableResolver(env)
73 err := cfg.configureProviders(env, resolver, knownProviders)
74 assert.NoError(t, err)
75 assert.Len(t, cfg.Providers, 1)
76
77 // We want to make sure that we keep the configured API key as a placeholder
78 assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey)
79}
80
81func TestConfig_configureProvidersWithOverride(t *testing.T) {
82 knownProviders := []provider.Provider{
83 {
84 ID: "openai",
85 APIKey: "$OPENAI_API_KEY",
86 APIEndpoint: "https://api.openai.com/v1",
87 Models: []provider.Model{{
88 ID: "test-model",
89 }},
90 },
91 }
92
93 cfg := &Config{
94 Providers: map[string]ProviderConfig{
95 "openai": {
96 APIKey: "xyz",
97 BaseURL: "https://api.openai.com/v2",
98 Models: []provider.Model{
99 {
100 ID: "test-model",
101 Name: "Updated",
102 },
103 {
104 ID: "another-model",
105 },
106 },
107 },
108 },
109 }
110 cfg.setDefaults("/tmp")
111
112 env := env.NewFromMap(map[string]string{
113 "OPENAI_API_KEY": "test-key",
114 })
115 resolver := NewEnvironmentVariableResolver(env)
116 err := cfg.configureProviders(env, resolver, knownProviders)
117 assert.NoError(t, err)
118 assert.Len(t, cfg.Providers, 1)
119
120 // We want to make sure that we keep the configured API key as a placeholder
121 assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
122 assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
123 assert.Len(t, cfg.Providers["openai"].Models, 2)
124 assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name)
125}
126
127func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
128 knownProviders := []provider.Provider{
129 {
130 ID: "openai",
131 APIKey: "$OPENAI_API_KEY",
132 APIEndpoint: "https://api.openai.com/v1",
133 Models: []provider.Model{{
134 ID: "test-model",
135 }},
136 },
137 }
138
139 cfg := &Config{
140 Providers: map[string]ProviderConfig{
141 "custom": {
142 APIKey: "xyz",
143 BaseURL: "https://api.someendpoint.com/v2",
144 Models: []provider.Model{
145 {
146 ID: "test-model",
147 },
148 },
149 },
150 },
151 }
152 cfg.setDefaults("/tmp")
153 env := env.NewFromMap(map[string]string{
154 "OPENAI_API_KEY": "test-key",
155 })
156 resolver := NewEnvironmentVariableResolver(env)
157 err := cfg.configureProviders(env, resolver, knownProviders)
158 assert.NoError(t, err)
159 // Should be to because of the env variable
160 assert.Len(t, cfg.Providers, 2)
161
162 // We want to make sure that we keep the configured API key as a placeholder
163 assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey)
164 // Make sure we set the ID correctly
165 assert.Equal(t, "custom", cfg.Providers["custom"].ID)
166 assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL)
167 assert.Len(t, cfg.Providers["custom"].Models, 1)
168
169 _, ok := cfg.Providers["openai"]
170 assert.True(t, ok, "OpenAI provider should still be present")
171}
172
173func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
174 knownProviders := []provider.Provider{
175 {
176 ID: provider.InferenceProviderBedrock,
177 APIKey: "",
178 APIEndpoint: "",
179 Models: []provider.Model{{
180 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
181 }},
182 },
183 }
184
185 cfg := &Config{}
186 cfg.setDefaults("/tmp")
187 env := env.NewFromMap(map[string]string{
188 "AWS_ACCESS_KEY_ID": "test-key-id",
189 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
190 })
191 resolver := NewEnvironmentVariableResolver(env)
192 err := cfg.configureProviders(env, resolver, knownProviders)
193 assert.NoError(t, err)
194 assert.Len(t, cfg.Providers, 1)
195
196 bedrockProvider, ok := cfg.Providers["bedrock"]
197 assert.True(t, ok, "Bedrock provider should be present")
198 assert.Len(t, bedrockProvider.Models, 1)
199 assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
200}
201
202func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
203 knownProviders := []provider.Provider{
204 {
205 ID: provider.InferenceProviderBedrock,
206 APIKey: "",
207 APIEndpoint: "",
208 Models: []provider.Model{{
209 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
210 }},
211 },
212 }
213
214 cfg := &Config{}
215 cfg.setDefaults("/tmp")
216 env := env.NewFromMap(map[string]string{})
217 resolver := NewEnvironmentVariableResolver(env)
218 err := cfg.configureProviders(env, resolver, knownProviders)
219 assert.NoError(t, err)
220 // Provider should not be configured without credentials
221 assert.Len(t, cfg.Providers, 0)
222}
223
224func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
225 knownProviders := []provider.Provider{
226 {
227 ID: provider.InferenceProviderBedrock,
228 APIKey: "",
229 APIEndpoint: "",
230 Models: []provider.Model{{
231 ID: "some-random-model",
232 }},
233 },
234 }
235
236 cfg := &Config{}
237 cfg.setDefaults("/tmp")
238 env := env.NewFromMap(map[string]string{
239 "AWS_ACCESS_KEY_ID": "test-key-id",
240 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
241 })
242 resolver := NewEnvironmentVariableResolver(env)
243 err := cfg.configureProviders(env, resolver, knownProviders)
244 assert.Error(t, err)
245}
246
247func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
248 knownProviders := []provider.Provider{
249 {
250 ID: provider.InferenceProviderVertexAI,
251 APIKey: "",
252 APIEndpoint: "",
253 Models: []provider.Model{{
254 ID: "gemini-pro",
255 }},
256 },
257 }
258
259 cfg := &Config{}
260 cfg.setDefaults("/tmp")
261 env := env.NewFromMap(map[string]string{
262 "GOOGLE_GENAI_USE_VERTEXAI": "true",
263 "GOOGLE_CLOUD_PROJECT": "test-project",
264 "GOOGLE_CLOUD_LOCATION": "us-central1",
265 })
266 resolver := NewEnvironmentVariableResolver(env)
267 err := cfg.configureProviders(env, resolver, knownProviders)
268 assert.NoError(t, err)
269 assert.Len(t, cfg.Providers, 1)
270
271 vertexProvider, ok := cfg.Providers["vertexai"]
272 assert.True(t, ok, "VertexAI provider should be present")
273 assert.Len(t, vertexProvider.Models, 1)
274 assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
275 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
276 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
277}
278
279func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
280 knownProviders := []provider.Provider{
281 {
282 ID: provider.InferenceProviderVertexAI,
283 APIKey: "",
284 APIEndpoint: "",
285 Models: []provider.Model{{
286 ID: "gemini-pro",
287 }},
288 },
289 }
290
291 cfg := &Config{}
292 cfg.setDefaults("/tmp")
293 env := env.NewFromMap(map[string]string{
294 "GOOGLE_GENAI_USE_VERTEXAI": "false",
295 "GOOGLE_CLOUD_PROJECT": "test-project",
296 "GOOGLE_CLOUD_LOCATION": "us-central1",
297 })
298 resolver := NewEnvironmentVariableResolver(env)
299 err := cfg.configureProviders(env, resolver, knownProviders)
300 assert.NoError(t, err)
301 // Provider should not be configured without proper credentials
302 assert.Len(t, cfg.Providers, 0)
303}
304
305func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
306 knownProviders := []provider.Provider{
307 {
308 ID: provider.InferenceProviderVertexAI,
309 APIKey: "",
310 APIEndpoint: "",
311 Models: []provider.Model{{
312 ID: "gemini-pro",
313 }},
314 },
315 }
316
317 cfg := &Config{}
318 cfg.setDefaults("/tmp")
319 env := env.NewFromMap(map[string]string{
320 "GOOGLE_GENAI_USE_VERTEXAI": "true",
321 "GOOGLE_CLOUD_LOCATION": "us-central1",
322 })
323 resolver := NewEnvironmentVariableResolver(env)
324 err := cfg.configureProviders(env, resolver, knownProviders)
325 assert.NoError(t, err)
326 // Provider should not be configured without project
327 assert.Len(t, cfg.Providers, 0)
328}
329
330func TestConfig_configureProvidersSetProviderID(t *testing.T) {
331 knownProviders := []provider.Provider{
332 {
333 ID: "openai",
334 APIKey: "$OPENAI_API_KEY",
335 APIEndpoint: "https://api.openai.com/v1",
336 Models: []provider.Model{{
337 ID: "test-model",
338 }},
339 },
340 }
341
342 cfg := &Config{}
343 cfg.setDefaults("/tmp")
344 env := env.NewFromMap(map[string]string{
345 "OPENAI_API_KEY": "test-key",
346 })
347 resolver := NewEnvironmentVariableResolver(env)
348 err := cfg.configureProviders(env, resolver, knownProviders)
349 assert.NoError(t, err)
350 assert.Len(t, cfg.Providers, 1)
351
352 // Provider ID should be set
353 assert.Equal(t, "openai", cfg.Providers["openai"].ID)
354}
355
356func TestConfig_EnabledProviders(t *testing.T) {
357 t.Run("all providers enabled", func(t *testing.T) {
358 cfg := &Config{
359 Providers: map[string]ProviderConfig{
360 "openai": {
361 ID: "openai",
362 APIKey: "key1",
363 Disable: false,
364 },
365 "anthropic": {
366 ID: "anthropic",
367 APIKey: "key2",
368 Disable: false,
369 },
370 },
371 }
372
373 enabled := cfg.EnabledProviders()
374 assert.Len(t, enabled, 2)
375 })
376
377 t.Run("some providers disabled", func(t *testing.T) {
378 cfg := &Config{
379 Providers: map[string]ProviderConfig{
380 "openai": {
381 ID: "openai",
382 APIKey: "key1",
383 Disable: false,
384 },
385 "anthropic": {
386 ID: "anthropic",
387 APIKey: "key2",
388 Disable: true,
389 },
390 },
391 }
392
393 enabled := cfg.EnabledProviders()
394 assert.Len(t, enabled, 1)
395 assert.Equal(t, "openai", enabled[0].ID)
396 })
397
398 t.Run("empty providers map", func(t *testing.T) {
399 cfg := &Config{
400 Providers: map[string]ProviderConfig{},
401 }
402
403 enabled := cfg.EnabledProviders()
404 assert.Len(t, enabled, 0)
405 })
406}
407
408func TestConfig_IsConfigured(t *testing.T) {
409 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
410 cfg := &Config{
411 Providers: map[string]ProviderConfig{
412 "openai": {
413 ID: "openai",
414 APIKey: "key1",
415 Disable: false,
416 },
417 },
418 }
419
420 assert.True(t, cfg.IsConfigured())
421 })
422
423 t.Run("returns false when no providers are configured", func(t *testing.T) {
424 cfg := &Config{
425 Providers: map[string]ProviderConfig{},
426 }
427
428 assert.False(t, cfg.IsConfigured())
429 })
430
431 t.Run("returns false when all providers are disabled", func(t *testing.T) {
432 cfg := &Config{
433 Providers: map[string]ProviderConfig{
434 "openai": {
435 ID: "openai",
436 APIKey: "key1",
437 Disable: true,
438 },
439 "anthropic": {
440 ID: "anthropic",
441 APIKey: "key2",
442 Disable: true,
443 },
444 },
445 }
446
447 assert.False(t, cfg.IsConfigured())
448 })
449}
450
451func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
452 knownProviders := []provider.Provider{
453 {
454 ID: "openai",
455 APIKey: "$OPENAI_API_KEY",
456 APIEndpoint: "https://api.openai.com/v1",
457 Models: []provider.Model{{
458 ID: "test-model",
459 }},
460 },
461 }
462
463 cfg := &Config{
464 Providers: map[string]ProviderConfig{
465 "openai": {
466 Disable: true,
467 },
468 },
469 }
470 cfg.setDefaults("/tmp")
471
472 env := env.NewFromMap(map[string]string{
473 "OPENAI_API_KEY": "test-key",
474 })
475 resolver := NewEnvironmentVariableResolver(env)
476 err := cfg.configureProviders(env, resolver, knownProviders)
477 assert.NoError(t, err)
478
479 // Provider should be removed from config when disabled
480 assert.Len(t, cfg.Providers, 0)
481 _, exists := cfg.Providers["openai"]
482 assert.False(t, exists)
483}
484
485func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
486 t.Run("custom provider with missing API key is removed", func(t *testing.T) {
487 cfg := &Config{
488 Providers: map[string]ProviderConfig{
489 "custom": {
490 BaseURL: "https://api.custom.com/v1",
491 Models: []provider.Model{{
492 ID: "test-model",
493 }},
494 },
495 },
496 }
497 cfg.setDefaults("/tmp")
498
499 env := env.NewFromMap(map[string]string{})
500 resolver := NewEnvironmentVariableResolver(env)
501 err := cfg.configureProviders(env, resolver, []provider.Provider{})
502 assert.NoError(t, err)
503
504 assert.Len(t, cfg.Providers, 0)
505 _, exists := cfg.Providers["custom"]
506 assert.False(t, exists)
507 })
508
509 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
510 cfg := &Config{
511 Providers: map[string]ProviderConfig{
512 "custom": {
513 APIKey: "test-key",
514 Models: []provider.Model{{
515 ID: "test-model",
516 }},
517 },
518 },
519 }
520 cfg.setDefaults("/tmp")
521
522 env := env.NewFromMap(map[string]string{})
523 resolver := NewEnvironmentVariableResolver(env)
524 err := cfg.configureProviders(env, resolver, []provider.Provider{})
525 assert.NoError(t, err)
526
527 assert.Len(t, cfg.Providers, 0)
528 _, exists := cfg.Providers["custom"]
529 assert.False(t, exists)
530 })
531
532 t.Run("custom provider with no models is removed", func(t *testing.T) {
533 cfg := &Config{
534 Providers: map[string]ProviderConfig{
535 "custom": {
536 APIKey: "test-key",
537 BaseURL: "https://api.custom.com/v1",
538 Models: []provider.Model{},
539 },
540 },
541 }
542 cfg.setDefaults("/tmp")
543
544 env := env.NewFromMap(map[string]string{})
545 resolver := NewEnvironmentVariableResolver(env)
546 err := cfg.configureProviders(env, resolver, []provider.Provider{})
547 assert.NoError(t, err)
548
549 assert.Len(t, cfg.Providers, 0)
550 _, exists := cfg.Providers["custom"]
551 assert.False(t, exists)
552 })
553
554 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
555 cfg := &Config{
556 Providers: map[string]ProviderConfig{
557 "custom": {
558 APIKey: "test-key",
559 BaseURL: "https://api.custom.com/v1",
560 Type: "unsupported",
561 Models: []provider.Model{{
562 ID: "test-model",
563 }},
564 },
565 },
566 }
567 cfg.setDefaults("/tmp")
568
569 env := env.NewFromMap(map[string]string{})
570 resolver := NewEnvironmentVariableResolver(env)
571 err := cfg.configureProviders(env, resolver, []provider.Provider{})
572 assert.NoError(t, err)
573
574 assert.Len(t, cfg.Providers, 0)
575 _, exists := cfg.Providers["custom"]
576 assert.False(t, exists)
577 })
578
579 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
580 cfg := &Config{
581 Providers: map[string]ProviderConfig{
582 "custom": {
583 APIKey: "test-key",
584 BaseURL: "https://api.custom.com/v1",
585 Type: provider.TypeOpenAI,
586 Models: []provider.Model{{
587 ID: "test-model",
588 }},
589 },
590 },
591 }
592 cfg.setDefaults("/tmp")
593
594 env := env.NewFromMap(map[string]string{})
595 resolver := NewEnvironmentVariableResolver(env)
596 err := cfg.configureProviders(env, resolver, []provider.Provider{})
597 assert.NoError(t, err)
598
599 assert.Len(t, cfg.Providers, 1)
600 customProvider, exists := cfg.Providers["custom"]
601 assert.True(t, exists)
602 assert.Equal(t, "custom", customProvider.ID)
603 assert.Equal(t, "test-key", customProvider.APIKey)
604 assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
605 })
606
607 t.Run("disabled custom provider is removed", func(t *testing.T) {
608 cfg := &Config{
609 Providers: map[string]ProviderConfig{
610 "custom": {
611 APIKey: "test-key",
612 BaseURL: "https://api.custom.com/v1",
613 Type: provider.TypeOpenAI,
614 Disable: true,
615 Models: []provider.Model{{
616 ID: "test-model",
617 }},
618 },
619 },
620 }
621 cfg.setDefaults("/tmp")
622
623 env := env.NewFromMap(map[string]string{})
624 resolver := NewEnvironmentVariableResolver(env)
625 err := cfg.configureProviders(env, resolver, []provider.Provider{})
626 assert.NoError(t, err)
627
628 assert.Len(t, cfg.Providers, 0)
629 _, exists := cfg.Providers["custom"]
630 assert.False(t, exists)
631 })
632}
633
634func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
635 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
636 knownProviders := []provider.Provider{
637 {
638 ID: provider.InferenceProviderVertexAI,
639 APIKey: "",
640 APIEndpoint: "",
641 Models: []provider.Model{{
642 ID: "gemini-pro",
643 }},
644 },
645 }
646
647 cfg := &Config{
648 Providers: map[string]ProviderConfig{
649 "vertexai": {
650 BaseURL: "custom-url",
651 },
652 },
653 }
654 cfg.setDefaults("/tmp")
655
656 env := env.NewFromMap(map[string]string{
657 "GOOGLE_GENAI_USE_VERTEXAI": "false",
658 })
659 resolver := NewEnvironmentVariableResolver(env)
660 err := cfg.configureProviders(env, resolver, knownProviders)
661 assert.NoError(t, err)
662
663 assert.Len(t, cfg.Providers, 0)
664 _, exists := cfg.Providers["vertexai"]
665 assert.False(t, exists)
666 })
667
668 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
669 knownProviders := []provider.Provider{
670 {
671 ID: provider.InferenceProviderBedrock,
672 APIKey: "",
673 APIEndpoint: "",
674 Models: []provider.Model{{
675 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
676 }},
677 },
678 }
679
680 cfg := &Config{
681 Providers: map[string]ProviderConfig{
682 "bedrock": {
683 BaseURL: "custom-url",
684 },
685 },
686 }
687 cfg.setDefaults("/tmp")
688
689 env := env.NewFromMap(map[string]string{})
690 resolver := NewEnvironmentVariableResolver(env)
691 err := cfg.configureProviders(env, resolver, knownProviders)
692 assert.NoError(t, err)
693
694 assert.Len(t, cfg.Providers, 0)
695 _, exists := cfg.Providers["bedrock"]
696 assert.False(t, exists)
697 })
698
699 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
700 knownProviders := []provider.Provider{
701 {
702 ID: "openai",
703 APIKey: "$MISSING_API_KEY",
704 APIEndpoint: "https://api.openai.com/v1",
705 Models: []provider.Model{{
706 ID: "test-model",
707 }},
708 },
709 }
710
711 cfg := &Config{
712 Providers: map[string]ProviderConfig{
713 "openai": {
714 BaseURL: "custom-url",
715 },
716 },
717 }
718 cfg.setDefaults("/tmp")
719
720 env := env.NewFromMap(map[string]string{})
721 resolver := NewEnvironmentVariableResolver(env)
722 err := cfg.configureProviders(env, resolver, knownProviders)
723 assert.NoError(t, err)
724
725 assert.Len(t, cfg.Providers, 0)
726 _, exists := cfg.Providers["openai"]
727 assert.False(t, exists)
728 })
729
730 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
731 knownProviders := []provider.Provider{
732 {
733 ID: "openai",
734 APIKey: "$OPENAI_API_KEY",
735 APIEndpoint: "$MISSING_ENDPOINT",
736 Models: []provider.Model{{
737 ID: "test-model",
738 }},
739 },
740 }
741
742 cfg := &Config{
743 Providers: map[string]ProviderConfig{
744 "openai": {
745 APIKey: "test-key",
746 },
747 },
748 }
749 cfg.setDefaults("/tmp")
750
751 env := env.NewFromMap(map[string]string{
752 "OPENAI_API_KEY": "test-key",
753 })
754 resolver := NewEnvironmentVariableResolver(env)
755 err := cfg.configureProviders(env, resolver, knownProviders)
756 assert.NoError(t, err)
757
758 assert.Len(t, cfg.Providers, 1)
759 _, exists := cfg.Providers["openai"]
760 assert.True(t, exists)
761 })
762}
763
764func TestConfig_defaultModelSelection(t *testing.T) {
765 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
766 knownProviders := []provider.Provider{
767 {
768 ID: "openai",
769 APIKey: "abc",
770 DefaultLargeModelID: "large-model",
771 DefaultSmallModelID: "small-model",
772 Models: []provider.Model{
773 {
774 ID: "large-model",
775 DefaultMaxTokens: 1000,
776 },
777 {
778 ID: "small-model",
779 DefaultMaxTokens: 500,
780 },
781 },
782 },
783 }
784
785 cfg := &Config{}
786 cfg.setDefaults("/tmp")
787 env := env.NewFromMap(map[string]string{})
788 resolver := NewEnvironmentVariableResolver(env)
789 err := cfg.configureProviders(env, resolver, knownProviders)
790 assert.NoError(t, err)
791
792 large, small, err := cfg.defaultModelSelection(knownProviders)
793 assert.NoError(t, err)
794 assert.Equal(t, "large-model", large.Model)
795 assert.Equal(t, "openai", large.Provider)
796 assert.Equal(t, int64(1000), large.MaxTokens)
797 assert.Equal(t, "small-model", small.Model)
798 assert.Equal(t, "openai", small.Provider)
799 assert.Equal(t, int64(500), small.MaxTokens)
800 })
801 t.Run("should error if no providers configured", func(t *testing.T) {
802 knownProviders := []provider.Provider{
803 {
804 ID: "openai",
805 APIKey: "$MISSING_KEY",
806 DefaultLargeModelID: "large-model",
807 DefaultSmallModelID: "small-model",
808 Models: []provider.Model{
809 {
810 ID: "large-model",
811 DefaultMaxTokens: 1000,
812 },
813 {
814 ID: "small-model",
815 DefaultMaxTokens: 500,
816 },
817 },
818 },
819 }
820
821 cfg := &Config{}
822 cfg.setDefaults("/tmp")
823 env := env.NewFromMap(map[string]string{})
824 resolver := NewEnvironmentVariableResolver(env)
825 err := cfg.configureProviders(env, resolver, knownProviders)
826 assert.NoError(t, err)
827
828 _, _, err = cfg.defaultModelSelection(knownProviders)
829 assert.Error(t, err)
830 })
831 t.Run("should error if model is missing", func(t *testing.T) {
832 knownProviders := []provider.Provider{
833 {
834 ID: "openai",
835 APIKey: "abc",
836 DefaultLargeModelID: "large-model",
837 DefaultSmallModelID: "small-model",
838 Models: []provider.Model{
839 {
840 ID: "not-large-model",
841 DefaultMaxTokens: 1000,
842 },
843 {
844 ID: "small-model",
845 DefaultMaxTokens: 500,
846 },
847 },
848 },
849 }
850
851 cfg := &Config{}
852 cfg.setDefaults("/tmp")
853 env := env.NewFromMap(map[string]string{})
854 resolver := NewEnvironmentVariableResolver(env)
855 err := cfg.configureProviders(env, resolver, knownProviders)
856 assert.NoError(t, err)
857 _, _, err = cfg.defaultModelSelection(knownProviders)
858 assert.Error(t, err)
859 })
860
861 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
862 knownProviders := []provider.Provider{
863 {
864 ID: "openai",
865 APIKey: "$MISSING", // will not be included in the config
866 DefaultLargeModelID: "large-model",
867 DefaultSmallModelID: "small-model",
868 Models: []provider.Model{
869 {
870 ID: "not-large-model",
871 DefaultMaxTokens: 1000,
872 },
873 {
874 ID: "small-model",
875 DefaultMaxTokens: 500,
876 },
877 },
878 },
879 }
880
881 cfg := &Config{
882 Providers: map[string]ProviderConfig{
883 "custom": {
884 APIKey: "test-key",
885 BaseURL: "https://api.custom.com/v1",
886 Models: []provider.Model{
887 {
888 ID: "model",
889 DefaultMaxTokens: 600,
890 },
891 },
892 },
893 },
894 }
895 cfg.setDefaults("/tmp")
896 env := env.NewFromMap(map[string]string{})
897 resolver := NewEnvironmentVariableResolver(env)
898 err := cfg.configureProviders(env, resolver, knownProviders)
899 assert.NoError(t, err)
900 large, small, err := cfg.defaultModelSelection(knownProviders)
901 assert.NoError(t, err)
902 assert.Equal(t, "model", large.Model)
903 assert.Equal(t, "custom", large.Provider)
904 assert.Equal(t, int64(600), large.MaxTokens)
905 assert.Equal(t, "model", small.Model)
906 assert.Equal(t, "custom", small.Provider)
907 assert.Equal(t, int64(600), small.MaxTokens)
908 })
909
910 t.Run("should fail if no model configured", func(t *testing.T) {
911 knownProviders := []provider.Provider{
912 {
913 ID: "openai",
914 APIKey: "$MISSING", // will not be included in the config
915 DefaultLargeModelID: "large-model",
916 DefaultSmallModelID: "small-model",
917 Models: []provider.Model{
918 {
919 ID: "not-large-model",
920 DefaultMaxTokens: 1000,
921 },
922 {
923 ID: "small-model",
924 DefaultMaxTokens: 500,
925 },
926 },
927 },
928 }
929
930 cfg := &Config{
931 Providers: map[string]ProviderConfig{
932 "custom": {
933 APIKey: "test-key",
934 BaseURL: "https://api.custom.com/v1",
935 Models: []provider.Model{},
936 },
937 },
938 }
939 cfg.setDefaults("/tmp")
940 env := env.NewFromMap(map[string]string{})
941 resolver := NewEnvironmentVariableResolver(env)
942 err := cfg.configureProviders(env, resolver, knownProviders)
943 assert.NoError(t, err)
944 _, _, err = cfg.defaultModelSelection(knownProviders)
945 assert.Error(t, err)
946 })
947 t.Run("should use the default provider first", func(t *testing.T) {
948 knownProviders := []provider.Provider{
949 {
950 ID: "openai",
951 APIKey: "set",
952 DefaultLargeModelID: "large-model",
953 DefaultSmallModelID: "small-model",
954 Models: []provider.Model{
955 {
956 ID: "large-model",
957 DefaultMaxTokens: 1000,
958 },
959 {
960 ID: "small-model",
961 DefaultMaxTokens: 500,
962 },
963 },
964 },
965 }
966
967 cfg := &Config{
968 Providers: map[string]ProviderConfig{
969 "custom": {
970 APIKey: "test-key",
971 BaseURL: "https://api.custom.com/v1",
972 Models: []provider.Model{
973 {
974 ID: "large-model",
975 DefaultMaxTokens: 1000,
976 },
977 },
978 },
979 },
980 }
981 cfg.setDefaults("/tmp")
982 env := env.NewFromMap(map[string]string{})
983 resolver := NewEnvironmentVariableResolver(env)
984 err := cfg.configureProviders(env, resolver, knownProviders)
985 assert.NoError(t, err)
986 large, small, err := cfg.defaultModelSelection(knownProviders)
987 assert.NoError(t, err)
988 assert.Equal(t, "large-model", large.Model)
989 assert.Equal(t, "openai", large.Provider)
990 assert.Equal(t, int64(1000), large.MaxTokens)
991 assert.Equal(t, "small-model", small.Model)
992 assert.Equal(t, "openai", small.Provider)
993 assert.Equal(t, int64(500), small.MaxTokens)
994 })
995}
996
997func TestConfig_configureSelectedModels(t *testing.T) {
998 t.Run("should override defaults", func(t *testing.T) {
999 knownProviders := []provider.Provider{
1000 {
1001 ID: "openai",
1002 APIKey: "abc",
1003 DefaultLargeModelID: "large-model",
1004 DefaultSmallModelID: "small-model",
1005 Models: []provider.Model{
1006 {
1007 ID: "larger-model",
1008 DefaultMaxTokens: 2000,
1009 },
1010 {
1011 ID: "large-model",
1012 DefaultMaxTokens: 1000,
1013 },
1014 {
1015 ID: "small-model",
1016 DefaultMaxTokens: 500,
1017 },
1018 },
1019 },
1020 }
1021
1022 cfg := &Config{
1023 Models: map[SelectedModelType]SelectedModel{
1024 "large": {
1025 Model: "larger-model",
1026 },
1027 },
1028 }
1029 cfg.setDefaults("/tmp")
1030 env := env.NewFromMap(map[string]string{})
1031 resolver := NewEnvironmentVariableResolver(env)
1032 err := cfg.configureProviders(env, resolver, knownProviders)
1033 assert.NoError(t, err)
1034
1035 err = cfg.configureSelectedModels(knownProviders)
1036 assert.NoError(t, err)
1037 large := cfg.Models[SelectedModelTypeLarge]
1038 small := cfg.Models[SelectedModelTypeSmall]
1039 assert.Equal(t, "larger-model", large.Model)
1040 assert.Equal(t, "openai", large.Provider)
1041 assert.Equal(t, int64(2000), large.MaxTokens)
1042 assert.Equal(t, "small-model", small.Model)
1043 assert.Equal(t, "openai", small.Provider)
1044 assert.Equal(t, int64(500), small.MaxTokens)
1045 })
1046 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1047 knownProviders := []provider.Provider{
1048 {
1049 ID: "openai",
1050 APIKey: "abc",
1051 DefaultLargeModelID: "large-model",
1052 DefaultSmallModelID: "small-model",
1053 Models: []provider.Model{
1054 {
1055 ID: "large-model",
1056 DefaultMaxTokens: 1000,
1057 },
1058 {
1059 ID: "small-model",
1060 DefaultMaxTokens: 500,
1061 },
1062 },
1063 },
1064 {
1065 ID: "anthropic",
1066 APIKey: "abc",
1067 DefaultLargeModelID: "a-large-model",
1068 DefaultSmallModelID: "a-small-model",
1069 Models: []provider.Model{
1070 {
1071 ID: "a-large-model",
1072 DefaultMaxTokens: 1000,
1073 },
1074 {
1075 ID: "a-small-model",
1076 DefaultMaxTokens: 200,
1077 },
1078 },
1079 },
1080 }
1081
1082 cfg := &Config{
1083 Models: map[SelectedModelType]SelectedModel{
1084 "small": {
1085 Model: "a-small-model",
1086 Provider: "anthropic",
1087 MaxTokens: 300,
1088 },
1089 },
1090 }
1091 cfg.setDefaults("/tmp")
1092 env := env.NewFromMap(map[string]string{})
1093 resolver := NewEnvironmentVariableResolver(env)
1094 err := cfg.configureProviders(env, resolver, knownProviders)
1095 assert.NoError(t, err)
1096
1097 err = cfg.configureSelectedModels(knownProviders)
1098 assert.NoError(t, err)
1099 large := cfg.Models[SelectedModelTypeLarge]
1100 small := cfg.Models[SelectedModelTypeSmall]
1101 assert.Equal(t, "large-model", large.Model)
1102 assert.Equal(t, "openai", large.Provider)
1103 assert.Equal(t, int64(1000), large.MaxTokens)
1104 assert.Equal(t, "a-small-model", small.Model)
1105 assert.Equal(t, "anthropic", small.Provider)
1106 assert.Equal(t, int64(300), small.MaxTokens)
1107 })
1108
1109 t.Run("should override the max tokens only", func(t *testing.T) {
1110 knownProviders := []provider.Provider{
1111 {
1112 ID: "openai",
1113 APIKey: "abc",
1114 DefaultLargeModelID: "large-model",
1115 DefaultSmallModelID: "small-model",
1116 Models: []provider.Model{
1117 {
1118 ID: "large-model",
1119 DefaultMaxTokens: 1000,
1120 },
1121 {
1122 ID: "small-model",
1123 DefaultMaxTokens: 500,
1124 },
1125 },
1126 },
1127 }
1128
1129 cfg := &Config{
1130 Models: map[SelectedModelType]SelectedModel{
1131 "large": {
1132 MaxTokens: 100,
1133 },
1134 },
1135 }
1136 cfg.setDefaults("/tmp")
1137 env := env.NewFromMap(map[string]string{})
1138 resolver := NewEnvironmentVariableResolver(env)
1139 err := cfg.configureProviders(env, resolver, knownProviders)
1140 assert.NoError(t, err)
1141
1142 err = cfg.configureSelectedModels(knownProviders)
1143 assert.NoError(t, err)
1144 large := cfg.Models[SelectedModelTypeLarge]
1145 assert.Equal(t, "large-model", large.Model)
1146 assert.Equal(t, "openai", large.Provider)
1147 assert.Equal(t, int64(100), large.MaxTokens)
1148 })
1149}