1package config
2
3import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "runtime"
12 "slices"
13 "strconv"
14 "strings"
15
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/fsext"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/version"
21 uv "github.com/charmbracelet/ultraviolet"
22 powernapConfig "github.com/charmbracelet/x/powernap/pkg/config"
23)
24
25type environ = uv.Environ
26
27const defaultCatwalkURL = "https://catwalk.charm.sh"
28
29// LoadReader config via io.Reader.
30func LoadReader(fd io.Reader) (*Config, error) {
31 data, err := io.ReadAll(fd)
32 if err != nil {
33 return nil, err
34 }
35
36 var config Config
37 err = json.Unmarshal(data, &config)
38 if err != nil {
39 return nil, err
40 }
41 return &config, err
42}
43
44// Load loads the configuration from the default paths.
45func Load(workingDir, dataDir string, debug bool, envs []string) (*Config, error) {
46 configPaths := lookupConfigs(workingDir)
47
48 cfg, err := loadFromConfigPaths(configPaths)
49 if err != nil {
50 return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
51 }
52
53 cfg.dataConfigDir = GlobalConfigData()
54
55 cfg.setDefaults(workingDir, dataDir)
56
57 if debug {
58 cfg.Options.Debug = true
59 }
60
61 // Load known providers, this loads the config from catwalk
62 providers, err := Providers(cfg)
63 if err != nil {
64 return nil, err
65 }
66 cfg.knownProviders = providers
67
68 // Configure providers
69 valueResolver := NewShellVariableResolver(envs)
70 if err := cfg.configureProviders(envs, valueResolver, cfg.knownProviders); err != nil {
71 return nil, fmt.Errorf("failed to configure providers: %w", err)
72 }
73
74 if !cfg.IsConfigured() {
75 slog.Warn("No providers configured")
76 return cfg, nil
77 }
78
79 if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil {
80 return nil, fmt.Errorf("failed to configure selected models: %w", err)
81 }
82 cfg.SetupAgents()
83 return cfg, nil
84}
85
86func PushPopCrushEnv() func() {
87 found := []string{}
88 for _, ev := range os.Environ() {
89 if strings.HasPrefix(ev, "CRUSH_") {
90 pair := strings.SplitN(ev, "=", 2)
91 if len(pair) != 2 {
92 continue
93 }
94 found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
95 }
96 }
97 backups := make(map[string]string)
98 for _, ev := range found {
99 backups[ev] = os.Getenv(ev)
100 }
101
102 for _, ev := range found {
103 os.Setenv(ev, os.Getenv("CRUSH_"+ev))
104 }
105
106 restore := func() {
107 for k, v := range backups {
108 os.Setenv(k, v)
109 }
110 }
111 return restore
112}
113
114func (c *Config) configureProviders(env environ, resolver VariableResolver, knownProviders []catwalk.Provider) error {
115 knownProviderNames := make(map[string]bool)
116 restore := PushPopCrushEnv()
117 defer restore()
118 for _, p := range knownProviders {
119 knownProviderNames[string(p.ID)] = true
120 config, configExists := c.Providers.Get(string(p.ID))
121 // if the user configured a known provider we need to allow it to override a couple of parameters
122 if configExists {
123 if config.BaseURL != "" {
124 p.APIEndpoint = config.BaseURL
125 }
126 if config.APIKey != "" {
127 p.APIKey = config.APIKey
128 }
129 if len(config.Models) > 0 {
130 models := []catwalk.Model{}
131 seen := make(map[string]bool)
132
133 for _, model := range config.Models {
134 if seen[model.ID] {
135 continue
136 }
137 seen[model.ID] = true
138 if model.Name == "" {
139 model.Name = model.ID
140 }
141 models = append(models, model)
142 }
143 for _, model := range p.Models {
144 if seen[model.ID] {
145 continue
146 }
147 seen[model.ID] = true
148 if model.Name == "" {
149 model.Name = model.ID
150 }
151 models = append(models, model)
152 }
153
154 p.Models = models
155 }
156 }
157
158 headers := map[string]string{}
159 if len(p.DefaultHeaders) > 0 {
160 maps.Copy(headers, p.DefaultHeaders)
161 }
162 if len(config.ExtraHeaders) > 0 {
163 maps.Copy(headers, config.ExtraHeaders)
164 }
165 prepared := ProviderConfig{
166 ID: string(p.ID),
167 Name: p.Name,
168 BaseURL: p.APIEndpoint,
169 APIKey: p.APIKey,
170 Type: p.Type,
171 Disable: config.Disable,
172 SystemPromptPrefix: config.SystemPromptPrefix,
173 ExtraHeaders: headers,
174 ExtraBody: config.ExtraBody,
175 ExtraParams: make(map[string]string),
176 Models: p.Models,
177 }
178
179 switch p.ID {
180 // Handle specific providers that require additional configuration
181 case catwalk.InferenceProviderVertexAI:
182 if !hasVertexCredentials(env) {
183 if configExists {
184 slog.Warn("Skipping Vertex AI provider due to missing credentials")
185 c.Providers.Del(string(p.ID))
186 }
187 continue
188 }
189 prepared.ExtraParams["project"] = env.Getenv("VERTEXAI_PROJECT")
190 prepared.ExtraParams["location"] = env.Getenv("VERTEXAI_LOCATION")
191 case catwalk.InferenceProviderAzure:
192 endpoint, err := resolver.ResolveValue(p.APIEndpoint)
193 if err != nil || endpoint == "" {
194 if configExists {
195 slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
196 c.Providers.Del(string(p.ID))
197 }
198 continue
199 }
200 prepared.BaseURL = endpoint
201 prepared.ExtraParams["apiVersion"] = env.Getenv("AZURE_OPENAI_API_VERSION")
202 case catwalk.InferenceProviderBedrock:
203 if !hasAWSCredentials(env) {
204 if configExists {
205 slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
206 c.Providers.Del(string(p.ID))
207 }
208 continue
209 }
210 prepared.ExtraParams["region"] = env.Getenv("AWS_REGION")
211 if prepared.ExtraParams["region"] == "" {
212 prepared.ExtraParams["region"] = env.Getenv("AWS_DEFAULT_REGION")
213 }
214 for _, model := range p.Models {
215 if !strings.HasPrefix(model.ID, "anthropic.") {
216 return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
217 }
218 }
219 default:
220 // if the provider api or endpoint are missing we skip them
221 v, err := resolver.ResolveValue(p.APIKey)
222 if v == "" || err != nil {
223 if configExists {
224 slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
225 c.Providers.Del(string(p.ID))
226 }
227 continue
228 }
229 }
230 c.Providers.Set(string(p.ID), prepared)
231 }
232
233 // validate the custom providers
234 for id, providerConfig := range c.Providers.Seq2() {
235 if knownProviderNames[id] {
236 continue
237 }
238
239 // Make sure the provider ID is set
240 providerConfig.ID = id
241 if providerConfig.Name == "" {
242 providerConfig.Name = id // Use ID as name if not set
243 }
244 // default to OpenAI if not set
245 if providerConfig.Type == "" {
246 providerConfig.Type = catwalk.TypeOpenAI
247 }
248
249 if providerConfig.Disable {
250 slog.Debug("Skipping custom provider due to disable flag", "provider", id)
251 c.Providers.Del(id)
252 continue
253 }
254 if providerConfig.APIKey == "" {
255 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
256 }
257 if providerConfig.BaseURL == "" {
258 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
259 c.Providers.Del(id)
260 continue
261 }
262 if len(providerConfig.Models) == 0 {
263 slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
264 c.Providers.Del(id)
265 continue
266 }
267 if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic && providerConfig.Type != catwalk.TypeGemini {
268 slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
269 c.Providers.Del(id)
270 continue
271 }
272
273 apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
274 if apiKey == "" || err != nil {
275 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
276 }
277 baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
278 if baseURL == "" || err != nil {
279 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
280 c.Providers.Del(id)
281 continue
282 }
283
284 c.Providers.Set(id, providerConfig)
285 }
286 return nil
287}
288
289func (c *Config) setDefaults(workingDir, dataDir string) {
290 c.workingDir = workingDir
291 if c.Options == nil {
292 c.Options = &Options{}
293 }
294 if c.Options.TUI == nil {
295 c.Options.TUI = &TUIOptions{}
296 }
297 if c.Options.ContextPaths == nil {
298 c.Options.ContextPaths = []string{}
299 }
300 if dataDir != "" {
301 c.Options.DataDirectory = dataDir
302 } else if c.Options.DataDirectory == "" {
303 if path, ok := fsext.LookupClosest(workingDir, defaultDataDirectory); ok {
304 c.Options.DataDirectory = path
305 } else {
306 c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
307 }
308 }
309 if c.Providers == nil {
310 c.Providers = csync.NewMap[string, ProviderConfig]()
311 }
312 if c.Models == nil {
313 c.Models = make(map[SelectedModelType]SelectedModel)
314 }
315 if c.MCP == nil {
316 c.MCP = make(map[string]MCPConfig)
317 }
318 if c.LSP == nil {
319 c.LSP = make(map[string]LSPConfig)
320 }
321
322 // Apply defaults to LSP configurations
323 c.applyLSPDefaults()
324
325 // Add the default context paths if they are not already present
326 c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
327 slices.Sort(c.Options.ContextPaths)
328 c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
329
330 if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok {
331 c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str)
332 }
333}
334
335// applyLSPDefaults applies default values from powernap to LSP configurations
336func (c *Config) applyLSPDefaults() {
337 // Get powernap's default configuration
338 configManager := powernapConfig.NewManager()
339 configManager.LoadDefaults()
340
341 // Apply defaults to each LSP configuration
342 for name, cfg := range c.LSP {
343 // Try to get defaults from powernap based on name or command name.
344 base, ok := configManager.GetServer(name)
345 if !ok {
346 base, ok = configManager.GetServer(cfg.Command)
347 if !ok {
348 continue
349 }
350 }
351 if cfg.Options == nil {
352 cfg.Options = base.Settings
353 }
354 if cfg.InitOptions == nil {
355 cfg.InitOptions = base.InitOptions
356 }
357 if len(cfg.FileTypes) == 0 {
358 cfg.FileTypes = base.FileTypes
359 }
360 if len(cfg.RootMarkers) == 0 {
361 cfg.RootMarkers = base.RootMarkers
362 }
363 if cfg.Command == "" {
364 cfg.Command = base.Command
365 }
366 if len(cfg.Args) == 0 {
367 cfg.Args = base.Args
368 }
369 if len(cfg.Env) == 0 {
370 cfg.Env = base.Environment
371 }
372 // Update the config in the map
373 c.LSP[name] = cfg
374 }
375}
376
377func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
378 if len(knownProviders) == 0 && c.Providers.Len() == 0 {
379 err = fmt.Errorf("no providers configured, please configure at least one provider")
380 return largeModel, smallModel, err
381 }
382
383 // Use the first provider enabled based on the known providers order
384 // if no provider found that is known use the first provider configured
385 for _, p := range knownProviders {
386 providerConfig, ok := c.Providers.Get(string(p.ID))
387 if !ok || providerConfig.Disable {
388 continue
389 }
390 defaultLargeModel := c.GetModel(string(p.ID), p.DefaultLargeModelID)
391 if defaultLargeModel == nil {
392 err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
393 return largeModel, smallModel, err
394 }
395 largeModel = SelectedModel{
396 Provider: string(p.ID),
397 Model: defaultLargeModel.ID,
398 MaxTokens: defaultLargeModel.DefaultMaxTokens,
399 ReasoningEffort: defaultLargeModel.DefaultReasoningEffort,
400 }
401
402 defaultSmallModel := c.GetModel(string(p.ID), p.DefaultSmallModelID)
403 if defaultSmallModel == nil {
404 err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
405 return largeModel, smallModel, err
406 }
407 smallModel = SelectedModel{
408 Provider: string(p.ID),
409 Model: defaultSmallModel.ID,
410 MaxTokens: defaultSmallModel.DefaultMaxTokens,
411 ReasoningEffort: defaultSmallModel.DefaultReasoningEffort,
412 }
413 return largeModel, smallModel, err
414 }
415
416 enabledProviders := c.EnabledProviders()
417 slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
418 return strings.Compare(a.ID, b.ID)
419 })
420
421 if len(enabledProviders) == 0 {
422 err = fmt.Errorf("no providers configured, please configure at least one provider")
423 return largeModel, smallModel, err
424 }
425
426 providerConfig := enabledProviders[0]
427 if len(providerConfig.Models) == 0 {
428 err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
429 return largeModel, smallModel, err
430 }
431 defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
432 largeModel = SelectedModel{
433 Provider: providerConfig.ID,
434 Model: defaultLargeModel.ID,
435 MaxTokens: defaultLargeModel.DefaultMaxTokens,
436 }
437 defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
438 smallModel = SelectedModel{
439 Provider: providerConfig.ID,
440 Model: defaultSmallModel.ID,
441 MaxTokens: defaultSmallModel.DefaultMaxTokens,
442 }
443 return largeModel, smallModel, err
444}
445
446func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
447 defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
448 if err != nil {
449 return fmt.Errorf("failed to select default models: %w", err)
450 }
451 large, small := defaultLarge, defaultSmall
452
453 largeModelSelected, largeModelConfigured := c.Models[SelectedModelTypeLarge]
454 if largeModelConfigured {
455 if largeModelSelected.Model != "" {
456 large.Model = largeModelSelected.Model
457 }
458 if largeModelSelected.Provider != "" {
459 large.Provider = largeModelSelected.Provider
460 }
461 model := c.GetModel(large.Provider, large.Model)
462 if model == nil {
463 large = defaultLarge
464 // override the model type to large
465 err := c.UpdatePreferredModel(SelectedModelTypeLarge, large)
466 if err != nil {
467 return fmt.Errorf("failed to update preferred large model: %w", err)
468 }
469 } else {
470 if largeModelSelected.MaxTokens > 0 {
471 large.MaxTokens = largeModelSelected.MaxTokens
472 } else {
473 large.MaxTokens = model.DefaultMaxTokens
474 }
475 if largeModelSelected.ReasoningEffort != "" {
476 large.ReasoningEffort = largeModelSelected.ReasoningEffort
477 }
478 large.Think = largeModelSelected.Think
479 }
480 }
481 smallModelSelected, smallModelConfigured := c.Models[SelectedModelTypeSmall]
482 if smallModelConfigured {
483 if smallModelSelected.Model != "" {
484 small.Model = smallModelSelected.Model
485 }
486 if smallModelSelected.Provider != "" {
487 small.Provider = smallModelSelected.Provider
488 }
489
490 model := c.GetModel(small.Provider, small.Model)
491 if model == nil {
492 small = defaultSmall
493 // override the model type to small
494 err := c.UpdatePreferredModel(SelectedModelTypeSmall, small)
495 if err != nil {
496 return fmt.Errorf("failed to update preferred small model: %w", err)
497 }
498 } else {
499 if smallModelSelected.MaxTokens > 0 {
500 small.MaxTokens = smallModelSelected.MaxTokens
501 } else {
502 small.MaxTokens = model.DefaultMaxTokens
503 }
504 small.ReasoningEffort = smallModelSelected.ReasoningEffort
505 small.Think = smallModelSelected.Think
506 }
507 }
508 c.Models[SelectedModelTypeLarge] = large
509 c.Models[SelectedModelTypeSmall] = small
510 return nil
511}
512
513// lookupConfigs searches config files recursively from CWD up to FS root
514func lookupConfigs(cwd string) []string {
515 // prepend default config paths
516 configPaths := []string{
517 globalConfig(),
518 GlobalConfigData(),
519 }
520
521 if cwd == "" {
522 return configPaths
523 }
524
525 configNames := []string{version.AppName + ".json", "." + version.AppName + ".json"}
526
527 foundConfigs, err := fsext.Lookup(cwd, configNames...)
528 if err != nil {
529 // returns at least default configs
530 return configPaths
531 }
532
533 // reverse order so last config has more priority
534 slices.Reverse(foundConfigs)
535
536 return append(configPaths, foundConfigs...)
537}
538
539func loadFromConfigPaths(configPaths []string) (*Config, error) {
540 var configs []io.Reader
541
542 for _, path := range configPaths {
543 fd, err := os.Open(path)
544 if err != nil {
545 if os.IsNotExist(err) {
546 continue
547 }
548 return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
549 }
550 defer fd.Close()
551
552 configs = append(configs, fd)
553 }
554
555 return loadFromReaders(configs)
556}
557
558func loadFromReaders(readers []io.Reader) (*Config, error) {
559 if len(readers) == 0 {
560 return &Config{}, nil
561 }
562
563 merged, err := Merge(readers)
564 if err != nil {
565 return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
566 }
567
568 return LoadReader(merged)
569}
570
571func hasVertexCredentials(env environ) bool {
572 hasProject := env.Getenv("VERTEXAI_PROJECT") != ""
573 hasLocation := env.Getenv("VERTEXAI_LOCATION") != ""
574 return hasProject && hasLocation
575}
576
577func hasAWSCredentials(env environ) bool {
578 if env.Getenv("AWS_ACCESS_KEY_ID") != "" && env.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
579 return true
580 }
581
582 if env.Getenv("AWS_PROFILE") != "" || env.Getenv("AWS_DEFAULT_PROFILE") != "" {
583 return true
584 }
585
586 if env.Getenv("AWS_REGION") != "" || env.Getenv("AWS_DEFAULT_REGION") != "" {
587 return true
588 }
589
590 if env.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
591 env.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
592 return true
593 }
594 return false
595}
596
597func globalConfig() string {
598 xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
599 if xdgConfigHome != "" {
600 return filepath.Join(xdgConfigHome, version.AppName, fmt.Sprintf("%s.json", version.AppName))
601 }
602
603 // return the path to the main config directory
604 // for windows, it should be in `%LOCALAPPDATA%/crush/`
605 // for linux and macOS, it should be in `$HOME/.config/crush/`
606 if runtime.GOOS == "windows" {
607 localAppData := os.Getenv("LOCALAPPDATA")
608 if localAppData == "" {
609 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
610 }
611 return filepath.Join(localAppData, version.AppName, fmt.Sprintf("%s.json", version.AppName))
612 }
613
614 return filepath.Join(home.Dir(), ".config", version.AppName, fmt.Sprintf("%s.json", version.AppName))
615}
616
617// GlobalConfigData returns the path to the main data directory for the application.
618// this config is used when the app overrides configurations instead of updating the global config.
619func GlobalConfigData() string {
620 xdgDataHome := os.Getenv("XDG_DATA_HOME")
621 if xdgDataHome != "" {
622 return filepath.Join(xdgDataHome, version.AppName, fmt.Sprintf("%s.json", version.AppName))
623 }
624
625 // return the path to the main data directory
626 // for windows, it should be in `%LOCALAPPDATA%/crush/`
627 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
628 if runtime.GOOS == "windows" {
629 localAppData := os.Getenv("LOCALAPPDATA")
630 if localAppData == "" {
631 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
632 }
633 return filepath.Join(localAppData, version.AppName, fmt.Sprintf("%s.json", version.AppName))
634 }
635
636 return filepath.Join(home.Dir(), ".local", "share", version.AppName, fmt.Sprintf("%s.json", version.AppName))
637}
638
639// GlobalCacheDir returns the path to the main cache directory for the application.
640func GlobalCacheDir() string {
641 xdgCacheHome := os.Getenv("XDG_CACHE_HOME")
642 if xdgCacheHome != "" {
643 return filepath.Join(xdgCacheHome, version.AppName)
644 }
645
646 // return the path to the main cache directory
647 // for windows, it should be in `%LOCALAPPDATA%/crush/Cache`
648 // for linux and macOS, it should be in `$HOME/.cache/crush/`
649 if runtime.GOOS == "windows" {
650 localAppData := os.Getenv("LOCALAPPDATA")
651 if localAppData == "" {
652 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
653 }
654 return filepath.Join(localAppData, version.AppName, "Cache")
655 }
656
657 return filepath.Join(home.Dir(), ".cache", version.AppName)
658}