load_test.go

   1package config
   2
   3import (
   4	"io"
   5	"log/slog"
   6	"os"
   7	"path/filepath"
   8	"testing"
   9
  10	"charm.land/catwalk/pkg/catwalk"
  11	"github.com/charmbracelet/crush/internal/env"
  12	"github.com/stretchr/testify/assert"
  13	"github.com/stretchr/testify/require"
  14)
  15
  16func serviceFor(cfg *Config) *Service {
  17	return &Service{cfg: cfg}
  18}
  19
  20func TestMain(m *testing.M) {
  21	slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  22
  23	exitVal := m.Run()
  24	os.Exit(exitVal)
  25}
  26
  27func TestConfig_LoadFromBytes(t *testing.T) {
  28	data1 := []byte(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  29	data2 := []byte(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  30	data3 := []byte(`{"providers": {"openai": {}}}`)
  31
  32	loadedConfig, err := loadFromBytes([][]byte{data1, data2, data3})
  33
  34	require.NoError(t, err)
  35	require.NotNil(t, loadedConfig)
  36	require.Equal(t, 1, len(loadedConfig.Providers))
  37	pc, _ := loadedConfig.Providers["openai"]
  38	require.Equal(t, "key2", pc.APIKey)
  39	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  40}
  41
  42func TestConfig_setDefaults(t *testing.T) {
  43	cfg := &Config{}
  44
  45	cfg.setDefaults("/tmp", "")
  46
  47	require.NotNil(t, cfg.Options)
  48	require.NotNil(t, cfg.Options.TUI)
  49	require.NotNil(t, cfg.Options.ContextPaths)
  50	require.NotNil(t, cfg.Providers)
  51	require.NotNil(t, cfg.Models)
  52	require.NotNil(t, cfg.LSP)
  53	require.NotNil(t, cfg.MCP)
  54	require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  55	require.Equal(t, "AGENTS.md", cfg.Options.InitializeAs)
  56	for _, path := range defaultContextPaths {
  57		require.Contains(t, cfg.Options.ContextPaths, path)
  58	}
  59}
  60
  61func TestConfig_configureProviders(t *testing.T) {
  62	knownProviders := []catwalk.Provider{
  63		{
  64			ID:          "openai",
  65			APIKey:      "$OPENAI_API_KEY",
  66			APIEndpoint: "https://api.openai.com/v1",
  67			Models: []catwalk.Model{{
  68				ID: "test-model",
  69			}},
  70		},
  71	}
  72
  73	cfg := &Config{}
  74	cfg.setDefaults("/tmp", "")
  75	env := env.NewFromMap(map[string]string{
  76		"OPENAI_API_KEY": "test-key",
  77	})
  78	resolver := NewEnvironmentVariableResolver(env)
  79	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
  80	require.NoError(t, err)
  81	require.Equal(t, 1, len(cfg.Providers))
  82
  83	// We want to make sure that we keep the configured API key as a placeholder
  84	pc, _ := cfg.Providers["openai"]
  85	require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  86}
  87
  88func TestConfig_configureProvidersWithOverride(t *testing.T) {
  89	knownProviders := []catwalk.Provider{
  90		{
  91			ID:          "openai",
  92			APIKey:      "$OPENAI_API_KEY",
  93			APIEndpoint: "https://api.openai.com/v1",
  94			Models: []catwalk.Model{{
  95				ID: "test-model",
  96			}},
  97		},
  98	}
  99
 100	cfg := &Config{
 101		Providers: make(map[string]ProviderConfig),
 102	}
 103	cfg.Providers["openai"] = ProviderConfig{
 104		APIKey:  "xyz",
 105		BaseURL: "https://api.openai.com/v2",
 106		Models: []catwalk.Model{
 107			{
 108				ID:   "test-model",
 109				Name: "Updated",
 110			},
 111			{
 112				ID: "another-model",
 113			},
 114		},
 115	}
 116	cfg.setDefaults("/tmp", "")
 117
 118	env := env.NewFromMap(map[string]string{
 119		"OPENAI_API_KEY": "test-key",
 120	})
 121	resolver := NewEnvironmentVariableResolver(env)
 122	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 123	require.NoError(t, err)
 124	require.Equal(t, 1, len(cfg.Providers))
 125
 126	// We want to make sure that we keep the configured API key as a placeholder
 127	pc, _ := cfg.Providers["openai"]
 128	require.Equal(t, "xyz", pc.APIKey)
 129	require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
 130	require.Len(t, pc.Models, 2)
 131	require.Equal(t, "Updated", pc.Models[0].Name)
 132}
 133
 134func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 135	knownProviders := []catwalk.Provider{
 136		{
 137			ID:          "openai",
 138			APIKey:      "$OPENAI_API_KEY",
 139			APIEndpoint: "https://api.openai.com/v1",
 140			Models: []catwalk.Model{{
 141				ID: "test-model",
 142			}},
 143		},
 144	}
 145
 146	cfg := &Config{
 147		Providers: map[string]ProviderConfig{
 148			"custom": {
 149				APIKey:  "xyz",
 150				BaseURL: "https://api.someendpoint.com/v2",
 151				Models: []catwalk.Model{
 152					{
 153						ID: "test-model",
 154					},
 155				},
 156			},
 157		},
 158	}
 159	cfg.setDefaults("/tmp", "")
 160	env := env.NewFromMap(map[string]string{
 161		"OPENAI_API_KEY": "test-key",
 162	})
 163	resolver := NewEnvironmentVariableResolver(env)
 164	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 165	require.NoError(t, err)
 166	// Should be to because of the env variable
 167	require.Equal(t, len(cfg.Providers), 2)
 168
 169	// We want to make sure that we keep the configured API key as a placeholder
 170	pc, _ := cfg.Providers["custom"]
 171	require.Equal(t, "xyz", pc.APIKey)
 172	// Make sure we set the ID correctly
 173	require.Equal(t, "custom", pc.ID)
 174	require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
 175	require.Len(t, pc.Models, 1)
 176
 177	_, ok := cfg.Providers["openai"]
 178	require.True(t, ok, "OpenAI provider should still be present")
 179}
 180
 181func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 182	knownProviders := []catwalk.Provider{
 183		{
 184			ID:          catwalk.InferenceProviderBedrock,
 185			APIKey:      "",
 186			APIEndpoint: "",
 187			Models: []catwalk.Model{{
 188				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 189			}},
 190		},
 191	}
 192
 193	cfg := &Config{}
 194	cfg.setDefaults("/tmp", "")
 195	env := env.NewFromMap(map[string]string{
 196		"AWS_ACCESS_KEY_ID":     "test-key-id",
 197		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 198	})
 199	resolver := NewEnvironmentVariableResolver(env)
 200	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 201	require.NoError(t, err)
 202	require.Equal(t, len(cfg.Providers), 1)
 203
 204	bedrockProvider, ok := cfg.Providers["bedrock"]
 205	require.True(t, ok, "Bedrock provider should be present")
 206	require.Len(t, bedrockProvider.Models, 1)
 207	require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
 208}
 209
 210func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 211	knownProviders := []catwalk.Provider{
 212		{
 213			ID:          catwalk.InferenceProviderBedrock,
 214			APIKey:      "",
 215			APIEndpoint: "",
 216			Models: []catwalk.Model{{
 217				ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 218			}},
 219		},
 220	}
 221
 222	cfg := &Config{}
 223	cfg.setDefaults("/tmp", "")
 224	env := env.NewFromMap(map[string]string{})
 225	resolver := NewEnvironmentVariableResolver(env)
 226	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 227	require.NoError(t, err)
 228	// Provider should not be configured without credentials
 229	require.Equal(t, len(cfg.Providers), 0)
 230}
 231
 232func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 233	knownProviders := []catwalk.Provider{
 234		{
 235			ID:          catwalk.InferenceProviderBedrock,
 236			APIKey:      "",
 237			APIEndpoint: "",
 238			Models: []catwalk.Model{{
 239				ID: "some-random-model",
 240			}},
 241		},
 242	}
 243
 244	cfg := &Config{}
 245	cfg.setDefaults("/tmp", "")
 246	env := env.NewFromMap(map[string]string{
 247		"AWS_ACCESS_KEY_ID":     "test-key-id",
 248		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 249	})
 250	resolver := NewEnvironmentVariableResolver(env)
 251	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 252	require.Error(t, err)
 253}
 254
 255func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 256	knownProviders := []catwalk.Provider{
 257		{
 258			ID:          catwalk.InferenceProviderVertexAI,
 259			APIKey:      "",
 260			APIEndpoint: "",
 261			Models: []catwalk.Model{{
 262				ID: "gemini-pro",
 263			}},
 264		},
 265	}
 266
 267	cfg := &Config{}
 268	cfg.setDefaults("/tmp", "")
 269	env := env.NewFromMap(map[string]string{
 270		"VERTEXAI_PROJECT":  "test-project",
 271		"VERTEXAI_LOCATION": "us-central1",
 272	})
 273	resolver := NewEnvironmentVariableResolver(env)
 274	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 275	require.NoError(t, err)
 276	require.Equal(t, len(cfg.Providers), 1)
 277
 278	vertexProvider, ok := cfg.Providers["vertexai"]
 279	require.True(t, ok, "VertexAI provider should be present")
 280	require.Len(t, vertexProvider.Models, 1)
 281	require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
 282	require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
 283	require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
 284}
 285
 286func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 287	knownProviders := []catwalk.Provider{
 288		{
 289			ID:          catwalk.InferenceProviderVertexAI,
 290			APIKey:      "",
 291			APIEndpoint: "",
 292			Models: []catwalk.Model{{
 293				ID: "gemini-pro",
 294			}},
 295		},
 296	}
 297
 298	cfg := &Config{}
 299	cfg.setDefaults("/tmp", "")
 300	env := env.NewFromMap(map[string]string{
 301		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 302		"GOOGLE_CLOUD_PROJECT":      "test-project",
 303		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 304	})
 305	resolver := NewEnvironmentVariableResolver(env)
 306	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 307	require.NoError(t, err)
 308	// Provider should not be configured without proper credentials
 309	require.Equal(t, len(cfg.Providers), 0)
 310}
 311
 312func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 313	knownProviders := []catwalk.Provider{
 314		{
 315			ID:          catwalk.InferenceProviderVertexAI,
 316			APIKey:      "",
 317			APIEndpoint: "",
 318			Models: []catwalk.Model{{
 319				ID: "gemini-pro",
 320			}},
 321		},
 322	}
 323
 324	cfg := &Config{}
 325	cfg.setDefaults("/tmp", "")
 326	env := env.NewFromMap(map[string]string{
 327		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 328		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 329	})
 330	resolver := NewEnvironmentVariableResolver(env)
 331	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 332	require.NoError(t, err)
 333	// Provider should not be configured without project
 334	require.Equal(t, len(cfg.Providers), 0)
 335}
 336
 337func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 338	knownProviders := []catwalk.Provider{
 339		{
 340			ID:          "openai",
 341			APIKey:      "$OPENAI_API_KEY",
 342			APIEndpoint: "https://api.openai.com/v1",
 343			Models: []catwalk.Model{{
 344				ID: "test-model",
 345			}},
 346		},
 347	}
 348
 349	cfg := &Config{}
 350	cfg.setDefaults("/tmp", "")
 351	env := env.NewFromMap(map[string]string{
 352		"OPENAI_API_KEY": "test-key",
 353	})
 354	resolver := NewEnvironmentVariableResolver(env)
 355	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 356	require.NoError(t, err)
 357	require.Equal(t, len(cfg.Providers), 1)
 358
 359	// Provider ID should be set
 360	pc, _ := cfg.Providers["openai"]
 361	require.Equal(t, "openai", pc.ID)
 362}
 363
 364func TestConfig_EnabledProviders(t *testing.T) {
 365	t.Run("all providers enabled", func(t *testing.T) {
 366		cfg := &Config{
 367			Providers: map[string]ProviderConfig{
 368				"openai": {
 369					ID:      "openai",
 370					APIKey:  "key1",
 371					Disable: false,
 372				},
 373				"anthropic": {
 374					ID:      "anthropic",
 375					APIKey:  "key2",
 376					Disable: false,
 377				},
 378			},
 379		}
 380
 381		enabled := cfg.EnabledProviders()
 382		require.Len(t, enabled, 2)
 383	})
 384
 385	t.Run("some providers disabled", func(t *testing.T) {
 386		cfg := &Config{
 387			Providers: map[string]ProviderConfig{
 388				"openai": {
 389					ID:      "openai",
 390					APIKey:  "key1",
 391					Disable: false,
 392				},
 393				"anthropic": {
 394					ID:      "anthropic",
 395					APIKey:  "key2",
 396					Disable: true,
 397				},
 398			},
 399		}
 400
 401		enabled := cfg.EnabledProviders()
 402		require.Len(t, enabled, 1)
 403		require.Equal(t, "openai", enabled[0].ID)
 404	})
 405
 406	t.Run("empty providers map", func(t *testing.T) {
 407		cfg := &Config{
 408			Providers: make(map[string]ProviderConfig),
 409		}
 410
 411		enabled := cfg.EnabledProviders()
 412		require.Len(t, enabled, 0)
 413	})
 414}
 415
 416func TestConfig_IsConfigured(t *testing.T) {
 417	t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
 418		cfg := &Config{
 419			Providers: map[string]ProviderConfig{
 420				"openai": {
 421					ID:      "openai",
 422					APIKey:  "key1",
 423					Disable: false,
 424				},
 425			},
 426		}
 427
 428		require.True(t, cfg.IsConfigured())
 429	})
 430
 431	t.Run("returns false when no providers are configured", func(t *testing.T) {
 432		cfg := &Config{
 433			Providers: make(map[string]ProviderConfig),
 434		}
 435
 436		require.False(t, cfg.IsConfigured())
 437	})
 438
 439	t.Run("returns false when all providers are disabled", func(t *testing.T) {
 440		cfg := &Config{
 441			Providers: map[string]ProviderConfig{
 442				"openai": {
 443					ID:      "openai",
 444					APIKey:  "key1",
 445					Disable: true,
 446				},
 447				"anthropic": {
 448					ID:      "anthropic",
 449					APIKey:  "key2",
 450					Disable: true,
 451				},
 452			},
 453		}
 454
 455		require.False(t, cfg.IsConfigured())
 456	})
 457}
 458
 459func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
 460	cfg := &Config{
 461		Options: &Options{
 462			DisabledTools: []string{},
 463		},
 464	}
 465
 466	svc := serviceFor(cfg)
 467	svc.SetupAgents()
 468	coderAgent, ok := svc.Agents()[AgentCoder]
 469	require.True(t, ok)
 470	assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
 471
 472	taskAgent, ok := svc.Agents()[AgentTask]
 473	require.True(t, ok)
 474	assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 475}
 476
 477func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
 478	cfg := &Config{
 479		Options: &Options{
 480			DisabledTools: []string{
 481				"edit",
 482				"download",
 483				"grep",
 484			},
 485		},
 486	}
 487
 488	svc := serviceFor(cfg)
 489	svc.SetupAgents()
 490	coderAgent, ok := svc.Agents()[AgentCoder]
 491	require.True(t, ok)
 492
 493	assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write"}, coderAgent.AllowedTools)
 494
 495	taskAgent, ok := svc.Agents()[AgentTask]
 496	require.True(t, ok)
 497	assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
 498}
 499
 500func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
 501	cfg := &Config{
 502		Options: &Options{
 503			DisabledTools: []string{
 504				"glob",
 505				"grep",
 506				"ls",
 507				"sourcegraph",
 508				"view",
 509			},
 510		},
 511	}
 512
 513	svc := serviceFor(cfg)
 514	svc.SetupAgents()
 515	coderAgent, ok := svc.Agents()[AgentCoder]
 516	require.True(t, ok)
 517	assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "todos", "write"}, coderAgent.AllowedTools)
 518
 519	taskAgent, ok := svc.Agents()[AgentTask]
 520	require.True(t, ok)
 521	assert.Equal(t, []string{}, taskAgent.AllowedTools)
 522}
 523
 524func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 525	knownProviders := []catwalk.Provider{
 526		{
 527			ID:          "openai",
 528			APIKey:      "$OPENAI_API_KEY",
 529			APIEndpoint: "https://api.openai.com/v1",
 530			Models: []catwalk.Model{{
 531				ID: "test-model",
 532			}},
 533		},
 534	}
 535
 536	cfg := &Config{
 537		Providers: map[string]ProviderConfig{
 538			"openai": {
 539				Disable: true,
 540			},
 541		},
 542	}
 543	cfg.setDefaults("/tmp", "")
 544
 545	env := env.NewFromMap(map[string]string{
 546		"OPENAI_API_KEY": "test-key",
 547	})
 548	resolver := NewEnvironmentVariableResolver(env)
 549	err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 550	require.NoError(t, err)
 551
 552	require.Equal(t, len(cfg.Providers), 1)
 553	prov, exists := cfg.Providers["openai"]
 554	require.True(t, exists)
 555	require.True(t, prov.Disable)
 556}
 557
 558func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 559	t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
 560		cfg := &Config{
 561			Providers: map[string]ProviderConfig{
 562				"custom": {
 563					BaseURL: "https://api.custom.com/v1",
 564					Models: []catwalk.Model{{
 565						ID: "test-model",
 566					}},
 567				},
 568				"openai": {
 569					APIKey: "$MISSING",
 570				},
 571			},
 572		}
 573		cfg.setDefaults("/tmp", "")
 574
 575		env := env.NewFromMap(map[string]string{})
 576		resolver := NewEnvironmentVariableResolver(env)
 577		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
 578		require.NoError(t, err)
 579
 580		require.Equal(t, len(cfg.Providers), 1)
 581		_, exists := cfg.Providers["custom"]
 582		require.True(t, exists)
 583	})
 584
 585	t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
 586		cfg := &Config{
 587			Providers: map[string]ProviderConfig{
 588				"custom": {
 589					APIKey: "test-key",
 590					Models: []catwalk.Model{{
 591						ID: "test-model",
 592					}},
 593				},
 594			},
 595		}
 596		cfg.setDefaults("/tmp", "")
 597
 598		env := env.NewFromMap(map[string]string{})
 599		resolver := NewEnvironmentVariableResolver(env)
 600		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
 601		require.NoError(t, err)
 602
 603		require.Equal(t, len(cfg.Providers), 0)
 604		_, exists := cfg.Providers["custom"]
 605		require.False(t, exists)
 606	})
 607
 608	t.Run("custom provider with no models is removed", func(t *testing.T) {
 609		cfg := &Config{
 610			Providers: map[string]ProviderConfig{
 611				"custom": {
 612					APIKey:  "test-key",
 613					BaseURL: "https://api.custom.com/v1",
 614					Models:  []catwalk.Model{},
 615				},
 616			},
 617		}
 618		cfg.setDefaults("/tmp", "")
 619
 620		env := env.NewFromMap(map[string]string{})
 621		resolver := NewEnvironmentVariableResolver(env)
 622		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
 623		require.NoError(t, err)
 624
 625		require.Equal(t, len(cfg.Providers), 0)
 626		_, exists := cfg.Providers["custom"]
 627		require.False(t, exists)
 628	})
 629
 630	t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
 631		cfg := &Config{
 632			Providers: map[string]ProviderConfig{
 633				"custom": {
 634					APIKey:  "test-key",
 635					BaseURL: "https://api.custom.com/v1",
 636					Type:    "unsupported",
 637					Models: []catwalk.Model{{
 638						ID: "test-model",
 639					}},
 640				},
 641			},
 642		}
 643		cfg.setDefaults("/tmp", "")
 644
 645		env := env.NewFromMap(map[string]string{})
 646		resolver := NewEnvironmentVariableResolver(env)
 647		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
 648		require.NoError(t, err)
 649
 650		require.Equal(t, len(cfg.Providers), 0)
 651		_, exists := cfg.Providers["custom"]
 652		require.False(t, exists)
 653	})
 654
 655	t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
 656		cfg := &Config{
 657			Providers: map[string]ProviderConfig{
 658				"custom": {
 659					APIKey:  "test-key",
 660					BaseURL: "https://api.custom.com/v1",
 661					Type:    catwalk.TypeOpenAI,
 662					Models: []catwalk.Model{{
 663						ID: "test-model",
 664					}},
 665				},
 666			},
 667		}
 668		cfg.setDefaults("/tmp", "")
 669
 670		env := env.NewFromMap(map[string]string{})
 671		resolver := NewEnvironmentVariableResolver(env)
 672		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
 673		require.NoError(t, err)
 674
 675		require.Equal(t, len(cfg.Providers), 1)
 676		customProvider, exists := cfg.Providers["custom"]
 677		require.True(t, exists)
 678		require.Equal(t, "custom", customProvider.ID)
 679		require.Equal(t, "test-key", customProvider.APIKey)
 680		require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
 681	})
 682
 683	t.Run("custom anthropic provider is supported", func(t *testing.T) {
 684		cfg := &Config{
 685			Providers: map[string]ProviderConfig{
 686				"custom-anthropic": {
 687					APIKey:  "test-key",
 688					BaseURL: "https://api.anthropic.com/v1",
 689					Type:    catwalk.TypeAnthropic,
 690					Models: []catwalk.Model{{
 691						ID: "claude-3-sonnet",
 692					}},
 693				},
 694			},
 695		}
 696		cfg.setDefaults("/tmp", "")
 697
 698		env := env.NewFromMap(map[string]string{})
 699		resolver := NewEnvironmentVariableResolver(env)
 700		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
 701		require.NoError(t, err)
 702
 703		require.Equal(t, len(cfg.Providers), 1)
 704		customProvider, exists := cfg.Providers["custom-anthropic"]
 705		require.True(t, exists)
 706		require.Equal(t, "custom-anthropic", customProvider.ID)
 707		require.Equal(t, "test-key", customProvider.APIKey)
 708		require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
 709		require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
 710	})
 711
 712	t.Run("disabled custom provider is removed", func(t *testing.T) {
 713		cfg := &Config{
 714			Providers: map[string]ProviderConfig{
 715				"custom": {
 716					APIKey:  "test-key",
 717					BaseURL: "https://api.custom.com/v1",
 718					Type:    catwalk.TypeOpenAI,
 719					Disable: true,
 720					Models: []catwalk.Model{{
 721						ID: "test-model",
 722					}},
 723				},
 724			},
 725		}
 726		cfg.setDefaults("/tmp", "")
 727
 728		env := env.NewFromMap(map[string]string{})
 729		resolver := NewEnvironmentVariableResolver(env)
 730		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
 731		require.NoError(t, err)
 732
 733		require.Equal(t, len(cfg.Providers), 0)
 734		_, exists := cfg.Providers["custom"]
 735		require.False(t, exists)
 736	})
 737}
 738
 739func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 740	t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
 741		knownProviders := []catwalk.Provider{
 742			{
 743				ID:          catwalk.InferenceProviderVertexAI,
 744				APIKey:      "",
 745				APIEndpoint: "",
 746				Models: []catwalk.Model{{
 747					ID: "gemini-pro",
 748				}},
 749			},
 750		}
 751
 752		cfg := &Config{
 753			Providers: map[string]ProviderConfig{
 754				"vertexai": {
 755					BaseURL: "custom-url",
 756				},
 757			},
 758		}
 759		cfg.setDefaults("/tmp", "")
 760
 761		env := env.NewFromMap(map[string]string{
 762			"GOOGLE_GENAI_USE_VERTEXAI": "false",
 763		})
 764		resolver := NewEnvironmentVariableResolver(env)
 765		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 766		require.NoError(t, err)
 767
 768		require.Equal(t, len(cfg.Providers), 0)
 769		_, exists := cfg.Providers["vertexai"]
 770		require.False(t, exists)
 771	})
 772
 773	t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
 774		knownProviders := []catwalk.Provider{
 775			{
 776				ID:          catwalk.InferenceProviderBedrock,
 777				APIKey:      "",
 778				APIEndpoint: "",
 779				Models: []catwalk.Model{{
 780					ID: "anthropic.claude-sonnet-4-20250514-v1:0",
 781				}},
 782			},
 783		}
 784
 785		cfg := &Config{
 786			Providers: map[string]ProviderConfig{
 787				"bedrock": {
 788					BaseURL: "custom-url",
 789				},
 790			},
 791		}
 792		cfg.setDefaults("/tmp", "")
 793
 794		env := env.NewFromMap(map[string]string{})
 795		resolver := NewEnvironmentVariableResolver(env)
 796		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 797		require.NoError(t, err)
 798
 799		require.Equal(t, len(cfg.Providers), 0)
 800		_, exists := cfg.Providers["bedrock"]
 801		require.False(t, exists)
 802	})
 803
 804	t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
 805		knownProviders := []catwalk.Provider{
 806			{
 807				ID:          "openai",
 808				APIKey:      "$MISSING_API_KEY",
 809				APIEndpoint: "https://api.openai.com/v1",
 810				Models: []catwalk.Model{{
 811					ID: "test-model",
 812				}},
 813			},
 814		}
 815
 816		cfg := &Config{
 817			Providers: map[string]ProviderConfig{
 818				"openai": {
 819					BaseURL: "custom-url",
 820				},
 821			},
 822		}
 823		cfg.setDefaults("/tmp", "")
 824
 825		env := env.NewFromMap(map[string]string{})
 826		resolver := NewEnvironmentVariableResolver(env)
 827		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 828		require.NoError(t, err)
 829
 830		require.Equal(t, len(cfg.Providers), 0)
 831		_, exists := cfg.Providers["openai"]
 832		require.False(t, exists)
 833	})
 834
 835	t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
 836		knownProviders := []catwalk.Provider{
 837			{
 838				ID:          "openai",
 839				APIKey:      "$OPENAI_API_KEY",
 840				APIEndpoint: "$MISSING_ENDPOINT",
 841				Models: []catwalk.Model{{
 842					ID: "test-model",
 843				}},
 844			},
 845		}
 846
 847		cfg := &Config{
 848			Providers: map[string]ProviderConfig{
 849				"openai": {
 850					APIKey: "test-key",
 851				},
 852			},
 853		}
 854		cfg.setDefaults("/tmp", "")
 855
 856		env := env.NewFromMap(map[string]string{
 857			"OPENAI_API_KEY": "test-key",
 858		})
 859		resolver := NewEnvironmentVariableResolver(env)
 860		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 861		require.NoError(t, err)
 862
 863		require.Equal(t, len(cfg.Providers), 1)
 864		_, exists := cfg.Providers["openai"]
 865		require.True(t, exists)
 866	})
 867}
 868
 869func TestConfig_defaultModelSelection(t *testing.T) {
 870	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
 871		knownProviders := []catwalk.Provider{
 872			{
 873				ID:                  "openai",
 874				APIKey:              "abc",
 875				DefaultLargeModelID: "large-model",
 876				DefaultSmallModelID: "small-model",
 877				Models: []catwalk.Model{
 878					{
 879						ID:               "large-model",
 880						DefaultMaxTokens: 1000,
 881					},
 882					{
 883						ID:               "small-model",
 884						DefaultMaxTokens: 500,
 885					},
 886				},
 887			},
 888		}
 889
 890		cfg := &Config{}
 891		cfg.setDefaults("/tmp", "")
 892		env := env.NewFromMap(map[string]string{})
 893		resolver := NewEnvironmentVariableResolver(env)
 894		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 895		require.NoError(t, err)
 896
 897		large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders)
 898		require.NoError(t, err)
 899		require.Equal(t, "large-model", large.Model)
 900		require.Equal(t, "openai", large.Provider)
 901		require.Equal(t, int64(1000), large.MaxTokens)
 902		require.Equal(t, "small-model", small.Model)
 903		require.Equal(t, "openai", small.Provider)
 904		require.Equal(t, int64(500), small.MaxTokens)
 905	})
 906	t.Run("should error if no providers configured", func(t *testing.T) {
 907		knownProviders := []catwalk.Provider{
 908			{
 909				ID:                  "openai",
 910				APIKey:              "$MISSING_KEY",
 911				DefaultLargeModelID: "large-model",
 912				DefaultSmallModelID: "small-model",
 913				Models: []catwalk.Model{
 914					{
 915						ID:               "large-model",
 916						DefaultMaxTokens: 1000,
 917					},
 918					{
 919						ID:               "small-model",
 920						DefaultMaxTokens: 500,
 921					},
 922				},
 923			},
 924		}
 925
 926		cfg := &Config{}
 927		cfg.setDefaults("/tmp", "")
 928		env := env.NewFromMap(map[string]string{})
 929		resolver := NewEnvironmentVariableResolver(env)
 930		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 931		require.NoError(t, err)
 932
 933		_, _, err = serviceFor(cfg).defaultModelSelection(knownProviders)
 934		require.Error(t, err)
 935	})
 936	t.Run("should error if model is missing", func(t *testing.T) {
 937		knownProviders := []catwalk.Provider{
 938			{
 939				ID:                  "openai",
 940				APIKey:              "abc",
 941				DefaultLargeModelID: "large-model",
 942				DefaultSmallModelID: "small-model",
 943				Models: []catwalk.Model{
 944					{
 945						ID:               "not-large-model",
 946						DefaultMaxTokens: 1000,
 947					},
 948					{
 949						ID:               "small-model",
 950						DefaultMaxTokens: 500,
 951					},
 952				},
 953			},
 954		}
 955
 956		cfg := &Config{}
 957		cfg.setDefaults("/tmp", "")
 958		env := env.NewFromMap(map[string]string{})
 959		resolver := NewEnvironmentVariableResolver(env)
 960		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
 961		require.NoError(t, err)
 962		_, _, err = serviceFor(cfg).defaultModelSelection(knownProviders)
 963		require.Error(t, err)
 964	})
 965
 966	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
 967		knownProviders := []catwalk.Provider{
 968			{
 969				ID:                  "openai",
 970				APIKey:              "$MISSING", // will not be included in the config
 971				DefaultLargeModelID: "large-model",
 972				DefaultSmallModelID: "small-model",
 973				Models: []catwalk.Model{
 974					{
 975						ID:               "not-large-model",
 976						DefaultMaxTokens: 1000,
 977					},
 978					{
 979						ID:               "small-model",
 980						DefaultMaxTokens: 500,
 981					},
 982				},
 983			},
 984		}
 985
 986		cfg := &Config{
 987			Providers: map[string]ProviderConfig{
 988				"custom": {
 989					APIKey:  "test-key",
 990					BaseURL: "https://api.custom.com/v1",
 991					Models: []catwalk.Model{
 992						{
 993							ID:               "model",
 994							DefaultMaxTokens: 600,
 995						},
 996					},
 997				},
 998			},
 999		}
1000		cfg.setDefaults("/tmp", "")
1001		env := env.NewFromMap(map[string]string{})
1002		resolver := NewEnvironmentVariableResolver(env)
1003		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1004		require.NoError(t, err)
1005		large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders)
1006		require.NoError(t, err)
1007		require.Equal(t, "model", large.Model)
1008		require.Equal(t, "custom", large.Provider)
1009		require.Equal(t, int64(600), large.MaxTokens)
1010		require.Equal(t, "model", small.Model)
1011		require.Equal(t, "custom", small.Provider)
1012		require.Equal(t, int64(600), small.MaxTokens)
1013	})
1014
1015	t.Run("should fail if no model configured", func(t *testing.T) {
1016		knownProviders := []catwalk.Provider{
1017			{
1018				ID:                  "openai",
1019				APIKey:              "$MISSING", // will not be included in the config
1020				DefaultLargeModelID: "large-model",
1021				DefaultSmallModelID: "small-model",
1022				Models: []catwalk.Model{
1023					{
1024						ID:               "not-large-model",
1025						DefaultMaxTokens: 1000,
1026					},
1027					{
1028						ID:               "small-model",
1029						DefaultMaxTokens: 500,
1030					},
1031				},
1032			},
1033		}
1034
1035		cfg := &Config{
1036			Providers: map[string]ProviderConfig{
1037				"custom": {
1038					APIKey:  "test-key",
1039					BaseURL: "https://api.custom.com/v1",
1040					Models:  []catwalk.Model{},
1041				},
1042			},
1043		}
1044		cfg.setDefaults("/tmp", "")
1045		env := env.NewFromMap(map[string]string{})
1046		resolver := NewEnvironmentVariableResolver(env)
1047		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1048		require.NoError(t, err)
1049		_, _, err = serviceFor(cfg).defaultModelSelection(knownProviders)
1050		require.Error(t, err)
1051	})
1052	t.Run("should use the default provider first", func(t *testing.T) {
1053		knownProviders := []catwalk.Provider{
1054			{
1055				ID:                  "openai",
1056				APIKey:              "set",
1057				DefaultLargeModelID: "large-model",
1058				DefaultSmallModelID: "small-model",
1059				Models: []catwalk.Model{
1060					{
1061						ID:               "large-model",
1062						DefaultMaxTokens: 1000,
1063					},
1064					{
1065						ID:               "small-model",
1066						DefaultMaxTokens: 500,
1067					},
1068				},
1069			},
1070		}
1071
1072		cfg := &Config{
1073			Providers: map[string]ProviderConfig{
1074				"custom": {
1075					APIKey:  "test-key",
1076					BaseURL: "https://api.custom.com/v1",
1077					Models: []catwalk.Model{
1078						{
1079							ID:               "large-model",
1080							DefaultMaxTokens: 1000,
1081						},
1082					},
1083				},
1084			},
1085		}
1086		cfg.setDefaults("/tmp", "")
1087		env := env.NewFromMap(map[string]string{})
1088		resolver := NewEnvironmentVariableResolver(env)
1089		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1090		require.NoError(t, err)
1091		large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders)
1092		require.NoError(t, err)
1093		require.Equal(t, "large-model", large.Model)
1094		require.Equal(t, "openai", large.Provider)
1095		require.Equal(t, int64(1000), large.MaxTokens)
1096		require.Equal(t, "small-model", small.Model)
1097		require.Equal(t, "openai", small.Provider)
1098		require.Equal(t, int64(500), small.MaxTokens)
1099	})
1100}
1101
1102func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
1103	t.Run("when enabled, ignores all default providers and requires full specification", func(t *testing.T) {
1104		knownProviders := []catwalk.Provider{
1105			{
1106				ID:          "openai",
1107				APIKey:      "$OPENAI_API_KEY",
1108				APIEndpoint: "https://api.openai.com/v1",
1109				Models: []catwalk.Model{{
1110					ID: "gpt-4",
1111				}},
1112			},
1113		}
1114
1115		// User references openai but doesn't fully specify it (no base_url, no
1116		// models). This should be rejected because disable_default_providers
1117		// treats all providers as custom.
1118		cfg := &Config{
1119			Options: &Options{
1120				DisableDefaultProviders: true,
1121			},
1122			Providers: map[string]ProviderConfig{
1123				"openai": {
1124					APIKey: "$OPENAI_API_KEY",
1125				},
1126			},
1127		}
1128		cfg.setDefaults("/tmp", "")
1129
1130		env := env.NewFromMap(map[string]string{
1131			"OPENAI_API_KEY": "test-key",
1132		})
1133		resolver := NewEnvironmentVariableResolver(env)
1134		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1135		require.NoError(t, err)
1136
1137		// openai should NOT be present because it lacks base_url and models.
1138		require.Equal(t, 0, len(cfg.Providers))
1139		_, exists := cfg.Providers["openai"]
1140		require.False(t, exists, "openai should not be present without full specification")
1141	})
1142
1143	t.Run("when enabled, fully specified providers work", func(t *testing.T) {
1144		knownProviders := []catwalk.Provider{
1145			{
1146				ID:          "openai",
1147				APIKey:      "$OPENAI_API_KEY",
1148				APIEndpoint: "https://api.openai.com/v1",
1149				Models: []catwalk.Model{{
1150					ID: "gpt-4",
1151				}},
1152			},
1153		}
1154
1155		// User fully specifies their provider.
1156		cfg := &Config{
1157			Options: &Options{
1158				DisableDefaultProviders: true,
1159			},
1160			Providers: map[string]ProviderConfig{
1161				"my-llm": {
1162					APIKey:  "$MY_API_KEY",
1163					BaseURL: "https://my-llm.example.com/v1",
1164					Models: []catwalk.Model{{
1165						ID: "my-model",
1166					}},
1167				},
1168			},
1169		}
1170		cfg.setDefaults("/tmp", "")
1171
1172		env := env.NewFromMap(map[string]string{
1173			"MY_API_KEY":     "test-key",
1174			"OPENAI_API_KEY": "test-key",
1175		})
1176		resolver := NewEnvironmentVariableResolver(env)
1177		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1178		require.NoError(t, err)
1179
1180		// Only fully specified provider should be present.
1181		require.Equal(t, 1, len(cfg.Providers))
1182		provider, exists := cfg.Providers["my-llm"]
1183		require.True(t, exists, "my-llm should be present")
1184		require.Equal(t, "https://my-llm.example.com/v1", provider.BaseURL)
1185		require.Len(t, provider.Models, 1)
1186
1187		// Default openai should NOT be present.
1188		_, exists = cfg.Providers["openai"]
1189		require.False(t, exists, "openai should not be present")
1190	})
1191
1192	t.Run("when disabled, includes all known providers with valid credentials", func(t *testing.T) {
1193		knownProviders := []catwalk.Provider{
1194			{
1195				ID:          "openai",
1196				APIKey:      "$OPENAI_API_KEY",
1197				APIEndpoint: "https://api.openai.com/v1",
1198				Models: []catwalk.Model{{
1199					ID: "gpt-4",
1200				}},
1201			},
1202			{
1203				ID:          "anthropic",
1204				APIKey:      "$ANTHROPIC_API_KEY",
1205				APIEndpoint: "https://api.anthropic.com/v1",
1206				Models: []catwalk.Model{{
1207					ID: "claude-3",
1208				}},
1209			},
1210		}
1211
1212		// User only configures openai, both API keys are available, but option
1213		// is disabled.
1214		cfg := &Config{
1215			Options: &Options{
1216				DisableDefaultProviders: false,
1217			},
1218			Providers: map[string]ProviderConfig{
1219				"openai": {
1220					APIKey: "$OPENAI_API_KEY",
1221				},
1222			},
1223		}
1224		cfg.setDefaults("/tmp", "")
1225
1226		env := env.NewFromMap(map[string]string{
1227			"OPENAI_API_KEY":    "test-key",
1228			"ANTHROPIC_API_KEY": "test-key",
1229		})
1230		resolver := NewEnvironmentVariableResolver(env)
1231		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1232		require.NoError(t, err)
1233
1234		// Both providers should be present.
1235		require.Equal(t, 2, len(cfg.Providers))
1236		_, exists := cfg.Providers["openai"]
1237		require.True(t, exists, "openai should be present")
1238		_, exists = cfg.Providers["anthropic"]
1239		require.True(t, exists, "anthropic should be present")
1240	})
1241
1242	t.Run("when enabled, provider missing models is rejected", func(t *testing.T) {
1243		cfg := &Config{
1244			Options: &Options{
1245				DisableDefaultProviders: true,
1246			},
1247			Providers: map[string]ProviderConfig{
1248				"my-llm": {
1249					APIKey:  "test-key",
1250					BaseURL: "https://my-llm.example.com/v1",
1251					Models:  []catwalk.Model{}, // No models.
1252				},
1253			},
1254		}
1255		cfg.setDefaults("/tmp", "")
1256
1257		env := env.NewFromMap(map[string]string{})
1258		resolver := NewEnvironmentVariableResolver(env)
1259		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
1260		require.NoError(t, err)
1261
1262		// Provider should be rejected for missing models.
1263		require.Equal(t, 0, len(cfg.Providers))
1264	})
1265
1266	t.Run("when enabled, provider missing base_url is rejected", func(t *testing.T) {
1267		cfg := &Config{
1268			Options: &Options{
1269				DisableDefaultProviders: true,
1270			},
1271			Providers: map[string]ProviderConfig{
1272				"my-llm": {
1273					APIKey: "test-key",
1274					Models: []catwalk.Model{{ID: "model"}},
1275					// No BaseURL.
1276				},
1277			},
1278		}
1279		cfg.setDefaults("/tmp", "")
1280
1281		env := env.NewFromMap(map[string]string{})
1282		resolver := NewEnvironmentVariableResolver(env)
1283		err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
1284		require.NoError(t, err)
1285
1286		// Provider should be rejected for missing base_url.
1287		require.Equal(t, 0, len(cfg.Providers))
1288	})
1289}
1290
1291func TestConfig_setDefaultsDisableDefaultProvidersEnvVar(t *testing.T) {
1292	t.Run("sets option from environment variable", func(t *testing.T) {
1293		t.Setenv("CRUSH_DISABLE_DEFAULT_PROVIDERS", "true")
1294
1295		cfg := &Config{}
1296		cfg.setDefaults("/tmp", "")
1297
1298		require.True(t, cfg.Options.DisableDefaultProviders)
1299	})
1300
1301	t.Run("does not override when env var is not set", func(t *testing.T) {
1302		cfg := &Config{
1303			Options: &Options{
1304				DisableDefaultProviders: true,
1305			},
1306		}
1307		cfg.setDefaults("/tmp", "")
1308
1309		require.True(t, cfg.Options.DisableDefaultProviders)
1310	})
1311}
1312
1313func TestConfig_configureSelectedModels(t *testing.T) {
1314	t.Run("should override defaults", func(t *testing.T) {
1315		knownProviders := []catwalk.Provider{
1316			{
1317				ID:                  "openai",
1318				APIKey:              "abc",
1319				DefaultLargeModelID: "large-model",
1320				DefaultSmallModelID: "small-model",
1321				Models: []catwalk.Model{
1322					{
1323						ID:               "larger-model",
1324						DefaultMaxTokens: 2000,
1325					},
1326					{
1327						ID:               "large-model",
1328						DefaultMaxTokens: 1000,
1329					},
1330					{
1331						ID:               "small-model",
1332						DefaultMaxTokens: 500,
1333					},
1334				},
1335			},
1336		}
1337
1338		cfg := &Config{
1339			Models: map[SelectedModelType]SelectedModel{
1340				"large": {
1341					Model: "larger-model",
1342				},
1343			},
1344		}
1345		cfg.setDefaults("/tmp", "")
1346		env := env.NewFromMap(map[string]string{})
1347		resolver := NewEnvironmentVariableResolver(env)
1348		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1349		require.NoError(t, err)
1350
1351		err = serviceFor(cfg).configureSelectedModels(knownProviders)
1352		require.NoError(t, err)
1353		large := cfg.Models[SelectedModelTypeLarge]
1354		small := cfg.Models[SelectedModelTypeSmall]
1355		require.Equal(t, "larger-model", large.Model)
1356		require.Equal(t, "openai", large.Provider)
1357		require.Equal(t, int64(2000), large.MaxTokens)
1358		require.Equal(t, "small-model", small.Model)
1359		require.Equal(t, "openai", small.Provider)
1360		require.Equal(t, int64(500), small.MaxTokens)
1361	})
1362	t.Run("should be possible to use multiple providers", func(t *testing.T) {
1363		knownProviders := []catwalk.Provider{
1364			{
1365				ID:                  "openai",
1366				APIKey:              "abc",
1367				DefaultLargeModelID: "large-model",
1368				DefaultSmallModelID: "small-model",
1369				Models: []catwalk.Model{
1370					{
1371						ID:               "large-model",
1372						DefaultMaxTokens: 1000,
1373					},
1374					{
1375						ID:               "small-model",
1376						DefaultMaxTokens: 500,
1377					},
1378				},
1379			},
1380			{
1381				ID:                  "anthropic",
1382				APIKey:              "abc",
1383				DefaultLargeModelID: "a-large-model",
1384				DefaultSmallModelID: "a-small-model",
1385				Models: []catwalk.Model{
1386					{
1387						ID:               "a-large-model",
1388						DefaultMaxTokens: 1000,
1389					},
1390					{
1391						ID:               "a-small-model",
1392						DefaultMaxTokens: 200,
1393					},
1394				},
1395			},
1396		}
1397
1398		cfg := &Config{
1399			Models: map[SelectedModelType]SelectedModel{
1400				"small": {
1401					Model:     "a-small-model",
1402					Provider:  "anthropic",
1403					MaxTokens: 300,
1404				},
1405			},
1406		}
1407		cfg.setDefaults("/tmp", "")
1408		env := env.NewFromMap(map[string]string{})
1409		resolver := NewEnvironmentVariableResolver(env)
1410		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1411		require.NoError(t, err)
1412
1413		err = serviceFor(cfg).configureSelectedModels(knownProviders)
1414		require.NoError(t, err)
1415		large := cfg.Models[SelectedModelTypeLarge]
1416		small := cfg.Models[SelectedModelTypeSmall]
1417		require.Equal(t, "large-model", large.Model)
1418		require.Equal(t, "openai", large.Provider)
1419		require.Equal(t, int64(1000), large.MaxTokens)
1420		require.Equal(t, "a-small-model", small.Model)
1421		require.Equal(t, "anthropic", small.Provider)
1422		require.Equal(t, int64(300), small.MaxTokens)
1423	})
1424
1425	t.Run("should override the max tokens only", func(t *testing.T) {
1426		knownProviders := []catwalk.Provider{
1427			{
1428				ID:                  "openai",
1429				APIKey:              "abc",
1430				DefaultLargeModelID: "large-model",
1431				DefaultSmallModelID: "small-model",
1432				Models: []catwalk.Model{
1433					{
1434						ID:               "large-model",
1435						DefaultMaxTokens: 1000,
1436					},
1437					{
1438						ID:               "small-model",
1439						DefaultMaxTokens: 500,
1440					},
1441				},
1442			},
1443		}
1444
1445		cfg := &Config{
1446			Models: map[SelectedModelType]SelectedModel{
1447				"large": {
1448					MaxTokens: 100,
1449				},
1450			},
1451		}
1452		cfg.setDefaults("/tmp", "")
1453		env := env.NewFromMap(map[string]string{})
1454		resolver := NewEnvironmentVariableResolver(env)
1455		err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
1456		require.NoError(t, err)
1457
1458		err = serviceFor(cfg).configureSelectedModels(knownProviders)
1459		require.NoError(t, err)
1460		large := cfg.Models[SelectedModelTypeLarge]
1461		require.Equal(t, "large-model", large.Model)
1462		require.Equal(t, "openai", large.Provider)
1463		require.Equal(t, int64(100), large.MaxTokens)
1464	})
1465}