1package config
2
3import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "runtime"
12 "slices"
13 "strings"
14
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/env"
18 "github.com/charmbracelet/crush/internal/fsext"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/log"
21)
22
23const defaultCatwalkURL = "https://catwalk.charm.sh"
24
25// LoadReader config via io.Reader.
26func LoadReader(fd io.Reader) (*Config, error) {
27 data, err := io.ReadAll(fd)
28 if err != nil {
29 return nil, err
30 }
31
32 var config Config
33 err = json.Unmarshal(data, &config)
34 if err != nil {
35 return nil, err
36 }
37 return &config, err
38}
39
40// Load loads the configuration from the default paths.
41func Load(workingDir, dataDir string, debug bool) (*Config, error) {
42 // uses default config paths
43 configPaths := []string{
44 globalConfig(),
45 GlobalConfigData(),
46 filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
47 filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
48 }
49 cfg, err := loadFromConfigPaths(configPaths)
50 if err != nil {
51 return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
52 }
53
54 cfg.dataConfigDir = GlobalConfigData()
55
56 cfg.setDefaults(workingDir, dataDir)
57
58 if debug {
59 cfg.Options.Debug = true
60 }
61
62 // Setup logs
63 log.Setup(
64 filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)),
65 cfg.Options.Debug,
66 )
67
68 // Load known providers, this loads the config from catwalk
69 providers, err := Providers()
70 if err != nil || len(providers) == 0 {
71 return nil, fmt.Errorf("failed to load providers: %w", err)
72 }
73 cfg.knownProviders = providers
74
75 env := env.New()
76 // Configure providers
77 valueResolver := NewShellVariableResolver(env)
78 cfg.resolver = valueResolver
79 if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
80 return nil, fmt.Errorf("failed to configure providers: %w", err)
81 }
82
83 if !cfg.IsConfigured() {
84 slog.Warn("No providers configured")
85 return cfg, nil
86 }
87
88 if err := cfg.configureSelectedModels(providers); err != nil {
89 return nil, fmt.Errorf("failed to configure selected models: %w", err)
90 }
91 cfg.SetupAgents()
92 return cfg, nil
93}
94
95func PushPopCrushEnv() func() {
96 found := []string{}
97 for _, ev := range os.Environ() {
98 if strings.HasPrefix(ev, "CRUSH_") {
99 pair := strings.SplitN(ev, "=", 2)
100 if len(pair) != 2 {
101 continue
102 }
103 found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
104 }
105 }
106 backups := make(map[string]string)
107 for _, ev := range found {
108 backups[ev] = os.Getenv(ev)
109 }
110
111 for _, ev := range found {
112 os.Setenv(ev, os.Getenv("CRUSH_"+ev))
113 }
114
115 restore := func() {
116 for k, v := range backups {
117 os.Setenv(k, v)
118 }
119 }
120 return restore
121}
122
123func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
124 knownProviderNames := make(map[string]bool)
125 restore := PushPopCrushEnv()
126 defer restore()
127 for _, p := range knownProviders {
128 knownProviderNames[string(p.ID)] = true
129 config, configExists := c.Providers.Get(string(p.ID))
130 // if the user configured a known provider we need to allow it to override a couple of parameters
131 if configExists {
132 if config.Disable {
133 slog.Debug("Skipping provider due to disable flag", "provider", p.ID)
134 c.Providers.Del(string(p.ID))
135 continue
136 }
137 if config.BaseURL != "" {
138 p.APIEndpoint = config.BaseURL
139 }
140 if config.APIKey != "" {
141 p.APIKey = config.APIKey
142 }
143 if len(config.Models) > 0 {
144 models := []catwalk.Model{}
145 seen := make(map[string]bool)
146
147 for _, model := range config.Models {
148 if seen[model.ID] {
149 continue
150 }
151 seen[model.ID] = true
152 if model.Name == "" {
153 model.Name = model.ID
154 }
155 models = append(models, model)
156 }
157 for _, model := range p.Models {
158 if seen[model.ID] {
159 continue
160 }
161 seen[model.ID] = true
162 if model.Name == "" {
163 model.Name = model.ID
164 }
165 models = append(models, model)
166 }
167
168 p.Models = models
169 }
170 }
171
172 headers := map[string]string{}
173 if len(p.DefaultHeaders) > 0 {
174 maps.Copy(headers, p.DefaultHeaders)
175 }
176 if len(config.ExtraHeaders) > 0 {
177 maps.Copy(headers, config.ExtraHeaders)
178 }
179 prepared := ProviderConfig{
180 ID: string(p.ID),
181 Name: p.Name,
182 BaseURL: p.APIEndpoint,
183 APIKey: p.APIKey,
184 Type: p.Type,
185 Disable: config.Disable,
186 SystemPromptPrefix: config.SystemPromptPrefix,
187 ExtraHeaders: headers,
188 ExtraBody: config.ExtraBody,
189 ExtraParams: make(map[string]string),
190 Models: p.Models,
191 }
192
193 switch p.ID {
194 // Handle specific providers that require additional configuration
195 case catwalk.InferenceProviderVertexAI:
196 if !hasVertexCredentials(env) {
197 if configExists {
198 slog.Warn("Skipping Vertex AI provider due to missing credentials")
199 c.Providers.Del(string(p.ID))
200 }
201 continue
202 }
203 prepared.ExtraParams["project"] = env.Get("VERTEXAI_PROJECT")
204 prepared.ExtraParams["location"] = env.Get("VERTEXAI_LOCATION")
205 case catwalk.InferenceProviderAzure:
206 endpoint, err := resolver.ResolveValue(p.APIEndpoint)
207 if err != nil || endpoint == "" {
208 if configExists {
209 slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
210 c.Providers.Del(string(p.ID))
211 }
212 continue
213 }
214 prepared.BaseURL = endpoint
215 prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION")
216 case catwalk.InferenceProviderBedrock:
217 if !hasAWSCredentials(env) {
218 if configExists {
219 slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
220 c.Providers.Del(string(p.ID))
221 }
222 continue
223 }
224 prepared.ExtraParams["region"] = env.Get("AWS_REGION")
225 if prepared.ExtraParams["region"] == "" {
226 prepared.ExtraParams["region"] = env.Get("AWS_DEFAULT_REGION")
227 }
228 for _, model := range p.Models {
229 if !strings.HasPrefix(model.ID, "anthropic.") {
230 return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
231 }
232 }
233 default:
234 // if the provider api or endpoint are missing we skip them
235 v, err := resolver.ResolveValue(p.APIKey)
236 if v == "" || err != nil {
237 if configExists {
238 slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
239 c.Providers.Del(string(p.ID))
240 }
241 continue
242 }
243 }
244 c.Providers.Set(string(p.ID), prepared)
245 }
246
247 // validate the custom providers
248 for id, providerConfig := range c.Providers.Seq2() {
249 if knownProviderNames[id] {
250 continue
251 }
252
253 // Make sure the provider ID is set
254 providerConfig.ID = id
255 if providerConfig.Name == "" {
256 providerConfig.Name = id // Use ID as name if not set
257 }
258 // default to OpenAI if not set
259 if providerConfig.Type == "" {
260 providerConfig.Type = catwalk.TypeOpenAI
261 }
262
263 if providerConfig.Disable {
264 slog.Debug("Skipping custom provider due to disable flag", "provider", id)
265 c.Providers.Del(id)
266 continue
267 }
268 if providerConfig.APIKey == "" {
269 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
270 }
271 if providerConfig.BaseURL == "" {
272 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
273 c.Providers.Del(id)
274 continue
275 }
276 if len(providerConfig.Models) == 0 {
277 slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
278 c.Providers.Del(id)
279 continue
280 }
281 if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic {
282 slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
283 c.Providers.Del(id)
284 continue
285 }
286
287 apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
288 if apiKey == "" || err != nil {
289 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
290 }
291 baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
292 if baseURL == "" || err != nil {
293 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
294 c.Providers.Del(id)
295 continue
296 }
297
298 c.Providers.Set(id, providerConfig)
299 }
300 return nil
301}
302
303func (c *Config) setDefaults(workingDir, dataDir string) {
304 c.workingDir = workingDir
305 if c.Options == nil {
306 c.Options = &Options{}
307 }
308 if c.Options.TUI == nil {
309 c.Options.TUI = &TUIOptions{}
310 }
311 if c.Options.ContextPaths == nil {
312 c.Options.ContextPaths = []string{}
313 }
314 if dataDir != "" {
315 c.Options.DataDirectory = dataDir
316 } else if c.Options.DataDirectory == "" {
317 if path, ok := fsext.SearchParent(workingDir, defaultDataDirectory); ok {
318 c.Options.DataDirectory = path
319 } else {
320 c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
321 }
322 }
323 if c.Providers == nil {
324 c.Providers = csync.NewMap[string, ProviderConfig]()
325 }
326 if c.Models == nil {
327 c.Models = make(map[SelectedModelType]SelectedModel)
328 }
329 if c.MCP == nil {
330 c.MCP = make(map[string]MCPConfig)
331 }
332 if c.LSP == nil {
333 c.LSP = make(map[string]LSPConfig)
334 }
335
336 // Apply default file types for known LSP servers if not specified
337 applyDefaultLSPFileTypes(c.LSP)
338
339 // Add the default context paths if they are not already present
340 c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
341 slices.Sort(c.Options.ContextPaths)
342 c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
343}
344
345var defaultLSPFileTypes = map[string][]string{
346 "gopls": {"go", "mod", "sum", "work"},
347 "typescript-language-server": {"ts", "tsx", "js", "jsx", "mjs", "cjs"},
348 "vtsls": {"ts", "tsx", "js", "jsx", "mjs", "cjs"},
349 "bash-language-server": {"sh", "bash", "zsh", "ksh"},
350 "rust-analyzer": {"rs"},
351 "pyright": {"py", "pyi"},
352 "pylsp": {"py", "pyi"},
353 "clangd": {"c", "cpp", "cc", "cxx", "h", "hpp"},
354 "jdtls": {"java"},
355 "vscode-html-languageserver": {"html", "htm"},
356 "vscode-css-languageserver": {"css", "scss", "sass", "less"},
357 "vscode-json-languageserver": {"json", "jsonc"},
358 "yaml-language-server": {"yaml", "yml"},
359 "lua-language-server": {"lua"},
360 "solargraph": {"rb"},
361 "elixir-ls": {"ex", "exs"},
362 "zls": {"zig"},
363}
364
365// applyDefaultLSPFileTypes sets default file types for known LSP servers
366func applyDefaultLSPFileTypes(lspConfigs map[string]LSPConfig) {
367 for name, config := range lspConfigs {
368 if len(config.FileTypes) != 0 {
369 continue
370 }
371 bin := strings.ToLower(filepath.Base(config.Command))
372 config.FileTypes = defaultLSPFileTypes[bin]
373 lspConfigs[name] = config
374 }
375}
376
377func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
378 if len(knownProviders) == 0 && c.Providers.Len() == 0 {
379 err = fmt.Errorf("no providers configured, please configure at least one provider")
380 return
381 }
382
383 // Use the first provider enabled based on the known providers order
384 // if no provider found that is known use the first provider configured
385 for _, p := range knownProviders {
386 providerConfig, ok := c.Providers.Get(string(p.ID))
387 if !ok || providerConfig.Disable {
388 continue
389 }
390 defaultLargeModel := c.GetModel(string(p.ID), p.DefaultLargeModelID)
391 if defaultLargeModel == nil {
392 err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
393 return
394 }
395 largeModel = SelectedModel{
396 Provider: string(p.ID),
397 Model: defaultLargeModel.ID,
398 MaxTokens: defaultLargeModel.DefaultMaxTokens,
399 ReasoningEffort: defaultLargeModel.DefaultReasoningEffort,
400 }
401
402 defaultSmallModel := c.GetModel(string(p.ID), p.DefaultSmallModelID)
403 if defaultSmallModel == nil {
404 err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
405 return
406 }
407 smallModel = SelectedModel{
408 Provider: string(p.ID),
409 Model: defaultSmallModel.ID,
410 MaxTokens: defaultSmallModel.DefaultMaxTokens,
411 ReasoningEffort: defaultSmallModel.DefaultReasoningEffort,
412 }
413 return
414 }
415
416 enabledProviders := c.EnabledProviders()
417 slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
418 return strings.Compare(a.ID, b.ID)
419 })
420
421 if len(enabledProviders) == 0 {
422 err = fmt.Errorf("no providers configured, please configure at least one provider")
423 return
424 }
425
426 providerConfig := enabledProviders[0]
427 if len(providerConfig.Models) == 0 {
428 err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
429 return
430 }
431 defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
432 largeModel = SelectedModel{
433 Provider: providerConfig.ID,
434 Model: defaultLargeModel.ID,
435 MaxTokens: defaultLargeModel.DefaultMaxTokens,
436 }
437 defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
438 smallModel = SelectedModel{
439 Provider: providerConfig.ID,
440 Model: defaultSmallModel.ID,
441 MaxTokens: defaultSmallModel.DefaultMaxTokens,
442 }
443 return
444}
445
446func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
447 defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
448 if err != nil {
449 return fmt.Errorf("failed to select default models: %w", err)
450 }
451 large, small := defaultLarge, defaultSmall
452
453 largeModelSelected, largeModelConfigured := c.Models[SelectedModelTypeLarge]
454 if largeModelConfigured {
455 if largeModelSelected.Model != "" {
456 large.Model = largeModelSelected.Model
457 }
458 if largeModelSelected.Provider != "" {
459 large.Provider = largeModelSelected.Provider
460 }
461 model := c.GetModel(large.Provider, large.Model)
462 if model == nil {
463 large = defaultLarge
464 // override the model type to large
465 err := c.UpdatePreferredModel(SelectedModelTypeLarge, large)
466 if err != nil {
467 return fmt.Errorf("failed to update preferred large model: %w", err)
468 }
469 } else {
470 if largeModelSelected.MaxTokens > 0 {
471 large.MaxTokens = largeModelSelected.MaxTokens
472 } else {
473 large.MaxTokens = model.DefaultMaxTokens
474 }
475 if largeModelSelected.ReasoningEffort != "" {
476 large.ReasoningEffort = largeModelSelected.ReasoningEffort
477 }
478 large.Think = largeModelSelected.Think
479 }
480 }
481 smallModelSelected, smallModelConfigured := c.Models[SelectedModelTypeSmall]
482 if smallModelConfigured {
483 if smallModelSelected.Model != "" {
484 small.Model = smallModelSelected.Model
485 }
486 if smallModelSelected.Provider != "" {
487 small.Provider = smallModelSelected.Provider
488 }
489
490 model := c.GetModel(small.Provider, small.Model)
491 if model == nil {
492 small = defaultSmall
493 // override the model type to small
494 err := c.UpdatePreferredModel(SelectedModelTypeSmall, small)
495 if err != nil {
496 return fmt.Errorf("failed to update preferred small model: %w", err)
497 }
498 } else {
499 if smallModelSelected.MaxTokens > 0 {
500 small.MaxTokens = smallModelSelected.MaxTokens
501 } else {
502 small.MaxTokens = model.DefaultMaxTokens
503 }
504 small.ReasoningEffort = smallModelSelected.ReasoningEffort
505 small.Think = smallModelSelected.Think
506 }
507 }
508 c.Models[SelectedModelTypeLarge] = large
509 c.Models[SelectedModelTypeSmall] = small
510 return nil
511}
512
513func loadFromConfigPaths(configPaths []string) (*Config, error) {
514 var configs []io.Reader
515
516 for _, path := range configPaths {
517 fd, err := os.Open(path)
518 if err != nil {
519 if os.IsNotExist(err) {
520 continue
521 }
522 return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
523 }
524 defer fd.Close()
525
526 configs = append(configs, fd)
527 }
528
529 return loadFromReaders(configs)
530}
531
532func loadFromReaders(readers []io.Reader) (*Config, error) {
533 if len(readers) == 0 {
534 return &Config{}, nil
535 }
536
537 merged, err := Merge(readers)
538 if err != nil {
539 return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
540 }
541
542 return LoadReader(merged)
543}
544
545func hasVertexCredentials(env env.Env) bool {
546 hasProject := env.Get("VERTEXAI_PROJECT") != ""
547 hasLocation := env.Get("VERTEXAI_LOCATION") != ""
548 return hasProject && hasLocation
549}
550
551func hasAWSCredentials(env env.Env) bool {
552 if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
553 return true
554 }
555
556 if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
557 return true
558 }
559
560 if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
561 return true
562 }
563
564 if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
565 env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
566 return true
567 }
568 return false
569}
570
571func globalConfig() string {
572 xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
573 if xdgConfigHome != "" {
574 return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName))
575 }
576
577 // return the path to the main config directory
578 // for windows, it should be in `%LOCALAPPDATA%/crush/`
579 // for linux and macOS, it should be in `$HOME/.config/crush/`
580 if runtime.GOOS == "windows" {
581 localAppData := os.Getenv("LOCALAPPDATA")
582 if localAppData == "" {
583 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
584 }
585 return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
586 }
587
588 return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName))
589}
590
591// GlobalConfigData returns the path to the main data directory for the application.
592// this config is used when the app overrides configurations instead of updating the global config.
593func GlobalConfigData() string {
594 xdgDataHome := os.Getenv("XDG_DATA_HOME")
595 if xdgDataHome != "" {
596 return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName))
597 }
598
599 // return the path to the main data directory
600 // for windows, it should be in `%LOCALAPPDATA%/crush/`
601 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
602 if runtime.GOOS == "windows" {
603 localAppData := os.Getenv("LOCALAPPDATA")
604 if localAppData == "" {
605 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
606 }
607 return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
608 }
609
610 return filepath.Join(home.Dir(), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
611}