1package config
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/env"
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/tidwall/sjson"
12 "golang.org/x/exp/slog"
13)
14
15const (
16 appName = "crush"
17 defaultDataDirectory = ".crush"
18 defaultLogLevel = "info"
19)
20
21var defaultContextPaths = []string{
22 ".github/copilot-instructions.md",
23 ".cursorrules",
24 ".cursor/rules/",
25 "CLAUDE.md",
26 "CLAUDE.local.md",
27 "GEMINI.md",
28 "gemini.md",
29 "crush.md",
30 "crush.local.md",
31 "Crush.md",
32 "Crush.local.md",
33 "CRUSH.md",
34 "CRUSH.local.md",
35}
36
37type SelectedModelType string
38
39const (
40 SelectedModelTypeLarge SelectedModelType = "large"
41 SelectedModelTypeSmall SelectedModelType = "small"
42)
43
44type SelectedModel struct {
45 // The model id as used by the provider API.
46 // Required.
47 Model string `json:"model"`
48 // The model provider, same as the key/id used in the providers config.
49 // Required.
50 Provider string `json:"provider"`
51
52 // Only used by models that use the openai provider and need this set.
53 ReasoningEffort string `json:"reasoning_effort,omitempty"`
54
55 // Overrides the default model configuration.
56 MaxTokens int64 `json:"max_tokens,omitempty"`
57
58 // Used by anthropic models that can reason to indicate if the model should think.
59 Think bool `json:"think,omitempty"`
60}
61
62type ProviderConfig struct {
63 // The provider's id.
64 ID string `json:"id,omitempty"`
65 // The provider's name, used for display purposes.
66 Name string `json:"name,omitempty"`
67 // The provider's API endpoint.
68 BaseURL string `json:"base_url,omitempty"`
69 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
70 Type provider.Type `json:"type,omitempty"`
71 // The provider's API key.
72 APIKey string `json:"api_key,omitempty"`
73 // Marks the provider as disabled.
74 Disable bool `json:"disable,omitempty"`
75
76 // Extra headers to send with each request to the provider.
77 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
78
79 // Used to pass extra parameters to the provider.
80 ExtraParams map[string]string `json:"-"`
81
82 // The provider models
83 Models []provider.Model `json:"models,omitempty"`
84}
85
86type MCPType string
87
88const (
89 MCPStdio MCPType = "stdio"
90 MCPSse MCPType = "sse"
91 MCPHttp MCPType = "http"
92)
93
94type MCPConfig struct {
95 Command string `json:"command,omitempty" `
96 Env map[string]string `json:"env,omitempty"`
97 Args []string `json:"args,omitempty"`
98 Type MCPType `json:"type"`
99 URL string `json:"url,omitempty"`
100 Disabled bool `json:"disabled,omitempty"`
101
102 // TODO: maybe make it possible to get the value from the env
103 Headers map[string]string `json:"headers,omitempty"`
104}
105
106type LSPConfig struct {
107 Disabled bool `json:"enabled,omitempty"`
108 Command string `json:"command"`
109 Args []string `json:"args,omitempty"`
110 Options any `json:"options,omitempty"`
111}
112
113type TUIOptions struct {
114 CompactMode bool `json:"compact_mode,omitempty"`
115 // Here we can add themes later or any TUI related options
116}
117
118type Options struct {
119 ContextPaths []string `json:"context_paths,omitempty"`
120 TUI *TUIOptions `json:"tui,omitempty"`
121 Debug bool `json:"debug,omitempty"`
122 DebugLSP bool `json:"debug_lsp,omitempty"`
123 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
124 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
125 SkipPermissionsRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
126}
127
128type MCPs map[string]MCPConfig
129
130type MCP struct {
131 Name string `json:"name"`
132 MCP MCPConfig `json:"mcp"`
133}
134
135func (m MCPs) Sorted() []MCP {
136 sorted := make([]MCP, 0, len(m))
137 for k, v := range m {
138 sorted = append(sorted, MCP{
139 Name: k,
140 MCP: v,
141 })
142 }
143 slices.SortFunc(sorted, func(a, b MCP) int {
144 return strings.Compare(a.Name, b.Name)
145 })
146 return sorted
147}
148
149type LSPs map[string]LSPConfig
150
151type LSP struct {
152 Name string `json:"name"`
153 LSP LSPConfig `json:"lsp"`
154}
155
156func (l LSPs) Sorted() []LSP {
157 sorted := make([]LSP, 0, len(l))
158 for k, v := range l {
159 sorted = append(sorted, LSP{
160 Name: k,
161 LSP: v,
162 })
163 }
164 slices.SortFunc(sorted, func(a, b LSP) int {
165 return strings.Compare(a.Name, b.Name)
166 })
167 return sorted
168}
169
170func (m MCPConfig) ResolvedEnv() []string {
171 resolver := NewShellVariableResolver(env.New())
172 for e, v := range m.Env {
173 var err error
174 m.Env[e], err = resolver.ResolveValue(v)
175 if err != nil {
176 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
177 continue
178 }
179 }
180
181 env := make([]string, 0, len(m.Env))
182 for k, v := range m.Env {
183 env = append(env, fmt.Sprintf("%s=%s", k, v))
184 }
185 return env
186}
187
188func (m MCPConfig) ResolvedHeaders() map[string]string {
189 resolver := NewShellVariableResolver(env.New())
190 for e, v := range m.Headers {
191 var err error
192 m.Headers[e], err = resolver.ResolveValue(v)
193 if err != nil {
194 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
195 continue
196 }
197 }
198 return m.Headers
199}
200
201type Agent struct {
202 ID string `json:"id,omitempty"`
203 Name string `json:"name,omitempty"`
204 Description string `json:"description,omitempty"`
205 // This is the id of the system prompt used by the agent
206 Disabled bool `json:"disabled,omitempty"`
207
208 Model SelectedModelType `json:"model"`
209
210 // The available tools for the agent
211 // if this is nil, all tools are available
212 AllowedTools []string `json:"allowed_tools,omitempty"`
213
214 // this tells us which MCPs are available for this agent
215 // if this is empty all mcps are available
216 // the string array is the list of tools from the AllowedMCP the agent has available
217 // if the string array is nil, all tools from the AllowedMCP are available
218 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
219
220 // The list of LSPs that this agent can use
221 // if this is nil, all LSPs are available
222 AllowedLSP []string `json:"allowed_lsp,omitempty"`
223
224 // Overrides the context paths for this agent
225 ContextPaths []string `json:"context_paths,omitempty"`
226}
227
228// Config holds the configuration for crush.
229type Config struct {
230 // We currently only support large/small as values here.
231 Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
232
233 // The providers that are configured
234 Providers map[string]ProviderConfig `json:"providers,omitempty"`
235
236 MCP MCPs `json:"mcp,omitempty"`
237
238 LSP LSPs `json:"lsp,omitempty"`
239
240 Options *Options `json:"options,omitempty"`
241
242 // Internal
243 workingDir string `json:"-"`
244 // TODO: most likely remove this concept when I come back to it
245 Agents map[string]Agent `json:"-"`
246 // TODO: find a better way to do this this should probably not be part of the config
247 resolver VariableResolver
248 dataConfigDir string `json:"-"`
249 knownProviders []provider.Provider `json:"-"`
250}
251
252func (c *Config) WorkingDir() string {
253 return c.workingDir
254}
255
256func (c *Config) EnabledProviders() []ProviderConfig {
257 enabled := make([]ProviderConfig, 0, len(c.Providers))
258 for _, p := range c.Providers {
259 if !p.Disable {
260 enabled = append(enabled, p)
261 }
262 }
263 return enabled
264}
265
266// IsConfigured return true if at least one provider is configured
267func (c *Config) IsConfigured() bool {
268 return len(c.EnabledProviders()) > 0
269}
270
271func (c *Config) GetModel(provider, model string) *provider.Model {
272 if providerConfig, ok := c.Providers[provider]; ok {
273 for _, m := range providerConfig.Models {
274 if m.ID == model {
275 return &m
276 }
277 }
278 }
279 return nil
280}
281
282func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
283 model, ok := c.Models[modelType]
284 if !ok {
285 return nil
286 }
287 if providerConfig, ok := c.Providers[model.Provider]; ok {
288 return &providerConfig
289 }
290 return nil
291}
292
293func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
294 model, ok := c.Models[modelType]
295 if !ok {
296 return nil
297 }
298 return c.GetModel(model.Provider, model.Model)
299}
300
301func (c *Config) LargeModel() *provider.Model {
302 model, ok := c.Models[SelectedModelTypeLarge]
303 if !ok {
304 return nil
305 }
306 return c.GetModel(model.Provider, model.Model)
307}
308
309func (c *Config) SmallModel() *provider.Model {
310 model, ok := c.Models[SelectedModelTypeSmall]
311 if !ok {
312 return nil
313 }
314 return c.GetModel(model.Provider, model.Model)
315}
316
317func (c *Config) SetCompactMode(enabled bool) error {
318 if c.Options == nil {
319 c.Options = &Options{}
320 }
321 c.Options.TUI.CompactMode = enabled
322 return c.SetConfigField("options.tui.compact_mode", enabled)
323}
324
325func (c *Config) Resolve(key string) (string, error) {
326 if c.resolver == nil {
327 return "", fmt.Errorf("no variable resolver configured")
328 }
329 return c.resolver.ResolveValue(key)
330}
331
332func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
333 c.Models[modelType] = model
334 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
335 return fmt.Errorf("failed to update preferred model: %w", err)
336 }
337 return nil
338}
339
340func (c *Config) SetConfigField(key string, value any) error {
341 // read the data
342 data, err := os.ReadFile(c.dataConfigDir)
343 if err != nil {
344 if os.IsNotExist(err) {
345 data = []byte("{}")
346 } else {
347 return fmt.Errorf("failed to read config file: %w", err)
348 }
349 }
350
351 newValue, err := sjson.Set(string(data), key, value)
352 if err != nil {
353 return fmt.Errorf("failed to set config field %s: %w", key, err)
354 }
355 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
356 return fmt.Errorf("failed to write config file: %w", err)
357 }
358 return nil
359}
360
361func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
362 // First save to the config file
363 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
364 if err != nil {
365 return fmt.Errorf("failed to save API key to config file: %w", err)
366 }
367
368 if c.Providers == nil {
369 c.Providers = make(map[string]ProviderConfig)
370 }
371
372 providerConfig, exists := c.Providers[providerID]
373 if exists {
374 providerConfig.APIKey = apiKey
375 c.Providers[providerID] = providerConfig
376 return nil
377 }
378
379 var foundProvider *provider.Provider
380 for _, p := range c.knownProviders {
381 if string(p.ID) == providerID {
382 foundProvider = &p
383 break
384 }
385 }
386
387 if foundProvider != nil {
388 // Create new provider config based on known provider
389 providerConfig = ProviderConfig{
390 ID: providerID,
391 Name: foundProvider.Name,
392 BaseURL: foundProvider.APIEndpoint,
393 Type: foundProvider.Type,
394 APIKey: apiKey,
395 Disable: false,
396 ExtraHeaders: make(map[string]string),
397 ExtraParams: make(map[string]string),
398 Models: foundProvider.Models,
399 }
400 } else {
401 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
402 }
403 // Store the updated provider config
404 c.Providers[providerID] = providerConfig
405 return nil
406}
407
408func (c *Config) SetupAgents() {
409 agents := map[string]Agent{
410 "coder": {
411 ID: "coder",
412 Name: "Coder",
413 Description: "An agent that helps with executing coding tasks.",
414 Model: SelectedModelTypeLarge,
415 ContextPaths: c.Options.ContextPaths,
416 // All tools allowed
417 },
418 "task": {
419 ID: "task",
420 Name: "Task",
421 Description: "An agent that helps with searching for context and finding implementation details.",
422 Model: SelectedModelTypeLarge,
423 ContextPaths: c.Options.ContextPaths,
424 AllowedTools: []string{
425 "glob",
426 "grep",
427 "ls",
428 "sourcegraph",
429 "view",
430 },
431 // NO MCPs or LSPs by default
432 AllowedMCP: map[string][]string{},
433 AllowedLSP: []string{},
434 },
435 }
436 c.Agents = agents
437}