1package config
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/env"
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/tidwall/sjson"
12 "golang.org/x/exp/slog"
13)
14
15const (
16 appName = "crush"
17 defaultDataDirectory = ".crush"
18 defaultLogLevel = "info"
19)
20
21var defaultContextPaths = []string{
22 ".github/copilot-instructions.md",
23 ".cursorrules",
24 ".cursor/rules/",
25 "CLAUDE.md",
26 "CLAUDE.local.md",
27 "GEMINI.md",
28 "gemini.md",
29 "crush.md",
30 "crush.local.md",
31 "Crush.md",
32 "Crush.local.md",
33 "CRUSH.md",
34 "CRUSH.local.md",
35}
36
37type SelectedModelType string
38
39const (
40 SelectedModelTypeLarge SelectedModelType = "large"
41 SelectedModelTypeSmall SelectedModelType = "small"
42)
43
44type SelectedModel struct {
45 // The model id as used by the provider API.
46 // Required.
47 Model string `json:"model"`
48 // The model provider, same as the key/id used in the providers config.
49 // Required.
50 Provider string `json:"provider"`
51
52 // Only used by models that use the openai provider and need this set.
53 ReasoningEffort string `json:"reasoning_effort,omitempty"`
54
55 // Overrides the default model configuration.
56 MaxTokens int64 `json:"max_tokens,omitempty"`
57
58 // Used by anthropic models that can reason to indicate if the model should think.
59 Think bool `json:"think,omitempty"`
60}
61
62type ProviderConfig struct {
63 // The provider's id.
64 ID string `json:"id,omitempty"`
65 // The provider's name, used for display purposes.
66 Name string `json:"name,omitempty"`
67 // The provider's API endpoint.
68 BaseURL string `json:"base_url,omitempty"`
69 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
70 Type provider.Type `json:"type,omitempty"`
71 // The provider's API key.
72 APIKey string `json:"api_key,omitempty"`
73 // Marks the provider as disabled.
74 Disable bool `json:"disable,omitempty"`
75
76 // Extra headers to send with each request to the provider.
77 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
78 // Extra body
79 ExtraBody map[string]any `json:"extra_body,omitempty"`
80
81 // Used to pass extra parameters to the provider.
82 ExtraParams map[string]string `json:"-"`
83
84 // The provider models
85 Models []provider.Model `json:"models,omitempty"`
86}
87
88type MCPType string
89
90const (
91 MCPStdio MCPType = "stdio"
92 MCPSse MCPType = "sse"
93 MCPHttp MCPType = "http"
94)
95
96type MCPConfig struct {
97 Command string `json:"command,omitempty" `
98 Env map[string]string `json:"env,omitempty"`
99 Args []string `json:"args,omitempty"`
100 Type MCPType `json:"type"`
101 URL string `json:"url,omitempty"`
102 Disabled bool `json:"disabled,omitempty"`
103
104 // TODO: maybe make it possible to get the value from the env
105 Headers map[string]string `json:"headers,omitempty"`
106}
107
108type LSPConfig struct {
109 Disabled bool `json:"enabled,omitempty"`
110 Command string `json:"command"`
111 Args []string `json:"args,omitempty"`
112 Options any `json:"options,omitempty"`
113}
114
115type TUIOptions struct {
116 CompactMode bool `json:"compact_mode,omitempty"`
117 // Here we can add themes later or any TUI related options
118}
119
120type Options struct {
121 ContextPaths []string `json:"context_paths,omitempty"`
122 TUI *TUIOptions `json:"tui,omitempty"`
123 Debug bool `json:"debug,omitempty"`
124 DebugLSP bool `json:"debug_lsp,omitempty"`
125 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
126 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
127 SkipPermissionsRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
128}
129
130type MCPs map[string]MCPConfig
131
132type MCP struct {
133 Name string `json:"name"`
134 MCP MCPConfig `json:"mcp"`
135}
136
137func (m MCPs) Sorted() []MCP {
138 sorted := make([]MCP, 0, len(m))
139 for k, v := range m {
140 sorted = append(sorted, MCP{
141 Name: k,
142 MCP: v,
143 })
144 }
145 slices.SortFunc(sorted, func(a, b MCP) int {
146 return strings.Compare(a.Name, b.Name)
147 })
148 return sorted
149}
150
151type LSPs map[string]LSPConfig
152
153type LSP struct {
154 Name string `json:"name"`
155 LSP LSPConfig `json:"lsp"`
156}
157
158func (l LSPs) Sorted() []LSP {
159 sorted := make([]LSP, 0, len(l))
160 for k, v := range l {
161 sorted = append(sorted, LSP{
162 Name: k,
163 LSP: v,
164 })
165 }
166 slices.SortFunc(sorted, func(a, b LSP) int {
167 return strings.Compare(a.Name, b.Name)
168 })
169 return sorted
170}
171
172func (m MCPConfig) ResolvedEnv() []string {
173 resolver := NewShellVariableResolver(env.New())
174 for e, v := range m.Env {
175 var err error
176 m.Env[e], err = resolver.ResolveValue(v)
177 if err != nil {
178 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
179 continue
180 }
181 }
182
183 env := make([]string, 0, len(m.Env))
184 for k, v := range m.Env {
185 env = append(env, fmt.Sprintf("%s=%s", k, v))
186 }
187 return env
188}
189
190func (m MCPConfig) ResolvedHeaders() map[string]string {
191 resolver := NewShellVariableResolver(env.New())
192 for e, v := range m.Headers {
193 var err error
194 m.Headers[e], err = resolver.ResolveValue(v)
195 if err != nil {
196 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
197 continue
198 }
199 }
200 return m.Headers
201}
202
203type Agent struct {
204 ID string `json:"id,omitempty"`
205 Name string `json:"name,omitempty"`
206 Description string `json:"description,omitempty"`
207 // This is the id of the system prompt used by the agent
208 Disabled bool `json:"disabled,omitempty"`
209
210 Model SelectedModelType `json:"model"`
211
212 // The available tools for the agent
213 // if this is nil, all tools are available
214 AllowedTools []string `json:"allowed_tools,omitempty"`
215
216 // this tells us which MCPs are available for this agent
217 // if this is empty all mcps are available
218 // the string array is the list of tools from the AllowedMCP the agent has available
219 // if the string array is nil, all tools from the AllowedMCP are available
220 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
221
222 // The list of LSPs that this agent can use
223 // if this is nil, all LSPs are available
224 AllowedLSP []string `json:"allowed_lsp,omitempty"`
225
226 // Overrides the context paths for this agent
227 ContextPaths []string `json:"context_paths,omitempty"`
228}
229
230// Config holds the configuration for crush.
231type Config struct {
232 // We currently only support large/small as values here.
233 Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
234
235 // The providers that are configured
236 Providers map[string]ProviderConfig `json:"providers,omitempty"`
237
238 MCP MCPs `json:"mcp,omitempty"`
239
240 LSP LSPs `json:"lsp,omitempty"`
241
242 Options *Options `json:"options,omitempty"`
243
244 // Internal
245 workingDir string `json:"-"`
246 // TODO: most likely remove this concept when I come back to it
247 Agents map[string]Agent `json:"-"`
248 // TODO: find a better way to do this this should probably not be part of the config
249 resolver VariableResolver
250 dataConfigDir string `json:"-"`
251 knownProviders []provider.Provider `json:"-"`
252}
253
254func (c *Config) WorkingDir() string {
255 return c.workingDir
256}
257
258func (c *Config) EnabledProviders() []ProviderConfig {
259 enabled := make([]ProviderConfig, 0, len(c.Providers))
260 for _, p := range c.Providers {
261 if !p.Disable {
262 enabled = append(enabled, p)
263 }
264 }
265 return enabled
266}
267
268// IsConfigured return true if at least one provider is configured
269func (c *Config) IsConfigured() bool {
270 return len(c.EnabledProviders()) > 0
271}
272
273func (c *Config) GetModel(provider, model string) *provider.Model {
274 if providerConfig, ok := c.Providers[provider]; ok {
275 for _, m := range providerConfig.Models {
276 if m.ID == model {
277 return &m
278 }
279 }
280 }
281 return nil
282}
283
284func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
285 model, ok := c.Models[modelType]
286 if !ok {
287 return nil
288 }
289 if providerConfig, ok := c.Providers[model.Provider]; ok {
290 return &providerConfig
291 }
292 return nil
293}
294
295func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
296 model, ok := c.Models[modelType]
297 if !ok {
298 return nil
299 }
300 return c.GetModel(model.Provider, model.Model)
301}
302
303func (c *Config) LargeModel() *provider.Model {
304 model, ok := c.Models[SelectedModelTypeLarge]
305 if !ok {
306 return nil
307 }
308 return c.GetModel(model.Provider, model.Model)
309}
310
311func (c *Config) SmallModel() *provider.Model {
312 model, ok := c.Models[SelectedModelTypeSmall]
313 if !ok {
314 return nil
315 }
316 return c.GetModel(model.Provider, model.Model)
317}
318
319func (c *Config) SetCompactMode(enabled bool) error {
320 if c.Options == nil {
321 c.Options = &Options{}
322 }
323 c.Options.TUI.CompactMode = enabled
324 return c.SetConfigField("options.tui.compact_mode", enabled)
325}
326
327func (c *Config) Resolve(key string) (string, error) {
328 if c.resolver == nil {
329 return "", fmt.Errorf("no variable resolver configured")
330 }
331 return c.resolver.ResolveValue(key)
332}
333
334func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
335 c.Models[modelType] = model
336 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
337 return fmt.Errorf("failed to update preferred model: %w", err)
338 }
339 return nil
340}
341
342func (c *Config) SetConfigField(key string, value any) error {
343 // read the data
344 data, err := os.ReadFile(c.dataConfigDir)
345 if err != nil {
346 if os.IsNotExist(err) {
347 data = []byte("{}")
348 } else {
349 return fmt.Errorf("failed to read config file: %w", err)
350 }
351 }
352
353 newValue, err := sjson.Set(string(data), key, value)
354 if err != nil {
355 return fmt.Errorf("failed to set config field %s: %w", key, err)
356 }
357 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
358 return fmt.Errorf("failed to write config file: %w", err)
359 }
360 return nil
361}
362
363func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
364 // First save to the config file
365 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
366 if err != nil {
367 return fmt.Errorf("failed to save API key to config file: %w", err)
368 }
369
370 if c.Providers == nil {
371 c.Providers = make(map[string]ProviderConfig)
372 }
373
374 providerConfig, exists := c.Providers[providerID]
375 if exists {
376 providerConfig.APIKey = apiKey
377 c.Providers[providerID] = providerConfig
378 return nil
379 }
380
381 var foundProvider *provider.Provider
382 for _, p := range c.knownProviders {
383 if string(p.ID) == providerID {
384 foundProvider = &p
385 break
386 }
387 }
388
389 if foundProvider != nil {
390 // Create new provider config based on known provider
391 providerConfig = ProviderConfig{
392 ID: providerID,
393 Name: foundProvider.Name,
394 BaseURL: foundProvider.APIEndpoint,
395 Type: foundProvider.Type,
396 APIKey: apiKey,
397 Disable: false,
398 ExtraHeaders: make(map[string]string),
399 ExtraParams: make(map[string]string),
400 Models: foundProvider.Models,
401 }
402 } else {
403 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
404 }
405 // Store the updated provider config
406 c.Providers[providerID] = providerConfig
407 return nil
408}
409
410func (c *Config) SetupAgents() {
411 agents := map[string]Agent{
412 "coder": {
413 ID: "coder",
414 Name: "Coder",
415 Description: "An agent that helps with executing coding tasks.",
416 Model: SelectedModelTypeLarge,
417 ContextPaths: c.Options.ContextPaths,
418 // All tools allowed
419 },
420 "task": {
421 ID: "task",
422 Name: "Task",
423 Description: "An agent that helps with searching for context and finding implementation details.",
424 Model: SelectedModelTypeLarge,
425 ContextPaths: c.Options.ContextPaths,
426 AllowedTools: []string{
427 "glob",
428 "grep",
429 "ls",
430 "sourcegraph",
431 "view",
432 },
433 // NO MCPs or LSPs by default
434 AllowedMCP: map[string][]string{},
435 AllowedLSP: []string{},
436 },
437 }
438 c.Agents = agents
439}