config.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"net/http"
 11	"net/url"
 12	"regexp"
 13	"slices"
 14	"strings"
 15	"time"
 16
 17	"charm.land/catwalk/pkg/catwalk"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/env"
 20	"github.com/charmbracelet/crush/internal/oauth"
 21	"github.com/charmbracelet/crush/internal/oauth/copilot"
 22	"github.com/invopop/jsonschema"
 23)
 24
 25const (
 26	appName              = "crush"
 27	defaultDataDirectory = ".crush"
 28	defaultInitializeAs  = "AGENTS.md"
 29)
 30
 31var defaultContextPaths = []string{
 32	".github/copilot-instructions.md",
 33	".cursorrules",
 34	".cursor/rules/",
 35	"CLAUDE.md",
 36	"CLAUDE.local.md",
 37	"GEMINI.md",
 38	"gemini.md",
 39	"crush.md",
 40	"crush.local.md",
 41	"Crush.md",
 42	"Crush.local.md",
 43	"CRUSH.md",
 44	"CRUSH.local.md",
 45	"AGENTS.md",
 46	"agents.md",
 47	"Agents.md",
 48}
 49
 50type SelectedModelType string
 51
 52// String returns the string representation of the [SelectedModelType].
 53func (s SelectedModelType) String() string {
 54	return string(s)
 55}
 56
 57const (
 58	SelectedModelTypeLarge SelectedModelType = "large"
 59	SelectedModelTypeSmall SelectedModelType = "small"
 60)
 61
 62const (
 63	AgentCoder string = "coder"
 64	AgentTask  string = "task"
 65)
 66
 67type SelectedModel struct {
 68	// The model id as used by the provider API.
 69	// Required.
 70	Model string `json:"model" jsonschema:"required,description=The model ID as used by the provider API,example=gpt-4o"`
 71	// The model provider, same as the key/id used in the providers config.
 72	// Required.
 73	Provider string `json:"provider" jsonschema:"required,description=The model provider ID that matches a key in the providers config,example=openai"`
 74
 75	// Only used by models that use the openai provider and need this set.
 76	ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"description=Reasoning effort level for OpenAI models that support it,enum=low,enum=medium,enum=high"`
 77
 78	// Used by anthropic models that can reason to indicate if the model should think.
 79	Think bool `json:"think,omitempty" jsonschema:"description=Enable thinking mode for Anthropic models that support reasoning"`
 80
 81	// Overrides the default model configuration.
 82	MaxTokens        int64    `json:"max_tokens,omitempty" jsonschema:"description=Maximum number of tokens for model responses,maximum=200000,example=4096"`
 83	Temperature      *float64 `json:"temperature,omitempty" jsonschema:"description=Sampling temperature,minimum=0,maximum=1,example=0.7"`
 84	TopP             *float64 `json:"top_p,omitempty" jsonschema:"description=Top-p (nucleus) sampling parameter,minimum=0,maximum=1,example=0.9"`
 85	TopK             *int64   `json:"top_k,omitempty" jsonschema:"description=Top-k sampling parameter"`
 86	FrequencyPenalty *float64 `json:"frequency_penalty,omitempty" jsonschema:"description=Frequency penalty to reduce repetition"`
 87	PresencePenalty  *float64 `json:"presence_penalty,omitempty" jsonschema:"description=Presence penalty to increase topic diversity"`
 88
 89	// Override provider specific options.
 90	ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for the model"`
 91}
 92
 93type ProviderConfig struct {
 94	// The provider's id.
 95	ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"`
 96	// The provider's name, used for display purposes.
 97	Name string `json:"name,omitempty" jsonschema:"description=Human-readable name for the provider,example=OpenAI"`
 98	// The provider's API endpoint.
 99	BaseURL string `json:"base_url,omitempty" jsonschema:"description=Base URL for the provider's API,format=uri,example=https://api.openai.com/v1"`
100	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
101	Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=openai-compat,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"`
102	// The provider's API key.
103	APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"`
104	// The original API key template before resolution (for re-resolution on auth errors).
105	APIKeyTemplate string `json:"-"`
106	// OAuthToken for providers that use OAuth2 authentication.
107	OAuthToken *oauth.Token `json:"oauth,omitempty" jsonschema:"description=OAuth2 token for authentication with the provider"`
108	// Marks the provider as disabled.
109	Disable bool `json:"disable,omitempty" jsonschema:"description=Whether this provider is disabled,default=false"`
110
111	// Custom system prompt prefix.
112	SystemPromptPrefix string `json:"system_prompt_prefix,omitempty" jsonschema:"description=Custom prefix to add to system prompts for this provider"`
113
114	// Extra headers to send with each request to the provider.
115	ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"description=Additional HTTP headers to send with requests"`
116	// Extra body
117	ExtraBody map[string]any `json:"extra_body,omitempty" jsonschema:"description=Additional fields to include in request bodies, only works with openai-compatible providers"`
118
119	ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for this provider"`
120
121	// Used to pass extra parameters to the provider.
122	ExtraParams map[string]string `json:"-"`
123
124	// The provider models
125	Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"`
126}
127
128// ToProvider converts the [ProviderConfig] to a [catwalk.Provider].
129func (c *ProviderConfig) ToProvider() catwalk.Provider {
130	// Convert config provider to provider.Provider format
131	provider := catwalk.Provider{
132		Name:   c.Name,
133		ID:     catwalk.InferenceProvider(c.ID),
134		Models: make([]catwalk.Model, len(c.Models)),
135	}
136
137	// Convert models
138	for i, model := range c.Models {
139		provider.Models[i] = catwalk.Model{
140			ID:                     model.ID,
141			Name:                   model.Name,
142			CostPer1MIn:            model.CostPer1MIn,
143			CostPer1MOut:           model.CostPer1MOut,
144			CostPer1MInCached:      model.CostPer1MInCached,
145			CostPer1MOutCached:     model.CostPer1MOutCached,
146			ContextWindow:          model.ContextWindow,
147			DefaultMaxTokens:       model.DefaultMaxTokens,
148			CanReason:              model.CanReason,
149			ReasoningLevels:        model.ReasoningLevels,
150			DefaultReasoningEffort: model.DefaultReasoningEffort,
151			SupportsImages:         model.SupportsImages,
152		}
153	}
154
155	return provider
156}
157
158func (c *ProviderConfig) SetupGitHubCopilot() {
159	maps.Copy(c.ExtraHeaders, copilot.Headers())
160}
161
162type MCPType string
163
164const (
165	MCPStdio MCPType = "stdio"
166	MCPSSE   MCPType = "sse"
167	MCPHttp  MCPType = "http"
168)
169
170type MCPConfig struct {
171	Command       string            `json:"command,omitempty" jsonschema:"description=Command to execute for stdio MCP servers,example=npx"`
172	Env           map[string]string `json:"env,omitempty" jsonschema:"description=Environment variables to set for the MCP server"`
173	Args          []string          `json:"args,omitempty" jsonschema:"description=Arguments to pass to the MCP server command"`
174	Type          MCPType           `json:"type" jsonschema:"required,description=Type of MCP connection,enum=stdio,enum=sse,enum=http,default=stdio"`
175	URL           string            `json:"url,omitempty" jsonschema:"description=URL for HTTP or SSE MCP servers,format=uri,example=http://localhost:3000/mcp"`
176	Disabled      bool              `json:"disabled,omitempty" jsonschema:"description=Whether this MCP server is disabled,default=false"`
177	DisabledTools []string          `json:"disabled_tools,omitempty" jsonschema:"description=List of tools from this MCP server to disable,example=get-library-doc"`
178	Timeout       int               `json:"timeout,omitempty" jsonschema:"description=Timeout in seconds for MCP server connections,default=15,example=30,example=60,example=120"`
179
180	// TODO: maybe make it possible to get the value from the env
181	Headers map[string]string `json:"headers,omitempty" jsonschema:"description=HTTP headers for HTTP/SSE MCP servers"`
182}
183
184type LSPConfig struct {
185	Disabled    bool              `json:"disabled,omitempty" jsonschema:"description=Whether this LSP server is disabled,default=false"`
186	Command     string            `json:"command,omitempty" jsonschema:"description=Command to execute for the LSP server,example=gopls"`
187	Args        []string          `json:"args,omitempty" jsonschema:"description=Arguments to pass to the LSP server command"`
188	Env         map[string]string `json:"env,omitempty" jsonschema:"description=Environment variables to set to the LSP server command"`
189	FileTypes   []string          `json:"filetypes,omitempty" jsonschema:"description=File types this LSP server handles,example=go,example=mod,example=rs,example=c,example=js,example=ts"`
190	RootMarkers []string          `json:"root_markers,omitempty" jsonschema:"description=Files or directories that indicate the project root,example=go.mod,example=package.json,example=Cargo.toml"`
191	InitOptions map[string]any    `json:"init_options,omitempty" jsonschema:"description=Initialization options passed to the LSP server during initialize request"`
192	Options     map[string]any    `json:"options,omitempty" jsonschema:"description=LSP server-specific settings passed during initialization"`
193	Timeout     int               `json:"timeout,omitempty" jsonschema:"description=Timeout in seconds for LSP server initialization,default=30,example=60,example=120"`
194}
195
196type TUIOptions struct {
197	CompactMode bool   `json:"compact_mode,omitempty" jsonschema:"description=Enable compact mode for the TUI interface,default=false"`
198	DiffMode    string `json:"diff_mode,omitempty" jsonschema:"description=Diff mode for the TUI interface,enum=unified,enum=split"`
199	// Here we can add themes later or any TUI related options
200	//
201
202	Completions Completions `json:"completions,omitzero" jsonschema:"description=Completions UI options"`
203	Transparent *bool       `json:"transparent,omitempty" jsonschema:"description=Enable transparent background for the TUI interface,default=false"`
204}
205
206// Completions defines options for the completions UI.
207type Completions struct {
208	MaxDepth *int `json:"max_depth,omitempty" jsonschema:"description=Maximum depth for the ls tool,default=0,example=10"`
209	MaxItems *int `json:"max_items,omitempty" jsonschema:"description=Maximum number of items to return for the ls tool,default=1000,example=100"`
210}
211
212func (c Completions) Limits() (depth, items int) {
213	return ptrValOr(c.MaxDepth, 0), ptrValOr(c.MaxItems, 0)
214}
215
216type Permissions struct {
217	AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"description=List of tools that don't require permission prompts,example=bash,example=view"`
218}
219
220type TrailerStyle string
221
222const (
223	TrailerStyleNone         TrailerStyle = "none"
224	TrailerStyleCoAuthoredBy TrailerStyle = "co-authored-by"
225	TrailerStyleAssistedBy   TrailerStyle = "assisted-by"
226)
227
228type Attribution struct {
229	TrailerStyle  TrailerStyle `json:"trailer_style,omitempty" jsonschema:"description=Style of attribution trailer to add to commits,enum=none,enum=co-authored-by,enum=assisted-by,default=assisted-by"`
230	CoAuthoredBy  *bool        `json:"co_authored_by,omitempty" jsonschema:"description=Deprecated: use trailer_style instead"`
231	GeneratedWith bool         `json:"generated_with,omitempty" jsonschema:"description=Add Generated with Crush line to commit messages and issues and PRs,default=true"`
232}
233
234// JSONSchemaExtend marks the co_authored_by field as deprecated in the schema.
235func (Attribution) JSONSchemaExtend(schema *jsonschema.Schema) {
236	if schema.Properties != nil {
237		if prop, ok := schema.Properties.Get("co_authored_by"); ok {
238			prop.Deprecated = true
239		}
240	}
241}
242
243type Options struct {
244	ContextPaths              []string     `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
245	SkillsPaths               []string     `json:"skills_paths,omitempty" jsonschema:"description=Paths to directories containing Agent Skills (folders with SKILL.md files),example=~/.config/crush/skills,example=./skills"`
246	TUI                       *TUIOptions  `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
247	Debug                     bool         `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
248	DebugLSP                  bool         `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
249	DisableAutoSummarize      bool         `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
250	DataDirectory             string       `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
251	DisabledTools             []string     `json:"disabled_tools,omitempty" jsonschema:"description=List of built-in tools to disable and hide from the agent,example=bash,example=sourcegraph"`
252	DisableProviderAutoUpdate bool         `json:"disable_provider_auto_update,omitempty" jsonschema:"description=Disable providers auto-update,default=false"`
253	DisableDefaultProviders   bool         `json:"disable_default_providers,omitempty" jsonschema:"description=Ignore all default/embedded providers. When enabled, providers must be fully specified in the config file with base_url, models, and api_key - no merging with defaults occurs,default=false"`
254	Attribution               *Attribution `json:"attribution,omitempty" jsonschema:"description=Attribution settings for generated content"`
255	DisableMetrics            bool         `json:"disable_metrics,omitempty" jsonschema:"description=Disable sending metrics,default=false"`
256	InitializeAs              string       `json:"initialize_as,omitempty" jsonschema:"description=Name of the context file to create/update during project initialization,default=AGENTS.md,example=AGENTS.md,example=CRUSH.md,example=CLAUDE.md,example=docs/LLMs.md"`
257	AutoLSP                   *bool        `json:"auto_lsp,omitempty" jsonschema:"description=Automatically setup LSPs based on root markers,default=true"`
258	Progress                  *bool        `json:"progress,omitempty" jsonschema:"description=Show indeterminate progress updates during long operations,default=true"`
259	DisableNotifications      bool         `json:"disable_notifications,omitempty" jsonschema:"description=Disable desktop notifications,default=false"`
260	DisabledSkills            []string     `json:"disabled_skills,omitempty" jsonschema:"description=List of skill names to disable and hide from the agent,example=crush-config"`
261}
262
263type MCPs map[string]MCPConfig
264
265type MCP struct {
266	Name string    `json:"name"`
267	MCP  MCPConfig `json:"mcp"`
268}
269
270func (m MCPs) Sorted() []MCP {
271	sorted := make([]MCP, 0, len(m))
272	for k, v := range m {
273		sorted = append(sorted, MCP{
274			Name: k,
275			MCP:  v,
276		})
277	}
278	slices.SortFunc(sorted, func(a, b MCP) int {
279		return strings.Compare(a.Name, b.Name)
280	})
281	return sorted
282}
283
284type LSPs map[string]LSPConfig
285
286type LSP struct {
287	Name string    `json:"name"`
288	LSP  LSPConfig `json:"lsp"`
289}
290
291func (l LSPs) Sorted() []LSP {
292	sorted := make([]LSP, 0, len(l))
293	for k, v := range l {
294		sorted = append(sorted, LSP{
295			Name: k,
296			LSP:  v,
297		})
298	}
299	slices.SortFunc(sorted, func(a, b LSP) int {
300		return strings.Compare(a.Name, b.Name)
301	})
302	return sorted
303}
304
305func (l LSPConfig) ResolvedEnv() []string {
306	return resolveEnvs(l.Env)
307}
308
309func (m MCPConfig) ResolvedEnv() []string {
310	return resolveEnvs(m.Env)
311}
312
313func (m MCPConfig) ResolvedHeaders() map[string]string {
314	resolver := NewShellVariableResolver(env.New())
315	for e, v := range m.Headers {
316		var err error
317		m.Headers[e], err = resolver.ResolveValue(v)
318		if err != nil {
319			slog.Error("Error resolving header variable", "error", err, "variable", e, "value", v)
320			continue
321		}
322	}
323	return m.Headers
324}
325
326type Agent struct {
327	ID          string `json:"id,omitempty"`
328	Name        string `json:"name,omitempty"`
329	Description string `json:"description,omitempty"`
330	// This is the id of the system prompt used by the agent
331	Disabled bool `json:"disabled,omitempty"`
332
333	Model SelectedModelType `json:"model" jsonschema:"required,description=The model type to use for this agent,enum=large,enum=small,default=large"`
334
335	// The available tools for the agent
336	//  if this is nil, all tools are available
337	AllowedTools []string `json:"allowed_tools,omitempty"`
338
339	// this tells us which MCPs are available for this agent
340	//  if this is empty all mcps are available
341	//  the string array is the list of tools from the AllowedMCP the agent has available
342	//  if the string array is nil, all tools from the AllowedMCP are available
343	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
344
345	// Overrides the context paths for this agent
346	ContextPaths []string `json:"context_paths,omitempty"`
347}
348
349type Tools struct {
350	Ls   ToolLs   `json:"ls,omitzero"`
351	Grep ToolGrep `json:"grep,omitzero"`
352}
353
354type ToolLs struct {
355	MaxDepth *int `json:"max_depth,omitempty" jsonschema:"description=Maximum depth for the ls tool,default=0,example=10"`
356	MaxItems *int `json:"max_items,omitempty" jsonschema:"description=Maximum number of items to return for the ls tool,default=1000,example=100"`
357}
358
359// Limits returns the user-defined max-depth and max-items, or their defaults.
360func (t ToolLs) Limits() (depth, items int) {
361	return ptrValOr(t.MaxDepth, 0), ptrValOr(t.MaxItems, 0)
362}
363
364type ToolGrep struct {
365	Timeout *time.Duration `json:"timeout,omitempty" jsonschema:"description=Timeout for the grep tool call,default=5s,example=10s"`
366}
367
368// GetTimeout returns the user-defined timeout or the default.
369func (t ToolGrep) GetTimeout() time.Duration {
370	return ptrValOr(t.Timeout, 5*time.Second)
371}
372
373// HookConfig defines a user-configured shell command that fires on a hook
374// event (e.g. PreToolUse).
375type HookConfig struct {
376	// Regex pattern tested against the tool name. Empty means match all.
377	Matcher string `json:"matcher,omitempty" jsonschema:"description=Regex pattern tested against the tool name. Empty means match all tools."`
378	// Shell command to execute.
379	Command string `json:"command" jsonschema:"required,description=Shell command to execute when the hook fires"`
380	// Timeout in seconds. Default 30.
381	Timeout int `json:"timeout,omitempty" jsonschema:"description=Timeout in seconds for the hook command,default=30"`
382
383	// Compiled matcher regex. Not serialized.
384	matcherRegex *regexp.Regexp
385}
386
387// MatcherRegex returns the compiled matcher regex, or nil if no matcher is
388// set.
389func (h *HookConfig) MatcherRegex() *regexp.Regexp {
390	return h.matcherRegex
391}
392
393// TimeoutDuration returns the hook timeout as a time.Duration, defaulting
394// to 30s.
395func (h *HookConfig) TimeoutDuration() time.Duration {
396	if h.Timeout <= 0 {
397		return 30 * time.Second
398	}
399	return time.Duration(h.Timeout) * time.Second
400}
401
402// Config holds the configuration for crush.
403type Config struct {
404	Schema string `json:"$schema,omitempty"`
405
406	// We currently only support large/small as values here.
407	Models map[SelectedModelType]SelectedModel `json:"models,omitempty" jsonschema:"description=Model configurations for different model types,example={\"large\":{\"model\":\"gpt-4o\",\"provider\":\"openai\"}}"`
408
409	// Recently used models stored in the data directory config.
410	RecentModels map[SelectedModelType][]SelectedModel `json:"recent_models,omitempty" jsonschema:"-"`
411
412	// The providers that are configured
413	Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty" jsonschema:"description=AI provider configurations"`
414
415	MCP MCPs `json:"mcp,omitempty" jsonschema:"description=Model Context Protocol server configurations"`
416
417	LSP LSPs `json:"lsp,omitempty" jsonschema:"description=Language Server Protocol configurations"`
418
419	Options *Options `json:"options,omitempty" jsonschema:"description=General application options"`
420
421	Permissions *Permissions `json:"permissions,omitempty" jsonschema:"description=Permission settings for tool usage"`
422
423	Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"`
424
425	Hooks map[string][]HookConfig `json:"hooks,omitempty" jsonschema:"description=User-defined shell commands that fire on hook events (e.g. PreToolUse)"`
426
427	Agents map[string]Agent `json:"-"`
428}
429
430func (c *Config) EnabledProviders() []ProviderConfig {
431	var enabled []ProviderConfig
432	for p := range c.Providers.Seq() {
433		if !p.Disable {
434			enabled = append(enabled, p)
435		}
436	}
437	return enabled
438}
439
440// IsConfigured  return true if at least one provider is configured
441func (c *Config) IsConfigured() bool {
442	return len(c.EnabledProviders()) > 0
443}
444
445func (c *Config) GetModel(provider, model string) *catwalk.Model {
446	if providerConfig, ok := c.Providers.Get(provider); ok {
447		for _, m := range providerConfig.Models {
448			if m.ID == model {
449				return &m
450			}
451		}
452	}
453	return nil
454}
455
456func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
457	model, ok := c.Models[modelType]
458	if !ok {
459		return nil
460	}
461	if providerConfig, ok := c.Providers.Get(model.Provider); ok {
462		return &providerConfig
463	}
464	return nil
465}
466
467func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
468	model, ok := c.Models[modelType]
469	if !ok {
470		return nil
471	}
472	return c.GetModel(model.Provider, model.Model)
473}
474
475func (c *Config) LargeModel() *catwalk.Model {
476	model, ok := c.Models[SelectedModelTypeLarge]
477	if !ok {
478		return nil
479	}
480	return c.GetModel(model.Provider, model.Model)
481}
482
483func (c *Config) SmallModel() *catwalk.Model {
484	model, ok := c.Models[SelectedModelTypeSmall]
485	if !ok {
486		return nil
487	}
488	return c.GetModel(model.Provider, model.Model)
489}
490
491const maxRecentModelsPerType = 5
492
493func allToolNames() []string {
494	return []string{
495		"agent",
496		"bash",
497		"crush_info",
498		"crush_logs",
499		"job_output",
500		"job_kill",
501		"download",
502		"edit",
503		"multiedit",
504		"lsp_diagnostics",
505		"lsp_references",
506		"lsp_restart",
507		"fetch",
508		"agentic_fetch",
509		"glob",
510		"grep",
511		"ls",
512		"sourcegraph",
513		"todos",
514		"view",
515		"write",
516		"list_mcp_resources",
517		"read_mcp_resource",
518	}
519}
520
521func resolveAllowedTools(allTools []string, disabledTools []string) []string {
522	if disabledTools == nil {
523		return allTools
524	}
525	// filter out disabled tools (exclude mode)
526	return filterSlice(allTools, disabledTools, false)
527}
528
529func resolveReadOnlyTools(tools []string) []string {
530	readOnlyTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
531	// filter to only include tools that are in allowedtools (include mode)
532	return filterSlice(tools, readOnlyTools, true)
533}
534
535func filterSlice(data []string, mask []string, include bool) []string {
536	var filtered []string
537	for _, s := range data {
538		// if include is true, we include items that ARE in the mask
539		// if include is false, we include items that are NOT in the mask
540		if include == slices.Contains(mask, s) {
541			filtered = append(filtered, s)
542		}
543	}
544	return filtered
545}
546
547func (c *Config) SetupAgents() {
548	allowedTools := resolveAllowedTools(allToolNames(), c.Options.DisabledTools)
549
550	agents := map[string]Agent{
551		AgentCoder: {
552			ID:           AgentCoder,
553			Name:         "Coder",
554			Description:  "An agent that helps with executing coding tasks.",
555			Model:        SelectedModelTypeLarge,
556			ContextPaths: c.Options.ContextPaths,
557			AllowedTools: allowedTools,
558		},
559
560		AgentTask: {
561			ID:           AgentTask,
562			Name:         "Task",
563			Description:  "An agent that helps with searching for context and finding implementation details.",
564			Model:        SelectedModelTypeLarge,
565			ContextPaths: c.Options.ContextPaths,
566			AllowedTools: resolveReadOnlyTools(allowedTools),
567			// NO MCPs or LSPs by default
568			AllowedMCP: map[string][]string{},
569		},
570	}
571	c.Agents = agents
572}
573
574func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
575	var (
576		providerID = catwalk.InferenceProvider(c.ID)
577		testURL    = ""
578		headers    = make(map[string]string)
579		apiKey, _  = resolver.ResolveValue(c.APIKey)
580	)
581
582	switch providerID {
583	case catwalk.InferenceProviderMiniMax, catwalk.InferenceProviderMiniMaxChina:
584		// NOTE: MiniMax has no good endpoint we can use to validate the API key.
585		return nil
586	}
587
588	switch c.Type {
589	case catwalk.TypeOpenAI, catwalk.TypeOpenAICompat, catwalk.TypeOpenRouter:
590		baseURL, _ := resolver.ResolveValue(c.BaseURL)
591		baseURL = cmp.Or(baseURL, "https://api.openai.com/v1")
592
593		switch providerID {
594		case catwalk.InferenceProviderOpenRouter:
595			testURL = baseURL + "/credits"
596		case catwalk.InferenceProviderOpenCodeGo:
597			testURL = strings.Replace(baseURL, "/go", "", 1) + "/models"
598		default:
599			testURL = baseURL + "/models"
600		}
601
602		headers["Authorization"] = "Bearer " + apiKey
603	case catwalk.TypeAnthropic:
604		baseURL, _ := resolver.ResolveValue(c.BaseURL)
605		baseURL = cmp.Or(baseURL, "https://api.anthropic.com/v1")
606
607		switch providerID {
608		case catwalk.InferenceKimiCoding:
609			testURL = baseURL + "/v1/models"
610		default:
611			testURL = baseURL + "/models"
612		}
613
614		headers["x-api-key"] = apiKey
615		headers["anthropic-version"] = "2023-06-01"
616	case catwalk.TypeGoogle:
617		baseURL, _ := resolver.ResolveValue(c.BaseURL)
618		baseURL = cmp.Or(baseURL, "https://generativelanguage.googleapis.com")
619		testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey)
620	case catwalk.TypeBedrock:
621		// NOTE: Bedrock has a `/foundation-models` endpoint that we could in
622		// theory use, but apparently the authorization is region-specific,
623		// so it's not so trivial.
624		if strings.HasPrefix(apiKey, "ABSK") { // Bedrock API keys
625			return nil
626		}
627		return errors.New("not a valid bedrock api key")
628	case catwalk.TypeVercel:
629		// NOTE: Vercel does not validate API keys on the `/models` endpoint.
630		if strings.HasPrefix(apiKey, "vck_") { // Vercel API keys
631			return nil
632		}
633		return errors.New("not a valid vercel api key")
634	}
635
636	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
637	defer cancel()
638
639	client := &http.Client{}
640	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
641	if err != nil {
642		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
643	}
644	for k, v := range headers {
645		req.Header.Set(k, v)
646	}
647	for k, v := range c.ExtraHeaders {
648		req.Header.Set(k, v)
649	}
650
651	resp, err := client.Do(req)
652	if err != nil {
653		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
654	}
655	defer resp.Body.Close()
656
657	switch providerID {
658	case catwalk.InferenceProviderZAI:
659		if resp.StatusCode == http.StatusUnauthorized {
660			return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
661		}
662	default:
663		if resp.StatusCode != http.StatusOK {
664			return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
665		}
666	}
667	return nil
668}
669
670func resolveEnvs(envs map[string]string) []string {
671	resolver := NewShellVariableResolver(env.New())
672	for e, v := range envs {
673		var err error
674		envs[e], err = resolver.ResolveValue(v)
675		if err != nil {
676			slog.Error("Error resolving environment variable", "error", err, "variable", e, "value", v)
677			continue
678		}
679	}
680
681	res := make([]string, 0, len(envs))
682	for k, v := range envs {
683		res = append(res, fmt.Sprintf("%s=%s", k, v))
684	}
685	return res
686}
687
688func ptrValOr[T any](t *T, el T) T {
689	if t == nil {
690		return el
691	}
692	return *t
693}