1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17)
18
19const (
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22 appName = "crush"
23
24 MaxTokensFallbackDefault = 4096
25)
26
27var defaultContextPaths = []string{
28 ".github/copilot-instructions.md",
29 ".cursorrules",
30 ".cursor/rules/",
31 "CLAUDE.md",
32 "CLAUDE.local.md",
33 "GEMINI.md",
34 "gemini.md",
35 "crush.md",
36 "crush.local.md",
37 "Crush.md",
38 "Crush.local.md",
39 "CRUSH.md",
40 "CRUSH.local.md",
41}
42
43type AgentID string
44
45const (
46 AgentCoder AgentID = "coder"
47 AgentTask AgentID = "task"
48)
49
50type Model struct {
51 ID string `json:"id"`
52 Name string `json:"model"`
53 CostPer1MIn float64 `json:"cost_per_1m_in"`
54 CostPer1MOut float64 `json:"cost_per_1m_out"`
55 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
56 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
57 ContextWindow int64 `json:"context_window"`
58 DefaultMaxTokens int64 `json:"default_max_tokens"`
59 CanReason bool `json:"can_reason"`
60 ReasoningEffort string `json:"reasoning_effort"`
61 SupportsImages bool `json:"supports_attachments"`
62}
63
64type VertexAIOptions struct {
65 APIKey string `json:"api_key,omitempty"`
66 Project string `json:"project,omitempty"`
67 Location string `json:"location,omitempty"`
68}
69
70type ProviderConfig struct {
71 ID provider.InferenceProvider `json:"id"`
72 BaseURL string `json:"base_url,omitempty"`
73 ProviderType provider.Type `json:"provider_type"`
74 APIKey string `json:"api_key,omitempty"`
75 Disabled bool `json:"disabled"`
76 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
77 // used for e.x for vertex to set the project
78 ExtraParams map[string]string `json:"extra_params,omitempty"`
79
80 DefaultLargeModel string `json:"default_large_model,omitempty"`
81 DefaultSmallModel string `json:"default_small_model,omitempty"`
82
83 Models []Model `json:"models,omitempty"`
84}
85
86type Agent struct {
87 ID AgentID `json:"id"`
88 Name string `json:"name"`
89 Description string `json:"description,omitempty"`
90 // This is the id of the system prompt used by the agent
91 Disabled bool `json:"disabled"`
92
93 Provider provider.InferenceProvider `json:"provider"`
94 Model string `json:"model"`
95
96 // The available tools for the agent
97 // if this is nil, all tools are available
98 AllowedTools []string `json:"allowed_tools"`
99
100 // this tells us which MCPs are available for this agent
101 // if this is empty all mcps are available
102 // the string array is the list of tools from the AllowedMCP the agent has available
103 // if the string array is nil, all tools from the AllowedMCP are available
104 AllowedMCP map[string][]string `json:"allowed_mcp"`
105
106 // The list of LSPs that this agent can use
107 // if this is nil, all LSPs are available
108 AllowedLSP []string `json:"allowed_lsp"`
109
110 // Overrides the context paths for this agent
111 ContextPaths []string `json:"context_paths"`
112}
113
114type MCPType string
115
116const (
117 MCPStdio MCPType = "stdio"
118 MCPSse MCPType = "sse"
119)
120
121type MCP struct {
122 Command string `json:"command"`
123 Env []string `json:"env"`
124 Args []string `json:"args"`
125 Type MCPType `json:"type"`
126 URL string `json:"url"`
127 Headers map[string]string `json:"headers"`
128}
129
130type LSPConfig struct {
131 Disabled bool `json:"enabled"`
132 Command string `json:"command"`
133 Args []string `json:"args"`
134 Options any `json:"options"`
135}
136
137type TUIOptions struct {
138 CompactMode bool `json:"compact_mode"`
139 // Here we can add themes later or any TUI related options
140}
141
142type Options struct {
143 ContextPaths []string `json:"context_paths"`
144 TUI TUIOptions `json:"tui"`
145 Debug bool `json:"debug"`
146 DebugLSP bool `json:"debug_lsp"`
147 DisableAutoSummarize bool `json:"disable_auto_summarize"`
148 // Relative to the cwd
149 DataDirectory string `json:"data_directory"`
150}
151
152type PreferredModel struct {
153 ModelID string `json:"model_id"`
154 Provider provider.InferenceProvider `json:"provider"`
155}
156
157type PreferredModels struct {
158 Large PreferredModel `json:"large"`
159 Small PreferredModel `json:"small"`
160}
161
162type Config struct {
163 Models PreferredModels `json:"models"`
164 // List of configured providers
165 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
166
167 // List of configured agents
168 Agents map[AgentID]Agent `json:"agents,omitempty"`
169
170 // List of configured MCPs
171 MCP map[string]MCP `json:"mcp,omitempty"`
172
173 // List of configured LSPs
174 LSP map[string]LSPConfig `json:"lsp,omitempty"`
175
176 // Miscellaneous options
177 Options Options `json:"options"`
178}
179
180var (
181 instance *Config // The single instance of the Singleton
182 cwd string
183 once sync.Once // Ensures the initialization happens only once
184
185)
186
187func loadConfig(cwd string, debug bool) (*Config, error) {
188 // First read the global config file
189 cfgPath := ConfigPath()
190
191 cfg := defaultConfigBasedOnEnv()
192 cfg.Options.Debug = debug
193 defaultLevel := slog.LevelInfo
194 if cfg.Options.Debug {
195 defaultLevel = slog.LevelDebug
196 }
197 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
198 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
199
200 // if file does not exist create it
201 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
202 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
203 return cfg, fmt.Errorf("failed to create directory: %w", err)
204 }
205 if _, err := os.Create(loggingFile); err != nil {
206 return cfg, fmt.Errorf("failed to create log file: %w", err)
207 }
208 }
209
210 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
211 if err != nil {
212 return cfg, fmt.Errorf("failed to open log file: %w", err)
213 }
214 // Configure logger
215 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
216 Level: defaultLevel,
217 }))
218 slog.SetDefault(logger)
219 } else {
220 // Configure logger
221 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
222 Level: defaultLevel,
223 }))
224 slog.SetDefault(logger)
225 }
226 var globalCfg *Config
227 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
228 // some other error occurred while checking the file
229 return nil, err
230 } else if err == nil {
231 // config file exists, read it
232 file, err := os.ReadFile(cfgPath)
233 if err != nil {
234 return nil, err
235 }
236 globalCfg = &Config{}
237 if err := json.Unmarshal(file, globalCfg); err != nil {
238 return nil, err
239 }
240 } else {
241 // config file does not exist, create a new one
242 globalCfg = &Config{}
243 }
244
245 var localConfig *Config
246 // Global config loaded, now read the local config file
247 localConfigPath := filepath.Join(cwd, "crush.json")
248 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
249 // some other error occurred while checking the file
250 return nil, err
251 } else if err == nil {
252 // local config file exists, read it
253 file, err := os.ReadFile(localConfigPath)
254 if err != nil {
255 return nil, err
256 }
257 localConfig = &Config{}
258 if err := json.Unmarshal(file, localConfig); err != nil {
259 return nil, err
260 }
261 }
262
263 // merge options
264 mergeOptions(cfg, globalCfg, localConfig)
265
266 mergeProviderConfigs(cfg, globalCfg, localConfig)
267 // no providers found the app is not initialized yet
268 if len(cfg.Providers) == 0 {
269 return cfg, nil
270 }
271 preferredProvider := getPreferredProvider(cfg.Providers)
272 cfg.Models = PreferredModels{
273 Large: PreferredModel{
274 ModelID: preferredProvider.DefaultLargeModel,
275 Provider: preferredProvider.ID,
276 },
277 Small: PreferredModel{
278 ModelID: preferredProvider.DefaultSmallModel,
279 Provider: preferredProvider.ID,
280 },
281 }
282
283 mergeModels(cfg, globalCfg, localConfig)
284
285 if preferredProvider == nil {
286 return nil, errors.New("no valid providers configured")
287 }
288
289 agents := map[AgentID]Agent{
290 AgentCoder: {
291 ID: AgentCoder,
292 Name: "Coder",
293 Description: "An agent that helps with executing coding tasks.",
294 Provider: cfg.Models.Large.Provider,
295 Model: cfg.Models.Large.ModelID,
296 ContextPaths: cfg.Options.ContextPaths,
297 // All tools allowed
298 },
299 AgentTask: {
300 ID: AgentTask,
301 Name: "Task",
302 Description: "An agent that helps with searching for context and finding implementation details.",
303 Provider: cfg.Models.Large.Provider,
304 Model: cfg.Models.Large.ModelID,
305 ContextPaths: cfg.Options.ContextPaths,
306 AllowedTools: []string{
307 "glob",
308 "grep",
309 "ls",
310 "sourcegraph",
311 "view",
312 },
313 // NO MCPs or LSPs by default
314 AllowedMCP: map[string][]string{},
315 AllowedLSP: []string{},
316 },
317 }
318 cfg.Agents = agents
319 mergeAgents(cfg, globalCfg, localConfig)
320 mergeMCPs(cfg, globalCfg, localConfig)
321 mergeLSPs(cfg, globalCfg, localConfig)
322
323 return cfg, nil
324}
325
326func Init(workingDir string, debug bool) (*Config, error) {
327 var err error
328 once.Do(func() {
329 cwd = workingDir
330 instance, err = loadConfig(cwd, debug)
331 if err != nil {
332 logging.Error("Failed to load config", "error", err)
333 }
334 })
335
336 return instance, err
337}
338
339func Get() *Config {
340 if instance == nil {
341 // TODO: Handle this better
342 panic("Config not initialized. Call InitConfig first.")
343 }
344 return instance
345}
346
347func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
348 providers := Providers()
349 for _, p := range providers {
350 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
351 return &providerConfig
352 }
353 }
354 // if none found return the first configured provider
355 for _, providerConfig := range configuredProviders {
356 if !providerConfig.Disabled {
357 return &providerConfig
358 }
359 }
360 return nil
361}
362
363func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
364 if other.APIKey != "" {
365 base.APIKey = other.APIKey
366 }
367 // Only change these options if the provider is not a known provider
368 if !slices.Contains(provider.KnownProviders(), p) {
369 if other.BaseURL != "" {
370 base.BaseURL = other.BaseURL
371 }
372 if other.ProviderType != "" {
373 base.ProviderType = other.ProviderType
374 }
375 if len(base.ExtraHeaders) > 0 {
376 if base.ExtraHeaders == nil {
377 base.ExtraHeaders = make(map[string]string)
378 }
379 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
380 }
381 if len(other.ExtraParams) > 0 {
382 if base.ExtraParams == nil {
383 base.ExtraParams = make(map[string]string)
384 }
385 maps.Copy(base.ExtraParams, other.ExtraParams)
386 }
387 }
388
389 if other.Disabled {
390 base.Disabled = other.Disabled
391 }
392
393 if other.DefaultLargeModel != "" {
394 base.DefaultLargeModel = other.DefaultLargeModel
395 }
396 // Add new models if they don't exist
397 if other.Models != nil {
398 for _, model := range other.Models {
399 // check if the model already exists
400 exists := false
401 for _, existingModel := range base.Models {
402 if existingModel.ID == model.ID {
403 exists = true
404 break
405 }
406 }
407 if !exists {
408 base.Models = append(base.Models, model)
409 }
410 }
411 }
412
413 return base
414}
415
416func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
417 if !slices.Contains(provider.KnownProviders(), p) {
418 if providerConfig.ProviderType != provider.TypeOpenAI {
419 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
420 }
421 if providerConfig.BaseURL == "" {
422 return errors.New("base URL must be set for custom providers")
423 }
424 if providerConfig.APIKey == "" {
425 return errors.New("API key must be set for custom providers")
426 }
427 }
428 return nil
429}
430
431func mergeModels(base, global, local *Config) {
432 for _, cfg := range []*Config{global, local} {
433 if cfg == nil {
434 continue
435 }
436 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
437 base.Models.Large = cfg.Models.Large
438 }
439
440 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
441 base.Models.Small = cfg.Models.Small
442 }
443 }
444}
445
446func mergeOptions(base, global, local *Config) {
447 for _, cfg := range []*Config{global, local} {
448 if cfg == nil {
449 continue
450 }
451 baseOptions := base.Options
452 other := cfg.Options
453 if len(other.ContextPaths) > 0 {
454 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
455 }
456
457 if other.TUI.CompactMode {
458 baseOptions.TUI.CompactMode = other.TUI.CompactMode
459 }
460
461 if other.Debug {
462 baseOptions.Debug = other.Debug
463 }
464
465 if other.DebugLSP {
466 baseOptions.DebugLSP = other.DebugLSP
467 }
468
469 if other.DisableAutoSummarize {
470 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
471 }
472
473 if other.DataDirectory != "" {
474 baseOptions.DataDirectory = other.DataDirectory
475 }
476 base.Options = baseOptions
477 }
478}
479
480func mergeAgents(base, global, local *Config) {
481 for _, cfg := range []*Config{global, local} {
482 if cfg == nil {
483 continue
484 }
485 for agentID, newAgent := range cfg.Agents {
486 if _, ok := base.Agents[agentID]; !ok {
487 newAgent.ID = agentID // Ensure the ID is set correctly
488 base.Agents[agentID] = newAgent
489 } else {
490 switch agentID {
491 case AgentCoder:
492 baseAgent := base.Agents[agentID]
493 if newAgent.Model != "" && newAgent.Provider != "" {
494 baseAgent.Model = newAgent.Model
495 baseAgent.Provider = newAgent.Provider
496 }
497 baseAgent.AllowedMCP = newAgent.AllowedMCP
498 baseAgent.AllowedLSP = newAgent.AllowedLSP
499 base.Agents[agentID] = baseAgent
500 default:
501 baseAgent := base.Agents[agentID]
502 baseAgent.Name = newAgent.Name
503 baseAgent.Description = newAgent.Description
504 baseAgent.Disabled = newAgent.Disabled
505 if newAgent.Model == "" || newAgent.Provider == "" {
506 baseAgent.Provider = base.Models.Large.Provider
507 baseAgent.Model = base.Models.Large.ModelID
508 }
509 baseAgent.AllowedTools = newAgent.AllowedTools
510 baseAgent.AllowedMCP = newAgent.AllowedMCP
511 baseAgent.AllowedLSP = newAgent.AllowedLSP
512 base.Agents[agentID] = baseAgent
513 }
514 }
515 }
516 }
517}
518
519func mergeMCPs(base, global, local *Config) {
520 for _, cfg := range []*Config{global, local} {
521 if cfg == nil {
522 continue
523 }
524 maps.Copy(base.MCP, cfg.MCP)
525 }
526}
527
528func mergeLSPs(base, global, local *Config) {
529 for _, cfg := range []*Config{global, local} {
530 if cfg == nil {
531 continue
532 }
533 maps.Copy(base.LSP, cfg.LSP)
534 }
535}
536
537func mergeProviderConfigs(base, global, local *Config) {
538 for _, cfg := range []*Config{global, local} {
539 if cfg == nil {
540 continue
541 }
542 for providerName, globalProvider := range cfg.Providers {
543 if _, ok := base.Providers[providerName]; !ok {
544 base.Providers[providerName] = globalProvider
545 } else {
546 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
547 }
548 }
549 }
550
551 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
552 for providerName, providerConfig := range base.Providers {
553 err := validateProvider(providerName, providerConfig)
554 if err != nil {
555 logging.Warn("Skipping provider", "name", providerName, "error", err)
556 }
557 finalProviders[providerName] = providerConfig
558 }
559 base.Providers = finalProviders
560}
561
562func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
563 switch providerId {
564 case provider.InferenceProviderAnthropic:
565 return ProviderConfig{
566 ID: providerId,
567 ProviderType: provider.TypeAnthropic,
568 }
569 case provider.InferenceProviderOpenAI:
570 return ProviderConfig{
571 ID: providerId,
572 ProviderType: provider.TypeOpenAI,
573 }
574 case provider.InferenceProviderGemini:
575 return ProviderConfig{
576 ID: providerId,
577 ProviderType: provider.TypeGemini,
578 }
579 case provider.InferenceProviderBedrock:
580 return ProviderConfig{
581 ID: providerId,
582 ProviderType: provider.TypeBedrock,
583 }
584 case provider.InferenceProviderAzure:
585 return ProviderConfig{
586 ID: providerId,
587 ProviderType: provider.TypeAzure,
588 }
589 case provider.InferenceProviderOpenRouter:
590 return ProviderConfig{
591 ID: providerId,
592 ProviderType: provider.TypeOpenAI,
593 BaseURL: "https://openrouter.ai/api/v1",
594 ExtraHeaders: map[string]string{
595 "HTTP-Referer": "crush.charm.land",
596 "X-Title": "Crush",
597 },
598 }
599 case provider.InferenceProviderXAI:
600 return ProviderConfig{
601 ID: providerId,
602 ProviderType: provider.TypeXAI,
603 BaseURL: "https://api.x.ai/v1",
604 }
605 case provider.InferenceProviderVertexAI:
606 return ProviderConfig{
607 ID: providerId,
608 ProviderType: provider.TypeVertexAI,
609 }
610 default:
611 return ProviderConfig{
612 ID: providerId,
613 ProviderType: provider.TypeOpenAI,
614 }
615 }
616}
617
618func defaultConfigBasedOnEnv() *Config {
619 cfg := &Config{
620 Options: Options{
621 DataDirectory: defaultDataDirectory,
622 ContextPaths: defaultContextPaths,
623 },
624 Providers: make(map[provider.InferenceProvider]ProviderConfig),
625 }
626
627 providers := Providers()
628
629 for _, p := range providers {
630 if strings.HasPrefix(p.APIKey, "$") {
631 envVar := strings.TrimPrefix(p.APIKey, "$")
632 if apiKey := os.Getenv(envVar); apiKey != "" {
633 providerConfig := providerDefaultConfig(p.ID)
634 providerConfig.APIKey = apiKey
635 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
636 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
637 baseURL := p.APIEndpoint
638 if strings.HasPrefix(baseURL, "$") {
639 envVar := strings.TrimPrefix(baseURL, "$")
640 if url := os.Getenv(envVar); url != "" {
641 baseURL = url
642 }
643 }
644 providerConfig.BaseURL = baseURL
645 for _, model := range p.Models {
646 providerConfig.Models = append(providerConfig.Models, Model{
647 ID: model.ID,
648 Name: model.Name,
649 CostPer1MIn: model.CostPer1MIn,
650 CostPer1MOut: model.CostPer1MOut,
651 CostPer1MInCached: model.CostPer1MInCached,
652 CostPer1MOutCached: model.CostPer1MOutCached,
653 ContextWindow: model.ContextWindow,
654 DefaultMaxTokens: model.DefaultMaxTokens,
655 CanReason: model.CanReason,
656 SupportsImages: model.SupportsImages,
657 })
658 }
659 cfg.Providers[p.ID] = providerConfig
660 }
661 }
662 }
663 // TODO: support local models
664
665 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
666 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
667 providerConfig.ExtraParams = map[string]string{
668 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
669 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
670 }
671 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
672 }
673
674 if hasAWSCredentials() {
675 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
676 providerConfig.ExtraParams = map[string]string{
677 "region": os.Getenv("AWS_DEFAULT_REGION"),
678 }
679 if providerConfig.ExtraParams["region"] == "" {
680 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
681 }
682 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
683 }
684 return cfg
685}
686
687func hasAWSCredentials() bool {
688 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
689 return true
690 }
691
692 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
693 return true
694 }
695
696 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
697 return true
698 }
699
700 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
701 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
702 return true
703 }
704
705 return false
706}
707
708func WorkingDirectory() string {
709 return cwd
710}
711
712func GetAgentModel(agentID AgentID) Model {
713 cfg := Get()
714 agent, ok := cfg.Agents[agentID]
715 if !ok {
716 logging.Error("Agent not found", "agent_id", agentID)
717 return Model{}
718 }
719
720 providerConfig, ok := cfg.Providers[agent.Provider]
721 if !ok {
722 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", agent.Provider)
723 return Model{}
724 }
725
726 for _, model := range providerConfig.Models {
727 if model.ID == agent.Model {
728 return model
729 }
730 }
731
732 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
733 return Model{}
734}
735
736func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
737 cfg := Get()
738 providerConfig, ok := cfg.Providers[provider]
739 if !ok {
740 logging.Error("Provider not found", "provider", provider)
741 return Model{}
742 }
743
744 for _, model := range providerConfig.Models {
745 if model.ID == modelID {
746 return model
747 }
748 }
749
750 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
751 return Model{}
752}