1package configv2
2
3import (
4 "encoding/json"
5 "errors"
6 "maps"
7 "os"
8 "path/filepath"
9 "slices"
10 "strings"
11 "sync"
12
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/charmbracelet/crush/internal/logging"
15)
16
17const (
18 defaultDataDirectory = ".crush"
19 defaultLogLevel = "info"
20 appName = "crush"
21
22 MaxTokensFallbackDefault = 4096
23)
24
25var defaultContextPaths = []string{
26 ".github/copilot-instructions.md",
27 ".cursorrules",
28 ".cursor/rules/",
29 "CLAUDE.md",
30 "CLAUDE.local.md",
31 "crush.md",
32 "crush.local.md",
33 "Crush.md",
34 "Crush.local.md",
35 "CRUSH.md",
36 "CRUSH.local.md",
37}
38
39type AgentID string
40
41const (
42 AgentCoder AgentID = "coder"
43 AgentTask AgentID = "task"
44 AgentTitle AgentID = "title"
45 AgentSummarize AgentID = "summarize"
46)
47
48type Model struct {
49 ID string `json:"id"`
50 Name string `json:"model"`
51 CostPer1MIn float64 `json:"cost_per_1m_in"`
52 CostPer1MOut float64 `json:"cost_per_1m_out"`
53 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
54 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
55 ContextWindow int64 `json:"context_window"`
56 DefaultMaxTokens int64 `json:"default_max_tokens"`
57 CanReason bool `json:"can_reason"`
58 ReasoningEffort string `json:"reasoning_effort"`
59 SupportsImages bool `json:"supports_attachments"`
60}
61
62type VertexAIOptions struct {
63 APIKey string `json:"api_key,omitempty"`
64 Project string `json:"project,omitempty"`
65 Location string `json:"location,omitempty"`
66}
67
68type ProviderConfig struct {
69 ID provider.InferenceProvider `json:"id"`
70 BaseURL string `json:"base_url,omitempty"`
71 ProviderType provider.Type `json:"provider_type"`
72 APIKey string `json:"api_key,omitempty"`
73 Disabled bool `json:"disabled"`
74 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
75 // used for e.x for vertex to set the project
76 ExtraParams map[string]string `json:"extra_params,omitempty"`
77
78 DefaultLargeModel string `json:"default_large_model,omitempty"`
79 DefaultSmallModel string `json:"default_small_model,omitempty"`
80
81 Models []Model `json:"models,omitempty"`
82}
83
84type Agent struct {
85 ID AgentID `json:"id"`
86 Name string `json:"name"`
87 Description string `json:"description,omitempty"`
88 // This is the id of the system prompt used by the agent
89 Disabled bool `json:"disabled"`
90
91 Provider provider.InferenceProvider `json:"provider"`
92 Model string `json:"model"`
93
94 // The available tools for the agent
95 // if this is nil, all tools are available
96 AllowedTools []string `json:"allowed_tools"`
97
98 // this tells us which MCPs are available for this agent
99 // if this is empty all mcps are available
100 // the string array is the list of tools from the AllowedMCP the agent has available
101 // if the string array is nil, all tools from the AllowedMCP are available
102 AllowedMCP map[string][]string `json:"allowed_mcp"`
103
104 // The list of LSPs that this agent can use
105 // if this is nil, all LSPs are available
106 AllowedLSP []string `json:"allowed_lsp"`
107
108 // Overrides the context paths for this agent
109 ContextPaths []string `json:"context_paths"`
110}
111
112type MCPType string
113
114const (
115 MCPStdio MCPType = "stdio"
116 MCPSse MCPType = "sse"
117)
118
119type MCP struct {
120 Command string `json:"command"`
121 Env []string `json:"env"`
122 Args []string `json:"args"`
123 Type MCPType `json:"type"`
124 URL string `json:"url"`
125 Headers map[string]string `json:"headers"`
126}
127
128type LSPConfig struct {
129 Disabled bool `json:"enabled"`
130 Command string `json:"command"`
131 Args []string `json:"args"`
132 Options any `json:"options"`
133}
134
135type TUIOptions struct {
136 CompactMode bool `json:"compact_mode"`
137 // Here we can add themes later or any TUI related options
138}
139
140type Options struct {
141 ContextPaths []string `json:"context_paths"`
142 TUI TUIOptions `json:"tui"`
143 Debug bool `json:"debug"`
144 DebugLSP bool `json:"debug_lsp"`
145 DisableAutoSummarize bool `json:"disable_auto_summarize"`
146 // Relative to the cwd
147 DataDirectory string `json:"data_directory"`
148}
149
150type Config struct {
151 // List of configured providers
152 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
153
154 // List of configured agents
155 Agents map[AgentID]Agent `json:"agents,omitempty"`
156
157 // List of configured MCPs
158 MCP map[string]MCP `json:"mcp,omitempty"`
159
160 // List of configured LSPs
161 LSP map[string]LSPConfig `json:"lsp,omitempty"`
162
163 // Miscellaneous options
164 Options Options `json:"options"`
165}
166
167var (
168 instance *Config // The single instance of the Singleton
169 cwd string
170 once sync.Once // Ensures the initialization happens only once
171
172)
173
174func loadConfig(cwd string) (*Config, error) {
175 // First read the global config file
176 cfgPath := ConfigPath()
177
178 cfg := defaultConfigBasedOnEnv()
179
180 var globalCfg *Config
181 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
182 // some other error occurred while checking the file
183 return nil, err
184 } else if err == nil {
185 // config file exists, read it
186 file, err := os.ReadFile(cfgPath)
187 if err != nil {
188 return nil, err
189 }
190 globalCfg = &Config{}
191 if err := json.Unmarshal(file, globalCfg); err != nil {
192 return nil, err
193 }
194 } else {
195 // config file does not exist, create a new one
196 globalCfg = &Config{}
197 }
198
199 var localConfig *Config
200 // Global config loaded, now read the local config file
201 localConfigPath := filepath.Join(cwd, "crush.json")
202 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
203 // some other error occurred while checking the file
204 return nil, err
205 } else if err == nil {
206 // local config file exists, read it
207 file, err := os.ReadFile(localConfigPath)
208 if err != nil {
209 return nil, err
210 }
211 localConfig = &Config{}
212 if err := json.Unmarshal(file, localConfig); err != nil {
213 return nil, err
214 }
215 }
216
217 // merge options
218 mergeOptions(cfg, globalCfg, localConfig)
219
220 mergeProviderConfigs(cfg, globalCfg, localConfig)
221 // no providers found the app is not initialized yet
222 if len(cfg.Providers) == 0 {
223 return cfg, nil
224 }
225 preferredProvider := getPreferredProvider(cfg.Providers)
226
227 if preferredProvider == nil {
228 return nil, errors.New("no valid providers configured")
229 }
230
231 agents := map[AgentID]Agent{
232 AgentCoder: {
233 ID: AgentCoder,
234 Name: "Coder",
235 Description: "An agent that helps with executing coding tasks.",
236 Provider: preferredProvider.ID,
237 Model: preferredProvider.DefaultLargeModel,
238 ContextPaths: cfg.Options.ContextPaths,
239 // All tools allowed
240 },
241 AgentTask: {
242 ID: AgentTask,
243 Name: "Task",
244 Description: "An agent that helps with searching for context and finding implementation details.",
245 Provider: preferredProvider.ID,
246 Model: preferredProvider.DefaultLargeModel,
247 ContextPaths: cfg.Options.ContextPaths,
248 AllowedTools: []string{
249 "glob",
250 "grep",
251 "ls",
252 "sourcegraph",
253 "view",
254 },
255 // NO MCPs or LSPs by default
256 AllowedMCP: map[string][]string{},
257 AllowedLSP: []string{},
258 },
259 AgentTitle: {
260 ID: AgentTitle,
261 Name: "Title",
262 Description: "An agent that helps with generating titles for sessions.",
263 Provider: preferredProvider.ID,
264 Model: preferredProvider.DefaultSmallModel,
265 ContextPaths: cfg.Options.ContextPaths,
266 AllowedTools: []string{},
267 // NO MCPs or LSPs by default
268 AllowedMCP: map[string][]string{},
269 AllowedLSP: []string{},
270 },
271 AgentSummarize: {
272 ID: AgentSummarize,
273 Name: "Summarize",
274 Description: "An agent that helps with summarizing sessions.",
275 Provider: preferredProvider.ID,
276 Model: preferredProvider.DefaultSmallModel,
277 ContextPaths: cfg.Options.ContextPaths,
278 AllowedTools: []string{},
279 // NO MCPs or LSPs by default
280 AllowedMCP: map[string][]string{},
281 AllowedLSP: []string{},
282 },
283 }
284 cfg.Agents = agents
285 mergeAgents(cfg, globalCfg, localConfig)
286 mergeMCPs(cfg, globalCfg, localConfig)
287 mergeLSPs(cfg, globalCfg, localConfig)
288
289 return cfg, nil
290}
291
292func InitConfig(workingDir string) *Config {
293 once.Do(func() {
294 cwd = workingDir
295 cfg, err := loadConfig(cwd)
296 if err != nil {
297 // TODO: Handle this better
298 panic("Failed to load config: " + err.Error())
299 }
300 instance = cfg
301 })
302
303 return instance
304}
305
306func GetConfig() *Config {
307 if instance == nil {
308 // TODO: Handle this better
309 panic("Config not initialized. Call InitConfig first.")
310 }
311 return instance
312}
313
314func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
315 providers := Providers()
316 for _, p := range providers {
317 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
318 return &providerConfig
319 }
320 }
321 // if none found return the first configured provider
322 for _, providerConfig := range configuredProviders {
323 if !providerConfig.Disabled {
324 return &providerConfig
325 }
326 }
327 return nil
328}
329
330func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
331 if other.APIKey != "" {
332 base.APIKey = other.APIKey
333 }
334 // Only change these options if the provider is not a known provider
335 if !slices.Contains(provider.KnownProviders(), p) {
336 if other.BaseURL != "" {
337 base.BaseURL = other.BaseURL
338 }
339 if other.ProviderType != "" {
340 base.ProviderType = other.ProviderType
341 }
342 if len(base.ExtraHeaders) > 0 {
343 if base.ExtraHeaders == nil {
344 base.ExtraHeaders = make(map[string]string)
345 }
346 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
347 }
348 if len(other.ExtraParams) > 0 {
349 if base.ExtraParams == nil {
350 base.ExtraParams = make(map[string]string)
351 }
352 maps.Copy(base.ExtraParams, other.ExtraParams)
353 }
354 }
355
356 if other.Disabled {
357 base.Disabled = other.Disabled
358 }
359
360 if other.DefaultLargeModel != "" {
361 base.DefaultLargeModel = other.DefaultLargeModel
362 }
363 // Add new models if they don't exist
364 if other.Models != nil {
365 for _, model := range other.Models {
366 // check if the model already exists
367 exists := false
368 for _, existingModel := range base.Models {
369 if existingModel.ID == model.ID {
370 exists = true
371 break
372 }
373 }
374 if !exists {
375 base.Models = append(base.Models, model)
376 }
377 }
378 }
379
380 return base
381}
382
383func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
384 if !slices.Contains(provider.KnownProviders(), p) {
385 if providerConfig.ProviderType != provider.TypeOpenAI {
386 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
387 }
388 if providerConfig.BaseURL == "" {
389 return errors.New("base URL must be set for custom providers")
390 }
391 if providerConfig.APIKey == "" {
392 return errors.New("API key must be set for custom providers")
393 }
394 }
395 return nil
396}
397
398func mergeOptions(base, global, local *Config) {
399 for _, cfg := range []*Config{global, local} {
400 if cfg == nil {
401 continue
402 }
403 baseOptions := base.Options
404 other := cfg.Options
405 if len(other.ContextPaths) > 0 {
406 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
407 }
408
409 if other.TUI.CompactMode {
410 baseOptions.TUI.CompactMode = other.TUI.CompactMode
411 }
412
413 if other.Debug {
414 baseOptions.Debug = other.Debug
415 }
416
417 if other.DebugLSP {
418 baseOptions.DebugLSP = other.DebugLSP
419 }
420
421 if other.DisableAutoSummarize {
422 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
423 }
424
425 if other.DataDirectory != "" {
426 baseOptions.DataDirectory = other.DataDirectory
427 }
428 base.Options = baseOptions
429 }
430}
431
432func mergeAgents(base, global, local *Config) {
433 for _, cfg := range []*Config{global, local} {
434 if cfg == nil {
435 continue
436 }
437 for agentID, newAgent := range cfg.Agents {
438 if _, ok := base.Agents[agentID]; !ok {
439 newAgent.ID = agentID // Ensure the ID is set correctly
440 base.Agents[agentID] = newAgent
441 } else {
442 switch agentID {
443 case AgentCoder:
444 baseAgent := base.Agents[agentID]
445 baseAgent.Model = newAgent.Model
446 baseAgent.Provider = newAgent.Provider
447 baseAgent.AllowedMCP = newAgent.AllowedMCP
448 baseAgent.AllowedLSP = newAgent.AllowedLSP
449 base.Agents[agentID] = baseAgent
450 case AgentTask:
451 baseAgent := base.Agents[agentID]
452 baseAgent.Model = newAgent.Model
453 baseAgent.Provider = newAgent.Provider
454 base.Agents[agentID] = baseAgent
455 case AgentTitle:
456 baseAgent := base.Agents[agentID]
457 baseAgent.Model = newAgent.Model
458 baseAgent.Provider = newAgent.Provider
459 base.Agents[agentID] = baseAgent
460 case AgentSummarize:
461 baseAgent := base.Agents[agentID]
462 baseAgent.Model = newAgent.Model
463 baseAgent.Provider = newAgent.Provider
464 base.Agents[agentID] = baseAgent
465 default:
466 baseAgent := base.Agents[agentID]
467 baseAgent.Name = newAgent.Name
468 baseAgent.Description = newAgent.Description
469 baseAgent.Disabled = newAgent.Disabled
470 baseAgent.Provider = newAgent.Provider
471 baseAgent.Model = newAgent.Model
472 baseAgent.AllowedTools = newAgent.AllowedTools
473 baseAgent.AllowedMCP = newAgent.AllowedMCP
474 baseAgent.AllowedLSP = newAgent.AllowedLSP
475 base.Agents[agentID] = baseAgent
476
477 }
478 }
479 }
480 }
481}
482
483func mergeMCPs(base, global, local *Config) {
484 for _, cfg := range []*Config{global, local} {
485 if cfg == nil {
486 continue
487 }
488 maps.Copy(base.MCP, cfg.MCP)
489 }
490}
491
492func mergeLSPs(base, global, local *Config) {
493 for _, cfg := range []*Config{global, local} {
494 if cfg == nil {
495 continue
496 }
497 maps.Copy(base.LSP, cfg.LSP)
498 }
499}
500
501func mergeProviderConfigs(base, global, local *Config) {
502 for _, cfg := range []*Config{global, local} {
503 if cfg == nil {
504 continue
505 }
506 for providerName, globalProvider := range cfg.Providers {
507 if _, ok := base.Providers[providerName]; !ok {
508 base.Providers[providerName] = globalProvider
509 } else {
510 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
511 }
512 }
513 }
514
515 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
516 for providerName, providerConfig := range base.Providers {
517 err := validateProvider(providerName, providerConfig)
518 if err != nil {
519 logging.Warn("Skipping provider", "name", providerName, "error", err)
520 }
521 finalProviders[providerName] = providerConfig
522 }
523 base.Providers = finalProviders
524}
525
526func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
527 switch providerId {
528 case provider.InferenceProviderAnthropic:
529 return ProviderConfig{
530 ID: providerId,
531 ProviderType: provider.TypeAnthropic,
532 }
533 case provider.InferenceProviderOpenAI:
534 return ProviderConfig{
535 ID: providerId,
536 ProviderType: provider.TypeOpenAI,
537 }
538 case provider.InferenceProviderGemini:
539 return ProviderConfig{
540 ID: providerId,
541 ProviderType: provider.TypeGemini,
542 }
543 case provider.InferenceProviderBedrock:
544 return ProviderConfig{
545 ID: providerId,
546 ProviderType: provider.TypeBedrock,
547 }
548 case provider.InferenceProviderAzure:
549 return ProviderConfig{
550 ID: providerId,
551 ProviderType: provider.TypeAzure,
552 }
553 case provider.InferenceProviderOpenRouter:
554 return ProviderConfig{
555 ID: providerId,
556 ProviderType: provider.TypeOpenAI,
557 BaseURL: "https://openrouter.ai/api/v1",
558 ExtraHeaders: map[string]string{
559 "HTTP-Referer": "crush.charm.land",
560 "X-Title": "Crush",
561 },
562 }
563 case provider.InferenceProviderXAI:
564 return ProviderConfig{
565 ID: providerId,
566 ProviderType: provider.TypeXAI,
567 BaseURL: "https://api.x.ai/v1",
568 }
569 case provider.InferenceProviderVertexAI:
570 return ProviderConfig{
571 ID: providerId,
572 ProviderType: provider.TypeVertexAI,
573 }
574 default:
575 return ProviderConfig{
576 ID: providerId,
577 ProviderType: provider.TypeOpenAI,
578 }
579 }
580}
581
582func defaultConfigBasedOnEnv() *Config {
583 cfg := &Config{
584 Options: Options{
585 DataDirectory: defaultDataDirectory,
586 ContextPaths: defaultContextPaths,
587 },
588 Providers: make(map[provider.InferenceProvider]ProviderConfig),
589 }
590
591 providers := Providers()
592
593 for _, p := range providers {
594 if strings.HasPrefix(p.APIKey, "$") {
595 envVar := strings.TrimPrefix(p.APIKey, "$")
596 if apiKey := os.Getenv(envVar); apiKey != "" {
597 providerConfig := providerDefaultConfig(p.ID)
598 providerConfig.APIKey = apiKey
599 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
600 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
601 for _, model := range p.Models {
602 providerConfig.Models = append(providerConfig.Models, Model{
603 ID: model.ID,
604 Name: model.Name,
605 CostPer1MIn: model.CostPer1MIn,
606 CostPer1MOut: model.CostPer1MOut,
607 CostPer1MInCached: model.CostPer1MInCached,
608 CostPer1MOutCached: model.CostPer1MOutCached,
609 ContextWindow: model.ContextWindow,
610 DefaultMaxTokens: model.DefaultMaxTokens,
611 CanReason: model.CanReason,
612 SupportsImages: model.SupportsImages,
613 })
614 }
615 cfg.Providers[p.ID] = providerConfig
616 }
617 }
618 }
619 // TODO: support local models
620
621 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
622 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
623 providerConfig.ExtraParams = map[string]string{
624 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
625 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
626 }
627 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
628 }
629
630 if hasAWSCredentials() {
631 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
632 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
633 }
634 return cfg
635}
636
637func hasAWSCredentials() bool {
638 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
639 return true
640 }
641
642 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
643 return true
644 }
645
646 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
647 return true
648 }
649
650 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
651 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
652 return true
653 }
654
655 return false
656}
657
658func WorkingDirectory() string {
659 return cwd
660}