1package config
2
3import (
4 "io"
5 "log/slog"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10
11 "github.com/charmbracelet/crush/internal/env"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/stretchr/testify/assert"
14)
15
16func TestMain(m *testing.M) {
17 slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
18
19 exitVal := m.Run()
20 os.Exit(exitVal)
21}
22
23func TestConfig_LoadFromReaders(t *testing.T) {
24 data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
25 data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
26 data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
27
28 loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
29
30 assert.NoError(t, err)
31 assert.NotNil(t, loadedConfig)
32 assert.Len(t, loadedConfig.Providers, 1)
33 assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey)
34 assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL)
35}
36
37func TestConfig_setDefaults(t *testing.T) {
38 cfg := &Config{}
39
40 cfg.setDefaults("/tmp")
41
42 assert.NotNil(t, cfg.Options)
43 assert.NotNil(t, cfg.Options.TUI)
44 assert.NotNil(t, cfg.Options.ContextPaths)
45 assert.NotNil(t, cfg.Providers)
46 assert.NotNil(t, cfg.Models)
47 assert.NotNil(t, cfg.LSP)
48 assert.NotNil(t, cfg.MCP)
49 assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
50 for _, path := range defaultContextPaths {
51 assert.Contains(t, cfg.Options.ContextPaths, path)
52 }
53 assert.Equal(t, "/tmp", cfg.workingDir)
54}
55
56func TestConfig_configureProviders(t *testing.T) {
57 knownProviders := []provider.Provider{
58 {
59 ID: "openai",
60 APIKey: "$OPENAI_API_KEY",
61 APIEndpoint: "https://api.openai.com/v1",
62 Models: []provider.Model{{
63 ID: "test-model",
64 }},
65 },
66 }
67
68 cfg := &Config{}
69 cfg.setDefaults("/tmp")
70 env := env.NewFromMap(map[string]string{
71 "OPENAI_API_KEY": "test-key",
72 })
73 resolver := NewEnvironmentVariableResolver(env)
74 err := cfg.configureProviders(env, resolver, knownProviders)
75 assert.NoError(t, err)
76 assert.Len(t, cfg.Providers, 1)
77
78 // We want to make sure that we keep the configured API key as a placeholder
79 assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey)
80}
81
82func TestConfig_configureProvidersWithOverride(t *testing.T) {
83 knownProviders := []provider.Provider{
84 {
85 ID: "openai",
86 APIKey: "$OPENAI_API_KEY",
87 APIEndpoint: "https://api.openai.com/v1",
88 Models: []provider.Model{{
89 ID: "test-model",
90 }},
91 },
92 }
93
94 cfg := &Config{
95 Providers: map[string]ProviderConfig{
96 "openai": {
97 APIKey: "xyz",
98 BaseURL: "https://api.openai.com/v2",
99 Models: []provider.Model{
100 {
101 ID: "test-model",
102 Name: "Updated",
103 },
104 {
105 ID: "another-model",
106 },
107 },
108 },
109 },
110 }
111 cfg.setDefaults("/tmp")
112
113 env := env.NewFromMap(map[string]string{
114 "OPENAI_API_KEY": "test-key",
115 })
116 resolver := NewEnvironmentVariableResolver(env)
117 err := cfg.configureProviders(env, resolver, knownProviders)
118 assert.NoError(t, err)
119 assert.Len(t, cfg.Providers, 1)
120
121 // We want to make sure that we keep the configured API key as a placeholder
122 assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
123 assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
124 assert.Len(t, cfg.Providers["openai"].Models, 2)
125 assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name)
126}
127
128func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
129 knownProviders := []provider.Provider{
130 {
131 ID: "openai",
132 APIKey: "$OPENAI_API_KEY",
133 APIEndpoint: "https://api.openai.com/v1",
134 Models: []provider.Model{{
135 ID: "test-model",
136 }},
137 },
138 }
139
140 cfg := &Config{
141 Providers: map[string]ProviderConfig{
142 "custom": {
143 APIKey: "xyz",
144 BaseURL: "https://api.someendpoint.com/v2",
145 Models: []provider.Model{
146 {
147 ID: "test-model",
148 },
149 },
150 },
151 },
152 }
153 cfg.setDefaults("/tmp")
154 env := env.NewFromMap(map[string]string{
155 "OPENAI_API_KEY": "test-key",
156 })
157 resolver := NewEnvironmentVariableResolver(env)
158 err := cfg.configureProviders(env, resolver, knownProviders)
159 assert.NoError(t, err)
160 // Should be to because of the env variable
161 assert.Len(t, cfg.Providers, 2)
162
163 // We want to make sure that we keep the configured API key as a placeholder
164 assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey)
165 // Make sure we set the ID correctly
166 assert.Equal(t, "custom", cfg.Providers["custom"].ID)
167 assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL)
168 assert.Len(t, cfg.Providers["custom"].Models, 1)
169
170 _, ok := cfg.Providers["openai"]
171 assert.True(t, ok, "OpenAI provider should still be present")
172}
173
174func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
175 knownProviders := []provider.Provider{
176 {
177 ID: provider.InferenceProviderBedrock,
178 APIKey: "",
179 APIEndpoint: "",
180 Models: []provider.Model{{
181 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
182 }},
183 },
184 }
185
186 cfg := &Config{}
187 cfg.setDefaults("/tmp")
188 env := env.NewFromMap(map[string]string{
189 "AWS_ACCESS_KEY_ID": "test-key-id",
190 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
191 })
192 resolver := NewEnvironmentVariableResolver(env)
193 err := cfg.configureProviders(env, resolver, knownProviders)
194 assert.NoError(t, err)
195 assert.Len(t, cfg.Providers, 1)
196
197 bedrockProvider, ok := cfg.Providers["bedrock"]
198 assert.True(t, ok, "Bedrock provider should be present")
199 assert.Len(t, bedrockProvider.Models, 1)
200 assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
201}
202
203func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
204 knownProviders := []provider.Provider{
205 {
206 ID: provider.InferenceProviderBedrock,
207 APIKey: "",
208 APIEndpoint: "",
209 Models: []provider.Model{{
210 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
211 }},
212 },
213 }
214
215 cfg := &Config{}
216 cfg.setDefaults("/tmp")
217 env := env.NewFromMap(map[string]string{})
218 resolver := NewEnvironmentVariableResolver(env)
219 err := cfg.configureProviders(env, resolver, knownProviders)
220 assert.NoError(t, err)
221 // Provider should not be configured without credentials
222 assert.Len(t, cfg.Providers, 0)
223}
224
225func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
226 knownProviders := []provider.Provider{
227 {
228 ID: provider.InferenceProviderBedrock,
229 APIKey: "",
230 APIEndpoint: "",
231 Models: []provider.Model{{
232 ID: "some-random-model",
233 }},
234 },
235 }
236
237 cfg := &Config{}
238 cfg.setDefaults("/tmp")
239 env := env.NewFromMap(map[string]string{
240 "AWS_ACCESS_KEY_ID": "test-key-id",
241 "AWS_SECRET_ACCESS_KEY": "test-secret-key",
242 })
243 resolver := NewEnvironmentVariableResolver(env)
244 err := cfg.configureProviders(env, resolver, knownProviders)
245 assert.Error(t, err)
246}
247
248func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
249 knownProviders := []provider.Provider{
250 {
251 ID: provider.InferenceProviderVertexAI,
252 APIKey: "",
253 APIEndpoint: "",
254 Models: []provider.Model{{
255 ID: "gemini-pro",
256 }},
257 },
258 }
259
260 cfg := &Config{}
261 cfg.setDefaults("/tmp")
262 env := env.NewFromMap(map[string]string{
263 "GOOGLE_GENAI_USE_VERTEXAI": "true",
264 "GOOGLE_CLOUD_PROJECT": "test-project",
265 "GOOGLE_CLOUD_LOCATION": "us-central1",
266 })
267 resolver := NewEnvironmentVariableResolver(env)
268 err := cfg.configureProviders(env, resolver, knownProviders)
269 assert.NoError(t, err)
270 assert.Len(t, cfg.Providers, 1)
271
272 vertexProvider, ok := cfg.Providers["vertexai"]
273 assert.True(t, ok, "VertexAI provider should be present")
274 assert.Len(t, vertexProvider.Models, 1)
275 assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
276 assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
277 assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
278}
279
280func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
281 knownProviders := []provider.Provider{
282 {
283 ID: provider.InferenceProviderVertexAI,
284 APIKey: "",
285 APIEndpoint: "",
286 Models: []provider.Model{{
287 ID: "gemini-pro",
288 }},
289 },
290 }
291
292 cfg := &Config{}
293 cfg.setDefaults("/tmp")
294 env := env.NewFromMap(map[string]string{
295 "GOOGLE_GENAI_USE_VERTEXAI": "false",
296 "GOOGLE_CLOUD_PROJECT": "test-project",
297 "GOOGLE_CLOUD_LOCATION": "us-central1",
298 })
299 resolver := NewEnvironmentVariableResolver(env)
300 err := cfg.configureProviders(env, resolver, knownProviders)
301 assert.NoError(t, err)
302 // Provider should not be configured without proper credentials
303 assert.Len(t, cfg.Providers, 0)
304}
305
306func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
307 knownProviders := []provider.Provider{
308 {
309 ID: provider.InferenceProviderVertexAI,
310 APIKey: "",
311 APIEndpoint: "",
312 Models: []provider.Model{{
313 ID: "gemini-pro",
314 }},
315 },
316 }
317
318 cfg := &Config{}
319 cfg.setDefaults("/tmp")
320 env := env.NewFromMap(map[string]string{
321 "GOOGLE_GENAI_USE_VERTEXAI": "true",
322 "GOOGLE_CLOUD_LOCATION": "us-central1",
323 })
324 resolver := NewEnvironmentVariableResolver(env)
325 err := cfg.configureProviders(env, resolver, knownProviders)
326 assert.NoError(t, err)
327 // Provider should not be configured without project
328 assert.Len(t, cfg.Providers, 0)
329}
330
331func TestConfig_configureProvidersSetProviderID(t *testing.T) {
332 knownProviders := []provider.Provider{
333 {
334 ID: "openai",
335 APIKey: "$OPENAI_API_KEY",
336 APIEndpoint: "https://api.openai.com/v1",
337 Models: []provider.Model{{
338 ID: "test-model",
339 }},
340 },
341 }
342
343 cfg := &Config{}
344 cfg.setDefaults("/tmp")
345 env := env.NewFromMap(map[string]string{
346 "OPENAI_API_KEY": "test-key",
347 })
348 resolver := NewEnvironmentVariableResolver(env)
349 err := cfg.configureProviders(env, resolver, knownProviders)
350 assert.NoError(t, err)
351 assert.Len(t, cfg.Providers, 1)
352
353 // Provider ID should be set
354 assert.Equal(t, "openai", cfg.Providers["openai"].ID)
355}
356
357func TestConfig_EnabledProviders(t *testing.T) {
358 t.Run("all providers enabled", func(t *testing.T) {
359 cfg := &Config{
360 Providers: map[string]ProviderConfig{
361 "openai": {
362 ID: "openai",
363 APIKey: "key1",
364 Disable: false,
365 },
366 "anthropic": {
367 ID: "anthropic",
368 APIKey: "key2",
369 Disable: false,
370 },
371 },
372 }
373
374 enabled := cfg.EnabledProviders()
375 assert.Len(t, enabled, 2)
376 })
377
378 t.Run("some providers disabled", func(t *testing.T) {
379 cfg := &Config{
380 Providers: map[string]ProviderConfig{
381 "openai": {
382 ID: "openai",
383 APIKey: "key1",
384 Disable: false,
385 },
386 "anthropic": {
387 ID: "anthropic",
388 APIKey: "key2",
389 Disable: true,
390 },
391 },
392 }
393
394 enabled := cfg.EnabledProviders()
395 assert.Len(t, enabled, 1)
396 assert.Equal(t, "openai", enabled[0].ID)
397 })
398
399 t.Run("empty providers map", func(t *testing.T) {
400 cfg := &Config{
401 Providers: map[string]ProviderConfig{},
402 }
403
404 enabled := cfg.EnabledProviders()
405 assert.Len(t, enabled, 0)
406 })
407}
408
409func TestConfig_IsConfigured(t *testing.T) {
410 t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
411 cfg := &Config{
412 Providers: map[string]ProviderConfig{
413 "openai": {
414 ID: "openai",
415 APIKey: "key1",
416 Disable: false,
417 },
418 },
419 }
420
421 assert.True(t, cfg.IsConfigured())
422 })
423
424 t.Run("returns false when no providers are configured", func(t *testing.T) {
425 cfg := &Config{
426 Providers: map[string]ProviderConfig{},
427 }
428
429 assert.False(t, cfg.IsConfigured())
430 })
431
432 t.Run("returns false when all providers are disabled", func(t *testing.T) {
433 cfg := &Config{
434 Providers: map[string]ProviderConfig{
435 "openai": {
436 ID: "openai",
437 APIKey: "key1",
438 Disable: true,
439 },
440 "anthropic": {
441 ID: "anthropic",
442 APIKey: "key2",
443 Disable: true,
444 },
445 },
446 }
447
448 assert.False(t, cfg.IsConfigured())
449 })
450}
451
452func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
453 knownProviders := []provider.Provider{
454 {
455 ID: "openai",
456 APIKey: "$OPENAI_API_KEY",
457 APIEndpoint: "https://api.openai.com/v1",
458 Models: []provider.Model{{
459 ID: "test-model",
460 }},
461 },
462 }
463
464 cfg := &Config{
465 Providers: map[string]ProviderConfig{
466 "openai": {
467 Disable: true,
468 },
469 },
470 }
471 cfg.setDefaults("/tmp")
472
473 env := env.NewFromMap(map[string]string{
474 "OPENAI_API_KEY": "test-key",
475 })
476 resolver := NewEnvironmentVariableResolver(env)
477 err := cfg.configureProviders(env, resolver, knownProviders)
478 assert.NoError(t, err)
479
480 // Provider should be removed from config when disabled
481 assert.Len(t, cfg.Providers, 0)
482 _, exists := cfg.Providers["openai"]
483 assert.False(t, exists)
484}
485
486func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
487 t.Run("custom provider with missing API key is removed", func(t *testing.T) {
488 cfg := &Config{
489 Providers: map[string]ProviderConfig{
490 "custom": {
491 BaseURL: "https://api.custom.com/v1",
492 Models: []provider.Model{{
493 ID: "test-model",
494 }},
495 },
496 },
497 }
498 cfg.setDefaults("/tmp")
499
500 env := env.NewFromMap(map[string]string{})
501 resolver := NewEnvironmentVariableResolver(env)
502 err := cfg.configureProviders(env, resolver, []provider.Provider{})
503 assert.NoError(t, err)
504
505 assert.Len(t, cfg.Providers, 0)
506 _, exists := cfg.Providers["custom"]
507 assert.False(t, exists)
508 })
509
510 t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
511 cfg := &Config{
512 Providers: map[string]ProviderConfig{
513 "custom": {
514 APIKey: "test-key",
515 Models: []provider.Model{{
516 ID: "test-model",
517 }},
518 },
519 },
520 }
521 cfg.setDefaults("/tmp")
522
523 env := env.NewFromMap(map[string]string{})
524 resolver := NewEnvironmentVariableResolver(env)
525 err := cfg.configureProviders(env, resolver, []provider.Provider{})
526 assert.NoError(t, err)
527
528 assert.Len(t, cfg.Providers, 0)
529 _, exists := cfg.Providers["custom"]
530 assert.False(t, exists)
531 })
532
533 t.Run("custom provider with no models is removed", func(t *testing.T) {
534 cfg := &Config{
535 Providers: map[string]ProviderConfig{
536 "custom": {
537 APIKey: "test-key",
538 BaseURL: "https://api.custom.com/v1",
539 Models: []provider.Model{},
540 },
541 },
542 }
543 cfg.setDefaults("/tmp")
544
545 env := env.NewFromMap(map[string]string{})
546 resolver := NewEnvironmentVariableResolver(env)
547 err := cfg.configureProviders(env, resolver, []provider.Provider{})
548 assert.NoError(t, err)
549
550 assert.Len(t, cfg.Providers, 0)
551 _, exists := cfg.Providers["custom"]
552 assert.False(t, exists)
553 })
554
555 t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
556 cfg := &Config{
557 Providers: map[string]ProviderConfig{
558 "custom": {
559 APIKey: "test-key",
560 BaseURL: "https://api.custom.com/v1",
561 Type: "unsupported",
562 Models: []provider.Model{{
563 ID: "test-model",
564 }},
565 },
566 },
567 }
568 cfg.setDefaults("/tmp")
569
570 env := env.NewFromMap(map[string]string{})
571 resolver := NewEnvironmentVariableResolver(env)
572 err := cfg.configureProviders(env, resolver, []provider.Provider{})
573 assert.NoError(t, err)
574
575 assert.Len(t, cfg.Providers, 0)
576 _, exists := cfg.Providers["custom"]
577 assert.False(t, exists)
578 })
579
580 t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
581 cfg := &Config{
582 Providers: map[string]ProviderConfig{
583 "custom": {
584 APIKey: "test-key",
585 BaseURL: "https://api.custom.com/v1",
586 Type: provider.TypeOpenAI,
587 Models: []provider.Model{{
588 ID: "test-model",
589 }},
590 },
591 },
592 }
593 cfg.setDefaults("/tmp")
594
595 env := env.NewFromMap(map[string]string{})
596 resolver := NewEnvironmentVariableResolver(env)
597 err := cfg.configureProviders(env, resolver, []provider.Provider{})
598 assert.NoError(t, err)
599
600 assert.Len(t, cfg.Providers, 1)
601 customProvider, exists := cfg.Providers["custom"]
602 assert.True(t, exists)
603 assert.Equal(t, "custom", customProvider.ID)
604 assert.Equal(t, "test-key", customProvider.APIKey)
605 assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
606 })
607
608 t.Run("disabled custom provider is removed", func(t *testing.T) {
609 cfg := &Config{
610 Providers: map[string]ProviderConfig{
611 "custom": {
612 APIKey: "test-key",
613 BaseURL: "https://api.custom.com/v1",
614 Type: provider.TypeOpenAI,
615 Disable: true,
616 Models: []provider.Model{{
617 ID: "test-model",
618 }},
619 },
620 },
621 }
622 cfg.setDefaults("/tmp")
623
624 env := env.NewFromMap(map[string]string{})
625 resolver := NewEnvironmentVariableResolver(env)
626 err := cfg.configureProviders(env, resolver, []provider.Provider{})
627 assert.NoError(t, err)
628
629 assert.Len(t, cfg.Providers, 0)
630 _, exists := cfg.Providers["custom"]
631 assert.False(t, exists)
632 })
633}
634
635func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
636 t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
637 knownProviders := []provider.Provider{
638 {
639 ID: provider.InferenceProviderVertexAI,
640 APIKey: "",
641 APIEndpoint: "",
642 Models: []provider.Model{{
643 ID: "gemini-pro",
644 }},
645 },
646 }
647
648 cfg := &Config{
649 Providers: map[string]ProviderConfig{
650 "vertexai": {
651 BaseURL: "custom-url",
652 },
653 },
654 }
655 cfg.setDefaults("/tmp")
656
657 env := env.NewFromMap(map[string]string{
658 "GOOGLE_GENAI_USE_VERTEXAI": "false",
659 })
660 resolver := NewEnvironmentVariableResolver(env)
661 err := cfg.configureProviders(env, resolver, knownProviders)
662 assert.NoError(t, err)
663
664 assert.Len(t, cfg.Providers, 0)
665 _, exists := cfg.Providers["vertexai"]
666 assert.False(t, exists)
667 })
668
669 t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
670 knownProviders := []provider.Provider{
671 {
672 ID: provider.InferenceProviderBedrock,
673 APIKey: "",
674 APIEndpoint: "",
675 Models: []provider.Model{{
676 ID: "anthropic.claude-sonnet-4-20250514-v1:0",
677 }},
678 },
679 }
680
681 cfg := &Config{
682 Providers: map[string]ProviderConfig{
683 "bedrock": {
684 BaseURL: "custom-url",
685 },
686 },
687 }
688 cfg.setDefaults("/tmp")
689
690 env := env.NewFromMap(map[string]string{})
691 resolver := NewEnvironmentVariableResolver(env)
692 err := cfg.configureProviders(env, resolver, knownProviders)
693 assert.NoError(t, err)
694
695 assert.Len(t, cfg.Providers, 0)
696 _, exists := cfg.Providers["bedrock"]
697 assert.False(t, exists)
698 })
699
700 t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
701 knownProviders := []provider.Provider{
702 {
703 ID: "openai",
704 APIKey: "$MISSING_API_KEY",
705 APIEndpoint: "https://api.openai.com/v1",
706 Models: []provider.Model{{
707 ID: "test-model",
708 }},
709 },
710 }
711
712 cfg := &Config{
713 Providers: map[string]ProviderConfig{
714 "openai": {
715 BaseURL: "custom-url",
716 },
717 },
718 }
719 cfg.setDefaults("/tmp")
720
721 env := env.NewFromMap(map[string]string{})
722 resolver := NewEnvironmentVariableResolver(env)
723 err := cfg.configureProviders(env, resolver, knownProviders)
724 assert.NoError(t, err)
725
726 assert.Len(t, cfg.Providers, 0)
727 _, exists := cfg.Providers["openai"]
728 assert.False(t, exists)
729 })
730
731 t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
732 knownProviders := []provider.Provider{
733 {
734 ID: "openai",
735 APIKey: "$OPENAI_API_KEY",
736 APIEndpoint: "$MISSING_ENDPOINT",
737 Models: []provider.Model{{
738 ID: "test-model",
739 }},
740 },
741 }
742
743 cfg := &Config{
744 Providers: map[string]ProviderConfig{
745 "openai": {
746 APIKey: "test-key",
747 },
748 },
749 }
750 cfg.setDefaults("/tmp")
751
752 env := env.NewFromMap(map[string]string{
753 "OPENAI_API_KEY": "test-key",
754 })
755 resolver := NewEnvironmentVariableResolver(env)
756 err := cfg.configureProviders(env, resolver, knownProviders)
757 assert.NoError(t, err)
758
759 assert.Len(t, cfg.Providers, 1)
760 _, exists := cfg.Providers["openai"]
761 assert.True(t, exists)
762 })
763}
764
765func TestConfig_defaultModelSelection(t *testing.T) {
766 t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
767 knownProviders := []provider.Provider{
768 {
769 ID: "openai",
770 APIKey: "abc",
771 DefaultLargeModelID: "large-model",
772 DefaultSmallModelID: "small-model",
773 Models: []provider.Model{
774 {
775 ID: "large-model",
776 DefaultMaxTokens: 1000,
777 },
778 {
779 ID: "small-model",
780 DefaultMaxTokens: 500,
781 },
782 },
783 },
784 }
785
786 cfg := &Config{}
787 cfg.setDefaults("/tmp")
788 env := env.NewFromMap(map[string]string{})
789 resolver := NewEnvironmentVariableResolver(env)
790 err := cfg.configureProviders(env, resolver, knownProviders)
791 assert.NoError(t, err)
792
793 large, small, err := cfg.defaultModelSelection(knownProviders)
794 assert.NoError(t, err)
795 assert.Equal(t, "large-model", large.Model)
796 assert.Equal(t, "openai", large.Provider)
797 assert.Equal(t, int64(1000), large.MaxTokens)
798 assert.Equal(t, "small-model", small.Model)
799 assert.Equal(t, "openai", small.Provider)
800 assert.Equal(t, int64(500), small.MaxTokens)
801 })
802 t.Run("should error if no providers configured", func(t *testing.T) {
803 knownProviders := []provider.Provider{
804 {
805 ID: "openai",
806 APIKey: "$MISSING_KEY",
807 DefaultLargeModelID: "large-model",
808 DefaultSmallModelID: "small-model",
809 Models: []provider.Model{
810 {
811 ID: "large-model",
812 DefaultMaxTokens: 1000,
813 },
814 {
815 ID: "small-model",
816 DefaultMaxTokens: 500,
817 },
818 },
819 },
820 }
821
822 cfg := &Config{}
823 cfg.setDefaults("/tmp")
824 env := env.NewFromMap(map[string]string{})
825 resolver := NewEnvironmentVariableResolver(env)
826 err := cfg.configureProviders(env, resolver, knownProviders)
827 assert.NoError(t, err)
828
829 _, _, err = cfg.defaultModelSelection(knownProviders)
830 assert.Error(t, err)
831 })
832 t.Run("should error if model is missing", func(t *testing.T) {
833 knownProviders := []provider.Provider{
834 {
835 ID: "openai",
836 APIKey: "abc",
837 DefaultLargeModelID: "large-model",
838 DefaultSmallModelID: "small-model",
839 Models: []provider.Model{
840 {
841 ID: "not-large-model",
842 DefaultMaxTokens: 1000,
843 },
844 {
845 ID: "small-model",
846 DefaultMaxTokens: 500,
847 },
848 },
849 },
850 }
851
852 cfg := &Config{}
853 cfg.setDefaults("/tmp")
854 env := env.NewFromMap(map[string]string{})
855 resolver := NewEnvironmentVariableResolver(env)
856 err := cfg.configureProviders(env, resolver, knownProviders)
857 assert.NoError(t, err)
858 _, _, err = cfg.defaultModelSelection(knownProviders)
859 assert.Error(t, err)
860 })
861
862 t.Run("should configure the default models with a custom provider", func(t *testing.T) {
863 knownProviders := []provider.Provider{
864 {
865 ID: "openai",
866 APIKey: "$MISSING", // will not be included in the config
867 DefaultLargeModelID: "large-model",
868 DefaultSmallModelID: "small-model",
869 Models: []provider.Model{
870 {
871 ID: "not-large-model",
872 DefaultMaxTokens: 1000,
873 },
874 {
875 ID: "small-model",
876 DefaultMaxTokens: 500,
877 },
878 },
879 },
880 }
881
882 cfg := &Config{
883 Providers: map[string]ProviderConfig{
884 "custom": {
885 APIKey: "test-key",
886 BaseURL: "https://api.custom.com/v1",
887 Models: []provider.Model{
888 {
889 ID: "model",
890 DefaultMaxTokens: 600,
891 },
892 },
893 },
894 },
895 }
896 cfg.setDefaults("/tmp")
897 env := env.NewFromMap(map[string]string{})
898 resolver := NewEnvironmentVariableResolver(env)
899 err := cfg.configureProviders(env, resolver, knownProviders)
900 assert.NoError(t, err)
901 large, small, err := cfg.defaultModelSelection(knownProviders)
902 assert.NoError(t, err)
903 assert.Equal(t, "model", large.Model)
904 assert.Equal(t, "custom", large.Provider)
905 assert.Equal(t, int64(600), large.MaxTokens)
906 assert.Equal(t, "model", small.Model)
907 assert.Equal(t, "custom", small.Provider)
908 assert.Equal(t, int64(600), small.MaxTokens)
909 })
910
911 t.Run("should fail if no model configured", func(t *testing.T) {
912 knownProviders := []provider.Provider{
913 {
914 ID: "openai",
915 APIKey: "$MISSING", // will not be included in the config
916 DefaultLargeModelID: "large-model",
917 DefaultSmallModelID: "small-model",
918 Models: []provider.Model{
919 {
920 ID: "not-large-model",
921 DefaultMaxTokens: 1000,
922 },
923 {
924 ID: "small-model",
925 DefaultMaxTokens: 500,
926 },
927 },
928 },
929 }
930
931 cfg := &Config{
932 Providers: map[string]ProviderConfig{
933 "custom": {
934 APIKey: "test-key",
935 BaseURL: "https://api.custom.com/v1",
936 Models: []provider.Model{},
937 },
938 },
939 }
940 cfg.setDefaults("/tmp")
941 env := env.NewFromMap(map[string]string{})
942 resolver := NewEnvironmentVariableResolver(env)
943 err := cfg.configureProviders(env, resolver, knownProviders)
944 assert.NoError(t, err)
945 _, _, err = cfg.defaultModelSelection(knownProviders)
946 assert.Error(t, err)
947 })
948 t.Run("should use the default provider first", func(t *testing.T) {
949 knownProviders := []provider.Provider{
950 {
951 ID: "openai",
952 APIKey: "set",
953 DefaultLargeModelID: "large-model",
954 DefaultSmallModelID: "small-model",
955 Models: []provider.Model{
956 {
957 ID: "large-model",
958 DefaultMaxTokens: 1000,
959 },
960 {
961 ID: "small-model",
962 DefaultMaxTokens: 500,
963 },
964 },
965 },
966 }
967
968 cfg := &Config{
969 Providers: map[string]ProviderConfig{
970 "custom": {
971 APIKey: "test-key",
972 BaseURL: "https://api.custom.com/v1",
973 Models: []provider.Model{
974 {
975 ID: "large-model",
976 DefaultMaxTokens: 1000,
977 },
978 },
979 },
980 },
981 }
982 cfg.setDefaults("/tmp")
983 env := env.NewFromMap(map[string]string{})
984 resolver := NewEnvironmentVariableResolver(env)
985 err := cfg.configureProviders(env, resolver, knownProviders)
986 assert.NoError(t, err)
987 large, small, err := cfg.defaultModelSelection(knownProviders)
988 assert.NoError(t, err)
989 assert.Equal(t, "large-model", large.Model)
990 assert.Equal(t, "openai", large.Provider)
991 assert.Equal(t, int64(1000), large.MaxTokens)
992 assert.Equal(t, "small-model", small.Model)
993 assert.Equal(t, "openai", small.Provider)
994 assert.Equal(t, int64(500), small.MaxTokens)
995 })
996}
997
998func TestConfig_configureSelectedModels(t *testing.T) {
999 t.Run("should override defaults", func(t *testing.T) {
1000 knownProviders := []provider.Provider{
1001 {
1002 ID: "openai",
1003 APIKey: "abc",
1004 DefaultLargeModelID: "large-model",
1005 DefaultSmallModelID: "small-model",
1006 Models: []provider.Model{
1007 {
1008 ID: "larger-model",
1009 DefaultMaxTokens: 2000,
1010 },
1011 {
1012 ID: "large-model",
1013 DefaultMaxTokens: 1000,
1014 },
1015 {
1016 ID: "small-model",
1017 DefaultMaxTokens: 500,
1018 },
1019 },
1020 },
1021 }
1022
1023 cfg := &Config{
1024 Models: map[SelectedModelType]SelectedModel{
1025 "large": {
1026 Model: "larger-model",
1027 },
1028 },
1029 }
1030 cfg.setDefaults("/tmp")
1031 env := env.NewFromMap(map[string]string{})
1032 resolver := NewEnvironmentVariableResolver(env)
1033 err := cfg.configureProviders(env, resolver, knownProviders)
1034 assert.NoError(t, err)
1035
1036 err = cfg.configureSelectedModels(knownProviders)
1037 assert.NoError(t, err)
1038 large := cfg.Models[SelectedModelTypeLarge]
1039 small := cfg.Models[SelectedModelTypeSmall]
1040 assert.Equal(t, "larger-model", large.Model)
1041 assert.Equal(t, "openai", large.Provider)
1042 assert.Equal(t, int64(2000), large.MaxTokens)
1043 assert.Equal(t, "small-model", small.Model)
1044 assert.Equal(t, "openai", small.Provider)
1045 assert.Equal(t, int64(500), small.MaxTokens)
1046 })
1047 t.Run("should be possible to use multiple providers", func(t *testing.T) {
1048 knownProviders := []provider.Provider{
1049 {
1050 ID: "openai",
1051 APIKey: "abc",
1052 DefaultLargeModelID: "large-model",
1053 DefaultSmallModelID: "small-model",
1054 Models: []provider.Model{
1055 {
1056 ID: "large-model",
1057 DefaultMaxTokens: 1000,
1058 },
1059 {
1060 ID: "small-model",
1061 DefaultMaxTokens: 500,
1062 },
1063 },
1064 },
1065 {
1066 ID: "anthropic",
1067 APIKey: "abc",
1068 DefaultLargeModelID: "a-large-model",
1069 DefaultSmallModelID: "a-small-model",
1070 Models: []provider.Model{
1071 {
1072 ID: "a-large-model",
1073 DefaultMaxTokens: 1000,
1074 },
1075 {
1076 ID: "a-small-model",
1077 DefaultMaxTokens: 200,
1078 },
1079 },
1080 },
1081 }
1082
1083 cfg := &Config{
1084 Models: map[SelectedModelType]SelectedModel{
1085 "small": {
1086 Model: "a-small-model",
1087 Provider: "anthropic",
1088 MaxTokens: 300,
1089 },
1090 },
1091 }
1092 cfg.setDefaults("/tmp")
1093 env := env.NewFromMap(map[string]string{})
1094 resolver := NewEnvironmentVariableResolver(env)
1095 err := cfg.configureProviders(env, resolver, knownProviders)
1096 assert.NoError(t, err)
1097
1098 err = cfg.configureSelectedModels(knownProviders)
1099 assert.NoError(t, err)
1100 large := cfg.Models[SelectedModelTypeLarge]
1101 small := cfg.Models[SelectedModelTypeSmall]
1102 assert.Equal(t, "large-model", large.Model)
1103 assert.Equal(t, "openai", large.Provider)
1104 assert.Equal(t, int64(1000), large.MaxTokens)
1105 assert.Equal(t, "a-small-model", small.Model)
1106 assert.Equal(t, "anthropic", small.Provider)
1107 assert.Equal(t, int64(300), small.MaxTokens)
1108 })
1109
1110 t.Run("should override the max tokens only", func(t *testing.T) {
1111 knownProviders := []provider.Provider{
1112 {
1113 ID: "openai",
1114 APIKey: "abc",
1115 DefaultLargeModelID: "large-model",
1116 DefaultSmallModelID: "small-model",
1117 Models: []provider.Model{
1118 {
1119 ID: "large-model",
1120 DefaultMaxTokens: 1000,
1121 },
1122 {
1123 ID: "small-model",
1124 DefaultMaxTokens: 500,
1125 },
1126 },
1127 },
1128 }
1129
1130 cfg := &Config{
1131 Models: map[SelectedModelType]SelectedModel{
1132 "large": {
1133 MaxTokens: 100,
1134 },
1135 },
1136 }
1137 cfg.setDefaults("/tmp")
1138 env := env.NewFromMap(map[string]string{})
1139 resolver := NewEnvironmentVariableResolver(env)
1140 err := cfg.configureProviders(env, resolver, knownProviders)
1141 assert.NoError(t, err)
1142
1143 err = cfg.configureSelectedModels(knownProviders)
1144 assert.NoError(t, err)
1145 large := cfg.Models[SelectedModelTypeLarge]
1146 assert.Equal(t, "large-model", large.Model)
1147 assert.Equal(t, "openai", large.Provider)
1148 assert.Equal(t, int64(100), large.MaxTokens)
1149 })
1150}