config.go

  1package configv2
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"maps"
  7	"os"
  8	"path/filepath"
  9	"slices"
 10	"strings"
 11	"sync"
 12
 13	"github.com/charmbracelet/crush/internal/logging"
 14	"github.com/charmbracelet/fur/pkg/provider"
 15)
 16
 17const (
 18	defaultDataDirectory = ".crush"
 19	defaultLogLevel      = "info"
 20	appName              = "crush"
 21
 22	MaxTokensFallbackDefault = 4096
 23)
 24
 25type Model struct {
 26	ID                 string  `json:"id"`
 27	Name               string  `json:"model"`
 28	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 29	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 30	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
 31	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
 32	ContextWindow      int64   `json:"context_window"`
 33	DefaultMaxTokens   int64   `json:"default_max_tokens"`
 34	CanReason          bool    `json:"can_reason"`
 35	ReasoningEffort    string  `json:"reasoning_effort"`
 36	SupportsImages     bool    `json:"supports_attachments"`
 37}
 38
 39type VertexAIOptions struct {
 40	APIKey   string `json:"api_key,omitempty"`
 41	Project  string `json:"project,omitempty"`
 42	Location string `json:"location,omitempty"`
 43}
 44
 45type ProviderConfig struct {
 46	BaseURL      string            `json:"base_url,omitempty"`
 47	ProviderType provider.Type     `json:"provider_type"`
 48	APIKey       string            `json:"api_key,omitempty"`
 49	Disabled     bool              `json:"disabled"`
 50	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 51	// used for e.x for vertex to set the project
 52	ExtraParams map[string]string `json:"extra_params,omitempty"`
 53
 54	DefaultModel string `json:"default_model"`
 55}
 56
 57type Agent struct {
 58	Name string `json:"name"`
 59	// This is the id of the system prompt used by the agent
 60	//  TODO: still needs to be implemented
 61	PromptID string `json:"prompt_id"`
 62	Disabled bool   `json:"disabled"`
 63
 64	Provider provider.InferenceProvider `json:"provider"`
 65	Model    Model                      `json:"model"`
 66
 67	// The available tools for the agent
 68	//  if this is empty, all tools are available
 69	AllowedTools []string `json:"allowed_tools"`
 70
 71	// this tells us which MCPs are available for this agent
 72	//  if this is empty all mcps are available
 73	//  the string array is the list of tools from the MCP the agent has available
 74	//  if the string array is empty, all tools from the MCP are available
 75	MCP map[string][]string `json:"mcp"`
 76
 77	// The list of LSPs that this agent can use
 78	//  if this is empty, all LSPs are available
 79	LSP []string `json:"lsp"`
 80
 81	// Overrides the context paths for this agent
 82	ContextPaths []string `json:"context_paths"`
 83}
 84
 85type MCPType string
 86
 87const (
 88	MCPStdio MCPType = "stdio"
 89	MCPSse   MCPType = "sse"
 90)
 91
 92type MCP struct {
 93	Command string            `json:"command"`
 94	Env     []string          `json:"env"`
 95	Args    []string          `json:"args"`
 96	Type    MCPType           `json:"type"`
 97	URL     string            `json:"url"`
 98	Headers map[string]string `json:"headers"`
 99}
100
101type LSPConfig struct {
102	Disabled bool     `json:"enabled"`
103	Command  string   `json:"command"`
104	Args     []string `json:"args"`
105	Options  any      `json:"options"`
106}
107
108type TUIOptions struct {
109	CompactMode bool `json:"compact_mode"`
110	// Here we can add themes later or any TUI related options
111}
112
113type Options struct {
114	ContextPaths         []string   `json:"context_paths"`
115	TUI                  TUIOptions `json:"tui"`
116	Debug                bool       `json:"debug"`
117	DebugLSP             bool       `json:"debug_lsp"`
118	DisableAutoSummarize bool       `json:"disable_auto_summarize"`
119	// Relative to the cwd
120	DataDirectory string `json:"data_directory"`
121}
122
123type Config struct {
124	// List of configured providers
125	Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
126
127	// List of configured agents
128	Agents map[string]Agent `json:"agents,omitempty"`
129
130	// List of configured MCPs
131	MCP map[string]MCP `json:"mcp,omitempty"`
132
133	// List of configured LSPs
134	LSP map[string]LSPConfig `json:"lsp,omitempty"`
135
136	// Miscellaneous options
137	Options Options `json:"options"`
138
139	// Used to add models that are not already in the repository
140	Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"`
141}
142
143var (
144	instance *Config // The single instance of the Singleton
145	cwd      string
146	once     sync.Once // Ensures the initialization happens only once
147)
148
149func loadConfig(cwd string) (*Config, error) {
150	// First read the global config file
151	cfgPath := ConfigPath()
152
153	cfg := defaultConfigBasedOnEnv()
154
155	var globalCfg *Config
156	if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
157		// some other error occurred while checking the file
158		return nil, err
159	} else if err == nil {
160		// config file exists, read it
161		file, err := os.ReadFile(cfgPath)
162		if err != nil {
163			return nil, err
164		}
165		globalCfg = &Config{}
166		if err := json.Unmarshal(file, globalCfg); err != nil {
167			return nil, err
168		}
169	} else {
170		// config file does not exist, create a new one
171		globalCfg = &Config{}
172	}
173
174	var localConfig *Config
175	// Global config loaded, now read the local config file
176	localConfigPath := filepath.Join(cwd, "crush.json")
177	if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
178		// some other error occurred while checking the file
179		return nil, err
180	} else if err == nil {
181		// local config file exists, read it
182		file, err := os.ReadFile(localConfigPath)
183		if err != nil {
184			return nil, err
185		}
186		localConfig = &Config{}
187		if err := json.Unmarshal(file, localConfig); err != nil {
188			return nil, err
189		}
190	}
191
192	// merge options
193	cfg.Options = mergeOptions(cfg.Options, globalCfg.Options)
194	cfg.Options = mergeOptions(cfg.Options, localConfig.Options)
195
196	mergeProviderConfigs(cfg, globalCfg, localConfig)
197	return cfg, nil
198}
199
200func InitConfig(workingDir string) *Config {
201	once.Do(func() {
202		cwd = workingDir
203		cfg, err := loadConfig(cwd)
204		if err != nil {
205			// TODO: Handle this better
206			panic("Failed to load config: " + err.Error())
207		}
208		instance = cfg
209	})
210
211	return instance
212}
213
214func GetConfig() *Config {
215	if instance == nil {
216		// TODO: Handle this better
217		panic("Config not initialized. Call InitConfig first.")
218	}
219	return instance
220}
221
222func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
223	if other.APIKey != "" {
224		base.APIKey = other.APIKey
225	}
226	// Only change these options if the provider is not a known provider
227	if !slices.Contains(provider.KnownProviders(), p) {
228		if other.BaseURL != "" {
229			base.BaseURL = other.BaseURL
230		}
231		if other.ProviderType != "" {
232			base.ProviderType = other.ProviderType
233		}
234		if len(base.ExtraHeaders) > 0 {
235			if base.ExtraHeaders == nil {
236				base.ExtraHeaders = make(map[string]string)
237			}
238			maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
239		}
240		if len(other.ExtraParams) > 0 {
241			if base.ExtraParams == nil {
242				base.ExtraParams = make(map[string]string)
243			}
244			maps.Copy(base.ExtraParams, other.ExtraParams)
245		}
246	}
247
248	if other.Disabled {
249		base.Disabled = other.Disabled
250	}
251
252	return base
253}
254
255func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
256	if !slices.Contains(provider.KnownProviders(), p) {
257		if providerConfig.ProviderType != provider.TypeOpenAI {
258			return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
259		}
260		if providerConfig.BaseURL == "" {
261			return errors.New("base URL must be set for custom providers")
262		}
263		if providerConfig.APIKey == "" {
264			return errors.New("API key must be set for custom providers")
265		}
266	}
267	return nil
268}
269
270func mergeOptions(base, other Options) Options {
271	result := base
272
273	if len(other.ContextPaths) > 0 {
274		base.ContextPaths = append(base.ContextPaths, other.ContextPaths...)
275	}
276
277	if other.TUI.CompactMode {
278		result.TUI.CompactMode = other.TUI.CompactMode
279	}
280
281	if other.Debug {
282		result.Debug = other.Debug
283	}
284
285	if other.DebugLSP {
286		result.DebugLSP = other.DebugLSP
287	}
288
289	if other.DisableAutoSummarize {
290		result.DisableAutoSummarize = other.DisableAutoSummarize
291	}
292
293	if other.DataDirectory != "" {
294		result.DataDirectory = other.DataDirectory
295	}
296
297	return result
298}
299
300func mergeProviderConfigs(base, global, local *Config) {
301	if global != nil {
302		for providerName, globalProvider := range global.Providers {
303			if _, ok := base.Providers[providerName]; !ok {
304				base.Providers[providerName] = globalProvider
305			} else {
306				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
307			}
308		}
309	}
310	if local != nil {
311		for providerName, localProvider := range local.Providers {
312			if _, ok := base.Providers[providerName]; !ok {
313				base.Providers[providerName] = localProvider
314			} else {
315				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider)
316			}
317		}
318	}
319
320	finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
321	for providerName, providerConfig := range base.Providers {
322		err := validateProvider(providerName, providerConfig)
323		if err != nil {
324			logging.Warn("Skipping provider", "name", providerName, "error", err)
325		}
326		finalProviders[providerName] = providerConfig
327	}
328	base.Providers = finalProviders
329}
330
331func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig {
332	switch providerName {
333	case provider.InferenceProviderAnthropic:
334		return ProviderConfig{
335			ProviderType: provider.TypeAnthropic,
336		}
337	case provider.InferenceProviderOpenAI:
338		return ProviderConfig{
339			ProviderType: provider.TypeOpenAI,
340		}
341	case provider.InferenceProviderGemini:
342		return ProviderConfig{
343			ProviderType: provider.TypeGemini,
344		}
345	case provider.InferenceProviderBedrock:
346		return ProviderConfig{
347			ProviderType: provider.TypeBedrock,
348		}
349	case provider.InferenceProviderAzure:
350		return ProviderConfig{
351			ProviderType: provider.TypeAzure,
352		}
353	case provider.InferenceProviderOpenRouter:
354		return ProviderConfig{
355			ProviderType: provider.TypeOpenAI,
356			BaseURL:      "https://openrouter.ai/api/v1",
357			ExtraHeaders: map[string]string{
358				"HTTP-Referer": "crush.charm.land",
359				"X-Title":      "Crush",
360			},
361		}
362	case provider.InferenceProviderXAI:
363		return ProviderConfig{
364			ProviderType: provider.TypeXAI,
365			BaseURL:      "https://api.x.ai/v1",
366		}
367	case provider.InferenceProviderVertexAI:
368		return ProviderConfig{
369			ProviderType: provider.TypeVertexAI,
370		}
371	default:
372		return ProviderConfig{
373			ProviderType: provider.TypeOpenAI,
374		}
375	}
376}
377
378func defaultConfigBasedOnEnv() *Config {
379	cfg := &Config{
380		Options: Options{
381			DataDirectory: defaultDataDirectory,
382		},
383		Providers: make(map[provider.InferenceProvider]ProviderConfig),
384	}
385
386	providers := Providers()
387
388	for _, p := range providers {
389		if strings.HasPrefix(p.APIKey, "$") {
390			envVar := strings.TrimPrefix(p.APIKey, "$")
391			if apiKey := os.Getenv(envVar); apiKey != "" {
392				providerConfig := providerDefaultConfig(p.ID)
393				providerConfig.APIKey = apiKey
394				providerConfig.DefaultModel = p.DefaultModelID
395				cfg.Providers[p.ID] = providerConfig
396			}
397		}
398	}
399	// TODO: support local models
400
401	if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
402		providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
403		providerConfig.ExtraParams = map[string]string{
404			"project":  os.Getenv("GOOGLE_CLOUD_PROJECT"),
405			"location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
406		}
407		cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
408	}
409
410	if hasAWSCredentials() {
411		providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
412		cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
413	}
414	return cfg
415}
416
417func hasAWSCredentials() bool {
418	if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
419		return true
420	}
421
422	if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
423		return true
424	}
425
426	if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
427		return true
428	}
429
430	if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
431		os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
432		return true
433	}
434
435	return false
436}
437
438func WorkingDirectory() string {
439	return cwd
440}