1package config
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "os"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/env"
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/tidwall/sjson"
15 "golang.org/x/exp/slog"
16)
17
18const (
19 appName = "crush"
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22)
23
24var defaultContextPaths = []string{
25 ".github/copilot-instructions.md",
26 ".cursorrules",
27 ".cursor/rules/",
28 "CLAUDE.md",
29 "CLAUDE.local.md",
30 "GEMINI.md",
31 "gemini.md",
32 "crush.md",
33 "crush.local.md",
34 "Crush.md",
35 "Crush.local.md",
36 "CRUSH.md",
37 "CRUSH.local.md",
38}
39
40type SelectedModelType string
41
42const (
43 SelectedModelTypeLarge SelectedModelType = "large"
44 SelectedModelTypeSmall SelectedModelType = "small"
45)
46
47type SelectedModel struct {
48 // The model id as used by the provider API.
49 // Required.
50 Model string `json:"model"`
51 // The model provider, same as the key/id used in the providers config.
52 // Required.
53 Provider string `json:"provider"`
54
55 // Only used by models that use the openai provider and need this set.
56 ReasoningEffort string `json:"reasoning_effort,omitempty"`
57
58 // Overrides the default model configuration.
59 MaxTokens int64 `json:"max_tokens,omitempty"`
60
61 // Used by anthropic models that can reason to indicate if the model should think.
62 Think bool `json:"think,omitempty"`
63}
64
65type ProviderConfig struct {
66 // The provider's id.
67 ID string `json:"id,omitempty"`
68 // The provider's name, used for display purposes.
69 Name string `json:"name,omitempty"`
70 // The provider's API endpoint.
71 BaseURL string `json:"base_url,omitempty"`
72 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
73 Type provider.Type `json:"type,omitempty"`
74 // The provider's API key.
75 APIKey string `json:"api_key,omitempty"`
76 // Marks the provider as disabled.
77 Disable bool `json:"disable,omitempty"`
78
79 // Extra headers to send with each request to the provider.
80 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
81 // Extra body
82 ExtraBody map[string]any `json:"extra_body,omitempty"`
83
84 // Used to pass extra parameters to the provider.
85 ExtraParams map[string]string `json:"-"`
86
87 // The provider models
88 Models []provider.Model `json:"models,omitempty"`
89}
90
91type MCPType string
92
93const (
94 MCPStdio MCPType = "stdio"
95 MCPSse MCPType = "sse"
96 MCPHttp MCPType = "http"
97)
98
99type MCPConfig struct {
100 Command string `json:"command,omitempty" `
101 Env map[string]string `json:"env,omitempty"`
102 Args []string `json:"args,omitempty"`
103 Type MCPType `json:"type"`
104 URL string `json:"url,omitempty"`
105 Disabled bool `json:"disabled,omitempty"`
106
107 // TODO: maybe make it possible to get the value from the env
108 Headers map[string]string `json:"headers,omitempty"`
109}
110
111type LSPConfig struct {
112 Disabled bool `json:"enabled,omitempty"`
113 Command string `json:"command"`
114 Args []string `json:"args,omitempty"`
115 Options any `json:"options,omitempty"`
116}
117
118type TUIOptions struct {
119 CompactMode bool `json:"compact_mode,omitempty"`
120 // Here we can add themes later or any TUI related options
121}
122
123type Options struct {
124 ContextPaths []string `json:"context_paths,omitempty"`
125 TUI *TUIOptions `json:"tui,omitempty"`
126 Debug bool `json:"debug,omitempty"`
127 DebugLSP bool `json:"debug_lsp,omitempty"`
128 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
129 DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
130 SkipPermissionsRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
131}
132
133type MCPs map[string]MCPConfig
134
135type MCP struct {
136 Name string `json:"name"`
137 MCP MCPConfig `json:"mcp"`
138}
139
140func (m MCPs) Sorted() []MCP {
141 sorted := make([]MCP, 0, len(m))
142 for k, v := range m {
143 sorted = append(sorted, MCP{
144 Name: k,
145 MCP: v,
146 })
147 }
148 slices.SortFunc(sorted, func(a, b MCP) int {
149 return strings.Compare(a.Name, b.Name)
150 })
151 return sorted
152}
153
154type LSPs map[string]LSPConfig
155
156type LSP struct {
157 Name string `json:"name"`
158 LSP LSPConfig `json:"lsp"`
159}
160
161func (l LSPs) Sorted() []LSP {
162 sorted := make([]LSP, 0, len(l))
163 for k, v := range l {
164 sorted = append(sorted, LSP{
165 Name: k,
166 LSP: v,
167 })
168 }
169 slices.SortFunc(sorted, func(a, b LSP) int {
170 return strings.Compare(a.Name, b.Name)
171 })
172 return sorted
173}
174
175func (m MCPConfig) ResolvedEnv() []string {
176 resolver := NewShellVariableResolver(env.New())
177 for e, v := range m.Env {
178 var err error
179 m.Env[e], err = resolver.ResolveValue(v)
180 if err != nil {
181 slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
182 continue
183 }
184 }
185
186 env := make([]string, 0, len(m.Env))
187 for k, v := range m.Env {
188 env = append(env, fmt.Sprintf("%s=%s", k, v))
189 }
190 return env
191}
192
193func (m MCPConfig) ResolvedHeaders() map[string]string {
194 resolver := NewShellVariableResolver(env.New())
195 for e, v := range m.Headers {
196 var err error
197 m.Headers[e], err = resolver.ResolveValue(v)
198 if err != nil {
199 slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
200 continue
201 }
202 }
203 return m.Headers
204}
205
206type Agent struct {
207 ID string `json:"id,omitempty"`
208 Name string `json:"name,omitempty"`
209 Description string `json:"description,omitempty"`
210 // This is the id of the system prompt used by the agent
211 Disabled bool `json:"disabled,omitempty"`
212
213 Model SelectedModelType `json:"model"`
214
215 // The available tools for the agent
216 // if this is nil, all tools are available
217 AllowedTools []string `json:"allowed_tools,omitempty"`
218
219 // this tells us which MCPs are available for this agent
220 // if this is empty all mcps are available
221 // the string array is the list of tools from the AllowedMCP the agent has available
222 // if the string array is nil, all tools from the AllowedMCP are available
223 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
224
225 // The list of LSPs that this agent can use
226 // if this is nil, all LSPs are available
227 AllowedLSP []string `json:"allowed_lsp,omitempty"`
228
229 // Overrides the context paths for this agent
230 ContextPaths []string `json:"context_paths,omitempty"`
231}
232
233// Config holds the configuration for crush.
234type Config struct {
235 // We currently only support large/small as values here.
236 Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
237
238 // The providers that are configured
239 Providers map[string]ProviderConfig `json:"providers,omitempty"`
240
241 MCP MCPs `json:"mcp,omitempty"`
242
243 LSP LSPs `json:"lsp,omitempty"`
244
245 Options *Options `json:"options,omitempty"`
246
247 // Internal
248 workingDir string `json:"-"`
249 // TODO: most likely remove this concept when I come back to it
250 Agents map[string]Agent `json:"-"`
251 // TODO: find a better way to do this this should probably not be part of the config
252 resolver VariableResolver
253 dataConfigDir string `json:"-"`
254 knownProviders []provider.Provider `json:"-"`
255}
256
257func (c *Config) WorkingDir() string {
258 return c.workingDir
259}
260
261func (c *Config) EnabledProviders() []ProviderConfig {
262 enabled := make([]ProviderConfig, 0, len(c.Providers))
263 for _, p := range c.Providers {
264 if !p.Disable {
265 enabled = append(enabled, p)
266 }
267 }
268 return enabled
269}
270
271// IsConfigured return true if at least one provider is configured
272func (c *Config) IsConfigured() bool {
273 return len(c.EnabledProviders()) > 0
274}
275
276func (c *Config) GetModel(provider, model string) *provider.Model {
277 if providerConfig, ok := c.Providers[provider]; ok {
278 for _, m := range providerConfig.Models {
279 if m.ID == model {
280 return &m
281 }
282 }
283 }
284 return nil
285}
286
287func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
288 model, ok := c.Models[modelType]
289 if !ok {
290 return nil
291 }
292 if providerConfig, ok := c.Providers[model.Provider]; ok {
293 return &providerConfig
294 }
295 return nil
296}
297
298func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
299 model, ok := c.Models[modelType]
300 if !ok {
301 return nil
302 }
303 return c.GetModel(model.Provider, model.Model)
304}
305
306func (c *Config) LargeModel() *provider.Model {
307 model, ok := c.Models[SelectedModelTypeLarge]
308 if !ok {
309 return nil
310 }
311 return c.GetModel(model.Provider, model.Model)
312}
313
314func (c *Config) SmallModel() *provider.Model {
315 model, ok := c.Models[SelectedModelTypeSmall]
316 if !ok {
317 return nil
318 }
319 return c.GetModel(model.Provider, model.Model)
320}
321
322func (c *Config) SetCompactMode(enabled bool) error {
323 if c.Options == nil {
324 c.Options = &Options{}
325 }
326 c.Options.TUI.CompactMode = enabled
327 return c.SetConfigField("options.tui.compact_mode", enabled)
328}
329
330func (c *Config) Resolve(key string) (string, error) {
331 if c.resolver == nil {
332 return "", fmt.Errorf("no variable resolver configured")
333 }
334 return c.resolver.ResolveValue(key)
335}
336
337func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
338 c.Models[modelType] = model
339 if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
340 return fmt.Errorf("failed to update preferred model: %w", err)
341 }
342 return nil
343}
344
345func (c *Config) SetConfigField(key string, value any) error {
346 // read the data
347 data, err := os.ReadFile(c.dataConfigDir)
348 if err != nil {
349 if os.IsNotExist(err) {
350 data = []byte("{}")
351 } else {
352 return fmt.Errorf("failed to read config file: %w", err)
353 }
354 }
355
356 newValue, err := sjson.Set(string(data), key, value)
357 if err != nil {
358 return fmt.Errorf("failed to set config field %s: %w", key, err)
359 }
360 if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
361 return fmt.Errorf("failed to write config file: %w", err)
362 }
363 return nil
364}
365
366func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
367 // First save to the config file
368 err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
369 if err != nil {
370 return fmt.Errorf("failed to save API key to config file: %w", err)
371 }
372
373 if c.Providers == nil {
374 c.Providers = make(map[string]ProviderConfig)
375 }
376
377 providerConfig, exists := c.Providers[providerID]
378 if exists {
379 providerConfig.APIKey = apiKey
380 c.Providers[providerID] = providerConfig
381 return nil
382 }
383
384 var foundProvider *provider.Provider
385 for _, p := range c.knownProviders {
386 if string(p.ID) == providerID {
387 foundProvider = &p
388 break
389 }
390 }
391
392 if foundProvider != nil {
393 // Create new provider config based on known provider
394 providerConfig = ProviderConfig{
395 ID: providerID,
396 Name: foundProvider.Name,
397 BaseURL: foundProvider.APIEndpoint,
398 Type: foundProvider.Type,
399 APIKey: apiKey,
400 Disable: false,
401 ExtraHeaders: make(map[string]string),
402 ExtraParams: make(map[string]string),
403 Models: foundProvider.Models,
404 }
405 } else {
406 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
407 }
408 // Store the updated provider config
409 c.Providers[providerID] = providerConfig
410 return nil
411}
412
413func (c *Config) SetupAgents() {
414 agents := map[string]Agent{
415 "coder": {
416 ID: "coder",
417 Name: "Coder",
418 Description: "An agent that helps with executing coding tasks.",
419 Model: SelectedModelTypeLarge,
420 ContextPaths: c.Options.ContextPaths,
421 // All tools allowed
422 },
423 "task": {
424 ID: "task",
425 Name: "Task",
426 Description: "An agent that helps with searching for context and finding implementation details.",
427 Model: SelectedModelTypeLarge,
428 ContextPaths: c.Options.ContextPaths,
429 AllowedTools: []string{
430 "glob",
431 "grep",
432 "ls",
433 "sourcegraph",
434 "view",
435 },
436 // NO MCPs or LSPs by default
437 AllowedMCP: map[string][]string{},
438 AllowedLSP: []string{},
439 },
440 }
441 c.Agents = agents
442}
443
444func (c *Config) Resolver() VariableResolver {
445 return c.resolver
446}
447
448func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
449 testURL := ""
450 headers := make(map[string]string)
451 apiKey, _ := resolver.ResolveValue(c.APIKey)
452 switch c.Type {
453 case provider.TypeOpenAI:
454 baseURL, _ := resolver.ResolveValue(c.BaseURL)
455 if baseURL == "" {
456 baseURL = "https://api.openai.com/v1"
457 }
458 testURL = baseURL + "/models"
459 headers["Authorization"] = "Bearer " + apiKey
460 case provider.TypeAnthropic:
461 baseURL, _ := resolver.ResolveValue(c.BaseURL)
462 if baseURL == "" {
463 baseURL = "https://api.anthropic.com/v1"
464 }
465 testURL = baseURL + "/models"
466 headers["x-api-key"] = apiKey
467 headers["anthropic-version"] = "2023-06-01"
468 }
469 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
470 defer cancel()
471 client := &http.Client{}
472 req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
473 if err != nil {
474 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
475 }
476 for k, v := range headers {
477 req.Header.Set(k, v)
478 }
479 for k, v := range c.ExtraHeaders {
480 req.Header.Set(k, v)
481 }
482 b, err := client.Do(req)
483 if err != nil {
484 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
485 }
486 if b.StatusCode != http.StatusOK {
487 return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
488 }
489 _ = b.Body.Close()
490 return nil
491}