1package config
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/fur/provider"
10 "github.com/tidwall/sjson"
11)
12
13const (
14 appName = "crush"
15 defaultDataDirectory = ".crush"
16 defaultLogLevel = "info"
17)
18
19var defaultContextPaths = []string{
20 ".github/copilot-instructions.md",
21 ".cursorrules",
22 ".cursor/rules/",
23 "CLAUDE.md",
24 "CLAUDE.local.md",
25 "GEMINI.md",
26 "gemini.md",
27 "crush.md",
28 "crush.local.md",
29 "Crush.md",
30 "Crush.local.md",
31 "CRUSH.md",
32 "CRUSH.local.md",
33}
34
35type SelectedModelType string
36
37const (
38 SelectedModelTypeLarge SelectedModelType = "large"
39 SelectedModelTypeSmall SelectedModelType = "small"
40)
41
42type SelectedModel struct {
43 // The model id as used by the provider API.
44 // Required.
45 Model string `json:"model"`
46 // The model provider, same as the key/id used in the providers config.
47 // Required.
48 Provider string `json:"provider"`
49
50 // Only used by models that use the openai provider and need this set.
51 ReasoningEffort string `json:"reasoning_effort,omitempty"`
52
53 // Overrides the default model configuration.
54 MaxTokens int64 `json:"max_tokens,omitempty"`
55
56 // Used by anthropic models that can reason to indicate if the model should think.
57 Think bool `json:"think,omitempty"`
58}
59
60type ProviderConfig struct {
61 // The provider's id.
62 ID string `json:"id,omitempty"`
63 // The provider's name, used for display purposes.
64 Name string `json:"name,omitempty"`
65 // The provider's API endpoint.
66 BaseURL string `json:"base_url,omitempty"`
67 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
68 Type provider.Type `json:"type,omitempty"`
69 // The provider's API key.
70 APIKey string `json:"api_key,omitempty"`
71 // Marks the provider as disabled.
72 Disable bool `json:"disable,omitempty"`
73
74 // Extra headers to send with each request to the provider.
75 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
76
77 // Used to pass extra parameters to the provider.
78 ExtraParams map[string]string `json:"-"`
79
80 // The provider models
81 Models []provider.Model `json:"models,omitempty"`
82}
83
84type MCPType string
85
86const (
87 MCPStdio MCPType = "stdio"
88 MCPSse MCPType = "sse"
89 MCPHttp MCPType = "http"
90)
91
92type MCPConfig struct {
93 Command string `json:"command,omitempty" `
94 Env []string `json:"env,omitempty"`
95 Args []string `json:"args,omitempty"`
96 Type MCPType `json:"type"`
97 URL string `json:"url,omitempty"`
98 Disabled bool `json:"disabled,omitempty"`
99
100 // TODO: maybe make it possible to get the value from the env
101 Headers map[string]string `json:"headers,omitempty"`
102}
103
104type LSPConfig struct {
105 Disabled bool `json:"enabled,omitempty"`
106 Command string `json:"command"`
107 Args []string `json:"args,omitempty"`
108 Options any `json:"options,omitempty"`
109}
110
111type TUIOptions struct {
112 CompactMode bool `json:"compact_mode,omitempty"`
113 // Here we can add themes later or any TUI related options
114}
115
116type Options struct {
117 ContextPaths []string `json:"context_paths,omitempty"`
118 TUI *TUIOptions `json:"tui,omitempty"`
119 Debug bool `json:"debug,omitempty"`
120 DebugLSP bool `json:"debug_lsp,omitempty"`
121 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
122 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
123 SkipPermissionsRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
124}
125
126type MCPs map[string]MCPConfig
127
128type MCP struct {
129 Name string `json:"name"`
130 MCP MCPConfig `json:"mcp"`
131}
132
133func (m MCPs) Sorted() []MCP {
134 sorted := make([]MCP, 0, len(m))
135 for k, v := range m {
136 sorted = append(sorted, MCP{
137 Name: k,
138 MCP: v,
139 })
140 }
141 slices.SortFunc(sorted, func(a, b MCP) int {
142 return strings.Compare(a.Name, b.Name)
143 })
144 return sorted
145}
146
147type LSPs map[string]LSPConfig
148
149type LSP struct {
150 Name string `json:"name"`
151 LSP LSPConfig `json:"lsp"`
152}
153
154func (l LSPs) Sorted() []LSP {
155 sorted := make([]LSP, 0, len(l))
156 for k, v := range l {
157 sorted = append(sorted, LSP{
158 Name: k,
159 LSP: v,
160 })
161 }
162 slices.SortFunc(sorted, func(a, b LSP) int {
163 return strings.Compare(a.Name, b.Name)
164 })
165 return sorted
166}
167
168type Agent struct {
169 ID string `json:"id,omitempty"`
170 Name string `json:"name,omitempty"`
171 Description string `json:"description,omitempty"`
172 // This is the id of the system prompt used by the agent
173 Disabled bool `json:"disabled,omitempty"`
174
175 Model SelectedModelType `json:"model"`
176
177 // The available tools for the agent
178 // if this is nil, all tools are available
179 AllowedTools []string `json:"allowed_tools,omitempty"`
180
181 // this tells us which MCPs are available for this agent
182 // if this is empty all mcps are available
183 // the string array is the list of tools from the AllowedMCP the agent has available
184 // if the string array is nil, all tools from the AllowedMCP are available
185 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
186
187 // The list of LSPs that this agent can use
188 // if this is nil, all LSPs are available
189 AllowedLSP []string `json:"allowed_lsp,omitempty"`
190
191 // Overrides the context paths for this agent
192 ContextPaths []string `json:"context_paths,omitempty"`
193}
194
195// Config holds the configuration for crush.
196type Config struct {
197 // We currently only support large/small as values here.
198 Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
199
200 // The providers that are configured
201 Providers map[string]ProviderConfig `json:"providers,omitempty"`
202
203 MCP MCPs `json:"mcp,omitempty"`
204
205 LSP LSPs `json:"lsp,omitempty"`
206
207 Options *Options `json:"options,omitempty"`
208
209 // Internal
210 workingDir string `json:"-"`
211 // TODO: most likely remove this concept when I come back to it
212 Agents map[string]Agent `json:"-"`
213 // TODO: find a better way to do this this should probably not be part of the config
214 resolver VariableResolver
215 dataConfigDir string `json:"-"`
216 knownProviders []provider.Provider `json:"-"`
217}
218
219func (c *Config) WorkingDir() string {
220 return c.workingDir
221}
222
223func (c *Config) EnabledProviders() []ProviderConfig {
224 enabled := make([]ProviderConfig, 0, len(c.Providers))
225 for _, p := range c.Providers {
226 if !p.Disable {
227 enabled = append(enabled, p)
228 }
229 }
230 return enabled
231}
232
233// IsConfigured return true if at least one provider is configured
234func (c *Config) IsConfigured() bool {
235 return len(c.EnabledProviders()) > 0
236}
237
238func (c *Config) GetModel(provider, model string) *provider.Model {
239 if providerConfig, ok := c.Providers[provider]; ok {
240 for _, m := range providerConfig.Models {
241 if m.ID == model {
242 return &m
243 }
244 }
245 }
246 return nil
247}
248
249func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
250 model, ok := c.Models[modelType]
251 if !ok {
252 return nil
253 }
254 if providerConfig, ok := c.Providers[model.Provider]; ok {
255 return &providerConfig
256 }
257 return nil
258}
259
260func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
261 model, ok := c.Models[modelType]
262 if !ok {
263 return nil
264 }
265 return c.GetModel(model.Provider, model.Model)
266}
267
268func (c *Config) LargeModel() *provider.Model {
269 model, ok := c.Models[SelectedModelTypeLarge]
270 if !ok {
271 return nil
272 }
273 return c.GetModel(model.Provider, model.Model)
274}
275
276func (c *Config) SmallModel() *provider.Model {
277 model, ok := c.Models[SelectedModelTypeSmall]
278 if !ok {
279 return nil
280 }
281 return c.GetModel(model.Provider, model.Model)
282}
283
284func (c *Config) SetCompactMode(enabled bool) error {
285 if c.Options == nil {
286 c.Options = &Options{}
287 }
288 c.Options.TUI.CompactMode = enabled
289 return c.SetConfigField("options.tui.compact_mode", enabled)
290}
291
292func (c *Config) Resolve(key string) (string, error) {
293 if c.resolver == nil {
294 return "", fmt.Errorf("no variable resolver configured")
295 }
296 return c.resolver.ResolveValue(key)
297}
298
299func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
300 c.Models[modelType] = model
301 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
302 return fmt.Errorf("failed to update preferred model: %w", err)
303 }
304 return nil
305}
306
307func (c *Config) SetConfigField(key string, value any) error {
308 // read the data
309 data, err := os.ReadFile(c.dataConfigDir)
310 if err != nil {
311 if os.IsNotExist(err) {
312 data = []byte("{}")
313 } else {
314 return fmt.Errorf("failed to read config file: %w", err)
315 }
316 }
317
318 newValue, err := sjson.Set(string(data), key, value)
319 if err != nil {
320 return fmt.Errorf("failed to set config field %s: %w", key, err)
321 }
322 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
323 return fmt.Errorf("failed to write config file: %w", err)
324 }
325 return nil
326}
327
328func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
329 // First save to the config file
330 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
331 if err != nil {
332 return fmt.Errorf("failed to save API key to config file: %w", err)
333 }
334
335 if c.Providers == nil {
336 c.Providers = make(map[string]ProviderConfig)
337 }
338
339 providerConfig, exists := c.Providers[providerID]
340 if exists {
341 providerConfig.APIKey = apiKey
342 c.Providers[providerID] = providerConfig
343 return nil
344 }
345
346 var foundProvider *provider.Provider
347 for _, p := range c.knownProviders {
348 if string(p.ID) == providerID {
349 foundProvider = &p
350 break
351 }
352 }
353
354 if foundProvider != nil {
355 // Create new provider config based on known provider
356 providerConfig = ProviderConfig{
357 ID: providerID,
358 Name: foundProvider.Name,
359 BaseURL: foundProvider.APIEndpoint,
360 Type: foundProvider.Type,
361 APIKey: apiKey,
362 Disable: false,
363 ExtraHeaders: make(map[string]string),
364 ExtraParams: make(map[string]string),
365 Models: foundProvider.Models,
366 }
367 } else {
368 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
369 }
370 // Store the updated provider config
371 c.Providers[providerID] = providerConfig
372 return nil
373}
374
375func (c *Config) SetupAgents() {
376 agents := map[string]Agent{
377 "coder": {
378 ID: "coder",
379 Name: "Coder",
380 Description: "An agent that helps with executing coding tasks.",
381 Model: SelectedModelTypeLarge,
382 ContextPaths: c.Options.ContextPaths,
383 // All tools allowed
384 },
385 "task": {
386 ID: "task",
387 Name: "Task",
388 Description: "An agent that helps with searching for context and finding implementation details.",
389 Model: SelectedModelTypeLarge,
390 ContextPaths: c.Options.ContextPaths,
391 AllowedTools: []string{
392 "glob",
393 "grep",
394 "ls",
395 "sourcegraph",
396 "view",
397 },
398 // NO MCPs or LSPs by default
399 AllowedMCP: map[string][]string{},
400 AllowedLSP: []string{},
401 },
402 }
403 c.Agents = agents
404}