config.go

  1package configv2
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"maps"
  7	"os"
  8	"path/filepath"
  9	"slices"
 10	"strings"
 11	"sync"
 12
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/logging"
 15)
 16
 17const (
 18	defaultDataDirectory = ".crush"
 19	defaultLogLevel      = "info"
 20	appName              = "crush"
 21
 22	MaxTokensFallbackDefault = 4096
 23)
 24
 25var defaultContextPaths = []string{
 26	".github/copilot-instructions.md",
 27	".cursorrules",
 28	".cursor/rules/",
 29	"CLAUDE.md",
 30	"CLAUDE.local.md",
 31	"crush.md",
 32	"crush.local.md",
 33	"Crush.md",
 34	"Crush.local.md",
 35	"CRUSH.md",
 36	"CRUSH.local.md",
 37}
 38
 39type AgentID string
 40
 41const (
 42	AgentCoder     AgentID = "coder"
 43	AgentTask      AgentID = "task"
 44	AgentTitle     AgentID = "title"
 45	AgentSummarize AgentID = "summarize"
 46)
 47
 48type Model struct {
 49	ID                 string  `json:"id"`
 50	Name               string  `json:"model"`
 51	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 52	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 53	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
 54	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
 55	ContextWindow      int64   `json:"context_window"`
 56	DefaultMaxTokens   int64   `json:"default_max_tokens"`
 57	CanReason          bool    `json:"can_reason"`
 58	ReasoningEffort    string  `json:"reasoning_effort"`
 59	SupportsImages     bool    `json:"supports_attachments"`
 60}
 61
 62type VertexAIOptions struct {
 63	APIKey   string `json:"api_key,omitempty"`
 64	Project  string `json:"project,omitempty"`
 65	Location string `json:"location,omitempty"`
 66}
 67
 68type ProviderConfig struct {
 69	ID           provider.InferenceProvider `json:"id"`
 70	BaseURL      string                     `json:"base_url,omitempty"`
 71	ProviderType provider.Type              `json:"provider_type"`
 72	APIKey       string                     `json:"api_key,omitempty"`
 73	Disabled     bool                       `json:"disabled"`
 74	ExtraHeaders map[string]string          `json:"extra_headers,omitempty"`
 75	// used for e.x for vertex to set the project
 76	ExtraParams map[string]string `json:"extra_params,omitempty"`
 77
 78	DefaultLargeModel string `json:"default_large_model,omitempty"`
 79	DefaultSmallModel string `json:"default_small_model,omitempty"`
 80
 81	Models []Model `json:"models,omitempty"`
 82}
 83
 84type Agent struct {
 85	Name        string `json:"name"`
 86	Description string `json:"description,omitempty"`
 87	// This is the id of the system prompt used by the agent
 88	Disabled bool `json:"disabled"`
 89
 90	Provider provider.InferenceProvider `json:"provider"`
 91	Model    string                     `json:"model"`
 92
 93	// The available tools for the agent
 94	//  if this is nil, all tools are available
 95	AllowedTools []string `json:"allowed_tools"`
 96
 97	// this tells us which MCPs are available for this agent
 98	//  if this is empty all mcps are available
 99	//  the string array is the list of tools from the AllowedMCP the agent has available
100	//  if the string array is nil, all tools from the AllowedMCP are available
101	AllowedMCP map[string][]string `json:"allowed_mcp"`
102
103	// The list of LSPs that this agent can use
104	//  if this is nil, all LSPs are available
105	AllowedLSP []string `json:"allowed_lsp"`
106
107	// Overrides the context paths for this agent
108	ContextPaths []string `json:"context_paths"`
109}
110
111type MCPType string
112
113const (
114	MCPStdio MCPType = "stdio"
115	MCPSse   MCPType = "sse"
116)
117
118type MCP struct {
119	Command string            `json:"command"`
120	Env     []string          `json:"env"`
121	Args    []string          `json:"args"`
122	Type    MCPType           `json:"type"`
123	URL     string            `json:"url"`
124	Headers map[string]string `json:"headers"`
125}
126
127type LSPConfig struct {
128	Disabled bool     `json:"enabled"`
129	Command  string   `json:"command"`
130	Args     []string `json:"args"`
131	Options  any      `json:"options"`
132}
133
134type TUIOptions struct {
135	CompactMode bool `json:"compact_mode"`
136	// Here we can add themes later or any TUI related options
137}
138
139type Options struct {
140	ContextPaths         []string   `json:"context_paths"`
141	TUI                  TUIOptions `json:"tui"`
142	Debug                bool       `json:"debug"`
143	DebugLSP             bool       `json:"debug_lsp"`
144	DisableAutoSummarize bool       `json:"disable_auto_summarize"`
145	// Relative to the cwd
146	DataDirectory string `json:"data_directory"`
147}
148
149type Config struct {
150	// List of configured providers
151	Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
152
153	// List of configured agents
154	Agents map[AgentID]Agent `json:"agents,omitempty"`
155
156	// List of configured MCPs
157	MCP map[string]MCP `json:"mcp,omitempty"`
158
159	// List of configured LSPs
160	LSP map[string]LSPConfig `json:"lsp,omitempty"`
161
162	// Miscellaneous options
163	Options Options `json:"options"`
164}
165
166var (
167	instance *Config // The single instance of the Singleton
168	cwd      string
169	once     sync.Once // Ensures the initialization happens only once
170
171)
172
173func loadConfig(cwd string) (*Config, error) {
174	// First read the global config file
175	cfgPath := ConfigPath()
176
177	cfg := defaultConfigBasedOnEnv()
178
179	var globalCfg *Config
180	if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
181		// some other error occurred while checking the file
182		return nil, err
183	} else if err == nil {
184		// config file exists, read it
185		file, err := os.ReadFile(cfgPath)
186		if err != nil {
187			return nil, err
188		}
189		globalCfg = &Config{}
190		if err := json.Unmarshal(file, globalCfg); err != nil {
191			return nil, err
192		}
193	} else {
194		// config file does not exist, create a new one
195		globalCfg = &Config{}
196	}
197
198	var localConfig *Config
199	// Global config loaded, now read the local config file
200	localConfigPath := filepath.Join(cwd, "crush.json")
201	if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
202		// some other error occurred while checking the file
203		return nil, err
204	} else if err == nil {
205		// local config file exists, read it
206		file, err := os.ReadFile(localConfigPath)
207		if err != nil {
208			return nil, err
209		}
210		localConfig = &Config{}
211		if err := json.Unmarshal(file, localConfig); err != nil {
212			return nil, err
213		}
214	}
215
216	// merge options
217	mergeOptions(cfg, globalCfg, localConfig)
218
219	mergeProviderConfigs(cfg, globalCfg, localConfig)
220	// no providers found the app is not initialized yet
221	if len(cfg.Providers) == 0 {
222		return cfg, nil
223	}
224	preferredProvider := getPreferredProvider(cfg.Providers)
225
226	if preferredProvider == nil {
227		return nil, errors.New("no valid providers configured")
228	}
229
230	agents := map[AgentID]Agent{
231		AgentCoder: {
232			Name:         "Coder",
233			Description:  "An agent that helps with executing coding tasks.",
234			Provider:     preferredProvider.ID,
235			Model:        preferredProvider.DefaultLargeModel,
236			ContextPaths: cfg.Options.ContextPaths,
237			// All tools allowed
238		},
239		AgentTask: {
240			Name:         "Task",
241			Description:  "An agent that helps with searching for context and finding implementation details.",
242			Provider:     preferredProvider.ID,
243			Model:        preferredProvider.DefaultLargeModel,
244			ContextPaths: cfg.Options.ContextPaths,
245			AllowedTools: []string{
246				"glob",
247				"grep",
248				"ls",
249				"sourcegraph",
250				"view",
251			},
252			// NO MCPs or LSPs by default
253			AllowedMCP: map[string][]string{},
254			AllowedLSP: []string{},
255		},
256		AgentTitle: {
257			Name:         "Title",
258			Description:  "An agent that helps with generating titles for sessions.",
259			Provider:     preferredProvider.ID,
260			Model:        preferredProvider.DefaultSmallModel,
261			ContextPaths: cfg.Options.ContextPaths,
262			AllowedTools: []string{},
263			// NO MCPs or LSPs by default
264			AllowedMCP: map[string][]string{},
265			AllowedLSP: []string{},
266		},
267		AgentSummarize: {
268			Name:         "Summarize",
269			Description:  "An agent that helps with summarizing sessions.",
270			Provider:     preferredProvider.ID,
271			Model:        preferredProvider.DefaultSmallModel,
272			ContextPaths: cfg.Options.ContextPaths,
273			AllowedTools: []string{},
274			// NO MCPs or LSPs by default
275			AllowedMCP: map[string][]string{},
276			AllowedLSP: []string{},
277		},
278	}
279	cfg.Agents = agents
280	mergeAgents(cfg, globalCfg, localConfig)
281	mergeMCPs(cfg, globalCfg, localConfig)
282	mergeLSPs(cfg, globalCfg, localConfig)
283
284	return cfg, nil
285}
286
287func InitConfig(workingDir string) *Config {
288	once.Do(func() {
289		cwd = workingDir
290		cfg, err := loadConfig(cwd)
291		if err != nil {
292			// TODO: Handle this better
293			panic("Failed to load config: " + err.Error())
294		}
295		instance = cfg
296	})
297
298	return instance
299}
300
301func GetConfig() *Config {
302	if instance == nil {
303		// TODO: Handle this better
304		panic("Config not initialized. Call InitConfig first.")
305	}
306	return instance
307}
308
309func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
310	providers := Providers()
311	for _, p := range providers {
312		if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
313			return &providerConfig
314		}
315	}
316	// if none found return the first configured provider
317	for _, providerConfig := range configuredProviders {
318		if !providerConfig.Disabled {
319			return &providerConfig
320		}
321	}
322	return nil
323}
324
325func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
326	if other.APIKey != "" {
327		base.APIKey = other.APIKey
328	}
329	// Only change these options if the provider is not a known provider
330	if !slices.Contains(provider.KnownProviders(), p) {
331		if other.BaseURL != "" {
332			base.BaseURL = other.BaseURL
333		}
334		if other.ProviderType != "" {
335			base.ProviderType = other.ProviderType
336		}
337		if len(base.ExtraHeaders) > 0 {
338			if base.ExtraHeaders == nil {
339				base.ExtraHeaders = make(map[string]string)
340			}
341			maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
342		}
343		if len(other.ExtraParams) > 0 {
344			if base.ExtraParams == nil {
345				base.ExtraParams = make(map[string]string)
346			}
347			maps.Copy(base.ExtraParams, other.ExtraParams)
348		}
349	}
350
351	if other.Disabled {
352		base.Disabled = other.Disabled
353	}
354
355	if other.DefaultLargeModel != "" {
356		base.DefaultLargeModel = other.DefaultLargeModel
357	}
358	// Add new models if they don't exist
359	if other.Models != nil {
360		for _, model := range other.Models {
361			// check if the model already exists
362			exists := false
363			for _, existingModel := range base.Models {
364				if existingModel.ID == model.ID {
365					exists = true
366					break
367				}
368			}
369			if !exists {
370				base.Models = append(base.Models, model)
371			}
372		}
373	}
374
375	return base
376}
377
378func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
379	if !slices.Contains(provider.KnownProviders(), p) {
380		if providerConfig.ProviderType != provider.TypeOpenAI {
381			return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
382		}
383		if providerConfig.BaseURL == "" {
384			return errors.New("base URL must be set for custom providers")
385		}
386		if providerConfig.APIKey == "" {
387			return errors.New("API key must be set for custom providers")
388		}
389	}
390	return nil
391}
392
393func mergeOptions(base, global, local *Config) {
394	for _, cfg := range []*Config{global, local} {
395		if cfg == nil {
396			continue
397		}
398		baseOptions := base.Options
399		other := cfg.Options
400		if len(other.ContextPaths) > 0 {
401			baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
402		}
403
404		if other.TUI.CompactMode {
405			baseOptions.TUI.CompactMode = other.TUI.CompactMode
406		}
407
408		if other.Debug {
409			baseOptions.Debug = other.Debug
410		}
411
412		if other.DebugLSP {
413			baseOptions.DebugLSP = other.DebugLSP
414		}
415
416		if other.DisableAutoSummarize {
417			baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
418		}
419
420		if other.DataDirectory != "" {
421			baseOptions.DataDirectory = other.DataDirectory
422		}
423		base.Options = baseOptions
424	}
425}
426
427func mergeAgents(base, global, local *Config) {
428	for _, cfg := range []*Config{global, local} {
429		if cfg == nil {
430			continue
431		}
432		for agentID, globalAgent := range cfg.Agents {
433			if _, ok := base.Agents[agentID]; !ok {
434				base.Agents[agentID] = globalAgent
435			} else {
436				switch agentID {
437				case AgentCoder:
438					baseAgent := base.Agents[agentID]
439					baseAgent.Model = globalAgent.Model
440					baseAgent.Provider = globalAgent.Provider
441					baseAgent.AllowedMCP = globalAgent.AllowedMCP
442					baseAgent.AllowedLSP = globalAgent.AllowedLSP
443					base.Agents[agentID] = baseAgent
444				case AgentTask:
445					baseAgent := base.Agents[agentID]
446					baseAgent.Model = globalAgent.Model
447					baseAgent.Provider = globalAgent.Provider
448					base.Agents[agentID] = baseAgent
449				case AgentTitle:
450					baseAgent := base.Agents[agentID]
451					baseAgent.Model = globalAgent.Model
452					baseAgent.Provider = globalAgent.Provider
453					base.Agents[agentID] = baseAgent
454				case AgentSummarize:
455					baseAgent := base.Agents[agentID]
456					baseAgent.Model = globalAgent.Model
457					baseAgent.Provider = globalAgent.Provider
458					base.Agents[agentID] = baseAgent
459				default:
460					baseAgent := base.Agents[agentID]
461					baseAgent.Name = globalAgent.Name
462					baseAgent.Description = globalAgent.Description
463					baseAgent.Disabled = globalAgent.Disabled
464					baseAgent.Provider = globalAgent.Provider
465					baseAgent.Model = globalAgent.Model
466					baseAgent.AllowedTools = globalAgent.AllowedTools
467					baseAgent.AllowedMCP = globalAgent.AllowedMCP
468					baseAgent.AllowedLSP = globalAgent.AllowedLSP
469					base.Agents[agentID] = baseAgent
470
471				}
472			}
473		}
474	}
475}
476
477func mergeMCPs(base, global, local *Config) {
478	for _, cfg := range []*Config{global, local} {
479		if cfg == nil {
480			continue
481		}
482		maps.Copy(base.MCP, cfg.MCP)
483	}
484}
485
486func mergeLSPs(base, global, local *Config) {
487	for _, cfg := range []*Config{global, local} {
488		if cfg == nil {
489			continue
490		}
491		maps.Copy(base.LSP, cfg.LSP)
492	}
493}
494
495func mergeProviderConfigs(base, global, local *Config) {
496	for _, cfg := range []*Config{global, local} {
497		if cfg == nil {
498			continue
499		}
500		for providerName, globalProvider := range cfg.Providers {
501			if _, ok := base.Providers[providerName]; !ok {
502				base.Providers[providerName] = globalProvider
503			} else {
504				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
505			}
506		}
507	}
508
509	finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
510	for providerName, providerConfig := range base.Providers {
511		err := validateProvider(providerName, providerConfig)
512		if err != nil {
513			logging.Warn("Skipping provider", "name", providerName, "error", err)
514		}
515		finalProviders[providerName] = providerConfig
516	}
517	base.Providers = finalProviders
518}
519
520func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
521	switch providerId {
522	case provider.InferenceProviderAnthropic:
523		return ProviderConfig{
524			ID:           providerId,
525			ProviderType: provider.TypeAnthropic,
526		}
527	case provider.InferenceProviderOpenAI:
528		return ProviderConfig{
529			ID:           providerId,
530			ProviderType: provider.TypeOpenAI,
531		}
532	case provider.InferenceProviderGemini:
533		return ProviderConfig{
534			ID:           providerId,
535			ProviderType: provider.TypeGemini,
536		}
537	case provider.InferenceProviderBedrock:
538		return ProviderConfig{
539			ID:           providerId,
540			ProviderType: provider.TypeBedrock,
541		}
542	case provider.InferenceProviderAzure:
543		return ProviderConfig{
544			ID:           providerId,
545			ProviderType: provider.TypeAzure,
546		}
547	case provider.InferenceProviderOpenRouter:
548		return ProviderConfig{
549			ID:           providerId,
550			ProviderType: provider.TypeOpenAI,
551			BaseURL:      "https://openrouter.ai/api/v1",
552			ExtraHeaders: map[string]string{
553				"HTTP-Referer": "crush.charm.land",
554				"X-Title":      "Crush",
555			},
556		}
557	case provider.InferenceProviderXAI:
558		return ProviderConfig{
559			ID:           providerId,
560			ProviderType: provider.TypeXAI,
561			BaseURL:      "https://api.x.ai/v1",
562		}
563	case provider.InferenceProviderVertexAI:
564		return ProviderConfig{
565			ID:           providerId,
566			ProviderType: provider.TypeVertexAI,
567		}
568	default:
569		return ProviderConfig{
570			ID:           providerId,
571			ProviderType: provider.TypeOpenAI,
572		}
573	}
574}
575
576func defaultConfigBasedOnEnv() *Config {
577	cfg := &Config{
578		Options: Options{
579			DataDirectory: defaultDataDirectory,
580			ContextPaths:  defaultContextPaths,
581		},
582		Providers: make(map[provider.InferenceProvider]ProviderConfig),
583	}
584
585	providers := Providers()
586
587	for _, p := range providers {
588		if strings.HasPrefix(p.APIKey, "$") {
589			envVar := strings.TrimPrefix(p.APIKey, "$")
590			if apiKey := os.Getenv(envVar); apiKey != "" {
591				providerConfig := providerDefaultConfig(p.ID)
592				providerConfig.APIKey = apiKey
593				providerConfig.DefaultLargeModel = p.DefaultLargeModelID
594				providerConfig.DefaultSmallModel = p.DefaultSmallModelID
595				for _, model := range p.Models {
596					providerConfig.Models = append(providerConfig.Models, Model{
597						ID:                 model.ID,
598						Name:               model.Name,
599						CostPer1MIn:        model.CostPer1MIn,
600						CostPer1MOut:       model.CostPer1MOut,
601						CostPer1MInCached:  model.CostPer1MInCached,
602						CostPer1MOutCached: model.CostPer1MOutCached,
603						ContextWindow:      model.ContextWindow,
604						DefaultMaxTokens:   model.DefaultMaxTokens,
605						CanReason:          model.CanReason,
606						SupportsImages:     model.SupportsImages,
607					})
608				}
609				cfg.Providers[p.ID] = providerConfig
610			}
611		}
612	}
613	// TODO: support local models
614
615	if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
616		providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
617		providerConfig.ExtraParams = map[string]string{
618			"project":  os.Getenv("GOOGLE_CLOUD_PROJECT"),
619			"location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
620		}
621		cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
622	}
623
624	if hasAWSCredentials() {
625		providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
626		cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
627	}
628	return cfg
629}
630
631func hasAWSCredentials() bool {
632	if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
633		return true
634	}
635
636	if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
637		return true
638	}
639
640	if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
641		return true
642	}
643
644	if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
645		os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
646		return true
647	}
648
649	return false
650}
651
652func WorkingDirectory() string {
653	return cwd
654}