1package config
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "os"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/env"
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/tidwall/sjson"
15 "golang.org/x/exp/slog"
16)
17
18const (
19 appName = "crush"
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22)
23
24var defaultContextPaths = []string{
25 ".github/copilot-instructions.md",
26 ".cursorrules",
27 ".cursor/rules/",
28 "CLAUDE.md",
29 "CLAUDE.local.md",
30 "GEMINI.md",
31 "gemini.md",
32 "crush.md",
33 "crush.local.md",
34 "Crush.md",
35 "Crush.local.md",
36 "CRUSH.md",
37 "CRUSH.local.md",
38}
39
40type SelectedModelType string
41
42const (
43 SelectedModelTypeLarge SelectedModelType = "large"
44 SelectedModelTypeSmall SelectedModelType = "small"
45)
46
47type SelectedModel struct {
48 // The model id as used by the provider API.
49 // Required.
50 Model string `json:"model"`
51 // The model provider, same as the key/id used in the providers config.
52 // Required.
53 Provider string `json:"provider"`
54
55 // Only used by models that use the openai provider and need this set.
56 ReasoningEffort string `json:"reasoning_effort,omitempty"`
57
58 // Overrides the default model configuration.
59 MaxTokens int64 `json:"max_tokens,omitempty"`
60
61 // Used by anthropic models that can reason to indicate if the model should think.
62 Think bool `json:"think,omitempty"`
63}
64
65type ProviderConfig struct {
66 // The provider's id.
67 ID string `json:"id,omitempty"`
68 // The provider's name, used for display purposes.
69 Name string `json:"name,omitempty"`
70 // The provider's API endpoint.
71 BaseURL string `json:"base_url,omitempty"`
72 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
73 Type provider.Type `json:"type,omitempty"`
74 // The provider's API key.
75 APIKey string `json:"api_key,omitempty"`
76 // Marks the provider as disabled.
77 Disable bool `json:"disable,omitempty"`
78
79 // Extra headers to send with each request to the provider.
80 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
81
82 // Used to pass extra parameters to the provider.
83 ExtraParams map[string]string `json:"-"`
84
85 // The provider models
86 Models []provider.Model `json:"models,omitempty"`
87}
88
89type MCPType string
90
91const (
92 MCPStdio MCPType = "stdio"
93 MCPSse MCPType = "sse"
94 MCPHttp MCPType = "http"
95)
96
97type MCPConfig struct {
98 Command string `json:"command,omitempty" `
99 Env map[string]string `json:"env,omitempty"`
100 Args []string `json:"args,omitempty"`
101 Type MCPType `json:"type"`
102 URL string `json:"url,omitempty"`
103 Disabled bool `json:"disabled,omitempty"`
104
105 // TODO: maybe make it possible to get the value from the env
106 Headers map[string]string `json:"headers,omitempty"`
107}
108
109type LSPConfig struct {
110 Disabled bool `json:"enabled,omitempty"`
111 Command string `json:"command"`
112 Args []string `json:"args,omitempty"`
113 Options any `json:"options,omitempty"`
114}
115
116type TUIOptions struct {
117 CompactMode bool `json:"compact_mode,omitempty"`
118 // Here we can add themes later or any TUI related options
119}
120
121type Options struct {
122 ContextPaths []string `json:"context_paths,omitempty"`
123 TUI *TUIOptions `json:"tui,omitempty"`
124 Debug bool `json:"debug,omitempty"`
125 DebugLSP bool `json:"debug_lsp,omitempty"`
126 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
127 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
128 SkipPermissionsRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
129}
130
131type MCPs map[string]MCPConfig
132
133type MCP struct {
134 Name string `json:"name"`
135 MCP MCPConfig `json:"mcp"`
136}
137
138func (m MCPs) Sorted() []MCP {
139 sorted := make([]MCP, 0, len(m))
140 for k, v := range m {
141 sorted = append(sorted, MCP{
142 Name: k,
143 MCP: v,
144 })
145 }
146 slices.SortFunc(sorted, func(a, b MCP) int {
147 return strings.Compare(a.Name, b.Name)
148 })
149 return sorted
150}
151
152type LSPs map[string]LSPConfig
153
154type LSP struct {
155 Name string `json:"name"`
156 LSP LSPConfig `json:"lsp"`
157}
158
159func (l LSPs) Sorted() []LSP {
160 sorted := make([]LSP, 0, len(l))
161 for k, v := range l {
162 sorted = append(sorted, LSP{
163 Name: k,
164 LSP: v,
165 })
166 }
167 slices.SortFunc(sorted, func(a, b LSP) int {
168 return strings.Compare(a.Name, b.Name)
169 })
170 return sorted
171}
172
173func (m MCPConfig) ResolvedEnv() []string {
174 resolver := NewShellVariableResolver(env.New())
175 for e, v := range m.Env {
176 var err error
177 m.Env[e], err = resolver.ResolveValue(v)
178 if err != nil {
179 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
180 continue
181 }
182 }
183
184 env := make([]string, 0, len(m.Env))
185 for k, v := range m.Env {
186 env = append(env, fmt.Sprintf("%s=%s", k, v))
187 }
188 return env
189}
190
191func (m MCPConfig) ResolvedHeaders() map[string]string {
192 resolver := NewShellVariableResolver(env.New())
193 for e, v := range m.Headers {
194 var err error
195 m.Headers[e], err = resolver.ResolveValue(v)
196 if err != nil {
197 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
198 continue
199 }
200 }
201 return m.Headers
202}
203
204type Agent struct {
205 ID string `json:"id,omitempty"`
206 Name string `json:"name,omitempty"`
207 Description string `json:"description,omitempty"`
208 // This is the id of the system prompt used by the agent
209 Disabled bool `json:"disabled,omitempty"`
210
211 Model SelectedModelType `json:"model"`
212
213 // The available tools for the agent
214 // if this is nil, all tools are available
215 AllowedTools []string `json:"allowed_tools,omitempty"`
216
217 // this tells us which MCPs are available for this agent
218 // if this is empty all mcps are available
219 // the string array is the list of tools from the AllowedMCP the agent has available
220 // if the string array is nil, all tools from the AllowedMCP are available
221 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
222
223 // The list of LSPs that this agent can use
224 // if this is nil, all LSPs are available
225 AllowedLSP []string `json:"allowed_lsp,omitempty"`
226
227 // Overrides the context paths for this agent
228 ContextPaths []string `json:"context_paths,omitempty"`
229}
230
231// Config holds the configuration for crush.
232type Config struct {
233 // We currently only support large/small as values here.
234 Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
235
236 // The providers that are configured
237 Providers map[string]ProviderConfig `json:"providers,omitempty"`
238
239 MCP MCPs `json:"mcp,omitempty"`
240
241 LSP LSPs `json:"lsp,omitempty"`
242
243 Options *Options `json:"options,omitempty"`
244
245 // Internal
246 workingDir string `json:"-"`
247 // TODO: most likely remove this concept when I come back to it
248 Agents map[string]Agent `json:"-"`
249 // TODO: find a better way to do this this should probably not be part of the config
250 resolver VariableResolver
251 dataConfigDir string `json:"-"`
252 knownProviders []provider.Provider `json:"-"`
253}
254
255func (c *Config) WorkingDir() string {
256 return c.workingDir
257}
258
259func (c *Config) EnabledProviders() []ProviderConfig {
260 enabled := make([]ProviderConfig, 0, len(c.Providers))
261 for _, p := range c.Providers {
262 if !p.Disable {
263 enabled = append(enabled, p)
264 }
265 }
266 return enabled
267}
268
269// IsConfigured return true if at least one provider is configured
270func (c *Config) IsConfigured() bool {
271 return len(c.EnabledProviders()) > 0
272}
273
274func (c *Config) GetModel(provider, model string) *provider.Model {
275 if providerConfig, ok := c.Providers[provider]; ok {
276 for _, m := range providerConfig.Models {
277 if m.ID == model {
278 return &m
279 }
280 }
281 }
282 return nil
283}
284
285func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
286 model, ok := c.Models[modelType]
287 if !ok {
288 return nil
289 }
290 if providerConfig, ok := c.Providers[model.Provider]; ok {
291 return &providerConfig
292 }
293 return nil
294}
295
296func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
297 model, ok := c.Models[modelType]
298 if !ok {
299 return nil
300 }
301 return c.GetModel(model.Provider, model.Model)
302}
303
304func (c *Config) LargeModel() *provider.Model {
305 model, ok := c.Models[SelectedModelTypeLarge]
306 if !ok {
307 return nil
308 }
309 return c.GetModel(model.Provider, model.Model)
310}
311
312func (c *Config) SmallModel() *provider.Model {
313 model, ok := c.Models[SelectedModelTypeSmall]
314 if !ok {
315 return nil
316 }
317 return c.GetModel(model.Provider, model.Model)
318}
319
320func (c *Config) SetCompactMode(enabled bool) error {
321 if c.Options == nil {
322 c.Options = &Options{}
323 }
324 c.Options.TUI.CompactMode = enabled
325 return c.SetConfigField("options.tui.compact_mode", enabled)
326}
327
328func (c *Config) Resolve(key string) (string, error) {
329 if c.resolver == nil {
330 return "", fmt.Errorf("no variable resolver configured")
331 }
332 return c.resolver.ResolveValue(key)
333}
334
335func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
336 c.Models[modelType] = model
337 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
338 return fmt.Errorf("failed to update preferred model: %w", err)
339 }
340 return nil
341}
342
343func (c *Config) SetConfigField(key string, value any) error {
344 // read the data
345 data, err := os.ReadFile(c.dataConfigDir)
346 if err != nil {
347 if os.IsNotExist(err) {
348 data = []byte("{}")
349 } else {
350 return fmt.Errorf("failed to read config file: %w", err)
351 }
352 }
353
354 newValue, err := sjson.Set(string(data), key, value)
355 if err != nil {
356 return fmt.Errorf("failed to set config field %s: %w", key, err)
357 }
358 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
359 return fmt.Errorf("failed to write config file: %w", err)
360 }
361 return nil
362}
363
364func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
365 // First save to the config file
366 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
367 if err != nil {
368 return fmt.Errorf("failed to save API key to config file: %w", err)
369 }
370
371 if c.Providers == nil {
372 c.Providers = make(map[string]ProviderConfig)
373 }
374
375 providerConfig, exists := c.Providers[providerID]
376 if exists {
377 providerConfig.APIKey = apiKey
378 c.Providers[providerID] = providerConfig
379 return nil
380 }
381
382 var foundProvider *provider.Provider
383 for _, p := range c.knownProviders {
384 if string(p.ID) == providerID {
385 foundProvider = &p
386 break
387 }
388 }
389
390 if foundProvider != nil {
391 // Create new provider config based on known provider
392 providerConfig = ProviderConfig{
393 ID: providerID,
394 Name: foundProvider.Name,
395 BaseURL: foundProvider.APIEndpoint,
396 Type: foundProvider.Type,
397 APIKey: apiKey,
398 Disable: false,
399 ExtraHeaders: make(map[string]string),
400 ExtraParams: make(map[string]string),
401 Models: foundProvider.Models,
402 }
403 } else {
404 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
405 }
406 // Store the updated provider config
407 c.Providers[providerID] = providerConfig
408 return nil
409}
410
411func (c *Config) SetupAgents() {
412 agents := map[string]Agent{
413 "coder": {
414 ID: "coder",
415 Name: "Coder",
416 Description: "An agent that helps with executing coding tasks.",
417 Model: SelectedModelTypeLarge,
418 ContextPaths: c.Options.ContextPaths,
419 // All tools allowed
420 },
421 "task": {
422 ID: "task",
423 Name: "Task",
424 Description: "An agent that helps with searching for context and finding implementation details.",
425 Model: SelectedModelTypeLarge,
426 ContextPaths: c.Options.ContextPaths,
427 AllowedTools: []string{
428 "glob",
429 "grep",
430 "ls",
431 "sourcegraph",
432 "view",
433 },
434 // NO MCPs or LSPs by default
435 AllowedMCP: map[string][]string{},
436 AllowedLSP: []string{},
437 },
438 }
439 c.Agents = agents
440}
441
442func (c *Config) Resolver() VariableResolver {
443 return c.resolver
444}
445
446func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
447 testURL := ""
448 headers := make(map[string]string)
449 apiKey, _ := resolver.ResolveValue(c.APIKey)
450 switch c.Type {
451 case provider.TypeOpenAI:
452 baseURL, _ := resolver.ResolveValue(c.BaseURL)
453 if baseURL == "" {
454 baseURL = "https://api.openai.com/v1"
455 }
456 testURL = baseURL + "/models"
457 headers["Authorization"] = "Bearer " + apiKey
458 case provider.TypeAnthropic:
459 baseURL, _ := resolver.ResolveValue(c.BaseURL)
460 if baseURL == "" {
461 baseURL = "https://api.anthropic.com/v1"
462 }
463 testURL = baseURL + "/models"
464 headers["x-api-key"] = apiKey
465 headers["anthropic-version"] = "2023-06-01"
466 }
467 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
468 defer cancel()
469 client := &http.Client{}
470 req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
471 if err != nil {
472 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
473 }
474 for k, v := range headers {
475 req.Header.Set(k, v)
476 }
477 for k, v := range c.ExtraHeaders {
478 req.Header.Set(k, v)
479 }
480 b, err := client.Do(req)
481 if err != nil {
482 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
483 }
484 if b.StatusCode != http.StatusOK {
485 return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
486 }
487 _ = b.Body.Close()
488 return nil
489}