1package config
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "os"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/env"
16 "github.com/tidwall/sjson"
17)
18
19const (
20 appName = "crush"
21 defaultDataDirectory = ".crush"
22 defaultLogLevel = "info"
23)
24
25var defaultContextPaths = []string{
26 ".github/copilot-instructions.md",
27 ".cursorrules",
28 ".cursor/rules/",
29 "CLAUDE.md",
30 "CLAUDE.local.md",
31 "GEMINI.md",
32 "gemini.md",
33 "crush.md",
34 "crush.local.md",
35 "Crush.md",
36 "Crush.local.md",
37 "CRUSH.md",
38 "CRUSH.local.md",
39}
40
41type SelectedModelType string
42
43const (
44 SelectedModelTypeLarge SelectedModelType = "large"
45 SelectedModelTypeSmall SelectedModelType = "small"
46)
47
48type SelectedModel struct {
49 // The model id as used by the provider API.
50 // Required.
51 Model string `json:"model" jsonschema:"required,description=The model ID as used by the provider API,example=gpt-4o"`
52 // The model provider, same as the key/id used in the providers config.
53 // Required.
54 Provider string `json:"provider" jsonschema:"required,description=The model provider ID that matches a key in the providers config,example=openai"`
55
56 // Only used by models that use the openai provider and need this set.
57 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"description=Reasoning effort level for OpenAI models that support it,enum=low,enum=medium,enum=high"`
58
59 // Overrides the default model configuration.
60 MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"description=Maximum number of tokens for model responses,minimum=1,maximum=200000,example=4096"`
61
62 // Used by anthropic models that can reason to indicate if the model should think.
63 Think bool `json:"think,omitempty" jsonschema:"description=Enable thinking mode for Anthropic models that support reasoning"`
64}
65
66type ProviderConfig struct {
67 // The provider's id.
68 ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"`
69 // The provider's name, used for display purposes.
70 Name string `json:"name,omitempty" jsonschema:"description=Human-readable name for the provider,example=OpenAI"`
71 // The provider's API endpoint.
72 BaseURL string `json:"base_url,omitempty" jsonschema:"description=Base URL for the provider's API,format=uri,example=https://api.openai.com/v1"`
73 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
74 Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"`
75 // The provider's API key.
76 APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"`
77 // Marks the provider as disabled.
78 Disable bool `json:"disable,omitempty" jsonschema:"description=Whether this provider is disabled,default=false"`
79
80 // Custom system prompt prefix.
81 SystemPromptPrefix string `json:"system_prompt_prefix,omitempty" jsonschema:"description=Custom prefix to add to system prompts for this provider"`
82
83 // Extra headers to send with each request to the provider.
84 ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"description=Additional HTTP headers to send with requests"`
85 // Extra body
86 ExtraBody map[string]any `json:"extra_body,omitempty" jsonschema:"description=Additional fields to include in request bodies"`
87
88 // Used to pass extra parameters to the provider.
89 ExtraParams map[string]string `json:"-"`
90
91 // The provider models
92 Models []catwalk.Model `json:"models,omitempty" jsonschema:"description=List of models available from this provider"`
93}
94
95type MCPType string
96
97const (
98 MCPStdio MCPType = "stdio"
99 MCPSse MCPType = "sse"
100 MCPHttp MCPType = "http"
101)
102
103type MCPConfig struct {
104 Command string `json:"command,omitempty" jsonschema:"description=Command to execute for stdio MCP servers,example=npx"`
105 Env map[string]string `json:"env,omitempty" jsonschema:"description=Environment variables to set for the MCP server"`
106 Args []string `json:"args,omitempty" jsonschema:"description=Arguments to pass to the MCP server command"`
107 Type MCPType `json:"type" jsonschema:"required,description=Type of MCP connection,enum=stdio,enum=sse,enum=http,default=stdio"`
108 URL string `json:"url,omitempty" jsonschema:"description=URL for HTTP or SSE MCP servers,format=uri,example=http://localhost:3000/mcp"`
109 Disabled bool `json:"disabled,omitempty" jsonschema:"description=Whether this MCP server is disabled,default=false"`
110
111 // TODO: maybe make it possible to get the value from the env
112 Headers map[string]string `json:"headers,omitempty" jsonschema:"description=HTTP headers for HTTP/SSE MCP servers"`
113}
114
115type LSPConfig struct {
116 Disabled bool `json:"enabled,omitempty" jsonschema:"description=Whether this LSP server is disabled,default=false"`
117 Command string `json:"command" jsonschema:"required,description=Command to execute for the LSP server,example=gopls"`
118 Args []string `json:"args,omitempty" jsonschema:"description=Arguments to pass to the LSP server command"`
119 Options any `json:"options,omitempty" jsonschema:"description=LSP server-specific configuration options"`
120}
121
122type TUIOptions struct {
123 CompactMode bool `json:"compact_mode,omitempty" jsonschema:"description=Enable compact mode for the TUI interface,default=false"`
124 // Here we can add themes later or any TUI related options
125}
126
127type Permissions struct {
128 AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"description=List of tools that don't require permission prompts,example=bash,example=view"` // Tools that don't require permission prompts
129 SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
130}
131
132type Options struct {
133 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
134 TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
135 Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
136 DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
137 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
138 DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
139}
140
141type MCPs map[string]MCPConfig
142
143type MCP struct {
144 Name string `json:"name"`
145 MCP MCPConfig `json:"mcp"`
146}
147
148func (m MCPs) Sorted() []MCP {
149 sorted := make([]MCP, 0, len(m))
150 for k, v := range m {
151 sorted = append(sorted, MCP{
152 Name: k,
153 MCP: v,
154 })
155 }
156 slices.SortFunc(sorted, func(a, b MCP) int {
157 return strings.Compare(a.Name, b.Name)
158 })
159 return sorted
160}
161
162type LSPs map[string]LSPConfig
163
164type LSP struct {
165 Name string `json:"name"`
166 LSP LSPConfig `json:"lsp"`
167}
168
169func (l LSPs) Sorted() []LSP {
170 sorted := make([]LSP, 0, len(l))
171 for k, v := range l {
172 sorted = append(sorted, LSP{
173 Name: k,
174 LSP: v,
175 })
176 }
177 slices.SortFunc(sorted, func(a, b LSP) int {
178 return strings.Compare(a.Name, b.Name)
179 })
180 return sorted
181}
182
183func (m MCPConfig) ResolvedEnv() []string {
184 resolver := NewShellVariableResolver(env.New())
185 for e, v := range m.Env {
186 var err error
187 m.Env[e], err = resolver.ResolveValue(v)
188 if err != nil {
189 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
190 continue
191 }
192 }
193
194 env := make([]string, 0, len(m.Env))
195 for k, v := range m.Env {
196 env = append(env, fmt.Sprintf("%s=%s", k, v))
197 }
198 return env
199}
200
201func (m MCPConfig) ResolvedHeaders() map[string]string {
202 resolver := NewShellVariableResolver(env.New())
203 for e, v := range m.Headers {
204 var err error
205 m.Headers[e], err = resolver.ResolveValue(v)
206 if err != nil {
207 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
208 continue
209 }
210 }
211 return m.Headers
212}
213
214type Agent struct {
215 ID string `json:"id,omitempty"`
216 Name string `json:"name,omitempty"`
217 Description string `json:"description,omitempty"`
218 // This is the id of the system prompt used by the agent
219 Disabled bool `json:"disabled,omitempty"`
220
221 Model SelectedModelType `json:"model" jsonschema:"required,description=The model type to use for this agent,enum=large,enum=small,default=large"`
222
223 // The available tools for the agent
224 // if this is nil, all tools are available
225 AllowedTools []string `json:"allowed_tools,omitempty"`
226
227 // this tells us which MCPs are available for this agent
228 // if this is empty all mcps are available
229 // the string array is the list of tools from the AllowedMCP the agent has available
230 // if the string array is nil, all tools from the AllowedMCP are available
231 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
232
233 // The list of LSPs that this agent can use
234 // if this is nil, all LSPs are available
235 AllowedLSP []string `json:"allowed_lsp,omitempty"`
236
237 // Overrides the context paths for this agent
238 ContextPaths []string `json:"context_paths,omitempty"`
239}
240
241// Config holds the configuration for crush.
242type Config struct {
243 // We currently only support large/small as values here.
244 Models map[SelectedModelType]SelectedModel `json:"models,omitempty" jsonschema:"description=Model configurations for different model types,example={\"large\":{\"model\":\"gpt-4o\",\"provider\":\"openai\"}}"`
245
246 // The providers that are configured
247 Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty" jsonschema:"description=AI provider configurations"`
248
249 MCP MCPs `json:"mcp,omitempty" jsonschema:"description=Model Context Protocol server configurations"`
250
251 LSP LSPs `json:"lsp,omitempty" jsonschema:"description=Language Server Protocol configurations"`
252
253 Options *Options `json:"options,omitempty" jsonschema:"description=General application options"`
254
255 Permissions *Permissions `json:"permissions,omitempty" jsonschema:"description=Permission settings for tool usage"`
256
257 // Internal
258 workingDir string `json:"-"`
259 // TODO: most likely remove this concept when I come back to it
260 Agents map[string]Agent `json:"-"`
261 // TODO: find a better way to do this this should probably not be part of the config
262 resolver VariableResolver
263 dataConfigDir string `json:"-"`
264 knownProviders []catwalk.Provider `json:"-"`
265}
266
267func (c *Config) WorkingDir() string {
268 return c.workingDir
269}
270
271func (c *Config) EnabledProviders() []ProviderConfig {
272 var enabled []ProviderConfig
273 for p := range c.Providers.Seq() {
274 if !p.Disable {
275 enabled = append(enabled, p)
276 }
277 }
278 return enabled
279}
280
281// IsConfigured return true if at least one provider is configured
282func (c *Config) IsConfigured() bool {
283 return len(c.EnabledProviders()) > 0
284}
285
286func (c *Config) GetModel(provider, model string) *catwalk.Model {
287 if providerConfig, ok := c.Providers.Get(provider); ok {
288 for _, m := range providerConfig.Models {
289 if m.ID == model {
290 return &m
291 }
292 }
293 }
294 return nil
295}
296
297func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
298 model, ok := c.Models[modelType]
299 if !ok {
300 return nil
301 }
302 if providerConfig, ok := c.Providers.Get(model.Provider); ok {
303 return &providerConfig
304 }
305 return nil
306}
307
308func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
309 model, ok := c.Models[modelType]
310 if !ok {
311 return nil
312 }
313 return c.GetModel(model.Provider, model.Model)
314}
315
316func (c *Config) LargeModel() *catwalk.Model {
317 model, ok := c.Models[SelectedModelTypeLarge]
318 if !ok {
319 return nil
320 }
321 return c.GetModel(model.Provider, model.Model)
322}
323
324func (c *Config) SmallModel() *catwalk.Model {
325 model, ok := c.Models[SelectedModelTypeSmall]
326 if !ok {
327 return nil
328 }
329 return c.GetModel(model.Provider, model.Model)
330}
331
332func (c *Config) SetCompactMode(enabled bool) error {
333 if c.Options == nil {
334 c.Options = &Options{}
335 }
336 c.Options.TUI.CompactMode = enabled
337 return c.SetConfigField("options.tui.compact_mode", enabled)
338}
339
340func (c *Config) Resolve(key string) (string, error) {
341 if c.resolver == nil {
342 return "", fmt.Errorf("no variable resolver configured")
343 }
344 return c.resolver.ResolveValue(key)
345}
346
347func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
348 c.Models[modelType] = model
349 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
350 return fmt.Errorf("failed to update preferred model: %w", err)
351 }
352 return nil
353}
354
355func (c *Config) SetConfigField(key string, value any) error {
356 // read the data
357 data, err := os.ReadFile(c.dataConfigDir)
358 if err != nil {
359 if os.IsNotExist(err) {
360 data = []byte("{}")
361 } else {
362 return fmt.Errorf("failed to read config file: %w", err)
363 }
364 }
365
366 newValue, err := sjson.Set(string(data), key, value)
367 if err != nil {
368 return fmt.Errorf("failed to set config field %s: %w", key, err)
369 }
370 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
371 return fmt.Errorf("failed to write config file: %w", err)
372 }
373 return nil
374}
375
376func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
377 // First save to the config file
378 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
379 if err != nil {
380 return fmt.Errorf("failed to save API key to config file: %w", err)
381 }
382
383 providerConfig, exists := c.Providers.Get(providerID)
384 if exists {
385 providerConfig.APIKey = apiKey
386 c.Providers.Set(providerID, providerConfig)
387 return nil
388 }
389
390 var foundProvider *catwalk.Provider
391 for _, p := range c.knownProviders {
392 if string(p.ID) == providerID {
393 foundProvider = &p
394 break
395 }
396 }
397
398 if foundProvider != nil {
399 // Create new provider config based on known provider
400 providerConfig = ProviderConfig{
401 ID: providerID,
402 Name: foundProvider.Name,
403 BaseURL: foundProvider.APIEndpoint,
404 Type: foundProvider.Type,
405 APIKey: apiKey,
406 Disable: false,
407 ExtraHeaders: make(map[string]string),
408 ExtraParams: make(map[string]string),
409 Models: foundProvider.Models,
410 }
411 } else {
412 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
413 }
414 // Store the updated provider config
415 c.Providers.Set(providerID, providerConfig)
416 return nil
417}
418
419func (c *Config) SetupAgents() {
420 agents := map[string]Agent{
421 "coder": {
422 ID: "coder",
423 Name: "Coder",
424 Description: "An agent that helps with executing coding tasks.",
425 Model: SelectedModelTypeLarge,
426 ContextPaths: c.Options.ContextPaths,
427 // All tools allowed
428 },
429 "task": {
430 ID: "task",
431 Name: "Task",
432 Description: "An agent that helps with searching for context and finding implementation details.",
433 Model: SelectedModelTypeLarge,
434 ContextPaths: c.Options.ContextPaths,
435 AllowedTools: []string{
436 "glob",
437 "grep",
438 "ls",
439 "sourcegraph",
440 "view",
441 },
442 // NO MCPs or LSPs by default
443 AllowedMCP: map[string][]string{},
444 AllowedLSP: []string{},
445 },
446 }
447 c.Agents = agents
448}
449
450func (c *Config) Resolver() VariableResolver {
451 return c.resolver
452}
453
454func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
455 testURL := ""
456 headers := make(map[string]string)
457 apiKey, _ := resolver.ResolveValue(c.APIKey)
458 switch c.Type {
459 case catwalk.TypeOpenAI:
460 baseURL, _ := resolver.ResolveValue(c.BaseURL)
461 if baseURL == "" {
462 baseURL = "https://api.openai.com/v1"
463 }
464 testURL = baseURL + "/models"
465 headers["Authorization"] = "Bearer " + apiKey
466 case catwalk.TypeAnthropic:
467 baseURL, _ := resolver.ResolveValue(c.BaseURL)
468 if baseURL == "" {
469 baseURL = "https://api.anthropic.com/v1"
470 }
471 testURL = baseURL + "/models"
472 headers["x-api-key"] = apiKey
473 headers["anthropic-version"] = "2023-06-01"
474 }
475 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
476 defer cancel()
477 client := &http.Client{}
478 req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
479 if err != nil {
480 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
481 }
482 for k, v := range headers {
483 req.Header.Set(k, v)
484 }
485 for k, v := range c.ExtraHeaders {
486 req.Header.Set(k, v)
487 }
488 b, err := client.Do(req)
489 if err != nil {
490 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
491 }
492 if b.StatusCode != http.StatusOK {
493 return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
494 }
495 _ = b.Body.Close()
496 return nil
497}