load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/charmbracelet/crush/internal/llm/agent"
  15	"github.com/charmbracelet/crush/internal/llm/provider"
  16	"github.com/charmbracelet/crush/internal/resolver"
  17	"github.com/stretchr/testify/assert"
  18)
  19
  20func TestMain(m *testing.M) {
  21	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  22
  23	exitVal := m.Run()
  24	os.Exit(exitVal)
  25}
  26
  27func TestConfig_LoadFromReaders(t *testing.T) {
  28	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  29	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  30	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  31
  32	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  33
  34	assert.NoError(t, err)
  35	assert.NotNil(t, loadedConfig)
  36	assert.Equal(t, 1, loadedConfig.Providers.Len())
  37	pc, _ := loadedConfig.Providers.Get("openai")
  38	assert.Equal(t, "key2", pc.APIKey)
  39	assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  40}
  41
  42func TestConfig_setDefaults(t *testing.T) {
  43	cfg := &Config{}
  44
  45	cfg.setDefaults("/tmp")
  46
  47	assert.NotNil(t, cfg.Options)
  48	assert.NotNil(t, cfg.Options.TUI)
  49	assert.NotNil(t, cfg.Options.ContextPaths)
  50	assert.NotNil(t, cfg.Providers)
  51	assert.NotNil(t, cfg.Models)
  52	assert.NotNil(t, cfg.LSP)
  53	assert.NotNil(t, cfg.MCP)
  54	assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  55	for _, path := range defaultContextPaths {
  56		assert.Contains(t, cfg.Options.ContextPaths, path)
  57	}
  58	assert.Equal(t, "/tmp", cfg.workingDir)
  59}
  60
  61func TestConfig_configureProviders(t *testing.T) {
  62	knownProviders := []catwalk.Provider{
  63		{
  64			ID:          "openai",
  65			APIKey:      "$OPENAI_API_KEY",
  66			APIEndpoint: "https://api.openai.com/v1",
  67			Models: []catwalk.Model{{
  68				ID: "test-model",
  69			}},
  70		},
  71	}
  72
  73	cfg := &Config{}
  74	cfg.setDefaults("/tmp")
  75	env := env.NewFromMap(map[string]string{
  76		"OPENAI_API_KEY": "test-key",
  77	})
  78	resolver := resolver.NewEnvironmentVariableResolver(env)
  79	err := cfg.configureProviders(env, resolver, knownProviders)
  80	assert.NoError(t, err)
  81	assert.Equal(t, 1, cfg.Providers.Len())
  82
  83	// We want to make sure that we keep the configured API key as a placeholder
  84	pc, _ := cfg.Providers.Get("openai")
  85	assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  86}
  87
  88func TestConfig_configureProvidersWithOverride(t *testing.T) {
  89	knownProviders := []catwalk.Provider{
  90		{
  91			ID:          "openai",
  92			APIKey:      "$OPENAI_API_KEY",
  93			APIEndpoint: "https://api.openai.com/v1",
  94			Models: []catwalk.Model{{
  95				ID: "test-model",
  96			}},
  97		},
  98	}
  99
 100	cfg := &Config{
 101		Providers: csync.NewMap[string, provider.Config](),
 102	}
 103	cfg.Providers.Set("openai", provider.Config{
 104		APIKey:  "xyz",
 105		BaseURL: "https://api.openai.com/v2",
 106		Models: []catwalk.Model{
 107			{
 108				ID:   "test-model",
 109				Name: "Updated",
 110			},
 111			{
 112				ID: "another-model",
 113			},
 114		},
 115	})
 116	cfg.setDefaults("/tmp")
 117
 118	env := env.NewFromMap(map[string]string{
 119		"OPENAI_API_KEY": "test-key",
 120	})
 121	resolver := resolver.NewEnvironmentVariableResolver(env)
 122	err := cfg.configureProviders(env, resolver, knownProviders)
 123	assert.NoError(t, err)
 124	assert.Equal(t, 1, cfg.Providers.Len())
 125
 126	// We want to make sure that we keep the configured API key as a placeholder
 127	pc, _ := cfg.Providers.Get("openai")
 128	assert.Equal(t, "xyz", pc.APIKey)
 129	assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 130	assert.Len(t, pc.Models, 2)
 131	assert.Equal(t, "Updated", pc.Models[0].Name)
 132}
 133
 134func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 135	knownProviders := []catwalk.Provider{
 136		{
 137			ID:          "openai",
 138			APIKey:      "$OPENAI_API_KEY",
 139			APIEndpoint: "https://api.openai.com/v1",
 140			Models: []catwalk.Model{{
 141				ID: "test-model",
 142			}},
 143		},
 144	}
 145
 146	cfg := &Config{
 147		Providers: csync.NewMapFrom(map[string]provider.Config{
 148			"custom": {
 149				APIKey:  "xyz",
 150				BaseURL: "https://api.someendpoint.com/v2",
 151				Models: []catwalk.Model{
 152					{
 153						ID: "test-model",
 154					},
 155				},
 156			},
 157		}),
 158	}
 159	cfg.setDefaults("/tmp")
 160	env := env.NewFromMap(map[string]string{
 161		"OPENAI_API_KEY": "test-key",
 162	})
 163	resolver := resolver.NewEnvironmentVariableResolver(env)
 164	err := cfg.configureProviders(env, resolver, knownProviders)
 165	assert.NoError(t, err)
 166	// Should be to because of the env variable
 167	assert.Equal(t, cfg.Providers.Len(), 2)
 168
 169	// We want to make sure that we keep the configured API key as a placeholder
 170	pc, _ := cfg.Providers.Get("custom")
 171	assert.Equal(t, "xyz", pc.APIKey)
 172	// Make sure we set the ID correctly
 173	assert.Equal(t, "custom", pc.ID)
 174	assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 175	assert.Len(t, pc.Models, 1)
 176
 177	_, ok := cfg.Providers.Get("openai")
 178	assert.True(t, ok, "OpenAI provider should still be present")
 179}
 180
 181func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 182	knownProviders := []catwalk.Provider{
 183		{
 184			ID:          catwalk.InferenceProviderBedrock,
 185			APIKey:      "",
 186			APIEndpoint: "",
 187			Models: []catwalk.Model{{
 188				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 189			}},
 190		},
 191	}
 192
 193	cfg := &Config{}
 194	cfg.setDefaults("/tmp")
 195	env := env.NewFromMap(map[string]string{
 196		"AWS_ACCESS_KEY_ID":     "test-key-id",
 197		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 198	})
 199	resolver := resolver.NewEnvironmentVariableResolver(env)
 200	err := cfg.configureProviders(env, resolver, knownProviders)
 201	assert.NoError(t, err)
 202	assert.Equal(t, cfg.Providers.Len(), 1)
 203
 204	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 205	assert.True(t, ok, "Bedrock provider should be present")
 206	assert.Len(t, bedrockProvider.Models, 1)
 207	assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 208}
 209
 210func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 211	knownProviders := []catwalk.Provider{
 212		{
 213			ID:          catwalk.InferenceProviderBedrock,
 214			APIKey:      "",
 215			APIEndpoint: "",
 216			Models: []catwalk.Model{{
 217				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 218			}},
 219		},
 220	}
 221
 222	cfg := &Config{}
 223	cfg.setDefaults("/tmp")
 224	env := env.NewFromMap(map[string]string{})
 225	resolver := resolver.NewEnvironmentVariableResolver(env)
 226	err := cfg.configureProviders(env, resolver, knownProviders)
 227	assert.NoError(t, err)
 228	// Provider should not be configured without credentials
 229	assert.Equal(t, cfg.Providers.Len(), 0)
 230}
 231
 232func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 233	knownProviders := []catwalk.Provider{
 234		{
 235			ID:          catwalk.InferenceProviderBedrock,
 236			APIKey:      "",
 237			APIEndpoint: "",
 238			Models: []catwalk.Model{{
 239				ID: "some-random-model",
 240			}},
 241		},
 242	}
 243
 244	cfg := &Config{}
 245	cfg.setDefaults("/tmp")
 246	env := env.NewFromMap(map[string]string{
 247		"AWS_ACCESS_KEY_ID":     "test-key-id",
 248		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 249	})
 250	resolver := resolver.NewEnvironmentVariableResolver(env)
 251	err := cfg.configureProviders(env, resolver, knownProviders)
 252	assert.Error(t, err)
 253}
 254
 255func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 256	knownProviders := []catwalk.Provider{
 257		{
 258			ID:          catwalk.InferenceProviderVertexAI,
 259			APIKey:      "",
 260			APIEndpoint: "",
 261			Models: []catwalk.Model{{
 262				ID: "gemini-pro",
 263			}},
 264		},
 265	}
 266
 267	cfg := &Config{}
 268	cfg.setDefaults("/tmp")
 269	env := env.NewFromMap(map[string]string{
 270		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 271		"GOOGLE_CLOUD_PROJECT":      "test-project",
 272		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 273	})
 274	resolver := resolver.NewEnvironmentVariableResolver(env)
 275	err := cfg.configureProviders(env, resolver, knownProviders)
 276	assert.NoError(t, err)
 277	assert.Equal(t, cfg.Providers.Len(), 1)
 278
 279	vertexProvider, ok := cfg.Providers.Get("vertexai")
 280	assert.True(t, ok, "VertexAI provider should be present")
 281	assert.Len(t, vertexProvider.Models, 1)
 282	assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 283	assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 284	assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 285}
 286
 287func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 288	knownProviders := []catwalk.Provider{
 289		{
 290			ID:          catwalk.InferenceProviderVertexAI,
 291			APIKey:      "",
 292			APIEndpoint: "",
 293			Models: []catwalk.Model{{
 294				ID: "gemini-pro",
 295			}},
 296		},
 297	}
 298
 299	cfg := &Config{}
 300	cfg.setDefaults("/tmp")
 301	env := env.NewFromMap(map[string]string{
 302		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 303		"GOOGLE_CLOUD_PROJECT":      "test-project",
 304		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 305	})
 306	resolver := resolver.NewEnvironmentVariableResolver(env)
 307	err := cfg.configureProviders(env, resolver, knownProviders)
 308	assert.NoError(t, err)
 309	// Provider should not be configured without proper credentials
 310	assert.Equal(t, cfg.Providers.Len(), 0)
 311}
 312
 313func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 314	knownProviders := []catwalk.Provider{
 315		{
 316			ID:          catwalk.InferenceProviderVertexAI,
 317			APIKey:      "",
 318			APIEndpoint: "",
 319			Models: []catwalk.Model{{
 320				ID: "gemini-pro",
 321			}},
 322		},
 323	}
 324
 325	cfg := &Config{}
 326	cfg.setDefaults("/tmp")
 327	env := env.NewFromMap(map[string]string{
 328		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 329		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 330	})
 331	resolver := resolver.NewEnvironmentVariableResolver(env)
 332	err := cfg.configureProviders(env, resolver, knownProviders)
 333	assert.NoError(t, err)
 334	// Provider should not be configured without project
 335	assert.Equal(t, cfg.Providers.Len(), 0)
 336}
 337
 338func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 339	knownProviders := []catwalk.Provider{
 340		{
 341			ID:          "openai",
 342			APIKey:      "$OPENAI_API_KEY",
 343			APIEndpoint: "https://api.openai.com/v1",
 344			Models: []catwalk.Model{{
 345				ID: "test-model",
 346			}},
 347		},
 348	}
 349
 350	cfg := &Config{}
 351	cfg.setDefaults("/tmp")
 352	env := env.NewFromMap(map[string]string{
 353		"OPENAI_API_KEY": "test-key",
 354	})
 355	resolver := resolver.NewEnvironmentVariableResolver(env)
 356	err := cfg.configureProviders(env, resolver, knownProviders)
 357	assert.NoError(t, err)
 358	assert.Equal(t, cfg.Providers.Len(), 1)
 359
 360	// Provider ID should be set
 361	pc, _ := cfg.Providers.Get("openai")
 362	assert.Equal(t, "openai", pc.ID)
 363}
 364
 365func TestConfig_EnabledProviders(t *testing.T) {
 366	t.Run("all providers enabled", func(t *testing.T) {
 367		cfg := &Config{
 368			Providers: csync.NewMapFrom(map[string]provider.Config{
 369				"openai": {
 370					ID:      "openai",
 371					APIKey:  "key1",
 372					Disable: false,
 373				},
 374				"anthropic": {
 375					ID:      "anthropic",
 376					APIKey:  "key2",
 377					Disable: false,
 378				},
 379			}),
 380		}
 381
 382		enabled := cfg.EnabledProviders()
 383		assert.Len(t, enabled, 2)
 384	})
 385
 386	t.Run("some providers disabled", func(t *testing.T) {
 387		cfg := &Config{
 388			Providers: csync.NewMapFrom(map[string]provider.Config{
 389				"openai": {
 390					ID:      "openai",
 391					APIKey:  "key1",
 392					Disable: false,
 393				},
 394				"anthropic": {
 395					ID:      "anthropic",
 396					APIKey:  "key2",
 397					Disable: true,
 398				},
 399			}),
 400		}
 401
 402		enabled := cfg.EnabledProviders()
 403		assert.Len(t, enabled, 1)
 404		assert.Equal(t, "openai", enabled[0].ID)
 405	})
 406
 407	t.Run("empty providers map", func(t *testing.T) {
 408		cfg := &Config{
 409			Providers: csync.NewMap[string, provider.Config](),
 410		}
 411
 412		enabled := cfg.EnabledProviders()
 413		assert.Len(t, enabled, 0)
 414	})
 415}
 416
 417func TestConfig_IsConfigured(t *testing.T) {
 418	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 419		cfg := &Config{
 420			Providers: csync.NewMapFrom(map[string]provider.Config{
 421				"openai": {
 422					ID:      "openai",
 423					APIKey:  "key1",
 424					Disable: false,
 425				},
 426			}),
 427		}
 428
 429		assert.True(t, cfg.IsConfigured())
 430	})
 431
 432	t.Run("returns false when no providers are configured", func(t *testing.T) {
 433		cfg := &Config{
 434			Providers: csync.NewMap[string, provider.Config](),
 435		}
 436
 437		assert.False(t, cfg.IsConfigured())
 438	})
 439
 440	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 441		cfg := &Config{
 442			Providers: csync.NewMapFrom(map[string]provider.Config{
 443				"openai": {
 444					ID:      "openai",
 445					APIKey:  "key1",
 446					Disable: true,
 447				},
 448				"anthropic": {
 449					ID:      "anthropic",
 450					APIKey:  "key2",
 451					Disable: true,
 452				},
 453			}),
 454		}
 455
 456		assert.False(t, cfg.IsConfigured())
 457	})
 458}
 459
 460func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 461	knownProviders := []catwalk.Provider{
 462		{
 463			ID:          "openai",
 464			APIKey:      "$OPENAI_API_KEY",
 465			APIEndpoint: "https://api.openai.com/v1",
 466			Models: []catwalk.Model{{
 467				ID: "test-model",
 468			}},
 469		},
 470	}
 471
 472	cfg := &Config{
 473		Providers: csync.NewMapFrom(map[string]provider.Config{
 474			"openai": {
 475				Disable: true,
 476			},
 477		}),
 478	}
 479	cfg.setDefaults("/tmp")
 480
 481	env := env.NewFromMap(map[string]string{
 482		"OPENAI_API_KEY": "test-key",
 483	})
 484	resolver := resolver.NewEnvironmentVariableResolver(env)
 485	err := cfg.configureProviders(env, resolver, knownProviders)
 486	assert.NoError(t, err)
 487
 488	// Provider should be removed from config when disabled
 489	assert.Equal(t, cfg.Providers.Len(), 0)
 490	_, exists := cfg.Providers.Get("openai")
 491	assert.False(t, exists)
 492}
 493
 494func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 495	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 496		cfg := &Config{
 497			Providers: csync.NewMapFrom(map[string]provider.Config{
 498				"custom": {
 499					BaseURL: "https://api.custom.com/v1",
 500					Models: []catwalk.Model{{
 501						ID: "test-model",
 502					}},
 503				},
 504				"openai": {
 505					APIKey: "$MISSING",
 506				},
 507			}),
 508		}
 509		cfg.setDefaults("/tmp")
 510
 511		env := env.NewFromMap(map[string]string{})
 512		resolver := resolver.NewEnvironmentVariableResolver(env)
 513		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 514		assert.NoError(t, err)
 515
 516		assert.Equal(t, cfg.Providers.Len(), 1)
 517		_, exists := cfg.Providers.Get("custom")
 518		assert.True(t, exists)
 519	})
 520
 521	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 522		cfg := &Config{
 523			Providers: csync.NewMapFrom(map[string]provider.Config{
 524				"custom": {
 525					APIKey: "test-key",
 526					Models: []catwalk.Model{{
 527						ID: "test-model",
 528					}},
 529				},
 530			}),
 531		}
 532		cfg.setDefaults("/tmp")
 533
 534		env := env.NewFromMap(map[string]string{})
 535		resolver := resolver.NewEnvironmentVariableResolver(env)
 536		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 537		assert.NoError(t, err)
 538
 539		assert.Equal(t, cfg.Providers.Len(), 0)
 540		_, exists := cfg.Providers.Get("custom")
 541		assert.False(t, exists)
 542	})
 543
 544	t.Run("custom provider with no models is removed", func(t *testing.T) {
 545		cfg := &Config{
 546			Providers: csync.NewMapFrom(map[string]provider.Config{
 547				"custom": {
 548					APIKey:  "test-key",
 549					BaseURL: "https://api.custom.com/v1",
 550					Models:  []catwalk.Model{},
 551				},
 552			}),
 553		}
 554		cfg.setDefaults("/tmp")
 555
 556		env := env.NewFromMap(map[string]string{})
 557		resolver := resolver.NewEnvironmentVariableResolver(env)
 558		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 559		assert.NoError(t, err)
 560
 561		assert.Equal(t, cfg.Providers.Len(), 0)
 562		_, exists := cfg.Providers.Get("custom")
 563		assert.False(t, exists)
 564	})
 565
 566	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 567		cfg := &Config{
 568			Providers: csync.NewMapFrom(map[string]provider.Config{
 569				"custom": {
 570					APIKey:  "test-key",
 571					BaseURL: "https://api.custom.com/v1",
 572					Type:    "unsupported",
 573					Models: []catwalk.Model{{
 574						ID: "test-model",
 575					}},
 576				},
 577			}),
 578		}
 579		cfg.setDefaults("/tmp")
 580
 581		env := env.NewFromMap(map[string]string{})
 582		resolver := resolver.NewEnvironmentVariableResolver(env)
 583		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 584		assert.NoError(t, err)
 585
 586		assert.Equal(t, cfg.Providers.Len(), 0)
 587		_, exists := cfg.Providers.Get("custom")
 588		assert.False(t, exists)
 589	})
 590
 591	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 592		cfg := &Config{
 593			Providers: csync.NewMapFrom(map[string]provider.Config{
 594				"custom": {
 595					APIKey:  "test-key",
 596					BaseURL: "https://api.custom.com/v1",
 597					Type:    catwalk.TypeOpenAI,
 598					Models: []catwalk.Model{{
 599						ID: "test-model",
 600					}},
 601				},
 602			}),
 603		}
 604		cfg.setDefaults("/tmp")
 605
 606		env := env.NewFromMap(map[string]string{})
 607		resolver := resolver.NewEnvironmentVariableResolver(env)
 608		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 609		assert.NoError(t, err)
 610
 611		assert.Equal(t, cfg.Providers.Len(), 1)
 612		customProvider, exists := cfg.Providers.Get("custom")
 613		assert.True(t, exists)
 614		assert.Equal(t, "custom", customProvider.ID)
 615		assert.Equal(t, "test-key", customProvider.APIKey)
 616		assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 617	})
 618
 619	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 620		cfg := &Config{
 621			Providers: csync.NewMapFrom(map[string]provider.Config{
 622				"custom-anthropic": {
 623					APIKey:  "test-key",
 624					BaseURL: "https://api.anthropic.com/v1",
 625					Type:    catwalk.TypeAnthropic,
 626					Models: []catwalk.Model{{
 627						ID: "claude-3-sonnet",
 628					}},
 629				},
 630			}),
 631		}
 632		cfg.setDefaults("/tmp")
 633
 634		env := env.NewFromMap(map[string]string{})
 635		resolver := resolver.NewEnvironmentVariableResolver(env)
 636		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 637		assert.NoError(t, err)
 638
 639		assert.Equal(t, cfg.Providers.Len(), 1)
 640		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 641		assert.True(t, exists)
 642		assert.Equal(t, "custom-anthropic", customProvider.ID)
 643		assert.Equal(t, "test-key", customProvider.APIKey)
 644		assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 645		assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 646	})
 647
 648	t.Run("disabled custom provider is removed", func(t *testing.T) {
 649		cfg := &Config{
 650			Providers: csync.NewMapFrom(map[string]provider.Config{
 651				"custom": {
 652					APIKey:  "test-key",
 653					BaseURL: "https://api.custom.com/v1",
 654					Type:    catwalk.TypeOpenAI,
 655					Disable: true,
 656					Models: []catwalk.Model{{
 657						ID: "test-model",
 658					}},
 659				},
 660			}),
 661		}
 662		cfg.setDefaults("/tmp")
 663
 664		env := env.NewFromMap(map[string]string{})
 665		resolver := resolver.NewEnvironmentVariableResolver(env)
 666		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 667		assert.NoError(t, err)
 668
 669		assert.Equal(t, cfg.Providers.Len(), 0)
 670		_, exists := cfg.Providers.Get("custom")
 671		assert.False(t, exists)
 672	})
 673}
 674
 675func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 676	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 677		knownProviders := []catwalk.Provider{
 678			{
 679				ID:          catwalk.InferenceProviderVertexAI,
 680				APIKey:      "",
 681				APIEndpoint: "",
 682				Models: []catwalk.Model{{
 683					ID: "gemini-pro",
 684				}},
 685			},
 686		}
 687
 688		cfg := &Config{
 689			Providers: csync.NewMapFrom(map[string]provider.Config{
 690				"vertexai": {
 691					BaseURL: "custom-url",
 692				},
 693			}),
 694		}
 695		cfg.setDefaults("/tmp")
 696
 697		env := env.NewFromMap(map[string]string{
 698			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 699		})
 700		resolver := resolver.NewEnvironmentVariableResolver(env)
 701		err := cfg.configureProviders(env, resolver, knownProviders)
 702		assert.NoError(t, err)
 703
 704		assert.Equal(t, cfg.Providers.Len(), 0)
 705		_, exists := cfg.Providers.Get("vertexai")
 706		assert.False(t, exists)
 707	})
 708
 709	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 710		knownProviders := []catwalk.Provider{
 711			{
 712				ID:          catwalk.InferenceProviderBedrock,
 713				APIKey:      "",
 714				APIEndpoint: "",
 715				Models: []catwalk.Model{{
 716					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 717				}},
 718			},
 719		}
 720
 721		cfg := &Config{
 722			Providers: csync.NewMapFrom(map[string]provider.Config{
 723				"bedrock": {
 724					BaseURL: "custom-url",
 725				},
 726			}),
 727		}
 728		cfg.setDefaults("/tmp")
 729
 730		env := env.NewFromMap(map[string]string{})
 731		resolver := resolver.NewEnvironmentVariableResolver(env)
 732		err := cfg.configureProviders(env, resolver, knownProviders)
 733		assert.NoError(t, err)
 734
 735		assert.Equal(t, cfg.Providers.Len(), 0)
 736		_, exists := cfg.Providers.Get("bedrock")
 737		assert.False(t, exists)
 738	})
 739
 740	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 741		knownProviders := []catwalk.Provider{
 742			{
 743				ID:          "openai",
 744				APIKey:      "$MISSING_API_KEY",
 745				APIEndpoint: "https://api.openai.com/v1",
 746				Models: []catwalk.Model{{
 747					ID: "test-model",
 748				}},
 749			},
 750		}
 751
 752		cfg := &Config{
 753			Providers: csync.NewMapFrom(map[string]provider.Config{
 754				"openai": {
 755					BaseURL: "custom-url",
 756				},
 757			}),
 758		}
 759		cfg.setDefaults("/tmp")
 760
 761		env := env.NewFromMap(map[string]string{})
 762		resolver := resolver.NewEnvironmentVariableResolver(env)
 763		err := cfg.configureProviders(env, resolver, knownProviders)
 764		assert.NoError(t, err)
 765
 766		assert.Equal(t, cfg.Providers.Len(), 0)
 767		_, exists := cfg.Providers.Get("openai")
 768		assert.False(t, exists)
 769	})
 770
 771	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 772		knownProviders := []catwalk.Provider{
 773			{
 774				ID:          "openai",
 775				APIKey:      "$OPENAI_API_KEY",
 776				APIEndpoint: "$MISSING_ENDPOINT",
 777				Models: []catwalk.Model{{
 778					ID: "test-model",
 779				}},
 780			},
 781		}
 782
 783		cfg := &Config{
 784			Providers: csync.NewMapFrom(map[string]provider.Config{
 785				"openai": {
 786					APIKey: "test-key",
 787				},
 788			}),
 789		}
 790		cfg.setDefaults("/tmp")
 791
 792		env := env.NewFromMap(map[string]string{
 793			"OPENAI_API_KEY": "test-key",
 794		})
 795		resolver := resolver.NewEnvironmentVariableResolver(env)
 796		err := cfg.configureProviders(env, resolver, knownProviders)
 797		assert.NoError(t, err)
 798
 799		assert.Equal(t, cfg.Providers.Len(), 1)
 800		_, exists := cfg.Providers.Get("openai")
 801		assert.True(t, exists)
 802	})
 803}
 804
 805func TestConfig_defaultModelSelection(t *testing.T) {
 806	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 807		knownProviders := []catwalk.Provider{
 808			{
 809				ID:                  "openai",
 810				APIKey:              "abc",
 811				DefaultLargeModelID: "large-model",
 812				DefaultSmallModelID: "small-model",
 813				Models: []catwalk.Model{
 814					{
 815						ID:               "large-model",
 816						DefaultMaxTokens: 1000,
 817					},
 818					{
 819						ID:               "small-model",
 820						DefaultMaxTokens: 500,
 821					},
 822				},
 823			},
 824		}
 825
 826		cfg := &Config{}
 827		cfg.setDefaults("/tmp")
 828		env := env.NewFromMap(map[string]string{})
 829		resolver := resolver.NewEnvironmentVariableResolver(env)
 830		err := cfg.configureProviders(env, resolver, knownProviders)
 831		assert.NoError(t, err)
 832
 833		large, small, err := cfg.defaultModelSelection(knownProviders)
 834		assert.NoError(t, err)
 835		assert.Equal(t, "large-model", large.Model)
 836		assert.Equal(t, "openai", large.Provider)
 837		assert.Equal(t, int64(1000), large.MaxTokens)
 838		assert.Equal(t, "small-model", small.Model)
 839		assert.Equal(t, "openai", small.Provider)
 840		assert.Equal(t, int64(500), small.MaxTokens)
 841	})
 842	t.Run("should error if no providers configured", func(t *testing.T) {
 843		knownProviders := []catwalk.Provider{
 844			{
 845				ID:                  "openai",
 846				APIKey:              "$MISSING_KEY",
 847				DefaultLargeModelID: "large-model",
 848				DefaultSmallModelID: "small-model",
 849				Models: []catwalk.Model{
 850					{
 851						ID:               "large-model",
 852						DefaultMaxTokens: 1000,
 853					},
 854					{
 855						ID:               "small-model",
 856						DefaultMaxTokens: 500,
 857					},
 858				},
 859			},
 860		}
 861
 862		cfg := &Config{}
 863		cfg.setDefaults("/tmp")
 864		env := env.NewFromMap(map[string]string{})
 865		resolver := resolver.NewEnvironmentVariableResolver(env)
 866		err := cfg.configureProviders(env, resolver, knownProviders)
 867		assert.NoError(t, err)
 868
 869		_, _, err = cfg.defaultModelSelection(knownProviders)
 870		assert.Error(t, err)
 871	})
 872	t.Run("should error if model is missing", func(t *testing.T) {
 873		knownProviders := []catwalk.Provider{
 874			{
 875				ID:                  "openai",
 876				APIKey:              "abc",
 877				DefaultLargeModelID: "large-model",
 878				DefaultSmallModelID: "small-model",
 879				Models: []catwalk.Model{
 880					{
 881						ID:               "not-large-model",
 882						DefaultMaxTokens: 1000,
 883					},
 884					{
 885						ID:               "small-model",
 886						DefaultMaxTokens: 500,
 887					},
 888				},
 889			},
 890		}
 891
 892		cfg := &Config{}
 893		cfg.setDefaults("/tmp")
 894		env := env.NewFromMap(map[string]string{})
 895		resolver := resolver.NewEnvironmentVariableResolver(env)
 896		err := cfg.configureProviders(env, resolver, knownProviders)
 897		assert.NoError(t, err)
 898		_, _, err = cfg.defaultModelSelection(knownProviders)
 899		assert.Error(t, err)
 900	})
 901
 902	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 903		knownProviders := []catwalk.Provider{
 904			{
 905				ID:                  "openai",
 906				APIKey:              "$MISSING", // will not be included in the config
 907				DefaultLargeModelID: "large-model",
 908				DefaultSmallModelID: "small-model",
 909				Models: []catwalk.Model{
 910					{
 911						ID:               "not-large-model",
 912						DefaultMaxTokens: 1000,
 913					},
 914					{
 915						ID:               "small-model",
 916						DefaultMaxTokens: 500,
 917					},
 918				},
 919			},
 920		}
 921
 922		cfg := &Config{
 923			Providers: csync.NewMapFrom(map[string]provider.Config{
 924				"custom": {
 925					APIKey:  "test-key",
 926					BaseURL: "https://api.custom.com/v1",
 927					Models: []catwalk.Model{
 928						{
 929							ID:               "model",
 930							DefaultMaxTokens: 600,
 931						},
 932					},
 933				},
 934			}),
 935		}
 936		cfg.setDefaults("/tmp")
 937		env := env.NewFromMap(map[string]string{})
 938		resolver := resolver.NewEnvironmentVariableResolver(env)
 939		err := cfg.configureProviders(env, resolver, knownProviders)
 940		assert.NoError(t, err)
 941		large, small, err := cfg.defaultModelSelection(knownProviders)
 942		assert.NoError(t, err)
 943		assert.Equal(t, "model", large.Model)
 944		assert.Equal(t, "custom", large.Provider)
 945		assert.Equal(t, int64(600), large.MaxTokens)
 946		assert.Equal(t, "model", small.Model)
 947		assert.Equal(t, "custom", small.Provider)
 948		assert.Equal(t, int64(600), small.MaxTokens)
 949	})
 950
 951	t.Run("should fail if no model configured", func(t *testing.T) {
 952		knownProviders := []catwalk.Provider{
 953			{
 954				ID:                  "openai",
 955				APIKey:              "$MISSING", // will not be included in the config
 956				DefaultLargeModelID: "large-model",
 957				DefaultSmallModelID: "small-model",
 958				Models: []catwalk.Model{
 959					{
 960						ID:               "not-large-model",
 961						DefaultMaxTokens: 1000,
 962					},
 963					{
 964						ID:               "small-model",
 965						DefaultMaxTokens: 500,
 966					},
 967				},
 968			},
 969		}
 970
 971		cfg := &Config{
 972			Providers: csync.NewMapFrom(map[string]provider.Config{
 973				"custom": {
 974					APIKey:  "test-key",
 975					BaseURL: "https://api.custom.com/v1",
 976					Models:  []catwalk.Model{},
 977				},
 978			}),
 979		}
 980		cfg.setDefaults("/tmp")
 981		env := env.NewFromMap(map[string]string{})
 982		resolver := resolver.NewEnvironmentVariableResolver(env)
 983		err := cfg.configureProviders(env, resolver, knownProviders)
 984		assert.NoError(t, err)
 985		_, _, err = cfg.defaultModelSelection(knownProviders)
 986		assert.Error(t, err)
 987	})
 988	t.Run("should use the default provider first", func(t *testing.T) {
 989		knownProviders := []catwalk.Provider{
 990			{
 991				ID:                  "openai",
 992				APIKey:              "set",
 993				DefaultLargeModelID: "large-model",
 994				DefaultSmallModelID: "small-model",
 995				Models: []catwalk.Model{
 996					{
 997						ID:               "large-model",
 998						DefaultMaxTokens: 1000,
 999					},
1000					{
1001						ID:               "small-model",
1002						DefaultMaxTokens: 500,
1003					},
1004				},
1005			},
1006		}
1007
1008		cfg := &Config{
1009			Providers: csync.NewMapFrom(map[string]provider.Config{
1010				"custom": {
1011					APIKey:  "test-key",
1012					BaseURL: "https://api.custom.com/v1",
1013					Models: []catwalk.Model{
1014						{
1015							ID:               "large-model",
1016							DefaultMaxTokens: 1000,
1017						},
1018					},
1019				},
1020			}),
1021		}
1022		cfg.setDefaults("/tmp")
1023		env := env.NewFromMap(map[string]string{})
1024		resolver := resolver.NewEnvironmentVariableResolver(env)
1025		err := cfg.configureProviders(env, resolver, knownProviders)
1026		assert.NoError(t, err)
1027		large, small, err := cfg.defaultModelSelection(knownProviders)
1028		assert.NoError(t, err)
1029		assert.Equal(t, "large-model", large.Model)
1030		assert.Equal(t, "openai", large.Provider)
1031		assert.Equal(t, int64(1000), large.MaxTokens)
1032		assert.Equal(t, "small-model", small.Model)
1033		assert.Equal(t, "openai", small.Provider)
1034		assert.Equal(t, int64(500), small.MaxTokens)
1035	})
1036}
1037
1038func TestConfig_configureSelectedModels(t *testing.T) {
1039	t.Run("should override defaults", func(t *testing.T) {
1040		knownProviders := []catwalk.Provider{
1041			{
1042				ID:                  "openai",
1043				APIKey:              "abc",
1044				DefaultLargeModelID: "large-model",
1045				DefaultSmallModelID: "small-model",
1046				Models: []catwalk.Model{
1047					{
1048						ID:               "larger-model",
1049						DefaultMaxTokens: 2000,
1050					},
1051					{
1052						ID:               "large-model",
1053						DefaultMaxTokens: 1000,
1054					},
1055					{
1056						ID:               "small-model",
1057						DefaultMaxTokens: 500,
1058					},
1059				},
1060			},
1061		}
1062
1063		cfg := &Config{
1064			Models: map[SelectedModelType]agent.Model{
1065				"large": {
1066					Model: "larger-model",
1067				},
1068			},
1069		}
1070		cfg.setDefaults("/tmp")
1071		env := env.NewFromMap(map[string]string{})
1072		resolver := resolver.NewEnvironmentVariableResolver(env)
1073		err := cfg.configureProviders(env, resolver, knownProviders)
1074		assert.NoError(t, err)
1075
1076		err = cfg.configureSelectedModels(knownProviders)
1077		assert.NoError(t, err)
1078		large := cfg.Models[SelectedModelTypeLarge]
1079		small := cfg.Models[SelectedModelTypeSmall]
1080		assert.Equal(t, "larger-model", large.Model)
1081		assert.Equal(t, "openai", large.Provider)
1082		assert.Equal(t, int64(2000), large.MaxTokens)
1083		assert.Equal(t, "small-model", small.Model)
1084		assert.Equal(t, "openai", small.Provider)
1085		assert.Equal(t, int64(500), small.MaxTokens)
1086	})
1087	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1088		knownProviders := []catwalk.Provider{
1089			{
1090				ID:                  "openai",
1091				APIKey:              "abc",
1092				DefaultLargeModelID: "large-model",
1093				DefaultSmallModelID: "small-model",
1094				Models: []catwalk.Model{
1095					{
1096						ID:               "large-model",
1097						DefaultMaxTokens: 1000,
1098					},
1099					{
1100						ID:               "small-model",
1101						DefaultMaxTokens: 500,
1102					},
1103				},
1104			},
1105			{
1106				ID:                  "anthropic",
1107				APIKey:              "abc",
1108				DefaultLargeModelID: "a-large-model",
1109				DefaultSmallModelID: "a-small-model",
1110				Models: []catwalk.Model{
1111					{
1112						ID:               "a-large-model",
1113						DefaultMaxTokens: 1000,
1114					},
1115					{
1116						ID:               "a-small-model",
1117						DefaultMaxTokens: 200,
1118					},
1119				},
1120			},
1121		}
1122
1123		cfg := &Config{
1124			Models: map[SelectedModelType]agent.Model{
1125				"small": {
1126					Model:     "a-small-model",
1127					Provider:  "anthropic",
1128					MaxTokens: 300,
1129				},
1130			},
1131		}
1132		cfg.setDefaults("/tmp")
1133		env := env.NewFromMap(map[string]string{})
1134		resolver := resolver.NewEnvironmentVariableResolver(env)
1135		err := cfg.configureProviders(env, resolver, knownProviders)
1136		assert.NoError(t, err)
1137
1138		err = cfg.configureSelectedModels(knownProviders)
1139		assert.NoError(t, err)
1140		large := cfg.Models[SelectedModelTypeLarge]
1141		small := cfg.Models[SelectedModelTypeSmall]
1142		assert.Equal(t, "large-model", large.Model)
1143		assert.Equal(t, "openai", large.Provider)
1144		assert.Equal(t, int64(1000), large.MaxTokens)
1145		assert.Equal(t, "a-small-model", small.Model)
1146		assert.Equal(t, "anthropic", small.Provider)
1147		assert.Equal(t, int64(300), small.MaxTokens)
1148	})
1149
1150	t.Run("should override the max tokens only", func(t *testing.T) {
1151		knownProviders := []catwalk.Provider{
1152			{
1153				ID:                  "openai",
1154				APIKey:              "abc",
1155				DefaultLargeModelID: "large-model",
1156				DefaultSmallModelID: "small-model",
1157				Models: []catwalk.Model{
1158					{
1159						ID:               "large-model",
1160						DefaultMaxTokens: 1000,
1161					},
1162					{
1163						ID:               "small-model",
1164						DefaultMaxTokens: 500,
1165					},
1166				},
1167			},
1168		}
1169
1170		cfg := &Config{
1171			Models: map[SelectedModelType]agent.Model{
1172				"large": {
1173					MaxTokens: 100,
1174				},
1175			},
1176		}
1177		cfg.setDefaults("/tmp")
1178		env := env.NewFromMap(map[string]string{})
1179		resolver := resolver.NewEnvironmentVariableResolver(env)
1180		err := cfg.configureProviders(env, resolver, knownProviders)
1181		assert.NoError(t, err)
1182
1183		err = cfg.configureSelectedModels(knownProviders)
1184		assert.NoError(t, err)
1185		large := cfg.Models[SelectedModelTypeLarge]
1186		assert.Equal(t, "large-model", large.Model)
1187		assert.Equal(t, "openai", large.Provider)
1188		assert.Equal(t, int64(100), large.MaxTokens)
1189	})
1190}