load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/stretchr/testify/assert"
  15)
  16
  17func TestMain(m *testing.M) {
  18	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  19
  20	exitVal := m.Run()
  21	os.Exit(exitVal)
  22}
  23
  24func TestConfig_LoadFromReaders(t *testing.T) {
  25	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  26	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  27	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  28
  29	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  30
  31	assert.NoError(t, err)
  32	assert.NotNil(t, loadedConfig)
  33	assert.Equal(t, 1, loadedConfig.Providers.Len())
  34	pc, _ := loadedConfig.Providers.Get("openai")
  35	assert.Equal(t, "key2", pc.APIKey)
  36	assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  37}
  38
  39func TestConfig_setDefaults(t *testing.T) {
  40	cfg := &Config{}
  41
  42	cfg.setDefaults("/tmp")
  43
  44	assert.NotNil(t, cfg.Options)
  45	assert.NotNil(t, cfg.Options.TUI)
  46	assert.NotNil(t, cfg.Options.ContextPaths)
  47	assert.NotNil(t, cfg.Providers)
  48	assert.NotNil(t, cfg.Models)
  49	assert.NotNil(t, cfg.LSP)
  50	assert.NotNil(t, cfg.MCP)
  51	assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  52	for _, path := range defaultContextPaths {
  53		assert.Contains(t, cfg.Options.ContextPaths, path)
  54	}
  55	assert.Equal(t, "/tmp", cfg.workingDir)
  56}
  57
  58func TestConfig_configureProviders(t *testing.T) {
  59	knownProviders := []catwalk.Provider{
  60		{
  61			ID:          "openai",
  62			APIKey:      "$OPENAI_API_KEY",
  63			APIEndpoint: "https://api.openai.com/v1",
  64			Models: []catwalk.Model{{
  65				ID: "test-model",
  66			}},
  67		},
  68	}
  69
  70	cfg := &Config{}
  71	cfg.setDefaults("/tmp")
  72	env := env.NewFromMap(map[string]string{
  73		"OPENAI_API_KEY": "test-key",
  74	})
  75	resolver := NewEnvironmentVariableResolver(env)
  76	err := cfg.configureProviders(env, resolver, knownProviders)
  77	assert.NoError(t, err)
  78	assert.Equal(t, 1, cfg.Providers.Len())
  79
  80	// We want to make sure that we keep the configured API key as a placeholder
  81	pc, _ := cfg.Providers.Get("openai")
  82	assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  83}
  84
  85func TestConfig_configureProvidersWithOverride(t *testing.T) {
  86	knownProviders := []catwalk.Provider{
  87		{
  88			ID:          "openai",
  89			APIKey:      "$OPENAI_API_KEY",
  90			APIEndpoint: "https://api.openai.com/v1",
  91			Models: []catwalk.Model{{
  92				ID: "test-model",
  93			}},
  94		},
  95	}
  96
  97	cfg := &Config{
  98		Providers: csync.NewMap[string, ProviderConfig](),
  99	}
 100	cfg.Providers.Set("openai", ProviderConfig{
 101		APIKey:  "xyz",
 102		BaseURL: "https://api.openai.com/v2",
 103		Models: []catwalk.Model{
 104			{
 105				ID:   "test-model",
 106				Name: "Updated",
 107			},
 108			{
 109				ID: "another-model",
 110			},
 111		},
 112	})
 113	cfg.setDefaults("/tmp")
 114
 115	env := env.NewFromMap(map[string]string{
 116		"OPENAI_API_KEY": "test-key",
 117	})
 118	resolver := NewEnvironmentVariableResolver(env)
 119	err := cfg.configureProviders(env, resolver, knownProviders)
 120	assert.NoError(t, err)
 121	assert.Equal(t, 1, cfg.Providers.Len())
 122
 123	// We want to make sure that we keep the configured API key as a placeholder
 124	pc, _ := cfg.Providers.Get("openai")
 125	assert.Equal(t, "xyz", pc.APIKey)
 126	assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 127	assert.Len(t, pc.Models, 2)
 128	assert.Equal(t, "Updated", pc.Models[0].Name)
 129}
 130
 131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 132	knownProviders := []catwalk.Provider{
 133		{
 134			ID:          "openai",
 135			APIKey:      "$OPENAI_API_KEY",
 136			APIEndpoint: "https://api.openai.com/v1",
 137			Models: []catwalk.Model{{
 138				ID: "test-model",
 139			}},
 140		},
 141	}
 142
 143	cfg := &Config{
 144		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 145			"custom": {
 146				APIKey:  "xyz",
 147				BaseURL: "https://api.someendpoint.com/v2",
 148				Models: []catwalk.Model{
 149					{
 150						ID: "test-model",
 151					},
 152				},
 153			},
 154		}),
 155	}
 156	cfg.setDefaults("/tmp")
 157	env := env.NewFromMap(map[string]string{
 158		"OPENAI_API_KEY": "test-key",
 159	})
 160	resolver := NewEnvironmentVariableResolver(env)
 161	err := cfg.configureProviders(env, resolver, knownProviders)
 162	assert.NoError(t, err)
 163	// Should be to because of the env variable
 164	assert.Equal(t, cfg.Providers.Len(), 2)
 165
 166	// We want to make sure that we keep the configured API key as a placeholder
 167	pc, _ := cfg.Providers.Get("custom")
 168	assert.Equal(t, "xyz", pc.APIKey)
 169	// Make sure we set the ID correctly
 170	assert.Equal(t, "custom", pc.ID)
 171	assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 172	assert.Len(t, pc.Models, 1)
 173
 174	_, ok := cfg.Providers.Get("openai")
 175	assert.True(t, ok, "OpenAI provider should still be present")
 176}
 177
 178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 179	knownProviders := []catwalk.Provider{
 180		{
 181			ID:          catwalk.InferenceProviderBedrock,
 182			APIKey:      "",
 183			APIEndpoint: "",
 184			Models: []catwalk.Model{{
 185				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 186			}},
 187		},
 188	}
 189
 190	cfg := &Config{}
 191	cfg.setDefaults("/tmp")
 192	env := env.NewFromMap(map[string]string{
 193		"AWS_ACCESS_KEY_ID":     "test-key-id",
 194		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 195	})
 196	resolver := NewEnvironmentVariableResolver(env)
 197	err := cfg.configureProviders(env, resolver, knownProviders)
 198	assert.NoError(t, err)
 199	assert.Equal(t, cfg.Providers.Len(), 1)
 200
 201	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 202	assert.True(t, ok, "Bedrock provider should be present")
 203	assert.Len(t, bedrockProvider.Models, 1)
 204	assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 205}
 206
 207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 208	knownProviders := []catwalk.Provider{
 209		{
 210			ID:          catwalk.InferenceProviderBedrock,
 211			APIKey:      "",
 212			APIEndpoint: "",
 213			Models: []catwalk.Model{{
 214				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 215			}},
 216		},
 217	}
 218
 219	cfg := &Config{}
 220	cfg.setDefaults("/tmp")
 221	env := env.NewFromMap(map[string]string{})
 222	resolver := NewEnvironmentVariableResolver(env)
 223	err := cfg.configureProviders(env, resolver, knownProviders)
 224	assert.NoError(t, err)
 225	// Provider should not be configured without credentials
 226	assert.Equal(t, cfg.Providers.Len(), 0)
 227}
 228
 229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 230	knownProviders := []catwalk.Provider{
 231		{
 232			ID:          catwalk.InferenceProviderBedrock,
 233			APIKey:      "",
 234			APIEndpoint: "",
 235			Models: []catwalk.Model{{
 236				ID: "some-random-model",
 237			}},
 238		},
 239	}
 240
 241	cfg := &Config{}
 242	cfg.setDefaults("/tmp")
 243	env := env.NewFromMap(map[string]string{
 244		"AWS_ACCESS_KEY_ID":     "test-key-id",
 245		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 246	})
 247	resolver := NewEnvironmentVariableResolver(env)
 248	err := cfg.configureProviders(env, resolver, knownProviders)
 249	assert.Error(t, err)
 250}
 251
 252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 253	knownProviders := []catwalk.Provider{
 254		{
 255			ID:          catwalk.InferenceProviderVertexAI,
 256			APIKey:      "",
 257			APIEndpoint: "",
 258			Models: []catwalk.Model{{
 259				ID: "gemini-pro",
 260			}},
 261		},
 262	}
 263
 264	cfg := &Config{}
 265	cfg.setDefaults("/tmp")
 266	env := env.NewFromMap(map[string]string{
 267		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 268		"GOOGLE_CLOUD_PROJECT":      "test-project",
 269		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 270	})
 271	resolver := NewEnvironmentVariableResolver(env)
 272	err := cfg.configureProviders(env, resolver, knownProviders)
 273	assert.NoError(t, err)
 274	assert.Equal(t, cfg.Providers.Len(), 1)
 275
 276	vertexProvider, ok := cfg.Providers.Get("vertexai")
 277	assert.True(t, ok, "VertexAI provider should be present")
 278	assert.Len(t, vertexProvider.Models, 1)
 279	assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 280	assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 281	assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 282}
 283
 284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 285	knownProviders := []catwalk.Provider{
 286		{
 287			ID:          catwalk.InferenceProviderVertexAI,
 288			APIKey:      "",
 289			APIEndpoint: "",
 290			Models: []catwalk.Model{{
 291				ID: "gemini-pro",
 292			}},
 293		},
 294	}
 295
 296	cfg := &Config{}
 297	cfg.setDefaults("/tmp")
 298	env := env.NewFromMap(map[string]string{
 299		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 300		"GOOGLE_CLOUD_PROJECT":      "test-project",
 301		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 302	})
 303	resolver := NewEnvironmentVariableResolver(env)
 304	err := cfg.configureProviders(env, resolver, knownProviders)
 305	assert.NoError(t, err)
 306	// Provider should not be configured without proper credentials
 307	assert.Equal(t, cfg.Providers.Len(), 0)
 308}
 309
 310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 311	knownProviders := []catwalk.Provider{
 312		{
 313			ID:          catwalk.InferenceProviderVertexAI,
 314			APIKey:      "",
 315			APIEndpoint: "",
 316			Models: []catwalk.Model{{
 317				ID: "gemini-pro",
 318			}},
 319		},
 320	}
 321
 322	cfg := &Config{}
 323	cfg.setDefaults("/tmp")
 324	env := env.NewFromMap(map[string]string{
 325		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 326		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 327	})
 328	resolver := NewEnvironmentVariableResolver(env)
 329	err := cfg.configureProviders(env, resolver, knownProviders)
 330	assert.NoError(t, err)
 331	// Provider should not be configured without project
 332	assert.Equal(t, cfg.Providers.Len(), 0)
 333}
 334
 335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 336	knownProviders := []catwalk.Provider{
 337		{
 338			ID:          "openai",
 339			APIKey:      "$OPENAI_API_KEY",
 340			APIEndpoint: "https://api.openai.com/v1",
 341			Models: []catwalk.Model{{
 342				ID: "test-model",
 343			}},
 344		},
 345	}
 346
 347	cfg := &Config{}
 348	cfg.setDefaults("/tmp")
 349	env := env.NewFromMap(map[string]string{
 350		"OPENAI_API_KEY": "test-key",
 351	})
 352	resolver := NewEnvironmentVariableResolver(env)
 353	err := cfg.configureProviders(env, resolver, knownProviders)
 354	assert.NoError(t, err)
 355	assert.Equal(t, cfg.Providers.Len(), 1)
 356
 357	// Provider ID should be set
 358	pc, _ := cfg.Providers.Get("openai")
 359	assert.Equal(t, "openai", pc.ID)
 360}
 361
 362func TestConfig_EnabledProviders(t *testing.T) {
 363	t.Run("all providers enabled", func(t *testing.T) {
 364		cfg := &Config{
 365			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 366				"openai": {
 367					ID:      "openai",
 368					APIKey:  "key1",
 369					Disable: false,
 370				},
 371				"anthropic": {
 372					ID:      "anthropic",
 373					APIKey:  "key2",
 374					Disable: false,
 375				},
 376			}),
 377		}
 378
 379		enabled := cfg.EnabledProviders()
 380		assert.Len(t, enabled, 2)
 381	})
 382
 383	t.Run("some providers disabled", func(t *testing.T) {
 384		cfg := &Config{
 385			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 386				"openai": {
 387					ID:      "openai",
 388					APIKey:  "key1",
 389					Disable: false,
 390				},
 391				"anthropic": {
 392					ID:      "anthropic",
 393					APIKey:  "key2",
 394					Disable: true,
 395				},
 396			}),
 397		}
 398
 399		enabled := cfg.EnabledProviders()
 400		assert.Len(t, enabled, 1)
 401		assert.Equal(t, "openai", enabled[0].ID)
 402	})
 403
 404	t.Run("empty providers map", func(t *testing.T) {
 405		cfg := &Config{
 406			Providers: csync.NewMap[string, ProviderConfig](),
 407		}
 408
 409		enabled := cfg.EnabledProviders()
 410		assert.Len(t, enabled, 0)
 411	})
 412}
 413
 414func TestConfig_IsConfigured(t *testing.T) {
 415	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 416		cfg := &Config{
 417			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 418				"openai": {
 419					ID:      "openai",
 420					APIKey:  "key1",
 421					Disable: false,
 422				},
 423			}),
 424		}
 425
 426		assert.True(t, cfg.IsConfigured())
 427	})
 428
 429	t.Run("returns false when no providers are configured", func(t *testing.T) {
 430		cfg := &Config{
 431			Providers: csync.NewMap[string, ProviderConfig](),
 432		}
 433
 434		assert.False(t, cfg.IsConfigured())
 435	})
 436
 437	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 438		cfg := &Config{
 439			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 440				"openai": {
 441					ID:      "openai",
 442					APIKey:  "key1",
 443					Disable: true,
 444				},
 445				"anthropic": {
 446					ID:      "anthropic",
 447					APIKey:  "key2",
 448					Disable: true,
 449				},
 450			}),
 451		}
 452
 453		assert.False(t, cfg.IsConfigured())
 454	})
 455}
 456
 457func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 458	knownProviders := []catwalk.Provider{
 459		{
 460			ID:          "openai",
 461			APIKey:      "$OPENAI_API_KEY",
 462			APIEndpoint: "https://api.openai.com/v1",
 463			Models: []catwalk.Model{{
 464				ID: "test-model",
 465			}},
 466		},
 467	}
 468
 469	cfg := &Config{
 470		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 471			"openai": {
 472				Disable: true,
 473			},
 474		}),
 475	}
 476	cfg.setDefaults("/tmp")
 477
 478	env := env.NewFromMap(map[string]string{
 479		"OPENAI_API_KEY": "test-key",
 480	})
 481	resolver := NewEnvironmentVariableResolver(env)
 482	err := cfg.configureProviders(env, resolver, knownProviders)
 483	assert.NoError(t, err)
 484
 485	// Provider should be removed from config when disabled
 486	assert.Equal(t, cfg.Providers.Len(), 0)
 487	_, exists := cfg.Providers.Get("openai")
 488	assert.False(t, exists)
 489}
 490
 491func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 492	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 493		cfg := &Config{
 494			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 495				"custom": {
 496					BaseURL: "https://api.custom.com/v1",
 497					Models: []catwalk.Model{{
 498						ID: "test-model",
 499					}},
 500				},
 501				"openai": {
 502					APIKey: "$MISSING",
 503				},
 504			}),
 505		}
 506		cfg.setDefaults("/tmp")
 507
 508		env := env.NewFromMap(map[string]string{})
 509		resolver := NewEnvironmentVariableResolver(env)
 510		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 511		assert.NoError(t, err)
 512
 513		assert.Equal(t, cfg.Providers.Len(), 1)
 514		_, exists := cfg.Providers.Get("custom")
 515		assert.True(t, exists)
 516	})
 517
 518	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 519		cfg := &Config{
 520			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 521				"custom": {
 522					APIKey: "test-key",
 523					Models: []catwalk.Model{{
 524						ID: "test-model",
 525					}},
 526				},
 527			}),
 528		}
 529		cfg.setDefaults("/tmp")
 530
 531		env := env.NewFromMap(map[string]string{})
 532		resolver := NewEnvironmentVariableResolver(env)
 533		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 534		assert.NoError(t, err)
 535
 536		assert.Equal(t, cfg.Providers.Len(), 0)
 537		_, exists := cfg.Providers.Get("custom")
 538		assert.False(t, exists)
 539	})
 540
 541	t.Run("custom provider with no models is removed", func(t *testing.T) {
 542		cfg := &Config{
 543			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 544				"custom": {
 545					APIKey:  "test-key",
 546					BaseURL: "https://api.custom.com/v1",
 547					Models:  []catwalk.Model{},
 548				},
 549			}),
 550		}
 551		cfg.setDefaults("/tmp")
 552
 553		env := env.NewFromMap(map[string]string{})
 554		resolver := NewEnvironmentVariableResolver(env)
 555		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 556		assert.NoError(t, err)
 557
 558		assert.Equal(t, cfg.Providers.Len(), 0)
 559		_, exists := cfg.Providers.Get("custom")
 560		assert.False(t, exists)
 561	})
 562
 563	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 564		cfg := &Config{
 565			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 566				"custom": {
 567					APIKey:  "test-key",
 568					BaseURL: "https://api.custom.com/v1",
 569					Type:    "unsupported",
 570					Models: []catwalk.Model{{
 571						ID: "test-model",
 572					}},
 573				},
 574			}),
 575		}
 576		cfg.setDefaults("/tmp")
 577
 578		env := env.NewFromMap(map[string]string{})
 579		resolver := NewEnvironmentVariableResolver(env)
 580		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 581		assert.NoError(t, err)
 582
 583		assert.Equal(t, cfg.Providers.Len(), 0)
 584		_, exists := cfg.Providers.Get("custom")
 585		assert.False(t, exists)
 586	})
 587
 588	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 589		cfg := &Config{
 590			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 591				"custom": {
 592					APIKey:  "test-key",
 593					BaseURL: "https://api.custom.com/v1",
 594					Type:    catwalk.TypeOpenAI,
 595					Models: []catwalk.Model{{
 596						ID: "test-model",
 597					}},
 598				},
 599			}),
 600		}
 601		cfg.setDefaults("/tmp")
 602
 603		env := env.NewFromMap(map[string]string{})
 604		resolver := NewEnvironmentVariableResolver(env)
 605		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 606		assert.NoError(t, err)
 607
 608		assert.Equal(t, cfg.Providers.Len(), 1)
 609		customProvider, exists := cfg.Providers.Get("custom")
 610		assert.True(t, exists)
 611		assert.Equal(t, "custom", customProvider.ID)
 612		assert.Equal(t, "test-key", customProvider.APIKey)
 613		assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 614	})
 615
 616	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 617		cfg := &Config{
 618			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 619				"custom-anthropic": {
 620					APIKey:  "test-key",
 621					BaseURL: "https://api.anthropic.com/v1",
 622					Type:    catwalk.TypeAnthropic,
 623					Models: []catwalk.Model{{
 624						ID: "claude-3-sonnet",
 625					}},
 626				},
 627			}),
 628		}
 629		cfg.setDefaults("/tmp")
 630
 631		env := env.NewFromMap(map[string]string{})
 632		resolver := NewEnvironmentVariableResolver(env)
 633		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 634		assert.NoError(t, err)
 635
 636		assert.Equal(t, cfg.Providers.Len(), 1)
 637		customProvider, exists := cfg.Providers.Get("custom-anthropic")
 638		assert.True(t, exists)
 639		assert.Equal(t, "custom-anthropic", customProvider.ID)
 640		assert.Equal(t, "test-key", customProvider.APIKey)
 641		assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 642		assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 643	})
 644
 645	t.Run("disabled custom provider is removed", func(t *testing.T) {
 646		cfg := &Config{
 647			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 648				"custom": {
 649					APIKey:  "test-key",
 650					BaseURL: "https://api.custom.com/v1",
 651					Type:    catwalk.TypeOpenAI,
 652					Disable: true,
 653					Models: []catwalk.Model{{
 654						ID: "test-model",
 655					}},
 656				},
 657			}),
 658		}
 659		cfg.setDefaults("/tmp")
 660
 661		env := env.NewFromMap(map[string]string{})
 662		resolver := NewEnvironmentVariableResolver(env)
 663		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 664		assert.NoError(t, err)
 665
 666		assert.Equal(t, cfg.Providers.Len(), 0)
 667		_, exists := cfg.Providers.Get("custom")
 668		assert.False(t, exists)
 669	})
 670}
 671
 672func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 673	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 674		knownProviders := []catwalk.Provider{
 675			{
 676				ID:          catwalk.InferenceProviderVertexAI,
 677				APIKey:      "",
 678				APIEndpoint: "",
 679				Models: []catwalk.Model{{
 680					ID: "gemini-pro",
 681				}},
 682			},
 683		}
 684
 685		cfg := &Config{
 686			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 687				"vertexai": {
 688					BaseURL: "custom-url",
 689				},
 690			}),
 691		}
 692		cfg.setDefaults("/tmp")
 693
 694		env := env.NewFromMap(map[string]string{
 695			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 696		})
 697		resolver := NewEnvironmentVariableResolver(env)
 698		err := cfg.configureProviders(env, resolver, knownProviders)
 699		assert.NoError(t, err)
 700
 701		assert.Equal(t, cfg.Providers.Len(), 0)
 702		_, exists := cfg.Providers.Get("vertexai")
 703		assert.False(t, exists)
 704	})
 705
 706	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 707		knownProviders := []catwalk.Provider{
 708			{
 709				ID:          catwalk.InferenceProviderBedrock,
 710				APIKey:      "",
 711				APIEndpoint: "",
 712				Models: []catwalk.Model{{
 713					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 714				}},
 715			},
 716		}
 717
 718		cfg := &Config{
 719			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 720				"bedrock": {
 721					BaseURL: "custom-url",
 722				},
 723			}),
 724		}
 725		cfg.setDefaults("/tmp")
 726
 727		env := env.NewFromMap(map[string]string{})
 728		resolver := NewEnvironmentVariableResolver(env)
 729		err := cfg.configureProviders(env, resolver, knownProviders)
 730		assert.NoError(t, err)
 731
 732		assert.Equal(t, cfg.Providers.Len(), 0)
 733		_, exists := cfg.Providers.Get("bedrock")
 734		assert.False(t, exists)
 735	})
 736
 737	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 738		knownProviders := []catwalk.Provider{
 739			{
 740				ID:          "openai",
 741				APIKey:      "$MISSING_API_KEY",
 742				APIEndpoint: "https://api.openai.com/v1",
 743				Models: []catwalk.Model{{
 744					ID: "test-model",
 745				}},
 746			},
 747		}
 748
 749		cfg := &Config{
 750			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 751				"openai": {
 752					BaseURL: "custom-url",
 753				},
 754			}),
 755		}
 756		cfg.setDefaults("/tmp")
 757
 758		env := env.NewFromMap(map[string]string{})
 759		resolver := NewEnvironmentVariableResolver(env)
 760		err := cfg.configureProviders(env, resolver, knownProviders)
 761		assert.NoError(t, err)
 762
 763		assert.Equal(t, cfg.Providers.Len(), 0)
 764		_, exists := cfg.Providers.Get("openai")
 765		assert.False(t, exists)
 766	})
 767
 768	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 769		knownProviders := []catwalk.Provider{
 770			{
 771				ID:          "openai",
 772				APIKey:      "$OPENAI_API_KEY",
 773				APIEndpoint: "$MISSING_ENDPOINT",
 774				Models: []catwalk.Model{{
 775					ID: "test-model",
 776				}},
 777			},
 778		}
 779
 780		cfg := &Config{
 781			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 782				"openai": {
 783					APIKey: "test-key",
 784				},
 785			}),
 786		}
 787		cfg.setDefaults("/tmp")
 788
 789		env := env.NewFromMap(map[string]string{
 790			"OPENAI_API_KEY": "test-key",
 791		})
 792		resolver := NewEnvironmentVariableResolver(env)
 793		err := cfg.configureProviders(env, resolver, knownProviders)
 794		assert.NoError(t, err)
 795
 796		assert.Equal(t, cfg.Providers.Len(), 1)
 797		_, exists := cfg.Providers.Get("openai")
 798		assert.True(t, exists)
 799	})
 800}
 801
 802func TestConfig_defaultModelSelection(t *testing.T) {
 803	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 804		knownProviders := []catwalk.Provider{
 805			{
 806				ID:                  "openai",
 807				APIKey:              "abc",
 808				DefaultLargeModelID: "large-model",
 809				DefaultSmallModelID: "small-model",
 810				Models: []catwalk.Model{
 811					{
 812						ID:               "large-model",
 813						DefaultMaxTokens: 1000,
 814					},
 815					{
 816						ID:               "small-model",
 817						DefaultMaxTokens: 500,
 818					},
 819				},
 820			},
 821		}
 822
 823		cfg := &Config{}
 824		cfg.setDefaults("/tmp")
 825		env := env.NewFromMap(map[string]string{})
 826		resolver := NewEnvironmentVariableResolver(env)
 827		err := cfg.configureProviders(env, resolver, knownProviders)
 828		assert.NoError(t, err)
 829
 830		large, small, err := cfg.defaultModelSelection(knownProviders)
 831		assert.NoError(t, err)
 832		assert.Equal(t, "large-model", large.Model)
 833		assert.Equal(t, "openai", large.Provider)
 834		assert.Equal(t, int64(1000), large.MaxTokens)
 835		assert.Equal(t, "small-model", small.Model)
 836		assert.Equal(t, "openai", small.Provider)
 837		assert.Equal(t, int64(500), small.MaxTokens)
 838	})
 839	t.Run("should error if no providers configured", func(t *testing.T) {
 840		knownProviders := []catwalk.Provider{
 841			{
 842				ID:                  "openai",
 843				APIKey:              "$MISSING_KEY",
 844				DefaultLargeModelID: "large-model",
 845				DefaultSmallModelID: "small-model",
 846				Models: []catwalk.Model{
 847					{
 848						ID:               "large-model",
 849						DefaultMaxTokens: 1000,
 850					},
 851					{
 852						ID:               "small-model",
 853						DefaultMaxTokens: 500,
 854					},
 855				},
 856			},
 857		}
 858
 859		cfg := &Config{}
 860		cfg.setDefaults("/tmp")
 861		env := env.NewFromMap(map[string]string{})
 862		resolver := NewEnvironmentVariableResolver(env)
 863		err := cfg.configureProviders(env, resolver, knownProviders)
 864		assert.NoError(t, err)
 865
 866		_, _, err = cfg.defaultModelSelection(knownProviders)
 867		assert.Error(t, err)
 868	})
 869	t.Run("should error if model is missing", func(t *testing.T) {
 870		knownProviders := []catwalk.Provider{
 871			{
 872				ID:                  "openai",
 873				APIKey:              "abc",
 874				DefaultLargeModelID: "large-model",
 875				DefaultSmallModelID: "small-model",
 876				Models: []catwalk.Model{
 877					{
 878						ID:               "not-large-model",
 879						DefaultMaxTokens: 1000,
 880					},
 881					{
 882						ID:               "small-model",
 883						DefaultMaxTokens: 500,
 884					},
 885				},
 886			},
 887		}
 888
 889		cfg := &Config{}
 890		cfg.setDefaults("/tmp")
 891		env := env.NewFromMap(map[string]string{})
 892		resolver := NewEnvironmentVariableResolver(env)
 893		err := cfg.configureProviders(env, resolver, knownProviders)
 894		assert.NoError(t, err)
 895		_, _, err = cfg.defaultModelSelection(knownProviders)
 896		assert.Error(t, err)
 897	})
 898
 899	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 900		knownProviders := []catwalk.Provider{
 901			{
 902				ID:                  "openai",
 903				APIKey:              "$MISSING", // will not be included in the config
 904				DefaultLargeModelID: "large-model",
 905				DefaultSmallModelID: "small-model",
 906				Models: []catwalk.Model{
 907					{
 908						ID:               "not-large-model",
 909						DefaultMaxTokens: 1000,
 910					},
 911					{
 912						ID:               "small-model",
 913						DefaultMaxTokens: 500,
 914					},
 915				},
 916			},
 917		}
 918
 919		cfg := &Config{
 920			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 921				"custom": {
 922					APIKey:  "test-key",
 923					BaseURL: "https://api.custom.com/v1",
 924					Models: []catwalk.Model{
 925						{
 926							ID:               "model",
 927							DefaultMaxTokens: 600,
 928						},
 929					},
 930				},
 931			}),
 932		}
 933		cfg.setDefaults("/tmp")
 934		env := env.NewFromMap(map[string]string{})
 935		resolver := NewEnvironmentVariableResolver(env)
 936		err := cfg.configureProviders(env, resolver, knownProviders)
 937		assert.NoError(t, err)
 938		large, small, err := cfg.defaultModelSelection(knownProviders)
 939		assert.NoError(t, err)
 940		assert.Equal(t, "model", large.Model)
 941		assert.Equal(t, "custom", large.Provider)
 942		assert.Equal(t, int64(600), large.MaxTokens)
 943		assert.Equal(t, "model", small.Model)
 944		assert.Equal(t, "custom", small.Provider)
 945		assert.Equal(t, int64(600), small.MaxTokens)
 946	})
 947
 948	t.Run("should fail if no model configured", func(t *testing.T) {
 949		knownProviders := []catwalk.Provider{
 950			{
 951				ID:                  "openai",
 952				APIKey:              "$MISSING", // will not be included in the config
 953				DefaultLargeModelID: "large-model",
 954				DefaultSmallModelID: "small-model",
 955				Models: []catwalk.Model{
 956					{
 957						ID:               "not-large-model",
 958						DefaultMaxTokens: 1000,
 959					},
 960					{
 961						ID:               "small-model",
 962						DefaultMaxTokens: 500,
 963					},
 964				},
 965			},
 966		}
 967
 968		cfg := &Config{
 969			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 970				"custom": {
 971					APIKey:  "test-key",
 972					BaseURL: "https://api.custom.com/v1",
 973					Models:  []catwalk.Model{},
 974				},
 975			}),
 976		}
 977		cfg.setDefaults("/tmp")
 978		env := env.NewFromMap(map[string]string{})
 979		resolver := NewEnvironmentVariableResolver(env)
 980		err := cfg.configureProviders(env, resolver, knownProviders)
 981		assert.NoError(t, err)
 982		_, _, err = cfg.defaultModelSelection(knownProviders)
 983		assert.Error(t, err)
 984	})
 985	t.Run("should use the default provider first", func(t *testing.T) {
 986		knownProviders := []catwalk.Provider{
 987			{
 988				ID:                  "openai",
 989				APIKey:              "set",
 990				DefaultLargeModelID: "large-model",
 991				DefaultSmallModelID: "small-model",
 992				Models: []catwalk.Model{
 993					{
 994						ID:               "large-model",
 995						DefaultMaxTokens: 1000,
 996					},
 997					{
 998						ID:               "small-model",
 999						DefaultMaxTokens: 500,
1000					},
1001				},
1002			},
1003		}
1004
1005		cfg := &Config{
1006			Providers: csync.NewMapFrom(map[string]ProviderConfig{
1007				"custom": {
1008					APIKey:  "test-key",
1009					BaseURL: "https://api.custom.com/v1",
1010					Models: []catwalk.Model{
1011						{
1012							ID:               "large-model",
1013							DefaultMaxTokens: 1000,
1014						},
1015					},
1016				},
1017			}),
1018		}
1019		cfg.setDefaults("/tmp")
1020		env := env.NewFromMap(map[string]string{})
1021		resolver := NewEnvironmentVariableResolver(env)
1022		err := cfg.configureProviders(env, resolver, knownProviders)
1023		assert.NoError(t, err)
1024		large, small, err := cfg.defaultModelSelection(knownProviders)
1025		assert.NoError(t, err)
1026		assert.Equal(t, "large-model", large.Model)
1027		assert.Equal(t, "openai", large.Provider)
1028		assert.Equal(t, int64(1000), large.MaxTokens)
1029		assert.Equal(t, "small-model", small.Model)
1030		assert.Equal(t, "openai", small.Provider)
1031		assert.Equal(t, int64(500), small.MaxTokens)
1032	})
1033}
1034
1035func TestConfig_configureSelectedModels(t *testing.T) {
1036	t.Run("should override defaults", func(t *testing.T) {
1037		knownProviders := []catwalk.Provider{
1038			{
1039				ID:                  "openai",
1040				APIKey:              "abc",
1041				DefaultLargeModelID: "large-model",
1042				DefaultSmallModelID: "small-model",
1043				Models: []catwalk.Model{
1044					{
1045						ID:               "larger-model",
1046						DefaultMaxTokens: 2000,
1047					},
1048					{
1049						ID:               "large-model",
1050						DefaultMaxTokens: 1000,
1051					},
1052					{
1053						ID:               "small-model",
1054						DefaultMaxTokens: 500,
1055					},
1056				},
1057			},
1058		}
1059
1060		cfg := &Config{
1061			Models: map[SelectedModelType]SelectedModel{
1062				"large": {
1063					Model: "larger-model",
1064				},
1065			},
1066		}
1067		cfg.setDefaults("/tmp")
1068		env := env.NewFromMap(map[string]string{})
1069		resolver := NewEnvironmentVariableResolver(env)
1070		err := cfg.configureProviders(env, resolver, knownProviders)
1071		assert.NoError(t, err)
1072
1073		err = cfg.configureSelectedModels(knownProviders)
1074		assert.NoError(t, err)
1075		large := cfg.Models[SelectedModelTypeLarge]
1076		small := cfg.Models[SelectedModelTypeSmall]
1077		assert.Equal(t, "larger-model", large.Model)
1078		assert.Equal(t, "openai", large.Provider)
1079		assert.Equal(t, int64(2000), large.MaxTokens)
1080		assert.Equal(t, "small-model", small.Model)
1081		assert.Equal(t, "openai", small.Provider)
1082		assert.Equal(t, int64(500), small.MaxTokens)
1083	})
1084	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1085		knownProviders := []catwalk.Provider{
1086			{
1087				ID:                  "openai",
1088				APIKey:              "abc",
1089				DefaultLargeModelID: "large-model",
1090				DefaultSmallModelID: "small-model",
1091				Models: []catwalk.Model{
1092					{
1093						ID:               "large-model",
1094						DefaultMaxTokens: 1000,
1095					},
1096					{
1097						ID:               "small-model",
1098						DefaultMaxTokens: 500,
1099					},
1100				},
1101			},
1102			{
1103				ID:                  "anthropic",
1104				APIKey:              "abc",
1105				DefaultLargeModelID: "a-large-model",
1106				DefaultSmallModelID: "a-small-model",
1107				Models: []catwalk.Model{
1108					{
1109						ID:               "a-large-model",
1110						DefaultMaxTokens: 1000,
1111					},
1112					{
1113						ID:               "a-small-model",
1114						DefaultMaxTokens: 200,
1115					},
1116				},
1117			},
1118		}
1119
1120		cfg := &Config{
1121			Models: map[SelectedModelType]SelectedModel{
1122				"small": {
1123					Model:     "a-small-model",
1124					Provider:  "anthropic",
1125					MaxTokens: 300,
1126				},
1127			},
1128		}
1129		cfg.setDefaults("/tmp")
1130		env := env.NewFromMap(map[string]string{})
1131		resolver := NewEnvironmentVariableResolver(env)
1132		err := cfg.configureProviders(env, resolver, knownProviders)
1133		assert.NoError(t, err)
1134
1135		err = cfg.configureSelectedModels(knownProviders)
1136		assert.NoError(t, err)
1137		large := cfg.Models[SelectedModelTypeLarge]
1138		small := cfg.Models[SelectedModelTypeSmall]
1139		assert.Equal(t, "large-model", large.Model)
1140		assert.Equal(t, "openai", large.Provider)
1141		assert.Equal(t, int64(1000), large.MaxTokens)
1142		assert.Equal(t, "a-small-model", small.Model)
1143		assert.Equal(t, "anthropic", small.Provider)
1144		assert.Equal(t, int64(300), small.MaxTokens)
1145	})
1146
1147	t.Run("should override the max tokens only", func(t *testing.T) {
1148		knownProviders := []catwalk.Provider{
1149			{
1150				ID:                  "openai",
1151				APIKey:              "abc",
1152				DefaultLargeModelID: "large-model",
1153				DefaultSmallModelID: "small-model",
1154				Models: []catwalk.Model{
1155					{
1156						ID:               "large-model",
1157						DefaultMaxTokens: 1000,
1158					},
1159					{
1160						ID:               "small-model",
1161						DefaultMaxTokens: 500,
1162					},
1163				},
1164			},
1165		}
1166
1167		cfg := &Config{
1168			Models: map[SelectedModelType]SelectedModel{
1169				"large": {
1170					MaxTokens: 100,
1171				},
1172			},
1173		}
1174		cfg.setDefaults("/tmp")
1175		env := env.NewFromMap(map[string]string{})
1176		resolver := NewEnvironmentVariableResolver(env)
1177		err := cfg.configureProviders(env, resolver, knownProviders)
1178		assert.NoError(t, err)
1179
1180		err = cfg.configureSelectedModels(knownProviders)
1181		assert.NoError(t, err)
1182		large := cfg.Models[SelectedModelTypeLarge]
1183		assert.Equal(t, "large-model", large.Model)
1184		assert.Equal(t, "openai", large.Provider)
1185		assert.Equal(t, int64(100), large.MaxTokens)
1186	})
1187}