1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/stretchr/testify/require"
  15)
  16
  17func TestMain(m *testing.M) {
  18	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  19
  20	exitVal := m.Run()
  21	os.Exit(exitVal)
  22}
  23
  24func TestConfig_LoadFromReaders(t *testing.T) {
  25	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  26	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  27	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  28
  29	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  30
  31	require.NoError(t, err)
  32	require.NotNil(t, loadedConfig)
  33	require.Equal(t, 1, loadedConfig.Providers.Len())
  34	pc, _ := loadedConfig.Providers.Get("openai")
  35	require.Equal(t, "key2", pc.APIKey)
  36	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  37}
  38
  39func TestConfig_setDefaults(t *testing.T) {
  40	cfg := &Config{}
  41
  42	cfg.setDefaults("/tmp")
  43
  44	require.NotNil(t, cfg.Options)
  45	require.NotNil(t, cfg.Options.TUI)
  46	require.NotNil(t, cfg.Options.ContextPaths)
  47	require.NotNil(t, cfg.Providers)
  48	require.NotNil(t, cfg.Models)
  49	require.NotNil(t, cfg.LSP)
  50	require.NotNil(t, cfg.MCP)
  51	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  52	for _, path := range defaultContextPaths {
  53		require.Contains(t, cfg.Options.ContextPaths, path)
  54	}
  55	require.Equal(t, "/tmp", cfg.workingDir)
  56}
  57
  58func TestConfig_configureProviders(t *testing.T) {
  59	knownProviders := []catwalk.Provider{
  60		{
  61			ID:          "openai",
  62			APIKey:      "$OPENAI_API_KEY",
  63			APIEndpoint: "https://api.openai.com/v1",
  64			Models: []catwalk.Model{{
  65				ID: "test-model",
  66			}},
  67		},
  68	}
  69
  70	cfg := &Config{}
  71	cfg.setDefaults("/tmp")
  72	env := env.NewFromMap(map[string]string{
  73		"OPENAI_API_KEY": "test-key",
  74	})
  75	resolver := NewEnvironmentVariableResolver(env)
  76	err := cfg.configureProviders(env, resolver, knownProviders)
  77	require.NoError(t, err)
  78	require.Equal(t, 1, cfg.Providers.Len())
  79
  80	// We want to make sure that we keep the configured API key as a placeholder
  81	pc, _ := cfg.Providers.Get("openai")
  82	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  83}
  84
  85func TestConfig_configureProvidersWithOverride(t *testing.T) {
  86	knownProviders := []catwalk.Provider{
  87		{
  88			ID:          "openai",
  89			APIKey:      "$OPENAI_API_KEY",
  90			APIEndpoint: "https://api.openai.com/v1",
  91			Models: []catwalk.Model{{
  92				ID: "test-model",
  93			}},
  94		},
  95	}
  96
  97	cfg := &Config{
  98		Providers: csync.NewMap[string, ProviderConfig](),
  99	}
 100	cfg.Providers.Set("openai", ProviderConfig{
 101		APIKey:  "xyz",
 102		BaseURL: "https://api.openai.com/v2",
 103		Models: []catwalk.Model{
 104			{
 105				ID:   "test-model",
 106				Name: "Updated",
 107			},
 108			{
 109				ID: "another-model",
 110			},
 111		},
 112	})
 113	cfg.setDefaults("/tmp")
 114
 115	env := env.NewFromMap(map[string]string{
 116		"OPENAI_API_KEY": "test-key",
 117	})
 118	resolver := NewEnvironmentVariableResolver(env)
 119	err := cfg.configureProviders(env, resolver, knownProviders)
 120	require.NoError(t, err)
 121	require.Equal(t, 1, cfg.Providers.Len())
 122
 123	// We want to make sure that we keep the configured API key as a placeholder
 124	pc, _ := cfg.Providers.Get("openai")
 125	require.Equal(t, "xyz", pc.APIKey)
 126	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 127	require.Len(t, pc.Models, 2)
 128	require.Equal(t, "Updated", pc.Models[0].Name)
 129}
 130
 131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 132	knownProviders := []catwalk.Provider{
 133		{
 134			ID:          "openai",
 135			APIKey:      "$OPENAI_API_KEY",
 136			APIEndpoint: "https://api.openai.com/v1",
 137			Models: []catwalk.Model{{
 138				ID: "test-model",
 139			}},
 140		},
 141	}
 142
 143	cfg := &Config{
 144		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 145			"custom": {
 146				APIKey:  "xyz",
 147				BaseURL: "https://api.someendpoint.com/v2",
 148				Models: []catwalk.Model{
 149					{
 150						ID: "test-model",
 151					},
 152				},
 153			},
 154		}),
 155	}
 156	cfg.setDefaults("/tmp")
 157	env := env.NewFromMap(map[string]string{
 158		"OPENAI_API_KEY": "test-key",
 159	})
 160	resolver := NewEnvironmentVariableResolver(env)
 161	err := cfg.configureProviders(env, resolver, knownProviders)
 162	require.NoError(t, err)
 163	// Should be to because of the env variable
 164	require.Equal(t, cfg.Providers.Len(), 2)
 165
 166	// We want to make sure that we keep the configured API key as a placeholder
 167	pc, _ := cfg.Providers.Get("custom")
 168	require.Equal(t, "xyz", pc.APIKey)
 169	// Make sure we set the ID correctly
 170	require.Equal(t, "custom", pc.ID)
 171	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 172	require.Len(t, pc.Models, 1)
 173
 174	_, ok := cfg.Providers.Get("openai")
 175	require.True(t, ok, "OpenAI provider should still be present")
 176}
 177
 178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 179	knownProviders := []catwalk.Provider{
 180		{
 181			ID:          catwalk.InferenceProviderBedrock,
 182			APIKey:      "",
 183			APIEndpoint: "",
 184			Models: []catwalk.Model{{
 185				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 186			}},
 187		},
 188	}
 189
 190	cfg := &Config{}
 191	cfg.setDefaults("/tmp")
 192	env := env.NewFromMap(map[string]string{
 193		"AWS_ACCESS_KEY_ID":     "test-key-id",
 194		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 195	})
 196	resolver := NewEnvironmentVariableResolver(env)
 197	err := cfg.configureProviders(env, resolver, knownProviders)
 198	require.NoError(t, err)
 199	require.Equal(t, cfg.Providers.Len(), 1)
 200
 201	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 202	require.True(t, ok, "Bedrock provider should be present")
 203	require.Len(t, bedrockProvider.Models, 1)
 204	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 205}
 206
 207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 208	knownProviders := []catwalk.Provider{
 209		{
 210			ID:          catwalk.InferenceProviderBedrock,
 211			APIKey:      "",
 212			APIEndpoint: "",
 213			Models: []catwalk.Model{{
 214				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 215			}},
 216		},
 217	}
 218
 219	cfg := &Config{}
 220	cfg.setDefaults("/tmp")
 221	env := env.NewFromMap(map[string]string{})
 222	resolver := NewEnvironmentVariableResolver(env)
 223	err := cfg.configureProviders(env, resolver, knownProviders)
 224	require.NoError(t, err)
 225	// Provider should not be configured without credentials
 226	require.Equal(t, cfg.Providers.Len(), 0)
 227}
 228
 229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 230	knownProviders := []catwalk.Provider{
 231		{
 232			ID:          catwalk.InferenceProviderBedrock,
 233			APIKey:      "",
 234			APIEndpoint: "",
 235			Models: []catwalk.Model{{
 236				ID: "some-random-model",
 237			}},
 238		},
 239	}
 240
 241	cfg := &Config{}
 242	cfg.setDefaults("/tmp")
 243	env := env.NewFromMap(map[string]string{
 244		"AWS_ACCESS_KEY_ID":     "test-key-id",
 245		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 246	})
 247	resolver := NewEnvironmentVariableResolver(env)
 248	err := cfg.configureProviders(env, resolver, knownProviders)
 249	require.Error(t, err)
 250}
 251
 252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 253	knownProviders := []catwalk.Provider{
 254		{
 255			ID:          catwalk.InferenceProviderVertexAI,
 256			APIKey:      "",
 257			APIEndpoint: "",
 258			Models: []catwalk.Model{{
 259				ID: "gemini-pro",
 260			}},
 261		},
 262	}
 263
 264	cfg := &Config{}
 265	cfg.setDefaults("/tmp")
 266	env := env.NewFromMap(map[string]string{
 267		"VERTEXAI_PROJECT":  "test-project",
 268		"VERTEXAI_LOCATION": "us-central1",
 269	})
 270	resolver := NewEnvironmentVariableResolver(env)
 271	err := cfg.configureProviders(env, resolver, knownProviders)
 272	require.NoError(t, err)
 273	require.Equal(t, cfg.Providers.Len(), 1)
 274
 275	vertexProvider, ok := cfg.Providers.Get("vertexai")
 276	require.True(t, ok, "VertexAI provider should be present")
 277	require.Len(t, vertexProvider.Models, 1)
 278	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 279	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 280	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 281}
 282
 283func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 284	knownProviders := []catwalk.Provider{
 285		{
 286			ID:          catwalk.InferenceProviderVertexAI,
 287			APIKey:      "",
 288			APIEndpoint: "",
 289			Models: []catwalk.Model{{
 290				ID: "gemini-pro",
 291			}},
 292		},
 293	}
 294
 295	cfg := &Config{}
 296	cfg.setDefaults("/tmp")
 297	env := env.NewFromMap(map[string]string{
 298		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 299		"GOOGLE_CLOUD_PROJECT":      "test-project",
 300		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 301	})
 302	resolver := NewEnvironmentVariableResolver(env)
 303	err := cfg.configureProviders(env, resolver, knownProviders)
 304	require.NoError(t, err)
 305	// Provider should not be configured without proper credentials
 306	require.Equal(t, cfg.Providers.Len(), 0)
 307}
 308
 309func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 310	knownProviders := []catwalk.Provider{
 311		{
 312			ID:          catwalk.InferenceProviderVertexAI,
 313			APIKey:      "",
 314			APIEndpoint: "",
 315			Models: []catwalk.Model{{
 316				ID: "gemini-pro",
 317			}},
 318		},
 319	}
 320
 321	cfg := &Config{}
 322	cfg.setDefaults("/tmp")
 323	env := env.NewFromMap(map[string]string{
 324		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 325		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 326	})
 327	resolver := NewEnvironmentVariableResolver(env)
 328	err := cfg.configureProviders(env, resolver, knownProviders)
 329	require.NoError(t, err)
 330	// Provider should not be configured without project
 331	require.Equal(t, cfg.Providers.Len(), 0)
 332}
 333
 334func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 335	knownProviders := []catwalk.Provider{
 336		{
 337			ID:          "openai",
 338			APIKey:      "$OPENAI_API_KEY",
 339			APIEndpoint: "https://api.openai.com/v1",
 340			Models: []catwalk.Model{{
 341				ID: "test-model",
 342			}},
 343		},
 344	}
 345
 346	cfg := &Config{}
 347	cfg.setDefaults("/tmp")
 348	env := env.NewFromMap(map[string]string{
 349		"OPENAI_API_KEY": "test-key",
 350	})
 351	resolver := NewEnvironmentVariableResolver(env)
 352	err := cfg.configureProviders(env, resolver, knownProviders)
 353	require.NoError(t, err)
 354	require.Equal(t, cfg.Providers.Len(), 1)
 355
 356	// Provider ID should be set
 357	pc, _ := cfg.Providers.Get("openai")
 358	require.Equal(t, "openai", pc.ID)
 359}
 360
 361func TestConfig_EnabledProviders(t *testing.T) {
 362	t.Run("all providers enabled", func(t *testing.T) {
 363		cfg := &Config{
 364			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 365				"openai": {
 366					ID:      "openai",
 367					APIKey:  "key1",
 368					Disable: false,
 369				},
 370				"anthropic": {
 371					ID:      "anthropic",
 372					APIKey:  "key2",
 373					Disable: false,
 374				},
 375			}),
 376		}
 377
 378		enabled := cfg.EnabledProviders()
 379		require.Len(t, enabled, 2)
 380	})
 381
 382	t.Run("some providers disabled", func(t *testing.T) {
 383		cfg := &Config{
 384			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 385				"openai": {
 386					ID:      "openai",
 387					APIKey:  "key1",
 388					Disable: false,
 389				},
 390				"anthropic": {
 391					ID:      "anthropic",
 392					APIKey:  "key2",
 393					Disable: true,
 394				},
 395			}),
 396		}
 397
 398		enabled := cfg.EnabledProviders()
 399		require.Len(t, enabled, 1)
 400		require.Equal(t, "openai", enabled[0].ID)
 401	})
 402
 403	t.Run("empty providers map", func(t *testing.T) {
 404		cfg := &Config{
 405			Providers: csync.NewMap[string, ProviderConfig](),
 406		}
 407
 408		enabled := cfg.EnabledProviders()
 409		require.Len(t, enabled, 0)
 410	})
 411}
 412
 413func TestConfig_IsConfigured(t *testing.T) {
 414	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 415		cfg := &Config{
 416			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 417				"openai": {
 418					ID:      "openai",
 419					APIKey:  "key1",
 420					Disable: false,
 421				},
 422			}),
 423		}
 424
 425		require.True(t, cfg.IsConfigured())
 426	})
 427
 428	t.Run("returns false when no providers are configured", func(t *testing.T) {
 429		cfg := &Config{
 430			Providers: csync.NewMap[string, ProviderConfig](),
 431		}
 432
 433		require.False(t, cfg.IsConfigured())
 434	})
 435
 436	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 437		cfg := &Config{
 438			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 439				"openai": {
 440					ID:      "openai",
 441					APIKey:  "key1",
 442					Disable: true,
 443				},
 444				"anthropic": {
 445					ID:      "anthropic",
 446					APIKey:  "key2",
 447					Disable: true,
 448				},
 449			}),
 450		}
 451
 452		require.False(t, cfg.IsConfigured())
 453	})
 454}
 455
 456func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 457	knownProviders := []catwalk.Provider{
 458		{
 459			ID:          "openai",
 460			APIKey:      "$OPENAI_API_KEY",
 461			APIEndpoint: "https://api.openai.com/v1",
 462			Models: []catwalk.Model{{
 463				ID: "test-model",
 464			}},
 465		},
 466	}
 467
 468	cfg := &Config{
 469		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 470			"openai": {
 471				Disable: true,
 472			},
 473		}),
 474	}
 475	cfg.setDefaults("/tmp")
 476
 477	env := env.NewFromMap(map[string]string{
 478		"OPENAI_API_KEY": "test-key",
 479	})
 480	resolver := NewEnvironmentVariableResolver(env)
 481	err := cfg.configureProviders(env, resolver, knownProviders)
 482	require.NoError(t, err)
 483
 484	// Provider should be removed from config when disabled
 485	require.Equal(t, cfg.Providers.Len(), 0)
 486	_, exists := cfg.Providers.Get("openai")
 487	require.False(t, exists)
 488}
 489
 490func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 491	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 492		cfg := &Config{
 493			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 494				"custom": {
 495					BaseURL: "https://api.custom.com/v1",
 496					Models: []catwalk.Model{{
 497						ID: "test-model",
 498					}},
 499				},
 500				"openai": {
 501					APIKey: "$MISSING",
 502				},
 503			}),
 504		}
 505		cfg.setDefaults("/tmp")
 506
 507		env := env.NewFromMap(map[string]string{})
 508		resolver := NewEnvironmentVariableResolver(env)
 509		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 510		require.NoError(t, err)
 511
 512		require.Equal(t, cfg.Providers.Len(), 1)
 513		_, exists := cfg.Providers.Get("custom")
 514		require.True(t, exists)
 515	})
 516
 517	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 518		cfg := &Config{
 519			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 520				"custom": {
 521					APIKey: "test-key",
 522					Models: []catwalk.Model{{
 523						ID: "test-model",
 524					}},
 525				},
 526			}),
 527		}
 528		cfg.setDefaults("/tmp")
 529
 530		env := env.NewFromMap(map[string]string{})
 531		resolver := NewEnvironmentVariableResolver(env)
 532		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 533		require.NoError(t, err)
 534
 535		require.Equal(t, cfg.Providers.Len(), 0)
 536		_, exists := cfg.Providers.Get("custom")
 537		require.False(t, exists)
 538	})
 539
 540	t.Run("custom provider with no models is removed", func(t *testing.T) {
 541		cfg := &Config{
 542			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 543				"custom": {
 544					APIKey:  "test-key",
 545					BaseURL: "https://api.custom.com/v1",
 546					Models:  []catwalk.Model{},
 547				},
 548			}),
 549		}
 550		cfg.setDefaults("/tmp")
 551
 552		env := env.NewFromMap(map[string]string{})
 553		resolver := NewEnvironmentVariableResolver(env)
 554		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 555		require.NoError(t, err)
 556
 557		require.Equal(t, cfg.Providers.Len(), 0)
 558		_, exists := cfg.Providers.Get("custom")
 559		require.False(t, exists)
 560	})
 561
 562	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 563		cfg := &Config{
 564			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 565				"custom": {
 566					APIKey:  "test-key",
 567					BaseURL: "https://api.custom.com/v1",
 568					Type:    "unsupported",
 569					Models: []catwalk.Model{{
 570						ID: "test-model",
 571					}},
 572				},
 573			}),
 574		}
 575		cfg.setDefaults("/tmp")
 576
 577		env := env.NewFromMap(map[string]string{})
 578		resolver := NewEnvironmentVariableResolver(env)
 579		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 580		require.NoError(t, err)
 581
 582		require.Equal(t, cfg.Providers.Len(), 0)
 583		_, exists := cfg.Providers.Get("custom")
 584		require.False(t, exists)
 585	})
 586
 587	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 588		cfg := &Config{
 589			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 590				"custom": {
 591					APIKey:  "test-key",
 592					BaseURL: "https://api.custom.com/v1",
 593					Type:    catwalk.TypeOpenAI,
 594					Models: []catwalk.Model{{
 595						ID: "test-model",
 596					}},
 597				},
 598			}),
 599		}
 600		cfg.setDefaults("/tmp")
 601
 602		env := env.NewFromMap(map[string]string{})
 603		resolver := NewEnvironmentVariableResolver(env)
 604		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 605		require.NoError(t, err)
 606
 607		require.Equal(t, cfg.Providers.Len(), 1)
 608		customProvider, exists := cfg.Providers.Get("custom")
 609		require.True(t, exists)
 610		require.Equal(t, "custom", customProvider.ID)
 611		require.Equal(t, "test-key", customProvider.APIKey)
 612		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 613	})
 614
 615	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 616		cfg := &Config{
 617			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 618				"custom-anthropic": {
 619					APIKey:  "test-key",
 620					BaseURL: "https://api.anthropic.com/v1",
 621					Type:    catwalk.TypeAnthropic,
 622					Models: []catwalk.Model{{
 623						ID: "claude-3-sonnet",
 624					}},
 625				},
 626			}),
 627		}
 628		cfg.setDefaults("/tmp")
 629
 630		env := env.NewFromMap(map[string]string{})
 631		resolver := NewEnvironmentVariableResolver(env)
 632		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 633		require.NoError(t, err)
 634
 635		require.Equal(t, cfg.Providers.Len(), 1)
 636		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 637		require.True(t, exists)
 638		require.Equal(t, "custom-anthropic", customProvider.ID)
 639		require.Equal(t, "test-key", customProvider.APIKey)
 640		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 641		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 642	})
 643
 644	t.Run("disabled custom provider is removed", func(t *testing.T) {
 645		cfg := &Config{
 646			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 647				"custom": {
 648					APIKey:  "test-key",
 649					BaseURL: "https://api.custom.com/v1",
 650					Type:    catwalk.TypeOpenAI,
 651					Disable: true,
 652					Models: []catwalk.Model{{
 653						ID: "test-model",
 654					}},
 655				},
 656			}),
 657		}
 658		cfg.setDefaults("/tmp")
 659
 660		env := env.NewFromMap(map[string]string{})
 661		resolver := NewEnvironmentVariableResolver(env)
 662		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 663		require.NoError(t, err)
 664
 665		require.Equal(t, cfg.Providers.Len(), 0)
 666		_, exists := cfg.Providers.Get("custom")
 667		require.False(t, exists)
 668	})
 669}
 670
 671func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 672	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 673		knownProviders := []catwalk.Provider{
 674			{
 675				ID:          catwalk.InferenceProviderVertexAI,
 676				APIKey:      "",
 677				APIEndpoint: "",
 678				Models: []catwalk.Model{{
 679					ID: "gemini-pro",
 680				}},
 681			},
 682		}
 683
 684		cfg := &Config{
 685			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 686				"vertexai": {
 687					BaseURL: "custom-url",
 688				},
 689			}),
 690		}
 691		cfg.setDefaults("/tmp")
 692
 693		env := env.NewFromMap(map[string]string{
 694			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 695		})
 696		resolver := NewEnvironmentVariableResolver(env)
 697		err := cfg.configureProviders(env, resolver, knownProviders)
 698		require.NoError(t, err)
 699
 700		require.Equal(t, cfg.Providers.Len(), 0)
 701		_, exists := cfg.Providers.Get("vertexai")
 702		require.False(t, exists)
 703	})
 704
 705	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 706		knownProviders := []catwalk.Provider{
 707			{
 708				ID:          catwalk.InferenceProviderBedrock,
 709				APIKey:      "",
 710				APIEndpoint: "",
 711				Models: []catwalk.Model{{
 712					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 713				}},
 714			},
 715		}
 716
 717		cfg := &Config{
 718			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 719				"bedrock": {
 720					BaseURL: "custom-url",
 721				},
 722			}),
 723		}
 724		cfg.setDefaults("/tmp")
 725
 726		env := env.NewFromMap(map[string]string{})
 727		resolver := NewEnvironmentVariableResolver(env)
 728		err := cfg.configureProviders(env, resolver, knownProviders)
 729		require.NoError(t, err)
 730
 731		require.Equal(t, cfg.Providers.Len(), 0)
 732		_, exists := cfg.Providers.Get("bedrock")
 733		require.False(t, exists)
 734	})
 735
 736	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 737		knownProviders := []catwalk.Provider{
 738			{
 739				ID:          "openai",
 740				APIKey:      "$MISSING_API_KEY",
 741				APIEndpoint: "https://api.openai.com/v1",
 742				Models: []catwalk.Model{{
 743					ID: "test-model",
 744				}},
 745			},
 746		}
 747
 748		cfg := &Config{
 749			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 750				"openai": {
 751					BaseURL: "custom-url",
 752				},
 753			}),
 754		}
 755		cfg.setDefaults("/tmp")
 756
 757		env := env.NewFromMap(map[string]string{})
 758		resolver := NewEnvironmentVariableResolver(env)
 759		err := cfg.configureProviders(env, resolver, knownProviders)
 760		require.NoError(t, err)
 761
 762		require.Equal(t, cfg.Providers.Len(), 0)
 763		_, exists := cfg.Providers.Get("openai")
 764		require.False(t, exists)
 765	})
 766
 767	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 768		knownProviders := []catwalk.Provider{
 769			{
 770				ID:          "openai",
 771				APIKey:      "$OPENAI_API_KEY",
 772				APIEndpoint: "$MISSING_ENDPOINT",
 773				Models: []catwalk.Model{{
 774					ID: "test-model",
 775				}},
 776			},
 777		}
 778
 779		cfg := &Config{
 780			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 781				"openai": {
 782					APIKey: "test-key",
 783				},
 784			}),
 785		}
 786		cfg.setDefaults("/tmp")
 787
 788		env := env.NewFromMap(map[string]string{
 789			"OPENAI_API_KEY": "test-key",
 790		})
 791		resolver := NewEnvironmentVariableResolver(env)
 792		err := cfg.configureProviders(env, resolver, knownProviders)
 793		require.NoError(t, err)
 794
 795		require.Equal(t, cfg.Providers.Len(), 1)
 796		_, exists := cfg.Providers.Get("openai")
 797		require.True(t, exists)
 798	})
 799}
 800
 801func TestConfig_defaultModelSelection(t *testing.T) {
 802	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 803		knownProviders := []catwalk.Provider{
 804			{
 805				ID:                  "openai",
 806				APIKey:              "abc",
 807				DefaultLargeModelID: "large-model",
 808				DefaultSmallModelID: "small-model",
 809				Models: []catwalk.Model{
 810					{
 811						ID:               "large-model",
 812						DefaultMaxTokens: 1000,
 813					},
 814					{
 815						ID:               "small-model",
 816						DefaultMaxTokens: 500,
 817					},
 818				},
 819			},
 820		}
 821
 822		cfg := &Config{}
 823		cfg.setDefaults("/tmp")
 824		env := env.NewFromMap(map[string]string{})
 825		resolver := NewEnvironmentVariableResolver(env)
 826		err := cfg.configureProviders(env, resolver, knownProviders)
 827		require.NoError(t, err)
 828
 829		large, small, err := cfg.defaultModelSelection(knownProviders)
 830		require.NoError(t, err)
 831		require.Equal(t, "large-model", large.Model)
 832		require.Equal(t, "openai", large.Provider)
 833		require.Equal(t, int64(1000), large.MaxTokens)
 834		require.Equal(t, "small-model", small.Model)
 835		require.Equal(t, "openai", small.Provider)
 836		require.Equal(t, int64(500), small.MaxTokens)
 837	})
 838	t.Run("should error if no providers configured", func(t *testing.T) {
 839		knownProviders := []catwalk.Provider{
 840			{
 841				ID:                  "openai",
 842				APIKey:              "$MISSING_KEY",
 843				DefaultLargeModelID: "large-model",
 844				DefaultSmallModelID: "small-model",
 845				Models: []catwalk.Model{
 846					{
 847						ID:               "large-model",
 848						DefaultMaxTokens: 1000,
 849					},
 850					{
 851						ID:               "small-model",
 852						DefaultMaxTokens: 500,
 853					},
 854				},
 855			},
 856		}
 857
 858		cfg := &Config{}
 859		cfg.setDefaults("/tmp")
 860		env := env.NewFromMap(map[string]string{})
 861		resolver := NewEnvironmentVariableResolver(env)
 862		err := cfg.configureProviders(env, resolver, knownProviders)
 863		require.NoError(t, err)
 864
 865		_, _, err = cfg.defaultModelSelection(knownProviders)
 866		require.Error(t, err)
 867	})
 868	t.Run("should error if model is missing", func(t *testing.T) {
 869		knownProviders := []catwalk.Provider{
 870			{
 871				ID:                  "openai",
 872				APIKey:              "abc",
 873				DefaultLargeModelID: "large-model",
 874				DefaultSmallModelID: "small-model",
 875				Models: []catwalk.Model{
 876					{
 877						ID:               "not-large-model",
 878						DefaultMaxTokens: 1000,
 879					},
 880					{
 881						ID:               "small-model",
 882						DefaultMaxTokens: 500,
 883					},
 884				},
 885			},
 886		}
 887
 888		cfg := &Config{}
 889		cfg.setDefaults("/tmp")
 890		env := env.NewFromMap(map[string]string{})
 891		resolver := NewEnvironmentVariableResolver(env)
 892		err := cfg.configureProviders(env, resolver, knownProviders)
 893		require.NoError(t, err)
 894		_, _, err = cfg.defaultModelSelection(knownProviders)
 895		require.Error(t, err)
 896	})
 897
 898	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 899		knownProviders := []catwalk.Provider{
 900			{
 901				ID:                  "openai",
 902				APIKey:              "$MISSING", // will not be included in the config
 903				DefaultLargeModelID: "large-model",
 904				DefaultSmallModelID: "small-model",
 905				Models: []catwalk.Model{
 906					{
 907						ID:               "not-large-model",
 908						DefaultMaxTokens: 1000,
 909					},
 910					{
 911						ID:               "small-model",
 912						DefaultMaxTokens: 500,
 913					},
 914				},
 915			},
 916		}
 917
 918		cfg := &Config{
 919			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 920				"custom": {
 921					APIKey:  "test-key",
 922					BaseURL: "https://api.custom.com/v1",
 923					Models: []catwalk.Model{
 924						{
 925							ID:               "model",
 926							DefaultMaxTokens: 600,
 927						},
 928					},
 929				},
 930			}),
 931		}
 932		cfg.setDefaults("/tmp")
 933		env := env.NewFromMap(map[string]string{})
 934		resolver := NewEnvironmentVariableResolver(env)
 935		err := cfg.configureProviders(env, resolver, knownProviders)
 936		require.NoError(t, err)
 937		large, small, err := cfg.defaultModelSelection(knownProviders)
 938		require.NoError(t, err)
 939		require.Equal(t, "model", large.Model)
 940		require.Equal(t, "custom", large.Provider)
 941		require.Equal(t, int64(600), large.MaxTokens)
 942		require.Equal(t, "model", small.Model)
 943		require.Equal(t, "custom", small.Provider)
 944		require.Equal(t, int64(600), small.MaxTokens)
 945	})
 946
 947	t.Run("should fail if no model configured", func(t *testing.T) {
 948		knownProviders := []catwalk.Provider{
 949			{
 950				ID:                  "openai",
 951				APIKey:              "$MISSING", // will not be included in the config
 952				DefaultLargeModelID: "large-model",
 953				DefaultSmallModelID: "small-model",
 954				Models: []catwalk.Model{
 955					{
 956						ID:               "not-large-model",
 957						DefaultMaxTokens: 1000,
 958					},
 959					{
 960						ID:               "small-model",
 961						DefaultMaxTokens: 500,
 962					},
 963				},
 964			},
 965		}
 966
 967		cfg := &Config{
 968			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 969				"custom": {
 970					APIKey:  "test-key",
 971					BaseURL: "https://api.custom.com/v1",
 972					Models:  []catwalk.Model{},
 973				},
 974			}),
 975		}
 976		cfg.setDefaults("/tmp")
 977		env := env.NewFromMap(map[string]string{})
 978		resolver := NewEnvironmentVariableResolver(env)
 979		err := cfg.configureProviders(env, resolver, knownProviders)
 980		require.NoError(t, err)
 981		_, _, err = cfg.defaultModelSelection(knownProviders)
 982		require.Error(t, err)
 983	})
 984	t.Run("should use the default provider first", func(t *testing.T) {
 985		knownProviders := []catwalk.Provider{
 986			{
 987				ID:                  "openai",
 988				APIKey:              "set",
 989				DefaultLargeModelID: "large-model",
 990				DefaultSmallModelID: "small-model",
 991				Models: []catwalk.Model{
 992					{
 993						ID:               "large-model",
 994						DefaultMaxTokens: 1000,
 995					},
 996					{
 997						ID:               "small-model",
 998						DefaultMaxTokens: 500,
 999					},
1000				},
1001			},
1002		}
1003
1004		cfg := &Config{
1005			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1006				"custom": {
1007					APIKey:  "test-key",
1008					BaseURL: "https://api.custom.com/v1",
1009					Models: []catwalk.Model{
1010						{
1011							ID:               "large-model",
1012							DefaultMaxTokens: 1000,
1013						},
1014					},
1015				},
1016			}),
1017		}
1018		cfg.setDefaults("/tmp")
1019		env := env.NewFromMap(map[string]string{})
1020		resolver := NewEnvironmentVariableResolver(env)
1021		err := cfg.configureProviders(env, resolver, knownProviders)
1022		require.NoError(t, err)
1023		large, small, err := cfg.defaultModelSelection(knownProviders)
1024		require.NoError(t, err)
1025		require.Equal(t, "large-model", large.Model)
1026		require.Equal(t, "openai", large.Provider)
1027		require.Equal(t, int64(1000), large.MaxTokens)
1028		require.Equal(t, "small-model", small.Model)
1029		require.Equal(t, "openai", small.Provider)
1030		require.Equal(t, int64(500), small.MaxTokens)
1031	})
1032}
1033
1034func TestConfig_configureSelectedModels(t *testing.T) {
1035	t.Run("should override defaults", func(t *testing.T) {
1036		knownProviders := []catwalk.Provider{
1037			{
1038				ID:                  "openai",
1039				APIKey:              "abc",
1040				DefaultLargeModelID: "large-model",
1041				DefaultSmallModelID: "small-model",
1042				Models: []catwalk.Model{
1043					{
1044						ID:               "larger-model",
1045						DefaultMaxTokens: 2000,
1046					},
1047					{
1048						ID:               "large-model",
1049						DefaultMaxTokens: 1000,
1050					},
1051					{
1052						ID:               "small-model",
1053						DefaultMaxTokens: 500,
1054					},
1055				},
1056			},
1057		}
1058
1059		cfg := &Config{
1060			Models: map[SelectedModelType]SelectedModel{
1061				"large": {
1062					Model: "larger-model",
1063				},
1064			},
1065		}
1066		cfg.setDefaults("/tmp")
1067		env := env.NewFromMap(map[string]string{})
1068		resolver := NewEnvironmentVariableResolver(env)
1069		err := cfg.configureProviders(env, resolver, knownProviders)
1070		require.NoError(t, err)
1071
1072		err = cfg.configureSelectedModels(knownProviders)
1073		require.NoError(t, err)
1074		large := cfg.Models[SelectedModelTypeLarge]
1075		small := cfg.Models[SelectedModelTypeSmall]
1076		require.Equal(t, "larger-model", large.Model)
1077		require.Equal(t, "openai", large.Provider)
1078		require.Equal(t, int64(2000), large.MaxTokens)
1079		require.Equal(t, "small-model", small.Model)
1080		require.Equal(t, "openai", small.Provider)
1081		require.Equal(t, int64(500), small.MaxTokens)
1082	})
1083	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1084		knownProviders := []catwalk.Provider{
1085			{
1086				ID:                  "openai",
1087				APIKey:              "abc",
1088				DefaultLargeModelID: "large-model",
1089				DefaultSmallModelID: "small-model",
1090				Models: []catwalk.Model{
1091					{
1092						ID:               "large-model",
1093						DefaultMaxTokens: 1000,
1094					},
1095					{
1096						ID:               "small-model",
1097						DefaultMaxTokens: 500,
1098					},
1099				},
1100			},
1101			{
1102				ID:                  "anthropic",
1103				APIKey:              "abc",
1104				DefaultLargeModelID: "a-large-model",
1105				DefaultSmallModelID: "a-small-model",
1106				Models: []catwalk.Model{
1107					{
1108						ID:               "a-large-model",
1109						DefaultMaxTokens: 1000,
1110					},
1111					{
1112						ID:               "a-small-model",
1113						DefaultMaxTokens: 200,
1114					},
1115				},
1116			},
1117		}
1118
1119		cfg := &Config{
1120			Models: map[SelectedModelType]SelectedModel{
1121				"small": {
1122					Model:     "a-small-model",
1123					Provider:  "anthropic",
1124					MaxTokens: 300,
1125				},
1126			},
1127		}
1128		cfg.setDefaults("/tmp")
1129		env := env.NewFromMap(map[string]string{})
1130		resolver := NewEnvironmentVariableResolver(env)
1131		err := cfg.configureProviders(env, resolver, knownProviders)
1132		require.NoError(t, err)
1133
1134		err = cfg.configureSelectedModels(knownProviders)
1135		require.NoError(t, err)
1136		large := cfg.Models[SelectedModelTypeLarge]
1137		small := cfg.Models[SelectedModelTypeSmall]
1138		require.Equal(t, "large-model", large.Model)
1139		require.Equal(t, "openai", large.Provider)
1140		require.Equal(t, int64(1000), large.MaxTokens)
1141		require.Equal(t, "a-small-model", small.Model)
1142		require.Equal(t, "anthropic", small.Provider)
1143		require.Equal(t, int64(300), small.MaxTokens)
1144	})
1145
1146	t.Run("should override the max tokens only", func(t *testing.T) {
1147		knownProviders := []catwalk.Provider{
1148			{
1149				ID:                  "openai",
1150				APIKey:              "abc",
1151				DefaultLargeModelID: "large-model",
1152				DefaultSmallModelID: "small-model",
1153				Models: []catwalk.Model{
1154					{
1155						ID:               "large-model",
1156						DefaultMaxTokens: 1000,
1157					},
1158					{
1159						ID:               "small-model",
1160						DefaultMaxTokens: 500,
1161					},
1162				},
1163			},
1164		}
1165
1166		cfg := &Config{
1167			Models: map[SelectedModelType]SelectedModel{
1168				"large": {
1169					MaxTokens: 100,
1170				},
1171			},
1172		}
1173		cfg.setDefaults("/tmp")
1174		env := env.NewFromMap(map[string]string{})
1175		resolver := NewEnvironmentVariableResolver(env)
1176		err := cfg.configureProviders(env, resolver, knownProviders)
1177		require.NoError(t, err)
1178
1179		err = cfg.configureSelectedModels(knownProviders)
1180		require.NoError(t, err)
1181		large := cfg.Models[SelectedModelTypeLarge]
1182		require.Equal(t, "large-model", large.Model)
1183		require.Equal(t, "openai", large.Provider)
1184		require.Equal(t, int64(100), large.MaxTokens)
1185	})
1186}