1package config
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/csync"
11 "github.com/charmbracelet/crush/internal/llm/agent"
12 "github.com/charmbracelet/crush/internal/llm/provider"
13 "github.com/charmbracelet/crush/internal/resolver"
14 "github.com/tidwall/sjson"
15)
16
17const (
18 appName = "crush"
19 defaultDataDirectory = ".crush"
20 defaultLogLevel = "info"
21)
22
23var defaultContextPaths = []string{
24 ".github/copilot-instructions.md",
25 ".cursorrules",
26 ".cursor/rules/",
27 "CLAUDE.md",
28 "CLAUDE.local.md",
29 "GEMINI.md",
30 "gemini.md",
31 "crush.md",
32 "crush.local.md",
33 "Crush.md",
34 "Crush.local.md",
35 "CRUSH.md",
36 "CRUSH.local.md",
37}
38
39type SelectedModelType string
40
41const (
42 SelectedModelTypeLarge SelectedModelType = "large"
43 SelectedModelTypeSmall SelectedModelType = "small"
44)
45
46type LSPConfig struct {
47 Disabled bool `json:"enabled,omitempty"`
48 Command string `json:"command"`
49 Args []string `json:"args,omitempty"`
50 Options any `json:"options,omitempty"`
51}
52
53type TUIOptions struct {
54 CompactMode bool `json:"compact_mode,omitempty"`
55 // Here we can add themes later or any TUI related options
56}
57
58type Permissions struct {
59 AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
60 SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
61}
62
63type Options struct {
64 ContextPaths []string `json:"context_paths,omitempty"`
65 TUI *TUIOptions `json:"tui,omitempty"`
66 Debug bool `json:"debug,omitempty"`
67 DebugLSP bool `json:"debug_lsp,omitempty"`
68 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
69 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
70}
71
72type MCPs map[string]agent.MCPConfig
73
74type MCP struct {
75 Name string `json:"name"`
76 MCP agent.MCPConfig `json:"mcp"`
77}
78
79func (m MCPs) Sorted() []MCP {
80 sorted := make([]MCP, 0, len(m))
81 for k, v := range m {
82 sorted = append(sorted, MCP{
83 Name: k,
84 MCP: v,
85 })
86 }
87 slices.SortFunc(sorted, func(a, b MCP) int {
88 return strings.Compare(a.Name, b.Name)
89 })
90 return sorted
91}
92
93type LSPs map[string]LSPConfig
94
95type LSP struct {
96 Name string `json:"name"`
97 LSP LSPConfig `json:"lsp"`
98}
99
100func (l LSPs) Sorted() []LSP {
101 sorted := make([]LSP, 0, len(l))
102 for k, v := range l {
103 sorted = append(sorted, LSP{
104 Name: k,
105 LSP: v,
106 })
107 }
108 slices.SortFunc(sorted, func(a, b LSP) int {
109 return strings.Compare(a.Name, b.Name)
110 })
111 return sorted
112}
113
114// Config holds the configuration for crush.
115type Config struct {
116 // We currently only support large/small as values here.
117 Models map[SelectedModelType]agent.Model `json:"models,omitempty"`
118
119 // The providers that are configured
120 Providers *csync.Map[string, provider.Config] `json:"providers,omitempty"`
121
122 MCP MCPs `json:"mcp,omitempty"`
123
124 LSP LSPs `json:"lsp,omitempty"`
125
126 Options *Options `json:"options,omitempty"`
127
128 Permissions *Permissions `json:"permissions,omitempty"`
129
130 // Internal
131 workingDir string `json:"-"`
132 resolver resolver.Resolver
133 dataConfigDir string `json:"-"`
134 knownProviders []catwalk.Provider `json:"-"`
135}
136
137func (c *Config) WorkingDir() string {
138 return c.workingDir
139}
140
141func (c *Config) EnabledProviders() []provider.Config {
142 var enabled []provider.Config
143 for p := range c.Providers.Seq() {
144 if !p.Disable {
145 enabled = append(enabled, p)
146 }
147 }
148 return enabled
149}
150
151// IsConfigured return true if at least one provider is configured
152func (c *Config) IsConfigured() bool {
153 return len(c.EnabledProviders()) > 0
154}
155
156func (c *Config) GetModel(provider, model string) *catwalk.Model {
157 if providerConfig, ok := c.Providers.Get(provider); ok {
158 for _, m := range providerConfig.Models {
159 if m.ID == model {
160 return &m
161 }
162 }
163 }
164 return nil
165}
166
167func (c *Config) GetProviderForModel(modelType SelectedModelType) *provider.Config {
168 model, ok := c.Models[modelType]
169 if !ok {
170 return nil
171 }
172 if providerConfig, ok := c.Providers.Get(model.Provider); ok {
173 return &providerConfig
174 }
175 return nil
176}
177
178func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
179 model, ok := c.Models[modelType]
180 if !ok {
181 return nil
182 }
183 return c.GetModel(model.Provider, model.Model)
184}
185
186func (c *Config) LargeModel() *catwalk.Model {
187 model, ok := c.Models[SelectedModelTypeLarge]
188 if !ok {
189 return nil
190 }
191 return c.GetModel(model.Provider, model.Model)
192}
193
194func (c *Config) SmallModel() *catwalk.Model {
195 model, ok := c.Models[SelectedModelTypeSmall]
196 if !ok {
197 return nil
198 }
199 return c.GetModel(model.Provider, model.Model)
200}
201
202func (c *Config) SetCompactMode(enabled bool) error {
203 if c.Options == nil {
204 c.Options = &Options{}
205 }
206 c.Options.TUI.CompactMode = enabled
207 return c.SetConfigField("options.tui.compact_mode", enabled)
208}
209
210func (c *Config) Resolve(key string) (string, error) {
211 if c.resolver == nil {
212 return "", fmt.Errorf("no variable resolver configured")
213 }
214 return c.resolver.ResolveValue(key)
215}
216
217func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model agent.Model) error {
218 c.Models[modelType] = model
219 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
220 return fmt.Errorf("failed to update preferred model: %w", err)
221 }
222 return nil
223}
224
225func (c *Config) SetConfigField(key string, value any) error {
226 // read the data
227 data, err := os.ReadFile(c.dataConfigDir)
228 if err != nil {
229 if os.IsNotExist(err) {
230 data = []byte("{}")
231 } else {
232 return fmt.Errorf("failed to read config file: %w", err)
233 }
234 }
235
236 newValue, err := sjson.Set(string(data), key, value)
237 if err != nil {
238 return fmt.Errorf("failed to set config field %s: %w", key, err)
239 }
240 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
241 return fmt.Errorf("failed to write config file: %w", err)
242 }
243 return nil
244}
245
246func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
247 // First save to the config file
248 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
249 if err != nil {
250 return fmt.Errorf("failed to save API key to config file: %w", err)
251 }
252
253 providerConfig, exists := c.Providers.Get(providerID)
254 if exists {
255 providerConfig.APIKey = apiKey
256 c.Providers.Set(providerID, providerConfig)
257 return nil
258 }
259
260 var foundProvider *catwalk.Provider
261 for _, p := range c.knownProviders {
262 if string(p.ID) == providerID {
263 foundProvider = &p
264 break
265 }
266 }
267
268 if foundProvider != nil {
269 // Create new provider config based on known provider
270 providerConfig = provider.Config{
271 ID: providerID,
272 Name: foundProvider.Name,
273 BaseURL: foundProvider.APIEndpoint,
274 Type: foundProvider.Type,
275 APIKey: apiKey,
276 Disable: false,
277 ExtraHeaders: make(map[string]string),
278 ExtraParams: make(map[string]string),
279 Models: foundProvider.Models,
280 }
281 } else {
282 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
283 }
284 // Store the updated provider config
285 c.Providers.Set(providerID, providerConfig)
286 return nil
287}
288
289func (c *Config) Resolver() resolver.Resolver {
290 return c.resolver
291}