config.go

  1package configv2
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"maps"
  7	"os"
  8	"path/filepath"
  9	"slices"
 10	"strings"
 11	"sync"
 12
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/logging"
 15)
 16
 17const (
 18	defaultDataDirectory = ".crush"
 19	defaultLogLevel      = "info"
 20	appName              = "crush"
 21
 22	MaxTokensFallbackDefault = 4096
 23)
 24
 25var defaultContextPaths = []string{
 26	".github/copilot-instructions.md",
 27	".cursorrules",
 28	".cursor/rules/",
 29	"CLAUDE.md",
 30	"CLAUDE.local.md",
 31	"crush.md",
 32	"crush.local.md",
 33	"Crush.md",
 34	"Crush.local.md",
 35	"CRUSH.md",
 36	"CRUSH.local.md",
 37}
 38
 39type AgentID string
 40
 41const (
 42	AgentCoder     AgentID = "coder"
 43	AgentTask      AgentID = "task"
 44	AgentTitle     AgentID = "title"
 45	AgentSummarize AgentID = "summarize"
 46)
 47
 48type Model struct {
 49	ID                 string  `json:"id"`
 50	Name               string  `json:"model"`
 51	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 52	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 53	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
 54	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
 55	ContextWindow      int64   `json:"context_window"`
 56	DefaultMaxTokens   int64   `json:"default_max_tokens"`
 57	CanReason          bool    `json:"can_reason"`
 58	ReasoningEffort    string  `json:"reasoning_effort"`
 59	SupportsImages     bool    `json:"supports_attachments"`
 60}
 61
 62type VertexAIOptions struct {
 63	APIKey   string `json:"api_key,omitempty"`
 64	Project  string `json:"project,omitempty"`
 65	Location string `json:"location,omitempty"`
 66}
 67
 68type ProviderConfig struct {
 69	ID           provider.InferenceProvider `json:"id"`
 70	BaseURL      string                     `json:"base_url,omitempty"`
 71	ProviderType provider.Type              `json:"provider_type"`
 72	APIKey       string                     `json:"api_key,omitempty"`
 73	Disabled     bool                       `json:"disabled"`
 74	ExtraHeaders map[string]string          `json:"extra_headers,omitempty"`
 75	// used for e.x for vertex to set the project
 76	ExtraParams map[string]string `json:"extra_params,omitempty"`
 77
 78	DefaultLargeModel string `json:"default_large_model,omitempty"`
 79	DefaultSmallModel string `json:"default_small_model,omitempty"`
 80
 81	Models []Model `json:"models,omitempty"`
 82}
 83
 84type Agent struct {
 85	ID          AgentID `json:"id"`
 86	Name        string  `json:"name"`
 87	Description string  `json:"description,omitempty"`
 88	// This is the id of the system prompt used by the agent
 89	Disabled bool `json:"disabled"`
 90
 91	Provider provider.InferenceProvider `json:"provider"`
 92	Model    string                     `json:"model"`
 93
 94	// The available tools for the agent
 95	//  if this is nil, all tools are available
 96	AllowedTools []string `json:"allowed_tools"`
 97
 98	// this tells us which MCPs are available for this agent
 99	//  if this is empty all mcps are available
100	//  the string array is the list of tools from the AllowedMCP the agent has available
101	//  if the string array is nil, all tools from the AllowedMCP are available
102	AllowedMCP map[string][]string `json:"allowed_mcp"`
103
104	// The list of LSPs that this agent can use
105	//  if this is nil, all LSPs are available
106	AllowedLSP []string `json:"allowed_lsp"`
107
108	// Overrides the context paths for this agent
109	ContextPaths []string `json:"context_paths"`
110}
111
112type MCPType string
113
114const (
115	MCPStdio MCPType = "stdio"
116	MCPSse   MCPType = "sse"
117)
118
119type MCP struct {
120	Command string            `json:"command"`
121	Env     []string          `json:"env"`
122	Args    []string          `json:"args"`
123	Type    MCPType           `json:"type"`
124	URL     string            `json:"url"`
125	Headers map[string]string `json:"headers"`
126}
127
128type LSPConfig struct {
129	Disabled bool     `json:"enabled"`
130	Command  string   `json:"command"`
131	Args     []string `json:"args"`
132	Options  any      `json:"options"`
133}
134
135type TUIOptions struct {
136	CompactMode bool `json:"compact_mode"`
137	// Here we can add themes later or any TUI related options
138}
139
140type Options struct {
141	ContextPaths         []string   `json:"context_paths"`
142	TUI                  TUIOptions `json:"tui"`
143	Debug                bool       `json:"debug"`
144	DebugLSP             bool       `json:"debug_lsp"`
145	DisableAutoSummarize bool       `json:"disable_auto_summarize"`
146	// Relative to the cwd
147	DataDirectory string `json:"data_directory"`
148}
149
150type Config struct {
151	// List of configured providers
152	Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
153
154	// List of configured agents
155	Agents map[AgentID]Agent `json:"agents,omitempty"`
156
157	// List of configured MCPs
158	MCP map[string]MCP `json:"mcp,omitempty"`
159
160	// List of configured LSPs
161	LSP map[string]LSPConfig `json:"lsp,omitempty"`
162
163	// Miscellaneous options
164	Options Options `json:"options"`
165}
166
167var (
168	instance *Config // The single instance of the Singleton
169	cwd      string
170	once     sync.Once // Ensures the initialization happens only once
171
172)
173
174func loadConfig(cwd string) (*Config, error) {
175	// First read the global config file
176	cfgPath := ConfigPath()
177
178	cfg := defaultConfigBasedOnEnv()
179
180	var globalCfg *Config
181	if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
182		// some other error occurred while checking the file
183		return nil, err
184	} else if err == nil {
185		// config file exists, read it
186		file, err := os.ReadFile(cfgPath)
187		if err != nil {
188			return nil, err
189		}
190		globalCfg = &Config{}
191		if err := json.Unmarshal(file, globalCfg); err != nil {
192			return nil, err
193		}
194	} else {
195		// config file does not exist, create a new one
196		globalCfg = &Config{}
197	}
198
199	var localConfig *Config
200	// Global config loaded, now read the local config file
201	localConfigPath := filepath.Join(cwd, "crush.json")
202	if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
203		// some other error occurred while checking the file
204		return nil, err
205	} else if err == nil {
206		// local config file exists, read it
207		file, err := os.ReadFile(localConfigPath)
208		if err != nil {
209			return nil, err
210		}
211		localConfig = &Config{}
212		if err := json.Unmarshal(file, localConfig); err != nil {
213			return nil, err
214		}
215	}
216
217	// merge options
218	mergeOptions(cfg, globalCfg, localConfig)
219
220	mergeProviderConfigs(cfg, globalCfg, localConfig)
221	// no providers found the app is not initialized yet
222	if len(cfg.Providers) == 0 {
223		return cfg, nil
224	}
225	preferredProvider := getPreferredProvider(cfg.Providers)
226
227	if preferredProvider == nil {
228		return nil, errors.New("no valid providers configured")
229	}
230
231	agents := map[AgentID]Agent{
232		AgentCoder: {
233			ID:           AgentCoder,
234			Name:         "Coder",
235			Description:  "An agent that helps with executing coding tasks.",
236			Provider:     preferredProvider.ID,
237			Model:        preferredProvider.DefaultLargeModel,
238			ContextPaths: cfg.Options.ContextPaths,
239			// All tools allowed
240		},
241		AgentTask: {
242			ID:           AgentTask,
243			Name:         "Task",
244			Description:  "An agent that helps with searching for context and finding implementation details.",
245			Provider:     preferredProvider.ID,
246			Model:        preferredProvider.DefaultLargeModel,
247			ContextPaths: cfg.Options.ContextPaths,
248			AllowedTools: []string{
249				"glob",
250				"grep",
251				"ls",
252				"sourcegraph",
253				"view",
254			},
255			// NO MCPs or LSPs by default
256			AllowedMCP: map[string][]string{},
257			AllowedLSP: []string{},
258		},
259		AgentTitle: {
260			ID:           AgentTitle,
261			Name:         "Title",
262			Description:  "An agent that helps with generating titles for sessions.",
263			Provider:     preferredProvider.ID,
264			Model:        preferredProvider.DefaultSmallModel,
265			ContextPaths: cfg.Options.ContextPaths,
266			AllowedTools: []string{},
267			// NO MCPs or LSPs by default
268			AllowedMCP: map[string][]string{},
269			AllowedLSP: []string{},
270		},
271		AgentSummarize: {
272			ID:           AgentSummarize,
273			Name:         "Summarize",
274			Description:  "An agent that helps with summarizing sessions.",
275			Provider:     preferredProvider.ID,
276			Model:        preferredProvider.DefaultSmallModel,
277			ContextPaths: cfg.Options.ContextPaths,
278			AllowedTools: []string{},
279			// NO MCPs or LSPs by default
280			AllowedMCP: map[string][]string{},
281			AllowedLSP: []string{},
282		},
283	}
284	cfg.Agents = agents
285	mergeAgents(cfg, globalCfg, localConfig)
286	mergeMCPs(cfg, globalCfg, localConfig)
287	mergeLSPs(cfg, globalCfg, localConfig)
288
289	return cfg, nil
290}
291
292func InitConfig(workingDir string) *Config {
293	once.Do(func() {
294		cwd = workingDir
295		cfg, err := loadConfig(cwd)
296		if err != nil {
297			// TODO: Handle this better
298			panic("Failed to load config: " + err.Error())
299		}
300		instance = cfg
301	})
302
303	return instance
304}
305
306func GetConfig() *Config {
307	if instance == nil {
308		// TODO: Handle this better
309		panic("Config not initialized. Call InitConfig first.")
310	}
311	return instance
312}
313
314func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
315	providers := Providers()
316	for _, p := range providers {
317		if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
318			return &providerConfig
319		}
320	}
321	// if none found return the first configured provider
322	for _, providerConfig := range configuredProviders {
323		if !providerConfig.Disabled {
324			return &providerConfig
325		}
326	}
327	return nil
328}
329
330func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
331	if other.APIKey != "" {
332		base.APIKey = other.APIKey
333	}
334	// Only change these options if the provider is not a known provider
335	if !slices.Contains(provider.KnownProviders(), p) {
336		if other.BaseURL != "" {
337			base.BaseURL = other.BaseURL
338		}
339		if other.ProviderType != "" {
340			base.ProviderType = other.ProviderType
341		}
342		if len(base.ExtraHeaders) > 0 {
343			if base.ExtraHeaders == nil {
344				base.ExtraHeaders = make(map[string]string)
345			}
346			maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
347		}
348		if len(other.ExtraParams) > 0 {
349			if base.ExtraParams == nil {
350				base.ExtraParams = make(map[string]string)
351			}
352			maps.Copy(base.ExtraParams, other.ExtraParams)
353		}
354	}
355
356	if other.Disabled {
357		base.Disabled = other.Disabled
358	}
359
360	if other.DefaultLargeModel != "" {
361		base.DefaultLargeModel = other.DefaultLargeModel
362	}
363	// Add new models if they don't exist
364	if other.Models != nil {
365		for _, model := range other.Models {
366			// check if the model already exists
367			exists := false
368			for _, existingModel := range base.Models {
369				if existingModel.ID == model.ID {
370					exists = true
371					break
372				}
373			}
374			if !exists {
375				base.Models = append(base.Models, model)
376			}
377		}
378	}
379
380	return base
381}
382
383func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
384	if !slices.Contains(provider.KnownProviders(), p) {
385		if providerConfig.ProviderType != provider.TypeOpenAI {
386			return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
387		}
388		if providerConfig.BaseURL == "" {
389			return errors.New("base URL must be set for custom providers")
390		}
391		if providerConfig.APIKey == "" {
392			return errors.New("API key must be set for custom providers")
393		}
394	}
395	return nil
396}
397
398func mergeOptions(base, global, local *Config) {
399	for _, cfg := range []*Config{global, local} {
400		if cfg == nil {
401			continue
402		}
403		baseOptions := base.Options
404		other := cfg.Options
405		if len(other.ContextPaths) > 0 {
406			baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
407		}
408
409		if other.TUI.CompactMode {
410			baseOptions.TUI.CompactMode = other.TUI.CompactMode
411		}
412
413		if other.Debug {
414			baseOptions.Debug = other.Debug
415		}
416
417		if other.DebugLSP {
418			baseOptions.DebugLSP = other.DebugLSP
419		}
420
421		if other.DisableAutoSummarize {
422			baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
423		}
424
425		if other.DataDirectory != "" {
426			baseOptions.DataDirectory = other.DataDirectory
427		}
428		base.Options = baseOptions
429	}
430}
431
432func mergeAgents(base, global, local *Config) {
433	for _, cfg := range []*Config{global, local} {
434		if cfg == nil {
435			continue
436		}
437		for agentID, newAgent := range cfg.Agents {
438			if _, ok := base.Agents[agentID]; !ok {
439				newAgent.ID = agentID // Ensure the ID is set correctly
440				base.Agents[agentID] = newAgent
441			} else {
442				switch agentID {
443				case AgentCoder:
444					baseAgent := base.Agents[agentID]
445					baseAgent.Model = newAgent.Model
446					baseAgent.Provider = newAgent.Provider
447					baseAgent.AllowedMCP = newAgent.AllowedMCP
448					baseAgent.AllowedLSP = newAgent.AllowedLSP
449					base.Agents[agentID] = baseAgent
450				case AgentTask:
451					baseAgent := base.Agents[agentID]
452					baseAgent.Model = newAgent.Model
453					baseAgent.Provider = newAgent.Provider
454					base.Agents[agentID] = baseAgent
455				case AgentTitle:
456					baseAgent := base.Agents[agentID]
457					baseAgent.Model = newAgent.Model
458					baseAgent.Provider = newAgent.Provider
459					base.Agents[agentID] = baseAgent
460				case AgentSummarize:
461					baseAgent := base.Agents[agentID]
462					baseAgent.Model = newAgent.Model
463					baseAgent.Provider = newAgent.Provider
464					base.Agents[agentID] = baseAgent
465				default:
466					baseAgent := base.Agents[agentID]
467					baseAgent.Name = newAgent.Name
468					baseAgent.Description = newAgent.Description
469					baseAgent.Disabled = newAgent.Disabled
470					baseAgent.Provider = newAgent.Provider
471					baseAgent.Model = newAgent.Model
472					baseAgent.AllowedTools = newAgent.AllowedTools
473					baseAgent.AllowedMCP = newAgent.AllowedMCP
474					baseAgent.AllowedLSP = newAgent.AllowedLSP
475					base.Agents[agentID] = baseAgent
476
477				}
478			}
479		}
480	}
481}
482
483func mergeMCPs(base, global, local *Config) {
484	for _, cfg := range []*Config{global, local} {
485		if cfg == nil {
486			continue
487		}
488		maps.Copy(base.MCP, cfg.MCP)
489	}
490}
491
492func mergeLSPs(base, global, local *Config) {
493	for _, cfg := range []*Config{global, local} {
494		if cfg == nil {
495			continue
496		}
497		maps.Copy(base.LSP, cfg.LSP)
498	}
499}
500
501func mergeProviderConfigs(base, global, local *Config) {
502	for _, cfg := range []*Config{global, local} {
503		if cfg == nil {
504			continue
505		}
506		for providerName, globalProvider := range cfg.Providers {
507			if _, ok := base.Providers[providerName]; !ok {
508				base.Providers[providerName] = globalProvider
509			} else {
510				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
511			}
512		}
513	}
514
515	finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
516	for providerName, providerConfig := range base.Providers {
517		err := validateProvider(providerName, providerConfig)
518		if err != nil {
519			logging.Warn("Skipping provider", "name", providerName, "error", err)
520		}
521		finalProviders[providerName] = providerConfig
522	}
523	base.Providers = finalProviders
524}
525
526func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
527	switch providerId {
528	case provider.InferenceProviderAnthropic:
529		return ProviderConfig{
530			ID:           providerId,
531			ProviderType: provider.TypeAnthropic,
532		}
533	case provider.InferenceProviderOpenAI:
534		return ProviderConfig{
535			ID:           providerId,
536			ProviderType: provider.TypeOpenAI,
537		}
538	case provider.InferenceProviderGemini:
539		return ProviderConfig{
540			ID:           providerId,
541			ProviderType: provider.TypeGemini,
542		}
543	case provider.InferenceProviderBedrock:
544		return ProviderConfig{
545			ID:           providerId,
546			ProviderType: provider.TypeBedrock,
547		}
548	case provider.InferenceProviderAzure:
549		return ProviderConfig{
550			ID:           providerId,
551			ProviderType: provider.TypeAzure,
552		}
553	case provider.InferenceProviderOpenRouter:
554		return ProviderConfig{
555			ID:           providerId,
556			ProviderType: provider.TypeOpenAI,
557			BaseURL:      "https://openrouter.ai/api/v1",
558			ExtraHeaders: map[string]string{
559				"HTTP-Referer": "crush.charm.land",
560				"X-Title":      "Crush",
561			},
562		}
563	case provider.InferenceProviderXAI:
564		return ProviderConfig{
565			ID:           providerId,
566			ProviderType: provider.TypeXAI,
567			BaseURL:      "https://api.x.ai/v1",
568		}
569	case provider.InferenceProviderVertexAI:
570		return ProviderConfig{
571			ID:           providerId,
572			ProviderType: provider.TypeVertexAI,
573		}
574	default:
575		return ProviderConfig{
576			ID:           providerId,
577			ProviderType: provider.TypeOpenAI,
578		}
579	}
580}
581
582func defaultConfigBasedOnEnv() *Config {
583	cfg := &Config{
584		Options: Options{
585			DataDirectory: defaultDataDirectory,
586			ContextPaths:  defaultContextPaths,
587		},
588		Providers: make(map[provider.InferenceProvider]ProviderConfig),
589	}
590
591	providers := Providers()
592
593	for _, p := range providers {
594		if strings.HasPrefix(p.APIKey, "$") {
595			envVar := strings.TrimPrefix(p.APIKey, "$")
596			if apiKey := os.Getenv(envVar); apiKey != "" {
597				providerConfig := providerDefaultConfig(p.ID)
598				providerConfig.APIKey = apiKey
599				providerConfig.DefaultLargeModel = p.DefaultLargeModelID
600				providerConfig.DefaultSmallModel = p.DefaultSmallModelID
601				for _, model := range p.Models {
602					providerConfig.Models = append(providerConfig.Models, Model{
603						ID:                 model.ID,
604						Name:               model.Name,
605						CostPer1MIn:        model.CostPer1MIn,
606						CostPer1MOut:       model.CostPer1MOut,
607						CostPer1MInCached:  model.CostPer1MInCached,
608						CostPer1MOutCached: model.CostPer1MOutCached,
609						ContextWindow:      model.ContextWindow,
610						DefaultMaxTokens:   model.DefaultMaxTokens,
611						CanReason:          model.CanReason,
612						SupportsImages:     model.SupportsImages,
613					})
614				}
615				cfg.Providers[p.ID] = providerConfig
616			}
617		}
618	}
619	// TODO: support local models
620
621	if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
622		providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
623		providerConfig.ExtraParams = map[string]string{
624			"project":  os.Getenv("GOOGLE_CLOUD_PROJECT"),
625			"location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
626		}
627		cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
628	}
629
630	if hasAWSCredentials() {
631		providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
632		cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
633	}
634	return cfg
635}
636
637func hasAWSCredentials() bool {
638	if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
639		return true
640	}
641
642	if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
643		return true
644	}
645
646	if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
647		return true
648	}
649
650	if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
651		os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
652		return true
653	}
654
655	return false
656}
657
658func WorkingDirectory() string {
659	return cwd
660}