1package config
2
3import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "runtime"
12 "slices"
13 "strconv"
14 "strings"
15
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/fsext"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/version"
21 uv "github.com/charmbracelet/ultraviolet"
22 powernapConfig "github.com/charmbracelet/x/powernap/pkg/config"
23)
24
25type environ = uv.Environ
26
27const defaultCatwalkURL = "https://catwalk.charm.sh"
28
29// LoadReader config via io.Reader.
30func LoadReader(fd io.Reader) (*Config, error) {
31 data, err := io.ReadAll(fd)
32 if err != nil {
33 return nil, err
34 }
35
36 var config Config
37 err = json.Unmarshal(data, &config)
38 if err != nil {
39 return nil, err
40 }
41 return &config, err
42}
43
44// Load loads the configuration from the default paths.
45func Load(workingDir, dataDir string, debug bool) (*Config, error) {
46 configPaths := lookupConfigs(workingDir)
47
48 cfg, err := loadFromConfigPaths(configPaths)
49 if err != nil {
50 return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
51 }
52
53 cfg.dataConfigDir = GlobalConfigData()
54
55 cfg.setDefaults(workingDir, dataDir)
56
57 if debug {
58 cfg.Options.Debug = true
59 }
60
61 // Load known providers, this loads the config from catwalk
62 providers, err := Providers(cfg)
63 if err != nil {
64 return nil, err
65 }
66 cfg.knownProviders = providers
67
68 envs := os.Environ()
69 // Configure providers
70 valueResolver := NewShellVariableResolver(envs)
71 cfg.resolver = valueResolver
72 if err := cfg.configureProviders(envs, valueResolver, cfg.knownProviders); err != nil {
73 return nil, fmt.Errorf("failed to configure providers: %w", err)
74 }
75
76 if !cfg.IsConfigured() {
77 slog.Warn("No providers configured")
78 return cfg, nil
79 }
80
81 if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil {
82 return nil, fmt.Errorf("failed to configure selected models: %w", err)
83 }
84 cfg.SetupAgents()
85 return cfg, nil
86}
87
88func PushPopCrushEnv() func() {
89 found := []string{}
90 for _, ev := range os.Environ() {
91 if strings.HasPrefix(ev, "CRUSH_") {
92 pair := strings.SplitN(ev, "=", 2)
93 if len(pair) != 2 {
94 continue
95 }
96 found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
97 }
98 }
99 backups := make(map[string]string)
100 for _, ev := range found {
101 backups[ev] = os.Getenv(ev)
102 }
103
104 for _, ev := range found {
105 os.Setenv(ev, os.Getenv("CRUSH_"+ev))
106 }
107
108 restore := func() {
109 for k, v := range backups {
110 os.Setenv(k, v)
111 }
112 }
113 return restore
114}
115
116func (c *Config) configureProviders(env environ, resolver VariableResolver, knownProviders []catwalk.Provider) error {
117 knownProviderNames := make(map[string]bool)
118 restore := PushPopCrushEnv()
119 defer restore()
120 for _, p := range knownProviders {
121 knownProviderNames[string(p.ID)] = true
122 config, configExists := c.Providers.Get(string(p.ID))
123 // if the user configured a known provider we need to allow it to override a couple of parameters
124 if configExists {
125 if config.BaseURL != "" {
126 p.APIEndpoint = config.BaseURL
127 }
128 if config.APIKey != "" {
129 p.APIKey = config.APIKey
130 }
131 if len(config.Models) > 0 {
132 models := []catwalk.Model{}
133 seen := make(map[string]bool)
134
135 for _, model := range config.Models {
136 if seen[model.ID] {
137 continue
138 }
139 seen[model.ID] = true
140 if model.Name == "" {
141 model.Name = model.ID
142 }
143 models = append(models, model)
144 }
145 for _, model := range p.Models {
146 if seen[model.ID] {
147 continue
148 }
149 seen[model.ID] = true
150 if model.Name == "" {
151 model.Name = model.ID
152 }
153 models = append(models, model)
154 }
155
156 p.Models = models
157 }
158 }
159
160 headers := map[string]string{}
161 if len(p.DefaultHeaders) > 0 {
162 maps.Copy(headers, p.DefaultHeaders)
163 }
164 if len(config.ExtraHeaders) > 0 {
165 maps.Copy(headers, config.ExtraHeaders)
166 }
167 prepared := ProviderConfig{
168 ID: string(p.ID),
169 Name: p.Name,
170 BaseURL: p.APIEndpoint,
171 APIKey: p.APIKey,
172 Type: p.Type,
173 Disable: config.Disable,
174 SystemPromptPrefix: config.SystemPromptPrefix,
175 ExtraHeaders: headers,
176 ExtraBody: config.ExtraBody,
177 ExtraParams: make(map[string]string),
178 Models: p.Models,
179 }
180
181 switch p.ID {
182 // Handle specific providers that require additional configuration
183 case catwalk.InferenceProviderVertexAI:
184 if !hasVertexCredentials(env) {
185 if configExists {
186 slog.Warn("Skipping Vertex AI provider due to missing credentials")
187 c.Providers.Del(string(p.ID))
188 }
189 continue
190 }
191 prepared.ExtraParams["project"] = env.Getenv("VERTEXAI_PROJECT")
192 prepared.ExtraParams["location"] = env.Getenv("VERTEXAI_LOCATION")
193 case catwalk.InferenceProviderAzure:
194 endpoint, err := resolver.ResolveValue(p.APIEndpoint)
195 if err != nil || endpoint == "" {
196 if configExists {
197 slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
198 c.Providers.Del(string(p.ID))
199 }
200 continue
201 }
202 prepared.BaseURL = endpoint
203 prepared.ExtraParams["apiVersion"] = env.Getenv("AZURE_OPENAI_API_VERSION")
204 case catwalk.InferenceProviderBedrock:
205 if !hasAWSCredentials(env) {
206 if configExists {
207 slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
208 c.Providers.Del(string(p.ID))
209 }
210 continue
211 }
212 prepared.ExtraParams["region"] = env.Getenv("AWS_REGION")
213 if prepared.ExtraParams["region"] == "" {
214 prepared.ExtraParams["region"] = env.Getenv("AWS_DEFAULT_REGION")
215 }
216 for _, model := range p.Models {
217 if !strings.HasPrefix(model.ID, "anthropic.") {
218 return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
219 }
220 }
221 default:
222 // if the provider api or endpoint are missing we skip them
223 v, err := resolver.ResolveValue(p.APIKey)
224 if v == "" || err != nil {
225 if configExists {
226 slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
227 c.Providers.Del(string(p.ID))
228 }
229 continue
230 }
231 }
232 c.Providers.Set(string(p.ID), prepared)
233 }
234
235 // validate the custom providers
236 for id, providerConfig := range c.Providers.Seq2() {
237 if knownProviderNames[id] {
238 continue
239 }
240
241 // Make sure the provider ID is set
242 providerConfig.ID = id
243 if providerConfig.Name == "" {
244 providerConfig.Name = id // Use ID as name if not set
245 }
246 // default to OpenAI if not set
247 if providerConfig.Type == "" {
248 providerConfig.Type = catwalk.TypeOpenAI
249 }
250
251 if providerConfig.Disable {
252 slog.Debug("Skipping custom provider due to disable flag", "provider", id)
253 c.Providers.Del(id)
254 continue
255 }
256 if providerConfig.APIKey == "" {
257 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
258 }
259 if providerConfig.BaseURL == "" {
260 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
261 c.Providers.Del(id)
262 continue
263 }
264 if len(providerConfig.Models) == 0 {
265 slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
266 c.Providers.Del(id)
267 continue
268 }
269 if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic && providerConfig.Type != catwalk.TypeGemini {
270 slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
271 c.Providers.Del(id)
272 continue
273 }
274
275 apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
276 if apiKey == "" || err != nil {
277 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
278 }
279 baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
280 if baseURL == "" || err != nil {
281 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
282 c.Providers.Del(id)
283 continue
284 }
285
286 c.Providers.Set(id, providerConfig)
287 }
288 return nil
289}
290
291func (c *Config) setDefaults(workingDir, dataDir string) {
292 c.workingDir = workingDir
293 if c.Options == nil {
294 c.Options = &Options{}
295 }
296 if c.Options.TUI == nil {
297 c.Options.TUI = &TUIOptions{}
298 }
299 if c.Options.ContextPaths == nil {
300 c.Options.ContextPaths = []string{}
301 }
302 if dataDir != "" {
303 c.Options.DataDirectory = dataDir
304 } else if c.Options.DataDirectory == "" {
305 if path, ok := fsext.LookupClosest(workingDir, defaultDataDirectory); ok {
306 c.Options.DataDirectory = path
307 } else {
308 c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
309 }
310 }
311 if c.Providers == nil {
312 c.Providers = csync.NewMap[string, ProviderConfig]()
313 }
314 if c.Models == nil {
315 c.Models = make(map[SelectedModelType]SelectedModel)
316 }
317 if c.MCP == nil {
318 c.MCP = make(map[string]MCPConfig)
319 }
320 if c.LSP == nil {
321 c.LSP = make(map[string]LSPConfig)
322 }
323
324 // Apply defaults to LSP configurations
325 c.applyLSPDefaults()
326
327 // Add the default context paths if they are not already present
328 c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
329 slices.Sort(c.Options.ContextPaths)
330 c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
331
332 if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok {
333 c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str)
334 }
335}
336
337// applyLSPDefaults applies default values from powernap to LSP configurations
338func (c *Config) applyLSPDefaults() {
339 // Get powernap's default configuration
340 configManager := powernapConfig.NewManager()
341 configManager.LoadDefaults()
342
343 // Apply defaults to each LSP configuration
344 for name, cfg := range c.LSP {
345 // Try to get defaults from powernap based on name or command name.
346 base, ok := configManager.GetServer(name)
347 if !ok {
348 base, ok = configManager.GetServer(cfg.Command)
349 if !ok {
350 continue
351 }
352 }
353 if cfg.Options == nil {
354 cfg.Options = base.Settings
355 }
356 if cfg.InitOptions == nil {
357 cfg.InitOptions = base.InitOptions
358 }
359 if len(cfg.FileTypes) == 0 {
360 cfg.FileTypes = base.FileTypes
361 }
362 if len(cfg.RootMarkers) == 0 {
363 cfg.RootMarkers = base.RootMarkers
364 }
365 if cfg.Command == "" {
366 cfg.Command = base.Command
367 }
368 if len(cfg.Args) == 0 {
369 cfg.Args = base.Args
370 }
371 if len(cfg.Env) == 0 {
372 cfg.Env = base.Environment
373 }
374 // Update the config in the map
375 c.LSP[name] = cfg
376 }
377}
378
379func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
380 if len(knownProviders) == 0 && c.Providers.Len() == 0 {
381 err = fmt.Errorf("no providers configured, please configure at least one provider")
382 return largeModel, smallModel, err
383 }
384
385 // Use the first provider enabled based on the known providers order
386 // if no provider found that is known use the first provider configured
387 for _, p := range knownProviders {
388 providerConfig, ok := c.Providers.Get(string(p.ID))
389 if !ok || providerConfig.Disable {
390 continue
391 }
392 defaultLargeModel := c.GetModel(string(p.ID), p.DefaultLargeModelID)
393 if defaultLargeModel == nil {
394 err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
395 return largeModel, smallModel, err
396 }
397 largeModel = SelectedModel{
398 Provider: string(p.ID),
399 Model: defaultLargeModel.ID,
400 MaxTokens: defaultLargeModel.DefaultMaxTokens,
401 ReasoningEffort: defaultLargeModel.DefaultReasoningEffort,
402 }
403
404 defaultSmallModel := c.GetModel(string(p.ID), p.DefaultSmallModelID)
405 if defaultSmallModel == nil {
406 err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
407 return largeModel, smallModel, err
408 }
409 smallModel = SelectedModel{
410 Provider: string(p.ID),
411 Model: defaultSmallModel.ID,
412 MaxTokens: defaultSmallModel.DefaultMaxTokens,
413 ReasoningEffort: defaultSmallModel.DefaultReasoningEffort,
414 }
415 return largeModel, smallModel, err
416 }
417
418 enabledProviders := c.EnabledProviders()
419 slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
420 return strings.Compare(a.ID, b.ID)
421 })
422
423 if len(enabledProviders) == 0 {
424 err = fmt.Errorf("no providers configured, please configure at least one provider")
425 return largeModel, smallModel, err
426 }
427
428 providerConfig := enabledProviders[0]
429 if len(providerConfig.Models) == 0 {
430 err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
431 return largeModel, smallModel, err
432 }
433 defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
434 largeModel = SelectedModel{
435 Provider: providerConfig.ID,
436 Model: defaultLargeModel.ID,
437 MaxTokens: defaultLargeModel.DefaultMaxTokens,
438 }
439 defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
440 smallModel = SelectedModel{
441 Provider: providerConfig.ID,
442 Model: defaultSmallModel.ID,
443 MaxTokens: defaultSmallModel.DefaultMaxTokens,
444 }
445 return largeModel, smallModel, err
446}
447
448func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
449 defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
450 if err != nil {
451 return fmt.Errorf("failed to select default models: %w", err)
452 }
453 large, small := defaultLarge, defaultSmall
454
455 largeModelSelected, largeModelConfigured := c.Models[SelectedModelTypeLarge]
456 if largeModelConfigured {
457 if largeModelSelected.Model != "" {
458 large.Model = largeModelSelected.Model
459 }
460 if largeModelSelected.Provider != "" {
461 large.Provider = largeModelSelected.Provider
462 }
463 model := c.GetModel(large.Provider, large.Model)
464 if model == nil {
465 large = defaultLarge
466 // override the model type to large
467 err := c.UpdatePreferredModel(SelectedModelTypeLarge, large)
468 if err != nil {
469 return fmt.Errorf("failed to update preferred large model: %w", err)
470 }
471 } else {
472 if largeModelSelected.MaxTokens > 0 {
473 large.MaxTokens = largeModelSelected.MaxTokens
474 } else {
475 large.MaxTokens = model.DefaultMaxTokens
476 }
477 if largeModelSelected.ReasoningEffort != "" {
478 large.ReasoningEffort = largeModelSelected.ReasoningEffort
479 }
480 large.Think = largeModelSelected.Think
481 }
482 }
483 smallModelSelected, smallModelConfigured := c.Models[SelectedModelTypeSmall]
484 if smallModelConfigured {
485 if smallModelSelected.Model != "" {
486 small.Model = smallModelSelected.Model
487 }
488 if smallModelSelected.Provider != "" {
489 small.Provider = smallModelSelected.Provider
490 }
491
492 model := c.GetModel(small.Provider, small.Model)
493 if model == nil {
494 small = defaultSmall
495 // override the model type to small
496 err := c.UpdatePreferredModel(SelectedModelTypeSmall, small)
497 if err != nil {
498 return fmt.Errorf("failed to update preferred small model: %w", err)
499 }
500 } else {
501 if smallModelSelected.MaxTokens > 0 {
502 small.MaxTokens = smallModelSelected.MaxTokens
503 } else {
504 small.MaxTokens = model.DefaultMaxTokens
505 }
506 small.ReasoningEffort = smallModelSelected.ReasoningEffort
507 small.Think = smallModelSelected.Think
508 }
509 }
510 c.Models[SelectedModelTypeLarge] = large
511 c.Models[SelectedModelTypeSmall] = small
512 return nil
513}
514
515// lookupConfigs searches config files recursively from CWD up to FS root
516func lookupConfigs(cwd string) []string {
517 // prepend default config paths
518 configPaths := []string{
519 globalConfig(),
520 GlobalConfigData(),
521 }
522
523 if cwd == "" {
524 return configPaths
525 }
526
527 configNames := []string{version.AppName + ".json", "." + version.AppName + ".json"}
528
529 foundConfigs, err := fsext.Lookup(cwd, configNames...)
530 if err != nil {
531 // returns at least default configs
532 return configPaths
533 }
534
535 // reverse order so last config has more priority
536 slices.Reverse(foundConfigs)
537
538 return append(configPaths, foundConfigs...)
539}
540
541func loadFromConfigPaths(configPaths []string) (*Config, error) {
542 var configs []io.Reader
543
544 for _, path := range configPaths {
545 fd, err := os.Open(path)
546 if err != nil {
547 if os.IsNotExist(err) {
548 continue
549 }
550 return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
551 }
552 defer fd.Close()
553
554 configs = append(configs, fd)
555 }
556
557 return loadFromReaders(configs)
558}
559
560func loadFromReaders(readers []io.Reader) (*Config, error) {
561 if len(readers) == 0 {
562 return &Config{}, nil
563 }
564
565 merged, err := Merge(readers)
566 if err != nil {
567 return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
568 }
569
570 return LoadReader(merged)
571}
572
573func hasVertexCredentials(env environ) bool {
574 hasProject := env.Getenv("VERTEXAI_PROJECT") != ""
575 hasLocation := env.Getenv("VERTEXAI_LOCATION") != ""
576 return hasProject && hasLocation
577}
578
579func hasAWSCredentials(env environ) bool {
580 if env.Getenv("AWS_ACCESS_KEY_ID") != "" && env.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
581 return true
582 }
583
584 if env.Getenv("AWS_PROFILE") != "" || env.Getenv("AWS_DEFAULT_PROFILE") != "" {
585 return true
586 }
587
588 if env.Getenv("AWS_REGION") != "" || env.Getenv("AWS_DEFAULT_REGION") != "" {
589 return true
590 }
591
592 if env.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
593 env.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
594 return true
595 }
596 return false
597}
598
599func globalConfig() string {
600 xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
601 if xdgConfigHome != "" {
602 return filepath.Join(xdgConfigHome, version.AppName, fmt.Sprintf("%s.json", version.AppName))
603 }
604
605 // return the path to the main config directory
606 // for windows, it should be in `%LOCALAPPDATA%/crush/`
607 // for linux and macOS, it should be in `$HOME/.config/crush/`
608 if runtime.GOOS == "windows" {
609 localAppData := os.Getenv("LOCALAPPDATA")
610 if localAppData == "" {
611 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
612 }
613 return filepath.Join(localAppData, version.AppName, fmt.Sprintf("%s.json", version.AppName))
614 }
615
616 return filepath.Join(home.Dir(), ".config", version.AppName, fmt.Sprintf("%s.json", version.AppName))
617}
618
619// GlobalConfigData returns the path to the main data directory for the application.
620// this config is used when the app overrides configurations instead of updating the global config.
621func GlobalConfigData() string {
622 xdgDataHome := os.Getenv("XDG_DATA_HOME")
623 if xdgDataHome != "" {
624 return filepath.Join(xdgDataHome, version.AppName, fmt.Sprintf("%s.json", version.AppName))
625 }
626
627 // return the path to the main data directory
628 // for windows, it should be in `%LOCALAPPDATA%/crush/`
629 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
630 if runtime.GOOS == "windows" {
631 localAppData := os.Getenv("LOCALAPPDATA")
632 if localAppData == "" {
633 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
634 }
635 return filepath.Join(localAppData, version.AppName, fmt.Sprintf("%s.json", version.AppName))
636 }
637
638 return filepath.Join(home.Dir(), ".local", "share", version.AppName, fmt.Sprintf("%s.json", version.AppName))
639}
640
641// GlobalCacheDir returns the path to the main cache directory for the application.
642func GlobalCacheDir() string {
643 xdgCacheHome := os.Getenv("XDG_CACHE_HOME")
644 if xdgCacheHome != "" {
645 return filepath.Join(xdgCacheHome, version.AppName)
646 }
647
648 // return the path to the main cache directory
649 // for windows, it should be in `%LOCALAPPDATA%/crush/Cache`
650 // for linux and macOS, it should be in `$HOME/.cache/crush/`
651 if runtime.GOOS == "windows" {
652 localAppData := os.Getenv("LOCALAPPDATA")
653 if localAppData == "" {
654 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
655 }
656 return filepath.Join(localAppData, version.AppName, "Cache")
657 }
658
659 return filepath.Join(home.Dir(), ".cache", version.AppName)
660}