load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"testing"
   9
  10	"github.com/charmbracelet/catwalk/pkg/catwalk"
  11	"github.com/charmbracelet/crush/internal/csync"
  12	"github.com/charmbracelet/crush/internal/env"
  13	"github.com/stretchr/testify/assert"
  14	"github.com/stretchr/testify/require"
  15)
  16
  17func TestMain(m *testing.M) {
  18	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  19
  20	exitVal := m.Run()
  21	os.Exit(exitVal)
  22}
  23
  24func TestConfig_setDefaults(t *testing.T) {
  25	cfg := &Config{}
  26
  27	cfg.setDefaults("/tmp", "")
  28
  29	require.NotNil(t, cfg.Options)
  30	require.NotNil(t, cfg.Options.TUI)
  31	require.NotNil(t, cfg.Options.ContextPaths)
  32	require.NotNil(t, cfg.Providers)
  33	require.NotNil(t, cfg.Models)
  34	require.NotNil(t, cfg.LSP)
  35	require.NotNil(t, cfg.MCP)
  36	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  37	for _, path := range defaultContextPaths {
  38		require.Contains(t, cfg.Options.ContextPaths, path)
  39	}
  40	require.Equal(t, "/tmp", cfg.workingDir)
  41}
  42
  43func TestConfig_configureProviders(t *testing.T) {
  44	knownProviders := []catwalk.Provider{
  45		{
  46			ID:          "openai",
  47			APIKey:      "$OPENAI_API_KEY",
  48			APIEndpoint: "https://api.openai.com/v1",
  49			Models: []catwalk.Model{{
  50				ID: "test-model",
  51			}},
  52		},
  53	}
  54
  55	cfg := &Config{}
  56	cfg.setDefaults("/tmp", "")
  57	env := env.NewFromMap(map[string]string{
  58		"OPENAI_API_KEY": "test-key",
  59	})
  60	resolver := NewEnvironmentVariableResolver(env)
  61	err := cfg.configureProviders(env, resolver, knownProviders)
  62	require.NoError(t, err)
  63	require.Equal(t, 1, cfg.Providers.Len())
  64
  65	// We want to make sure that we keep the configured API key as a placeholder
  66	pc, _ := cfg.Providers.Get("openai")
  67	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  68}
  69
  70func TestConfig_configureProvidersWithOverride(t *testing.T) {
  71	knownProviders := []catwalk.Provider{
  72		{
  73			ID:          "openai",
  74			APIKey:      "$OPENAI_API_KEY",
  75			APIEndpoint: "https://api.openai.com/v1",
  76			Models: []catwalk.Model{{
  77				ID: "test-model",
  78			}},
  79		},
  80	}
  81
  82	cfg := &Config{
  83		Providers: csync.NewMap[string, ProviderConfig](),
  84	}
  85	cfg.Providers.Set("openai", ProviderConfig{
  86		APIKey:  "xyz",
  87		BaseURL: "https://api.openai.com/v2",
  88		Models: []catwalk.Model{
  89			{
  90				ID:   "test-model",
  91				Name: "Updated",
  92			},
  93			{
  94				ID: "another-model",
  95			},
  96		},
  97	})
  98	cfg.setDefaults("/tmp", "")
  99
 100	env := env.NewFromMap(map[string]string{
 101		"OPENAI_API_KEY": "test-key",
 102	})
 103	resolver := NewEnvironmentVariableResolver(env)
 104	err := cfg.configureProviders(env, resolver, knownProviders)
 105	require.NoError(t, err)
 106	require.Equal(t, 1, cfg.Providers.Len())
 107
 108	// We want to make sure that we keep the configured API key as a placeholder
 109	pc, _ := cfg.Providers.Get("openai")
 110	require.Equal(t, "xyz", pc.APIKey)
 111	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 112	require.Len(t, pc.Models, 2)
 113	require.Equal(t, "Updated", pc.Models[0].Name)
 114}
 115
 116func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 117	knownProviders := []catwalk.Provider{
 118		{
 119			ID:          "openai",
 120			APIKey:      "$OPENAI_API_KEY",
 121			APIEndpoint: "https://api.openai.com/v1",
 122			Models: []catwalk.Model{{
 123				ID: "test-model",
 124			}},
 125		},
 126	}
 127
 128	cfg := &Config{
 129		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 130			"custom": {
 131				APIKey:  "xyz",
 132				BaseURL: "https://api.someendpoint.com/v2",
 133				Models: []catwalk.Model{
 134					{
 135						ID: "test-model",
 136					},
 137				},
 138			},
 139		}),
 140	}
 141	cfg.setDefaults("/tmp", "")
 142	env := env.NewFromMap(map[string]string{
 143		"OPENAI_API_KEY": "test-key",
 144	})
 145	resolver := NewEnvironmentVariableResolver(env)
 146	err := cfg.configureProviders(env, resolver, knownProviders)
 147	require.NoError(t, err)
 148	// Should be to because of the env variable
 149	require.Equal(t, cfg.Providers.Len(), 2)
 150
 151	// We want to make sure that we keep the configured API key as a placeholder
 152	pc, _ := cfg.Providers.Get("custom")
 153	require.Equal(t, "xyz", pc.APIKey)
 154	// Make sure we set the ID correctly
 155	require.Equal(t, "custom", pc.ID)
 156	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 157	require.Len(t, pc.Models, 1)
 158
 159	_, ok := cfg.Providers.Get("openai")
 160	require.True(t, ok, "OpenAI provider should still be present")
 161}
 162
 163func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 164	knownProviders := []catwalk.Provider{
 165		{
 166			ID:          catwalk.InferenceProviderBedrock,
 167			APIKey:      "",
 168			APIEndpoint: "",
 169			Models: []catwalk.Model{{
 170				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 171			}},
 172		},
 173	}
 174
 175	cfg := &Config{}
 176	cfg.setDefaults("/tmp", "")
 177	env := env.NewFromMap(map[string]string{
 178		"AWS_ACCESS_KEY_ID":     "test-key-id",
 179		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 180	})
 181	resolver := NewEnvironmentVariableResolver(env)
 182	err := cfg.configureProviders(env, resolver, knownProviders)
 183	require.NoError(t, err)
 184	require.Equal(t, cfg.Providers.Len(), 1)
 185
 186	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 187	require.True(t, ok, "Bedrock provider should be present")
 188	require.Len(t, bedrockProvider.Models, 1)
 189	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 190}
 191
 192func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 193	knownProviders := []catwalk.Provider{
 194		{
 195			ID:          catwalk.InferenceProviderBedrock,
 196			APIKey:      "",
 197			APIEndpoint: "",
 198			Models: []catwalk.Model{{
 199				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 200			}},
 201		},
 202	}
 203
 204	cfg := &Config{}
 205	cfg.setDefaults("/tmp", "")
 206	env := env.NewFromMap(map[string]string{})
 207	resolver := NewEnvironmentVariableResolver(env)
 208	err := cfg.configureProviders(env, resolver, knownProviders)
 209	require.NoError(t, err)
 210	// Provider should not be configured without credentials
 211	require.Equal(t, cfg.Providers.Len(), 0)
 212}
 213
 214func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 215	knownProviders := []catwalk.Provider{
 216		{
 217			ID:          catwalk.InferenceProviderBedrock,
 218			APIKey:      "",
 219			APIEndpoint: "",
 220			Models: []catwalk.Model{{
 221				ID: "some-random-model",
 222			}},
 223		},
 224	}
 225
 226	cfg := &Config{}
 227	cfg.setDefaults("/tmp", "")
 228	env := env.NewFromMap(map[string]string{
 229		"AWS_ACCESS_KEY_ID":     "test-key-id",
 230		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 231	})
 232	resolver := NewEnvironmentVariableResolver(env)
 233	err := cfg.configureProviders(env, resolver, knownProviders)
 234	require.Error(t, err)
 235}
 236
 237func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 238	knownProviders := []catwalk.Provider{
 239		{
 240			ID:          catwalk.InferenceProviderVertexAI,
 241			APIKey:      "",
 242			APIEndpoint: "",
 243			Models: []catwalk.Model{{
 244				ID: "gemini-pro",
 245			}},
 246		},
 247	}
 248
 249	cfg := &Config{}
 250	cfg.setDefaults("/tmp", "")
 251	env := env.NewFromMap(map[string]string{
 252		"VERTEXAI_PROJECT":  "test-project",
 253		"VERTEXAI_LOCATION": "us-central1",
 254	})
 255	resolver := NewEnvironmentVariableResolver(env)
 256	err := cfg.configureProviders(env, resolver, knownProviders)
 257	require.NoError(t, err)
 258	require.Equal(t, cfg.Providers.Len(), 1)
 259
 260	vertexProvider, ok := cfg.Providers.Get("vertexai")
 261	require.True(t, ok, "VertexAI provider should be present")
 262	require.Len(t, vertexProvider.Models, 1)
 263	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 264	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 265	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 266}
 267
 268func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 269	knownProviders := []catwalk.Provider{
 270		{
 271			ID:          catwalk.InferenceProviderVertexAI,
 272			APIKey:      "",
 273			APIEndpoint: "",
 274			Models: []catwalk.Model{{
 275				ID: "gemini-pro",
 276			}},
 277		},
 278	}
 279
 280	cfg := &Config{}
 281	cfg.setDefaults("/tmp", "")
 282	env := env.NewFromMap(map[string]string{
 283		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 284		"GOOGLE_CLOUD_PROJECT":      "test-project",
 285		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 286	})
 287	resolver := NewEnvironmentVariableResolver(env)
 288	err := cfg.configureProviders(env, resolver, knownProviders)
 289	require.NoError(t, err)
 290	// Provider should not be configured without proper credentials
 291	require.Equal(t, cfg.Providers.Len(), 0)
 292}
 293
 294func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 295	knownProviders := []catwalk.Provider{
 296		{
 297			ID:          catwalk.InferenceProviderVertexAI,
 298			APIKey:      "",
 299			APIEndpoint: "",
 300			Models: []catwalk.Model{{
 301				ID: "gemini-pro",
 302			}},
 303		},
 304	}
 305
 306	cfg := &Config{}
 307	cfg.setDefaults("/tmp", "")
 308	env := env.NewFromMap(map[string]string{
 309		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 310		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 311	})
 312	resolver := NewEnvironmentVariableResolver(env)
 313	err := cfg.configureProviders(env, resolver, knownProviders)
 314	require.NoError(t, err)
 315	// Provider should not be configured without project
 316	require.Equal(t, cfg.Providers.Len(), 0)
 317}
 318
 319func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 320	knownProviders := []catwalk.Provider{
 321		{
 322			ID:          "openai",
 323			APIKey:      "$OPENAI_API_KEY",
 324			APIEndpoint: "https://api.openai.com/v1",
 325			Models: []catwalk.Model{{
 326				ID: "test-model",
 327			}},
 328		},
 329	}
 330
 331	cfg := &Config{}
 332	cfg.setDefaults("/tmp", "")
 333	env := env.NewFromMap(map[string]string{
 334		"OPENAI_API_KEY": "test-key",
 335	})
 336	resolver := NewEnvironmentVariableResolver(env)
 337	err := cfg.configureProviders(env, resolver, knownProviders)
 338	require.NoError(t, err)
 339	require.Equal(t, cfg.Providers.Len(), 1)
 340
 341	// Provider ID should be set
 342	pc, _ := cfg.Providers.Get("openai")
 343	require.Equal(t, "openai", pc.ID)
 344}
 345
 346func TestConfig_EnabledProviders(t *testing.T) {
 347	t.Run("all providers enabled", func(t *testing.T) {
 348		cfg := &Config{
 349			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 350				"openai": {
 351					ID:      "openai",
 352					APIKey:  "key1",
 353					Disable: false,
 354				},
 355				"anthropic": {
 356					ID:      "anthropic",
 357					APIKey:  "key2",
 358					Disable: false,
 359				},
 360			}),
 361		}
 362
 363		enabled := cfg.EnabledProviders()
 364		require.Len(t, enabled, 2)
 365	})
 366
 367	t.Run("some providers disabled", func(t *testing.T) {
 368		cfg := &Config{
 369			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 370				"openai": {
 371					ID:      "openai",
 372					APIKey:  "key1",
 373					Disable: false,
 374				},
 375				"anthropic": {
 376					ID:      "anthropic",
 377					APIKey:  "key2",
 378					Disable: true,
 379				},
 380			}),
 381		}
 382
 383		enabled := cfg.EnabledProviders()
 384		require.Len(t, enabled, 1)
 385		require.Equal(t, "openai", enabled[0].ID)
 386	})
 387
 388	t.Run("empty providers map", func(t *testing.T) {
 389		cfg := &Config{
 390			Providers: csync.NewMap[string, ProviderConfig](),
 391		}
 392
 393		enabled := cfg.EnabledProviders()
 394		require.Len(t, enabled, 0)
 395	})
 396}
 397
 398func TestConfig_IsConfigured(t *testing.T) {
 399	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 400		cfg := &Config{
 401			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 402				"openai": {
 403					ID:      "openai",
 404					APIKey:  "key1",
 405					Disable: false,
 406				},
 407			}),
 408		}
 409
 410		require.True(t, cfg.IsConfigured())
 411	})
 412
 413	t.Run("returns false when no providers are configured", func(t *testing.T) {
 414		cfg := &Config{
 415			Providers: csync.NewMap[string, ProviderConfig](),
 416		}
 417
 418		require.False(t, cfg.IsConfigured())
 419	})
 420
 421	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 422		cfg := &Config{
 423			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 424				"openai": {
 425					ID:      "openai",
 426					APIKey:  "key1",
 427					Disable: true,
 428				},
 429				"anthropic": {
 430					ID:      "anthropic",
 431					APIKey:  "key2",
 432					Disable: true,
 433				},
 434			}),
 435		}
 436
 437		require.False(t, cfg.IsConfigured())
 438	})
 439}
 440
 441func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 442	cfg := &Config{
 443		Options: &Options{
 444			DisabledTools: []string{},
 445		},
 446	}
 447
 448	cfg.SetupAgents()
 449	coderAgent, ok := cfg.Agents["coder"]
 450	require.True(t, ok)
 451	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 452
 453	taskAgent, ok := cfg.Agents["task"]
 454	require.True(t, ok)
 455	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 456}
 457
 458func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 459	cfg := &Config{
 460		Options: &Options{
 461			DisabledTools: []string{
 462				"edit",
 463				"download",
 464				"grep",
 465			},
 466		},
 467	}
 468
 469	cfg.SetupAgents()
 470	coderAgent, ok := cfg.Agents["coder"]
 471	require.True(t, ok)
 472	assert.Equal(t, []string{"agent", "bash", "multiedit", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
 473
 474	taskAgent, ok := cfg.Agents["task"]
 475	require.True(t, ok)
 476	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 477}
 478
 479func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 480	cfg := &Config{
 481		Options: &Options{
 482			DisabledTools: []string{
 483				"glob",
 484				"grep",
 485				"ls",
 486				"sourcegraph",
 487				"view",
 488			},
 489		},
 490	}
 491
 492	cfg.SetupAgents()
 493	coderAgent, ok := cfg.Agents["coder"]
 494	require.True(t, ok)
 495	assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "fetch", "write"}, coderAgent.AllowedTools)
 496
 497	taskAgent, ok := cfg.Agents["task"]
 498	require.True(t, ok)
 499	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 500}
 501
 502func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 503	knownProviders := []catwalk.Provider{
 504		{
 505			ID:          "openai",
 506			APIKey:      "$OPENAI_API_KEY",
 507			APIEndpoint: "https://api.openai.com/v1",
 508			Models: []catwalk.Model{{
 509				ID: "test-model",
 510			}},
 511		},
 512	}
 513
 514	cfg := &Config{
 515		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 516			"openai": {
 517				Disable: true,
 518			},
 519		}),
 520	}
 521	cfg.setDefaults("/tmp", "")
 522
 523	env := env.NewFromMap(map[string]string{
 524		"OPENAI_API_KEY": "test-key",
 525	})
 526	resolver := NewEnvironmentVariableResolver(env)
 527	err := cfg.configureProviders(env, resolver, knownProviders)
 528	require.NoError(t, err)
 529
 530	require.Equal(t, cfg.Providers.Len(), 1)
 531	prov, exists := cfg.Providers.Get("openai")
 532	require.True(t, exists)
 533	require.True(t, prov.Disable)
 534}
 535
 536func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 537	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 538		cfg := &Config{
 539			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 540				"custom": {
 541					BaseURL: "https://api.custom.com/v1",
 542					Models: []catwalk.Model{{
 543						ID: "test-model",
 544					}},
 545				},
 546				"openai": {
 547					APIKey: "$MISSING",
 548				},
 549			}),
 550		}
 551		cfg.setDefaults("/tmp", "")
 552
 553		env := env.NewFromMap(map[string]string{})
 554		resolver := NewEnvironmentVariableResolver(env)
 555		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 556		require.NoError(t, err)
 557
 558		require.Equal(t, cfg.Providers.Len(), 1)
 559		_, exists := cfg.Providers.Get("custom")
 560		require.True(t, exists)
 561	})
 562
 563	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 564		cfg := &Config{
 565			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 566				"custom": {
 567					APIKey: "test-key",
 568					Models: []catwalk.Model{{
 569						ID: "test-model",
 570					}},
 571				},
 572			}),
 573		}
 574		cfg.setDefaults("/tmp", "")
 575
 576		env := env.NewFromMap(map[string]string{})
 577		resolver := NewEnvironmentVariableResolver(env)
 578		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 579		require.NoError(t, err)
 580
 581		require.Equal(t, cfg.Providers.Len(), 0)
 582		_, exists := cfg.Providers.Get("custom")
 583		require.False(t, exists)
 584	})
 585
 586	t.Run("custom provider with no models is removed", func(t *testing.T) {
 587		cfg := &Config{
 588			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 589				"custom": {
 590					APIKey:  "test-key",
 591					BaseURL: "https://api.custom.com/v1",
 592					Models:  []catwalk.Model{},
 593				},
 594			}),
 595		}
 596		cfg.setDefaults("/tmp", "")
 597
 598		env := env.NewFromMap(map[string]string{})
 599		resolver := NewEnvironmentVariableResolver(env)
 600		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 601		require.NoError(t, err)
 602
 603		require.Equal(t, cfg.Providers.Len(), 0)
 604		_, exists := cfg.Providers.Get("custom")
 605		require.False(t, exists)
 606	})
 607
 608	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 609		cfg := &Config{
 610			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 611				"custom": {
 612					APIKey:  "test-key",
 613					BaseURL: "https://api.custom.com/v1",
 614					Type:    "unsupported",
 615					Models: []catwalk.Model{{
 616						ID: "test-model",
 617					}},
 618				},
 619			}),
 620		}
 621		cfg.setDefaults("/tmp", "")
 622
 623		env := env.NewFromMap(map[string]string{})
 624		resolver := NewEnvironmentVariableResolver(env)
 625		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 626		require.NoError(t, err)
 627
 628		require.Equal(t, cfg.Providers.Len(), 0)
 629		_, exists := cfg.Providers.Get("custom")
 630		require.False(t, exists)
 631	})
 632
 633	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 634		cfg := &Config{
 635			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 636				"custom": {
 637					APIKey:  "test-key",
 638					BaseURL: "https://api.custom.com/v1",
 639					Type:    catwalk.TypeOpenAI,
 640					Models: []catwalk.Model{{
 641						ID: "test-model",
 642					}},
 643				},
 644			}),
 645		}
 646		cfg.setDefaults("/tmp", "")
 647
 648		env := env.NewFromMap(map[string]string{})
 649		resolver := NewEnvironmentVariableResolver(env)
 650		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 651		require.NoError(t, err)
 652
 653		require.Equal(t, cfg.Providers.Len(), 1)
 654		customProvider, exists := cfg.Providers.Get("custom")
 655		require.True(t, exists)
 656		require.Equal(t, "custom", customProvider.ID)
 657		require.Equal(t, "test-key", customProvider.APIKey)
 658		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 659	})
 660
 661	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 662		cfg := &Config{
 663			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 664				"custom-anthropic": {
 665					APIKey:  "test-key",
 666					BaseURL: "https://api.anthropic.com/v1",
 667					Type:    catwalk.TypeAnthropic,
 668					Models: []catwalk.Model{{
 669						ID: "claude-3-sonnet",
 670					}},
 671				},
 672			}),
 673		}
 674		cfg.setDefaults("/tmp", "")
 675
 676		env := env.NewFromMap(map[string]string{})
 677		resolver := NewEnvironmentVariableResolver(env)
 678		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 679		require.NoError(t, err)
 680
 681		require.Equal(t, cfg.Providers.Len(), 1)
 682		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 683		require.True(t, exists)
 684		require.Equal(t, "custom-anthropic", customProvider.ID)
 685		require.Equal(t, "test-key", customProvider.APIKey)
 686		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 687		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 688	})
 689
 690	t.Run("disabled custom provider is removed", func(t *testing.T) {
 691		cfg := &Config{
 692			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 693				"custom": {
 694					APIKey:  "test-key",
 695					BaseURL: "https://api.custom.com/v1",
 696					Type:    catwalk.TypeOpenAI,
 697					Disable: true,
 698					Models: []catwalk.Model{{
 699						ID: "test-model",
 700					}},
 701				},
 702			}),
 703		}
 704		cfg.setDefaults("/tmp", "")
 705
 706		env := env.NewFromMap(map[string]string{})
 707		resolver := NewEnvironmentVariableResolver(env)
 708		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 709		require.NoError(t, err)
 710
 711		require.Equal(t, cfg.Providers.Len(), 0)
 712		_, exists := cfg.Providers.Get("custom")
 713		require.False(t, exists)
 714	})
 715}
 716
 717func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 718	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 719		knownProviders := []catwalk.Provider{
 720			{
 721				ID:          catwalk.InferenceProviderVertexAI,
 722				APIKey:      "",
 723				APIEndpoint: "",
 724				Models: []catwalk.Model{{
 725					ID: "gemini-pro",
 726				}},
 727			},
 728		}
 729
 730		cfg := &Config{
 731			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 732				"vertexai": {
 733					BaseURL: "custom-url",
 734				},
 735			}),
 736		}
 737		cfg.setDefaults("/tmp", "")
 738
 739		env := env.NewFromMap(map[string]string{
 740			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 741		})
 742		resolver := NewEnvironmentVariableResolver(env)
 743		err := cfg.configureProviders(env, resolver, knownProviders)
 744		require.NoError(t, err)
 745
 746		require.Equal(t, cfg.Providers.Len(), 0)
 747		_, exists := cfg.Providers.Get("vertexai")
 748		require.False(t, exists)
 749	})
 750
 751	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 752		knownProviders := []catwalk.Provider{
 753			{
 754				ID:          catwalk.InferenceProviderBedrock,
 755				APIKey:      "",
 756				APIEndpoint: "",
 757				Models: []catwalk.Model{{
 758					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 759				}},
 760			},
 761		}
 762
 763		cfg := &Config{
 764			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 765				"bedrock": {
 766					BaseURL: "custom-url",
 767				},
 768			}),
 769		}
 770		cfg.setDefaults("/tmp", "")
 771
 772		env := env.NewFromMap(map[string]string{})
 773		resolver := NewEnvironmentVariableResolver(env)
 774		err := cfg.configureProviders(env, resolver, knownProviders)
 775		require.NoError(t, err)
 776
 777		require.Equal(t, cfg.Providers.Len(), 0)
 778		_, exists := cfg.Providers.Get("bedrock")
 779		require.False(t, exists)
 780	})
 781
 782	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 783		knownProviders := []catwalk.Provider{
 784			{
 785				ID:          "openai",
 786				APIKey:      "$MISSING_API_KEY",
 787				APIEndpoint: "https://api.openai.com/v1",
 788				Models: []catwalk.Model{{
 789					ID: "test-model",
 790				}},
 791			},
 792		}
 793
 794		cfg := &Config{
 795			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 796				"openai": {
 797					BaseURL: "custom-url",
 798				},
 799			}),
 800		}
 801		cfg.setDefaults("/tmp", "")
 802
 803		env := env.NewFromMap(map[string]string{})
 804		resolver := NewEnvironmentVariableResolver(env)
 805		err := cfg.configureProviders(env, resolver, knownProviders)
 806		require.NoError(t, err)
 807
 808		require.Equal(t, cfg.Providers.Len(), 0)
 809		_, exists := cfg.Providers.Get("openai")
 810		require.False(t, exists)
 811	})
 812
 813	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 814		knownProviders := []catwalk.Provider{
 815			{
 816				ID:          "openai",
 817				APIKey:      "$OPENAI_API_KEY",
 818				APIEndpoint: "$MISSING_ENDPOINT",
 819				Models: []catwalk.Model{{
 820					ID: "test-model",
 821				}},
 822			},
 823		}
 824
 825		cfg := &Config{
 826			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 827				"openai": {
 828					APIKey: "test-key",
 829				},
 830			}),
 831		}
 832		cfg.setDefaults("/tmp", "")
 833
 834		env := env.NewFromMap(map[string]string{
 835			"OPENAI_API_KEY": "test-key",
 836		})
 837		resolver := NewEnvironmentVariableResolver(env)
 838		err := cfg.configureProviders(env, resolver, knownProviders)
 839		require.NoError(t, err)
 840
 841		require.Equal(t, cfg.Providers.Len(), 1)
 842		_, exists := cfg.Providers.Get("openai")
 843		require.True(t, exists)
 844	})
 845}
 846
 847func TestConfig_defaultModelSelection(t *testing.T) {
 848	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 849		knownProviders := []catwalk.Provider{
 850			{
 851				ID:                  "openai",
 852				APIKey:              "abc",
 853				DefaultLargeModelID: "large-model",
 854				DefaultSmallModelID: "small-model",
 855				Models: []catwalk.Model{
 856					{
 857						ID:               "large-model",
 858						DefaultMaxTokens: 1000,
 859					},
 860					{
 861						ID:               "small-model",
 862						DefaultMaxTokens: 500,
 863					},
 864				},
 865			},
 866		}
 867
 868		cfg := &Config{}
 869		cfg.setDefaults("/tmp", "")
 870		env := env.NewFromMap(map[string]string{})
 871		resolver := NewEnvironmentVariableResolver(env)
 872		err := cfg.configureProviders(env, resolver, knownProviders)
 873		require.NoError(t, err)
 874
 875		large, small, err := cfg.defaultModelSelection(knownProviders)
 876		require.NoError(t, err)
 877		require.Equal(t, "large-model", large.Model)
 878		require.Equal(t, "openai", large.Provider)
 879		require.Equal(t, int64(1000), large.MaxTokens)
 880		require.Equal(t, "small-model", small.Model)
 881		require.Equal(t, "openai", small.Provider)
 882		require.Equal(t, int64(500), small.MaxTokens)
 883	})
 884	t.Run("should error if no providers configured", func(t *testing.T) {
 885		knownProviders := []catwalk.Provider{
 886			{
 887				ID:                  "openai",
 888				APIKey:              "$MISSING_KEY",
 889				DefaultLargeModelID: "large-model",
 890				DefaultSmallModelID: "small-model",
 891				Models: []catwalk.Model{
 892					{
 893						ID:               "large-model",
 894						DefaultMaxTokens: 1000,
 895					},
 896					{
 897						ID:               "small-model",
 898						DefaultMaxTokens: 500,
 899					},
 900				},
 901			},
 902		}
 903
 904		cfg := &Config{}
 905		cfg.setDefaults("/tmp", "")
 906		env := env.NewFromMap(map[string]string{})
 907		resolver := NewEnvironmentVariableResolver(env)
 908		err := cfg.configureProviders(env, resolver, knownProviders)
 909		require.NoError(t, err)
 910
 911		_, _, err = cfg.defaultModelSelection(knownProviders)
 912		require.Error(t, err)
 913	})
 914	t.Run("should error if model is missing", func(t *testing.T) {
 915		knownProviders := []catwalk.Provider{
 916			{
 917				ID:                  "openai",
 918				APIKey:              "abc",
 919				DefaultLargeModelID: "large-model",
 920				DefaultSmallModelID: "small-model",
 921				Models: []catwalk.Model{
 922					{
 923						ID:               "not-large-model",
 924						DefaultMaxTokens: 1000,
 925					},
 926					{
 927						ID:               "small-model",
 928						DefaultMaxTokens: 500,
 929					},
 930				},
 931			},
 932		}
 933
 934		cfg := &Config{}
 935		cfg.setDefaults("/tmp", "")
 936		env := env.NewFromMap(map[string]string{})
 937		resolver := NewEnvironmentVariableResolver(env)
 938		err := cfg.configureProviders(env, resolver, knownProviders)
 939		require.NoError(t, err)
 940		_, _, err = cfg.defaultModelSelection(knownProviders)
 941		require.Error(t, err)
 942	})
 943
 944	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 945		knownProviders := []catwalk.Provider{
 946			{
 947				ID:                  "openai",
 948				APIKey:              "$MISSING", // will not be included in the config
 949				DefaultLargeModelID: "large-model",
 950				DefaultSmallModelID: "small-model",
 951				Models: []catwalk.Model{
 952					{
 953						ID:               "not-large-model",
 954						DefaultMaxTokens: 1000,
 955					},
 956					{
 957						ID:               "small-model",
 958						DefaultMaxTokens: 500,
 959					},
 960				},
 961			},
 962		}
 963
 964		cfg := &Config{
 965			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 966				"custom": {
 967					APIKey:  "test-key",
 968					BaseURL: "https://api.custom.com/v1",
 969					Models: []catwalk.Model{
 970						{
 971							ID:               "model",
 972							DefaultMaxTokens: 600,
 973						},
 974					},
 975				},
 976			}),
 977		}
 978		cfg.setDefaults("/tmp", "")
 979		env := env.NewFromMap(map[string]string{})
 980		resolver := NewEnvironmentVariableResolver(env)
 981		err := cfg.configureProviders(env, resolver, knownProviders)
 982		require.NoError(t, err)
 983		large, small, err := cfg.defaultModelSelection(knownProviders)
 984		require.NoError(t, err)
 985		require.Equal(t, "model", large.Model)
 986		require.Equal(t, "custom", large.Provider)
 987		require.Equal(t, int64(600), large.MaxTokens)
 988		require.Equal(t, "model", small.Model)
 989		require.Equal(t, "custom", small.Provider)
 990		require.Equal(t, int64(600), small.MaxTokens)
 991	})
 992
 993	t.Run("should fail if no model configured", func(t *testing.T) {
 994		knownProviders := []catwalk.Provider{
 995			{
 996				ID:                  "openai",
 997				APIKey:              "$MISSING", // will not be included in the config
 998				DefaultLargeModelID: "large-model",
 999				DefaultSmallModelID: "small-model",
1000				Models: []catwalk.Model{
1001					{
1002						ID:               "not-large-model",
1003						DefaultMaxTokens: 1000,
1004					},
1005					{
1006						ID:               "small-model",
1007						DefaultMaxTokens: 500,
1008					},
1009				},
1010			},
1011		}
1012
1013		cfg := &Config{
1014			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1015				"custom": {
1016					APIKey:  "test-key",
1017					BaseURL: "https://api.custom.com/v1",
1018					Models:  []catwalk.Model{},
1019				},
1020			}),
1021		}
1022		cfg.setDefaults("/tmp", "")
1023		env := env.NewFromMap(map[string]string{})
1024		resolver := NewEnvironmentVariableResolver(env)
1025		err := cfg.configureProviders(env, resolver, knownProviders)
1026		require.NoError(t, err)
1027		_, _, err = cfg.defaultModelSelection(knownProviders)
1028		require.Error(t, err)
1029	})
1030	t.Run("should use the default provider first", func(t *testing.T) {
1031		knownProviders := []catwalk.Provider{
1032			{
1033				ID:                  "openai",
1034				APIKey:              "set",
1035				DefaultLargeModelID: "large-model",
1036				DefaultSmallModelID: "small-model",
1037				Models: []catwalk.Model{
1038					{
1039						ID:               "large-model",
1040						DefaultMaxTokens: 1000,
1041					},
1042					{
1043						ID:               "small-model",
1044						DefaultMaxTokens: 500,
1045					},
1046				},
1047			},
1048		}
1049
1050		cfg := &Config{
1051			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1052				"custom": {
1053					APIKey:  "test-key",
1054					BaseURL: "https://api.custom.com/v1",
1055					Models: []catwalk.Model{
1056						{
1057							ID:               "large-model",
1058							DefaultMaxTokens: 1000,
1059						},
1060					},
1061				},
1062			}),
1063		}
1064		cfg.setDefaults("/tmp", "")
1065		env := env.NewFromMap(map[string]string{})
1066		resolver := NewEnvironmentVariableResolver(env)
1067		err := cfg.configureProviders(env, resolver, knownProviders)
1068		require.NoError(t, err)
1069		large, small, err := cfg.defaultModelSelection(knownProviders)
1070		require.NoError(t, err)
1071		require.Equal(t, "large-model", large.Model)
1072		require.Equal(t, "openai", large.Provider)
1073		require.Equal(t, int64(1000), large.MaxTokens)
1074		require.Equal(t, "small-model", small.Model)
1075		require.Equal(t, "openai", small.Provider)
1076		require.Equal(t, int64(500), small.MaxTokens)
1077	})
1078}
1079
1080func TestConfig_configureSelectedModels(t *testing.T) {
1081	t.Run("should override defaults", func(t *testing.T) {
1082		knownProviders := []catwalk.Provider{
1083			{
1084				ID:                  "openai",
1085				APIKey:              "abc",
1086				DefaultLargeModelID: "large-model",
1087				DefaultSmallModelID: "small-model",
1088				Models: []catwalk.Model{
1089					{
1090						ID:               "larger-model",
1091						DefaultMaxTokens: 2000,
1092					},
1093					{
1094						ID:               "large-model",
1095						DefaultMaxTokens: 1000,
1096					},
1097					{
1098						ID:               "small-model",
1099						DefaultMaxTokens: 500,
1100					},
1101				},
1102			},
1103		}
1104
1105		cfg := &Config{
1106			Models: map[SelectedModelType]SelectedModel{
1107				"large": {
1108					Model: "larger-model",
1109				},
1110			},
1111		}
1112		cfg.setDefaults("/tmp", "")
1113		env := env.NewFromMap(map[string]string{})
1114		resolver := NewEnvironmentVariableResolver(env)
1115		err := cfg.configureProviders(env, resolver, knownProviders)
1116		require.NoError(t, err)
1117
1118		err = cfg.configureSelectedModels(knownProviders)
1119		require.NoError(t, err)
1120		large := cfg.Models[SelectedModelTypeLarge]
1121		small := cfg.Models[SelectedModelTypeSmall]
1122		require.Equal(t, "larger-model", large.Model)
1123		require.Equal(t, "openai", large.Provider)
1124		require.Equal(t, int64(2000), large.MaxTokens)
1125		require.Equal(t, "small-model", small.Model)
1126		require.Equal(t, "openai", small.Provider)
1127		require.Equal(t, int64(500), small.MaxTokens)
1128	})
1129	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1130		knownProviders := []catwalk.Provider{
1131			{
1132				ID:                  "openai",
1133				APIKey:              "abc",
1134				DefaultLargeModelID: "large-model",
1135				DefaultSmallModelID: "small-model",
1136				Models: []catwalk.Model{
1137					{
1138						ID:               "large-model",
1139						DefaultMaxTokens: 1000,
1140					},
1141					{
1142						ID:               "small-model",
1143						DefaultMaxTokens: 500,
1144					},
1145				},
1146			},
1147			{
1148				ID:                  "anthropic",
1149				APIKey:              "abc",
1150				DefaultLargeModelID: "a-large-model",
1151				DefaultSmallModelID: "a-small-model",
1152				Models: []catwalk.Model{
1153					{
1154						ID:               "a-large-model",
1155						DefaultMaxTokens: 1000,
1156					},
1157					{
1158						ID:               "a-small-model",
1159						DefaultMaxTokens: 200,
1160					},
1161				},
1162			},
1163		}
1164
1165		cfg := &Config{
1166			Models: map[SelectedModelType]SelectedModel{
1167				"small": {
1168					Model:     "a-small-model",
1169					Provider:  "anthropic",
1170					MaxTokens: 300,
1171				},
1172			},
1173		}
1174		cfg.setDefaults("/tmp", "")
1175		env := env.NewFromMap(map[string]string{})
1176		resolver := NewEnvironmentVariableResolver(env)
1177		err := cfg.configureProviders(env, resolver, knownProviders)
1178		require.NoError(t, err)
1179
1180		err = cfg.configureSelectedModels(knownProviders)
1181		require.NoError(t, err)
1182		large := cfg.Models[SelectedModelTypeLarge]
1183		small := cfg.Models[SelectedModelTypeSmall]
1184		require.Equal(t, "large-model", large.Model)
1185		require.Equal(t, "openai", large.Provider)
1186		require.Equal(t, int64(1000), large.MaxTokens)
1187		require.Equal(t, "a-small-model", small.Model)
1188		require.Equal(t, "anthropic", small.Provider)
1189		require.Equal(t, int64(300), small.MaxTokens)
1190	})
1191
1192	t.Run("should override the max tokens only", func(t *testing.T) {
1193		knownProviders := []catwalk.Provider{
1194			{
1195				ID:                  "openai",
1196				APIKey:              "abc",
1197				DefaultLargeModelID: "large-model",
1198				DefaultSmallModelID: "small-model",
1199				Models: []catwalk.Model{
1200					{
1201						ID:               "large-model",
1202						DefaultMaxTokens: 1000,
1203					},
1204					{
1205						ID:               "small-model",
1206						DefaultMaxTokens: 500,
1207					},
1208				},
1209			},
1210		}
1211
1212		cfg := &Config{
1213			Models: map[SelectedModelType]SelectedModel{
1214				"large": {
1215					MaxTokens: 100,
1216				},
1217			},
1218		}
1219		cfg.setDefaults("/tmp", "")
1220		env := env.NewFromMap(map[string]string{})
1221		resolver := NewEnvironmentVariableResolver(env)
1222		err := cfg.configureProviders(env, resolver, knownProviders)
1223		require.NoError(t, err)
1224
1225		err = cfg.configureSelectedModels(knownProviders)
1226		require.NoError(t, err)
1227		large := cfg.Models[SelectedModelTypeLarge]
1228		require.Equal(t, "large-model", large.Model)
1229		require.Equal(t, "openai", large.Provider)
1230		require.Equal(t, int64(100), large.MaxTokens)
1231	})
1232}