load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/stretchr/testify/assert"
  15	"github.com/stretchr/testify/require"
  16)
  17
  18func TestMain(m *testing.M) {
  19	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  20
  21	exitVal := m.Run()
  22	os.Exit(exitVal)
  23}
  24
  25func TestConfig_LoadFromReaders(t *testing.T) {
  26	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  27	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  28	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  29
  30	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  31
  32	require.NoError(t, err)
  33	require.NotNil(t, loadedConfig)
  34	require.Equal(t, 1, loadedConfig.Providers.Len())
  35	pc, _ := loadedConfig.Providers.Get("openai")
  36	require.Equal(t, "key2", pc.APIKey)
  37	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  38}
  39
  40func TestConfig_setDefaults(t *testing.T) {
  41	cfg := &Config{}
  42
  43	cfg.setDefaults("/tmp", "")
  44
  45	require.NotNil(t, cfg.Options)
  46	require.NotNil(t, cfg.Options.TUI)
  47	require.NotNil(t, cfg.Options.ContextPaths)
  48	require.NotNil(t, cfg.Providers)
  49	require.NotNil(t, cfg.Models)
  50	require.NotNil(t, cfg.LSP)
  51	require.NotNil(t, cfg.MCP)
  52	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  53	for _, path := range defaultContextPaths {
  54		require.Contains(t, cfg.Options.ContextPaths, path)
  55	}
  56	require.Equal(t, "/tmp", cfg.workingDir)
  57}
  58
  59func TestConfig_configureProviders(t *testing.T) {
  60	knownProviders := []catwalk.Provider{
  61		{
  62			ID:          "openai",
  63			APIKey:      "$OPENAI_API_KEY",
  64			APIEndpoint: "https://api.openai.com/v1",
  65			Models: []catwalk.Model{{
  66				ID: "test-model",
  67			}},
  68		},
  69	}
  70
  71	cfg := &Config{}
  72	cfg.setDefaults("/tmp", "")
  73	env := env.NewFromMap(map[string]string{
  74		"OPENAI_API_KEY": "test-key",
  75	})
  76	resolver := NewEnvironmentVariableResolver(env)
  77	err := cfg.configureProviders(env, resolver, knownProviders)
  78	require.NoError(t, err)
  79	require.Equal(t, 1, cfg.Providers.Len())
  80
  81	// We want to make sure that we keep the configured API key as a placeholder
  82	pc, _ := cfg.Providers.Get("openai")
  83	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  84}
  85
  86func TestConfig_configureProvidersWithOverride(t *testing.T) {
  87	knownProviders := []catwalk.Provider{
  88		{
  89			ID:          "openai",
  90			APIKey:      "$OPENAI_API_KEY",
  91			APIEndpoint: "https://api.openai.com/v1",
  92			Models: []catwalk.Model{{
  93				ID: "test-model",
  94			}},
  95		},
  96	}
  97
  98	cfg := &Config{
  99		Providers: csync.NewMap[string, ProviderConfig](),
 100	}
 101	cfg.Providers.Set("openai", ProviderConfig{
 102		APIKey:  "xyz",
 103		BaseURL: "https://api.openai.com/v2",
 104		Models: []catwalk.Model{
 105			{
 106				ID:   "test-model",
 107				Name: "Updated",
 108			},
 109			{
 110				ID: "another-model",
 111			},
 112		},
 113	})
 114	cfg.setDefaults("/tmp", "")
 115
 116	env := env.NewFromMap(map[string]string{
 117		"OPENAI_API_KEY": "test-key",
 118	})
 119	resolver := NewEnvironmentVariableResolver(env)
 120	err := cfg.configureProviders(env, resolver, knownProviders)
 121	require.NoError(t, err)
 122	require.Equal(t, 1, cfg.Providers.Len())
 123
 124	// We want to make sure that we keep the configured API key as a placeholder
 125	pc, _ := cfg.Providers.Get("openai")
 126	require.Equal(t, "xyz", pc.APIKey)
 127	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 128	require.Len(t, pc.Models, 2)
 129	require.Equal(t, "Updated", pc.Models[0].Name)
 130}
 131
 132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 133	knownProviders := []catwalk.Provider{
 134		{
 135			ID:          "openai",
 136			APIKey:      "$OPENAI_API_KEY",
 137			APIEndpoint: "https://api.openai.com/v1",
 138			Models: []catwalk.Model{{
 139				ID: "test-model",
 140			}},
 141		},
 142	}
 143
 144	cfg := &Config{
 145		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 146			"custom": {
 147				APIKey:  "xyz",
 148				BaseURL: "https://api.someendpoint.com/v2",
 149				Models: []catwalk.Model{
 150					{
 151						ID: "test-model",
 152					},
 153				},
 154			},
 155		}),
 156	}
 157	cfg.setDefaults("/tmp", "")
 158	env := env.NewFromMap(map[string]string{
 159		"OPENAI_API_KEY": "test-key",
 160	})
 161	resolver := NewEnvironmentVariableResolver(env)
 162	err := cfg.configureProviders(env, resolver, knownProviders)
 163	require.NoError(t, err)
 164	// Should be to because of the env variable
 165	require.Equal(t, cfg.Providers.Len(), 2)
 166
 167	// We want to make sure that we keep the configured API key as a placeholder
 168	pc, _ := cfg.Providers.Get("custom")
 169	require.Equal(t, "xyz", pc.APIKey)
 170	// Make sure we set the ID correctly
 171	require.Equal(t, "custom", pc.ID)
 172	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 173	require.Len(t, pc.Models, 1)
 174
 175	_, ok := cfg.Providers.Get("openai")
 176	require.True(t, ok, "OpenAI provider should still be present")
 177}
 178
 179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 180	knownProviders := []catwalk.Provider{
 181		{
 182			ID:          catwalk.InferenceProviderBedrock,
 183			APIKey:      "",
 184			APIEndpoint: "",
 185			Models: []catwalk.Model{{
 186				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 187			}},
 188		},
 189	}
 190
 191	cfg := &Config{}
 192	cfg.setDefaults("/tmp", "")
 193	env := env.NewFromMap(map[string]string{
 194		"AWS_ACCESS_KEY_ID":     "test-key-id",
 195		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 196	})
 197	resolver := NewEnvironmentVariableResolver(env)
 198	err := cfg.configureProviders(env, resolver, knownProviders)
 199	require.NoError(t, err)
 200	require.Equal(t, cfg.Providers.Len(), 1)
 201
 202	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 203	require.True(t, ok, "Bedrock provider should be present")
 204	require.Len(t, bedrockProvider.Models, 1)
 205	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 206}
 207
 208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 209	knownProviders := []catwalk.Provider{
 210		{
 211			ID:          catwalk.InferenceProviderBedrock,
 212			APIKey:      "",
 213			APIEndpoint: "",
 214			Models: []catwalk.Model{{
 215				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 216			}},
 217		},
 218	}
 219
 220	cfg := &Config{}
 221	cfg.setDefaults("/tmp", "")
 222	env := env.NewFromMap(map[string]string{})
 223	resolver := NewEnvironmentVariableResolver(env)
 224	err := cfg.configureProviders(env, resolver, knownProviders)
 225	require.NoError(t, err)
 226	// Provider should not be configured without credentials
 227	require.Equal(t, cfg.Providers.Len(), 0)
 228}
 229
 230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 231	knownProviders := []catwalk.Provider{
 232		{
 233			ID:          catwalk.InferenceProviderBedrock,
 234			APIKey:      "",
 235			APIEndpoint: "",
 236			Models: []catwalk.Model{{
 237				ID: "some-random-model",
 238			}},
 239		},
 240	}
 241
 242	cfg := &Config{}
 243	cfg.setDefaults("/tmp", "")
 244	env := env.NewFromMap(map[string]string{
 245		"AWS_ACCESS_KEY_ID":     "test-key-id",
 246		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 247	})
 248	resolver := NewEnvironmentVariableResolver(env)
 249	err := cfg.configureProviders(env, resolver, knownProviders)
 250	require.Error(t, err)
 251}
 252
 253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 254	knownProviders := []catwalk.Provider{
 255		{
 256			ID:          catwalk.InferenceProviderVertexAI,
 257			APIKey:      "",
 258			APIEndpoint: "",
 259			Models: []catwalk.Model{{
 260				ID: "gemini-pro",
 261			}},
 262		},
 263	}
 264
 265	cfg := &Config{}
 266	cfg.setDefaults("/tmp", "")
 267	env := env.NewFromMap(map[string]string{
 268		"VERTEXAI_PROJECT":  "test-project",
 269		"VERTEXAI_LOCATION": "us-central1",
 270	})
 271	resolver := NewEnvironmentVariableResolver(env)
 272	err := cfg.configureProviders(env, resolver, knownProviders)
 273	require.NoError(t, err)
 274	require.Equal(t, cfg.Providers.Len(), 1)
 275
 276	vertexProvider, ok := cfg.Providers.Get("vertexai")
 277	require.True(t, ok, "VertexAI provider should be present")
 278	require.Len(t, vertexProvider.Models, 1)
 279	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 280	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 281	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 282}
 283
 284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 285	knownProviders := []catwalk.Provider{
 286		{
 287			ID:          catwalk.InferenceProviderVertexAI,
 288			APIKey:      "",
 289			APIEndpoint: "",
 290			Models: []catwalk.Model{{
 291				ID: "gemini-pro",
 292			}},
 293		},
 294	}
 295
 296	cfg := &Config{}
 297	cfg.setDefaults("/tmp", "")
 298	env := env.NewFromMap(map[string]string{
 299		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 300		"GOOGLE_CLOUD_PROJECT":      "test-project",
 301		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 302	})
 303	resolver := NewEnvironmentVariableResolver(env)
 304	err := cfg.configureProviders(env, resolver, knownProviders)
 305	require.NoError(t, err)
 306	// Provider should not be configured without proper credentials
 307	require.Equal(t, cfg.Providers.Len(), 0)
 308}
 309
 310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 311	knownProviders := []catwalk.Provider{
 312		{
 313			ID:          catwalk.InferenceProviderVertexAI,
 314			APIKey:      "",
 315			APIEndpoint: "",
 316			Models: []catwalk.Model{{
 317				ID: "gemini-pro",
 318			}},
 319		},
 320	}
 321
 322	cfg := &Config{}
 323	cfg.setDefaults("/tmp", "")
 324	env := env.NewFromMap(map[string]string{
 325		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 326		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 327	})
 328	resolver := NewEnvironmentVariableResolver(env)
 329	err := cfg.configureProviders(env, resolver, knownProviders)
 330	require.NoError(t, err)
 331	// Provider should not be configured without project
 332	require.Equal(t, cfg.Providers.Len(), 0)
 333}
 334
 335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 336	knownProviders := []catwalk.Provider{
 337		{
 338			ID:          "openai",
 339			APIKey:      "$OPENAI_API_KEY",
 340			APIEndpoint: "https://api.openai.com/v1",
 341			Models: []catwalk.Model{{
 342				ID: "test-model",
 343			}},
 344		},
 345	}
 346
 347	cfg := &Config{}
 348	cfg.setDefaults("/tmp", "")
 349	env := env.NewFromMap(map[string]string{
 350		"OPENAI_API_KEY": "test-key",
 351	})
 352	resolver := NewEnvironmentVariableResolver(env)
 353	err := cfg.configureProviders(env, resolver, knownProviders)
 354	require.NoError(t, err)
 355	require.Equal(t, cfg.Providers.Len(), 1)
 356
 357	// Provider ID should be set
 358	pc, _ := cfg.Providers.Get("openai")
 359	require.Equal(t, "openai", pc.ID)
 360}
 361
 362func TestConfig_EnabledProviders(t *testing.T) {
 363	t.Run("all providers enabled", func(t *testing.T) {
 364		cfg := &Config{
 365			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 366				"openai": {
 367					ID:      "openai",
 368					APIKey:  "key1",
 369					Disable: false,
 370				},
 371				"anthropic": {
 372					ID:      "anthropic",
 373					APIKey:  "key2",
 374					Disable: false,
 375				},
 376			}),
 377		}
 378
 379		enabled := cfg.EnabledProviders()
 380		require.Len(t, enabled, 2)
 381	})
 382
 383	t.Run("some providers disabled", func(t *testing.T) {
 384		cfg := &Config{
 385			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 386				"openai": {
 387					ID:      "openai",
 388					APIKey:  "key1",
 389					Disable: false,
 390				},
 391				"anthropic": {
 392					ID:      "anthropic",
 393					APIKey:  "key2",
 394					Disable: true,
 395				},
 396			}),
 397		}
 398
 399		enabled := cfg.EnabledProviders()
 400		require.Len(t, enabled, 1)
 401		require.Equal(t, "openai", enabled[0].ID)
 402	})
 403
 404	t.Run("empty providers map", func(t *testing.T) {
 405		cfg := &Config{
 406			Providers: csync.NewMap[string, ProviderConfig](),
 407		}
 408
 409		enabled := cfg.EnabledProviders()
 410		require.Len(t, enabled, 0)
 411	})
 412}
 413
 414func TestConfig_IsConfigured(t *testing.T) {
 415	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 416		cfg := &Config{
 417			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 418				"openai": {
 419					ID:      "openai",
 420					APIKey:  "key1",
 421					Disable: false,
 422				},
 423			}),
 424		}
 425
 426		require.True(t, cfg.IsConfigured())
 427	})
 428
 429	t.Run("returns false when no providers are configured", func(t *testing.T) {
 430		cfg := &Config{
 431			Providers: csync.NewMap[string, ProviderConfig](),
 432		}
 433
 434		require.False(t, cfg.IsConfigured())
 435	})
 436
 437	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 438		cfg := &Config{
 439			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 440				"openai": {
 441					ID:      "openai",
 442					APIKey:  "key1",
 443					Disable: true,
 444				},
 445				"anthropic": {
 446					ID:      "anthropic",
 447					APIKey:  "key2",
 448					Disable: true,
 449				},
 450			}),
 451		}
 452
 453		require.False(t, cfg.IsConfigured())
 454	})
 455}
 456
 457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 458	cfg := &Config{
 459		Options: &Options{
 460			DisabledTools: []string{},
 461		},
 462	}
 463
 464	cfg.SetupAgents()
 465	coderAgent, ok := cfg.Agents[AgentCoder]
 466	require.True(t, ok)
 467	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 468
 469	taskAgent, ok := cfg.Agents[AgentTask]
 470	require.True(t, ok)
 471	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 472}
 473
 474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 475	cfg := &Config{
 476		Options: &Options{
 477			DisabledTools: []string{
 478				"edit",
 479				"download",
 480				"grep",
 481			},
 482		},
 483	}
 484
 485	cfg.SetupAgents()
 486	coderAgent, ok := cfg.Agents[AgentCoder]
 487	require.True(t, ok)
 488
 489	assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
 490
 491	taskAgent, ok := cfg.Agents[AgentTask]
 492	require.True(t, ok)
 493	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 494}
 495
 496func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 497	cfg := &Config{
 498		Options: &Options{
 499			DisabledTools: []string{
 500				"glob",
 501				"grep",
 502				"ls",
 503				"sourcegraph",
 504				"view",
 505			},
 506		},
 507	}
 508
 509	cfg.SetupAgents()
 510	coderAgent, ok := cfg.Agents[AgentCoder]
 511	require.True(t, ok)
 512	assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools)
 513
 514	taskAgent, ok := cfg.Agents[AgentTask]
 515	require.True(t, ok)
 516	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 517}
 518
 519func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 520	knownProviders := []catwalk.Provider{
 521		{
 522			ID:          "openai",
 523			APIKey:      "$OPENAI_API_KEY",
 524			APIEndpoint: "https://api.openai.com/v1",
 525			Models: []catwalk.Model{{
 526				ID: "test-model",
 527			}},
 528		},
 529	}
 530
 531	cfg := &Config{
 532		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 533			"openai": {
 534				Disable: true,
 535			},
 536		}),
 537	}
 538	cfg.setDefaults("/tmp", "")
 539
 540	env := env.NewFromMap(map[string]string{
 541		"OPENAI_API_KEY": "test-key",
 542	})
 543	resolver := NewEnvironmentVariableResolver(env)
 544	err := cfg.configureProviders(env, resolver, knownProviders)
 545	require.NoError(t, err)
 546
 547	require.Equal(t, cfg.Providers.Len(), 1)
 548	prov, exists := cfg.Providers.Get("openai")
 549	require.True(t, exists)
 550	require.True(t, prov.Disable)
 551}
 552
 553func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 554	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 555		cfg := &Config{
 556			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 557				"custom": {
 558					BaseURL: "https://api.custom.com/v1",
 559					Models: []catwalk.Model{{
 560						ID: "test-model",
 561					}},
 562				},
 563				"openai": {
 564					APIKey: "$MISSING",
 565				},
 566			}),
 567		}
 568		cfg.setDefaults("/tmp", "")
 569
 570		env := env.NewFromMap(map[string]string{})
 571		resolver := NewEnvironmentVariableResolver(env)
 572		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 573		require.NoError(t, err)
 574
 575		require.Equal(t, cfg.Providers.Len(), 1)
 576		_, exists := cfg.Providers.Get("custom")
 577		require.True(t, exists)
 578	})
 579
 580	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 581		cfg := &Config{
 582			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 583				"custom": {
 584					APIKey: "test-key",
 585					Models: []catwalk.Model{{
 586						ID: "test-model",
 587					}},
 588				},
 589			}),
 590		}
 591		cfg.setDefaults("/tmp", "")
 592
 593		env := env.NewFromMap(map[string]string{})
 594		resolver := NewEnvironmentVariableResolver(env)
 595		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 596		require.NoError(t, err)
 597
 598		require.Equal(t, cfg.Providers.Len(), 0)
 599		_, exists := cfg.Providers.Get("custom")
 600		require.False(t, exists)
 601	})
 602
 603	t.Run("custom provider with no models is removed", func(t *testing.T) {
 604		cfg := &Config{
 605			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 606				"custom": {
 607					APIKey:  "test-key",
 608					BaseURL: "https://api.custom.com/v1",
 609					Models:  []catwalk.Model{},
 610				},
 611			}),
 612		}
 613		cfg.setDefaults("/tmp", "")
 614
 615		env := env.NewFromMap(map[string]string{})
 616		resolver := NewEnvironmentVariableResolver(env)
 617		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 618		require.NoError(t, err)
 619
 620		require.Equal(t, cfg.Providers.Len(), 0)
 621		_, exists := cfg.Providers.Get("custom")
 622		require.False(t, exists)
 623	})
 624
 625	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 626		cfg := &Config{
 627			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 628				"custom": {
 629					APIKey:  "test-key",
 630					BaseURL: "https://api.custom.com/v1",
 631					Type:    "unsupported",
 632					Models: []catwalk.Model{{
 633						ID: "test-model",
 634					}},
 635				},
 636			}),
 637		}
 638		cfg.setDefaults("/tmp", "")
 639
 640		env := env.NewFromMap(map[string]string{})
 641		resolver := NewEnvironmentVariableResolver(env)
 642		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 643		require.NoError(t, err)
 644
 645		require.Equal(t, cfg.Providers.Len(), 0)
 646		_, exists := cfg.Providers.Get("custom")
 647		require.False(t, exists)
 648	})
 649
 650	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 651		cfg := &Config{
 652			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 653				"custom": {
 654					APIKey:  "test-key",
 655					BaseURL: "https://api.custom.com/v1",
 656					Type:    catwalk.TypeOpenAI,
 657					Models: []catwalk.Model{{
 658						ID: "test-model",
 659					}},
 660				},
 661			}),
 662		}
 663		cfg.setDefaults("/tmp", "")
 664
 665		env := env.NewFromMap(map[string]string{})
 666		resolver := NewEnvironmentVariableResolver(env)
 667		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 668		require.NoError(t, err)
 669
 670		require.Equal(t, cfg.Providers.Len(), 1)
 671		customProvider, exists := cfg.Providers.Get("custom")
 672		require.True(t, exists)
 673		require.Equal(t, "custom", customProvider.ID)
 674		require.Equal(t, "test-key", customProvider.APIKey)
 675		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 676	})
 677
 678	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 679		cfg := &Config{
 680			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 681				"custom-anthropic": {
 682					APIKey:  "test-key",
 683					BaseURL: "https://api.anthropic.com/v1",
 684					Type:    catwalk.TypeAnthropic,
 685					Models: []catwalk.Model{{
 686						ID: "claude-3-sonnet",
 687					}},
 688				},
 689			}),
 690		}
 691		cfg.setDefaults("/tmp", "")
 692
 693		env := env.NewFromMap(map[string]string{})
 694		resolver := NewEnvironmentVariableResolver(env)
 695		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 696		require.NoError(t, err)
 697
 698		require.Equal(t, cfg.Providers.Len(), 1)
 699		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 700		require.True(t, exists)
 701		require.Equal(t, "custom-anthropic", customProvider.ID)
 702		require.Equal(t, "test-key", customProvider.APIKey)
 703		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 704		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 705	})
 706
 707	t.Run("disabled custom provider is removed", func(t *testing.T) {
 708		cfg := &Config{
 709			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 710				"custom": {
 711					APIKey:  "test-key",
 712					BaseURL: "https://api.custom.com/v1",
 713					Type:    catwalk.TypeOpenAI,
 714					Disable: true,
 715					Models: []catwalk.Model{{
 716						ID: "test-model",
 717					}},
 718				},
 719			}),
 720		}
 721		cfg.setDefaults("/tmp", "")
 722
 723		env := env.NewFromMap(map[string]string{})
 724		resolver := NewEnvironmentVariableResolver(env)
 725		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 726		require.NoError(t, err)
 727
 728		require.Equal(t, cfg.Providers.Len(), 0)
 729		_, exists := cfg.Providers.Get("custom")
 730		require.False(t, exists)
 731	})
 732}
 733
 734func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 735	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 736		knownProviders := []catwalk.Provider{
 737			{
 738				ID:          catwalk.InferenceProviderVertexAI,
 739				APIKey:      "",
 740				APIEndpoint: "",
 741				Models: []catwalk.Model{{
 742					ID: "gemini-pro",
 743				}},
 744			},
 745		}
 746
 747		cfg := &Config{
 748			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 749				"vertexai": {
 750					BaseURL: "custom-url",
 751				},
 752			}),
 753		}
 754		cfg.setDefaults("/tmp", "")
 755
 756		env := env.NewFromMap(map[string]string{
 757			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 758		})
 759		resolver := NewEnvironmentVariableResolver(env)
 760		err := cfg.configureProviders(env, resolver, knownProviders)
 761		require.NoError(t, err)
 762
 763		require.Equal(t, cfg.Providers.Len(), 0)
 764		_, exists := cfg.Providers.Get("vertexai")
 765		require.False(t, exists)
 766	})
 767
 768	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 769		knownProviders := []catwalk.Provider{
 770			{
 771				ID:          catwalk.InferenceProviderBedrock,
 772				APIKey:      "",
 773				APIEndpoint: "",
 774				Models: []catwalk.Model{{
 775					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 776				}},
 777			},
 778		}
 779
 780		cfg := &Config{
 781			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 782				"bedrock": {
 783					BaseURL: "custom-url",
 784				},
 785			}),
 786		}
 787		cfg.setDefaults("/tmp", "")
 788
 789		env := env.NewFromMap(map[string]string{})
 790		resolver := NewEnvironmentVariableResolver(env)
 791		err := cfg.configureProviders(env, resolver, knownProviders)
 792		require.NoError(t, err)
 793
 794		require.Equal(t, cfg.Providers.Len(), 0)
 795		_, exists := cfg.Providers.Get("bedrock")
 796		require.False(t, exists)
 797	})
 798
 799	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 800		knownProviders := []catwalk.Provider{
 801			{
 802				ID:          "openai",
 803				APIKey:      "$MISSING_API_KEY",
 804				APIEndpoint: "https://api.openai.com/v1",
 805				Models: []catwalk.Model{{
 806					ID: "test-model",
 807				}},
 808			},
 809		}
 810
 811		cfg := &Config{
 812			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 813				"openai": {
 814					BaseURL: "custom-url",
 815				},
 816			}),
 817		}
 818		cfg.setDefaults("/tmp", "")
 819
 820		env := env.NewFromMap(map[string]string{})
 821		resolver := NewEnvironmentVariableResolver(env)
 822		err := cfg.configureProviders(env, resolver, knownProviders)
 823		require.NoError(t, err)
 824
 825		require.Equal(t, cfg.Providers.Len(), 0)
 826		_, exists := cfg.Providers.Get("openai")
 827		require.False(t, exists)
 828	})
 829
 830	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 831		knownProviders := []catwalk.Provider{
 832			{
 833				ID:          "openai",
 834				APIKey:      "$OPENAI_API_KEY",
 835				APIEndpoint: "$MISSING_ENDPOINT",
 836				Models: []catwalk.Model{{
 837					ID: "test-model",
 838				}},
 839			},
 840		}
 841
 842		cfg := &Config{
 843			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 844				"openai": {
 845					APIKey: "test-key",
 846				},
 847			}),
 848		}
 849		cfg.setDefaults("/tmp", "")
 850
 851		env := env.NewFromMap(map[string]string{
 852			"OPENAI_API_KEY": "test-key",
 853		})
 854		resolver := NewEnvironmentVariableResolver(env)
 855		err := cfg.configureProviders(env, resolver, knownProviders)
 856		require.NoError(t, err)
 857
 858		require.Equal(t, cfg.Providers.Len(), 1)
 859		_, exists := cfg.Providers.Get("openai")
 860		require.True(t, exists)
 861	})
 862}
 863
 864func TestConfig_defaultModelSelection(t *testing.T) {
 865	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 866		knownProviders := []catwalk.Provider{
 867			{
 868				ID:                  "openai",
 869				APIKey:              "abc",
 870				DefaultLargeModelID: "large-model",
 871				DefaultSmallModelID: "small-model",
 872				Models: []catwalk.Model{
 873					{
 874						ID:               "large-model",
 875						DefaultMaxTokens: 1000,
 876					},
 877					{
 878						ID:               "small-model",
 879						DefaultMaxTokens: 500,
 880					},
 881				},
 882			},
 883		}
 884
 885		cfg := &Config{}
 886		cfg.setDefaults("/tmp", "")
 887		env := env.NewFromMap(map[string]string{})
 888		resolver := NewEnvironmentVariableResolver(env)
 889		err := cfg.configureProviders(env, resolver, knownProviders)
 890		require.NoError(t, err)
 891
 892		large, small, err := cfg.defaultModelSelection(knownProviders)
 893		require.NoError(t, err)
 894		require.Equal(t, "large-model", large.Model)
 895		require.Equal(t, "openai", large.Provider)
 896		require.Equal(t, int64(1000), large.MaxTokens)
 897		require.Equal(t, "small-model", small.Model)
 898		require.Equal(t, "openai", small.Provider)
 899		require.Equal(t, int64(500), small.MaxTokens)
 900	})
 901	t.Run("should error if no providers configured", func(t *testing.T) {
 902		knownProviders := []catwalk.Provider{
 903			{
 904				ID:                  "openai",
 905				APIKey:              "$MISSING_KEY",
 906				DefaultLargeModelID: "large-model",
 907				DefaultSmallModelID: "small-model",
 908				Models: []catwalk.Model{
 909					{
 910						ID:               "large-model",
 911						DefaultMaxTokens: 1000,
 912					},
 913					{
 914						ID:               "small-model",
 915						DefaultMaxTokens: 500,
 916					},
 917				},
 918			},
 919		}
 920
 921		cfg := &Config{}
 922		cfg.setDefaults("/tmp", "")
 923		env := env.NewFromMap(map[string]string{})
 924		resolver := NewEnvironmentVariableResolver(env)
 925		err := cfg.configureProviders(env, resolver, knownProviders)
 926		require.NoError(t, err)
 927
 928		_, _, err = cfg.defaultModelSelection(knownProviders)
 929		require.Error(t, err)
 930	})
 931	t.Run("should error if model is missing", func(t *testing.T) {
 932		knownProviders := []catwalk.Provider{
 933			{
 934				ID:                  "openai",
 935				APIKey:              "abc",
 936				DefaultLargeModelID: "large-model",
 937				DefaultSmallModelID: "small-model",
 938				Models: []catwalk.Model{
 939					{
 940						ID:               "not-large-model",
 941						DefaultMaxTokens: 1000,
 942					},
 943					{
 944						ID:               "small-model",
 945						DefaultMaxTokens: 500,
 946					},
 947				},
 948			},
 949		}
 950
 951		cfg := &Config{}
 952		cfg.setDefaults("/tmp", "")
 953		env := env.NewFromMap(map[string]string{})
 954		resolver := NewEnvironmentVariableResolver(env)
 955		err := cfg.configureProviders(env, resolver, knownProviders)
 956		require.NoError(t, err)
 957		_, _, err = cfg.defaultModelSelection(knownProviders)
 958		require.Error(t, err)
 959	})
 960
 961	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 962		knownProviders := []catwalk.Provider{
 963			{
 964				ID:                  "openai",
 965				APIKey:              "$MISSING", // will not be included in the config
 966				DefaultLargeModelID: "large-model",
 967				DefaultSmallModelID: "small-model",
 968				Models: []catwalk.Model{
 969					{
 970						ID:               "not-large-model",
 971						DefaultMaxTokens: 1000,
 972					},
 973					{
 974						ID:               "small-model",
 975						DefaultMaxTokens: 500,
 976					},
 977				},
 978			},
 979		}
 980
 981		cfg := &Config{
 982			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 983				"custom": {
 984					APIKey:  "test-key",
 985					BaseURL: "https://api.custom.com/v1",
 986					Models: []catwalk.Model{
 987						{
 988							ID:               "model",
 989							DefaultMaxTokens: 600,
 990						},
 991					},
 992				},
 993			}),
 994		}
 995		cfg.setDefaults("/tmp", "")
 996		env := env.NewFromMap(map[string]string{})
 997		resolver := NewEnvironmentVariableResolver(env)
 998		err := cfg.configureProviders(env, resolver, knownProviders)
 999		require.NoError(t, err)
1000		large, small, err := cfg.defaultModelSelection(knownProviders)
1001		require.NoError(t, err)
1002		require.Equal(t, "model", large.Model)
1003		require.Equal(t, "custom", large.Provider)
1004		require.Equal(t, int64(600), large.MaxTokens)
1005		require.Equal(t, "model", small.Model)
1006		require.Equal(t, "custom", small.Provider)
1007		require.Equal(t, int64(600), small.MaxTokens)
1008	})
1009
1010	t.Run("should fail if no model configured", func(t *testing.T) {
1011		knownProviders := []catwalk.Provider{
1012			{
1013				ID:                  "openai",
1014				APIKey:              "$MISSING", // will not be included in the config
1015				DefaultLargeModelID: "large-model",
1016				DefaultSmallModelID: "small-model",
1017				Models: []catwalk.Model{
1018					{
1019						ID:               "not-large-model",
1020						DefaultMaxTokens: 1000,
1021					},
1022					{
1023						ID:               "small-model",
1024						DefaultMaxTokens: 500,
1025					},
1026				},
1027			},
1028		}
1029
1030		cfg := &Config{
1031			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1032				"custom": {
1033					APIKey:  "test-key",
1034					BaseURL: "https://api.custom.com/v1",
1035					Models:  []catwalk.Model{},
1036				},
1037			}),
1038		}
1039		cfg.setDefaults("/tmp", "")
1040		env := env.NewFromMap(map[string]string{})
1041		resolver := NewEnvironmentVariableResolver(env)
1042		err := cfg.configureProviders(env, resolver, knownProviders)
1043		require.NoError(t, err)
1044		_, _, err = cfg.defaultModelSelection(knownProviders)
1045		require.Error(t, err)
1046	})
1047	t.Run("should use the default provider first", func(t *testing.T) {
1048		knownProviders := []catwalk.Provider{
1049			{
1050				ID:                  "openai",
1051				APIKey:              "set",
1052				DefaultLargeModelID: "large-model",
1053				DefaultSmallModelID: "small-model",
1054				Models: []catwalk.Model{
1055					{
1056						ID:               "large-model",
1057						DefaultMaxTokens: 1000,
1058					},
1059					{
1060						ID:               "small-model",
1061						DefaultMaxTokens: 500,
1062					},
1063				},
1064			},
1065		}
1066
1067		cfg := &Config{
1068			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1069				"custom": {
1070					APIKey:  "test-key",
1071					BaseURL: "https://api.custom.com/v1",
1072					Models: []catwalk.Model{
1073						{
1074							ID:               "large-model",
1075							DefaultMaxTokens: 1000,
1076						},
1077					},
1078				},
1079			}),
1080		}
1081		cfg.setDefaults("/tmp", "")
1082		env := env.NewFromMap(map[string]string{})
1083		resolver := NewEnvironmentVariableResolver(env)
1084		err := cfg.configureProviders(env, resolver, knownProviders)
1085		require.NoError(t, err)
1086		large, small, err := cfg.defaultModelSelection(knownProviders)
1087		require.NoError(t, err)
1088		require.Equal(t, "large-model", large.Model)
1089		require.Equal(t, "openai", large.Provider)
1090		require.Equal(t, int64(1000), large.MaxTokens)
1091		require.Equal(t, "small-model", small.Model)
1092		require.Equal(t, "openai", small.Provider)
1093		require.Equal(t, int64(500), small.MaxTokens)
1094	})
1095}
1096
1097func TestConfig_configureSelectedModels(t *testing.T) {
1098	t.Run("should override defaults", func(t *testing.T) {
1099		knownProviders := []catwalk.Provider{
1100			{
1101				ID:                  "openai",
1102				APIKey:              "abc",
1103				DefaultLargeModelID: "large-model",
1104				DefaultSmallModelID: "small-model",
1105				Models: []catwalk.Model{
1106					{
1107						ID:               "larger-model",
1108						DefaultMaxTokens: 2000,
1109					},
1110					{
1111						ID:               "large-model",
1112						DefaultMaxTokens: 1000,
1113					},
1114					{
1115						ID:               "small-model",
1116						DefaultMaxTokens: 500,
1117					},
1118				},
1119			},
1120		}
1121
1122		cfg := &Config{
1123			Models: map[SelectedModelType]SelectedModel{
1124				"large": {
1125					Model: "larger-model",
1126				},
1127			},
1128		}
1129		cfg.setDefaults("/tmp", "")
1130		env := env.NewFromMap(map[string]string{})
1131		resolver := NewEnvironmentVariableResolver(env)
1132		err := cfg.configureProviders(env, resolver, knownProviders)
1133		require.NoError(t, err)
1134
1135		err = cfg.configureSelectedModels(knownProviders)
1136		require.NoError(t, err)
1137		large := cfg.Models[SelectedModelTypeLarge]
1138		small := cfg.Models[SelectedModelTypeSmall]
1139		require.Equal(t, "larger-model", large.Model)
1140		require.Equal(t, "openai", large.Provider)
1141		require.Equal(t, int64(2000), large.MaxTokens)
1142		require.Equal(t, "small-model", small.Model)
1143		require.Equal(t, "openai", small.Provider)
1144		require.Equal(t, int64(500), small.MaxTokens)
1145	})
1146	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1147		knownProviders := []catwalk.Provider{
1148			{
1149				ID:                  "openai",
1150				APIKey:              "abc",
1151				DefaultLargeModelID: "large-model",
1152				DefaultSmallModelID: "small-model",
1153				Models: []catwalk.Model{
1154					{
1155						ID:               "large-model",
1156						DefaultMaxTokens: 1000,
1157					},
1158					{
1159						ID:               "small-model",
1160						DefaultMaxTokens: 500,
1161					},
1162				},
1163			},
1164			{
1165				ID:                  "anthropic",
1166				APIKey:              "abc",
1167				DefaultLargeModelID: "a-large-model",
1168				DefaultSmallModelID: "a-small-model",
1169				Models: []catwalk.Model{
1170					{
1171						ID:               "a-large-model",
1172						DefaultMaxTokens: 1000,
1173					},
1174					{
1175						ID:               "a-small-model",
1176						DefaultMaxTokens: 200,
1177					},
1178				},
1179			},
1180		}
1181
1182		cfg := &Config{
1183			Models: map[SelectedModelType]SelectedModel{
1184				"small": {
1185					Model:     "a-small-model",
1186					Provider:  "anthropic",
1187					MaxTokens: 300,
1188				},
1189			},
1190		}
1191		cfg.setDefaults("/tmp", "")
1192		env := env.NewFromMap(map[string]string{})
1193		resolver := NewEnvironmentVariableResolver(env)
1194		err := cfg.configureProviders(env, resolver, knownProviders)
1195		require.NoError(t, err)
1196
1197		err = cfg.configureSelectedModels(knownProviders)
1198		require.NoError(t, err)
1199		large := cfg.Models[SelectedModelTypeLarge]
1200		small := cfg.Models[SelectedModelTypeSmall]
1201		require.Equal(t, "large-model", large.Model)
1202		require.Equal(t, "openai", large.Provider)
1203		require.Equal(t, int64(1000), large.MaxTokens)
1204		require.Equal(t, "a-small-model", small.Model)
1205		require.Equal(t, "anthropic", small.Provider)
1206		require.Equal(t, int64(300), small.MaxTokens)
1207	})
1208
1209	t.Run("should override the max tokens only", func(t *testing.T) {
1210		knownProviders := []catwalk.Provider{
1211			{
1212				ID:                  "openai",
1213				APIKey:              "abc",
1214				DefaultLargeModelID: "large-model",
1215				DefaultSmallModelID: "small-model",
1216				Models: []catwalk.Model{
1217					{
1218						ID:               "large-model",
1219						DefaultMaxTokens: 1000,
1220					},
1221					{
1222						ID:               "small-model",
1223						DefaultMaxTokens: 500,
1224					},
1225				},
1226			},
1227		}
1228
1229		cfg := &Config{
1230			Models: map[SelectedModelType]SelectedModel{
1231				"large": {
1232					MaxTokens: 100,
1233				},
1234			},
1235		}
1236		cfg.setDefaults("/tmp", "")
1237		env := env.NewFromMap(map[string]string{})
1238		resolver := NewEnvironmentVariableResolver(env)
1239		err := cfg.configureProviders(env, resolver, knownProviders)
1240		require.NoError(t, err)
1241
1242		err = cfg.configureSelectedModels(knownProviders)
1243		require.NoError(t, err)
1244		large := cfg.Models[SelectedModelTypeLarge]
1245		require.Equal(t, "large-model", large.Model)
1246		require.Equal(t, "openai", large.Provider)
1247		require.Equal(t, int64(100), large.MaxTokens)
1248	})
1249}