load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/stretchr/testify/assert"
  15	"github.com/stretchr/testify/require"
  16)
  17
  18func TestMain(m *testing.M) {
  19	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  20
  21	exitVal := m.Run()
  22	os.Exit(exitVal)
  23}
  24
  25func TestConfig_LoadFromReaders(t *testing.T) {
  26	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  27	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  28	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  29
  30	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  31
  32	require.NoError(t, err)
  33	require.NotNil(t, loadedConfig)
  34	require.Equal(t, 1, loadedConfig.Providers.Len())
  35	pc, _ := loadedConfig.Providers.Get("openai")
  36	require.Equal(t, "key2", pc.APIKey)
  37	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  38}
  39
  40func TestConfig_setDefaults(t *testing.T) {
  41	cfg := &Config{}
  42
  43	cfg.setDefaults("/tmp", "")
  44
  45	require.NotNil(t, cfg.Options)
  46	require.NotNil(t, cfg.Options.TUI)
  47	require.NotNil(t, cfg.Options.ContextPaths)
  48	require.NotNil(t, cfg.Providers)
  49	require.NotNil(t, cfg.Models)
  50	require.NotNil(t, cfg.LSP)
  51	require.NotNil(t, cfg.MCP)
  52	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  53	for _, path := range defaultContextPaths {
  54		require.Contains(t, cfg.Options.ContextPaths, path)
  55	}
  56	require.Equal(t, "/tmp", cfg.workingDir)
  57}
  58
  59func TestConfig_configureProviders(t *testing.T) {
  60	knownProviders := []catwalk.Provider{
  61		{
  62			ID:          "openai",
  63			APIKey:      "$OPENAI_API_KEY",
  64			APIEndpoint: "https://api.openai.com/v1",
  65			Models: []catwalk.Model{{
  66				ID: "test-model",
  67			}},
  68		},
  69	}
  70
  71	cfg := &Config{}
  72	cfg.setDefaults("/tmp", "")
  73	env := env.NewFromMap(map[string]string{
  74		"OPENAI_API_KEY": "test-key",
  75	})
  76	resolver := NewEnvironmentVariableResolver(env)
  77	err := cfg.configureProviders(env, resolver, knownProviders)
  78	require.NoError(t, err)
  79	require.Equal(t, 1, cfg.Providers.Len())
  80
  81	// We want to make sure that we keep the configured API key as a placeholder
  82	pc, _ := cfg.Providers.Get("openai")
  83	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  84}
  85
  86func TestConfig_configureProvidersWithOverride(t *testing.T) {
  87	knownProviders := []catwalk.Provider{
  88		{
  89			ID:          "openai",
  90			APIKey:      "$OPENAI_API_KEY",
  91			APIEndpoint: "https://api.openai.com/v1",
  92			Models: []catwalk.Model{{
  93				ID: "test-model",
  94			}},
  95		},
  96	}
  97
  98	cfg := &Config{
  99		Providers: csync.NewMap[string, ProviderConfig](),
 100	}
 101	cfg.Providers.Set("openai", ProviderConfig{
 102		APIKey:  "xyz",
 103		BaseURL: "https://api.openai.com/v2",
 104		Models: []catwalk.Model{
 105			{
 106				ID:   "test-model",
 107				Name: "Updated",
 108			},
 109			{
 110				ID: "another-model",
 111			},
 112		},
 113	})
 114	cfg.setDefaults("/tmp", "")
 115
 116	env := env.NewFromMap(map[string]string{
 117		"OPENAI_API_KEY": "test-key",
 118	})
 119	resolver := NewEnvironmentVariableResolver(env)
 120	err := cfg.configureProviders(env, resolver, knownProviders)
 121	require.NoError(t, err)
 122	require.Equal(t, 1, cfg.Providers.Len())
 123
 124	// We want to make sure that we keep the configured API key as a placeholder
 125	pc, _ := cfg.Providers.Get("openai")
 126	require.Equal(t, "xyz", pc.APIKey)
 127	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 128	require.Len(t, pc.Models, 2)
 129	require.Equal(t, "Updated", pc.Models[0].Name)
 130}
 131
 132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 133	knownProviders := []catwalk.Provider{
 134		{
 135			ID:          "openai",
 136			APIKey:      "$OPENAI_API_KEY",
 137			APIEndpoint: "https://api.openai.com/v1",
 138			Models: []catwalk.Model{{
 139				ID: "test-model",
 140			}},
 141		},
 142	}
 143
 144	cfg := &Config{
 145		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 146			"custom": {
 147				APIKey:  "xyz",
 148				BaseURL: "https://api.someendpoint.com/v2",
 149				Models: []catwalk.Model{
 150					{
 151						ID: "test-model",
 152					},
 153				},
 154			},
 155		}),
 156	}
 157	cfg.setDefaults("/tmp", "")
 158	env := env.NewFromMap(map[string]string{
 159		"OPENAI_API_KEY": "test-key",
 160	})
 161	resolver := NewEnvironmentVariableResolver(env)
 162	err := cfg.configureProviders(env, resolver, knownProviders)
 163	require.NoError(t, err)
 164	// Should be to because of the env variable
 165	require.Equal(t, cfg.Providers.Len(), 2)
 166
 167	// We want to make sure that we keep the configured API key as a placeholder
 168	pc, _ := cfg.Providers.Get("custom")
 169	require.Equal(t, "xyz", pc.APIKey)
 170	// Make sure we set the ID correctly
 171	require.Equal(t, "custom", pc.ID)
 172	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 173	require.Len(t, pc.Models, 1)
 174
 175	_, ok := cfg.Providers.Get("openai")
 176	require.True(t, ok, "OpenAI provider should still be present")
 177}
 178
 179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 180	knownProviders := []catwalk.Provider{
 181		{
 182			ID:          catwalk.InferenceProviderBedrock,
 183			APIKey:      "",
 184			APIEndpoint: "",
 185			Models: []catwalk.Model{{
 186				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 187			}},
 188		},
 189	}
 190
 191	cfg := &Config{}
 192	cfg.setDefaults("/tmp", "")
 193	env := env.NewFromMap(map[string]string{
 194		"AWS_ACCESS_KEY_ID":     "test-key-id",
 195		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 196	})
 197	resolver := NewEnvironmentVariableResolver(env)
 198	err := cfg.configureProviders(env, resolver, knownProviders)
 199	require.NoError(t, err)
 200	require.Equal(t, cfg.Providers.Len(), 1)
 201
 202	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 203	require.True(t, ok, "Bedrock provider should be present")
 204	require.Len(t, bedrockProvider.Models, 1)
 205	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 206}
 207
 208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 209	knownProviders := []catwalk.Provider{
 210		{
 211			ID:          catwalk.InferenceProviderBedrock,
 212			APIKey:      "",
 213			APIEndpoint: "",
 214			Models: []catwalk.Model{{
 215				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 216			}},
 217		},
 218	}
 219
 220	cfg := &Config{}
 221	cfg.setDefaults("/tmp", "")
 222	env := env.NewFromMap(map[string]string{})
 223	resolver := NewEnvironmentVariableResolver(env)
 224	err := cfg.configureProviders(env, resolver, knownProviders)
 225	require.NoError(t, err)
 226	// Provider should not be configured without credentials
 227	require.Equal(t, cfg.Providers.Len(), 0)
 228}
 229
 230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 231	knownProviders := []catwalk.Provider{
 232		{
 233			ID:          catwalk.InferenceProviderBedrock,
 234			APIKey:      "",
 235			APIEndpoint: "",
 236			Models: []catwalk.Model{{
 237				ID: "some-random-model",
 238			}},
 239		},
 240	}
 241
 242	cfg := &Config{}
 243	cfg.setDefaults("/tmp", "")
 244	env := env.NewFromMap(map[string]string{
 245		"AWS_ACCESS_KEY_ID":     "test-key-id",
 246		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 247	})
 248	resolver := NewEnvironmentVariableResolver(env)
 249	err := cfg.configureProviders(env, resolver, knownProviders)
 250	require.Error(t, err)
 251}
 252
 253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 254	knownProviders := []catwalk.Provider{
 255		{
 256			ID:          catwalk.InferenceProviderVertexAI,
 257			APIKey:      "",
 258			APIEndpoint: "",
 259			Models: []catwalk.Model{{
 260				ID: "gemini-pro",
 261			}},
 262		},
 263	}
 264
 265	cfg := &Config{}
 266	cfg.setDefaults("/tmp", "")
 267	env := env.NewFromMap(map[string]string{
 268		"VERTEXAI_PROJECT":  "test-project",
 269		"VERTEXAI_LOCATION": "us-central1",
 270	})
 271	resolver := NewEnvironmentVariableResolver(env)
 272	err := cfg.configureProviders(env, resolver, knownProviders)
 273	require.NoError(t, err)
 274	require.Equal(t, cfg.Providers.Len(), 1)
 275
 276	vertexProvider, ok := cfg.Providers.Get("vertexai")
 277	require.True(t, ok, "VertexAI provider should be present")
 278	require.Len(t, vertexProvider.Models, 1)
 279	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 280	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 281	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 282}
 283
 284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 285	knownProviders := []catwalk.Provider{
 286		{
 287			ID:          catwalk.InferenceProviderVertexAI,
 288			APIKey:      "",
 289			APIEndpoint: "",
 290			Models: []catwalk.Model{{
 291				ID: "gemini-pro",
 292			}},
 293		},
 294	}
 295
 296	cfg := &Config{}
 297	cfg.setDefaults("/tmp", "")
 298	env := env.NewFromMap(map[string]string{
 299		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 300		"GOOGLE_CLOUD_PROJECT":      "test-project",
 301		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 302	})
 303	resolver := NewEnvironmentVariableResolver(env)
 304	err := cfg.configureProviders(env, resolver, knownProviders)
 305	require.NoError(t, err)
 306	// Provider should not be configured without proper credentials
 307	require.Equal(t, cfg.Providers.Len(), 0)
 308}
 309
 310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 311	knownProviders := []catwalk.Provider{
 312		{
 313			ID:          catwalk.InferenceProviderVertexAI,
 314			APIKey:      "",
 315			APIEndpoint: "",
 316			Models: []catwalk.Model{{
 317				ID: "gemini-pro",
 318			}},
 319		},
 320	}
 321
 322	cfg := &Config{}
 323	cfg.setDefaults("/tmp", "")
 324	env := env.NewFromMap(map[string]string{
 325		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 326		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 327	})
 328	resolver := NewEnvironmentVariableResolver(env)
 329	err := cfg.configureProviders(env, resolver, knownProviders)
 330	require.NoError(t, err)
 331	// Provider should not be configured without project
 332	require.Equal(t, cfg.Providers.Len(), 0)
 333}
 334
 335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 336	knownProviders := []catwalk.Provider{
 337		{
 338			ID:          "openai",
 339			APIKey:      "$OPENAI_API_KEY",
 340			APIEndpoint: "https://api.openai.com/v1",
 341			Models: []catwalk.Model{{
 342				ID: "test-model",
 343			}},
 344		},
 345	}
 346
 347	cfg := &Config{}
 348	cfg.setDefaults("/tmp", "")
 349	env := env.NewFromMap(map[string]string{
 350		"OPENAI_API_KEY": "test-key",
 351	})
 352	resolver := NewEnvironmentVariableResolver(env)
 353	err := cfg.configureProviders(env, resolver, knownProviders)
 354	require.NoError(t, err)
 355	require.Equal(t, cfg.Providers.Len(), 1)
 356
 357	// Provider ID should be set
 358	pc, _ := cfg.Providers.Get("openai")
 359	require.Equal(t, "openai", pc.ID)
 360}
 361
 362func TestConfig_EnabledProviders(t *testing.T) {
 363	t.Run("all providers enabled", func(t *testing.T) {
 364		cfg := &Config{
 365			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 366				"openai": {
 367					ID:      "openai",
 368					APIKey:  "key1",
 369					Disable: false,
 370				},
 371				"anthropic": {
 372					ID:      "anthropic",
 373					APIKey:  "key2",
 374					Disable: false,
 375				},
 376			}),
 377		}
 378
 379		enabled := cfg.EnabledProviders()
 380		require.Len(t, enabled, 2)
 381	})
 382
 383	t.Run("some providers disabled", func(t *testing.T) {
 384		cfg := &Config{
 385			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 386				"openai": {
 387					ID:      "openai",
 388					APIKey:  "key1",
 389					Disable: false,
 390				},
 391				"anthropic": {
 392					ID:      "anthropic",
 393					APIKey:  "key2",
 394					Disable: true,
 395				},
 396			}),
 397		}
 398
 399		enabled := cfg.EnabledProviders()
 400		require.Len(t, enabled, 1)
 401		require.Equal(t, "openai", enabled[0].ID)
 402	})
 403
 404	t.Run("empty providers map", func(t *testing.T) {
 405		cfg := &Config{
 406			Providers: csync.NewMap[string, ProviderConfig](),
 407		}
 408
 409		enabled := cfg.EnabledProviders()
 410		require.Len(t, enabled, 0)
 411	})
 412}
 413
 414func TestConfig_IsConfigured(t *testing.T) {
 415	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 416		cfg := &Config{
 417			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 418				"openai": {
 419					ID:      "openai",
 420					APIKey:  "key1",
 421					Disable: false,
 422				},
 423			}),
 424		}
 425
 426		require.True(t, cfg.IsConfigured())
 427	})
 428
 429	t.Run("returns false when no providers are configured", func(t *testing.T) {
 430		cfg := &Config{
 431			Providers: csync.NewMap[string, ProviderConfig](),
 432		}
 433
 434		require.False(t, cfg.IsConfigured())
 435	})
 436
 437	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 438		cfg := &Config{
 439			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 440				"openai": {
 441					ID:      "openai",
 442					APIKey:  "key1",
 443					Disable: true,
 444				},
 445				"anthropic": {
 446					ID:      "anthropic",
 447					APIKey:  "key2",
 448					Disable: true,
 449				},
 450			}),
 451		}
 452
 453		require.False(t, cfg.IsConfigured())
 454	})
 455}
 456
 457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 458	cfg := &Config{
 459		Options: &Options{
 460			DisabledTools: []string{},
 461		},
 462	}
 463
 464	cfg.SetupAgents()
 465	coderAgent, ok := cfg.Agents[AgentCoder]
 466	require.True(t, ok)
 467	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 468
 469	taskAgent, ok := cfg.Agents[AgentTask]
 470	require.True(t, ok)
 471	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 472
 473	fetchAgent, ok := cfg.Agents[AgentFetch]
 474	require.True(t, ok)
 475	assert.Equal(t, "fetch", fetchAgent.ID)
 476	assert.Equal(t, SelectedModelTypeSmall, fetchAgent.Model)
 477	assert.Equal(t, []string{"web_fetch", "grep", "glob", "view"}, fetchAgent.AllowedTools)
 478	assert.Empty(t, fetchAgent.ContextPaths)
 479}
 480
 481func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 482	cfg := &Config{
 483		Options: &Options{
 484			DisabledTools: []string{
 485				"edit",
 486				"download",
 487				"grep",
 488			},
 489		},
 490	}
 491
 492	cfg.SetupAgents()
 493	coderAgent, ok := cfg.Agents[AgentCoder]
 494	require.True(t, ok)
 495	assert.Equal(t, []string{"agent", "bash", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
 496
 497	taskAgent, ok := cfg.Agents[AgentTask]
 498	require.True(t, ok)
 499	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 500}
 501
 502func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 503	cfg := &Config{
 504		Options: &Options{
 505			DisabledTools: []string{
 506				"glob",
 507				"grep",
 508				"ls",
 509				"sourcegraph",
 510				"view",
 511			},
 512		},
 513	}
 514
 515	cfg.SetupAgents()
 516	coderAgent, ok := cfg.Agents[AgentCoder]
 517	require.True(t, ok)
 518	assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "write"}, coderAgent.AllowedTools)
 519
 520	taskAgent, ok := cfg.Agents[AgentTask]
 521	require.True(t, ok)
 522	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 523}
 524
 525func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 526	knownProviders := []catwalk.Provider{
 527		{
 528			ID:          "openai",
 529			APIKey:      "$OPENAI_API_KEY",
 530			APIEndpoint: "https://api.openai.com/v1",
 531			Models: []catwalk.Model{{
 532				ID: "test-model",
 533			}},
 534		},
 535	}
 536
 537	cfg := &Config{
 538		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 539			"openai": {
 540				Disable: true,
 541			},
 542		}),
 543	}
 544	cfg.setDefaults("/tmp", "")
 545
 546	env := env.NewFromMap(map[string]string{
 547		"OPENAI_API_KEY": "test-key",
 548	})
 549	resolver := NewEnvironmentVariableResolver(env)
 550	err := cfg.configureProviders(env, resolver, knownProviders)
 551	require.NoError(t, err)
 552
 553	require.Equal(t, cfg.Providers.Len(), 1)
 554	prov, exists := cfg.Providers.Get("openai")
 555	require.True(t, exists)
 556	require.True(t, prov.Disable)
 557}
 558
 559func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 560	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 561		cfg := &Config{
 562			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 563				"custom": {
 564					BaseURL: "https://api.custom.com/v1",
 565					Models: []catwalk.Model{{
 566						ID: "test-model",
 567					}},
 568				},
 569				"openai": {
 570					APIKey: "$MISSING",
 571				},
 572			}),
 573		}
 574		cfg.setDefaults("/tmp", "")
 575
 576		env := env.NewFromMap(map[string]string{})
 577		resolver := NewEnvironmentVariableResolver(env)
 578		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 579		require.NoError(t, err)
 580
 581		require.Equal(t, cfg.Providers.Len(), 1)
 582		_, exists := cfg.Providers.Get("custom")
 583		require.True(t, exists)
 584	})
 585
 586	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 587		cfg := &Config{
 588			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 589				"custom": {
 590					APIKey: "test-key",
 591					Models: []catwalk.Model{{
 592						ID: "test-model",
 593					}},
 594				},
 595			}),
 596		}
 597		cfg.setDefaults("/tmp", "")
 598
 599		env := env.NewFromMap(map[string]string{})
 600		resolver := NewEnvironmentVariableResolver(env)
 601		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 602		require.NoError(t, err)
 603
 604		require.Equal(t, cfg.Providers.Len(), 0)
 605		_, exists := cfg.Providers.Get("custom")
 606		require.False(t, exists)
 607	})
 608
 609	t.Run("custom provider with no models is removed", func(t *testing.T) {
 610		cfg := &Config{
 611			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 612				"custom": {
 613					APIKey:  "test-key",
 614					BaseURL: "https://api.custom.com/v1",
 615					Models:  []catwalk.Model{},
 616				},
 617			}),
 618		}
 619		cfg.setDefaults("/tmp", "")
 620
 621		env := env.NewFromMap(map[string]string{})
 622		resolver := NewEnvironmentVariableResolver(env)
 623		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 624		require.NoError(t, err)
 625
 626		require.Equal(t, cfg.Providers.Len(), 0)
 627		_, exists := cfg.Providers.Get("custom")
 628		require.False(t, exists)
 629	})
 630
 631	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 632		cfg := &Config{
 633			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 634				"custom": {
 635					APIKey:  "test-key",
 636					BaseURL: "https://api.custom.com/v1",
 637					Type:    "unsupported",
 638					Models: []catwalk.Model{{
 639						ID: "test-model",
 640					}},
 641				},
 642			}),
 643		}
 644		cfg.setDefaults("/tmp", "")
 645
 646		env := env.NewFromMap(map[string]string{})
 647		resolver := NewEnvironmentVariableResolver(env)
 648		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 649		require.NoError(t, err)
 650
 651		require.Equal(t, cfg.Providers.Len(), 0)
 652		_, exists := cfg.Providers.Get("custom")
 653		require.False(t, exists)
 654	})
 655
 656	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 657		cfg := &Config{
 658			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 659				"custom": {
 660					APIKey:  "test-key",
 661					BaseURL: "https://api.custom.com/v1",
 662					Type:    catwalk.TypeOpenAI,
 663					Models: []catwalk.Model{{
 664						ID: "test-model",
 665					}},
 666				},
 667			}),
 668		}
 669		cfg.setDefaults("/tmp", "")
 670
 671		env := env.NewFromMap(map[string]string{})
 672		resolver := NewEnvironmentVariableResolver(env)
 673		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 674		require.NoError(t, err)
 675
 676		require.Equal(t, cfg.Providers.Len(), 1)
 677		customProvider, exists := cfg.Providers.Get("custom")
 678		require.True(t, exists)
 679		require.Equal(t, "custom", customProvider.ID)
 680		require.Equal(t, "test-key", customProvider.APIKey)
 681		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 682	})
 683
 684	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 685		cfg := &Config{
 686			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 687				"custom-anthropic": {
 688					APIKey:  "test-key",
 689					BaseURL: "https://api.anthropic.com/v1",
 690					Type:    catwalk.TypeAnthropic,
 691					Models: []catwalk.Model{{
 692						ID: "claude-3-sonnet",
 693					}},
 694				},
 695			}),
 696		}
 697		cfg.setDefaults("/tmp", "")
 698
 699		env := env.NewFromMap(map[string]string{})
 700		resolver := NewEnvironmentVariableResolver(env)
 701		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 702		require.NoError(t, err)
 703
 704		require.Equal(t, cfg.Providers.Len(), 1)
 705		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 706		require.True(t, exists)
 707		require.Equal(t, "custom-anthropic", customProvider.ID)
 708		require.Equal(t, "test-key", customProvider.APIKey)
 709		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 710		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 711	})
 712
 713	t.Run("disabled custom provider is removed", func(t *testing.T) {
 714		cfg := &Config{
 715			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 716				"custom": {
 717					APIKey:  "test-key",
 718					BaseURL: "https://api.custom.com/v1",
 719					Type:    catwalk.TypeOpenAI,
 720					Disable: true,
 721					Models: []catwalk.Model{{
 722						ID: "test-model",
 723					}},
 724				},
 725			}),
 726		}
 727		cfg.setDefaults("/tmp", "")
 728
 729		env := env.NewFromMap(map[string]string{})
 730		resolver := NewEnvironmentVariableResolver(env)
 731		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 732		require.NoError(t, err)
 733
 734		require.Equal(t, cfg.Providers.Len(), 0)
 735		_, exists := cfg.Providers.Get("custom")
 736		require.False(t, exists)
 737	})
 738}
 739
 740func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 741	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 742		knownProviders := []catwalk.Provider{
 743			{
 744				ID:          catwalk.InferenceProviderVertexAI,
 745				APIKey:      "",
 746				APIEndpoint: "",
 747				Models: []catwalk.Model{{
 748					ID: "gemini-pro",
 749				}},
 750			},
 751		}
 752
 753		cfg := &Config{
 754			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 755				"vertexai": {
 756					BaseURL: "custom-url",
 757				},
 758			}),
 759		}
 760		cfg.setDefaults("/tmp", "")
 761
 762		env := env.NewFromMap(map[string]string{
 763			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 764		})
 765		resolver := NewEnvironmentVariableResolver(env)
 766		err := cfg.configureProviders(env, resolver, knownProviders)
 767		require.NoError(t, err)
 768
 769		require.Equal(t, cfg.Providers.Len(), 0)
 770		_, exists := cfg.Providers.Get("vertexai")
 771		require.False(t, exists)
 772	})
 773
 774	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 775		knownProviders := []catwalk.Provider{
 776			{
 777				ID:          catwalk.InferenceProviderBedrock,
 778				APIKey:      "",
 779				APIEndpoint: "",
 780				Models: []catwalk.Model{{
 781					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 782				}},
 783			},
 784		}
 785
 786		cfg := &Config{
 787			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 788				"bedrock": {
 789					BaseURL: "custom-url",
 790				},
 791			}),
 792		}
 793		cfg.setDefaults("/tmp", "")
 794
 795		env := env.NewFromMap(map[string]string{})
 796		resolver := NewEnvironmentVariableResolver(env)
 797		err := cfg.configureProviders(env, resolver, knownProviders)
 798		require.NoError(t, err)
 799
 800		require.Equal(t, cfg.Providers.Len(), 0)
 801		_, exists := cfg.Providers.Get("bedrock")
 802		require.False(t, exists)
 803	})
 804
 805	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 806		knownProviders := []catwalk.Provider{
 807			{
 808				ID:          "openai",
 809				APIKey:      "$MISSING_API_KEY",
 810				APIEndpoint: "https://api.openai.com/v1",
 811				Models: []catwalk.Model{{
 812					ID: "test-model",
 813				}},
 814			},
 815		}
 816
 817		cfg := &Config{
 818			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 819				"openai": {
 820					BaseURL: "custom-url",
 821				},
 822			}),
 823		}
 824		cfg.setDefaults("/tmp", "")
 825
 826		env := env.NewFromMap(map[string]string{})
 827		resolver := NewEnvironmentVariableResolver(env)
 828		err := cfg.configureProviders(env, resolver, knownProviders)
 829		require.NoError(t, err)
 830
 831		require.Equal(t, cfg.Providers.Len(), 0)
 832		_, exists := cfg.Providers.Get("openai")
 833		require.False(t, exists)
 834	})
 835
 836	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 837		knownProviders := []catwalk.Provider{
 838			{
 839				ID:          "openai",
 840				APIKey:      "$OPENAI_API_KEY",
 841				APIEndpoint: "$MISSING_ENDPOINT",
 842				Models: []catwalk.Model{{
 843					ID: "test-model",
 844				}},
 845			},
 846		}
 847
 848		cfg := &Config{
 849			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 850				"openai": {
 851					APIKey: "test-key",
 852				},
 853			}),
 854		}
 855		cfg.setDefaults("/tmp", "")
 856
 857		env := env.NewFromMap(map[string]string{
 858			"OPENAI_API_KEY": "test-key",
 859		})
 860		resolver := NewEnvironmentVariableResolver(env)
 861		err := cfg.configureProviders(env, resolver, knownProviders)
 862		require.NoError(t, err)
 863
 864		require.Equal(t, cfg.Providers.Len(), 1)
 865		_, exists := cfg.Providers.Get("openai")
 866		require.True(t, exists)
 867	})
 868}
 869
 870func TestConfig_defaultModelSelection(t *testing.T) {
 871	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 872		knownProviders := []catwalk.Provider{
 873			{
 874				ID:                  "openai",
 875				APIKey:              "abc",
 876				DefaultLargeModelID: "large-model",
 877				DefaultSmallModelID: "small-model",
 878				Models: []catwalk.Model{
 879					{
 880						ID:               "large-model",
 881						DefaultMaxTokens: 1000,
 882					},
 883					{
 884						ID:               "small-model",
 885						DefaultMaxTokens: 500,
 886					},
 887				},
 888			},
 889		}
 890
 891		cfg := &Config{}
 892		cfg.setDefaults("/tmp", "")
 893		env := env.NewFromMap(map[string]string{})
 894		resolver := NewEnvironmentVariableResolver(env)
 895		err := cfg.configureProviders(env, resolver, knownProviders)
 896		require.NoError(t, err)
 897
 898		large, small, err := cfg.defaultModelSelection(knownProviders)
 899		require.NoError(t, err)
 900		require.Equal(t, "large-model", large.Model)
 901		require.Equal(t, "openai", large.Provider)
 902		require.Equal(t, int64(1000), large.MaxTokens)
 903		require.Equal(t, "small-model", small.Model)
 904		require.Equal(t, "openai", small.Provider)
 905		require.Equal(t, int64(500), small.MaxTokens)
 906	})
 907	t.Run("should error if no providers configured", func(t *testing.T) {
 908		knownProviders := []catwalk.Provider{
 909			{
 910				ID:                  "openai",
 911				APIKey:              "$MISSING_KEY",
 912				DefaultLargeModelID: "large-model",
 913				DefaultSmallModelID: "small-model",
 914				Models: []catwalk.Model{
 915					{
 916						ID:               "large-model",
 917						DefaultMaxTokens: 1000,
 918					},
 919					{
 920						ID:               "small-model",
 921						DefaultMaxTokens: 500,
 922					},
 923				},
 924			},
 925		}
 926
 927		cfg := &Config{}
 928		cfg.setDefaults("/tmp", "")
 929		env := env.NewFromMap(map[string]string{})
 930		resolver := NewEnvironmentVariableResolver(env)
 931		err := cfg.configureProviders(env, resolver, knownProviders)
 932		require.NoError(t, err)
 933
 934		_, _, err = cfg.defaultModelSelection(knownProviders)
 935		require.Error(t, err)
 936	})
 937	t.Run("should error if model is missing", func(t *testing.T) {
 938		knownProviders := []catwalk.Provider{
 939			{
 940				ID:                  "openai",
 941				APIKey:              "abc",
 942				DefaultLargeModelID: "large-model",
 943				DefaultSmallModelID: "small-model",
 944				Models: []catwalk.Model{
 945					{
 946						ID:               "not-large-model",
 947						DefaultMaxTokens: 1000,
 948					},
 949					{
 950						ID:               "small-model",
 951						DefaultMaxTokens: 500,
 952					},
 953				},
 954			},
 955		}
 956
 957		cfg := &Config{}
 958		cfg.setDefaults("/tmp", "")
 959		env := env.NewFromMap(map[string]string{})
 960		resolver := NewEnvironmentVariableResolver(env)
 961		err := cfg.configureProviders(env, resolver, knownProviders)
 962		require.NoError(t, err)
 963		_, _, err = cfg.defaultModelSelection(knownProviders)
 964		require.Error(t, err)
 965	})
 966
 967	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 968		knownProviders := []catwalk.Provider{
 969			{
 970				ID:                  "openai",
 971				APIKey:              "$MISSING", // will not be included in the config
 972				DefaultLargeModelID: "large-model",
 973				DefaultSmallModelID: "small-model",
 974				Models: []catwalk.Model{
 975					{
 976						ID:               "not-large-model",
 977						DefaultMaxTokens: 1000,
 978					},
 979					{
 980						ID:               "small-model",
 981						DefaultMaxTokens: 500,
 982					},
 983				},
 984			},
 985		}
 986
 987		cfg := &Config{
 988			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 989				"custom": {
 990					APIKey:  "test-key",
 991					BaseURL: "https://api.custom.com/v1",
 992					Models: []catwalk.Model{
 993						{
 994							ID:               "model",
 995							DefaultMaxTokens: 600,
 996						},
 997					},
 998				},
 999			}),
1000		}
1001		cfg.setDefaults("/tmp", "")
1002		env := env.NewFromMap(map[string]string{})
1003		resolver := NewEnvironmentVariableResolver(env)
1004		err := cfg.configureProviders(env, resolver, knownProviders)
1005		require.NoError(t, err)
1006		large, small, err := cfg.defaultModelSelection(knownProviders)
1007		require.NoError(t, err)
1008		require.Equal(t, "model", large.Model)
1009		require.Equal(t, "custom", large.Provider)
1010		require.Equal(t, int64(600), large.MaxTokens)
1011		require.Equal(t, "model", small.Model)
1012		require.Equal(t, "custom", small.Provider)
1013		require.Equal(t, int64(600), small.MaxTokens)
1014	})
1015
1016	t.Run("should fail if no model configured", func(t *testing.T) {
1017		knownProviders := []catwalk.Provider{
1018			{
1019				ID:                  "openai",
1020				APIKey:              "$MISSING", // will not be included in the config
1021				DefaultLargeModelID: "large-model",
1022				DefaultSmallModelID: "small-model",
1023				Models: []catwalk.Model{
1024					{
1025						ID:               "not-large-model",
1026						DefaultMaxTokens: 1000,
1027					},
1028					{
1029						ID:               "small-model",
1030						DefaultMaxTokens: 500,
1031					},
1032				},
1033			},
1034		}
1035
1036		cfg := &Config{
1037			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1038				"custom": {
1039					APIKey:  "test-key",
1040					BaseURL: "https://api.custom.com/v1",
1041					Models:  []catwalk.Model{},
1042				},
1043			}),
1044		}
1045		cfg.setDefaults("/tmp", "")
1046		env := env.NewFromMap(map[string]string{})
1047		resolver := NewEnvironmentVariableResolver(env)
1048		err := cfg.configureProviders(env, resolver, knownProviders)
1049		require.NoError(t, err)
1050		_, _, err = cfg.defaultModelSelection(knownProviders)
1051		require.Error(t, err)
1052	})
1053	t.Run("should use the default provider first", func(t *testing.T) {
1054		knownProviders := []catwalk.Provider{
1055			{
1056				ID:                  "openai",
1057				APIKey:              "set",
1058				DefaultLargeModelID: "large-model",
1059				DefaultSmallModelID: "small-model",
1060				Models: []catwalk.Model{
1061					{
1062						ID:               "large-model",
1063						DefaultMaxTokens: 1000,
1064					},
1065					{
1066						ID:               "small-model",
1067						DefaultMaxTokens: 500,
1068					},
1069				},
1070			},
1071		}
1072
1073		cfg := &Config{
1074			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1075				"custom": {
1076					APIKey:  "test-key",
1077					BaseURL: "https://api.custom.com/v1",
1078					Models: []catwalk.Model{
1079						{
1080							ID:               "large-model",
1081							DefaultMaxTokens: 1000,
1082						},
1083					},
1084				},
1085			}),
1086		}
1087		cfg.setDefaults("/tmp", "")
1088		env := env.NewFromMap(map[string]string{})
1089		resolver := NewEnvironmentVariableResolver(env)
1090		err := cfg.configureProviders(env, resolver, knownProviders)
1091		require.NoError(t, err)
1092		large, small, err := cfg.defaultModelSelection(knownProviders)
1093		require.NoError(t, err)
1094		require.Equal(t, "large-model", large.Model)
1095		require.Equal(t, "openai", large.Provider)
1096		require.Equal(t, int64(1000), large.MaxTokens)
1097		require.Equal(t, "small-model", small.Model)
1098		require.Equal(t, "openai", small.Provider)
1099		require.Equal(t, int64(500), small.MaxTokens)
1100	})
1101}
1102
1103func TestConfig_configureSelectedModels(t *testing.T) {
1104	t.Run("should override defaults", func(t *testing.T) {
1105		knownProviders := []catwalk.Provider{
1106			{
1107				ID:                  "openai",
1108				APIKey:              "abc",
1109				DefaultLargeModelID: "large-model",
1110				DefaultSmallModelID: "small-model",
1111				Models: []catwalk.Model{
1112					{
1113						ID:               "larger-model",
1114						DefaultMaxTokens: 2000,
1115					},
1116					{
1117						ID:               "large-model",
1118						DefaultMaxTokens: 1000,
1119					},
1120					{
1121						ID:               "small-model",
1122						DefaultMaxTokens: 500,
1123					},
1124				},
1125			},
1126		}
1127
1128		cfg := &Config{
1129			Models: map[SelectedModelType]SelectedModel{
1130				"large": {
1131					Model: "larger-model",
1132				},
1133			},
1134		}
1135		cfg.setDefaults("/tmp", "")
1136		env := env.NewFromMap(map[string]string{})
1137		resolver := NewEnvironmentVariableResolver(env)
1138		err := cfg.configureProviders(env, resolver, knownProviders)
1139		require.NoError(t, err)
1140
1141		err = cfg.configureSelectedModels(knownProviders)
1142		require.NoError(t, err)
1143		large := cfg.Models[SelectedModelTypeLarge]
1144		small := cfg.Models[SelectedModelTypeSmall]
1145		require.Equal(t, "larger-model", large.Model)
1146		require.Equal(t, "openai", large.Provider)
1147		require.Equal(t, int64(2000), large.MaxTokens)
1148		require.Equal(t, "small-model", small.Model)
1149		require.Equal(t, "openai", small.Provider)
1150		require.Equal(t, int64(500), small.MaxTokens)
1151	})
1152	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1153		knownProviders := []catwalk.Provider{
1154			{
1155				ID:                  "openai",
1156				APIKey:              "abc",
1157				DefaultLargeModelID: "large-model",
1158				DefaultSmallModelID: "small-model",
1159				Models: []catwalk.Model{
1160					{
1161						ID:               "large-model",
1162						DefaultMaxTokens: 1000,
1163					},
1164					{
1165						ID:               "small-model",
1166						DefaultMaxTokens: 500,
1167					},
1168				},
1169			},
1170			{
1171				ID:                  "anthropic",
1172				APIKey:              "abc",
1173				DefaultLargeModelID: "a-large-model",
1174				DefaultSmallModelID: "a-small-model",
1175				Models: []catwalk.Model{
1176					{
1177						ID:               "a-large-model",
1178						DefaultMaxTokens: 1000,
1179					},
1180					{
1181						ID:               "a-small-model",
1182						DefaultMaxTokens: 200,
1183					},
1184				},
1185			},
1186		}
1187
1188		cfg := &Config{
1189			Models: map[SelectedModelType]SelectedModel{
1190				"small": {
1191					Model:     "a-small-model",
1192					Provider:  "anthropic",
1193					MaxTokens: 300,
1194				},
1195			},
1196		}
1197		cfg.setDefaults("/tmp", "")
1198		env := env.NewFromMap(map[string]string{})
1199		resolver := NewEnvironmentVariableResolver(env)
1200		err := cfg.configureProviders(env, resolver, knownProviders)
1201		require.NoError(t, err)
1202
1203		err = cfg.configureSelectedModels(knownProviders)
1204		require.NoError(t, err)
1205		large := cfg.Models[SelectedModelTypeLarge]
1206		small := cfg.Models[SelectedModelTypeSmall]
1207		require.Equal(t, "large-model", large.Model)
1208		require.Equal(t, "openai", large.Provider)
1209		require.Equal(t, int64(1000), large.MaxTokens)
1210		require.Equal(t, "a-small-model", small.Model)
1211		require.Equal(t, "anthropic", small.Provider)
1212		require.Equal(t, int64(300), small.MaxTokens)
1213	})
1214
1215	t.Run("should override the max tokens only", func(t *testing.T) {
1216		knownProviders := []catwalk.Provider{
1217			{
1218				ID:                  "openai",
1219				APIKey:              "abc",
1220				DefaultLargeModelID: "large-model",
1221				DefaultSmallModelID: "small-model",
1222				Models: []catwalk.Model{
1223					{
1224						ID:               "large-model",
1225						DefaultMaxTokens: 1000,
1226					},
1227					{
1228						ID:               "small-model",
1229						DefaultMaxTokens: 500,
1230					},
1231				},
1232			},
1233		}
1234
1235		cfg := &Config{
1236			Models: map[SelectedModelType]SelectedModel{
1237				"large": {
1238					MaxTokens: 100,
1239				},
1240			},
1241		}
1242		cfg.setDefaults("/tmp", "")
1243		env := env.NewFromMap(map[string]string{})
1244		resolver := NewEnvironmentVariableResolver(env)
1245		err := cfg.configureProviders(env, resolver, knownProviders)
1246		require.NoError(t, err)
1247
1248		err = cfg.configureSelectedModels(knownProviders)
1249		require.NoError(t, err)
1250		large := cfg.Models[SelectedModelTypeLarge]
1251		require.Equal(t, "large-model", large.Model)
1252		require.Equal(t, "openai", large.Provider)
1253		require.Equal(t, int64(100), large.MaxTokens)
1254	})
1255}