load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"git.secluded.site/crush/internal/csync"
  12	"git.secluded.site/crush/internal/env"
  13	"github.com/charmbracelet/catwalk/pkg/catwalk"
  14	"github.com/stretchr/testify/assert"
  15	"github.com/stretchr/testify/require"
  16)
  17
  18func TestMain(m *testing.M) {
  19	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  20
  21	exitVal := m.Run()
  22	os.Exit(exitVal)
  23}
  24
  25func TestConfig_LoadFromReaders(t *testing.T) {
  26	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  27	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  28	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  29
  30	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  31
  32	require.NoError(t, err)
  33	require.NotNil(t, loadedConfig)
  34	require.Equal(t, 1, loadedConfig.Providers.Len())
  35	pc, _ := loadedConfig.Providers.Get("openai")
  36	require.Equal(t, "key2", pc.APIKey)
  37	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  38}
  39
  40func TestConfig_setDefaults(t *testing.T) {
  41	cfg := &Config{}
  42
  43	cfg.setDefaults("/tmp", "")
  44
  45	require.NotNil(t, cfg.Options)
  46	require.NotNil(t, cfg.Options.TUI)
  47	require.NotNil(t, cfg.Options.ContextPaths)
  48	require.NotNil(t, cfg.Providers)
  49	require.NotNil(t, cfg.Models)
  50	require.NotNil(t, cfg.LSP)
  51	require.NotNil(t, cfg.MCP)
  52	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  53	require.Equal(t, "AGENTS.md", cfg.Options.InitializeAs)
  54	for _, path := range defaultContextPaths {
  55		require.Contains(t, cfg.Options.ContextPaths, path)
  56	}
  57	require.Equal(t, "/tmp", cfg.workingDir)
  58}
  59
  60func TestConfig_configureProviders(t *testing.T) {
  61	knownProviders := []catwalk.Provider{
  62		{
  63			ID:          "openai",
  64			APIKey:      "$OPENAI_API_KEY",
  65			APIEndpoint: "https://api.openai.com/v1",
  66			Models: []catwalk.Model{{
  67				ID: "test-model",
  68			}},
  69		},
  70	}
  71
  72	cfg := &Config{}
  73	cfg.setDefaults("/tmp", "")
  74	env := env.NewFromMap(map[string]string{
  75		"OPENAI_API_KEY": "test-key",
  76	})
  77	resolver := NewEnvironmentVariableResolver(env)
  78	err := cfg.configureProviders(env, resolver, knownProviders)
  79	require.NoError(t, err)
  80	require.Equal(t, 1, cfg.Providers.Len())
  81
  82	// We want to make sure that we keep the configured API key as a placeholder
  83	pc, _ := cfg.Providers.Get("openai")
  84	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  85}
  86
  87func TestConfig_configureProvidersWithOverride(t *testing.T) {
  88	knownProviders := []catwalk.Provider{
  89		{
  90			ID:          "openai",
  91			APIKey:      "$OPENAI_API_KEY",
  92			APIEndpoint: "https://api.openai.com/v1",
  93			Models: []catwalk.Model{{
  94				ID: "test-model",
  95			}},
  96		},
  97	}
  98
  99	cfg := &Config{
 100		Providers: csync.NewMap[string, ProviderConfig](),
 101	}
 102	cfg.Providers.Set("openai", ProviderConfig{
 103		APIKey:  "xyz",
 104		BaseURL: "https://api.openai.com/v2",
 105		Models: []catwalk.Model{
 106			{
 107				ID:   "test-model",
 108				Name: "Updated",
 109			},
 110			{
 111				ID: "another-model",
 112			},
 113		},
 114	})
 115	cfg.setDefaults("/tmp", "")
 116
 117	env := env.NewFromMap(map[string]string{
 118		"OPENAI_API_KEY": "test-key",
 119	})
 120	resolver := NewEnvironmentVariableResolver(env)
 121	err := cfg.configureProviders(env, resolver, knownProviders)
 122	require.NoError(t, err)
 123	require.Equal(t, 1, cfg.Providers.Len())
 124
 125	// We want to make sure that we keep the configured API key as a placeholder
 126	pc, _ := cfg.Providers.Get("openai")
 127	require.Equal(t, "xyz", pc.APIKey)
 128	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 129	require.Len(t, pc.Models, 2)
 130	require.Equal(t, "Updated", pc.Models[0].Name)
 131}
 132
 133func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 134	knownProviders := []catwalk.Provider{
 135		{
 136			ID:          "openai",
 137			APIKey:      "$OPENAI_API_KEY",
 138			APIEndpoint: "https://api.openai.com/v1",
 139			Models: []catwalk.Model{{
 140				ID: "test-model",
 141			}},
 142		},
 143	}
 144
 145	cfg := &Config{
 146		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 147			"custom": {
 148				APIKey:  "xyz",
 149				BaseURL: "https://api.someendpoint.com/v2",
 150				Models: []catwalk.Model{
 151					{
 152						ID: "test-model",
 153					},
 154				},
 155			},
 156		}),
 157	}
 158	cfg.setDefaults("/tmp", "")
 159	env := env.NewFromMap(map[string]string{
 160		"OPENAI_API_KEY": "test-key",
 161	})
 162	resolver := NewEnvironmentVariableResolver(env)
 163	err := cfg.configureProviders(env, resolver, knownProviders)
 164	require.NoError(t, err)
 165	// Should be to because of the env variable
 166	require.Equal(t, cfg.Providers.Len(), 2)
 167
 168	// We want to make sure that we keep the configured API key as a placeholder
 169	pc, _ := cfg.Providers.Get("custom")
 170	require.Equal(t, "xyz", pc.APIKey)
 171	// Make sure we set the ID correctly
 172	require.Equal(t, "custom", pc.ID)
 173	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 174	require.Len(t, pc.Models, 1)
 175
 176	_, ok := cfg.Providers.Get("openai")
 177	require.True(t, ok, "OpenAI provider should still be present")
 178}
 179
 180func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 181	knownProviders := []catwalk.Provider{
 182		{
 183			ID:          catwalk.InferenceProviderBedrock,
 184			APIKey:      "",
 185			APIEndpoint: "",
 186			Models: []catwalk.Model{{
 187				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 188			}},
 189		},
 190	}
 191
 192	cfg := &Config{}
 193	cfg.setDefaults("/tmp", "")
 194	env := env.NewFromMap(map[string]string{
 195		"AWS_ACCESS_KEY_ID":     "test-key-id",
 196		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 197	})
 198	resolver := NewEnvironmentVariableResolver(env)
 199	err := cfg.configureProviders(env, resolver, knownProviders)
 200	require.NoError(t, err)
 201	require.Equal(t, cfg.Providers.Len(), 1)
 202
 203	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 204	require.True(t, ok, "Bedrock provider should be present")
 205	require.Len(t, bedrockProvider.Models, 1)
 206	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 207}
 208
 209func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 210	knownProviders := []catwalk.Provider{
 211		{
 212			ID:          catwalk.InferenceProviderBedrock,
 213			APIKey:      "",
 214			APIEndpoint: "",
 215			Models: []catwalk.Model{{
 216				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 217			}},
 218		},
 219	}
 220
 221	cfg := &Config{}
 222	cfg.setDefaults("/tmp", "")
 223	env := env.NewFromMap(map[string]string{})
 224	resolver := NewEnvironmentVariableResolver(env)
 225	err := cfg.configureProviders(env, resolver, knownProviders)
 226	require.NoError(t, err)
 227	// Provider should not be configured without credentials
 228	require.Equal(t, cfg.Providers.Len(), 0)
 229}
 230
 231func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 232	knownProviders := []catwalk.Provider{
 233		{
 234			ID:          catwalk.InferenceProviderBedrock,
 235			APIKey:      "",
 236			APIEndpoint: "",
 237			Models: []catwalk.Model{{
 238				ID: "some-random-model",
 239			}},
 240		},
 241	}
 242
 243	cfg := &Config{}
 244	cfg.setDefaults("/tmp", "")
 245	env := env.NewFromMap(map[string]string{
 246		"AWS_ACCESS_KEY_ID":     "test-key-id",
 247		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 248	})
 249	resolver := NewEnvironmentVariableResolver(env)
 250	err := cfg.configureProviders(env, resolver, knownProviders)
 251	require.Error(t, err)
 252}
 253
 254func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 255	knownProviders := []catwalk.Provider{
 256		{
 257			ID:          catwalk.InferenceProviderVertexAI,
 258			APIKey:      "",
 259			APIEndpoint: "",
 260			Models: []catwalk.Model{{
 261				ID: "gemini-pro",
 262			}},
 263		},
 264	}
 265
 266	cfg := &Config{}
 267	cfg.setDefaults("/tmp", "")
 268	env := env.NewFromMap(map[string]string{
 269		"VERTEXAI_PROJECT":  "test-project",
 270		"VERTEXAI_LOCATION": "us-central1",
 271	})
 272	resolver := NewEnvironmentVariableResolver(env)
 273	err := cfg.configureProviders(env, resolver, knownProviders)
 274	require.NoError(t, err)
 275	require.Equal(t, cfg.Providers.Len(), 1)
 276
 277	vertexProvider, ok := cfg.Providers.Get("vertexai")
 278	require.True(t, ok, "VertexAI provider should be present")
 279	require.Len(t, vertexProvider.Models, 1)
 280	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 281	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 282	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 283}
 284
 285func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 286	knownProviders := []catwalk.Provider{
 287		{
 288			ID:          catwalk.InferenceProviderVertexAI,
 289			APIKey:      "",
 290			APIEndpoint: "",
 291			Models: []catwalk.Model{{
 292				ID: "gemini-pro",
 293			}},
 294		},
 295	}
 296
 297	cfg := &Config{}
 298	cfg.setDefaults("/tmp", "")
 299	env := env.NewFromMap(map[string]string{
 300		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 301		"GOOGLE_CLOUD_PROJECT":      "test-project",
 302		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 303	})
 304	resolver := NewEnvironmentVariableResolver(env)
 305	err := cfg.configureProviders(env, resolver, knownProviders)
 306	require.NoError(t, err)
 307	// Provider should not be configured without proper credentials
 308	require.Equal(t, cfg.Providers.Len(), 0)
 309}
 310
 311func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 312	knownProviders := []catwalk.Provider{
 313		{
 314			ID:          catwalk.InferenceProviderVertexAI,
 315			APIKey:      "",
 316			APIEndpoint: "",
 317			Models: []catwalk.Model{{
 318				ID: "gemini-pro",
 319			}},
 320		},
 321	}
 322
 323	cfg := &Config{}
 324	cfg.setDefaults("/tmp", "")
 325	env := env.NewFromMap(map[string]string{
 326		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 327		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 328	})
 329	resolver := NewEnvironmentVariableResolver(env)
 330	err := cfg.configureProviders(env, resolver, knownProviders)
 331	require.NoError(t, err)
 332	// Provider should not be configured without project
 333	require.Equal(t, cfg.Providers.Len(), 0)
 334}
 335
 336func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 337	knownProviders := []catwalk.Provider{
 338		{
 339			ID:          "openai",
 340			APIKey:      "$OPENAI_API_KEY",
 341			APIEndpoint: "https://api.openai.com/v1",
 342			Models: []catwalk.Model{{
 343				ID: "test-model",
 344			}},
 345		},
 346	}
 347
 348	cfg := &Config{}
 349	cfg.setDefaults("/tmp", "")
 350	env := env.NewFromMap(map[string]string{
 351		"OPENAI_API_KEY": "test-key",
 352	})
 353	resolver := NewEnvironmentVariableResolver(env)
 354	err := cfg.configureProviders(env, resolver, knownProviders)
 355	require.NoError(t, err)
 356	require.Equal(t, cfg.Providers.Len(), 1)
 357
 358	// Provider ID should be set
 359	pc, _ := cfg.Providers.Get("openai")
 360	require.Equal(t, "openai", pc.ID)
 361}
 362
 363func TestConfig_EnabledProviders(t *testing.T) {
 364	t.Run("all providers enabled", func(t *testing.T) {
 365		cfg := &Config{
 366			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 367				"openai": {
 368					ID:      "openai",
 369					APIKey:  "key1",
 370					Disable: false,
 371				},
 372				"anthropic": {
 373					ID:      "anthropic",
 374					APIKey:  "key2",
 375					Disable: false,
 376				},
 377			}),
 378		}
 379
 380		enabled := cfg.EnabledProviders()
 381		require.Len(t, enabled, 2)
 382	})
 383
 384	t.Run("some providers disabled", func(t *testing.T) {
 385		cfg := &Config{
 386			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 387				"openai": {
 388					ID:      "openai",
 389					APIKey:  "key1",
 390					Disable: false,
 391				},
 392				"anthropic": {
 393					ID:      "anthropic",
 394					APIKey:  "key2",
 395					Disable: true,
 396				},
 397			}),
 398		}
 399
 400		enabled := cfg.EnabledProviders()
 401		require.Len(t, enabled, 1)
 402		require.Equal(t, "openai", enabled[0].ID)
 403	})
 404
 405	t.Run("empty providers map", func(t *testing.T) {
 406		cfg := &Config{
 407			Providers: csync.NewMap[string, ProviderConfig](),
 408		}
 409
 410		enabled := cfg.EnabledProviders()
 411		require.Len(t, enabled, 0)
 412	})
 413}
 414
 415func TestConfig_IsConfigured(t *testing.T) {
 416	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 417		cfg := &Config{
 418			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 419				"openai": {
 420					ID:      "openai",
 421					APIKey:  "key1",
 422					Disable: false,
 423				},
 424			}),
 425		}
 426
 427		require.True(t, cfg.IsConfigured())
 428	})
 429
 430	t.Run("returns false when no providers are configured", func(t *testing.T) {
 431		cfg := &Config{
 432			Providers: csync.NewMap[string, ProviderConfig](),
 433		}
 434
 435		require.False(t, cfg.IsConfigured())
 436	})
 437
 438	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 439		cfg := &Config{
 440			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 441				"openai": {
 442					ID:      "openai",
 443					APIKey:  "key1",
 444					Disable: true,
 445				},
 446				"anthropic": {
 447					ID:      "anthropic",
 448					APIKey:  "key2",
 449					Disable: true,
 450				},
 451			}),
 452		}
 453
 454		require.False(t, cfg.IsConfigured())
 455	})
 456}
 457
 458func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 459	cfg := &Config{
 460		Options: &Options{
 461			DisabledTools: []string{},
 462		},
 463	}
 464
 465	cfg.SetupAgents()
 466	coderAgent, ok := cfg.Agents[AgentCoder]
 467	require.True(t, ok)
 468	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 469
 470	taskAgent, ok := cfg.Agents[AgentTask]
 471	require.True(t, ok)
 472	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 473}
 474
 475func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 476	cfg := &Config{
 477		Options: &Options{
 478			DisabledTools: []string{
 479				"edit",
 480				"download",
 481				"grep",
 482			},
 483		},
 484	}
 485
 486	cfg.SetupAgents()
 487	coderAgent, ok := cfg.Agents[AgentCoder]
 488	require.True(t, ok)
 489
 490	assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
 491
 492	taskAgent, ok := cfg.Agents[AgentTask]
 493	require.True(t, ok)
 494	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 495}
 496
 497func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 498	cfg := &Config{
 499		Options: &Options{
 500			DisabledTools: []string{
 501				"glob",
 502				"grep",
 503				"ls",
 504				"sourcegraph",
 505				"view",
 506			},
 507		},
 508	}
 509
 510	cfg.SetupAgents()
 511	coderAgent, ok := cfg.Agents[AgentCoder]
 512	require.True(t, ok)
 513	assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools)
 514
 515	taskAgent, ok := cfg.Agents[AgentTask]
 516	require.True(t, ok)
 517	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 518}
 519
 520func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 521	knownProviders := []catwalk.Provider{
 522		{
 523			ID:          "openai",
 524			APIKey:      "$OPENAI_API_KEY",
 525			APIEndpoint: "https://api.openai.com/v1",
 526			Models: []catwalk.Model{{
 527				ID: "test-model",
 528			}},
 529		},
 530	}
 531
 532	cfg := &Config{
 533		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 534			"openai": {
 535				Disable: true,
 536			},
 537		}),
 538	}
 539	cfg.setDefaults("/tmp", "")
 540
 541	env := env.NewFromMap(map[string]string{
 542		"OPENAI_API_KEY": "test-key",
 543	})
 544	resolver := NewEnvironmentVariableResolver(env)
 545	err := cfg.configureProviders(env, resolver, knownProviders)
 546	require.NoError(t, err)
 547
 548	require.Equal(t, cfg.Providers.Len(), 1)
 549	prov, exists := cfg.Providers.Get("openai")
 550	require.True(t, exists)
 551	require.True(t, prov.Disable)
 552}
 553
 554func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 555	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 556		cfg := &Config{
 557			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 558				"custom": {
 559					BaseURL: "https://api.custom.com/v1",
 560					Models: []catwalk.Model{{
 561						ID: "test-model",
 562					}},
 563				},
 564				"openai": {
 565					APIKey: "$MISSING",
 566				},
 567			}),
 568		}
 569		cfg.setDefaults("/tmp", "")
 570
 571		env := env.NewFromMap(map[string]string{})
 572		resolver := NewEnvironmentVariableResolver(env)
 573		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 574		require.NoError(t, err)
 575
 576		require.Equal(t, cfg.Providers.Len(), 1)
 577		_, exists := cfg.Providers.Get("custom")
 578		require.True(t, exists)
 579	})
 580
 581	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 582		cfg := &Config{
 583			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 584				"custom": {
 585					APIKey: "test-key",
 586					Models: []catwalk.Model{{
 587						ID: "test-model",
 588					}},
 589				},
 590			}),
 591		}
 592		cfg.setDefaults("/tmp", "")
 593
 594		env := env.NewFromMap(map[string]string{})
 595		resolver := NewEnvironmentVariableResolver(env)
 596		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 597		require.NoError(t, err)
 598
 599		require.Equal(t, cfg.Providers.Len(), 0)
 600		_, exists := cfg.Providers.Get("custom")
 601		require.False(t, exists)
 602	})
 603
 604	t.Run("custom provider with no models is removed", func(t *testing.T) {
 605		cfg := &Config{
 606			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 607				"custom": {
 608					APIKey:  "test-key",
 609					BaseURL: "https://api.custom.com/v1",
 610					Models:  []catwalk.Model{},
 611				},
 612			}),
 613		}
 614		cfg.setDefaults("/tmp", "")
 615
 616		env := env.NewFromMap(map[string]string{})
 617		resolver := NewEnvironmentVariableResolver(env)
 618		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 619		require.NoError(t, err)
 620
 621		require.Equal(t, cfg.Providers.Len(), 0)
 622		_, exists := cfg.Providers.Get("custom")
 623		require.False(t, exists)
 624	})
 625
 626	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 627		cfg := &Config{
 628			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 629				"custom": {
 630					APIKey:  "test-key",
 631					BaseURL: "https://api.custom.com/v1",
 632					Type:    "unsupported",
 633					Models: []catwalk.Model{{
 634						ID: "test-model",
 635					}},
 636				},
 637			}),
 638		}
 639		cfg.setDefaults("/tmp", "")
 640
 641		env := env.NewFromMap(map[string]string{})
 642		resolver := NewEnvironmentVariableResolver(env)
 643		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 644		require.NoError(t, err)
 645
 646		require.Equal(t, cfg.Providers.Len(), 0)
 647		_, exists := cfg.Providers.Get("custom")
 648		require.False(t, exists)
 649	})
 650
 651	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 652		cfg := &Config{
 653			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 654				"custom": {
 655					APIKey:  "test-key",
 656					BaseURL: "https://api.custom.com/v1",
 657					Type:    catwalk.TypeOpenAI,
 658					Models: []catwalk.Model{{
 659						ID: "test-model",
 660					}},
 661				},
 662			}),
 663		}
 664		cfg.setDefaults("/tmp", "")
 665
 666		env := env.NewFromMap(map[string]string{})
 667		resolver := NewEnvironmentVariableResolver(env)
 668		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 669		require.NoError(t, err)
 670
 671		require.Equal(t, cfg.Providers.Len(), 1)
 672		customProvider, exists := cfg.Providers.Get("custom")
 673		require.True(t, exists)
 674		require.Equal(t, "custom", customProvider.ID)
 675		require.Equal(t, "test-key", customProvider.APIKey)
 676		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 677	})
 678
 679	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 680		cfg := &Config{
 681			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 682				"custom-anthropic": {
 683					APIKey:  "test-key",
 684					BaseURL: "https://api.anthropic.com/v1",
 685					Type:    catwalk.TypeAnthropic,
 686					Models: []catwalk.Model{{
 687						ID: "claude-3-sonnet",
 688					}},
 689				},
 690			}),
 691		}
 692		cfg.setDefaults("/tmp", "")
 693
 694		env := env.NewFromMap(map[string]string{})
 695		resolver := NewEnvironmentVariableResolver(env)
 696		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 697		require.NoError(t, err)
 698
 699		require.Equal(t, cfg.Providers.Len(), 1)
 700		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 701		require.True(t, exists)
 702		require.Equal(t, "custom-anthropic", customProvider.ID)
 703		require.Equal(t, "test-key", customProvider.APIKey)
 704		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 705		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 706	})
 707
 708	t.Run("disabled custom provider is removed", func(t *testing.T) {
 709		cfg := &Config{
 710			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 711				"custom": {
 712					APIKey:  "test-key",
 713					BaseURL: "https://api.custom.com/v1",
 714					Type:    catwalk.TypeOpenAI,
 715					Disable: true,
 716					Models: []catwalk.Model{{
 717						ID: "test-model",
 718					}},
 719				},
 720			}),
 721		}
 722		cfg.setDefaults("/tmp", "")
 723
 724		env := env.NewFromMap(map[string]string{})
 725		resolver := NewEnvironmentVariableResolver(env)
 726		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 727		require.NoError(t, err)
 728
 729		require.Equal(t, cfg.Providers.Len(), 0)
 730		_, exists := cfg.Providers.Get("custom")
 731		require.False(t, exists)
 732	})
 733}
 734
 735func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 736	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 737		knownProviders := []catwalk.Provider{
 738			{
 739				ID:          catwalk.InferenceProviderVertexAI,
 740				APIKey:      "",
 741				APIEndpoint: "",
 742				Models: []catwalk.Model{{
 743					ID: "gemini-pro",
 744				}},
 745			},
 746		}
 747
 748		cfg := &Config{
 749			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 750				"vertexai": {
 751					BaseURL: "custom-url",
 752				},
 753			}),
 754		}
 755		cfg.setDefaults("/tmp", "")
 756
 757		env := env.NewFromMap(map[string]string{
 758			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 759		})
 760		resolver := NewEnvironmentVariableResolver(env)
 761		err := cfg.configureProviders(env, resolver, knownProviders)
 762		require.NoError(t, err)
 763
 764		require.Equal(t, cfg.Providers.Len(), 0)
 765		_, exists := cfg.Providers.Get("vertexai")
 766		require.False(t, exists)
 767	})
 768
 769	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 770		knownProviders := []catwalk.Provider{
 771			{
 772				ID:          catwalk.InferenceProviderBedrock,
 773				APIKey:      "",
 774				APIEndpoint: "",
 775				Models: []catwalk.Model{{
 776					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 777				}},
 778			},
 779		}
 780
 781		cfg := &Config{
 782			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 783				"bedrock": {
 784					BaseURL: "custom-url",
 785				},
 786			}),
 787		}
 788		cfg.setDefaults("/tmp", "")
 789
 790		env := env.NewFromMap(map[string]string{})
 791		resolver := NewEnvironmentVariableResolver(env)
 792		err := cfg.configureProviders(env, resolver, knownProviders)
 793		require.NoError(t, err)
 794
 795		require.Equal(t, cfg.Providers.Len(), 0)
 796		_, exists := cfg.Providers.Get("bedrock")
 797		require.False(t, exists)
 798	})
 799
 800	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 801		knownProviders := []catwalk.Provider{
 802			{
 803				ID:          "openai",
 804				APIKey:      "$MISSING_API_KEY",
 805				APIEndpoint: "https://api.openai.com/v1",
 806				Models: []catwalk.Model{{
 807					ID: "test-model",
 808				}},
 809			},
 810		}
 811
 812		cfg := &Config{
 813			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 814				"openai": {
 815					BaseURL: "custom-url",
 816				},
 817			}),
 818		}
 819		cfg.setDefaults("/tmp", "")
 820
 821		env := env.NewFromMap(map[string]string{})
 822		resolver := NewEnvironmentVariableResolver(env)
 823		err := cfg.configureProviders(env, resolver, knownProviders)
 824		require.NoError(t, err)
 825
 826		require.Equal(t, cfg.Providers.Len(), 0)
 827		_, exists := cfg.Providers.Get("openai")
 828		require.False(t, exists)
 829	})
 830
 831	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 832		knownProviders := []catwalk.Provider{
 833			{
 834				ID:          "openai",
 835				APIKey:      "$OPENAI_API_KEY",
 836				APIEndpoint: "$MISSING_ENDPOINT",
 837				Models: []catwalk.Model{{
 838					ID: "test-model",
 839				}},
 840			},
 841		}
 842
 843		cfg := &Config{
 844			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 845				"openai": {
 846					APIKey: "test-key",
 847				},
 848			}),
 849		}
 850		cfg.setDefaults("/tmp", "")
 851
 852		env := env.NewFromMap(map[string]string{
 853			"OPENAI_API_KEY": "test-key",
 854		})
 855		resolver := NewEnvironmentVariableResolver(env)
 856		err := cfg.configureProviders(env, resolver, knownProviders)
 857		require.NoError(t, err)
 858
 859		require.Equal(t, cfg.Providers.Len(), 1)
 860		_, exists := cfg.Providers.Get("openai")
 861		require.True(t, exists)
 862	})
 863}
 864
 865func TestConfig_defaultModelSelection(t *testing.T) {
 866	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 867		knownProviders := []catwalk.Provider{
 868			{
 869				ID:                  "openai",
 870				APIKey:              "abc",
 871				DefaultLargeModelID: "large-model",
 872				DefaultSmallModelID: "small-model",
 873				Models: []catwalk.Model{
 874					{
 875						ID:               "large-model",
 876						DefaultMaxTokens: 1000,
 877					},
 878					{
 879						ID:               "small-model",
 880						DefaultMaxTokens: 500,
 881					},
 882				},
 883			},
 884		}
 885
 886		cfg := &Config{}
 887		cfg.setDefaults("/tmp", "")
 888		env := env.NewFromMap(map[string]string{})
 889		resolver := NewEnvironmentVariableResolver(env)
 890		err := cfg.configureProviders(env, resolver, knownProviders)
 891		require.NoError(t, err)
 892
 893		large, small, err := cfg.defaultModelSelection(knownProviders)
 894		require.NoError(t, err)
 895		require.Equal(t, "large-model", large.Model)
 896		require.Equal(t, "openai", large.Provider)
 897		require.Equal(t, int64(1000), large.MaxTokens)
 898		require.Equal(t, "small-model", small.Model)
 899		require.Equal(t, "openai", small.Provider)
 900		require.Equal(t, int64(500), small.MaxTokens)
 901	})
 902	t.Run("should error if no providers configured", func(t *testing.T) {
 903		knownProviders := []catwalk.Provider{
 904			{
 905				ID:                  "openai",
 906				APIKey:              "$MISSING_KEY",
 907				DefaultLargeModelID: "large-model",
 908				DefaultSmallModelID: "small-model",
 909				Models: []catwalk.Model{
 910					{
 911						ID:               "large-model",
 912						DefaultMaxTokens: 1000,
 913					},
 914					{
 915						ID:               "small-model",
 916						DefaultMaxTokens: 500,
 917					},
 918				},
 919			},
 920		}
 921
 922		cfg := &Config{}
 923		cfg.setDefaults("/tmp", "")
 924		env := env.NewFromMap(map[string]string{})
 925		resolver := NewEnvironmentVariableResolver(env)
 926		err := cfg.configureProviders(env, resolver, knownProviders)
 927		require.NoError(t, err)
 928
 929		_, _, err = cfg.defaultModelSelection(knownProviders)
 930		require.Error(t, err)
 931	})
 932	t.Run("should error if model is missing", func(t *testing.T) {
 933		knownProviders := []catwalk.Provider{
 934			{
 935				ID:                  "openai",
 936				APIKey:              "abc",
 937				DefaultLargeModelID: "large-model",
 938				DefaultSmallModelID: "small-model",
 939				Models: []catwalk.Model{
 940					{
 941						ID:               "not-large-model",
 942						DefaultMaxTokens: 1000,
 943					},
 944					{
 945						ID:               "small-model",
 946						DefaultMaxTokens: 500,
 947					},
 948				},
 949			},
 950		}
 951
 952		cfg := &Config{}
 953		cfg.setDefaults("/tmp", "")
 954		env := env.NewFromMap(map[string]string{})
 955		resolver := NewEnvironmentVariableResolver(env)
 956		err := cfg.configureProviders(env, resolver, knownProviders)
 957		require.NoError(t, err)
 958		_, _, err = cfg.defaultModelSelection(knownProviders)
 959		require.Error(t, err)
 960	})
 961
 962	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 963		knownProviders := []catwalk.Provider{
 964			{
 965				ID:                  "openai",
 966				APIKey:              "$MISSING", // will not be included in the config
 967				DefaultLargeModelID: "large-model",
 968				DefaultSmallModelID: "small-model",
 969				Models: []catwalk.Model{
 970					{
 971						ID:               "not-large-model",
 972						DefaultMaxTokens: 1000,
 973					},
 974					{
 975						ID:               "small-model",
 976						DefaultMaxTokens: 500,
 977					},
 978				},
 979			},
 980		}
 981
 982		cfg := &Config{
 983			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 984				"custom": {
 985					APIKey:  "test-key",
 986					BaseURL: "https://api.custom.com/v1",
 987					Models: []catwalk.Model{
 988						{
 989							ID:               "model",
 990							DefaultMaxTokens: 600,
 991						},
 992					},
 993				},
 994			}),
 995		}
 996		cfg.setDefaults("/tmp", "")
 997		env := env.NewFromMap(map[string]string{})
 998		resolver := NewEnvironmentVariableResolver(env)
 999		err := cfg.configureProviders(env, resolver, knownProviders)
1000		require.NoError(t, err)
1001		large, small, err := cfg.defaultModelSelection(knownProviders)
1002		require.NoError(t, err)
1003		require.Equal(t, "model", large.Model)
1004		require.Equal(t, "custom", large.Provider)
1005		require.Equal(t, int64(600), large.MaxTokens)
1006		require.Equal(t, "model", small.Model)
1007		require.Equal(t, "custom", small.Provider)
1008		require.Equal(t, int64(600), small.MaxTokens)
1009	})
1010
1011	t.Run("should fail if no model configured", func(t *testing.T) {
1012		knownProviders := []catwalk.Provider{
1013			{
1014				ID:                  "openai",
1015				APIKey:              "$MISSING", // will not be included in the config
1016				DefaultLargeModelID: "large-model",
1017				DefaultSmallModelID: "small-model",
1018				Models: []catwalk.Model{
1019					{
1020						ID:               "not-large-model",
1021						DefaultMaxTokens: 1000,
1022					},
1023					{
1024						ID:               "small-model",
1025						DefaultMaxTokens: 500,
1026					},
1027				},
1028			},
1029		}
1030
1031		cfg := &Config{
1032			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1033				"custom": {
1034					APIKey:  "test-key",
1035					BaseURL: "https://api.custom.com/v1",
1036					Models:  []catwalk.Model{},
1037				},
1038			}),
1039		}
1040		cfg.setDefaults("/tmp", "")
1041		env := env.NewFromMap(map[string]string{})
1042		resolver := NewEnvironmentVariableResolver(env)
1043		err := cfg.configureProviders(env, resolver, knownProviders)
1044		require.NoError(t, err)
1045		_, _, err = cfg.defaultModelSelection(knownProviders)
1046		require.Error(t, err)
1047	})
1048	t.Run("should use the default provider first", func(t *testing.T) {
1049		knownProviders := []catwalk.Provider{
1050			{
1051				ID:                  "openai",
1052				APIKey:              "set",
1053				DefaultLargeModelID: "large-model",
1054				DefaultSmallModelID: "small-model",
1055				Models: []catwalk.Model{
1056					{
1057						ID:               "large-model",
1058						DefaultMaxTokens: 1000,
1059					},
1060					{
1061						ID:               "small-model",
1062						DefaultMaxTokens: 500,
1063					},
1064				},
1065			},
1066		}
1067
1068		cfg := &Config{
1069			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1070				"custom": {
1071					APIKey:  "test-key",
1072					BaseURL: "https://api.custom.com/v1",
1073					Models: []catwalk.Model{
1074						{
1075							ID:               "large-model",
1076							DefaultMaxTokens: 1000,
1077						},
1078					},
1079				},
1080			}),
1081		}
1082		cfg.setDefaults("/tmp", "")
1083		env := env.NewFromMap(map[string]string{})
1084		resolver := NewEnvironmentVariableResolver(env)
1085		err := cfg.configureProviders(env, resolver, knownProviders)
1086		require.NoError(t, err)
1087		large, small, err := cfg.defaultModelSelection(knownProviders)
1088		require.NoError(t, err)
1089		require.Equal(t, "large-model", large.Model)
1090		require.Equal(t, "openai", large.Provider)
1091		require.Equal(t, int64(1000), large.MaxTokens)
1092		require.Equal(t, "small-model", small.Model)
1093		require.Equal(t, "openai", small.Provider)
1094		require.Equal(t, int64(500), small.MaxTokens)
1095	})
1096}
1097
1098func TestConfig_configureSelectedModels(t *testing.T) {
1099	t.Run("should override defaults", func(t *testing.T) {
1100		knownProviders := []catwalk.Provider{
1101			{
1102				ID:                  "openai",
1103				APIKey:              "abc",
1104				DefaultLargeModelID: "large-model",
1105				DefaultSmallModelID: "small-model",
1106				Models: []catwalk.Model{
1107					{
1108						ID:               "larger-model",
1109						DefaultMaxTokens: 2000,
1110					},
1111					{
1112						ID:               "large-model",
1113						DefaultMaxTokens: 1000,
1114					},
1115					{
1116						ID:               "small-model",
1117						DefaultMaxTokens: 500,
1118					},
1119				},
1120			},
1121		}
1122
1123		cfg := &Config{
1124			Models: map[SelectedModelType]SelectedModel{
1125				"large": {
1126					Model: "larger-model",
1127				},
1128			},
1129		}
1130		cfg.setDefaults("/tmp", "")
1131		env := env.NewFromMap(map[string]string{})
1132		resolver := NewEnvironmentVariableResolver(env)
1133		err := cfg.configureProviders(env, resolver, knownProviders)
1134		require.NoError(t, err)
1135
1136		err = cfg.configureSelectedModels(knownProviders)
1137		require.NoError(t, err)
1138		large := cfg.Models[SelectedModelTypeLarge]
1139		small := cfg.Models[SelectedModelTypeSmall]
1140		require.Equal(t, "larger-model", large.Model)
1141		require.Equal(t, "openai", large.Provider)
1142		require.Equal(t, int64(2000), large.MaxTokens)
1143		require.Equal(t, "small-model", small.Model)
1144		require.Equal(t, "openai", small.Provider)
1145		require.Equal(t, int64(500), small.MaxTokens)
1146	})
1147	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1148		knownProviders := []catwalk.Provider{
1149			{
1150				ID:                  "openai",
1151				APIKey:              "abc",
1152				DefaultLargeModelID: "large-model",
1153				DefaultSmallModelID: "small-model",
1154				Models: []catwalk.Model{
1155					{
1156						ID:               "large-model",
1157						DefaultMaxTokens: 1000,
1158					},
1159					{
1160						ID:               "small-model",
1161						DefaultMaxTokens: 500,
1162					},
1163				},
1164			},
1165			{
1166				ID:                  "anthropic",
1167				APIKey:              "abc",
1168				DefaultLargeModelID: "a-large-model",
1169				DefaultSmallModelID: "a-small-model",
1170				Models: []catwalk.Model{
1171					{
1172						ID:               "a-large-model",
1173						DefaultMaxTokens: 1000,
1174					},
1175					{
1176						ID:               "a-small-model",
1177						DefaultMaxTokens: 200,
1178					},
1179				},
1180			},
1181		}
1182
1183		cfg := &Config{
1184			Models: map[SelectedModelType]SelectedModel{
1185				"small": {
1186					Model:     "a-small-model",
1187					Provider:  "anthropic",
1188					MaxTokens: 300,
1189				},
1190			},
1191		}
1192		cfg.setDefaults("/tmp", "")
1193		env := env.NewFromMap(map[string]string{})
1194		resolver := NewEnvironmentVariableResolver(env)
1195		err := cfg.configureProviders(env, resolver, knownProviders)
1196		require.NoError(t, err)
1197
1198		err = cfg.configureSelectedModels(knownProviders)
1199		require.NoError(t, err)
1200		large := cfg.Models[SelectedModelTypeLarge]
1201		small := cfg.Models[SelectedModelTypeSmall]
1202		require.Equal(t, "large-model", large.Model)
1203		require.Equal(t, "openai", large.Provider)
1204		require.Equal(t, int64(1000), large.MaxTokens)
1205		require.Equal(t, "a-small-model", small.Model)
1206		require.Equal(t, "anthropic", small.Provider)
1207		require.Equal(t, int64(300), small.MaxTokens)
1208	})
1209
1210	t.Run("should override the max tokens only", func(t *testing.T) {
1211		knownProviders := []catwalk.Provider{
1212			{
1213				ID:                  "openai",
1214				APIKey:              "abc",
1215				DefaultLargeModelID: "large-model",
1216				DefaultSmallModelID: "small-model",
1217				Models: []catwalk.Model{
1218					{
1219						ID:               "large-model",
1220						DefaultMaxTokens: 1000,
1221					},
1222					{
1223						ID:               "small-model",
1224						DefaultMaxTokens: 500,
1225					},
1226				},
1227			},
1228		}
1229
1230		cfg := &Config{
1231			Models: map[SelectedModelType]SelectedModel{
1232				"large": {
1233					MaxTokens: 100,
1234				},
1235			},
1236		}
1237		cfg.setDefaults("/tmp", "")
1238		env := env.NewFromMap(map[string]string{})
1239		resolver := NewEnvironmentVariableResolver(env)
1240		err := cfg.configureProviders(env, resolver, knownProviders)
1241		require.NoError(t, err)
1242
1243		err = cfg.configureSelectedModels(knownProviders)
1244		require.NoError(t, err)
1245		large := cfg.Models[SelectedModelTypeLarge]
1246		require.Equal(t, "large-model", large.Model)
1247		require.Equal(t, "openai", large.Provider)
1248		require.Equal(t, int64(100), large.MaxTokens)
1249	})
1250}