load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/stretchr/testify/assert"
  14	"github.com/stretchr/testify/require"
  15)
  16
  17func TestMain(m *testing.M) {
  18	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  19
  20	exitVal := m.Run()
  21	os.Exit(exitVal)
  22}
  23
  24func TestConfig_LoadFromReaders(t *testing.T) {
  25	data1 := strings.NewReader(`{"providers": {"openai": {"api_key=key1", "base_url": "https://api.openai.com/v1"}}}`)
  26	data2 := strings.NewReader(`{"providers": {"openai": {"api_key=key2", "base_url": "https://api.openai.com/v2"}}}`)
  27	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  28
  29	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  30
  31	require.NoError(t, err)
  32	require.NotNil(t, loadedConfig)
  33	require.Equal(t, 1, loadedConfig.Providers.Len())
  34	pc, _ := loadedConfig.Providers.Get("openai")
  35	require.Equal(t, "key2", pc.APIKey)
  36	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  37}
  38
  39func TestConfig_setDefaults(t *testing.T) {
  40	cfg := &Config{}
  41
  42	cfg.setDefaults("/tmp", "")
  43
  44	require.NotNil(t, cfg.Options)
  45	require.NotNil(t, cfg.Options.TUI)
  46	require.NotNil(t, cfg.Options.ContextPaths)
  47	require.NotNil(t, cfg.Providers)
  48	require.NotNil(t, cfg.Models)
  49	require.NotNil(t, cfg.LSP)
  50	require.NotNil(t, cfg.MCP)
  51	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  52	for _, path := range defaultContextPaths {
  53		require.Contains(t, cfg.Options.ContextPaths, path)
  54	}
  55	require.Equal(t, "/tmp", cfg.workingDir)
  56}
  57
  58func TestConfig_configureProviders(t *testing.T) {
  59	knownProviders := []catwalk.Provider{
  60		{
  61			ID:          "openai",
  62			APIKey:      "$OPENAI_API_KEY",
  63			APIEndpoint: "https://api.openai.com/v1",
  64			Models: []catwalk.Model{{
  65				ID: "test-model",
  66			}},
  67		},
  68	}
  69
  70	cfg := &Config{}
  71	cfg.setDefaults("/tmp", "")
  72	env := environ([]string{
  73		"OPENAI_API_KEY=test-key",
  74	})
  75	resolver := NewEnvironmentVariableResolver(env)
  76	err := cfg.configureProviders(env, resolver, knownProviders)
  77	require.NoError(t, err)
  78	require.Equal(t, 1, cfg.Providers.Len())
  79
  80	// We want to make sure that we keep the configured API key as a placeholder
  81	pc, _ := cfg.Providers.Get("openai")
  82	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  83}
  84
  85func TestConfig_configureProvidersWithOverride(t *testing.T) {
  86	knownProviders := []catwalk.Provider{
  87		{
  88			ID:          "openai",
  89			APIKey:      "$OPENAI_API_KEY",
  90			APIEndpoint: "https://api.openai.com/v1",
  91			Models: []catwalk.Model{{
  92				ID: "test-model",
  93			}},
  94		},
  95	}
  96
  97	cfg := &Config{
  98		Providers: csync.NewMap[string, ProviderConfig](),
  99	}
 100	cfg.Providers.Set("openai", ProviderConfig{
 101		APIKey:  "xyz",
 102		BaseURL: "https://api.openai.com/v2",
 103		Models: []catwalk.Model{
 104			{
 105				ID:   "test-model",
 106				Name: "Updated",
 107			},
 108			{
 109				ID: "another-model",
 110			},
 111		},
 112	})
 113	cfg.setDefaults("/tmp", "")
 114
 115	env := environ([]string{
 116		"OPENAI_API_KEY=test-key",
 117	})
 118	resolver := NewEnvironmentVariableResolver(env)
 119	err := cfg.configureProviders(env, resolver, knownProviders)
 120	require.NoError(t, err)
 121	require.Equal(t, 1, cfg.Providers.Len())
 122
 123	// We want to make sure that we keep the configured API key as a placeholder
 124	pc, _ := cfg.Providers.Get("openai")
 125	require.Equal(t, "xyz", pc.APIKey)
 126	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 127	require.Len(t, pc.Models, 2)
 128	require.Equal(t, "Updated", pc.Models[0].Name)
 129}
 130
 131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 132	knownProviders := []catwalk.Provider{
 133		{
 134			ID:          "openai",
 135			APIKey:      "$OPENAI_API_KEY",
 136			APIEndpoint: "https://api.openai.com/v1",
 137			Models: []catwalk.Model{{
 138				ID: "test-model",
 139			}},
 140		},
 141	}
 142
 143	cfg := &Config{
 144		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 145			"custom": {
 146				APIKey:  "xyz",
 147				BaseURL: "https://api.someendpoint.com/v2",
 148				Models: []catwalk.Model{
 149					{
 150						ID: "test-model",
 151					},
 152				},
 153			},
 154		}),
 155	}
 156	cfg.setDefaults("/tmp", "")
 157	env := environ([]string{
 158		"OPENAI_API_KEY=test-key",
 159	})
 160	resolver := NewEnvironmentVariableResolver(env)
 161	err := cfg.configureProviders(env, resolver, knownProviders)
 162	require.NoError(t, err)
 163	// Should be to because of the env variable
 164	require.Equal(t, cfg.Providers.Len(), 2)
 165
 166	// We want to make sure that we keep the configured API key as a placeholder
 167	pc, _ := cfg.Providers.Get("custom")
 168	require.Equal(t, "xyz", pc.APIKey)
 169	// Make sure we set the ID correctly
 170	require.Equal(t, "custom", pc.ID)
 171	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 172	require.Len(t, pc.Models, 1)
 173
 174	_, ok := cfg.Providers.Get("openai")
 175	require.True(t, ok, "OpenAI provider should still be present")
 176}
 177
 178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 179	knownProviders := []catwalk.Provider{
 180		{
 181			ID:          catwalk.InferenceProviderBedrock,
 182			APIKey:      "",
 183			APIEndpoint: "",
 184			Models: []catwalk.Model{{
 185				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 186			}},
 187		},
 188	}
 189
 190	cfg := &Config{}
 191	cfg.setDefaults("/tmp", "")
 192	env := environ([]string{
 193		"AWS_ACCESS_KEY_ID=test-key-id",
 194		"AWS_SECRET_ACCESS_KEY=test-secret-key",
 195	})
 196	resolver := NewEnvironmentVariableResolver(env)
 197	err := cfg.configureProviders(env, resolver, knownProviders)
 198	require.NoError(t, err)
 199	require.Equal(t, cfg.Providers.Len(), 1)
 200
 201	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 202	require.True(t, ok, "Bedrock provider should be present")
 203	require.Len(t, bedrockProvider.Models, 1)
 204	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 205}
 206
 207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 208	knownProviders := []catwalk.Provider{
 209		{
 210			ID:          catwalk.InferenceProviderBedrock,
 211			APIKey:      "",
 212			APIEndpoint: "",
 213			Models: []catwalk.Model{{
 214				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 215			}},
 216		},
 217	}
 218
 219	cfg := &Config{}
 220	cfg.setDefaults("/tmp", "")
 221	env := environ([]string{})
 222	resolver := NewEnvironmentVariableResolver(env)
 223	err := cfg.configureProviders(env, resolver, knownProviders)
 224	require.NoError(t, err)
 225	// Provider should not be configured without credentials
 226	require.Equal(t, cfg.Providers.Len(), 0)
 227}
 228
 229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 230	knownProviders := []catwalk.Provider{
 231		{
 232			ID:          catwalk.InferenceProviderBedrock,
 233			APIKey:      "",
 234			APIEndpoint: "",
 235			Models: []catwalk.Model{{
 236				ID: "some-random-model",
 237			}},
 238		},
 239	}
 240
 241	cfg := &Config{}
 242	cfg.setDefaults("/tmp", "")
 243	env := environ([]string{
 244		"AWS_ACCESS_KEY_ID=test-key-id",
 245		"AWS_SECRET_ACCESS_KEY=test-secret-key",
 246	})
 247	resolver := NewEnvironmentVariableResolver(env)
 248	err := cfg.configureProviders(env, resolver, knownProviders)
 249	require.Error(t, err)
 250}
 251
 252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 253	knownProviders := []catwalk.Provider{
 254		{
 255			ID:          catwalk.InferenceProviderVertexAI,
 256			APIKey:      "",
 257			APIEndpoint: "",
 258			Models: []catwalk.Model{{
 259				ID: "gemini-pro",
 260			}},
 261		},
 262	}
 263
 264	cfg := &Config{}
 265	cfg.setDefaults("/tmp", "")
 266	env := environ([]string{
 267		"VERTEXAI_PROJECT=test-project",
 268		"VERTEXAI_LOCATION=us-central1",
 269	})
 270	resolver := NewEnvironmentVariableResolver(env)
 271	err := cfg.configureProviders(env, resolver, knownProviders)
 272	require.NoError(t, err)
 273	require.Equal(t, cfg.Providers.Len(), 1)
 274
 275	vertexProvider, ok := cfg.Providers.Get("vertexai")
 276	require.True(t, ok, "VertexAI provider should be present")
 277	require.Len(t, vertexProvider.Models, 1)
 278	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 279	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 280	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 281}
 282
 283func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 284	knownProviders := []catwalk.Provider{
 285		{
 286			ID:          catwalk.InferenceProviderVertexAI,
 287			APIKey:      "",
 288			APIEndpoint: "",
 289			Models: []catwalk.Model{{
 290				ID: "gemini-pro",
 291			}},
 292		},
 293	}
 294
 295	cfg := &Config{}
 296	cfg.setDefaults("/tmp", "")
 297	env := environ([]string{
 298		"GOOGLE_GENAI_USE_VERTEXAI=false",
 299		"GOOGLE_CLOUD_PROJECT=test-project",
 300		"GOOGLE_CLOUD_LOCATION=us-central1",
 301	})
 302	resolver := NewEnvironmentVariableResolver(env)
 303	err := cfg.configureProviders(env, resolver, knownProviders)
 304	require.NoError(t, err)
 305	// Provider should not be configured without proper credentials
 306	require.Equal(t, cfg.Providers.Len(), 0)
 307}
 308
 309func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 310	knownProviders := []catwalk.Provider{
 311		{
 312			ID:          catwalk.InferenceProviderVertexAI,
 313			APIKey:      "",
 314			APIEndpoint: "",
 315			Models: []catwalk.Model{{
 316				ID: "gemini-pro",
 317			}},
 318		},
 319	}
 320
 321	cfg := &Config{}
 322	cfg.setDefaults("/tmp", "")
 323	env := environ([]string{
 324		"GOOGLE_GENAI_USE_VERTEXAI=true",
 325		"GOOGLE_CLOUD_LOCATION=us-central1",
 326	})
 327	resolver := NewEnvironmentVariableResolver(env)
 328	err := cfg.configureProviders(env, resolver, knownProviders)
 329	require.NoError(t, err)
 330	// Provider should not be configured without project
 331	require.Equal(t, cfg.Providers.Len(), 0)
 332}
 333
 334func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 335	knownProviders := []catwalk.Provider{
 336		{
 337			ID:          "openai",
 338			APIKey:      "$OPENAI_API_KEY",
 339			APIEndpoint: "https://api.openai.com/v1",
 340			Models: []catwalk.Model{{
 341				ID: "test-model",
 342			}},
 343		},
 344	}
 345
 346	cfg := &Config{}
 347	cfg.setDefaults("/tmp", "")
 348	env := environ([]string{
 349		"OPENAI_API_KEY=test-key",
 350	})
 351	resolver := NewEnvironmentVariableResolver(env)
 352	err := cfg.configureProviders(env, resolver, knownProviders)
 353	require.NoError(t, err)
 354	require.Equal(t, cfg.Providers.Len(), 1)
 355
 356	// Provider ID should be set
 357	pc, _ := cfg.Providers.Get("openai")
 358	require.Equal(t, "openai", pc.ID)
 359}
 360
 361func TestConfig_EnabledProviders(t *testing.T) {
 362	t.Run("all providers enabled", func(t *testing.T) {
 363		cfg := &Config{
 364			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 365				"openai": {
 366					ID:      "openai",
 367					APIKey:  "key1",
 368					Disable: false,
 369				},
 370				"anthropic": {
 371					ID:      "anthropic",
 372					APIKey:  "key2",
 373					Disable: false,
 374				},
 375			}),
 376		}
 377
 378		enabled := cfg.EnabledProviders()
 379		require.Len(t, enabled, 2)
 380	})
 381
 382	t.Run("some providers disabled", func(t *testing.T) {
 383		cfg := &Config{
 384			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 385				"openai": {
 386					ID:      "openai",
 387					APIKey:  "key1",
 388					Disable: false,
 389				},
 390				"anthropic": {
 391					ID:      "anthropic",
 392					APIKey:  "key2",
 393					Disable: true,
 394				},
 395			}),
 396		}
 397
 398		enabled := cfg.EnabledProviders()
 399		require.Len(t, enabled, 1)
 400		require.Equal(t, "openai", enabled[0].ID)
 401	})
 402
 403	t.Run("empty providers map", func(t *testing.T) {
 404		cfg := &Config{
 405			Providers: csync.NewMap[string, ProviderConfig](),
 406		}
 407
 408		enabled := cfg.EnabledProviders()
 409		require.Len(t, enabled, 0)
 410	})
 411}
 412
 413func TestConfig_IsConfigured(t *testing.T) {
 414	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 415		cfg := &Config{
 416			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 417				"openai": {
 418					ID:      "openai",
 419					APIKey:  "key1",
 420					Disable: false,
 421				},
 422			}),
 423		}
 424
 425		require.True(t, cfg.IsConfigured())
 426	})
 427
 428	t.Run("returns false when no providers are configured", func(t *testing.T) {
 429		cfg := &Config{
 430			Providers: csync.NewMap[string, ProviderConfig](),
 431		}
 432
 433		require.False(t, cfg.IsConfigured())
 434	})
 435
 436	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 437		cfg := &Config{
 438			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 439				"openai": {
 440					ID:      "openai",
 441					APIKey:  "key1",
 442					Disable: true,
 443				},
 444				"anthropic": {
 445					ID:      "anthropic",
 446					APIKey:  "key2",
 447					Disable: true,
 448				},
 449			}),
 450		}
 451
 452		require.False(t, cfg.IsConfigured())
 453	})
 454}
 455
 456func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 457	cfg := &Config{
 458		Options: &Options{
 459			DisabledTools: []string{},
 460		},
 461	}
 462
 463	cfg.SetupAgents()
 464	coderAgent, ok := cfg.Agents["coder"]
 465	require.True(t, ok)
 466	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 467
 468	taskAgent, ok := cfg.Agents["task"]
 469	require.True(t, ok)
 470	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 471}
 472
 473func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 474	cfg := &Config{
 475		Options: &Options{
 476			DisabledTools: []string{
 477				"edit",
 478				"download",
 479				"grep",
 480			},
 481		},
 482	}
 483
 484	cfg.SetupAgents()
 485	coderAgent, ok := cfg.Agents["coder"]
 486	require.True(t, ok)
 487	assert.Equal(t, []string{"agent", "bash", "multiedit", "fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
 488
 489	taskAgent, ok := cfg.Agents["task"]
 490	require.True(t, ok)
 491	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 492}
 493
 494func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 495	cfg := &Config{
 496		Options: &Options{
 497			DisabledTools: []string{
 498				"glob",
 499				"grep",
 500				"ls",
 501				"sourcegraph",
 502				"view",
 503			},
 504		},
 505	}
 506
 507	cfg.SetupAgents()
 508	coderAgent, ok := cfg.Agents["coder"]
 509	require.True(t, ok)
 510	assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "fetch", "write"}, coderAgent.AllowedTools)
 511
 512	taskAgent, ok := cfg.Agents["task"]
 513	require.True(t, ok)
 514	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 515}
 516
 517func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 518	knownProviders := []catwalk.Provider{
 519		{
 520			ID:          "openai",
 521			APIKey:      "$OPENAI_API_KEY",
 522			APIEndpoint: "https://api.openai.com/v1",
 523			Models: []catwalk.Model{{
 524				ID: "test-model",
 525			}},
 526		},
 527	}
 528
 529	cfg := &Config{
 530		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 531			"openai": {
 532				Disable: true,
 533			},
 534		}),
 535	}
 536	cfg.setDefaults("/tmp", "")
 537
 538	env := environ([]string{
 539		"OPENAI_API_KEY=test-key",
 540	})
 541	resolver := NewEnvironmentVariableResolver(env)
 542	err := cfg.configureProviders(env, resolver, knownProviders)
 543	require.NoError(t, err)
 544
 545	require.Equal(t, cfg.Providers.Len(), 1)
 546	prov, exists := cfg.Providers.Get("openai")
 547	require.True(t, exists)
 548	require.True(t, prov.Disable)
 549}
 550
 551func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 552	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 553		cfg := &Config{
 554			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 555				"custom": {
 556					BaseURL: "https://api.custom.com/v1",
 557					Models: []catwalk.Model{{
 558						ID: "test-model",
 559					}},
 560				},
 561				"openai": {
 562					APIKey: "$MISSING",
 563				},
 564			}),
 565		}
 566		cfg.setDefaults("/tmp", "")
 567
 568		env := environ([]string{})
 569		resolver := NewEnvironmentVariableResolver(env)
 570		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 571		require.NoError(t, err)
 572
 573		require.Equal(t, cfg.Providers.Len(), 1)
 574		_, exists := cfg.Providers.Get("custom")
 575		require.True(t, exists)
 576	})
 577
 578	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 579		cfg := &Config{
 580			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 581				"custom": {
 582					APIKey: "test-key",
 583					Models: []catwalk.Model{{
 584						ID: "test-model",
 585					}},
 586				},
 587			}),
 588		}
 589		cfg.setDefaults("/tmp", "")
 590
 591		env := environ([]string{})
 592		resolver := NewEnvironmentVariableResolver(env)
 593		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 594		require.NoError(t, err)
 595
 596		require.Equal(t, cfg.Providers.Len(), 0)
 597		_, exists := cfg.Providers.Get("custom")
 598		require.False(t, exists)
 599	})
 600
 601	t.Run("custom provider with no models is removed", func(t *testing.T) {
 602		cfg := &Config{
 603			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 604				"custom": {
 605					APIKey:  "test-key",
 606					BaseURL: "https://api.custom.com/v1",
 607					Models:  []catwalk.Model{},
 608				},
 609			}),
 610		}
 611		cfg.setDefaults("/tmp", "")
 612
 613		env := environ([]string{})
 614		resolver := NewEnvironmentVariableResolver(env)
 615		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 616		require.NoError(t, err)
 617
 618		require.Equal(t, cfg.Providers.Len(), 0)
 619		_, exists := cfg.Providers.Get("custom")
 620		require.False(t, exists)
 621	})
 622
 623	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 624		cfg := &Config{
 625			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 626				"custom": {
 627					APIKey:  "test-key",
 628					BaseURL: "https://api.custom.com/v1",
 629					Type:    "unsupported",
 630					Models: []catwalk.Model{{
 631						ID: "test-model",
 632					}},
 633				},
 634			}),
 635		}
 636		cfg.setDefaults("/tmp", "")
 637
 638		env := environ([]string{})
 639		resolver := NewEnvironmentVariableResolver(env)
 640		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 641		require.NoError(t, err)
 642
 643		require.Equal(t, cfg.Providers.Len(), 0)
 644		_, exists := cfg.Providers.Get("custom")
 645		require.False(t, exists)
 646	})
 647
 648	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 649		cfg := &Config{
 650			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 651				"custom": {
 652					APIKey:  "test-key",
 653					BaseURL: "https://api.custom.com/v1",
 654					Type:    catwalk.TypeOpenAI,
 655					Models: []catwalk.Model{{
 656						ID: "test-model",
 657					}},
 658				},
 659			}),
 660		}
 661		cfg.setDefaults("/tmp", "")
 662
 663		env := environ([]string{})
 664		resolver := NewEnvironmentVariableResolver(env)
 665		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 666		require.NoError(t, err)
 667
 668		require.Equal(t, cfg.Providers.Len(), 1)
 669		customProvider, exists := cfg.Providers.Get("custom")
 670		require.True(t, exists)
 671		require.Equal(t, "custom", customProvider.ID)
 672		require.Equal(t, "test-key", customProvider.APIKey)
 673		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 674	})
 675
 676	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 677		cfg := &Config{
 678			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 679				"custom-anthropic": {
 680					APIKey:  "test-key",
 681					BaseURL: "https://api.anthropic.com/v1",
 682					Type:    catwalk.TypeAnthropic,
 683					Models: []catwalk.Model{{
 684						ID: "claude-3-sonnet",
 685					}},
 686				},
 687			}),
 688		}
 689		cfg.setDefaults("/tmp", "")
 690
 691		env := environ([]string{})
 692		resolver := NewEnvironmentVariableResolver(env)
 693		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 694		require.NoError(t, err)
 695
 696		require.Equal(t, cfg.Providers.Len(), 1)
 697		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 698		require.True(t, exists)
 699		require.Equal(t, "custom-anthropic", customProvider.ID)
 700		require.Equal(t, "test-key", customProvider.APIKey)
 701		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 702		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 703	})
 704
 705	t.Run("disabled custom provider is removed", func(t *testing.T) {
 706		cfg := &Config{
 707			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 708				"custom": {
 709					APIKey:  "test-key",
 710					BaseURL: "https://api.custom.com/v1",
 711					Type:    catwalk.TypeOpenAI,
 712					Disable: true,
 713					Models: []catwalk.Model{{
 714						ID: "test-model",
 715					}},
 716				},
 717			}),
 718		}
 719		cfg.setDefaults("/tmp", "")
 720
 721		env := environ([]string{})
 722		resolver := NewEnvironmentVariableResolver(env)
 723		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 724		require.NoError(t, err)
 725
 726		require.Equal(t, cfg.Providers.Len(), 0)
 727		_, exists := cfg.Providers.Get("custom")
 728		require.False(t, exists)
 729	})
 730}
 731
 732func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 733	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 734		knownProviders := []catwalk.Provider{
 735			{
 736				ID:          catwalk.InferenceProviderVertexAI,
 737				APIKey:      "",
 738				APIEndpoint: "",
 739				Models: []catwalk.Model{{
 740					ID: "gemini-pro",
 741				}},
 742			},
 743		}
 744
 745		cfg := &Config{
 746			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 747				"vertexai": {
 748					BaseURL: "custom-url",
 749				},
 750			}),
 751		}
 752		cfg.setDefaults("/tmp", "")
 753
 754		env := environ([]string{
 755			"GOOGLE_GENAI_USE_VERTEXAI=false",
 756		})
 757		resolver := NewEnvironmentVariableResolver(env)
 758		err := cfg.configureProviders(env, resolver, knownProviders)
 759		require.NoError(t, err)
 760
 761		require.Equal(t, cfg.Providers.Len(), 0)
 762		_, exists := cfg.Providers.Get("vertexai")
 763		require.False(t, exists)
 764	})
 765
 766	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 767		knownProviders := []catwalk.Provider{
 768			{
 769				ID:          catwalk.InferenceProviderBedrock,
 770				APIKey:      "",
 771				APIEndpoint: "",
 772				Models: []catwalk.Model{{
 773					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 774				}},
 775			},
 776		}
 777
 778		cfg := &Config{
 779			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 780				"bedrock": {
 781					BaseURL: "custom-url",
 782				},
 783			}),
 784		}
 785		cfg.setDefaults("/tmp", "")
 786
 787		env := environ([]string{})
 788		resolver := NewEnvironmentVariableResolver(env)
 789		err := cfg.configureProviders(env, resolver, knownProviders)
 790		require.NoError(t, err)
 791
 792		require.Equal(t, cfg.Providers.Len(), 0)
 793		_, exists := cfg.Providers.Get("bedrock")
 794		require.False(t, exists)
 795	})
 796
 797	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 798		knownProviders := []catwalk.Provider{
 799			{
 800				ID:          "openai",
 801				APIKey:      "$MISSING_API_KEY",
 802				APIEndpoint: "https://api.openai.com/v1",
 803				Models: []catwalk.Model{{
 804					ID: "test-model",
 805				}},
 806			},
 807		}
 808
 809		cfg := &Config{
 810			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 811				"openai": {
 812					BaseURL: "custom-url",
 813				},
 814			}),
 815		}
 816		cfg.setDefaults("/tmp", "")
 817
 818		env := environ([]string{})
 819		resolver := NewEnvironmentVariableResolver(env)
 820		err := cfg.configureProviders(env, resolver, knownProviders)
 821		require.NoError(t, err)
 822
 823		require.Equal(t, cfg.Providers.Len(), 0)
 824		_, exists := cfg.Providers.Get("openai")
 825		require.False(t, exists)
 826	})
 827
 828	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 829		knownProviders := []catwalk.Provider{
 830			{
 831				ID:          "openai",
 832				APIKey:      "$OPENAI_API_KEY",
 833				APIEndpoint: "$MISSING_ENDPOINT",
 834				Models: []catwalk.Model{{
 835					ID: "test-model",
 836				}},
 837			},
 838		}
 839
 840		cfg := &Config{
 841			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 842				"openai": {
 843					APIKey: "test-key",
 844				},
 845			}),
 846		}
 847		cfg.setDefaults("/tmp", "")
 848
 849		env := environ([]string{
 850			"OPENAI_API_KEY=test-key",
 851		})
 852		resolver := NewEnvironmentVariableResolver(env)
 853		err := cfg.configureProviders(env, resolver, knownProviders)
 854		require.NoError(t, err)
 855
 856		require.Equal(t, cfg.Providers.Len(), 1)
 857		_, exists := cfg.Providers.Get("openai")
 858		require.True(t, exists)
 859	})
 860}
 861
 862func TestConfig_defaultModelSelection(t *testing.T) {
 863	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 864		knownProviders := []catwalk.Provider{
 865			{
 866				ID:                  "openai",
 867				APIKey:              "abc",
 868				DefaultLargeModelID: "large-model",
 869				DefaultSmallModelID: "small-model",
 870				Models: []catwalk.Model{
 871					{
 872						ID:               "large-model",
 873						DefaultMaxTokens: 1000,
 874					},
 875					{
 876						ID:               "small-model",
 877						DefaultMaxTokens: 500,
 878					},
 879				},
 880			},
 881		}
 882
 883		cfg := &Config{}
 884		cfg.setDefaults("/tmp", "")
 885		env := environ([]string{})
 886		resolver := NewEnvironmentVariableResolver(env)
 887		err := cfg.configureProviders(env, resolver, knownProviders)
 888		require.NoError(t, err)
 889
 890		large, small, err := cfg.defaultModelSelection(knownProviders)
 891		require.NoError(t, err)
 892		require.Equal(t, "large-model", large.Model)
 893		require.Equal(t, "openai", large.Provider)
 894		require.Equal(t, int64(1000), large.MaxTokens)
 895		require.Equal(t, "small-model", small.Model)
 896		require.Equal(t, "openai", small.Provider)
 897		require.Equal(t, int64(500), small.MaxTokens)
 898	})
 899	t.Run("should error if no providers configured", func(t *testing.T) {
 900		knownProviders := []catwalk.Provider{
 901			{
 902				ID:                  "openai",
 903				APIKey:              "$MISSING_KEY",
 904				DefaultLargeModelID: "large-model",
 905				DefaultSmallModelID: "small-model",
 906				Models: []catwalk.Model{
 907					{
 908						ID:               "large-model",
 909						DefaultMaxTokens: 1000,
 910					},
 911					{
 912						ID:               "small-model",
 913						DefaultMaxTokens: 500,
 914					},
 915				},
 916			},
 917		}
 918
 919		cfg := &Config{}
 920		cfg.setDefaults("/tmp", "")
 921		env := environ([]string{})
 922		resolver := NewEnvironmentVariableResolver(env)
 923		err := cfg.configureProviders(env, resolver, knownProviders)
 924		require.NoError(t, err)
 925
 926		_, _, err = cfg.defaultModelSelection(knownProviders)
 927		require.Error(t, err)
 928	})
 929	t.Run("should error if model is missing", func(t *testing.T) {
 930		knownProviders := []catwalk.Provider{
 931			{
 932				ID:                  "openai",
 933				APIKey:              "abc",
 934				DefaultLargeModelID: "large-model",
 935				DefaultSmallModelID: "small-model",
 936				Models: []catwalk.Model{
 937					{
 938						ID:               "not-large-model",
 939						DefaultMaxTokens: 1000,
 940					},
 941					{
 942						ID:               "small-model",
 943						DefaultMaxTokens: 500,
 944					},
 945				},
 946			},
 947		}
 948
 949		cfg := &Config{}
 950		cfg.setDefaults("/tmp", "")
 951		env := environ([]string{})
 952		resolver := NewEnvironmentVariableResolver(env)
 953		err := cfg.configureProviders(env, resolver, knownProviders)
 954		require.NoError(t, err)
 955		_, _, err = cfg.defaultModelSelection(knownProviders)
 956		require.Error(t, err)
 957	})
 958
 959	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 960		knownProviders := []catwalk.Provider{
 961			{
 962				ID:                  "openai",
 963				APIKey:              "$MISSING", // will not be included in the config
 964				DefaultLargeModelID: "large-model",
 965				DefaultSmallModelID: "small-model",
 966				Models: []catwalk.Model{
 967					{
 968						ID:               "not-large-model",
 969						DefaultMaxTokens: 1000,
 970					},
 971					{
 972						ID:               "small-model",
 973						DefaultMaxTokens: 500,
 974					},
 975				},
 976			},
 977		}
 978
 979		cfg := &Config{
 980			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 981				"custom": {
 982					APIKey:  "test-key",
 983					BaseURL: "https://api.custom.com/v1",
 984					Models: []catwalk.Model{
 985						{
 986							ID:               "model",
 987							DefaultMaxTokens: 600,
 988						},
 989					},
 990				},
 991			}),
 992		}
 993		cfg.setDefaults("/tmp", "")
 994		env := environ([]string{})
 995		resolver := NewEnvironmentVariableResolver(env)
 996		err := cfg.configureProviders(env, resolver, knownProviders)
 997		require.NoError(t, err)
 998		large, small, err := cfg.defaultModelSelection(knownProviders)
 999		require.NoError(t, err)
1000		require.Equal(t, "model", large.Model)
1001		require.Equal(t, "custom", large.Provider)
1002		require.Equal(t, int64(600), large.MaxTokens)
1003		require.Equal(t, "model", small.Model)
1004		require.Equal(t, "custom", small.Provider)
1005		require.Equal(t, int64(600), small.MaxTokens)
1006	})
1007
1008	t.Run("should fail if no model configured", func(t *testing.T) {
1009		knownProviders := []catwalk.Provider{
1010			{
1011				ID:                  "openai",
1012				APIKey:              "$MISSING", // will not be included in the config
1013				DefaultLargeModelID: "large-model",
1014				DefaultSmallModelID: "small-model",
1015				Models: []catwalk.Model{
1016					{
1017						ID:               "not-large-model",
1018						DefaultMaxTokens: 1000,
1019					},
1020					{
1021						ID:               "small-model",
1022						DefaultMaxTokens: 500,
1023					},
1024				},
1025			},
1026		}
1027
1028		cfg := &Config{
1029			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1030				"custom": {
1031					APIKey:  "test-key",
1032					BaseURL: "https://api.custom.com/v1",
1033					Models:  []catwalk.Model{},
1034				},
1035			}),
1036		}
1037		cfg.setDefaults("/tmp", "")
1038		env := environ([]string{})
1039		resolver := NewEnvironmentVariableResolver(env)
1040		err := cfg.configureProviders(env, resolver, knownProviders)
1041		require.NoError(t, err)
1042		_, _, err = cfg.defaultModelSelection(knownProviders)
1043		require.Error(t, err)
1044	})
1045	t.Run("should use the default provider first", func(t *testing.T) {
1046		knownProviders := []catwalk.Provider{
1047			{
1048				ID:                  "openai",
1049				APIKey:              "set",
1050				DefaultLargeModelID: "large-model",
1051				DefaultSmallModelID: "small-model",
1052				Models: []catwalk.Model{
1053					{
1054						ID:               "large-model",
1055						DefaultMaxTokens: 1000,
1056					},
1057					{
1058						ID:               "small-model",
1059						DefaultMaxTokens: 500,
1060					},
1061				},
1062			},
1063		}
1064
1065		cfg := &Config{
1066			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1067				"custom": {
1068					APIKey:  "test-key",
1069					BaseURL: "https://api.custom.com/v1",
1070					Models: []catwalk.Model{
1071						{
1072							ID:               "large-model",
1073							DefaultMaxTokens: 1000,
1074						},
1075					},
1076				},
1077			}),
1078		}
1079		cfg.setDefaults("/tmp", "")
1080		env := environ([]string{})
1081		resolver := NewEnvironmentVariableResolver(env)
1082		err := cfg.configureProviders(env, resolver, knownProviders)
1083		require.NoError(t, err)
1084		large, small, err := cfg.defaultModelSelection(knownProviders)
1085		require.NoError(t, err)
1086		require.Equal(t, "large-model", large.Model)
1087		require.Equal(t, "openai", large.Provider)
1088		require.Equal(t, int64(1000), large.MaxTokens)
1089		require.Equal(t, "small-model", small.Model)
1090		require.Equal(t, "openai", small.Provider)
1091		require.Equal(t, int64(500), small.MaxTokens)
1092	})
1093}
1094
1095func TestConfig_configureSelectedModels(t *testing.T) {
1096	t.Run("should override defaults", func(t *testing.T) {
1097		knownProviders := []catwalk.Provider{
1098			{
1099				ID:                  "openai",
1100				APIKey:              "abc",
1101				DefaultLargeModelID: "large-model",
1102				DefaultSmallModelID: "small-model",
1103				Models: []catwalk.Model{
1104					{
1105						ID:               "larger-model",
1106						DefaultMaxTokens: 2000,
1107					},
1108					{
1109						ID:               "large-model",
1110						DefaultMaxTokens: 1000,
1111					},
1112					{
1113						ID:               "small-model",
1114						DefaultMaxTokens: 500,
1115					},
1116				},
1117			},
1118		}
1119
1120		cfg := &Config{
1121			Models: map[SelectedModelType]SelectedModel{
1122				"large": {
1123					Model: "larger-model",
1124				},
1125			},
1126		}
1127		cfg.setDefaults("/tmp", "")
1128		env := environ([]string{})
1129		resolver := NewEnvironmentVariableResolver(env)
1130		err := cfg.configureProviders(env, resolver, knownProviders)
1131		require.NoError(t, err)
1132
1133		err = cfg.configureSelectedModels(knownProviders)
1134		require.NoError(t, err)
1135		large := cfg.Models[SelectedModelTypeLarge]
1136		small := cfg.Models[SelectedModelTypeSmall]
1137		require.Equal(t, "larger-model", large.Model)
1138		require.Equal(t, "openai", large.Provider)
1139		require.Equal(t, int64(2000), large.MaxTokens)
1140		require.Equal(t, "small-model", small.Model)
1141		require.Equal(t, "openai", small.Provider)
1142		require.Equal(t, int64(500), small.MaxTokens)
1143	})
1144	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1145		knownProviders := []catwalk.Provider{
1146			{
1147				ID:                  "openai",
1148				APIKey:              "abc",
1149				DefaultLargeModelID: "large-model",
1150				DefaultSmallModelID: "small-model",
1151				Models: []catwalk.Model{
1152					{
1153						ID:               "large-model",
1154						DefaultMaxTokens: 1000,
1155					},
1156					{
1157						ID:               "small-model",
1158						DefaultMaxTokens: 500,
1159					},
1160				},
1161			},
1162			{
1163				ID:                  "anthropic",
1164				APIKey:              "abc",
1165				DefaultLargeModelID: "a-large-model",
1166				DefaultSmallModelID: "a-small-model",
1167				Models: []catwalk.Model{
1168					{
1169						ID:               "a-large-model",
1170						DefaultMaxTokens: 1000,
1171					},
1172					{
1173						ID:               "a-small-model",
1174						DefaultMaxTokens: 200,
1175					},
1176				},
1177			},
1178		}
1179
1180		cfg := &Config{
1181			Models: map[SelectedModelType]SelectedModel{
1182				"small": {
1183					Model:     "a-small-model",
1184					Provider:  "anthropic",
1185					MaxTokens: 300,
1186				},
1187			},
1188		}
1189		cfg.setDefaults("/tmp", "")
1190		env := environ([]string{})
1191		resolver := NewEnvironmentVariableResolver(env)
1192		err := cfg.configureProviders(env, resolver, knownProviders)
1193		require.NoError(t, err)
1194
1195		err = cfg.configureSelectedModels(knownProviders)
1196		require.NoError(t, err)
1197		large := cfg.Models[SelectedModelTypeLarge]
1198		small := cfg.Models[SelectedModelTypeSmall]
1199		require.Equal(t, "large-model", large.Model)
1200		require.Equal(t, "openai", large.Provider)
1201		require.Equal(t, int64(1000), large.MaxTokens)
1202		require.Equal(t, "a-small-model", small.Model)
1203		require.Equal(t, "anthropic", small.Provider)
1204		require.Equal(t, int64(300), small.MaxTokens)
1205	})
1206
1207	t.Run("should override the max tokens only", func(t *testing.T) {
1208		knownProviders := []catwalk.Provider{
1209			{
1210				ID:                  "openai",
1211				APIKey:              "abc",
1212				DefaultLargeModelID: "large-model",
1213				DefaultSmallModelID: "small-model",
1214				Models: []catwalk.Model{
1215					{
1216						ID:               "large-model",
1217						DefaultMaxTokens: 1000,
1218					},
1219					{
1220						ID:               "small-model",
1221						DefaultMaxTokens: 500,
1222					},
1223				},
1224			},
1225		}
1226
1227		cfg := &Config{
1228			Models: map[SelectedModelType]SelectedModel{
1229				"large": {
1230					MaxTokens: 100,
1231				},
1232			},
1233		}
1234		cfg.setDefaults("/tmp", "")
1235		env := environ([]string{})
1236		resolver := NewEnvironmentVariableResolver(env)
1237		err := cfg.configureProviders(env, resolver, knownProviders)
1238		require.NoError(t, err)
1239
1240		err = cfg.configureSelectedModels(knownProviders)
1241		require.NoError(t, err)
1242		large := cfg.Models[SelectedModelTypeLarge]
1243		require.Equal(t, "large-model", large.Model)
1244		require.Equal(t, "openai", large.Provider)
1245		require.Equal(t, int64(100), large.MaxTokens)
1246	})
1247}