1package config
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "os"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/env"
16 "github.com/invopop/jsonschema"
17 "github.com/tidwall/sjson"
18)
19
20const (
21 appName = "crush"
22 defaultDataDirectory = ".crush"
23 defaultLogLevel = "info"
24)
25
26var defaultContextPaths = []string{
27 ".github/copilot-instructions.md",
28 ".cursorrules",
29 ".cursor/rules/",
30 "CLAUDE.md",
31 "CLAUDE.local.md",
32 "GEMINI.md",
33 "gemini.md",
34 "crush.md",
35 "crush.local.md",
36 "Crush.md",
37 "Crush.local.md",
38 "CRUSH.md",
39 "CRUSH.local.md",
40}
41
42type SelectedModelType string
43
44const (
45 SelectedModelTypeLarge SelectedModelType = "large"
46 SelectedModelTypeSmall SelectedModelType = "small"
47)
48
49// JSONSchema returns the JSON schema for SelectedModelType
50func (SelectedModelType) JSONSchema() *jsonschema.Schema {
51 return &jsonschema.Schema{
52 Type: "string",
53 Description: "Model type selection for different use cases",
54 Enum: []any{"large", "small"},
55 Default: "large",
56 }
57}
58
59type SelectedModel struct {
60 // The model id as used by the provider API.
61 // Required.
62 Model string `json:"model" jsonschema:"required,description=The model ID as used by the provider API,example=gpt-4o"`
63 // The model provider, same as the key/id used in the providers config.
64 // Required.
65 Provider string `json:"provider" jsonschema:"required,description=The model provider ID that matches a key in the providers config,example=openai"`
66
67 // Only used by models that use the openai provider and need this set.
68 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"description=Reasoning effort level for OpenAI models that support it,enum=low,enum=medium,enum=high"`
69
70 // Overrides the default model configuration.
71 MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"description=Maximum number of tokens for model responses,minimum=1,maximum=200000,example=4096"`
72
73 // Used by anthropic models that can reason to indicate if the model should think.
74 Think bool `json:"think,omitempty" jsonschema:"description=Enable thinking mode for Anthropic models that support reasoning"`
75}
76
77type ProviderConfig struct {
78 // The provider's id.
79 ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"`
80 // The provider's name, used for display purposes.
81 Name string `json:"name,omitempty" jsonschema:"description=Human-readable name for the provider,example=OpenAI"`
82 // The provider's API endpoint.
83 BaseURL string `json:"base_url,omitempty" jsonschema:"description=Base URL for the provider's API,format=uri,example=https://api.openai.com/v1"`
84 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
85 Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"`
86 // The provider's API key.
87 APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"`
88 // Marks the provider as disabled.
89 Disable bool `json:"disable,omitempty" jsonschema:"description=Whether this provider is disabled,default=false"`
90
91 // Custom system prompt prefix.
92 SystemPromptPrefix string `json:"system_prompt_prefix,omitempty" jsonschema:"description=Custom prefix to add to system prompts for this provider"`
93
94 // Extra headers to send with each request to the provider.
95 ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"description=Additional HTTP headers to send with requests"`
96 // Extra body
97 ExtraBody map[string]any `json:"extra_body,omitempty" jsonschema:"description=Additional fields to include in request bodies"`
98
99 // Used to pass extra parameters to the provider.
100 ExtraParams map[string]string `json:"-"`
101
102 // The provider models
103 Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"`
104}
105
106type MCPType string
107
108const (
109 MCPStdio MCPType = "stdio"
110 MCPSse MCPType = "sse"
111 MCPHttp MCPType = "http"
112)
113
114// JSONSchema returns the JSON schema for MCPType
115func (MCPType) JSONSchema() *jsonschema.Schema {
116 return &jsonschema.Schema{
117 Type: "string",
118 Description: "Type of MCP connection protocol",
119 Enum: []any{"stdio", "sse", "http"},
120 Default: "stdio",
121 }
122}
123
124type MCPConfig struct {
125 Command string `json:"command,omitempty" jsonschema:"description=Command to execute for stdio MCP servers,example=npx"`
126 Env map[string]string `json:"env,omitempty" jsonschema:"description=Environment variables to set for the MCP server"`
127 Args []string `json:"args,omitempty" jsonschema:"description=Arguments to pass to the MCP server command"`
128 Type MCPType `json:"type" jsonschema:"required,description=Type of MCP connection,enum=stdio,enum=sse,enum=http,default=stdio"`
129 URL string `json:"url,omitempty" jsonschema:"description=URL for HTTP or SSE MCP servers,format=uri,example=http://localhost:3000/mcp"`
130 Disabled bool `json:"disabled,omitempty" jsonschema:"description=Whether this MCP server is disabled,default=false"`
131
132 // TODO: maybe make it possible to get the value from the env
133 Headers map[string]string `json:"headers,omitempty" jsonschema:"description=HTTP headers for HTTP/SSE MCP servers"`
134}
135
136type LSPConfig struct {
137 Disabled bool `json:"enabled,omitempty" jsonschema:"description=Whether this LSP server is disabled,default=false"`
138 Command string `json:"command" jsonschema:"required,description=Command to execute for the LSP server,example=gopls"`
139 Args []string `json:"args,omitempty" jsonschema:"description=Arguments to pass to the LSP server command"`
140 Options any `json:"options,omitempty" jsonschema:"description=LSP server-specific configuration options"`
141}
142
143type TUIOptions struct {
144 CompactMode bool `json:"compact_mode,omitempty" jsonschema:"description=Enable compact mode for the TUI interface,default=false"`
145 // Here we can add themes later or any TUI related options
146}
147
148type Permissions struct {
149 AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"description=List of tools that don't require permission prompts,example=bash,example=view"` // Tools that don't require permission prompts
150 SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
151}
152
153type Options struct {
154 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
155 TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
156 Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
157 DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
158 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
159 DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
160}
161
162type MCPs map[string]MCPConfig
163
164type MCP struct {
165 Name string `json:"name"`
166 MCP MCPConfig `json:"mcp"`
167}
168
169func (m MCPs) Sorted() []MCP {
170 sorted := make([]MCP, 0, len(m))
171 for k, v := range m {
172 sorted = append(sorted, MCP{
173 Name: k,
174 MCP: v,
175 })
176 }
177 slices.SortFunc(sorted, func(a, b MCP) int {
178 return strings.Compare(a.Name, b.Name)
179 })
180 return sorted
181}
182
183type LSPs map[string]LSPConfig
184
185type LSP struct {
186 Name string `json:"name"`
187 LSP LSPConfig `json:"lsp"`
188}
189
190func (l LSPs) Sorted() []LSP {
191 sorted := make([]LSP, 0, len(l))
192 for k, v := range l {
193 sorted = append(sorted, LSP{
194 Name: k,
195 LSP: v,
196 })
197 }
198 slices.SortFunc(sorted, func(a, b LSP) int {
199 return strings.Compare(a.Name, b.Name)
200 })
201 return sorted
202}
203
204func (m MCPConfig) ResolvedEnv() []string {
205 resolver := NewShellVariableResolver(env.New())
206 for e, v := range m.Env {
207 var err error
208 m.Env[e], err = resolver.ResolveValue(v)
209 if err != nil {
210 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
211 continue
212 }
213 }
214
215 env := make([]string, 0, len(m.Env))
216 for k, v := range m.Env {
217 env = append(env, fmt.Sprintf("%s=%s", k, v))
218 }
219 return env
220}
221
222func (m MCPConfig) ResolvedHeaders() map[string]string {
223 resolver := NewShellVariableResolver(env.New())
224 for e, v := range m.Headers {
225 var err error
226 m.Headers[e], err = resolver.ResolveValue(v)
227 if err != nil {
228 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
229 continue
230 }
231 }
232 return m.Headers
233}
234
235type Agent struct {
236 ID string `json:"id,omitempty"`
237 Name string `json:"name,omitempty"`
238 Description string `json:"description,omitempty"`
239 // This is the id of the system prompt used by the agent
240 Disabled bool `json:"disabled,omitempty"`
241
242 Model SelectedModelType `json:"model"`
243
244 // The available tools for the agent
245 // if this is nil, all tools are available
246 AllowedTools []string `json:"allowed_tools,omitempty"`
247
248 // this tells us which MCPs are available for this agent
249 // if this is empty all mcps are available
250 // the string array is the list of tools from the AllowedMCP the agent has available
251 // if the string array is nil, all tools from the AllowedMCP are available
252 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
253
254 // The list of LSPs that this agent can use
255 // if this is nil, all LSPs are available
256 AllowedLSP []string `json:"allowed_lsp,omitempty"`
257
258 // Overrides the context paths for this agent
259 ContextPaths []string `json:"context_paths,omitempty"`
260}
261
262// Config holds the configuration for crush.
263type Config struct {
264 // We currently only support large/small as values here.
265 Models map[SelectedModelType]SelectedModel `json:"models,omitempty" jsonschema:"description=Model configurations for different model types,example={\"large\":{\"model\":\"gpt-4o\",\"provider\":\"openai\"}}"`
266
267 // The providers that are configured
268 Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty" jsonschema:"description=AI provider configurations"`
269
270 MCP MCPs `json:"mcp,omitempty" jsonschema:"description=Model Context Protocol server configurations"`
271
272 LSP LSPs `json:"lsp,omitempty" jsonschema:"description=Language Server Protocol configurations"`
273
274 Options *Options `json:"options,omitempty" jsonschema:"description=General application options"`
275
276 Permissions *Permissions `json:"permissions,omitempty" jsonschema:"description=Permission settings for tool usage"`
277
278 // Internal
279 workingDir string `json:"-"`
280 // TODO: most likely remove this concept when I come back to it
281 Agents map[string]Agent `json:"-"`
282 // TODO: find a better way to do this this should probably not be part of the config
283 resolver VariableResolver
284 dataConfigDir string `json:"-"`
285 knownProviders []catwalk.Provider `json:"-"`
286}
287
288func (c *Config) WorkingDir() string {
289 return c.workingDir
290}
291
292func (c *Config) EnabledProviders() []ProviderConfig {
293 var enabled []ProviderConfig
294 for p := range c.Providers.Seq() {
295 if !p.Disable {
296 enabled = append(enabled, p)
297 }
298 }
299 return enabled
300}
301
302// IsConfigured return true if at least one provider is configured
303func (c *Config) IsConfigured() bool {
304 return len(c.EnabledProviders()) > 0
305}
306
307func (c *Config) GetModel(provider, model string) *catwalk.Model {
308 if providerConfig, ok := c.Providers.Get(provider); ok {
309 for _, m := range providerConfig.Models {
310 if m.ID == model {
311 return &m
312 }
313 }
314 }
315 return nil
316}
317
318func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
319 model, ok := c.Models[modelType]
320 if !ok {
321 return nil
322 }
323 if providerConfig, ok := c.Providers.Get(model.Provider); ok {
324 return &providerConfig
325 }
326 return nil
327}
328
329func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
330 model, ok := c.Models[modelType]
331 if !ok {
332 return nil
333 }
334 return c.GetModel(model.Provider, model.Model)
335}
336
337func (c *Config) LargeModel() *catwalk.Model {
338 model, ok := c.Models[SelectedModelTypeLarge]
339 if !ok {
340 return nil
341 }
342 return c.GetModel(model.Provider, model.Model)
343}
344
345func (c *Config) SmallModel() *catwalk.Model {
346 model, ok := c.Models[SelectedModelTypeSmall]
347 if !ok {
348 return nil
349 }
350 return c.GetModel(model.Provider, model.Model)
351}
352
353func (c *Config) SetCompactMode(enabled bool) error {
354 if c.Options == nil {
355 c.Options = &Options{}
356 }
357 c.Options.TUI.CompactMode = enabled
358 return c.SetConfigField("options.tui.compact_mode", enabled)
359}
360
361func (c *Config) Resolve(key string) (string, error) {
362 if c.resolver == nil {
363 return "", fmt.Errorf("no variable resolver configured")
364 }
365 return c.resolver.ResolveValue(key)
366}
367
368func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
369 c.Models[modelType] = model
370 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
371 return fmt.Errorf("failed to update preferred model: %w", err)
372 }
373 return nil
374}
375
376func (c *Config) SetConfigField(key string, value any) error {
377 // read the data
378 data, err := os.ReadFile(c.dataConfigDir)
379 if err != nil {
380 if os.IsNotExist(err) {
381 data = []byte("{}")
382 } else {
383 return fmt.Errorf("failed to read config file: %w", err)
384 }
385 }
386
387 newValue, err := sjson.Set(string(data), key, value)
388 if err != nil {
389 return fmt.Errorf("failed to set config field %s: %w", key, err)
390 }
391 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
392 return fmt.Errorf("failed to write config file: %w", err)
393 }
394 return nil
395}
396
397func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
398 // First save to the config file
399 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
400 if err != nil {
401 return fmt.Errorf("failed to save API key to config file: %w", err)
402 }
403
404 providerConfig, exists := c.Providers.Get(providerID)
405 if exists {
406 providerConfig.APIKey = apiKey
407 c.Providers.Set(providerID, providerConfig)
408 return nil
409 }
410
411 var foundProvider *catwalk.Provider
412 for _, p := range c.knownProviders {
413 if string(p.ID) == providerID {
414 foundProvider = &p
415 break
416 }
417 }
418
419 if foundProvider != nil {
420 // Create new provider config based on known provider
421 providerConfig = ProviderConfig{
422 ID: providerID,
423 Name: foundProvider.Name,
424 BaseURL: foundProvider.APIEndpoint,
425 Type: foundProvider.Type,
426 APIKey: apiKey,
427 Disable: false,
428 ExtraHeaders: make(map[string]string),
429 ExtraParams: make(map[string]string),
430 Models: foundProvider.Models,
431 }
432 } else {
433 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
434 }
435 // Store the updated provider config
436 c.Providers.Set(providerID, providerConfig)
437 return nil
438}
439
440func (c *Config) SetupAgents() {
441 agents := map[string]Agent{
442 "coder": {
443 ID: "coder",
444 Name: "Coder",
445 Description: "An agent that helps with executing coding tasks.",
446 Model: SelectedModelTypeLarge,
447 ContextPaths: c.Options.ContextPaths,
448 // All tools allowed
449 },
450 "task": {
451 ID: "task",
452 Name: "Task",
453 Description: "An agent that helps with searching for context and finding implementation details.",
454 Model: SelectedModelTypeLarge,
455 ContextPaths: c.Options.ContextPaths,
456 AllowedTools: []string{
457 "glob",
458 "grep",
459 "ls",
460 "sourcegraph",
461 "view",
462 },
463 // NO MCPs or LSPs by default
464 AllowedMCP: map[string][]string{},
465 AllowedLSP: []string{},
466 },
467 }
468 c.Agents = agents
469}
470
471func (c *Config) Resolver() VariableResolver {
472 return c.resolver
473}
474
475func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
476 testURL := ""
477 headers := make(map[string]string)
478 apiKey, _ := resolver.ResolveValue(c.APIKey)
479 switch c.Type {
480 case catwalk.TypeOpenAI:
481 baseURL, _ := resolver.ResolveValue(c.BaseURL)
482 if baseURL == "" {
483 baseURL = "https://api.openai.com/v1"
484 }
485 testURL = baseURL + "/models"
486 headers["Authorization"] = "Bearer " + apiKey
487 case catwalk.TypeAnthropic:
488 baseURL, _ := resolver.ResolveValue(c.BaseURL)
489 if baseURL == "" {
490 baseURL = "https://api.anthropic.com/v1"
491 }
492 testURL = baseURL + "/models"
493 headers["x-api-key"] = apiKey
494 headers["anthropic-version"] = "2023-06-01"
495 }
496 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
497 defer cancel()
498 client := &http.Client{}
499 req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
500 if err != nil {
501 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
502 }
503 for k, v := range headers {
504 req.Header.Set(k, v)
505 }
506 for k, v := range c.ExtraHeaders {
507 req.Header.Set(k, v)
508 }
509 b, err := client.Do(req)
510 if err != nil {
511 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
512 }
513 if b.StatusCode != http.StatusOK {
514 return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
515 }
516 _ = b.Body.Close()
517 return nil
518}