1package config
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "os"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/env"
14 "github.com/tidwall/sjson"
15 "golang.org/x/exp/slog"
16)
17
18const (
19 appName = "crush"
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22)
23
24var defaultContextPaths = []string{
25 ".github/copilot-instructions.md",
26 ".cursorrules",
27 ".cursor/rules/",
28 "CLAUDE.md",
29 "CLAUDE.local.md",
30 "GEMINI.md",
31 "gemini.md",
32 "crush.md",
33 "crush.local.md",
34 "Crush.md",
35 "Crush.local.md",
36 "CRUSH.md",
37 "CRUSH.local.md",
38}
39
40type SelectedModelType string
41
42const (
43 SelectedModelTypeLarge SelectedModelType = "large"
44 SelectedModelTypeSmall SelectedModelType = "small"
45)
46
47type SelectedModel struct {
48 // The model id as used by the provider API.
49 // Required.
50 Model string `json:"model"`
51 // The model provider, same as the key/id used in the providers config.
52 // Required.
53 Provider string `json:"provider"`
54
55 // Only used by models that use the openai provider and need this set.
56 ReasoningEffort string `json:"reasoning_effort,omitempty"`
57
58 // Overrides the default model configuration.
59 MaxTokens int64 `json:"max_tokens,omitempty"`
60
61 // Used by anthropic models that can reason to indicate if the model should think.
62 Think bool `json:"think,omitempty"`
63}
64
65type ProviderConfig struct {
66 // The provider's id.
67 ID string `json:"id,omitempty"`
68 // The provider's name, used for display purposes.
69 Name string `json:"name,omitempty"`
70 // The provider's API endpoint.
71 BaseURL string `json:"base_url,omitempty"`
72 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
73 Type catwalk.Type `json:"type,omitempty"`
74 // The provider's API key.
75 APIKey string `json:"api_key,omitempty"`
76 // Marks the provider as disabled.
77 Disable bool `json:"disable,omitempty"`
78
79 // Extra headers to send with each request to the provider.
80 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
81 // Extra body
82 ExtraBody map[string]any `json:"extra_body,omitempty"`
83
84 // Used to pass extra parameters to the provider.
85 ExtraParams map[string]string `json:"-"`
86
87 // The provider models
88 Models []catwalk.Model `json:"models,omitempty"`
89}
90
91type MCPType string
92
93const (
94 MCPStdio MCPType = "stdio"
95 MCPSse MCPType = "sse"
96 MCPHttp MCPType = "http"
97)
98
99type MCPConfig struct {
100 Command string `json:"command,omitempty" `
101 Env map[string]string `json:"env,omitempty"`
102 Args []string `json:"args,omitempty"`
103 Type MCPType `json:"type"`
104 URL string `json:"url,omitempty"`
105 Disabled bool `json:"disabled,omitempty"`
106
107 // TODO: maybe make it possible to get the value from the env
108 Headers map[string]string `json:"headers,omitempty"`
109}
110
111type LSPConfig struct {
112 Disabled bool `json:"enabled,omitempty"`
113 Command string `json:"command"`
114 Args []string `json:"args,omitempty"`
115 Options any `json:"options,omitempty"`
116}
117
118type TUIOptions struct {
119 CompactMode bool `json:"compact_mode,omitempty"`
120 // Here we can add themes later or any TUI related options
121}
122
123type Permissions struct {
124 AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
125 SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
126}
127
128type Options struct {
129 ContextPaths []string `json:"context_paths,omitempty"`
130 TUI *TUIOptions `json:"tui,omitempty"`
131 Debug bool `json:"debug,omitempty"`
132 DebugLSP bool `json:"debug_lsp,omitempty"`
133 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
134 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
135}
136
137type MCPs map[string]MCPConfig
138
139type MCP struct {
140 Name string `json:"name"`
141 MCP MCPConfig `json:"mcp"`
142}
143
144func (m MCPs) Sorted() []MCP {
145 sorted := make([]MCP, 0, len(m))
146 for k, v := range m {
147 sorted = append(sorted, MCP{
148 Name: k,
149 MCP: v,
150 })
151 }
152 slices.SortFunc(sorted, func(a, b MCP) int {
153 return strings.Compare(a.Name, b.Name)
154 })
155 return sorted
156}
157
158type LSPs map[string]LSPConfig
159
160type LSP struct {
161 Name string `json:"name"`
162 LSP LSPConfig `json:"lsp"`
163}
164
165func (l LSPs) Sorted() []LSP {
166 sorted := make([]LSP, 0, len(l))
167 for k, v := range l {
168 sorted = append(sorted, LSP{
169 Name: k,
170 LSP: v,
171 })
172 }
173 slices.SortFunc(sorted, func(a, b LSP) int {
174 return strings.Compare(a.Name, b.Name)
175 })
176 return sorted
177}
178
179func (m MCPConfig) ResolvedEnv() []string {
180 resolver := NewShellVariableResolver(env.New())
181 for e, v := range m.Env {
182 var err error
183 m.Env[e], err = resolver.ResolveValue(v)
184 if err != nil {
185 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
186 continue
187 }
188 }
189
190 env := make([]string, 0, len(m.Env))
191 for k, v := range m.Env {
192 env = append(env, fmt.Sprintf("%s=%s", k, v))
193 }
194 return env
195}
196
197func (m MCPConfig) ResolvedHeaders() map[string]string {
198 resolver := NewShellVariableResolver(env.New())
199 for e, v := range m.Headers {
200 var err error
201 m.Headers[e], err = resolver.ResolveValue(v)
202 if err != nil {
203 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
204 continue
205 }
206 }
207 return m.Headers
208}
209
210type Agent struct {
211 ID string `json:"id,omitempty"`
212 Name string `json:"name,omitempty"`
213 Description string `json:"description,omitempty"`
214 // This is the id of the system prompt used by the agent
215 Disabled bool `json:"disabled,omitempty"`
216
217 Model SelectedModelType `json:"model"`
218
219 // The available tools for the agent
220 // if this is nil, all tools are available
221 AllowedTools []string `json:"allowed_tools,omitempty"`
222
223 // this tells us which MCPs are available for this agent
224 // if this is empty all mcps are available
225 // the string array is the list of tools from the AllowedMCP the agent has available
226 // if the string array is nil, all tools from the AllowedMCP are available
227 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
228
229 // The list of LSPs that this agent can use
230 // if this is nil, all LSPs are available
231 AllowedLSP []string `json:"allowed_lsp,omitempty"`
232
233 // Overrides the context paths for this agent
234 ContextPaths []string `json:"context_paths,omitempty"`
235}
236
237// Config holds the configuration for crush.
238type Config struct {
239 // We currently only support large/small as values here.
240 Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
241
242 // The providers that are configured
243 Providers map[string]ProviderConfig `json:"providers,omitempty"`
244
245 MCP MCPs `json:"mcp,omitempty"`
246
247 LSP LSPs `json:"lsp,omitempty"`
248
249 Options *Options `json:"options,omitempty"`
250
251 Permissions *Permissions `json:"permissions,omitempty"`
252
253 // Internal
254 workingDir string `json:"-"`
255 // TODO: most likely remove this concept when I come back to it
256 Agents map[string]Agent `json:"-"`
257 // TODO: find a better way to do this this should probably not be part of the config
258 resolver VariableResolver
259 dataConfigDir string `json:"-"`
260 knownProviders []catwalk.Provider `json:"-"`
261}
262
263func (c *Config) WorkingDir() string {
264 return c.workingDir
265}
266
267func (c *Config) EnabledProviders() []ProviderConfig {
268 enabled := make([]ProviderConfig, 0, len(c.Providers))
269 for _, p := range c.Providers {
270 if !p.Disable {
271 enabled = append(enabled, p)
272 }
273 }
274 return enabled
275}
276
277// IsConfigured return true if at least one provider is configured
278func (c *Config) IsConfigured() bool {
279 return len(c.EnabledProviders()) > 0
280}
281
282func (c *Config) GetModel(provider, model string) *catwalk.Model {
283 if providerConfig, ok := c.Providers[provider]; ok {
284 for _, m := range providerConfig.Models {
285 if m.ID == model {
286 return &m
287 }
288 }
289 }
290 return nil
291}
292
293func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
294 model, ok := c.Models[modelType]
295 if !ok {
296 return nil
297 }
298 if providerConfig, ok := c.Providers[model.Provider]; ok {
299 return &providerConfig
300 }
301 return nil
302}
303
304func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
305 model, ok := c.Models[modelType]
306 if !ok {
307 return nil
308 }
309 return c.GetModel(model.Provider, model.Model)
310}
311
312func (c *Config) LargeModel() *catwalk.Model {
313 model, ok := c.Models[SelectedModelTypeLarge]
314 if !ok {
315 return nil
316 }
317 return c.GetModel(model.Provider, model.Model)
318}
319
320func (c *Config) SmallModel() *catwalk.Model {
321 model, ok := c.Models[SelectedModelTypeSmall]
322 if !ok {
323 return nil
324 }
325 return c.GetModel(model.Provider, model.Model)
326}
327
328func (c *Config) SetCompactMode(enabled bool) error {
329 if c.Options == nil {
330 c.Options = &Options{}
331 }
332 c.Options.TUI.CompactMode = enabled
333 return c.SetConfigField("options.tui.compact_mode", enabled)
334}
335
336func (c *Config) Resolve(key string) (string, error) {
337 if c.resolver == nil {
338 return "", fmt.Errorf("no variable resolver configured")
339 }
340 return c.resolver.ResolveValue(key)
341}
342
343func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
344 c.Models[modelType] = model
345 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
346 return fmt.Errorf("failed to update preferred model: %w", err)
347 }
348 return nil
349}
350
351func (c *Config) SetConfigField(key string, value any) error {
352 // read the data
353 data, err := os.ReadFile(c.dataConfigDir)
354 if err != nil {
355 if os.IsNotExist(err) {
356 data = []byte("{}")
357 } else {
358 return fmt.Errorf("failed to read config file: %w", err)
359 }
360 }
361
362 newValue, err := sjson.Set(string(data), key, value)
363 if err != nil {
364 return fmt.Errorf("failed to set config field %s: %w", key, err)
365 }
366 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
367 return fmt.Errorf("failed to write config file: %w", err)
368 }
369 return nil
370}
371
372func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
373 // First save to the config file
374 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
375 if err != nil {
376 return fmt.Errorf("failed to save API key to config file: %w", err)
377 }
378
379 if c.Providers == nil {
380 c.Providers = make(map[string]ProviderConfig)
381 }
382
383 providerConfig, exists := c.Providers[providerID]
384 if exists {
385 providerConfig.APIKey = apiKey
386 c.Providers[providerID] = providerConfig
387 return nil
388 }
389
390 var foundProvider *catwalk.Provider
391 for _, p := range c.knownProviders {
392 if string(p.ID) == providerID {
393 foundProvider = &p
394 break
395 }
396 }
397
398 if foundProvider != nil {
399 // Create new provider config based on known provider
400 providerConfig = ProviderConfig{
401 ID: providerID,
402 Name: foundProvider.Name,
403 BaseURL: foundProvider.APIEndpoint,
404 Type: foundProvider.Type,
405 APIKey: apiKey,
406 Disable: false,
407 ExtraHeaders: make(map[string]string),
408 ExtraParams: make(map[string]string),
409 Models: foundProvider.Models,
410 }
411 } else {
412 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
413 }
414 // Store the updated provider config
415 c.Providers[providerID] = providerConfig
416 return nil
417}
418
419func (c *Config) SetupAgents() {
420 agents := map[string]Agent{
421 "coder": {
422 ID: "coder",
423 Name: "Coder",
424 Description: "An agent that helps with executing coding tasks.",
425 Model: SelectedModelTypeLarge,
426 ContextPaths: c.Options.ContextPaths,
427 // All tools allowed
428 },
429 "task": {
430 ID: "task",
431 Name: "Task",
432 Description: "An agent that helps with searching for context and finding implementation details.",
433 Model: SelectedModelTypeLarge,
434 ContextPaths: c.Options.ContextPaths,
435 AllowedTools: []string{
436 "glob",
437 "grep",
438 "ls",
439 "sourcegraph",
440 "view",
441 },
442 // NO MCPs or LSPs by default
443 AllowedMCP: map[string][]string{},
444 AllowedLSP: []string{},
445 },
446 }
447 c.Agents = agents
448}
449
450func (c *Config) Resolver() VariableResolver {
451 return c.resolver
452}
453
454func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
455 testURL := ""
456 headers := make(map[string]string)
457 apiKey, _ := resolver.ResolveValue(c.APIKey)
458 switch c.Type {
459 case catwalk.TypeOpenAI:
460 baseURL, _ := resolver.ResolveValue(c.BaseURL)
461 if baseURL == "" {
462 baseURL = "https://api.openai.com/v1"
463 }
464 testURL = baseURL + "/models"
465 headers["Authorization"] = "Bearer " + apiKey
466 case catwalk.TypeAnthropic:
467 baseURL, _ := resolver.ResolveValue(c.BaseURL)
468 if baseURL == "" {
469 baseURL = "https://api.anthropic.com/v1"
470 }
471 testURL = baseURL + "/models"
472 headers["x-api-key"] = apiKey
473 headers["anthropic-version"] = "2023-06-01"
474 }
475 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
476 defer cancel()
477 client := &http.Client{}
478 req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
479 if err != nil {
480 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
481 }
482 for k, v := range headers {
483 req.Header.Set(k, v)
484 }
485 for k, v := range c.ExtraHeaders {
486 req.Header.Set(k, v)
487 }
488 b, err := client.Do(req)
489 if err != nil {
490 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
491 }
492 if b.StatusCode != http.StatusOK {
493 return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
494 }
495 _ = b.Body.Close()
496 return nil
497}