1package config
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/csync"
11 "github.com/charmbracelet/crush/internal/llm/agent"
12 "github.com/charmbracelet/crush/internal/llm/provider"
13 "github.com/charmbracelet/crush/internal/resolver"
14 "github.com/tidwall/sjson"
15)
16
17const (
18 appName = "crush"
19 defaultDataDirectory = ".crush"
20 defaultLogLevel = "info"
21)
22
23var defaultContextPaths = []string{
24 ".github/copilot-instructions.md",
25 ".cursorrules",
26 ".cursor/rules/",
27 "CLAUDE.md",
28 "CLAUDE.local.md",
29 "GEMINI.md",
30 "gemini.md",
31 "crush.md",
32 "crush.local.md",
33 "Crush.md",
34 "Crush.local.md",
35 "CRUSH.md",
36 "CRUSH.local.md",
37}
38
39type SelectedModelType string
40
41const (
42 SelectedModelTypeLarge SelectedModelType = "large"
43 SelectedModelTypeSmall SelectedModelType = "small"
44)
45
46type LSPConfig struct {
47 Disabled bool `json:"enabled,omitempty"`
48 Command string `json:"command"`
49 Args []string `json:"args,omitempty"`
50 Options any `json:"options,omitempty"`
51}
52
53type TUIOptions struct {
54 CompactMode bool `json:"compact_mode,omitempty"`
55 // Here we can add themes later or any TUI related options
56}
57
58type Permissions struct {
59 AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
60 SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
61}
62
63type Options struct {
64 ContextPaths []string `json:"context_paths,omitempty"`
65 TUI *TUIOptions `json:"tui,omitempty"`
66 Debug bool `json:"debug,omitempty"`
67 DebugLSP bool `json:"debug_lsp,omitempty"`
68 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
69 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
70}
71
72type MCPs map[string]agent.MCPConfig
73
74type MCP struct {
75 Name string `json:"name"`
76 MCP agent.MCPConfig `json:"mcp"`
77}
78
79func (m MCPs) Sorted() []MCP {
80 sorted := make([]MCP, 0, len(m))
81 for k, v := range m {
82 sorted = append(sorted, MCP{
83 Name: k,
84 MCP: v,
85 })
86 }
87 slices.SortFunc(sorted, func(a, b MCP) int {
88 return strings.Compare(a.Name, b.Name)
89 })
90 return sorted
91}
92
93type LSPs map[string]LSPConfig
94
95type LSP struct {
96 Name string `json:"name"`
97 LSP LSPConfig `json:"lsp"`
98}
99
100func (l LSPs) Sorted() []LSP {
101 sorted := make([]LSP, 0, len(l))
102 for k, v := range l {
103 sorted = append(sorted, LSP{
104 Name: k,
105 LSP: v,
106 })
107 }
108 slices.SortFunc(sorted, func(a, b LSP) int {
109 return strings.Compare(a.Name, b.Name)
110 })
111 return sorted
112}
113
114// Config holds the configuration for crush.
115type Config struct {
116 // We currently only support large/small as values here.
117 Models map[SelectedModelType]agent.Model `json:"models,omitempty"`
118
119 // The providers that are configured
120 Providers *csync.Map[string, provider.Config] `json:"providers,omitempty"`
121
122 MCP MCPs `json:"mcp,omitempty"`
123
124 LSP LSPs `json:"lsp,omitempty"`
125
126 Options *Options `json:"options,omitempty"`
127
128 Permissions *Permissions `json:"permissions,omitempty"`
129
130 // Internal
131 workingDir string `json:"-"`
132 // TODO: find a better way to do this this should probably not be part of the config
133 resolver resolver.Resolver
134 dataConfigDir string `json:"-"`
135 knownProviders []catwalk.Provider `json:"-"`
136}
137
138func (c *Config) WorkingDir() string {
139 return c.workingDir
140}
141
142func (c *Config) EnabledProviders() []provider.Config {
143 var enabled []provider.Config
144 for p := range c.Providers.Seq() {
145 if !p.Disable {
146 enabled = append(enabled, p)
147 }
148 }
149 return enabled
150}
151
152// IsConfigured return true if at least one provider is configured
153func (c *Config) IsConfigured() bool {
154 return len(c.EnabledProviders()) > 0
155}
156
157func (c *Config) GetModel(provider, model string) *catwalk.Model {
158 if providerConfig, ok := c.Providers.Get(provider); ok {
159 for _, m := range providerConfig.Models {
160 if m.ID == model {
161 return &m
162 }
163 }
164 }
165 return nil
166}
167
168func (c *Config) GetProviderForModel(modelType SelectedModelType) *provider.Config {
169 model, ok := c.Models[modelType]
170 if !ok {
171 return nil
172 }
173 if providerConfig, ok := c.Providers.Get(model.Provider); ok {
174 return &providerConfig
175 }
176 return nil
177}
178
179func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
180 model, ok := c.Models[modelType]
181 if !ok {
182 return nil
183 }
184 return c.GetModel(model.Provider, model.Model)
185}
186
187func (c *Config) LargeModel() *catwalk.Model {
188 model, ok := c.Models[SelectedModelTypeLarge]
189 if !ok {
190 return nil
191 }
192 return c.GetModel(model.Provider, model.Model)
193}
194
195func (c *Config) SmallModel() *catwalk.Model {
196 model, ok := c.Models[SelectedModelTypeSmall]
197 if !ok {
198 return nil
199 }
200 return c.GetModel(model.Provider, model.Model)
201}
202
203func (c *Config) SetCompactMode(enabled bool) error {
204 if c.Options == nil {
205 c.Options = &Options{}
206 }
207 c.Options.TUI.CompactMode = enabled
208 return c.SetConfigField("options.tui.compact_mode", enabled)
209}
210
211func (c *Config) Resolve(key string) (string, error) {
212 if c.resolver == nil {
213 return "", fmt.Errorf("no variable resolver configured")
214 }
215 return c.resolver.ResolveValue(key)
216}
217
218func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model agent.Model) error {
219 c.Models[modelType] = model
220 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
221 return fmt.Errorf("failed to update preferred model: %w", err)
222 }
223 return nil
224}
225
226func (c *Config) SetConfigField(key string, value any) error {
227 // read the data
228 data, err := os.ReadFile(c.dataConfigDir)
229 if err != nil {
230 if os.IsNotExist(err) {
231 data = []byte("{}")
232 } else {
233 return fmt.Errorf("failed to read config file: %w", err)
234 }
235 }
236
237 newValue, err := sjson.Set(string(data), key, value)
238 if err != nil {
239 return fmt.Errorf("failed to set config field %s: %w", key, err)
240 }
241 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
242 return fmt.Errorf("failed to write config file: %w", err)
243 }
244 return nil
245}
246
247func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
248 // First save to the config file
249 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
250 if err != nil {
251 return fmt.Errorf("failed to save API key to config file: %w", err)
252 }
253
254 providerConfig, exists := c.Providers.Get(providerID)
255 if exists {
256 providerConfig.APIKey = apiKey
257 c.Providers.Set(providerID, providerConfig)
258 return nil
259 }
260
261 var foundProvider *catwalk.Provider
262 for _, p := range c.knownProviders {
263 if string(p.ID) == providerID {
264 foundProvider = &p
265 break
266 }
267 }
268
269 if foundProvider != nil {
270 // Create new provider config based on known provider
271 providerConfig = provider.Config{
272 ID: providerID,
273 Name: foundProvider.Name,
274 BaseURL: foundProvider.APIEndpoint,
275 Type: foundProvider.Type,
276 APIKey: apiKey,
277 Disable: false,
278 ExtraHeaders: make(map[string]string),
279 ExtraParams: make(map[string]string),
280 Models: foundProvider.Models,
281 }
282 } else {
283 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
284 }
285 // Store the updated provider config
286 c.Providers.Set(providerID, providerConfig)
287 return nil
288}
289
290func (c *Config) Resolver() resolver.Resolver {
291 return c.resolver
292}