load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/stretchr/testify/assert"
  15	"github.com/stretchr/testify/require"
  16)
  17
  18func TestMain(m *testing.M) {
  19	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  20
  21	exitVal := m.Run()
  22	os.Exit(exitVal)
  23}
  24
  25func TestConfig_LoadFromReaders(t *testing.T) {
  26	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  27	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  28	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  29
  30	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  31
  32	require.NoError(t, err)
  33	require.NotNil(t, loadedConfig)
  34	require.Equal(t, 1, loadedConfig.Providers.Len())
  35	pc, _ := loadedConfig.Providers.Get("openai")
  36	require.Equal(t, "key2", pc.APIKey)
  37	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  38}
  39
  40func TestConfig_setDefaults(t *testing.T) {
  41	cfg := &Config{}
  42
  43	cfg.setDefaults("/tmp", "")
  44
  45	require.NotNil(t, cfg.Options)
  46	require.NotNil(t, cfg.Options.TUI)
  47	require.NotNil(t, cfg.Options.ContextPaths)
  48	require.NotNil(t, cfg.Providers)
  49	require.NotNil(t, cfg.Models)
  50	require.NotNil(t, cfg.LSP)
  51	require.NotNil(t, cfg.MCP)
  52	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  53	for _, path := range defaultContextPaths {
  54		require.Contains(t, cfg.Options.ContextPaths, path)
  55	}
  56	require.Equal(t, "/tmp", cfg.workingDir)
  57}
  58
  59func TestConfig_configureProviders(t *testing.T) {
  60	knownProviders := []catwalk.Provider{
  61		{
  62			ID:          "openai",
  63			APIKey:      "$OPENAI_API_KEY",
  64			APIEndpoint: "https://api.openai.com/v1",
  65			Models: []catwalk.Model{{
  66				ID: "test-model",
  67			}},
  68		},
  69	}
  70
  71	cfg := &Config{}
  72	cfg.setDefaults("/tmp", "")
  73	env := env.NewFromMap(map[string]string{
  74		"OPENAI_API_KEY": "test-key",
  75	})
  76	resolver := NewEnvironmentVariableResolver(env)
  77	err := cfg.configureProviders(env, resolver, knownProviders)
  78	require.NoError(t, err)
  79	require.Equal(t, 1, cfg.Providers.Len())
  80
  81	// We want to make sure that we keep the configured API key as a placeholder
  82	pc, _ := cfg.Providers.Get("openai")
  83	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  84}
  85
  86func TestConfig_configureProvidersWithOverride(t *testing.T) {
  87	knownProviders := []catwalk.Provider{
  88		{
  89			ID:          "openai",
  90			APIKey:      "$OPENAI_API_KEY",
  91			APIEndpoint: "https://api.openai.com/v1",
  92			Models: []catwalk.Model{{
  93				ID: "test-model",
  94			}},
  95		},
  96	}
  97
  98	cfg := &Config{
  99		Providers: csync.NewMap[string, ProviderConfig](),
 100	}
 101	cfg.Providers.Set("openai", ProviderConfig{
 102		APIKey:  "xyz",
 103		BaseURL: "https://api.openai.com/v2",
 104		Models: []catwalk.Model{
 105			{
 106				ID:   "test-model",
 107				Name: "Updated",
 108			},
 109			{
 110				ID: "another-model",
 111			},
 112		},
 113	})
 114	cfg.setDefaults("/tmp", "")
 115
 116	env := env.NewFromMap(map[string]string{
 117		"OPENAI_API_KEY": "test-key",
 118	})
 119	resolver := NewEnvironmentVariableResolver(env)
 120	err := cfg.configureProviders(env, resolver, knownProviders)
 121	require.NoError(t, err)
 122	require.Equal(t, 1, cfg.Providers.Len())
 123
 124	// We want to make sure that we keep the configured API key as a placeholder
 125	pc, _ := cfg.Providers.Get("openai")
 126	require.Equal(t, "xyz", pc.APIKey)
 127	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 128	require.Len(t, pc.Models, 2)
 129	require.Equal(t, "Updated", pc.Models[0].Name)
 130}
 131
 132func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 133	knownProviders := []catwalk.Provider{
 134		{
 135			ID:          "openai",
 136			APIKey:      "$OPENAI_API_KEY",
 137			APIEndpoint: "https://api.openai.com/v1",
 138			Models: []catwalk.Model{{
 139				ID: "test-model",
 140			}},
 141		},
 142	}
 143
 144	cfg := &Config{
 145		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 146			"custom": {
 147				APIKey:  "xyz",
 148				BaseURL: "https://api.someendpoint.com/v2",
 149				Models: []catwalk.Model{
 150					{
 151						ID: "test-model",
 152					},
 153				},
 154			},
 155		}),
 156	}
 157	cfg.setDefaults("/tmp", "")
 158	env := env.NewFromMap(map[string]string{
 159		"OPENAI_API_KEY": "test-key",
 160	})
 161	resolver := NewEnvironmentVariableResolver(env)
 162	err := cfg.configureProviders(env, resolver, knownProviders)
 163	require.NoError(t, err)
 164	// Should be to because of the env variable
 165	require.Equal(t, cfg.Providers.Len(), 2)
 166
 167	// We want to make sure that we keep the configured API key as a placeholder
 168	pc, _ := cfg.Providers.Get("custom")
 169	require.Equal(t, "xyz", pc.APIKey)
 170	// Make sure we set the ID correctly
 171	require.Equal(t, "custom", pc.ID)
 172	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 173	require.Len(t, pc.Models, 1)
 174
 175	_, ok := cfg.Providers.Get("openai")
 176	require.True(t, ok, "OpenAI provider should still be present")
 177}
 178
 179func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 180	knownProviders := []catwalk.Provider{
 181		{
 182			ID:          catwalk.InferenceProviderBedrock,
 183			APIKey:      "",
 184			APIEndpoint: "",
 185			Models: []catwalk.Model{{
 186				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 187			}},
 188		},
 189	}
 190
 191	cfg := &Config{}
 192	cfg.setDefaults("/tmp", "")
 193	env := env.NewFromMap(map[string]string{
 194		"AWS_ACCESS_KEY_ID":     "test-key-id",
 195		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 196	})
 197	resolver := NewEnvironmentVariableResolver(env)
 198	err := cfg.configureProviders(env, resolver, knownProviders)
 199	require.NoError(t, err)
 200	require.Equal(t, cfg.Providers.Len(), 1)
 201
 202	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 203	require.True(t, ok, "Bedrock provider should be present")
 204	require.Len(t, bedrockProvider.Models, 1)
 205	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 206}
 207
 208func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 209	knownProviders := []catwalk.Provider{
 210		{
 211			ID:          catwalk.InferenceProviderBedrock,
 212			APIKey:      "",
 213			APIEndpoint: "",
 214			Models: []catwalk.Model{{
 215				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 216			}},
 217		},
 218	}
 219
 220	cfg := &Config{}
 221	cfg.setDefaults("/tmp", "")
 222	env := env.NewFromMap(map[string]string{})
 223	resolver := NewEnvironmentVariableResolver(env)
 224	err := cfg.configureProviders(env, resolver, knownProviders)
 225	require.NoError(t, err)
 226	// Provider should not be configured without credentials
 227	require.Equal(t, cfg.Providers.Len(), 0)
 228}
 229
 230func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 231	knownProviders := []catwalk.Provider{
 232		{
 233			ID:          catwalk.InferenceProviderBedrock,
 234			APIKey:      "",
 235			APIEndpoint: "",
 236			Models: []catwalk.Model{{
 237				ID: "some-random-model",
 238			}},
 239		},
 240	}
 241
 242	cfg := &Config{}
 243	cfg.setDefaults("/tmp", "")
 244	env := env.NewFromMap(map[string]string{
 245		"AWS_ACCESS_KEY_ID":     "test-key-id",
 246		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 247	})
 248	resolver := NewEnvironmentVariableResolver(env)
 249	err := cfg.configureProviders(env, resolver, knownProviders)
 250	require.Error(t, err)
 251}
 252
 253func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 254	knownProviders := []catwalk.Provider{
 255		{
 256			ID:          catwalk.InferenceProviderVertexAI,
 257			APIKey:      "",
 258			APIEndpoint: "",
 259			Models: []catwalk.Model{{
 260				ID: "gemini-pro",
 261			}},
 262		},
 263	}
 264
 265	cfg := &Config{}
 266	cfg.setDefaults("/tmp", "")
 267	env := env.NewFromMap(map[string]string{
 268		"VERTEXAI_PROJECT":  "test-project",
 269		"VERTEXAI_LOCATION": "us-central1",
 270	})
 271	resolver := NewEnvironmentVariableResolver(env)
 272	err := cfg.configureProviders(env, resolver, knownProviders)
 273	require.NoError(t, err)
 274	require.Equal(t, cfg.Providers.Len(), 1)
 275
 276	vertexProvider, ok := cfg.Providers.Get("vertexai")
 277	require.True(t, ok, "VertexAI provider should be present")
 278	require.Len(t, vertexProvider.Models, 1)
 279	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 280	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 281	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 282}
 283
 284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 285	knownProviders := []catwalk.Provider{
 286		{
 287			ID:          catwalk.InferenceProviderVertexAI,
 288			APIKey:      "",
 289			APIEndpoint: "",
 290			Models: []catwalk.Model{{
 291				ID: "gemini-pro",
 292			}},
 293		},
 294	}
 295
 296	cfg := &Config{}
 297	cfg.setDefaults("/tmp", "")
 298	env := env.NewFromMap(map[string]string{
 299		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 300		"GOOGLE_CLOUD_PROJECT":      "test-project",
 301		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 302	})
 303	resolver := NewEnvironmentVariableResolver(env)
 304	err := cfg.configureProviders(env, resolver, knownProviders)
 305	require.NoError(t, err)
 306	// Provider should not be configured without proper credentials
 307	require.Equal(t, cfg.Providers.Len(), 0)
 308}
 309
 310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 311	knownProviders := []catwalk.Provider{
 312		{
 313			ID:          catwalk.InferenceProviderVertexAI,
 314			APIKey:      "",
 315			APIEndpoint: "",
 316			Models: []catwalk.Model{{
 317				ID: "gemini-pro",
 318			}},
 319		},
 320	}
 321
 322	cfg := &Config{}
 323	cfg.setDefaults("/tmp", "")
 324	env := env.NewFromMap(map[string]string{
 325		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 326		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 327	})
 328	resolver := NewEnvironmentVariableResolver(env)
 329	err := cfg.configureProviders(env, resolver, knownProviders)
 330	require.NoError(t, err)
 331	// Provider should not be configured without project
 332	require.Equal(t, cfg.Providers.Len(), 0)
 333}
 334
 335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 336	knownProviders := []catwalk.Provider{
 337		{
 338			ID:          "openai",
 339			APIKey:      "$OPENAI_API_KEY",
 340			APIEndpoint: "https://api.openai.com/v1",
 341			Models: []catwalk.Model{{
 342				ID: "test-model",
 343			}},
 344		},
 345	}
 346
 347	cfg := &Config{}
 348	cfg.setDefaults("/tmp", "")
 349	env := env.NewFromMap(map[string]string{
 350		"OPENAI_API_KEY": "test-key",
 351	})
 352	resolver := NewEnvironmentVariableResolver(env)
 353	err := cfg.configureProviders(env, resolver, knownProviders)
 354	require.NoError(t, err)
 355	require.Equal(t, cfg.Providers.Len(), 1)
 356
 357	// Provider ID should be set
 358	pc, _ := cfg.Providers.Get("openai")
 359	require.Equal(t, "openai", pc.ID)
 360}
 361
 362func TestConfig_EnabledProviders(t *testing.T) {
 363	t.Run("all providers enabled", func(t *testing.T) {
 364		cfg := &Config{
 365			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 366				"openai": {
 367					ID:      "openai",
 368					APIKey:  "key1",
 369					Disable: false,
 370				},
 371				"anthropic": {
 372					ID:      "anthropic",
 373					APIKey:  "key2",
 374					Disable: false,
 375				},
 376			}),
 377		}
 378
 379		enabled := cfg.EnabledProviders()
 380		require.Len(t, enabled, 2)
 381	})
 382
 383	t.Run("some providers disabled", func(t *testing.T) {
 384		cfg := &Config{
 385			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 386				"openai": {
 387					ID:      "openai",
 388					APIKey:  "key1",
 389					Disable: false,
 390				},
 391				"anthropic": {
 392					ID:      "anthropic",
 393					APIKey:  "key2",
 394					Disable: true,
 395				},
 396			}),
 397		}
 398
 399		enabled := cfg.EnabledProviders()
 400		require.Len(t, enabled, 1)
 401		require.Equal(t, "openai", enabled[0].ID)
 402	})
 403
 404	t.Run("empty providers map", func(t *testing.T) {
 405		cfg := &Config{
 406			Providers: csync.NewMap[string, ProviderConfig](),
 407		}
 408
 409		enabled := cfg.EnabledProviders()
 410		require.Len(t, enabled, 0)
 411	})
 412}
 413
 414func TestConfig_IsConfigured(t *testing.T) {
 415	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 416		cfg := &Config{
 417			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 418				"openai": {
 419					ID:      "openai",
 420					APIKey:  "key1",
 421					Disable: false,
 422				},
 423			}),
 424		}
 425
 426		require.True(t, cfg.IsConfigured())
 427	})
 428
 429	t.Run("returns false when no providers are configured", func(t *testing.T) {
 430		cfg := &Config{
 431			Providers: csync.NewMap[string, ProviderConfig](),
 432		}
 433
 434		require.False(t, cfg.IsConfigured())
 435	})
 436
 437	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 438		cfg := &Config{
 439			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 440				"openai": {
 441					ID:      "openai",
 442					APIKey:  "key1",
 443					Disable: true,
 444				},
 445				"anthropic": {
 446					ID:      "anthropic",
 447					APIKey:  "key2",
 448					Disable: true,
 449				},
 450			}),
 451		}
 452
 453		require.False(t, cfg.IsConfigured())
 454	})
 455}
 456
 457func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 458	cfg := &Config{
 459		Options: &Options{
 460			DisabledTools: []string{},
 461		},
 462	}
 463
 464	cfg.SetupAgents()
 465	coderAgent, ok := cfg.Agents[AgentCoder]
 466	require.True(t, ok)
 467	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 468
 469	taskAgent, ok := cfg.Agents[AgentTask]
 470	require.True(t, ok)
 471	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 472}
 473
 474func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 475	cfg := &Config{
 476		Options: &Options{
 477			DisabledTools: []string{
 478				"edit",
 479				"download",
 480				"grep",
 481			},
 482		},
 483	}
 484
 485	cfg.SetupAgents()
 486	coderAgent, ok := cfg.Agents[AgentCoder]
 487	require.True(t, ok)
 488	assert.Equal(t, []string{"agent", "bash", "bash_output", "bash_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
 489
 490	taskAgent, ok := cfg.Agents[AgentTask]
 491	require.True(t, ok)
 492	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 493}
 494
 495func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 496	cfg := &Config{
 497		Options: &Options{
 498			DisabledTools: []string{
 499				"glob",
 500				"grep",
 501				"ls",
 502				"sourcegraph",
 503				"view",
 504			},
 505		},
 506	}
 507
 508	cfg.SetupAgents()
 509	coderAgent, ok := cfg.Agents[AgentCoder]
 510	require.True(t, ok)
 511	assert.Equal(t, []string{"agent", "bash", "bash_output", "bash_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools)
 512
 513	taskAgent, ok := cfg.Agents[AgentTask]
 514	require.True(t, ok)
 515	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 516}
 517
 518func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 519	knownProviders := []catwalk.Provider{
 520		{
 521			ID:          "openai",
 522			APIKey:      "$OPENAI_API_KEY",
 523			APIEndpoint: "https://api.openai.com/v1",
 524			Models: []catwalk.Model{{
 525				ID: "test-model",
 526			}},
 527		},
 528	}
 529
 530	cfg := &Config{
 531		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 532			"openai": {
 533				Disable: true,
 534			},
 535		}),
 536	}
 537	cfg.setDefaults("/tmp", "")
 538
 539	env := env.NewFromMap(map[string]string{
 540		"OPENAI_API_KEY": "test-key",
 541	})
 542	resolver := NewEnvironmentVariableResolver(env)
 543	err := cfg.configureProviders(env, resolver, knownProviders)
 544	require.NoError(t, err)
 545
 546	require.Equal(t, cfg.Providers.Len(), 1)
 547	prov, exists := cfg.Providers.Get("openai")
 548	require.True(t, exists)
 549	require.True(t, prov.Disable)
 550}
 551
 552func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 553	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 554		cfg := &Config{
 555			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 556				"custom": {
 557					BaseURL: "https://api.custom.com/v1",
 558					Models: []catwalk.Model{{
 559						ID: "test-model",
 560					}},
 561				},
 562				"openai": {
 563					APIKey: "$MISSING",
 564				},
 565			}),
 566		}
 567		cfg.setDefaults("/tmp", "")
 568
 569		env := env.NewFromMap(map[string]string{})
 570		resolver := NewEnvironmentVariableResolver(env)
 571		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 572		require.NoError(t, err)
 573
 574		require.Equal(t, cfg.Providers.Len(), 1)
 575		_, exists := cfg.Providers.Get("custom")
 576		require.True(t, exists)
 577	})
 578
 579	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 580		cfg := &Config{
 581			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 582				"custom": {
 583					APIKey: "test-key",
 584					Models: []catwalk.Model{{
 585						ID: "test-model",
 586					}},
 587				},
 588			}),
 589		}
 590		cfg.setDefaults("/tmp", "")
 591
 592		env := env.NewFromMap(map[string]string{})
 593		resolver := NewEnvironmentVariableResolver(env)
 594		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 595		require.NoError(t, err)
 596
 597		require.Equal(t, cfg.Providers.Len(), 0)
 598		_, exists := cfg.Providers.Get("custom")
 599		require.False(t, exists)
 600	})
 601
 602	t.Run("custom provider with no models is removed", func(t *testing.T) {
 603		cfg := &Config{
 604			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 605				"custom": {
 606					APIKey:  "test-key",
 607					BaseURL: "https://api.custom.com/v1",
 608					Models:  []catwalk.Model{},
 609				},
 610			}),
 611		}
 612		cfg.setDefaults("/tmp", "")
 613
 614		env := env.NewFromMap(map[string]string{})
 615		resolver := NewEnvironmentVariableResolver(env)
 616		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 617		require.NoError(t, err)
 618
 619		require.Equal(t, cfg.Providers.Len(), 0)
 620		_, exists := cfg.Providers.Get("custom")
 621		require.False(t, exists)
 622	})
 623
 624	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 625		cfg := &Config{
 626			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 627				"custom": {
 628					APIKey:  "test-key",
 629					BaseURL: "https://api.custom.com/v1",
 630					Type:    "unsupported",
 631					Models: []catwalk.Model{{
 632						ID: "test-model",
 633					}},
 634				},
 635			}),
 636		}
 637		cfg.setDefaults("/tmp", "")
 638
 639		env := env.NewFromMap(map[string]string{})
 640		resolver := NewEnvironmentVariableResolver(env)
 641		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 642		require.NoError(t, err)
 643
 644		require.Equal(t, cfg.Providers.Len(), 0)
 645		_, exists := cfg.Providers.Get("custom")
 646		require.False(t, exists)
 647	})
 648
 649	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 650		cfg := &Config{
 651			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 652				"custom": {
 653					APIKey:  "test-key",
 654					BaseURL: "https://api.custom.com/v1",
 655					Type:    catwalk.TypeOpenAI,
 656					Models: []catwalk.Model{{
 657						ID: "test-model",
 658					}},
 659				},
 660			}),
 661		}
 662		cfg.setDefaults("/tmp", "")
 663
 664		env := env.NewFromMap(map[string]string{})
 665		resolver := NewEnvironmentVariableResolver(env)
 666		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 667		require.NoError(t, err)
 668
 669		require.Equal(t, cfg.Providers.Len(), 1)
 670		customProvider, exists := cfg.Providers.Get("custom")
 671		require.True(t, exists)
 672		require.Equal(t, "custom", customProvider.ID)
 673		require.Equal(t, "test-key", customProvider.APIKey)
 674		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 675	})
 676
 677	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 678		cfg := &Config{
 679			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 680				"custom-anthropic": {
 681					APIKey:  "test-key",
 682					BaseURL: "https://api.anthropic.com/v1",
 683					Type:    catwalk.TypeAnthropic,
 684					Models: []catwalk.Model{{
 685						ID: "claude-3-sonnet",
 686					}},
 687				},
 688			}),
 689		}
 690		cfg.setDefaults("/tmp", "")
 691
 692		env := env.NewFromMap(map[string]string{})
 693		resolver := NewEnvironmentVariableResolver(env)
 694		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 695		require.NoError(t, err)
 696
 697		require.Equal(t, cfg.Providers.Len(), 1)
 698		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 699		require.True(t, exists)
 700		require.Equal(t, "custom-anthropic", customProvider.ID)
 701		require.Equal(t, "test-key", customProvider.APIKey)
 702		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 703		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 704	})
 705
 706	t.Run("disabled custom provider is removed", func(t *testing.T) {
 707		cfg := &Config{
 708			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 709				"custom": {
 710					APIKey:  "test-key",
 711					BaseURL: "https://api.custom.com/v1",
 712					Type:    catwalk.TypeOpenAI,
 713					Disable: true,
 714					Models: []catwalk.Model{{
 715						ID: "test-model",
 716					}},
 717				},
 718			}),
 719		}
 720		cfg.setDefaults("/tmp", "")
 721
 722		env := env.NewFromMap(map[string]string{})
 723		resolver := NewEnvironmentVariableResolver(env)
 724		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 725		require.NoError(t, err)
 726
 727		require.Equal(t, cfg.Providers.Len(), 0)
 728		_, exists := cfg.Providers.Get("custom")
 729		require.False(t, exists)
 730	})
 731}
 732
 733func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 734	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 735		knownProviders := []catwalk.Provider{
 736			{
 737				ID:          catwalk.InferenceProviderVertexAI,
 738				APIKey:      "",
 739				APIEndpoint: "",
 740				Models: []catwalk.Model{{
 741					ID: "gemini-pro",
 742				}},
 743			},
 744		}
 745
 746		cfg := &Config{
 747			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 748				"vertexai": {
 749					BaseURL: "custom-url",
 750				},
 751			}),
 752		}
 753		cfg.setDefaults("/tmp", "")
 754
 755		env := env.NewFromMap(map[string]string{
 756			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 757		})
 758		resolver := NewEnvironmentVariableResolver(env)
 759		err := cfg.configureProviders(env, resolver, knownProviders)
 760		require.NoError(t, err)
 761
 762		require.Equal(t, cfg.Providers.Len(), 0)
 763		_, exists := cfg.Providers.Get("vertexai")
 764		require.False(t, exists)
 765	})
 766
 767	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 768		knownProviders := []catwalk.Provider{
 769			{
 770				ID:          catwalk.InferenceProviderBedrock,
 771				APIKey:      "",
 772				APIEndpoint: "",
 773				Models: []catwalk.Model{{
 774					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 775				}},
 776			},
 777		}
 778
 779		cfg := &Config{
 780			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 781				"bedrock": {
 782					BaseURL: "custom-url",
 783				},
 784			}),
 785		}
 786		cfg.setDefaults("/tmp", "")
 787
 788		env := env.NewFromMap(map[string]string{})
 789		resolver := NewEnvironmentVariableResolver(env)
 790		err := cfg.configureProviders(env, resolver, knownProviders)
 791		require.NoError(t, err)
 792
 793		require.Equal(t, cfg.Providers.Len(), 0)
 794		_, exists := cfg.Providers.Get("bedrock")
 795		require.False(t, exists)
 796	})
 797
 798	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 799		knownProviders := []catwalk.Provider{
 800			{
 801				ID:          "openai",
 802				APIKey:      "$MISSING_API_KEY",
 803				APIEndpoint: "https://api.openai.com/v1",
 804				Models: []catwalk.Model{{
 805					ID: "test-model",
 806				}},
 807			},
 808		}
 809
 810		cfg := &Config{
 811			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 812				"openai": {
 813					BaseURL: "custom-url",
 814				},
 815			}),
 816		}
 817		cfg.setDefaults("/tmp", "")
 818
 819		env := env.NewFromMap(map[string]string{})
 820		resolver := NewEnvironmentVariableResolver(env)
 821		err := cfg.configureProviders(env, resolver, knownProviders)
 822		require.NoError(t, err)
 823
 824		require.Equal(t, cfg.Providers.Len(), 0)
 825		_, exists := cfg.Providers.Get("openai")
 826		require.False(t, exists)
 827	})
 828
 829	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 830		knownProviders := []catwalk.Provider{
 831			{
 832				ID:          "openai",
 833				APIKey:      "$OPENAI_API_KEY",
 834				APIEndpoint: "$MISSING_ENDPOINT",
 835				Models: []catwalk.Model{{
 836					ID: "test-model",
 837				}},
 838			},
 839		}
 840
 841		cfg := &Config{
 842			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 843				"openai": {
 844					APIKey: "test-key",
 845				},
 846			}),
 847		}
 848		cfg.setDefaults("/tmp", "")
 849
 850		env := env.NewFromMap(map[string]string{
 851			"OPENAI_API_KEY": "test-key",
 852		})
 853		resolver := NewEnvironmentVariableResolver(env)
 854		err := cfg.configureProviders(env, resolver, knownProviders)
 855		require.NoError(t, err)
 856
 857		require.Equal(t, cfg.Providers.Len(), 1)
 858		_, exists := cfg.Providers.Get("openai")
 859		require.True(t, exists)
 860	})
 861}
 862
 863func TestConfig_defaultModelSelection(t *testing.T) {
 864	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 865		knownProviders := []catwalk.Provider{
 866			{
 867				ID:                  "openai",
 868				APIKey:              "abc",
 869				DefaultLargeModelID: "large-model",
 870				DefaultSmallModelID: "small-model",
 871				Models: []catwalk.Model{
 872					{
 873						ID:               "large-model",
 874						DefaultMaxTokens: 1000,
 875					},
 876					{
 877						ID:               "small-model",
 878						DefaultMaxTokens: 500,
 879					},
 880				},
 881			},
 882		}
 883
 884		cfg := &Config{}
 885		cfg.setDefaults("/tmp", "")
 886		env := env.NewFromMap(map[string]string{})
 887		resolver := NewEnvironmentVariableResolver(env)
 888		err := cfg.configureProviders(env, resolver, knownProviders)
 889		require.NoError(t, err)
 890
 891		large, small, err := cfg.defaultModelSelection(knownProviders)
 892		require.NoError(t, err)
 893		require.Equal(t, "large-model", large.Model)
 894		require.Equal(t, "openai", large.Provider)
 895		require.Equal(t, int64(1000), large.MaxTokens)
 896		require.Equal(t, "small-model", small.Model)
 897		require.Equal(t, "openai", small.Provider)
 898		require.Equal(t, int64(500), small.MaxTokens)
 899	})
 900	t.Run("should error if no providers configured", func(t *testing.T) {
 901		knownProviders := []catwalk.Provider{
 902			{
 903				ID:                  "openai",
 904				APIKey:              "$MISSING_KEY",
 905				DefaultLargeModelID: "large-model",
 906				DefaultSmallModelID: "small-model",
 907				Models: []catwalk.Model{
 908					{
 909						ID:               "large-model",
 910						DefaultMaxTokens: 1000,
 911					},
 912					{
 913						ID:               "small-model",
 914						DefaultMaxTokens: 500,
 915					},
 916				},
 917			},
 918		}
 919
 920		cfg := &Config{}
 921		cfg.setDefaults("/tmp", "")
 922		env := env.NewFromMap(map[string]string{})
 923		resolver := NewEnvironmentVariableResolver(env)
 924		err := cfg.configureProviders(env, resolver, knownProviders)
 925		require.NoError(t, err)
 926
 927		_, _, err = cfg.defaultModelSelection(knownProviders)
 928		require.Error(t, err)
 929	})
 930	t.Run("should error if model is missing", func(t *testing.T) {
 931		knownProviders := []catwalk.Provider{
 932			{
 933				ID:                  "openai",
 934				APIKey:              "abc",
 935				DefaultLargeModelID: "large-model",
 936				DefaultSmallModelID: "small-model",
 937				Models: []catwalk.Model{
 938					{
 939						ID:               "not-large-model",
 940						DefaultMaxTokens: 1000,
 941					},
 942					{
 943						ID:               "small-model",
 944						DefaultMaxTokens: 500,
 945					},
 946				},
 947			},
 948		}
 949
 950		cfg := &Config{}
 951		cfg.setDefaults("/tmp", "")
 952		env := env.NewFromMap(map[string]string{})
 953		resolver := NewEnvironmentVariableResolver(env)
 954		err := cfg.configureProviders(env, resolver, knownProviders)
 955		require.NoError(t, err)
 956		_, _, err = cfg.defaultModelSelection(knownProviders)
 957		require.Error(t, err)
 958	})
 959
 960	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 961		knownProviders := []catwalk.Provider{
 962			{
 963				ID:                  "openai",
 964				APIKey:              "$MISSING", // will not be included in the config
 965				DefaultLargeModelID: "large-model",
 966				DefaultSmallModelID: "small-model",
 967				Models: []catwalk.Model{
 968					{
 969						ID:               "not-large-model",
 970						DefaultMaxTokens: 1000,
 971					},
 972					{
 973						ID:               "small-model",
 974						DefaultMaxTokens: 500,
 975					},
 976				},
 977			},
 978		}
 979
 980		cfg := &Config{
 981			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 982				"custom": {
 983					APIKey:  "test-key",
 984					BaseURL: "https://api.custom.com/v1",
 985					Models: []catwalk.Model{
 986						{
 987							ID:               "model",
 988							DefaultMaxTokens: 600,
 989						},
 990					},
 991				},
 992			}),
 993		}
 994		cfg.setDefaults("/tmp", "")
 995		env := env.NewFromMap(map[string]string{})
 996		resolver := NewEnvironmentVariableResolver(env)
 997		err := cfg.configureProviders(env, resolver, knownProviders)
 998		require.NoError(t, err)
 999		large, small, err := cfg.defaultModelSelection(knownProviders)
1000		require.NoError(t, err)
1001		require.Equal(t, "model", large.Model)
1002		require.Equal(t, "custom", large.Provider)
1003		require.Equal(t, int64(600), large.MaxTokens)
1004		require.Equal(t, "model", small.Model)
1005		require.Equal(t, "custom", small.Provider)
1006		require.Equal(t, int64(600), small.MaxTokens)
1007	})
1008
1009	t.Run("should fail if no model configured", func(t *testing.T) {
1010		knownProviders := []catwalk.Provider{
1011			{
1012				ID:                  "openai",
1013				APIKey:              "$MISSING", // will not be included in the config
1014				DefaultLargeModelID: "large-model",
1015				DefaultSmallModelID: "small-model",
1016				Models: []catwalk.Model{
1017					{
1018						ID:               "not-large-model",
1019						DefaultMaxTokens: 1000,
1020					},
1021					{
1022						ID:               "small-model",
1023						DefaultMaxTokens: 500,
1024					},
1025				},
1026			},
1027		}
1028
1029		cfg := &Config{
1030			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1031				"custom": {
1032					APIKey:  "test-key",
1033					BaseURL: "https://api.custom.com/v1",
1034					Models:  []catwalk.Model{},
1035				},
1036			}),
1037		}
1038		cfg.setDefaults("/tmp", "")
1039		env := env.NewFromMap(map[string]string{})
1040		resolver := NewEnvironmentVariableResolver(env)
1041		err := cfg.configureProviders(env, resolver, knownProviders)
1042		require.NoError(t, err)
1043		_, _, err = cfg.defaultModelSelection(knownProviders)
1044		require.Error(t, err)
1045	})
1046	t.Run("should use the default provider first", func(t *testing.T) {
1047		knownProviders := []catwalk.Provider{
1048			{
1049				ID:                  "openai",
1050				APIKey:              "set",
1051				DefaultLargeModelID: "large-model",
1052				DefaultSmallModelID: "small-model",
1053				Models: []catwalk.Model{
1054					{
1055						ID:               "large-model",
1056						DefaultMaxTokens: 1000,
1057					},
1058					{
1059						ID:               "small-model",
1060						DefaultMaxTokens: 500,
1061					},
1062				},
1063			},
1064		}
1065
1066		cfg := &Config{
1067			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1068				"custom": {
1069					APIKey:  "test-key",
1070					BaseURL: "https://api.custom.com/v1",
1071					Models: []catwalk.Model{
1072						{
1073							ID:               "large-model",
1074							DefaultMaxTokens: 1000,
1075						},
1076					},
1077				},
1078			}),
1079		}
1080		cfg.setDefaults("/tmp", "")
1081		env := env.NewFromMap(map[string]string{})
1082		resolver := NewEnvironmentVariableResolver(env)
1083		err := cfg.configureProviders(env, resolver, knownProviders)
1084		require.NoError(t, err)
1085		large, small, err := cfg.defaultModelSelection(knownProviders)
1086		require.NoError(t, err)
1087		require.Equal(t, "large-model", large.Model)
1088		require.Equal(t, "openai", large.Provider)
1089		require.Equal(t, int64(1000), large.MaxTokens)
1090		require.Equal(t, "small-model", small.Model)
1091		require.Equal(t, "openai", small.Provider)
1092		require.Equal(t, int64(500), small.MaxTokens)
1093	})
1094}
1095
1096func TestConfig_configureSelectedModels(t *testing.T) {
1097	t.Run("should override defaults", func(t *testing.T) {
1098		knownProviders := []catwalk.Provider{
1099			{
1100				ID:                  "openai",
1101				APIKey:              "abc",
1102				DefaultLargeModelID: "large-model",
1103				DefaultSmallModelID: "small-model",
1104				Models: []catwalk.Model{
1105					{
1106						ID:               "larger-model",
1107						DefaultMaxTokens: 2000,
1108					},
1109					{
1110						ID:               "large-model",
1111						DefaultMaxTokens: 1000,
1112					},
1113					{
1114						ID:               "small-model",
1115						DefaultMaxTokens: 500,
1116					},
1117				},
1118			},
1119		}
1120
1121		cfg := &Config{
1122			Models: map[SelectedModelType]SelectedModel{
1123				"large": {
1124					Model: "larger-model",
1125				},
1126			},
1127		}
1128		cfg.setDefaults("/tmp", "")
1129		env := env.NewFromMap(map[string]string{})
1130		resolver := NewEnvironmentVariableResolver(env)
1131		err := cfg.configureProviders(env, resolver, knownProviders)
1132		require.NoError(t, err)
1133
1134		err = cfg.configureSelectedModels(knownProviders)
1135		require.NoError(t, err)
1136		large := cfg.Models[SelectedModelTypeLarge]
1137		small := cfg.Models[SelectedModelTypeSmall]
1138		require.Equal(t, "larger-model", large.Model)
1139		require.Equal(t, "openai", large.Provider)
1140		require.Equal(t, int64(2000), large.MaxTokens)
1141		require.Equal(t, "small-model", small.Model)
1142		require.Equal(t, "openai", small.Provider)
1143		require.Equal(t, int64(500), small.MaxTokens)
1144	})
1145	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1146		knownProviders := []catwalk.Provider{
1147			{
1148				ID:                  "openai",
1149				APIKey:              "abc",
1150				DefaultLargeModelID: "large-model",
1151				DefaultSmallModelID: "small-model",
1152				Models: []catwalk.Model{
1153					{
1154						ID:               "large-model",
1155						DefaultMaxTokens: 1000,
1156					},
1157					{
1158						ID:               "small-model",
1159						DefaultMaxTokens: 500,
1160					},
1161				},
1162			},
1163			{
1164				ID:                  "anthropic",
1165				APIKey:              "abc",
1166				DefaultLargeModelID: "a-large-model",
1167				DefaultSmallModelID: "a-small-model",
1168				Models: []catwalk.Model{
1169					{
1170						ID:               "a-large-model",
1171						DefaultMaxTokens: 1000,
1172					},
1173					{
1174						ID:               "a-small-model",
1175						DefaultMaxTokens: 200,
1176					},
1177				},
1178			},
1179		}
1180
1181		cfg := &Config{
1182			Models: map[SelectedModelType]SelectedModel{
1183				"small": {
1184					Model:     "a-small-model",
1185					Provider:  "anthropic",
1186					MaxTokens: 300,
1187				},
1188			},
1189		}
1190		cfg.setDefaults("/tmp", "")
1191		env := env.NewFromMap(map[string]string{})
1192		resolver := NewEnvironmentVariableResolver(env)
1193		err := cfg.configureProviders(env, resolver, knownProviders)
1194		require.NoError(t, err)
1195
1196		err = cfg.configureSelectedModels(knownProviders)
1197		require.NoError(t, err)
1198		large := cfg.Models[SelectedModelTypeLarge]
1199		small := cfg.Models[SelectedModelTypeSmall]
1200		require.Equal(t, "large-model", large.Model)
1201		require.Equal(t, "openai", large.Provider)
1202		require.Equal(t, int64(1000), large.MaxTokens)
1203		require.Equal(t, "a-small-model", small.Model)
1204		require.Equal(t, "anthropic", small.Provider)
1205		require.Equal(t, int64(300), small.MaxTokens)
1206	})
1207
1208	t.Run("should override the max tokens only", func(t *testing.T) {
1209		knownProviders := []catwalk.Provider{
1210			{
1211				ID:                  "openai",
1212				APIKey:              "abc",
1213				DefaultLargeModelID: "large-model",
1214				DefaultSmallModelID: "small-model",
1215				Models: []catwalk.Model{
1216					{
1217						ID:               "large-model",
1218						DefaultMaxTokens: 1000,
1219					},
1220					{
1221						ID:               "small-model",
1222						DefaultMaxTokens: 500,
1223					},
1224				},
1225			},
1226		}
1227
1228		cfg := &Config{
1229			Models: map[SelectedModelType]SelectedModel{
1230				"large": {
1231					MaxTokens: 100,
1232				},
1233			},
1234		}
1235		cfg.setDefaults("/tmp", "")
1236		env := env.NewFromMap(map[string]string{})
1237		resolver := NewEnvironmentVariableResolver(env)
1238		err := cfg.configureProviders(env, resolver, knownProviders)
1239		require.NoError(t, err)
1240
1241		err = cfg.configureSelectedModels(knownProviders)
1242		require.NoError(t, err)
1243		large := cfg.Models[SelectedModelTypeLarge]
1244		require.Equal(t, "large-model", large.Model)
1245		require.Equal(t, "openai", large.Provider)
1246		require.Equal(t, int64(100), large.MaxTokens)
1247	})
1248}