load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"strings"
   9	"testing"
  10
  11	"github.com/charmbracelet/catwalk/pkg/catwalk"
  12	"github.com/charmbracelet/crush/internal/csync"
  13	"github.com/charmbracelet/crush/internal/env"
  14	"github.com/stretchr/testify/assert"
  15)
  16
  17func TestMain(m *testing.M) {
  18	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  19
  20	exitVal := m.Run()
  21	os.Exit(exitVal)
  22}
  23
  24func TestConfig_LoadFromReaders(t *testing.T) {
  25	data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  26	data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  27	data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  28
  29	loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  30
  31	assert.NoError(t, err)
  32	assert.NotNil(t, loadedConfig)
  33	assert.Equal(t, 1, loadedConfig.Providers.Len())
  34	pc, _ := loadedConfig.Providers.Get("openai")
  35	assert.Equal(t, "key2", pc.APIKey)
  36	assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  37}
  38
  39func TestConfig_setDefaults(t *testing.T) {
  40	cfg := &Config{}
  41
  42	cfg.setDefaults("/tmp")
  43
  44	assert.NotNil(t, cfg.Options)
  45	assert.NotNil(t, cfg.Options.TUI)
  46	assert.NotNil(t, cfg.Options.ContextPaths)
  47	assert.NotNil(t, cfg.Providers)
  48	assert.NotNil(t, cfg.Models)
  49	assert.NotNil(t, cfg.LSP)
  50	assert.NotNil(t, cfg.MCP)
  51	assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  52	for _, path := range defaultContextPaths {
  53		assert.Contains(t, cfg.Options.ContextPaths, path)
  54	}
  55	assert.Equal(t, "/tmp", cfg.workingDir)
  56}
  57
  58func TestConfig_configureProviders(t *testing.T) {
  59	knownProviders := []catwalk.Provider{
  60		{
  61			ID:          "openai",
  62			APIKey:      "$OPENAI_API_KEY",
  63			APIEndpoint: "https://api.openai.com/v1",
  64			Models: []catwalk.Model{{
  65				ID: "test-model",
  66			}},
  67		},
  68	}
  69
  70	cfg := &Config{}
  71	cfg.setDefaults("/tmp")
  72	env := env.NewFromMap(map[string]string{
  73		"OPENAI_API_KEY": "test-key",
  74	})
  75	resolver := NewEnvironmentVariableResolver(env)
  76	err := cfg.configureProviders(env, resolver, knownProviders)
  77	assert.NoError(t, err)
  78	assert.Equal(t, 1, cfg.Providers.Len())
  79
  80	// We want to make sure that we keep the configured API key as a placeholder
  81	pc, _ := cfg.Providers.Get("openai")
  82	assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  83}
  84
  85func TestConfig_configureProvidersWithOverride(t *testing.T) {
  86	knownProviders := []catwalk.Provider{
  87		{
  88			ID:          "openai",
  89			APIKey:      "$OPENAI_API_KEY",
  90			APIEndpoint: "https://api.openai.com/v1",
  91			Models: []catwalk.Model{{
  92				ID: "test-model",
  93			}},
  94		},
  95	}
  96
  97	cfg := &Config{
  98		Providers: csync.NewMap[string, ProviderConfig](),
  99	}
 100	cfg.Providers.Set("openai", ProviderConfig{
 101		APIKey:  "xyz",
 102		BaseURL: "https://api.openai.com/v2",
 103		Models: []catwalk.Model{
 104			{
 105				ID:   "test-model",
 106				Name: "Updated",
 107			},
 108			{
 109				ID: "another-model",
 110			},
 111		},
 112	})
 113	cfg.setDefaults("/tmp")
 114
 115	env := env.NewFromMap(map[string]string{
 116		"OPENAI_API_KEY": "test-key",
 117	})
 118	resolver := NewEnvironmentVariableResolver(env)
 119	err := cfg.configureProviders(env, resolver, knownProviders)
 120	assert.NoError(t, err)
 121	assert.Equal(t, 1, cfg.Providers.Len())
 122
 123	// We want to make sure that we keep the configured API key as a placeholder
 124	pc, _ := cfg.Providers.Get("openai")
 125	assert.Equal(t, "xyz", pc.APIKey)
 126	assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 127	assert.Len(t, pc.Models, 2)
 128	assert.Equal(t, "Updated", pc.Models[0].Name)
 129}
 130
 131func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 132	knownProviders := []catwalk.Provider{
 133		{
 134			ID:          "openai",
 135			APIKey:      "$OPENAI_API_KEY",
 136			APIEndpoint: "https://api.openai.com/v1",
 137			Models: []catwalk.Model{{
 138				ID: "test-model",
 139			}},
 140		},
 141	}
 142
 143	cfg := &Config{
 144		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 145			"custom": {
 146				APIKey:  "xyz",
 147				BaseURL: "https://api.someendpoint.com/v2",
 148				Models: []catwalk.Model{
 149					{
 150						ID: "test-model",
 151					},
 152				},
 153			},
 154		}),
 155	}
 156	cfg.setDefaults("/tmp")
 157	env := env.NewFromMap(map[string]string{
 158		"OPENAI_API_KEY": "test-key",
 159	})
 160	resolver := NewEnvironmentVariableResolver(env)
 161	err := cfg.configureProviders(env, resolver, knownProviders)
 162	assert.NoError(t, err)
 163	// Should be to because of the env variable
 164	assert.Equal(t, cfg.Providers.Len(), 2)
 165
 166	// We want to make sure that we keep the configured API key as a placeholder
 167	pc, _ := cfg.Providers.Get("custom")
 168	assert.Equal(t, "xyz", pc.APIKey)
 169	// Make sure we set the ID correctly
 170	assert.Equal(t, "custom", pc.ID)
 171	assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 172	assert.Len(t, pc.Models, 1)
 173
 174	_, ok := cfg.Providers.Get("openai")
 175	assert.True(t, ok, "OpenAI provider should still be present")
 176}
 177
 178func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 179	knownProviders := []catwalk.Provider{
 180		{
 181			ID:          catwalk.InferenceProviderBedrock,
 182			APIKey:      "",
 183			APIEndpoint: "",
 184			Models: []catwalk.Model{{
 185				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 186			}},
 187		},
 188	}
 189
 190	cfg := &Config{}
 191	cfg.setDefaults("/tmp")
 192	env := env.NewFromMap(map[string]string{
 193		"AWS_ACCESS_KEY_ID":     "test-key-id",
 194		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 195	})
 196	resolver := NewEnvironmentVariableResolver(env)
 197	err := cfg.configureProviders(env, resolver, knownProviders)
 198	assert.NoError(t, err)
 199	assert.Equal(t, cfg.Providers.Len(), 1)
 200
 201	bedrockProvider, ok := cfg.Providers.Get("bedrock")
 202	assert.True(t, ok, "Bedrock provider should be present")
 203	assert.Len(t, bedrockProvider.Models, 1)
 204	assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 205}
 206
 207func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 208	knownProviders := []catwalk.Provider{
 209		{
 210			ID:          catwalk.InferenceProviderBedrock,
 211			APIKey:      "",
 212			APIEndpoint: "",
 213			Models: []catwalk.Model{{
 214				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 215			}},
 216		},
 217	}
 218
 219	cfg := &Config{}
 220	cfg.setDefaults("/tmp")
 221	env := env.NewFromMap(map[string]string{})
 222	resolver := NewEnvironmentVariableResolver(env)
 223	err := cfg.configureProviders(env, resolver, knownProviders)
 224	assert.NoError(t, err)
 225	// Provider should not be configured without credentials
 226	assert.Equal(t, cfg.Providers.Len(), 0)
 227}
 228
 229func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 230	knownProviders := []catwalk.Provider{
 231		{
 232			ID:          catwalk.InferenceProviderBedrock,
 233			APIKey:      "",
 234			APIEndpoint: "",
 235			Models: []catwalk.Model{{
 236				ID: "some-random-model",
 237			}},
 238		},
 239	}
 240
 241	cfg := &Config{}
 242	cfg.setDefaults("/tmp")
 243	env := env.NewFromMap(map[string]string{
 244		"AWS_ACCESS_KEY_ID":     "test-key-id",
 245		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 246	})
 247	resolver := NewEnvironmentVariableResolver(env)
 248	err := cfg.configureProviders(env, resolver, knownProviders)
 249	assert.Error(t, err)
 250}
 251
 252func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 253	knownProviders := []catwalk.Provider{
 254		{
 255			ID:          catwalk.InferenceProviderVertexAI,
 256			APIKey:      "",
 257			APIEndpoint: "",
 258			Models: []catwalk.Model{{
 259				ID: "gemini-pro",
 260			}},
 261		},
 262	}
 263
 264	cfg := &Config{}
 265	cfg.setDefaults("/tmp")
 266	env := env.NewFromMap(map[string]string{
 267		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 268		"GOOGLE_CLOUD_PROJECT":      "test-project",
 269		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 270	})
 271	resolver := NewEnvironmentVariableResolver(env)
 272	err := cfg.configureProviders(env, resolver, knownProviders)
 273	assert.NoError(t, err)
 274	assert.Equal(t, cfg.Providers.Len(), 1)
 275
 276	vertexProvider, ok := cfg.Providers.Get("vertexai")
 277	assert.True(t, ok, "VertexAI provider should be present")
 278	assert.Len(t, vertexProvider.Models, 1)
 279	assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 280	assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 281	assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 282}
 283
 284func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 285	knownProviders := []catwalk.Provider{
 286		{
 287			ID:          catwalk.InferenceProviderVertexAI,
 288			APIKey:      "",
 289			APIEndpoint: "",
 290			Models: []catwalk.Model{{
 291				ID: "gemini-pro",
 292			}},
 293		},
 294	}
 295
 296	cfg := &Config{}
 297	cfg.setDefaults("/tmp")
 298	env := env.NewFromMap(map[string]string{
 299		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 300		"GOOGLE_CLOUD_PROJECT":      "test-project",
 301		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 302	})
 303	resolver := NewEnvironmentVariableResolver(env)
 304	err := cfg.configureProviders(env, resolver, knownProviders)
 305	assert.NoError(t, err)
 306	// Provider should not be configured without proper credentials
 307	assert.Equal(t, cfg.Providers.Len(), 0)
 308}
 309
 310func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 311	knownProviders := []catwalk.Provider{
 312		{
 313			ID:          catwalk.InferenceProviderVertexAI,
 314			APIKey:      "",
 315			APIEndpoint: "",
 316			Models: []catwalk.Model{{
 317				ID: "gemini-pro",
 318			}},
 319		},
 320	}
 321
 322	cfg := &Config{}
 323	cfg.setDefaults("/tmp")
 324	env := env.NewFromMap(map[string]string{
 325		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 326		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 327	})
 328	resolver := NewEnvironmentVariableResolver(env)
 329	err := cfg.configureProviders(env, resolver, knownProviders)
 330	assert.NoError(t, err)
 331	// Provider should not be configured without project
 332	assert.Equal(t, cfg.Providers.Len(), 0)
 333}
 334
 335func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 336	knownProviders := []catwalk.Provider{
 337		{
 338			ID:          "openai",
 339			APIKey:      "$OPENAI_API_KEY",
 340			APIEndpoint: "https://api.openai.com/v1",
 341			Models: []catwalk.Model{{
 342				ID: "test-model",
 343			}},
 344		},
 345	}
 346
 347	cfg := &Config{}
 348	cfg.setDefaults("/tmp")
 349	env := env.NewFromMap(map[string]string{
 350		"OPENAI_API_KEY": "test-key",
 351	})
 352	resolver := NewEnvironmentVariableResolver(env)
 353	err := cfg.configureProviders(env, resolver, knownProviders)
 354	assert.NoError(t, err)
 355	assert.Equal(t, cfg.Providers.Len(), 1)
 356
 357	// Provider ID should be set
 358	pc, _ := cfg.Providers.Get("openai")
 359	assert.Equal(t, "openai", pc.ID)
 360}
 361
 362func TestConfig_EnabledProviders(t *testing.T) {
 363	t.Run("all providers enabled", func(t *testing.T) {
 364		cfg := &Config{
 365			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 366				"openai": {
 367					ID:      "openai",
 368					APIKey:  "key1",
 369					Disable: false,
 370				},
 371				"anthropic": {
 372					ID:      "anthropic",
 373					APIKey:  "key2",
 374					Disable: false,
 375				},
 376			}),
 377		}
 378
 379		enabled := cfg.EnabledProviders()
 380		assert.Len(t, enabled, 2)
 381	})
 382
 383	t.Run("some providers disabled", func(t *testing.T) {
 384		cfg := &Config{
 385			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 386				"openai": {
 387					ID:      "openai",
 388					APIKey:  "key1",
 389					Disable: false,
 390				},
 391				"anthropic": {
 392					ID:      "anthropic",
 393					APIKey:  "key2",
 394					Disable: true,
 395				},
 396			}),
 397		}
 398
 399		enabled := cfg.EnabledProviders()
 400		assert.Len(t, enabled, 1)
 401		assert.Equal(t, "openai", enabled[0].ID)
 402	})
 403
 404	t.Run("empty providers map", func(t *testing.T) {
 405		cfg := &Config{
 406			Providers: csync.NewMap[string, ProviderConfig](),
 407		}
 408
 409		enabled := cfg.EnabledProviders()
 410		assert.Len(t, enabled, 0)
 411	})
 412}
 413
 414func TestConfig_IsConfigured(t *testing.T) {
 415	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 416		cfg := &Config{
 417			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 418				"openai": {
 419					ID:      "openai",
 420					APIKey:  "key1",
 421					Disable: false,
 422				},
 423			}),
 424		}
 425
 426		assert.True(t, cfg.IsConfigured())
 427	})
 428
 429	t.Run("returns false when no providers are configured", func(t *testing.T) {
 430		cfg := &Config{
 431			Providers: csync.NewMap[string, ProviderConfig](),
 432		}
 433
 434		assert.False(t, cfg.IsConfigured())
 435	})
 436
 437	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 438		cfg := &Config{
 439			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 440				"openai": {
 441					ID:      "openai",
 442					APIKey:  "key1",
 443					Disable: true,
 444				},
 445				"anthropic": {
 446					ID:      "anthropic",
 447					APIKey:  "key2",
 448					Disable: true,
 449				},
 450			}),
 451		}
 452
 453		assert.False(t, cfg.IsConfigured())
 454	})
 455}
 456
 457func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 458	knownProviders := []catwalk.Provider{
 459		{
 460			ID:          "openai",
 461			APIKey:      "$OPENAI_API_KEY",
 462			APIEndpoint: "https://api.openai.com/v1",
 463			Models: []catwalk.Model{{
 464				ID: "test-model",
 465			}},
 466		},
 467	}
 468
 469	cfg := &Config{
 470		Providers: csync.NewMapFrom(map[string]ProviderConfig{
 471			"openai": {
 472				Disable: true,
 473			},
 474		}),
 475	}
 476	cfg.setDefaults("/tmp")
 477
 478	env := env.NewFromMap(map[string]string{
 479		"OPENAI_API_KEY": "test-key",
 480	})
 481	resolver := NewEnvironmentVariableResolver(env)
 482	err := cfg.configureProviders(env, resolver, knownProviders)
 483	assert.NoError(t, err)
 484
 485	// Provider should be removed from config when disabled
 486	assert.Equal(t, cfg.Providers.Len(), 0)
 487	_, exists := cfg.Providers.Get("openai")
 488	assert.False(t, exists)
 489}
 490
 491func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 492	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 493		cfg := &Config{
 494			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 495				"custom": {
 496					BaseURL: "https://api.custom.com/v1",
 497					Models: []catwalk.Model{{
 498						ID: "test-model",
 499					}},
 500				},
 501				"openai": {
 502					APIKey: "$MISSING",
 503				},
 504			}),
 505		}
 506		cfg.setDefaults("/tmp")
 507
 508		env := env.NewFromMap(map[string]string{})
 509		resolver := NewEnvironmentVariableResolver(env)
 510		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 511		assert.NoError(t, err)
 512
 513		assert.Equal(t, cfg.Providers.Len(), 1)
 514		_, exists := cfg.Providers.Get("custom")
 515		assert.True(t, exists)
 516	})
 517
 518	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 519		cfg := &Config{
 520			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 521				"custom": {
 522					APIKey: "test-key",
 523					Models: []catwalk.Model{{
 524						ID: "test-model",
 525					}},
 526				},
 527			}),
 528		}
 529		cfg.setDefaults("/tmp")
 530
 531		env := env.NewFromMap(map[string]string{})
 532		resolver := NewEnvironmentVariableResolver(env)
 533		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 534		assert.NoError(t, err)
 535
 536		assert.Equal(t, cfg.Providers.Len(), 0)
 537		_, exists := cfg.Providers.Get("custom")
 538		assert.False(t, exists)
 539	})
 540
 541	t.Run("custom provider with no models is removed", func(t *testing.T) {
 542		cfg := &Config{
 543			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 544				"custom": {
 545					APIKey:  "test-key",
 546					BaseURL: "https://api.custom.com/v1",
 547					Models:  []catwalk.Model{},
 548				},
 549			}),
 550		}
 551		cfg.setDefaults("/tmp")
 552
 553		env := env.NewFromMap(map[string]string{})
 554		resolver := NewEnvironmentVariableResolver(env)
 555		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 556		assert.NoError(t, err)
 557
 558		assert.Equal(t, cfg.Providers.Len(), 0)
 559		_, exists := cfg.Providers.Get("custom")
 560		assert.False(t, exists)
 561	})
 562
 563	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 564		cfg := &Config{
 565			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 566				"custom": {
 567					APIKey:  "test-key",
 568					BaseURL: "https://api.custom.com/v1",
 569					Type:    "unsupported",
 570					Models: []catwalk.Model{{
 571						ID: "test-model",
 572					}},
 573				},
 574			}),
 575		}
 576		cfg.setDefaults("/tmp")
 577
 578		env := env.NewFromMap(map[string]string{})
 579		resolver := NewEnvironmentVariableResolver(env)
 580		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 581		assert.NoError(t, err)
 582
 583		assert.Equal(t, cfg.Providers.Len(), 0)
 584		_, exists := cfg.Providers.Get("custom")
 585		assert.False(t, exists)
 586	})
 587
 588	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 589		cfg := &Config{
 590			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 591				"custom": {
 592					APIKey:  "test-key",
 593					BaseURL: "https://api.custom.com/v1",
 594					Type:    catwalk.TypeOpenAI,
 595					Models: []catwalk.Model{{
 596						ID: "test-model",
 597					}},
 598				},
 599			}),
 600		}
 601		cfg.setDefaults("/tmp")
 602
 603		env := env.NewFromMap(map[string]string{})
 604		resolver := NewEnvironmentVariableResolver(env)
 605		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 606		assert.NoError(t, err)
 607
 608		assert.Equal(t, cfg.Providers.Len(), 1)
 609		customProvider, exists := cfg.Providers.Get("custom")
 610		assert.True(t, exists)
 611		assert.Equal(t, "custom", customProvider.ID)
 612		assert.Equal(t, "test-key", customProvider.APIKey)
 613		assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 614	})
 615
 616	t.Run("disabled custom provider is removed", func(t *testing.T) {
 617		cfg := &Config{
 618			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 619				"custom": {
 620					APIKey:  "test-key",
 621					BaseURL: "https://api.custom.com/v1",
 622					Type:    catwalk.TypeOpenAI,
 623					Disable: true,
 624					Models: []catwalk.Model{{
 625						ID: "test-model",
 626					}},
 627				},
 628			}),
 629		}
 630		cfg.setDefaults("/tmp")
 631
 632		env := env.NewFromMap(map[string]string{})
 633		resolver := NewEnvironmentVariableResolver(env)
 634		err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
 635		assert.NoError(t, err)
 636
 637		assert.Equal(t, cfg.Providers.Len(), 0)
 638		_, exists := cfg.Providers.Get("custom")
 639		assert.False(t, exists)
 640	})
 641}
 642
 643func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 644	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 645		knownProviders := []catwalk.Provider{
 646			{
 647				ID:          catwalk.InferenceProviderVertexAI,
 648				APIKey:      "",
 649				APIEndpoint: "",
 650				Models: []catwalk.Model{{
 651					ID: "gemini-pro",
 652				}},
 653			},
 654		}
 655
 656		cfg := &Config{
 657			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 658				"vertexai": {
 659					BaseURL: "custom-url",
 660				},
 661			}),
 662		}
 663		cfg.setDefaults("/tmp")
 664
 665		env := env.NewFromMap(map[string]string{
 666			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 667		})
 668		resolver := NewEnvironmentVariableResolver(env)
 669		err := cfg.configureProviders(env, resolver, knownProviders)
 670		assert.NoError(t, err)
 671
 672		assert.Equal(t, cfg.Providers.Len(), 0)
 673		_, exists := cfg.Providers.Get("vertexai")
 674		assert.False(t, exists)
 675	})
 676
 677	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 678		knownProviders := []catwalk.Provider{
 679			{
 680				ID:          catwalk.InferenceProviderBedrock,
 681				APIKey:      "",
 682				APIEndpoint: "",
 683				Models: []catwalk.Model{{
 684					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 685				}},
 686			},
 687		}
 688
 689		cfg := &Config{
 690			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 691				"bedrock": {
 692					BaseURL: "custom-url",
 693				},
 694			}),
 695		}
 696		cfg.setDefaults("/tmp")
 697
 698		env := env.NewFromMap(map[string]string{})
 699		resolver := NewEnvironmentVariableResolver(env)
 700		err := cfg.configureProviders(env, resolver, knownProviders)
 701		assert.NoError(t, err)
 702
 703		assert.Equal(t, cfg.Providers.Len(), 0)
 704		_, exists := cfg.Providers.Get("bedrock")
 705		assert.False(t, exists)
 706	})
 707
 708	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 709		knownProviders := []catwalk.Provider{
 710			{
 711				ID:          "openai",
 712				APIKey:      "$MISSING_API_KEY",
 713				APIEndpoint: "https://api.openai.com/v1",
 714				Models: []catwalk.Model{{
 715					ID: "test-model",
 716				}},
 717			},
 718		}
 719
 720		cfg := &Config{
 721			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 722				"openai": {
 723					BaseURL: "custom-url",
 724				},
 725			}),
 726		}
 727		cfg.setDefaults("/tmp")
 728
 729		env := env.NewFromMap(map[string]string{})
 730		resolver := NewEnvironmentVariableResolver(env)
 731		err := cfg.configureProviders(env, resolver, knownProviders)
 732		assert.NoError(t, err)
 733
 734		assert.Equal(t, cfg.Providers.Len(), 0)
 735		_, exists := cfg.Providers.Get("openai")
 736		assert.False(t, exists)
 737	})
 738
 739	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 740		knownProviders := []catwalk.Provider{
 741			{
 742				ID:          "openai",
 743				APIKey:      "$OPENAI_API_KEY",
 744				APIEndpoint: "$MISSING_ENDPOINT",
 745				Models: []catwalk.Model{{
 746					ID: "test-model",
 747				}},
 748			},
 749		}
 750
 751		cfg := &Config{
 752			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 753				"openai": {
 754					APIKey: "test-key",
 755				},
 756			}),
 757		}
 758		cfg.setDefaults("/tmp")
 759
 760		env := env.NewFromMap(map[string]string{
 761			"OPENAI_API_KEY": "test-key",
 762		})
 763		resolver := NewEnvironmentVariableResolver(env)
 764		err := cfg.configureProviders(env, resolver, knownProviders)
 765		assert.NoError(t, err)
 766
 767		assert.Equal(t, cfg.Providers.Len(), 1)
 768		_, exists := cfg.Providers.Get("openai")
 769		assert.True(t, exists)
 770	})
 771}
 772
 773func TestConfig_defaultModelSelection(t *testing.T) {
 774	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 775		knownProviders := []catwalk.Provider{
 776			{
 777				ID:                  "openai",
 778				APIKey:              "abc",
 779				DefaultLargeModelID: "large-model",
 780				DefaultSmallModelID: "small-model",
 781				Models: []catwalk.Model{
 782					{
 783						ID:               "large-model",
 784						DefaultMaxTokens: 1000,
 785					},
 786					{
 787						ID:               "small-model",
 788						DefaultMaxTokens: 500,
 789					},
 790				},
 791			},
 792		}
 793
 794		cfg := &Config{}
 795		cfg.setDefaults("/tmp")
 796		env := env.NewFromMap(map[string]string{})
 797		resolver := NewEnvironmentVariableResolver(env)
 798		err := cfg.configureProviders(env, resolver, knownProviders)
 799		assert.NoError(t, err)
 800
 801		large, small, err := cfg.defaultModelSelection(knownProviders)
 802		assert.NoError(t, err)
 803		assert.Equal(t, "large-model", large.Model)
 804		assert.Equal(t, "openai", large.Provider)
 805		assert.Equal(t, int64(1000), large.MaxTokens)
 806		assert.Equal(t, "small-model", small.Model)
 807		assert.Equal(t, "openai", small.Provider)
 808		assert.Equal(t, int64(500), small.MaxTokens)
 809	})
 810	t.Run("should error if no providers configured", func(t *testing.T) {
 811		knownProviders := []catwalk.Provider{
 812			{
 813				ID:                  "openai",
 814				APIKey:              "$MISSING_KEY",
 815				DefaultLargeModelID: "large-model",
 816				DefaultSmallModelID: "small-model",
 817				Models: []catwalk.Model{
 818					{
 819						ID:               "large-model",
 820						DefaultMaxTokens: 1000,
 821					},
 822					{
 823						ID:               "small-model",
 824						DefaultMaxTokens: 500,
 825					},
 826				},
 827			},
 828		}
 829
 830		cfg := &Config{}
 831		cfg.setDefaults("/tmp")
 832		env := env.NewFromMap(map[string]string{})
 833		resolver := NewEnvironmentVariableResolver(env)
 834		err := cfg.configureProviders(env, resolver, knownProviders)
 835		assert.NoError(t, err)
 836
 837		_, _, err = cfg.defaultModelSelection(knownProviders)
 838		assert.Error(t, err)
 839	})
 840	t.Run("should error if model is missing", func(t *testing.T) {
 841		knownProviders := []catwalk.Provider{
 842			{
 843				ID:                  "openai",
 844				APIKey:              "abc",
 845				DefaultLargeModelID: "large-model",
 846				DefaultSmallModelID: "small-model",
 847				Models: []catwalk.Model{
 848					{
 849						ID:               "not-large-model",
 850						DefaultMaxTokens: 1000,
 851					},
 852					{
 853						ID:               "small-model",
 854						DefaultMaxTokens: 500,
 855					},
 856				},
 857			},
 858		}
 859
 860		cfg := &Config{}
 861		cfg.setDefaults("/tmp")
 862		env := env.NewFromMap(map[string]string{})
 863		resolver := NewEnvironmentVariableResolver(env)
 864		err := cfg.configureProviders(env, resolver, knownProviders)
 865		assert.NoError(t, err)
 866		_, _, err = cfg.defaultModelSelection(knownProviders)
 867		assert.Error(t, err)
 868	})
 869
 870	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 871		knownProviders := []catwalk.Provider{
 872			{
 873				ID:                  "openai",
 874				APIKey:              "$MISSING", // will not be included in the config
 875				DefaultLargeModelID: "large-model",
 876				DefaultSmallModelID: "small-model",
 877				Models: []catwalk.Model{
 878					{
 879						ID:               "not-large-model",
 880						DefaultMaxTokens: 1000,
 881					},
 882					{
 883						ID:               "small-model",
 884						DefaultMaxTokens: 500,
 885					},
 886				},
 887			},
 888		}
 889
 890		cfg := &Config{
 891			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 892				"custom": {
 893					APIKey:  "test-key",
 894					BaseURL: "https://api.custom.com/v1",
 895					Models: []catwalk.Model{
 896						{
 897							ID:               "model",
 898							DefaultMaxTokens: 600,
 899						},
 900					},
 901				},
 902			}),
 903		}
 904		cfg.setDefaults("/tmp")
 905		env := env.NewFromMap(map[string]string{})
 906		resolver := NewEnvironmentVariableResolver(env)
 907		err := cfg.configureProviders(env, resolver, knownProviders)
 908		assert.NoError(t, err)
 909		large, small, err := cfg.defaultModelSelection(knownProviders)
 910		assert.NoError(t, err)
 911		assert.Equal(t, "model", large.Model)
 912		assert.Equal(t, "custom", large.Provider)
 913		assert.Equal(t, int64(600), large.MaxTokens)
 914		assert.Equal(t, "model", small.Model)
 915		assert.Equal(t, "custom", small.Provider)
 916		assert.Equal(t, int64(600), small.MaxTokens)
 917	})
 918
 919	t.Run("should fail if no model configured", func(t *testing.T) {
 920		knownProviders := []catwalk.Provider{
 921			{
 922				ID:                  "openai",
 923				APIKey:              "$MISSING", // will not be included in the config
 924				DefaultLargeModelID: "large-model",
 925				DefaultSmallModelID: "small-model",
 926				Models: []catwalk.Model{
 927					{
 928						ID:               "not-large-model",
 929						DefaultMaxTokens: 1000,
 930					},
 931					{
 932						ID:               "small-model",
 933						DefaultMaxTokens: 500,
 934					},
 935				},
 936			},
 937		}
 938
 939		cfg := &Config{
 940			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 941				"custom": {
 942					APIKey:  "test-key",
 943					BaseURL: "https://api.custom.com/v1",
 944					Models:  []catwalk.Model{},
 945				},
 946			}),
 947		}
 948		cfg.setDefaults("/tmp")
 949		env := env.NewFromMap(map[string]string{})
 950		resolver := NewEnvironmentVariableResolver(env)
 951		err := cfg.configureProviders(env, resolver, knownProviders)
 952		assert.NoError(t, err)
 953		_, _, err = cfg.defaultModelSelection(knownProviders)
 954		assert.Error(t, err)
 955	})
 956	t.Run("should use the default provider first", func(t *testing.T) {
 957		knownProviders := []catwalk.Provider{
 958			{
 959				ID:                  "openai",
 960				APIKey:              "set",
 961				DefaultLargeModelID: "large-model",
 962				DefaultSmallModelID: "small-model",
 963				Models: []catwalk.Model{
 964					{
 965						ID:               "large-model",
 966						DefaultMaxTokens: 1000,
 967					},
 968					{
 969						ID:               "small-model",
 970						DefaultMaxTokens: 500,
 971					},
 972				},
 973			},
 974		}
 975
 976		cfg := &Config{
 977			Providers: csync.NewMapFrom(map[string]ProviderConfig{
 978				"custom": {
 979					APIKey:  "test-key",
 980					BaseURL: "https://api.custom.com/v1",
 981					Models: []catwalk.Model{
 982						{
 983							ID:               "large-model",
 984							DefaultMaxTokens: 1000,
 985						},
 986					},
 987				},
 988			}),
 989		}
 990		cfg.setDefaults("/tmp")
 991		env := env.NewFromMap(map[string]string{})
 992		resolver := NewEnvironmentVariableResolver(env)
 993		err := cfg.configureProviders(env, resolver, knownProviders)
 994		assert.NoError(t, err)
 995		large, small, err := cfg.defaultModelSelection(knownProviders)
 996		assert.NoError(t, err)
 997		assert.Equal(t, "large-model", large.Model)
 998		assert.Equal(t, "openai", large.Provider)
 999		assert.Equal(t, int64(1000), large.MaxTokens)
1000		assert.Equal(t, "small-model", small.Model)
1001		assert.Equal(t, "openai", small.Provider)
1002		assert.Equal(t, int64(500), small.MaxTokens)
1003	})
1004}
1005
1006func TestConfig_configureSelectedModels(t *testing.T) {
1007	t.Run("should override defaults", func(t *testing.T) {
1008		knownProviders := []catwalk.Provider{
1009			{
1010				ID:                  "openai",
1011				APIKey:              "abc",
1012				DefaultLargeModelID: "large-model",
1013				DefaultSmallModelID: "small-model",
1014				Models: []catwalk.Model{
1015					{
1016						ID:               "larger-model",
1017						DefaultMaxTokens: 2000,
1018					},
1019					{
1020						ID:               "large-model",
1021						DefaultMaxTokens: 1000,
1022					},
1023					{
1024						ID:               "small-model",
1025						DefaultMaxTokens: 500,
1026					},
1027				},
1028			},
1029		}
1030
1031		cfg := &Config{
1032			Models: map[SelectedModelType]SelectedModel{
1033				"large": {
1034					Model: "larger-model",
1035				},
1036			},
1037		}
1038		cfg.setDefaults("/tmp")
1039		env := env.NewFromMap(map[string]string{})
1040		resolver := NewEnvironmentVariableResolver(env)
1041		err := cfg.configureProviders(env, resolver, knownProviders)
1042		assert.NoError(t, err)
1043
1044		err = cfg.configureSelectedModels(knownProviders)
1045		assert.NoError(t, err)
1046		large := cfg.Models[SelectedModelTypeLarge]
1047		small := cfg.Models[SelectedModelTypeSmall]
1048		assert.Equal(t, "larger-model", large.Model)
1049		assert.Equal(t, "openai", large.Provider)
1050		assert.Equal(t, int64(2000), large.MaxTokens)
1051		assert.Equal(t, "small-model", small.Model)
1052		assert.Equal(t, "openai", small.Provider)
1053		assert.Equal(t, int64(500), small.MaxTokens)
1054	})
1055	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1056		knownProviders := []catwalk.Provider{
1057			{
1058				ID:                  "openai",
1059				APIKey:              "abc",
1060				DefaultLargeModelID: "large-model",
1061				DefaultSmallModelID: "small-model",
1062				Models: []catwalk.Model{
1063					{
1064						ID:               "large-model",
1065						DefaultMaxTokens: 1000,
1066					},
1067					{
1068						ID:               "small-model",
1069						DefaultMaxTokens: 500,
1070					},
1071				},
1072			},
1073			{
1074				ID:                  "anthropic",
1075				APIKey:              "abc",
1076				DefaultLargeModelID: "a-large-model",
1077				DefaultSmallModelID: "a-small-model",
1078				Models: []catwalk.Model{
1079					{
1080						ID:               "a-large-model",
1081						DefaultMaxTokens: 1000,
1082					},
1083					{
1084						ID:               "a-small-model",
1085						DefaultMaxTokens: 200,
1086					},
1087				},
1088			},
1089		}
1090
1091		cfg := &Config{
1092			Models: map[SelectedModelType]SelectedModel{
1093				"small": {
1094					Model:     "a-small-model",
1095					Provider:  "anthropic",
1096					MaxTokens: 300,
1097				},
1098			},
1099		}
1100		cfg.setDefaults("/tmp")
1101		env := env.NewFromMap(map[string]string{})
1102		resolver := NewEnvironmentVariableResolver(env)
1103		err := cfg.configureProviders(env, resolver, knownProviders)
1104		assert.NoError(t, err)
1105
1106		err = cfg.configureSelectedModels(knownProviders)
1107		assert.NoError(t, err)
1108		large := cfg.Models[SelectedModelTypeLarge]
1109		small := cfg.Models[SelectedModelTypeSmall]
1110		assert.Equal(t, "large-model", large.Model)
1111		assert.Equal(t, "openai", large.Provider)
1112		assert.Equal(t, int64(1000), large.MaxTokens)
1113		assert.Equal(t, "a-small-model", small.Model)
1114		assert.Equal(t, "anthropic", small.Provider)
1115		assert.Equal(t, int64(300), small.MaxTokens)
1116	})
1117
1118	t.Run("should override the max tokens only", func(t *testing.T) {
1119		knownProviders := []catwalk.Provider{
1120			{
1121				ID:                  "openai",
1122				APIKey:              "abc",
1123				DefaultLargeModelID: "large-model",
1124				DefaultSmallModelID: "small-model",
1125				Models: []catwalk.Model{
1126					{
1127						ID:               "large-model",
1128						DefaultMaxTokens: 1000,
1129					},
1130					{
1131						ID:               "small-model",
1132						DefaultMaxTokens: 500,
1133					},
1134				},
1135			},
1136		}
1137
1138		cfg := &Config{
1139			Models: map[SelectedModelType]SelectedModel{
1140				"large": {
1141					MaxTokens: 100,
1142				},
1143			},
1144		}
1145		cfg.setDefaults("/tmp")
1146		env := env.NewFromMap(map[string]string{})
1147		resolver := NewEnvironmentVariableResolver(env)
1148		err := cfg.configureProviders(env, resolver, knownProviders)
1149		assert.NoError(t, err)
1150
1151		err = cfg.configureSelectedModels(knownProviders)
1152		assert.NoError(t, err)
1153		large := cfg.Models[SelectedModelTypeLarge]
1154		assert.Equal(t, "large-model", large.Model)
1155		assert.Equal(t, "openai", large.Provider)
1156		assert.Equal(t, int64(100), large.MaxTokens)
1157	})
1158}