1package config
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "net/url"
9 "os"
10 "slices"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/env"
17 "github.com/tidwall/sjson"
18)
19
20const (
21 appName = "crush"
22 defaultDataDirectory = ".crush"
23 defaultLogLevel = "info"
24)
25
26var defaultContextPaths = []string{
27 ".github/copilot-instructions.md",
28 ".cursorrules",
29 ".cursor/rules/",
30 "CLAUDE.md",
31 "CLAUDE.local.md",
32 "GEMINI.md",
33 "gemini.md",
34 "crush.md",
35 "crush.local.md",
36 "Crush.md",
37 "Crush.local.md",
38 "CRUSH.md",
39 "CRUSH.local.md",
40}
41
42type SelectedModelType string
43
44const (
45 SelectedModelTypeLarge SelectedModelType = "large"
46 SelectedModelTypeSmall SelectedModelType = "small"
47)
48
49type SelectedModel struct {
50 // The model id as used by the provider API.
51 // Required.
52 Model string `json:"model"`
53 // The model provider, same as the key/id used in the providers config.
54 // Required.
55 Provider string `json:"provider"`
56
57 // Only used by models that use the openai provider and need this set.
58 ReasoningEffort string `json:"reasoning_effort,omitempty"`
59
60 // Overrides the default model configuration.
61 MaxTokens int64 `json:"max_tokens,omitempty"`
62
63 // Used by anthropic models that can reason to indicate if the model should think.
64 Think bool `json:"think,omitempty"`
65}
66
67type ProviderConfig struct {
68 // The provider's id.
69 ID string `json:"id,omitempty"`
70 // The provider's name, used for display purposes.
71 Name string `json:"name,omitempty"`
72 // The provider's API endpoint.
73 BaseURL string `json:"base_url,omitempty"`
74 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
75 Type catwalk.Type `json:"type,omitempty"`
76 // The provider's API key.
77 APIKey string `json:"api_key,omitempty"`
78 // Marks the provider as disabled.
79 Disable bool `json:"disable,omitempty"`
80
81 // Custom system prompt prefix.
82 SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
83
84 // Extra headers to send with each request to the provider.
85 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
86 // Extra body
87 ExtraBody map[string]any `json:"extra_body,omitempty"`
88
89 // Used to pass extra parameters to the provider.
90 ExtraParams map[string]string `json:"-"`
91
92 // The provider models
93 Models []catwalk.Model `json:"models,omitempty"`
94}
95
96type MCPType string
97
98const (
99 MCPStdio MCPType = "stdio"
100 MCPSse MCPType = "sse"
101 MCPHttp MCPType = "http"
102)
103
104type MCPConfig struct {
105 Command string `json:"command,omitempty" `
106 Env map[string]string `json:"env,omitempty"`
107 Args []string `json:"args,omitempty"`
108 Type MCPType `json:"type"`
109 URL string `json:"url,omitempty"`
110 Disabled bool `json:"disabled,omitempty"`
111
112 // TODO: maybe make it possible to get the value from the env
113 Headers map[string]string `json:"headers,omitempty"`
114}
115
116type LSPConfig struct {
117 Disabled bool `json:"enabled,omitempty"`
118 Command string `json:"command"`
119 Args []string `json:"args,omitempty"`
120 Options any `json:"options,omitempty"`
121}
122
123type TUIOptions struct {
124 CompactMode bool `json:"compact_mode,omitempty"`
125 // Here we can add themes later or any TUI related options
126}
127
128type Permissions struct {
129 AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
130 SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
131}
132
133type Options struct {
134 ContextPaths []string `json:"context_paths,omitempty"`
135 TUI *TUIOptions `json:"tui,omitempty"`
136 Debug bool `json:"debug,omitempty"`
137 DebugLSP bool `json:"debug_lsp,omitempty"`
138 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
139 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
140}
141
142type MCPs map[string]MCPConfig
143
144type MCP struct {
145 Name string `json:"name"`
146 MCP MCPConfig `json:"mcp"`
147}
148
149func (m MCPs) Sorted() []MCP {
150 sorted := make([]MCP, 0, len(m))
151 for k, v := range m {
152 sorted = append(sorted, MCP{
153 Name: k,
154 MCP: v,
155 })
156 }
157 slices.SortFunc(sorted, func(a, b MCP) int {
158 return strings.Compare(a.Name, b.Name)
159 })
160 return sorted
161}
162
163type LSPs map[string]LSPConfig
164
165type LSP struct {
166 Name string `json:"name"`
167 LSP LSPConfig `json:"lsp"`
168}
169
170func (l LSPs) Sorted() []LSP {
171 sorted := make([]LSP, 0, len(l))
172 for k, v := range l {
173 sorted = append(sorted, LSP{
174 Name: k,
175 LSP: v,
176 })
177 }
178 slices.SortFunc(sorted, func(a, b LSP) int {
179 return strings.Compare(a.Name, b.Name)
180 })
181 return sorted
182}
183
184func (m MCPConfig) ResolvedEnv() []string {
185 resolver := NewShellVariableResolver(env.New())
186 for e, v := range m.Env {
187 var err error
188 m.Env[e], err = resolver.ResolveValue(v)
189 if err != nil {
190 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
191 continue
192 }
193 }
194
195 env := make([]string, 0, len(m.Env))
196 for k, v := range m.Env {
197 env = append(env, fmt.Sprintf("%s=%s", k, v))
198 }
199 return env
200}
201
202func (m MCPConfig) ResolvedHeaders() map[string]string {
203 resolver := NewShellVariableResolver(env.New())
204 for e, v := range m.Headers {
205 var err error
206 m.Headers[e], err = resolver.ResolveValue(v)
207 if err != nil {
208 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
209 continue
210 }
211 }
212 return m.Headers
213}
214
215type Agent struct {
216 ID string `json:"id,omitempty"`
217 Name string `json:"name,omitempty"`
218 Description string `json:"description,omitempty"`
219 // This is the id of the system prompt used by the agent
220 Disabled bool `json:"disabled,omitempty"`
221
222 Model SelectedModelType `json:"model"`
223
224 // The available tools for the agent
225 // if this is nil, all tools are available
226 AllowedTools []string `json:"allowed_tools,omitempty"`
227
228 // this tells us which MCPs are available for this agent
229 // if this is empty all mcps are available
230 // the string array is the list of tools from the AllowedMCP the agent has available
231 // if the string array is nil, all tools from the AllowedMCP are available
232 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
233
234 // The list of LSPs that this agent can use
235 // if this is nil, all LSPs are available
236 AllowedLSP []string `json:"allowed_lsp,omitempty"`
237
238 // Overrides the context paths for this agent
239 ContextPaths []string `json:"context_paths,omitempty"`
240}
241
242// Config holds the configuration for crush.
243type Config struct {
244 // We currently only support large/small as values here.
245 Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
246
247 // The providers that are configured
248 Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"`
249
250 MCP MCPs `json:"mcp,omitempty"`
251
252 LSP LSPs `json:"lsp,omitempty"`
253
254 Options *Options `json:"options,omitempty"`
255
256 Permissions *Permissions `json:"permissions,omitempty"`
257
258 // Internal
259 workingDir string `json:"-"`
260 // TODO: most likely remove this concept when I come back to it
261 Agents map[string]Agent `json:"-"`
262 // TODO: find a better way to do this this should probably not be part of the config
263 resolver VariableResolver
264 dataConfigDir string `json:"-"`
265 knownProviders []catwalk.Provider `json:"-"`
266}
267
268func (c *Config) WorkingDir() string {
269 return c.workingDir
270}
271
272func (c *Config) EnabledProviders() []ProviderConfig {
273 var enabled []ProviderConfig
274 for p := range c.Providers.Seq() {
275 if !p.Disable {
276 enabled = append(enabled, p)
277 }
278 }
279 return enabled
280}
281
282// IsConfigured return true if at least one provider is configured
283func (c *Config) IsConfigured() bool {
284 return len(c.EnabledProviders()) > 0
285}
286
287func (c *Config) GetModel(provider, model string) *catwalk.Model {
288 if providerConfig, ok := c.Providers.Get(provider); ok {
289 for _, m := range providerConfig.Models {
290 if m.ID == model {
291 return &m
292 }
293 }
294 }
295 return nil
296}
297
298func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
299 model, ok := c.Models[modelType]
300 if !ok {
301 return nil
302 }
303 if providerConfig, ok := c.Providers.Get(model.Provider); ok {
304 return &providerConfig
305 }
306 return nil
307}
308
309func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
310 model, ok := c.Models[modelType]
311 if !ok {
312 return nil
313 }
314 return c.GetModel(model.Provider, model.Model)
315}
316
317func (c *Config) LargeModel() *catwalk.Model {
318 model, ok := c.Models[SelectedModelTypeLarge]
319 if !ok {
320 return nil
321 }
322 return c.GetModel(model.Provider, model.Model)
323}
324
325func (c *Config) SmallModel() *catwalk.Model {
326 model, ok := c.Models[SelectedModelTypeSmall]
327 if !ok {
328 return nil
329 }
330 return c.GetModel(model.Provider, model.Model)
331}
332
333func (c *Config) SetCompactMode(enabled bool) error {
334 if c.Options == nil {
335 c.Options = &Options{}
336 }
337 c.Options.TUI.CompactMode = enabled
338 return c.SetConfigField("options.tui.compact_mode", enabled)
339}
340
341func (c *Config) Resolve(key string) (string, error) {
342 if c.resolver == nil {
343 return "", fmt.Errorf("no variable resolver configured")
344 }
345 return c.resolver.ResolveValue(key)
346}
347
348func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
349 c.Models[modelType] = model
350 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
351 return fmt.Errorf("failed to update preferred model: %w", err)
352 }
353 return nil
354}
355
356func (c *Config) SetConfigField(key string, value any) error {
357 // read the data
358 data, err := os.ReadFile(c.dataConfigDir)
359 if err != nil {
360 if os.IsNotExist(err) {
361 data = []byte("{}")
362 } else {
363 return fmt.Errorf("failed to read config file: %w", err)
364 }
365 }
366
367 newValue, err := sjson.Set(string(data), key, value)
368 if err != nil {
369 return fmt.Errorf("failed to set config field %s: %w", key, err)
370 }
371 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
372 return fmt.Errorf("failed to write config file: %w", err)
373 }
374 return nil
375}
376
377func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
378 // First save to the config file
379 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
380 if err != nil {
381 return fmt.Errorf("failed to save API key to config file: %w", err)
382 }
383
384 providerConfig, exists := c.Providers.Get(providerID)
385 if exists {
386 providerConfig.APIKey = apiKey
387 c.Providers.Set(providerID, providerConfig)
388 return nil
389 }
390
391 var foundProvider *catwalk.Provider
392 for _, p := range c.knownProviders {
393 if string(p.ID) == providerID {
394 foundProvider = &p
395 break
396 }
397 }
398
399 if foundProvider != nil {
400 // Create new provider config based on known provider
401 providerConfig = ProviderConfig{
402 ID: providerID,
403 Name: foundProvider.Name,
404 BaseURL: foundProvider.APIEndpoint,
405 Type: foundProvider.Type,
406 APIKey: apiKey,
407 Disable: false,
408 ExtraHeaders: make(map[string]string),
409 ExtraParams: make(map[string]string),
410 Models: foundProvider.Models,
411 }
412 } else {
413 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
414 }
415 // Store the updated provider config
416 c.Providers.Set(providerID, providerConfig)
417 return nil
418}
419
420func (c *Config) SetupAgents() {
421 agents := map[string]Agent{
422 "coder": {
423 ID: "coder",
424 Name: "Coder",
425 Description: "An agent that helps with executing coding tasks.",
426 Model: SelectedModelTypeLarge,
427 ContextPaths: c.Options.ContextPaths,
428 // All tools allowed
429 },
430 "task": {
431 ID: "task",
432 Name: "Task",
433 Description: "An agent that helps with searching for context and finding implementation details.",
434 Model: SelectedModelTypeLarge,
435 ContextPaths: c.Options.ContextPaths,
436 AllowedTools: []string{
437 "glob",
438 "grep",
439 "ls",
440 "sourcegraph",
441 "view",
442 },
443 // NO MCPs or LSPs by default
444 AllowedMCP: map[string][]string{},
445 AllowedLSP: []string{},
446 },
447 }
448 c.Agents = agents
449}
450
451func (c *Config) Resolver() VariableResolver {
452 return c.resolver
453}
454
455func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
456 testURL := ""
457 headers := make(map[string]string)
458 apiKey, _ := resolver.ResolveValue(c.APIKey)
459 switch c.Type {
460 case catwalk.TypeOpenAI:
461 baseURL, _ := resolver.ResolveValue(c.BaseURL)
462 if baseURL == "" {
463 baseURL = "https://api.openai.com/v1"
464 }
465 testURL = baseURL + "/models"
466 headers["Authorization"] = "Bearer " + apiKey
467 case catwalk.TypeAnthropic:
468 baseURL, _ := resolver.ResolveValue(c.BaseURL)
469 if baseURL == "" {
470 baseURL = "https://api.anthropic.com/v1"
471 }
472 testURL = baseURL + "/models"
473 headers["x-api-key"] = apiKey
474 headers["anthropic-version"] = "2023-06-01"
475 case catwalk.TypeGemini:
476 baseURL, _ := resolver.ResolveValue(c.BaseURL)
477 if baseURL == "" {
478 baseURL = "https://generativelanguage.googleapis.com"
479 }
480 testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey)
481 }
482 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
483 defer cancel()
484 client := &http.Client{}
485 req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
486 if err != nil {
487 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
488 }
489 for k, v := range headers {
490 req.Header.Set(k, v)
491 }
492 for k, v := range c.ExtraHeaders {
493 req.Header.Set(k, v)
494 }
495 b, err := client.Do(req)
496 if err != nil {
497 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
498 }
499 if b.StatusCode != http.StatusOK {
500 return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
501 }
502 _ = b.Body.Close()
503 return nil
504}