1package configv2
2
3import (
4 "encoding/json"
5 "errors"
6 "maps"
7 "os"
8 "path/filepath"
9 "slices"
10 "strings"
11 "sync"
12
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/charmbracelet/crush/internal/logging"
15)
16
17const (
18 defaultDataDirectory = ".crush"
19 defaultLogLevel = "info"
20 appName = "crush"
21
22 MaxTokensFallbackDefault = 4096
23)
24
25var defaultContextPaths = []string{
26 ".github/copilot-instructions.md",
27 ".cursorrules",
28 ".cursor/rules/",
29 "CLAUDE.md",
30 "CLAUDE.local.md",
31 "crush.md",
32 "crush.local.md",
33 "Crush.md",
34 "Crush.local.md",
35 "CRUSH.md",
36 "CRUSH.local.md",
37}
38
39type AgentID string
40
41const (
42 AgentCoder AgentID = "coder"
43 AgentTask AgentID = "task"
44 AgentTitle AgentID = "title"
45 AgentSummarize AgentID = "summarize"
46)
47
48type Model struct {
49 ID string `json:"id"`
50 Name string `json:"model"`
51 CostPer1MIn float64 `json:"cost_per_1m_in"`
52 CostPer1MOut float64 `json:"cost_per_1m_out"`
53 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
54 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
55 ContextWindow int64 `json:"context_window"`
56 DefaultMaxTokens int64 `json:"default_max_tokens"`
57 CanReason bool `json:"can_reason"`
58 ReasoningEffort string `json:"reasoning_effort"`
59 SupportsImages bool `json:"supports_attachments"`
60}
61
62type VertexAIOptions struct {
63 APIKey string `json:"api_key,omitempty"`
64 Project string `json:"project,omitempty"`
65 Location string `json:"location,omitempty"`
66}
67
68type ProviderConfig struct {
69 ID provider.InferenceProvider `json:"id"`
70 BaseURL string `json:"base_url,omitempty"`
71 ProviderType provider.Type `json:"provider_type"`
72 APIKey string `json:"api_key,omitempty"`
73 Disabled bool `json:"disabled"`
74 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
75 // used for e.x for vertex to set the project
76 ExtraParams map[string]string `json:"extra_params,omitempty"`
77
78 DefaultLargeModel string `json:"default_large_model,omitempty"`
79 DefaultSmallModel string `json:"default_small_model,omitempty"`
80
81 Models []Model `json:"models,omitempty"`
82}
83
84type Agent struct {
85 Name string `json:"name"`
86 Description string `json:"description,omitempty"`
87 // This is the id of the system prompt used by the agent
88 Disabled bool `json:"disabled"`
89
90 Provider provider.InferenceProvider `json:"provider"`
91 Model string `json:"model"`
92
93 // The available tools for the agent
94 // if this is nil, all tools are available
95 AllowedTools []string `json:"allowed_tools"`
96
97 // this tells us which MCPs are available for this agent
98 // if this is empty all mcps are available
99 // the string array is the list of tools from the AllowedMCP the agent has available
100 // if the string array is nil, all tools from the AllowedMCP are available
101 AllowedMCP map[string][]string `json:"allowed_mcp"`
102
103 // The list of LSPs that this agent can use
104 // if this is nil, all LSPs are available
105 AllowedLSP []string `json:"allowed_lsp"`
106
107 // Overrides the context paths for this agent
108 ContextPaths []string `json:"context_paths"`
109}
110
111type MCPType string
112
113const (
114 MCPStdio MCPType = "stdio"
115 MCPSse MCPType = "sse"
116)
117
118type MCP struct {
119 Command string `json:"command"`
120 Env []string `json:"env"`
121 Args []string `json:"args"`
122 Type MCPType `json:"type"`
123 URL string `json:"url"`
124 Headers map[string]string `json:"headers"`
125}
126
127type LSPConfig struct {
128 Disabled bool `json:"enabled"`
129 Command string `json:"command"`
130 Args []string `json:"args"`
131 Options any `json:"options"`
132}
133
134type TUIOptions struct {
135 CompactMode bool `json:"compact_mode"`
136 // Here we can add themes later or any TUI related options
137}
138
139type Options struct {
140 ContextPaths []string `json:"context_paths"`
141 TUI TUIOptions `json:"tui"`
142 Debug bool `json:"debug"`
143 DebugLSP bool `json:"debug_lsp"`
144 DisableAutoSummarize bool `json:"disable_auto_summarize"`
145 // Relative to the cwd
146 DataDirectory string `json:"data_directory"`
147}
148
149type Config struct {
150 // List of configured providers
151 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
152
153 // List of configured agents
154 Agents map[AgentID]Agent `json:"agents,omitempty"`
155
156 // List of configured MCPs
157 MCP map[string]MCP `json:"mcp,omitempty"`
158
159 // List of configured LSPs
160 LSP map[string]LSPConfig `json:"lsp,omitempty"`
161
162 // Miscellaneous options
163 Options Options `json:"options"`
164}
165
166var (
167 instance *Config // The single instance of the Singleton
168 cwd string
169 once sync.Once // Ensures the initialization happens only once
170
171)
172
173func loadConfig(cwd string) (*Config, error) {
174 // First read the global config file
175 cfgPath := ConfigPath()
176
177 cfg := defaultConfigBasedOnEnv()
178
179 var globalCfg *Config
180 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
181 // some other error occurred while checking the file
182 return nil, err
183 } else if err == nil {
184 // config file exists, read it
185 file, err := os.ReadFile(cfgPath)
186 if err != nil {
187 return nil, err
188 }
189 globalCfg = &Config{}
190 if err := json.Unmarshal(file, globalCfg); err != nil {
191 return nil, err
192 }
193 } else {
194 // config file does not exist, create a new one
195 globalCfg = &Config{}
196 }
197
198 var localConfig *Config
199 // Global config loaded, now read the local config file
200 localConfigPath := filepath.Join(cwd, "crush.json")
201 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
202 // some other error occurred while checking the file
203 return nil, err
204 } else if err == nil {
205 // local config file exists, read it
206 file, err := os.ReadFile(localConfigPath)
207 if err != nil {
208 return nil, err
209 }
210 localConfig = &Config{}
211 if err := json.Unmarshal(file, localConfig); err != nil {
212 return nil, err
213 }
214 }
215
216 // merge options
217 mergeOptions(cfg, globalCfg, localConfig)
218
219 mergeProviderConfigs(cfg, globalCfg, localConfig)
220 // no providers found the app is not initialized yet
221 if len(cfg.Providers) == 0 {
222 return cfg, nil
223 }
224 preferredProvider := getPreferredProvider(cfg.Providers)
225
226 if preferredProvider == nil {
227 return nil, errors.New("no valid providers configured")
228 }
229
230 agents := map[AgentID]Agent{
231 AgentCoder: {
232 Name: "Coder",
233 Description: "An agent that helps with executing coding tasks.",
234 Provider: preferredProvider.ID,
235 Model: preferredProvider.DefaultLargeModel,
236 ContextPaths: cfg.Options.ContextPaths,
237 // All tools allowed
238 },
239 AgentTask: {
240 Name: "Task",
241 Description: "An agent that helps with searching for context and finding implementation details.",
242 Provider: preferredProvider.ID,
243 Model: preferredProvider.DefaultLargeModel,
244 ContextPaths: cfg.Options.ContextPaths,
245 AllowedTools: []string{
246 "glob",
247 "grep",
248 "ls",
249 "sourcegraph",
250 "view",
251 },
252 // NO MCPs or LSPs by default
253 AllowedMCP: map[string][]string{},
254 AllowedLSP: []string{},
255 },
256 AgentTitle: {
257 Name: "Title",
258 Description: "An agent that helps with generating titles for sessions.",
259 Provider: preferredProvider.ID,
260 Model: preferredProvider.DefaultSmallModel,
261 ContextPaths: cfg.Options.ContextPaths,
262 AllowedTools: []string{},
263 // NO MCPs or LSPs by default
264 AllowedMCP: map[string][]string{},
265 AllowedLSP: []string{},
266 },
267 AgentSummarize: {
268 Name: "Summarize",
269 Description: "An agent that helps with summarizing sessions.",
270 Provider: preferredProvider.ID,
271 Model: preferredProvider.DefaultSmallModel,
272 ContextPaths: cfg.Options.ContextPaths,
273 AllowedTools: []string{},
274 // NO MCPs or LSPs by default
275 AllowedMCP: map[string][]string{},
276 AllowedLSP: []string{},
277 },
278 }
279 cfg.Agents = agents
280 mergeAgents(cfg, globalCfg, localConfig)
281 mergeMCPs(cfg, globalCfg, localConfig)
282 mergeLSPs(cfg, globalCfg, localConfig)
283
284 return cfg, nil
285}
286
287func InitConfig(workingDir string) *Config {
288 once.Do(func() {
289 cwd = workingDir
290 cfg, err := loadConfig(cwd)
291 if err != nil {
292 // TODO: Handle this better
293 panic("Failed to load config: " + err.Error())
294 }
295 instance = cfg
296 })
297
298 return instance
299}
300
301func GetConfig() *Config {
302 if instance == nil {
303 // TODO: Handle this better
304 panic("Config not initialized. Call InitConfig first.")
305 }
306 return instance
307}
308
309func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
310 providers := Providers()
311 for _, p := range providers {
312 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
313 return &providerConfig
314 }
315 }
316 // if none found return the first configured provider
317 for _, providerConfig := range configuredProviders {
318 if !providerConfig.Disabled {
319 return &providerConfig
320 }
321 }
322 return nil
323}
324
325func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
326 if other.APIKey != "" {
327 base.APIKey = other.APIKey
328 }
329 // Only change these options if the provider is not a known provider
330 if !slices.Contains(provider.KnownProviders(), p) {
331 if other.BaseURL != "" {
332 base.BaseURL = other.BaseURL
333 }
334 if other.ProviderType != "" {
335 base.ProviderType = other.ProviderType
336 }
337 if len(base.ExtraHeaders) > 0 {
338 if base.ExtraHeaders == nil {
339 base.ExtraHeaders = make(map[string]string)
340 }
341 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
342 }
343 if len(other.ExtraParams) > 0 {
344 if base.ExtraParams == nil {
345 base.ExtraParams = make(map[string]string)
346 }
347 maps.Copy(base.ExtraParams, other.ExtraParams)
348 }
349 }
350
351 if other.Disabled {
352 base.Disabled = other.Disabled
353 }
354
355 if other.DefaultLargeModel != "" {
356 base.DefaultLargeModel = other.DefaultLargeModel
357 }
358 // Add new models if they don't exist
359 if other.Models != nil {
360 for _, model := range other.Models {
361 // check if the model already exists
362 exists := false
363 for _, existingModel := range base.Models {
364 if existingModel.ID == model.ID {
365 exists = true
366 break
367 }
368 }
369 if !exists {
370 base.Models = append(base.Models, model)
371 }
372 }
373 }
374
375 return base
376}
377
378func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
379 if !slices.Contains(provider.KnownProviders(), p) {
380 if providerConfig.ProviderType != provider.TypeOpenAI {
381 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
382 }
383 if providerConfig.BaseURL == "" {
384 return errors.New("base URL must be set for custom providers")
385 }
386 if providerConfig.APIKey == "" {
387 return errors.New("API key must be set for custom providers")
388 }
389 }
390 return nil
391}
392
393func mergeOptions(base, global, local *Config) {
394 for _, cfg := range []*Config{global, local} {
395 if cfg == nil {
396 continue
397 }
398 baseOptions := base.Options
399 other := cfg.Options
400 if len(other.ContextPaths) > 0 {
401 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
402 }
403
404 if other.TUI.CompactMode {
405 baseOptions.TUI.CompactMode = other.TUI.CompactMode
406 }
407
408 if other.Debug {
409 baseOptions.Debug = other.Debug
410 }
411
412 if other.DebugLSP {
413 baseOptions.DebugLSP = other.DebugLSP
414 }
415
416 if other.DisableAutoSummarize {
417 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
418 }
419
420 if other.DataDirectory != "" {
421 baseOptions.DataDirectory = other.DataDirectory
422 }
423 base.Options = baseOptions
424 }
425}
426
427func mergeAgents(base, global, local *Config) {
428 for _, cfg := range []*Config{global, local} {
429 if cfg == nil {
430 continue
431 }
432 for agentID, globalAgent := range cfg.Agents {
433 if _, ok := base.Agents[agentID]; !ok {
434 base.Agents[agentID] = globalAgent
435 } else {
436 switch agentID {
437 case AgentCoder:
438 baseAgent := base.Agents[agentID]
439 baseAgent.Model = globalAgent.Model
440 baseAgent.Provider = globalAgent.Provider
441 baseAgent.AllowedMCP = globalAgent.AllowedMCP
442 baseAgent.AllowedLSP = globalAgent.AllowedLSP
443 base.Agents[agentID] = baseAgent
444 case AgentTask:
445 baseAgent := base.Agents[agentID]
446 baseAgent.Model = globalAgent.Model
447 baseAgent.Provider = globalAgent.Provider
448 base.Agents[agentID] = baseAgent
449 case AgentTitle:
450 baseAgent := base.Agents[agentID]
451 baseAgent.Model = globalAgent.Model
452 baseAgent.Provider = globalAgent.Provider
453 base.Agents[agentID] = baseAgent
454 case AgentSummarize:
455 baseAgent := base.Agents[agentID]
456 baseAgent.Model = globalAgent.Model
457 baseAgent.Provider = globalAgent.Provider
458 base.Agents[agentID] = baseAgent
459 default:
460 baseAgent := base.Agents[agentID]
461 baseAgent.Name = globalAgent.Name
462 baseAgent.Description = globalAgent.Description
463 baseAgent.Disabled = globalAgent.Disabled
464 baseAgent.Provider = globalAgent.Provider
465 baseAgent.Model = globalAgent.Model
466 baseAgent.AllowedTools = globalAgent.AllowedTools
467 baseAgent.AllowedMCP = globalAgent.AllowedMCP
468 baseAgent.AllowedLSP = globalAgent.AllowedLSP
469 base.Agents[agentID] = baseAgent
470
471 }
472 }
473 }
474 }
475}
476
477func mergeMCPs(base, global, local *Config) {
478 for _, cfg := range []*Config{global, local} {
479 if cfg == nil {
480 continue
481 }
482 maps.Copy(base.MCP, cfg.MCP)
483 }
484}
485
486func mergeLSPs(base, global, local *Config) {
487 for _, cfg := range []*Config{global, local} {
488 if cfg == nil {
489 continue
490 }
491 maps.Copy(base.LSP, cfg.LSP)
492 }
493}
494
495func mergeProviderConfigs(base, global, local *Config) {
496 for _, cfg := range []*Config{global, local} {
497 if cfg == nil {
498 continue
499 }
500 for providerName, globalProvider := range cfg.Providers {
501 if _, ok := base.Providers[providerName]; !ok {
502 base.Providers[providerName] = globalProvider
503 } else {
504 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
505 }
506 }
507 }
508
509 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
510 for providerName, providerConfig := range base.Providers {
511 err := validateProvider(providerName, providerConfig)
512 if err != nil {
513 logging.Warn("Skipping provider", "name", providerName, "error", err)
514 }
515 finalProviders[providerName] = providerConfig
516 }
517 base.Providers = finalProviders
518}
519
520func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
521 switch providerId {
522 case provider.InferenceProviderAnthropic:
523 return ProviderConfig{
524 ID: providerId,
525 ProviderType: provider.TypeAnthropic,
526 }
527 case provider.InferenceProviderOpenAI:
528 return ProviderConfig{
529 ID: providerId,
530 ProviderType: provider.TypeOpenAI,
531 }
532 case provider.InferenceProviderGemini:
533 return ProviderConfig{
534 ID: providerId,
535 ProviderType: provider.TypeGemini,
536 }
537 case provider.InferenceProviderBedrock:
538 return ProviderConfig{
539 ID: providerId,
540 ProviderType: provider.TypeBedrock,
541 }
542 case provider.InferenceProviderAzure:
543 return ProviderConfig{
544 ID: providerId,
545 ProviderType: provider.TypeAzure,
546 }
547 case provider.InferenceProviderOpenRouter:
548 return ProviderConfig{
549 ID: providerId,
550 ProviderType: provider.TypeOpenAI,
551 BaseURL: "https://openrouter.ai/api/v1",
552 ExtraHeaders: map[string]string{
553 "HTTP-Referer": "crush.charm.land",
554 "X-Title": "Crush",
555 },
556 }
557 case provider.InferenceProviderXAI:
558 return ProviderConfig{
559 ID: providerId,
560 ProviderType: provider.TypeXAI,
561 BaseURL: "https://api.x.ai/v1",
562 }
563 case provider.InferenceProviderVertexAI:
564 return ProviderConfig{
565 ID: providerId,
566 ProviderType: provider.TypeVertexAI,
567 }
568 default:
569 return ProviderConfig{
570 ID: providerId,
571 ProviderType: provider.TypeOpenAI,
572 }
573 }
574}
575
576func defaultConfigBasedOnEnv() *Config {
577 cfg := &Config{
578 Options: Options{
579 DataDirectory: defaultDataDirectory,
580 ContextPaths: defaultContextPaths,
581 },
582 Providers: make(map[provider.InferenceProvider]ProviderConfig),
583 }
584
585 providers := Providers()
586
587 for _, p := range providers {
588 if strings.HasPrefix(p.APIKey, "$") {
589 envVar := strings.TrimPrefix(p.APIKey, "$")
590 if apiKey := os.Getenv(envVar); apiKey != "" {
591 providerConfig := providerDefaultConfig(p.ID)
592 providerConfig.APIKey = apiKey
593 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
594 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
595 for _, model := range p.Models {
596 providerConfig.Models = append(providerConfig.Models, Model{
597 ID: model.ID,
598 Name: model.Name,
599 CostPer1MIn: model.CostPer1MIn,
600 CostPer1MOut: model.CostPer1MOut,
601 CostPer1MInCached: model.CostPer1MInCached,
602 CostPer1MOutCached: model.CostPer1MOutCached,
603 ContextWindow: model.ContextWindow,
604 DefaultMaxTokens: model.DefaultMaxTokens,
605 CanReason: model.CanReason,
606 SupportsImages: model.SupportsImages,
607 })
608 }
609 cfg.Providers[p.ID] = providerConfig
610 }
611 }
612 }
613 // TODO: support local models
614
615 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
616 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
617 providerConfig.ExtraParams = map[string]string{
618 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
619 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
620 }
621 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
622 }
623
624 if hasAWSCredentials() {
625 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
626 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
627 }
628 return cfg
629}
630
631func hasAWSCredentials() bool {
632 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
633 return true
634 }
635
636 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
637 return true
638 }
639
640 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
641 return true
642 }
643
644 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
645 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
646 return true
647 }
648
649 return false
650}
651
652func WorkingDirectory() string {
653 return cwd
654}