load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/stretchr/testify/assert"
  15	"github.com/stretchr/testify/require"
  16)
  17
  18func TestMain(m *testing.M) {
  19	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  20
  21	exitVal := m.Run()
  22	os.Exit(exitVal)
  23}
  24
  25func TestConfig_LoadFromReaders(t *testing.T) {
  26	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  27	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  28	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  29
  30	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  31
  32	require.NoError(t, err)
  33	require.NotNil(t, loadedConfig)
  34	require.Equal(t, 1, loadedConfig.Providers.Len())
  35	pc, _ := loadedConfig.Providers.Get("openai")
  36	require.Equal(t, "key2", pc.APIKey)
  37	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  38}
  39
  40func TestConfig_setDefaults(t *testing.T) {
  41	cfg := &Config{}
  42
  43	cfg.setDefaults("/tmp", "")
  44
  45	require.NotNil(t, cfg.Options)
  46	require.NotNil(t, cfg.Options.TUI)
  47	require.NotNil(t, cfg.Options.ContextPaths)
  48	require.NotNil(t, cfg.Providers)
  49	require.NotNil(t, cfg.Models)
  50	require.NotNil(t, cfg.LSP)
  51	require.NotNil(t, cfg.MCP)
  52	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  53	for _, path := range defaultContextPaths {
  54		require.Contains(t, cfg.Options.ContextPaths, path)
  55	}
  56	require.Equal(t, "/tmp", cfg.workingDir)
  57}
  58
  59func TestConfig_configureProviders(t *testing.T) {
  60	knownProviders := []catwalk.Provider{
  61		{
  62			ID:          "openai",
  63			APIKey:      "$OPENAI_API_KEY",
  64			APIEndpoint: "https://api.openai.com/v1",
  65			Models: []catwalk.Model{{
  66				ID: "test-model",
  67			}},
  68		},
  69	}
  70
  71	cfg := &Config{}
  72	cfg.setDefaults("/tmp", "")
  73	env := env.NewFromMap(map[string]string{
  74		"OPENAI_API_KEY": "test-key",
  75	})
  76	resolver := NewEnvironmentVariableResolver(env)
  77	err := cfg.configureProviders(env, resolver, knownProviders)
  78	require.NoError(t, err)
  79	require.Equal(t, 1, cfg.Providers.Len())
  80
  81	// We want to make sure that we keep the configured API key as a placeholder
  82	pc, _ := cfg.Providers.Get("openai")
  83	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  84}
  85
  86func TestConfig_configureProvidersWithOverride(t *testing.T) {
  87	knownProviders := []catwalk.Provider{
  88		{
  89			ID:          "openai",
  90			APIKey:      "$OPENAI_API_KEY",
  91			APIEndpoint: "https://api.openai.com/v1",
  92			Models: []catwalk.Model{{
  93				ID: "test-model",
  94			}},
  95		},
  96	}
  97
  98	cfg := &Config{
  99		Providers: csync.NewMap[string, ProviderConfig](),
 100	}
 101	cfg.Providers.Set("openai", ProviderConfig{
 102		APIKey:  "xyz",
 103		BaseURL: "https://api.openai.com/v2",
 104		Models: []catwalk.Model{
 105			{
 106				ID:   "test-model",
 107				Name: "Updated",
 108			},
 109			{
 110				ID: "another-model",
 111			},
 112		},
 113	})
 114	cfg.setDefaults("/tmp", "")
 115
 116	env := env.NewFromMap(map[string]string{
 117		"OPENAI_API_KEY": "test-key",
 118	})
 119	resolver := NewEnvironmentVariableResolver(env)
 120	err := cfg.configureProviders(env, resolver, knownProviders)
 121	require.NoError(t, err)
 122	require.Equal(t, 1, cfg.Providers.Len())
 123
 124	// We want to make sure that we keep the configured API key as a placeholder
 125	pc, _ := cfg.Providers.Get("openai")
 126	require.Equal(t, "xyz", pc.APIKey)
 127	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 128	require.Len(t, pc.Models, 2)
 129	require.Equal(t, "Updated", pc.Models[0].Name)
 130}
 131
 132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 133	knownProviders := []catwalk.Provider{
 134		{
 135			ID:          "openai",
 136			APIKey:      "$OPENAI_API_KEY",
 137			APIEndpoint: "https://api.openai.com/v1",
 138			Models: []catwalk.Model{{
 139				ID: "test-model",
 140			}},
 141		},
 142	}
 143
 144	cfg := &Config{
 145		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 146			"custom": {
 147				APIKey:  "xyz",
 148				BaseURL: "https://api.someendpoint.com/v2",
 149				Models: []catwalk.Model{
 150					{
 151						ID: "test-model",
 152					},
 153				},
 154			},
 155		}),
 156	}
 157	cfg.setDefaults("/tmp", "")
 158	env := env.NewFromMap(map[string]string{
 159		"OPENAI_API_KEY": "test-key",
 160	})
 161	resolver := NewEnvironmentVariableResolver(env)
 162	err := cfg.configureProviders(env, resolver, knownProviders)
 163	require.NoError(t, err)
 164	// Should be to because of the env variable
 165	require.Equal(t, cfg.Providers.Len(), 2)
 166
 167	// We want to make sure that we keep the configured API key as a placeholder
 168	pc, _ := cfg.Providers.Get("custom")
 169	require.Equal(t, "xyz", pc.APIKey)
 170	// Make sure we set the ID correctly
 171	require.Equal(t, "custom", pc.ID)
 172	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 173	require.Len(t, pc.Models, 1)
 174
 175	_, ok := cfg.Providers.Get("openai")
 176	require.True(t, ok, "OpenAI provider should still be present")
 177}
 178
 179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 180	knownProviders := []catwalk.Provider{
 181		{
 182			ID:          catwalk.InferenceProviderBedrock,
 183			APIKey:      "",
 184			APIEndpoint: "",
 185			Models: []catwalk.Model{{
 186				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 187			}},
 188		},
 189	}
 190
 191	cfg := &Config{}
 192	cfg.setDefaults("/tmp", "")
 193	env := env.NewFromMap(map[string]string{
 194		"AWS_ACCESS_KEY_ID":     "test-key-id",
 195		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 196	})
 197	resolver := NewEnvironmentVariableResolver(env)
 198	err := cfg.configureProviders(env, resolver, knownProviders)
 199	require.NoError(t, err)
 200	require.Equal(t, cfg.Providers.Len(), 1)
 201
 202	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 203	require.True(t, ok, "Bedrock provider should be present")
 204	require.Len(t, bedrockProvider.Models, 1)
 205	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 206}
 207
 208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 209	knownProviders := []catwalk.Provider{
 210		{
 211			ID:          catwalk.InferenceProviderBedrock,
 212			APIKey:      "",
 213			APIEndpoint: "",
 214			Models: []catwalk.Model{{
 215				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 216			}},
 217		},
 218	}
 219
 220	cfg := &Config{}
 221	cfg.setDefaults("/tmp", "")
 222	env := env.NewFromMap(map[string]string{})
 223	resolver := NewEnvironmentVariableResolver(env)
 224	err := cfg.configureProviders(env, resolver, knownProviders)
 225	require.NoError(t, err)
 226	// Provider should not be configured without credentials
 227	require.Equal(t, cfg.Providers.Len(), 0)
 228}
 229
 230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 231	knownProviders := []catwalk.Provider{
 232		{
 233			ID:          catwalk.InferenceProviderBedrock,
 234			APIKey:      "",
 235			APIEndpoint: "",
 236			Models: []catwalk.Model{{
 237				ID: "some-random-model",
 238			}},
 239		},
 240	}
 241
 242	cfg := &Config{}
 243	cfg.setDefaults("/tmp", "")
 244	env := env.NewFromMap(map[string]string{
 245		"AWS_ACCESS_KEY_ID":     "test-key-id",
 246		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 247	})
 248	resolver := NewEnvironmentVariableResolver(env)
 249	err := cfg.configureProviders(env, resolver, knownProviders)
 250	require.Error(t, err)
 251}
 252
 253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 254	knownProviders := []catwalk.Provider{
 255		{
 256			ID:          catwalk.InferenceProviderVertexAI,
 257			APIKey:      "",
 258			APIEndpoint: "",
 259			Models: []catwalk.Model{{
 260				ID: "gemini-pro",
 261			}},
 262		},
 263	}
 264
 265	cfg := &Config{}
 266	cfg.setDefaults("/tmp", "")
 267	env := env.NewFromMap(map[string]string{
 268		"VERTEXAI_PROJECT":  "test-project",
 269		"VERTEXAI_LOCATION": "us-central1",
 270	})
 271	resolver := NewEnvironmentVariableResolver(env)
 272	err := cfg.configureProviders(env, resolver, knownProviders)
 273	require.NoError(t, err)
 274	require.Equal(t, cfg.Providers.Len(), 1)
 275
 276	vertexProvider, ok := cfg.Providers.Get("vertexai")
 277	require.True(t, ok, "VertexAI provider should be present")
 278	require.Len(t, vertexProvider.Models, 1)
 279	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 280	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 281	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 282}
 283
 284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 285	knownProviders := []catwalk.Provider{
 286		{
 287			ID:          catwalk.InferenceProviderVertexAI,
 288			APIKey:      "",
 289			APIEndpoint: "",
 290			Models: []catwalk.Model{{
 291				ID: "gemini-pro",
 292			}},
 293		},
 294	}
 295
 296	cfg := &Config{}
 297	cfg.setDefaults("/tmp", "")
 298	env := env.NewFromMap(map[string]string{
 299		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 300		"GOOGLE_CLOUD_PROJECT":      "test-project",
 301		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 302	})
 303	resolver := NewEnvironmentVariableResolver(env)
 304	err := cfg.configureProviders(env, resolver, knownProviders)
 305	require.NoError(t, err)
 306	// Provider should not be configured without proper credentials
 307	require.Equal(t, cfg.Providers.Len(), 0)
 308}
 309
 310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 311	knownProviders := []catwalk.Provider{
 312		{
 313			ID:          catwalk.InferenceProviderVertexAI,
 314			APIKey:      "",
 315			APIEndpoint: "",
 316			Models: []catwalk.Model{{
 317				ID: "gemini-pro",
 318			}},
 319		},
 320	}
 321
 322	cfg := &Config{}
 323	cfg.setDefaults("/tmp", "")
 324	env := env.NewFromMap(map[string]string{
 325		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 326		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 327	})
 328	resolver := NewEnvironmentVariableResolver(env)
 329	err := cfg.configureProviders(env, resolver, knownProviders)
 330	require.NoError(t, err)
 331	// Provider should not be configured without project
 332	require.Equal(t, cfg.Providers.Len(), 0)
 333}
 334
 335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 336	knownProviders := []catwalk.Provider{
 337		{
 338			ID:          "openai",
 339			APIKey:      "$OPENAI_API_KEY",
 340			APIEndpoint: "https://api.openai.com/v1",
 341			Models: []catwalk.Model{{
 342				ID: "test-model",
 343			}},
 344		},
 345	}
 346
 347	cfg := &Config{}
 348	cfg.setDefaults("/tmp", "")
 349	env := env.NewFromMap(map[string]string{
 350		"OPENAI_API_KEY": "test-key",
 351	})
 352	resolver := NewEnvironmentVariableResolver(env)
 353	err := cfg.configureProviders(env, resolver, knownProviders)
 354	require.NoError(t, err)
 355	require.Equal(t, cfg.Providers.Len(), 1)
 356
 357	// Provider ID should be set
 358	pc, _ := cfg.Providers.Get("openai")
 359	require.Equal(t, "openai", pc.ID)
 360}
 361
 362func TestConfig_EnabledProviders(t *testing.T) {
 363	t.Run("all providers enabled", func(t *testing.T) {
 364		cfg := &Config{
 365			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 366				"openai": {
 367					ID:      "openai",
 368					APIKey:  "key1",
 369					Disable: false,
 370				},
 371				"anthropic": {
 372					ID:      "anthropic",
 373					APIKey:  "key2",
 374					Disable: false,
 375				},
 376			}),
 377		}
 378
 379		enabled := cfg.EnabledProviders()
 380		require.Len(t, enabled, 2)
 381	})
 382
 383	t.Run("some providers disabled", func(t *testing.T) {
 384		cfg := &Config{
 385			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 386				"openai": {
 387					ID:      "openai",
 388					APIKey:  "key1",
 389					Disable: false,
 390				},
 391				"anthropic": {
 392					ID:      "anthropic",
 393					APIKey:  "key2",
 394					Disable: true,
 395				},
 396			}),
 397		}
 398
 399		enabled := cfg.EnabledProviders()
 400		require.Len(t, enabled, 1)
 401		require.Equal(t, "openai", enabled[0].ID)
 402	})
 403
 404	t.Run("empty providers map", func(t *testing.T) {
 405		cfg := &Config{
 406			Providers: csync.NewMap[string, ProviderConfig](),
 407		}
 408
 409		enabled := cfg.EnabledProviders()
 410		require.Len(t, enabled, 0)
 411	})
 412}
 413
 414func TestConfig_IsConfigured(t *testing.T) {
 415	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 416		cfg := &Config{
 417			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 418				"openai": {
 419					ID:      "openai",
 420					APIKey:  "key1",
 421					Disable: false,
 422				},
 423			}),
 424		}
 425
 426		require.True(t, cfg.IsConfigured())
 427	})
 428
 429	t.Run("returns false when no providers are configured", func(t *testing.T) {
 430		cfg := &Config{
 431			Providers: csync.NewMap[string, ProviderConfig](),
 432		}
 433
 434		require.False(t, cfg.IsConfigured())
 435	})
 436
 437	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 438		cfg := &Config{
 439			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 440				"openai": {
 441					ID:      "openai",
 442					APIKey:  "key1",
 443					Disable: true,
 444				},
 445				"anthropic": {
 446					ID:      "anthropic",
 447					APIKey:  "key2",
 448					Disable: true,
 449				},
 450			}),
 451		}
 452
 453		require.False(t, cfg.IsConfigured())
 454	})
 455}
 456
 457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 458	cfg := &Config{
 459		Options: &Options{
 460			DisabledTools: []string{},
 461		},
 462	}
 463
 464	cfg.SetupAgents()
 465	coderAgent, ok := cfg.Agents[AgentCoder]
 466	require.True(t, ok)
 467	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 468
 469	taskAgent, ok := cfg.Agents[AgentTask]
 470	require.True(t, ok)
 471	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 472}
 473
 474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 475	cfg := &Config{
 476		Options: &Options{
 477			DisabledTools: []string{
 478				"edit",
 479				"download",
 480				"grep",
 481			},
 482		},
 483	}
 484
 485	cfg.SetupAgents()
 486	coderAgent, ok := cfg.Agents[AgentCoder]
 487	require.True(t, ok)
 488	assert.Equal(t, []string{
 489		"agent",
 490		"bash",
 491		"multiedit",
 492		"lsp_diagnostics",
 493		"lsp_references",
 494		"fetch",
 495		"glob",
 496		"ls",
 497		"sourcegraph",
 498		"view",
 499		"write",
 500		"read_mcp_resource",
 501		"list_mcp_resources",
 502	}, coderAgent.AllowedTools)
 503
 504	taskAgent, ok := cfg.Agents[AgentTask]
 505	require.True(t, ok)
 506	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 507}
 508
 509func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 510	cfg := &Config{
 511		Options: &Options{
 512			DisabledTools: []string{
 513				"glob",
 514				"grep",
 515				"ls",
 516				"sourcegraph",
 517				"view",
 518				"read_mcp_resource",
 519				"list_mcp_resources",
 520			},
 521		},
 522	}
 523
 524	cfg.SetupAgents()
 525	coderAgent, ok := cfg.Agents[AgentCoder]
 526	require.True(t, ok)
 527	assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "write"}, coderAgent.AllowedTools)
 528
 529	taskAgent, ok := cfg.Agents[AgentTask]
 530	require.True(t, ok)
 531	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 532}
 533
 534func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 535	knownProviders := []catwalk.Provider{
 536		{
 537			ID:          "openai",
 538			APIKey:      "$OPENAI_API_KEY",
 539			APIEndpoint: "https://api.openai.com/v1",
 540			Models: []catwalk.Model{{
 541				ID: "test-model",
 542			}},
 543		},
 544	}
 545
 546	cfg := &Config{
 547		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 548			"openai": {
 549				Disable: true,
 550			},
 551		}),
 552	}
 553	cfg.setDefaults("/tmp", "")
 554
 555	env := env.NewFromMap(map[string]string{
 556		"OPENAI_API_KEY": "test-key",
 557	})
 558	resolver := NewEnvironmentVariableResolver(env)
 559	err := cfg.configureProviders(env, resolver, knownProviders)
 560	require.NoError(t, err)
 561
 562	require.Equal(t, cfg.Providers.Len(), 1)
 563	prov, exists := cfg.Providers.Get("openai")
 564	require.True(t, exists)
 565	require.True(t, prov.Disable)
 566}
 567
 568func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 569	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 570		cfg := &Config{
 571			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 572				"custom": {
 573					BaseURL: "https://api.custom.com/v1",
 574					Models: []catwalk.Model{{
 575						ID: "test-model",
 576					}},
 577				},
 578				"openai": {
 579					APIKey: "$MISSING",
 580				},
 581			}),
 582		}
 583		cfg.setDefaults("/tmp", "")
 584
 585		env := env.NewFromMap(map[string]string{})
 586		resolver := NewEnvironmentVariableResolver(env)
 587		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 588		require.NoError(t, err)
 589
 590		require.Equal(t, cfg.Providers.Len(), 1)
 591		_, exists := cfg.Providers.Get("custom")
 592		require.True(t, exists)
 593	})
 594
 595	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 596		cfg := &Config{
 597			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 598				"custom": {
 599					APIKey: "test-key",
 600					Models: []catwalk.Model{{
 601						ID: "test-model",
 602					}},
 603				},
 604			}),
 605		}
 606		cfg.setDefaults("/tmp", "")
 607
 608		env := env.NewFromMap(map[string]string{})
 609		resolver := NewEnvironmentVariableResolver(env)
 610		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 611		require.NoError(t, err)
 612
 613		require.Equal(t, cfg.Providers.Len(), 0)
 614		_, exists := cfg.Providers.Get("custom")
 615		require.False(t, exists)
 616	})
 617
 618	t.Run("custom provider with no models is removed", func(t *testing.T) {
 619		cfg := &Config{
 620			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 621				"custom": {
 622					APIKey:  "test-key",
 623					BaseURL: "https://api.custom.com/v1",
 624					Models:  []catwalk.Model{},
 625				},
 626			}),
 627		}
 628		cfg.setDefaults("/tmp", "")
 629
 630		env := env.NewFromMap(map[string]string{})
 631		resolver := NewEnvironmentVariableResolver(env)
 632		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 633		require.NoError(t, err)
 634
 635		require.Equal(t, cfg.Providers.Len(), 0)
 636		_, exists := cfg.Providers.Get("custom")
 637		require.False(t, exists)
 638	})
 639
 640	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 641		cfg := &Config{
 642			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 643				"custom": {
 644					APIKey:  "test-key",
 645					BaseURL: "https://api.custom.com/v1",
 646					Type:    "unsupported",
 647					Models: []catwalk.Model{{
 648						ID: "test-model",
 649					}},
 650				},
 651			}),
 652		}
 653		cfg.setDefaults("/tmp", "")
 654
 655		env := env.NewFromMap(map[string]string{})
 656		resolver := NewEnvironmentVariableResolver(env)
 657		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 658		require.NoError(t, err)
 659
 660		require.Equal(t, cfg.Providers.Len(), 0)
 661		_, exists := cfg.Providers.Get("custom")
 662		require.False(t, exists)
 663	})
 664
 665	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 666		cfg := &Config{
 667			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 668				"custom": {
 669					APIKey:  "test-key",
 670					BaseURL: "https://api.custom.com/v1",
 671					Type:    catwalk.TypeOpenAI,
 672					Models: []catwalk.Model{{
 673						ID: "test-model",
 674					}},
 675				},
 676			}),
 677		}
 678		cfg.setDefaults("/tmp", "")
 679
 680		env := env.NewFromMap(map[string]string{})
 681		resolver := NewEnvironmentVariableResolver(env)
 682		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 683		require.NoError(t, err)
 684
 685		require.Equal(t, cfg.Providers.Len(), 1)
 686		customProvider, exists := cfg.Providers.Get("custom")
 687		require.True(t, exists)
 688		require.Equal(t, "custom", customProvider.ID)
 689		require.Equal(t, "test-key", customProvider.APIKey)
 690		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 691	})
 692
 693	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 694		cfg := &Config{
 695			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 696				"custom-anthropic": {
 697					APIKey:  "test-key",
 698					BaseURL: "https://api.anthropic.com/v1",
 699					Type:    catwalk.TypeAnthropic,
 700					Models: []catwalk.Model{{
 701						ID: "claude-3-sonnet",
 702					}},
 703				},
 704			}),
 705		}
 706		cfg.setDefaults("/tmp", "")
 707
 708		env := env.NewFromMap(map[string]string{})
 709		resolver := NewEnvironmentVariableResolver(env)
 710		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 711		require.NoError(t, err)
 712
 713		require.Equal(t, cfg.Providers.Len(), 1)
 714		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 715		require.True(t, exists)
 716		require.Equal(t, "custom-anthropic", customProvider.ID)
 717		require.Equal(t, "test-key", customProvider.APIKey)
 718		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 719		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 720	})
 721
 722	t.Run("disabled custom provider is removed", func(t *testing.T) {
 723		cfg := &Config{
 724			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 725				"custom": {
 726					APIKey:  "test-key",
 727					BaseURL: "https://api.custom.com/v1",
 728					Type:    catwalk.TypeOpenAI,
 729					Disable: true,
 730					Models: []catwalk.Model{{
 731						ID: "test-model",
 732					}},
 733				},
 734			}),
 735		}
 736		cfg.setDefaults("/tmp", "")
 737
 738		env := env.NewFromMap(map[string]string{})
 739		resolver := NewEnvironmentVariableResolver(env)
 740		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 741		require.NoError(t, err)
 742
 743		require.Equal(t, cfg.Providers.Len(), 0)
 744		_, exists := cfg.Providers.Get("custom")
 745		require.False(t, exists)
 746	})
 747}
 748
 749func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 750	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 751		knownProviders := []catwalk.Provider{
 752			{
 753				ID:          catwalk.InferenceProviderVertexAI,
 754				APIKey:      "",
 755				APIEndpoint: "",
 756				Models: []catwalk.Model{{
 757					ID: "gemini-pro",
 758				}},
 759			},
 760		}
 761
 762		cfg := &Config{
 763			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 764				"vertexai": {
 765					BaseURL: "custom-url",
 766				},
 767			}),
 768		}
 769		cfg.setDefaults("/tmp", "")
 770
 771		env := env.NewFromMap(map[string]string{
 772			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 773		})
 774		resolver := NewEnvironmentVariableResolver(env)
 775		err := cfg.configureProviders(env, resolver, knownProviders)
 776		require.NoError(t, err)
 777
 778		require.Equal(t, cfg.Providers.Len(), 0)
 779		_, exists := cfg.Providers.Get("vertexai")
 780		require.False(t, exists)
 781	})
 782
 783	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 784		knownProviders := []catwalk.Provider{
 785			{
 786				ID:          catwalk.InferenceProviderBedrock,
 787				APIKey:      "",
 788				APIEndpoint: "",
 789				Models: []catwalk.Model{{
 790					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 791				}},
 792			},
 793		}
 794
 795		cfg := &Config{
 796			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 797				"bedrock": {
 798					BaseURL: "custom-url",
 799				},
 800			}),
 801		}
 802		cfg.setDefaults("/tmp", "")
 803
 804		env := env.NewFromMap(map[string]string{})
 805		resolver := NewEnvironmentVariableResolver(env)
 806		err := cfg.configureProviders(env, resolver, knownProviders)
 807		require.NoError(t, err)
 808
 809		require.Equal(t, cfg.Providers.Len(), 0)
 810		_, exists := cfg.Providers.Get("bedrock")
 811		require.False(t, exists)
 812	})
 813
 814	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 815		knownProviders := []catwalk.Provider{
 816			{
 817				ID:          "openai",
 818				APIKey:      "$MISSING_API_KEY",
 819				APIEndpoint: "https://api.openai.com/v1",
 820				Models: []catwalk.Model{{
 821					ID: "test-model",
 822				}},
 823			},
 824		}
 825
 826		cfg := &Config{
 827			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 828				"openai": {
 829					BaseURL: "custom-url",
 830				},
 831			}),
 832		}
 833		cfg.setDefaults("/tmp", "")
 834
 835		env := env.NewFromMap(map[string]string{})
 836		resolver := NewEnvironmentVariableResolver(env)
 837		err := cfg.configureProviders(env, resolver, knownProviders)
 838		require.NoError(t, err)
 839
 840		require.Equal(t, cfg.Providers.Len(), 0)
 841		_, exists := cfg.Providers.Get("openai")
 842		require.False(t, exists)
 843	})
 844
 845	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 846		knownProviders := []catwalk.Provider{
 847			{
 848				ID:          "openai",
 849				APIKey:      "$OPENAI_API_KEY",
 850				APIEndpoint: "$MISSING_ENDPOINT",
 851				Models: []catwalk.Model{{
 852					ID: "test-model",
 853				}},
 854			},
 855		}
 856
 857		cfg := &Config{
 858			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 859				"openai": {
 860					APIKey: "test-key",
 861				},
 862			}),
 863		}
 864		cfg.setDefaults("/tmp", "")
 865
 866		env := env.NewFromMap(map[string]string{
 867			"OPENAI_API_KEY": "test-key",
 868		})
 869		resolver := NewEnvironmentVariableResolver(env)
 870		err := cfg.configureProviders(env, resolver, knownProviders)
 871		require.NoError(t, err)
 872
 873		require.Equal(t, cfg.Providers.Len(), 1)
 874		_, exists := cfg.Providers.Get("openai")
 875		require.True(t, exists)
 876	})
 877}
 878
 879func TestConfig_defaultModelSelection(t *testing.T) {
 880	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 881		knownProviders := []catwalk.Provider{
 882			{
 883				ID:                  "openai",
 884				APIKey:              "abc",
 885				DefaultLargeModelID: "large-model",
 886				DefaultSmallModelID: "small-model",
 887				Models: []catwalk.Model{
 888					{
 889						ID:               "large-model",
 890						DefaultMaxTokens: 1000,
 891					},
 892					{
 893						ID:               "small-model",
 894						DefaultMaxTokens: 500,
 895					},
 896				},
 897			},
 898		}
 899
 900		cfg := &Config{}
 901		cfg.setDefaults("/tmp", "")
 902		env := env.NewFromMap(map[string]string{})
 903		resolver := NewEnvironmentVariableResolver(env)
 904		err := cfg.configureProviders(env, resolver, knownProviders)
 905		require.NoError(t, err)
 906
 907		large, small, err := cfg.defaultModelSelection(knownProviders)
 908		require.NoError(t, err)
 909		require.Equal(t, "large-model", large.Model)
 910		require.Equal(t, "openai", large.Provider)
 911		require.Equal(t, int64(1000), large.MaxTokens)
 912		require.Equal(t, "small-model", small.Model)
 913		require.Equal(t, "openai", small.Provider)
 914		require.Equal(t, int64(500), small.MaxTokens)
 915	})
 916	t.Run("should error if no providers configured", func(t *testing.T) {
 917		knownProviders := []catwalk.Provider{
 918			{
 919				ID:                  "openai",
 920				APIKey:              "$MISSING_KEY",
 921				DefaultLargeModelID: "large-model",
 922				DefaultSmallModelID: "small-model",
 923				Models: []catwalk.Model{
 924					{
 925						ID:               "large-model",
 926						DefaultMaxTokens: 1000,
 927					},
 928					{
 929						ID:               "small-model",
 930						DefaultMaxTokens: 500,
 931					},
 932				},
 933			},
 934		}
 935
 936		cfg := &Config{}
 937		cfg.setDefaults("/tmp", "")
 938		env := env.NewFromMap(map[string]string{})
 939		resolver := NewEnvironmentVariableResolver(env)
 940		err := cfg.configureProviders(env, resolver, knownProviders)
 941		require.NoError(t, err)
 942
 943		_, _, err = cfg.defaultModelSelection(knownProviders)
 944		require.Error(t, err)
 945	})
 946	t.Run("should error if model is missing", func(t *testing.T) {
 947		knownProviders := []catwalk.Provider{
 948			{
 949				ID:                  "openai",
 950				APIKey:              "abc",
 951				DefaultLargeModelID: "large-model",
 952				DefaultSmallModelID: "small-model",
 953				Models: []catwalk.Model{
 954					{
 955						ID:               "not-large-model",
 956						DefaultMaxTokens: 1000,
 957					},
 958					{
 959						ID:               "small-model",
 960						DefaultMaxTokens: 500,
 961					},
 962				},
 963			},
 964		}
 965
 966		cfg := &Config{}
 967		cfg.setDefaults("/tmp", "")
 968		env := env.NewFromMap(map[string]string{})
 969		resolver := NewEnvironmentVariableResolver(env)
 970		err := cfg.configureProviders(env, resolver, knownProviders)
 971		require.NoError(t, err)
 972		_, _, err = cfg.defaultModelSelection(knownProviders)
 973		require.Error(t, err)
 974	})
 975
 976	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 977		knownProviders := []catwalk.Provider{
 978			{
 979				ID:                  "openai",
 980				APIKey:              "$MISSING", // will not be included in the config
 981				DefaultLargeModelID: "large-model",
 982				DefaultSmallModelID: "small-model",
 983				Models: []catwalk.Model{
 984					{
 985						ID:               "not-large-model",
 986						DefaultMaxTokens: 1000,
 987					},
 988					{
 989						ID:               "small-model",
 990						DefaultMaxTokens: 500,
 991					},
 992				},
 993			},
 994		}
 995
 996		cfg := &Config{
 997			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 998				"custom": {
 999					APIKey:  "test-key",
1000					BaseURL: "https://api.custom.com/v1",
1001					Models: []catwalk.Model{
1002						{
1003							ID:               "model",
1004							DefaultMaxTokens: 600,
1005						},
1006					},
1007				},
1008			}),
1009		}
1010		cfg.setDefaults("/tmp", "")
1011		env := env.NewFromMap(map[string]string{})
1012		resolver := NewEnvironmentVariableResolver(env)
1013		err := cfg.configureProviders(env, resolver, knownProviders)
1014		require.NoError(t, err)
1015		large, small, err := cfg.defaultModelSelection(knownProviders)
1016		require.NoError(t, err)
1017		require.Equal(t, "model", large.Model)
1018		require.Equal(t, "custom", large.Provider)
1019		require.Equal(t, int64(600), large.MaxTokens)
1020		require.Equal(t, "model", small.Model)
1021		require.Equal(t, "custom", small.Provider)
1022		require.Equal(t, int64(600), small.MaxTokens)
1023	})
1024
1025	t.Run("should fail if no model configured", func(t *testing.T) {
1026		knownProviders := []catwalk.Provider{
1027			{
1028				ID:                  "openai",
1029				APIKey:              "$MISSING", // will not be included in the config
1030				DefaultLargeModelID: "large-model",
1031				DefaultSmallModelID: "small-model",
1032				Models: []catwalk.Model{
1033					{
1034						ID:               "not-large-model",
1035						DefaultMaxTokens: 1000,
1036					},
1037					{
1038						ID:               "small-model",
1039						DefaultMaxTokens: 500,
1040					},
1041				},
1042			},
1043		}
1044
1045		cfg := &Config{
1046			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1047				"custom": {
1048					APIKey:  "test-key",
1049					BaseURL: "https://api.custom.com/v1",
1050					Models:  []catwalk.Model{},
1051				},
1052			}),
1053		}
1054		cfg.setDefaults("/tmp", "")
1055		env := env.NewFromMap(map[string]string{})
1056		resolver := NewEnvironmentVariableResolver(env)
1057		err := cfg.configureProviders(env, resolver, knownProviders)
1058		require.NoError(t, err)
1059		_, _, err = cfg.defaultModelSelection(knownProviders)
1060		require.Error(t, err)
1061	})
1062	t.Run("should use the default provider first", func(t *testing.T) {
1063		knownProviders := []catwalk.Provider{
1064			{
1065				ID:                  "openai",
1066				APIKey:              "set",
1067				DefaultLargeModelID: "large-model",
1068				DefaultSmallModelID: "small-model",
1069				Models: []catwalk.Model{
1070					{
1071						ID:               "large-model",
1072						DefaultMaxTokens: 1000,
1073					},
1074					{
1075						ID:               "small-model",
1076						DefaultMaxTokens: 500,
1077					},
1078				},
1079			},
1080		}
1081
1082		cfg := &Config{
1083			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1084				"custom": {
1085					APIKey:  "test-key",
1086					BaseURL: "https://api.custom.com/v1",
1087					Models: []catwalk.Model{
1088						{
1089							ID:               "large-model",
1090							DefaultMaxTokens: 1000,
1091						},
1092					},
1093				},
1094			}),
1095		}
1096		cfg.setDefaults("/tmp", "")
1097		env := env.NewFromMap(map[string]string{})
1098		resolver := NewEnvironmentVariableResolver(env)
1099		err := cfg.configureProviders(env, resolver, knownProviders)
1100		require.NoError(t, err)
1101		large, small, err := cfg.defaultModelSelection(knownProviders)
1102		require.NoError(t, err)
1103		require.Equal(t, "large-model", large.Model)
1104		require.Equal(t, "openai", large.Provider)
1105		require.Equal(t, int64(1000), large.MaxTokens)
1106		require.Equal(t, "small-model", small.Model)
1107		require.Equal(t, "openai", small.Provider)
1108		require.Equal(t, int64(500), small.MaxTokens)
1109	})
1110}
1111
1112func TestConfig_configureSelectedModels(t *testing.T) {
1113	t.Run("should override defaults", func(t *testing.T) {
1114		knownProviders := []catwalk.Provider{
1115			{
1116				ID:                  "openai",
1117				APIKey:              "abc",
1118				DefaultLargeModelID: "large-model",
1119				DefaultSmallModelID: "small-model",
1120				Models: []catwalk.Model{
1121					{
1122						ID:               "larger-model",
1123						DefaultMaxTokens: 2000,
1124					},
1125					{
1126						ID:               "large-model",
1127						DefaultMaxTokens: 1000,
1128					},
1129					{
1130						ID:               "small-model",
1131						DefaultMaxTokens: 500,
1132					},
1133				},
1134			},
1135		}
1136
1137		cfg := &Config{
1138			Models: map[SelectedModelType]SelectedModel{
1139				"large": {
1140					Model: "larger-model",
1141				},
1142			},
1143		}
1144		cfg.setDefaults("/tmp", "")
1145		env := env.NewFromMap(map[string]string{})
1146		resolver := NewEnvironmentVariableResolver(env)
1147		err := cfg.configureProviders(env, resolver, knownProviders)
1148		require.NoError(t, err)
1149
1150		err = cfg.configureSelectedModels(knownProviders)
1151		require.NoError(t, err)
1152		large := cfg.Models[SelectedModelTypeLarge]
1153		small := cfg.Models[SelectedModelTypeSmall]
1154		require.Equal(t, "larger-model", large.Model)
1155		require.Equal(t, "openai", large.Provider)
1156		require.Equal(t, int64(2000), large.MaxTokens)
1157		require.Equal(t, "small-model", small.Model)
1158		require.Equal(t, "openai", small.Provider)
1159		require.Equal(t, int64(500), small.MaxTokens)
1160	})
1161	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1162		knownProviders := []catwalk.Provider{
1163			{
1164				ID:                  "openai",
1165				APIKey:              "abc",
1166				DefaultLargeModelID: "large-model",
1167				DefaultSmallModelID: "small-model",
1168				Models: []catwalk.Model{
1169					{
1170						ID:               "large-model",
1171						DefaultMaxTokens: 1000,
1172					},
1173					{
1174						ID:               "small-model",
1175						DefaultMaxTokens: 500,
1176					},
1177				},
1178			},
1179			{
1180				ID:                  "anthropic",
1181				APIKey:              "abc",
1182				DefaultLargeModelID: "a-large-model",
1183				DefaultSmallModelID: "a-small-model",
1184				Models: []catwalk.Model{
1185					{
1186						ID:               "a-large-model",
1187						DefaultMaxTokens: 1000,
1188					},
1189					{
1190						ID:               "a-small-model",
1191						DefaultMaxTokens: 200,
1192					},
1193				},
1194			},
1195		}
1196
1197		cfg := &Config{
1198			Models: map[SelectedModelType]SelectedModel{
1199				"small": {
1200					Model:     "a-small-model",
1201					Provider:  "anthropic",
1202					MaxTokens: 300,
1203				},
1204			},
1205		}
1206		cfg.setDefaults("/tmp", "")
1207		env := env.NewFromMap(map[string]string{})
1208		resolver := NewEnvironmentVariableResolver(env)
1209		err := cfg.configureProviders(env, resolver, knownProviders)
1210		require.NoError(t, err)
1211
1212		err = cfg.configureSelectedModels(knownProviders)
1213		require.NoError(t, err)
1214		large := cfg.Models[SelectedModelTypeLarge]
1215		small := cfg.Models[SelectedModelTypeSmall]
1216		require.Equal(t, "large-model", large.Model)
1217		require.Equal(t, "openai", large.Provider)
1218		require.Equal(t, int64(1000), large.MaxTokens)
1219		require.Equal(t, "a-small-model", small.Model)
1220		require.Equal(t, "anthropic", small.Provider)
1221		require.Equal(t, int64(300), small.MaxTokens)
1222	})
1223
1224	t.Run("should override the max tokens only", func(t *testing.T) {
1225		knownProviders := []catwalk.Provider{
1226			{
1227				ID:                  "openai",
1228				APIKey:              "abc",
1229				DefaultLargeModelID: "large-model",
1230				DefaultSmallModelID: "small-model",
1231				Models: []catwalk.Model{
1232					{
1233						ID:               "large-model",
1234						DefaultMaxTokens: 1000,
1235					},
1236					{
1237						ID:               "small-model",
1238						DefaultMaxTokens: 500,
1239					},
1240				},
1241			},
1242		}
1243
1244		cfg := &Config{
1245			Models: map[SelectedModelType]SelectedModel{
1246				"large": {
1247					MaxTokens: 100,
1248				},
1249			},
1250		}
1251		cfg.setDefaults("/tmp", "")
1252		env := env.NewFromMap(map[string]string{})
1253		resolver := NewEnvironmentVariableResolver(env)
1254		err := cfg.configureProviders(env, resolver, knownProviders)
1255		require.NoError(t, err)
1256
1257		err = cfg.configureSelectedModels(knownProviders)
1258		require.NoError(t, err)
1259		large := cfg.Models[SelectedModelTypeLarge]
1260		require.Equal(t, "large-model", large.Model)
1261		require.Equal(t, "openai", large.Provider)
1262		require.Equal(t, int64(100), large.MaxTokens)
1263	})
1264}