1package config
2
3import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "runtime"
12 "slices"
13 "strings"
14
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/env"
18 "github.com/charmbracelet/crush/internal/log"
19)
20
21const defaultCatwalkURL = "https://catwalk.charm.sh"
22
23// LoadReader config via io.Reader.
24func LoadReader(fd io.Reader) (*Config, error) {
25 data, err := io.ReadAll(fd)
26 if err != nil {
27 return nil, err
28 }
29
30 var config Config
31 err = json.Unmarshal(data, &config)
32 if err != nil {
33 return nil, err
34 }
35 return &config, err
36}
37
38// Load loads the configuration from the default paths.
39func Load(workingDir string, debug bool) (*Config, error) {
40 // uses default config paths
41 configPaths := []string{
42 globalConfig(),
43 GlobalConfigData(),
44 filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
45 filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
46 }
47 cfg, err := loadFromConfigPaths(configPaths)
48 if err != nil {
49 return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
50 }
51
52 cfg.dataConfigDir = GlobalConfigData()
53
54 cfg.setDefaults(workingDir)
55
56 if debug {
57 cfg.Options.Debug = true
58 }
59
60 // Setup logs
61 log.Setup(
62 filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)),
63 cfg.Options.Debug,
64 )
65
66 // Load known providers, this loads the config from catwalk
67 providers, err := Providers()
68 if err != nil || len(providers) == 0 {
69 return nil, fmt.Errorf("failed to load providers: %w", err)
70 }
71 cfg.knownProviders = providers
72
73 env := env.New()
74 // Configure providers
75 valueResolver := NewShellVariableResolver(env)
76 cfg.resolver = valueResolver
77 if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
78 return nil, fmt.Errorf("failed to configure providers: %w", err)
79 }
80
81 if !cfg.IsConfigured() {
82 slog.Warn("No providers configured")
83 return cfg, nil
84 }
85
86 if err := cfg.configureSelectedModels(providers); err != nil {
87 return nil, fmt.Errorf("failed to configure selected models: %w", err)
88 }
89 cfg.SetupAgents()
90 return cfg, nil
91}
92
93func PushPopCrushEnv() func() {
94 found := []string{}
95 for _, ev := range os.Environ() {
96 if strings.HasPrefix(ev, "CRUSH_") {
97 pair := strings.SplitN(ev, "=", 2)
98 if len(pair) != 2 {
99 continue
100 }
101 found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
102 }
103 }
104 backups := make(map[string]string)
105 for _, ev := range found {
106 backups[ev] = os.Getenv(ev)
107 }
108
109 for _, ev := range found {
110 os.Setenv(ev, os.Getenv("CRUSH_"+ev))
111 }
112
113 restore := func() {
114 for k, v := range backups {
115 os.Setenv(k, v)
116 }
117 }
118 return restore
119}
120
121func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
122 knownProviderNames := make(map[string]bool)
123 restore := PushPopCrushEnv()
124 defer restore()
125 for _, p := range knownProviders {
126 knownProviderNames[string(p.ID)] = true
127 config, configExists := c.Providers.Get(string(p.ID))
128 // if the user configured a known provider we need to allow it to override a couple of parameters
129 if configExists {
130 if config.Disable {
131 slog.Debug("Skipping provider due to disable flag", "provider", p.ID)
132 c.Providers.Del(string(p.ID))
133 continue
134 }
135 if config.BaseURL != "" {
136 p.APIEndpoint = config.BaseURL
137 }
138 if config.APIKey != "" {
139 p.APIKey = config.APIKey
140 }
141 if len(config.Models) > 0 {
142 models := []catwalk.Model{}
143 seen := make(map[string]bool)
144
145 for _, model := range config.Models {
146 if seen[model.ID] {
147 continue
148 }
149 seen[model.ID] = true
150 if model.Name == "" {
151 model.Name = model.ID
152 }
153 models = append(models, model)
154 }
155 for _, model := range p.Models {
156 if seen[model.ID] {
157 continue
158 }
159 seen[model.ID] = true
160 if model.Name == "" {
161 model.Name = model.ID
162 }
163 models = append(models, model)
164 }
165
166 p.Models = models
167 }
168 }
169
170 headers := map[string]string{}
171 if len(p.DefaultHeaders) > 0 {
172 maps.Copy(headers, p.DefaultHeaders)
173 }
174 if len(config.ExtraHeaders) > 0 {
175 maps.Copy(headers, config.ExtraHeaders)
176 }
177 prepared := ProviderConfig{
178 ID: string(p.ID),
179 Name: p.Name,
180 BaseURL: p.APIEndpoint,
181 APIKey: p.APIKey,
182 Type: p.Type,
183 Disable: config.Disable,
184 SystemPromptPrefix: config.SystemPromptPrefix,
185 ExtraHeaders: headers,
186 ExtraBody: config.ExtraBody,
187 ExtraParams: make(map[string]string),
188 Models: p.Models,
189 }
190
191 switch p.ID {
192 // Handle specific providers that require additional configuration
193 case catwalk.InferenceProviderVertexAI:
194 if !hasVertexCredentials(env) {
195 if configExists {
196 slog.Warn("Skipping Vertex AI provider due to missing credentials")
197 c.Providers.Del(string(p.ID))
198 }
199 continue
200 }
201 prepared.ExtraParams["project"] = env.Get("VERTEXAI_PROJECT")
202 prepared.ExtraParams["location"] = env.Get("VERTEXAI_LOCATION")
203 case catwalk.InferenceProviderAzure:
204 endpoint, err := resolver.ResolveValue(p.APIEndpoint)
205 if err != nil || endpoint == "" {
206 if configExists {
207 slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
208 c.Providers.Del(string(p.ID))
209 }
210 continue
211 }
212 prepared.BaseURL = endpoint
213 prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION")
214 case catwalk.InferenceProviderBedrock:
215 if !hasAWSCredentials(env) {
216 if configExists {
217 slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
218 c.Providers.Del(string(p.ID))
219 }
220 continue
221 }
222 prepared.ExtraParams["region"] = env.Get("AWS_REGION")
223 if prepared.ExtraParams["region"] == "" {
224 prepared.ExtraParams["region"] = env.Get("AWS_DEFAULT_REGION")
225 }
226 for _, model := range p.Models {
227 if !strings.HasPrefix(model.ID, "anthropic.") {
228 return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
229 }
230 }
231 default:
232 // if the provider api or endpoint are missing we skip them
233 v, err := resolver.ResolveValue(p.APIKey)
234 if v == "" || err != nil {
235 if configExists {
236 slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
237 c.Providers.Del(string(p.ID))
238 }
239 continue
240 }
241 }
242 c.Providers.Set(string(p.ID), prepared)
243 }
244
245 // validate the custom providers
246 for id, providerConfig := range c.Providers.Seq2() {
247 if knownProviderNames[id] {
248 continue
249 }
250
251 // Make sure the provider ID is set
252 providerConfig.ID = id
253 if providerConfig.Name == "" {
254 providerConfig.Name = id // Use ID as name if not set
255 }
256 // default to OpenAI if not set
257 if providerConfig.Type == "" {
258 providerConfig.Type = catwalk.TypeOpenAI
259 }
260
261 if providerConfig.Disable {
262 slog.Debug("Skipping custom provider due to disable flag", "provider", id)
263 c.Providers.Del(id)
264 continue
265 }
266 if providerConfig.APIKey == "" {
267 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
268 }
269 if providerConfig.BaseURL == "" {
270 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
271 c.Providers.Del(id)
272 continue
273 }
274 if len(providerConfig.Models) == 0 {
275 slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
276 c.Providers.Del(id)
277 continue
278 }
279 if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic {
280 slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
281 c.Providers.Del(id)
282 continue
283 }
284
285 apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
286 if apiKey == "" || err != nil {
287 slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
288 }
289 baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
290 if baseURL == "" || err != nil {
291 slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
292 c.Providers.Del(id)
293 continue
294 }
295
296 c.Providers.Set(id, providerConfig)
297 }
298 return nil
299}
300
301func (c *Config) setDefaults(workingDir string) {
302 c.workingDir = workingDir
303 if c.Options == nil {
304 c.Options = &Options{}
305 }
306 if c.Options.TUI == nil {
307 c.Options.TUI = &TUIOptions{}
308 }
309 if c.Options.ContextPaths == nil {
310 c.Options.ContextPaths = []string{}
311 }
312 if c.Options.DataDirectory == "" {
313 c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
314 }
315 if c.Providers == nil {
316 c.Providers = csync.NewMap[string, ProviderConfig]()
317 }
318 if c.Models == nil {
319 c.Models = make(map[SelectedModelType]SelectedModel)
320 }
321 if c.MCP == nil {
322 c.MCP = make(map[string]MCPConfig)
323 }
324 if c.LSP == nil {
325 c.LSP = make(map[string]LSPConfig)
326 }
327
328 // Apply default file types for known LSP servers if not specified
329 applyDefaultLSPFileTypes(c.LSP)
330
331 // Add the default context paths if they are not already present
332 c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
333 slices.Sort(c.Options.ContextPaths)
334 c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
335}
336
337var defaultLSPFileTypes = map[string][]string{
338 "gopls": {"go", "mod", "sum", "work"},
339 "typescript-language-server": {"ts", "tsx", "js", "jsx", "mjs", "cjs"},
340 "vtsls": {"ts", "tsx", "js", "jsx", "mjs", "cjs"},
341 "bash-language-server": {"sh", "bash", "zsh", "ksh"},
342 "rust-analyzer": {"rs"},
343 "pyright": {"py", "pyi"},
344 "pylsp": {"py", "pyi"},
345 "clangd": {"c", "cpp", "cc", "cxx", "h", "hpp"},
346 "jdtls": {"java"},
347 "vscode-html-languageserver": {"html", "htm"},
348 "vscode-css-languageserver": {"css", "scss", "sass", "less"},
349 "vscode-json-languageserver": {"json", "jsonc"},
350 "yaml-language-server": {"yaml", "yml"},
351 "lua-language-server": {"lua"},
352 "solargraph": {"rb"},
353 "elixir-ls": {"ex", "exs"},
354 "zls": {"zig"},
355}
356
357// applyDefaultLSPFileTypes sets default file types for known LSP servers
358func applyDefaultLSPFileTypes(lspConfigs map[string]LSPConfig) {
359 for name, config := range lspConfigs {
360 if len(config.FileTypes) != 0 {
361 continue
362 }
363 bin := strings.ToLower(filepath.Base(config.Command))
364 config.FileTypes = defaultLSPFileTypes[bin]
365 lspConfigs[name] = config
366 }
367}
368
369func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
370 if len(knownProviders) == 0 && c.Providers.Len() == 0 {
371 err = fmt.Errorf("no providers configured, please configure at least one provider")
372 return
373 }
374
375 // Use the first provider enabled based on the known providers order
376 // if no provider found that is known use the first provider configured
377 for _, p := range knownProviders {
378 providerConfig, ok := c.Providers.Get(string(p.ID))
379 if !ok || providerConfig.Disable {
380 continue
381 }
382 defaultLargeModel := c.GetModel(string(p.ID), p.DefaultLargeModelID)
383 if defaultLargeModel == nil {
384 err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
385 return
386 }
387 largeModel = SelectedModel{
388 Provider: string(p.ID),
389 Model: defaultLargeModel.ID,
390 MaxTokens: defaultLargeModel.DefaultMaxTokens,
391 ReasoningEffort: defaultLargeModel.DefaultReasoningEffort,
392 }
393
394 defaultSmallModel := c.GetModel(string(p.ID), p.DefaultSmallModelID)
395 if defaultSmallModel == nil {
396 err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
397 return
398 }
399 smallModel = SelectedModel{
400 Provider: string(p.ID),
401 Model: defaultSmallModel.ID,
402 MaxTokens: defaultSmallModel.DefaultMaxTokens,
403 ReasoningEffort: defaultSmallModel.DefaultReasoningEffort,
404 }
405 return
406 }
407
408 enabledProviders := c.EnabledProviders()
409 slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
410 return strings.Compare(a.ID, b.ID)
411 })
412
413 if len(enabledProviders) == 0 {
414 err = fmt.Errorf("no providers configured, please configure at least one provider")
415 return
416 }
417
418 providerConfig := enabledProviders[0]
419 if len(providerConfig.Models) == 0 {
420 err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
421 return
422 }
423 defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
424 largeModel = SelectedModel{
425 Provider: providerConfig.ID,
426 Model: defaultLargeModel.ID,
427 MaxTokens: defaultLargeModel.DefaultMaxTokens,
428 }
429 defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
430 smallModel = SelectedModel{
431 Provider: providerConfig.ID,
432 Model: defaultSmallModel.ID,
433 MaxTokens: defaultSmallModel.DefaultMaxTokens,
434 }
435 return
436}
437
438func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
439 defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
440 if err != nil {
441 return fmt.Errorf("failed to select default models: %w", err)
442 }
443 large, small := defaultLarge, defaultSmall
444
445 largeModelSelected, largeModelConfigured := c.Models[SelectedModelTypeLarge]
446 if largeModelConfigured {
447 if largeModelSelected.Model != "" {
448 large.Model = largeModelSelected.Model
449 }
450 if largeModelSelected.Provider != "" {
451 large.Provider = largeModelSelected.Provider
452 }
453 model := c.GetModel(large.Provider, large.Model)
454 if model == nil {
455 large = defaultLarge
456 // override the model type to large
457 err := c.UpdatePreferredModel(SelectedModelTypeLarge, large)
458 if err != nil {
459 return fmt.Errorf("failed to update preferred large model: %w", err)
460 }
461 } else {
462 if largeModelSelected.MaxTokens > 0 {
463 large.MaxTokens = largeModelSelected.MaxTokens
464 } else {
465 large.MaxTokens = model.DefaultMaxTokens
466 }
467 if largeModelSelected.ReasoningEffort != "" {
468 large.ReasoningEffort = largeModelSelected.ReasoningEffort
469 }
470 large.Think = largeModelSelected.Think
471 }
472 }
473 smallModelSelected, smallModelConfigured := c.Models[SelectedModelTypeSmall]
474 if smallModelConfigured {
475 if smallModelSelected.Model != "" {
476 small.Model = smallModelSelected.Model
477 }
478 if smallModelSelected.Provider != "" {
479 small.Provider = smallModelSelected.Provider
480 }
481
482 model := c.GetModel(small.Provider, small.Model)
483 if model == nil {
484 small = defaultSmall
485 // override the model type to small
486 err := c.UpdatePreferredModel(SelectedModelTypeSmall, small)
487 if err != nil {
488 return fmt.Errorf("failed to update preferred small model: %w", err)
489 }
490 } else {
491 if smallModelSelected.MaxTokens > 0 {
492 small.MaxTokens = smallModelSelected.MaxTokens
493 } else {
494 small.MaxTokens = model.DefaultMaxTokens
495 }
496 small.ReasoningEffort = smallModelSelected.ReasoningEffort
497 small.Think = smallModelSelected.Think
498 }
499 }
500 c.Models[SelectedModelTypeLarge] = large
501 c.Models[SelectedModelTypeSmall] = small
502 return nil
503}
504
505func loadFromConfigPaths(configPaths []string) (*Config, error) {
506 var configs []io.Reader
507
508 for _, path := range configPaths {
509 fd, err := os.Open(path)
510 if err != nil {
511 if os.IsNotExist(err) {
512 continue
513 }
514 return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
515 }
516 defer fd.Close()
517
518 configs = append(configs, fd)
519 }
520
521 return loadFromReaders(configs)
522}
523
524func loadFromReaders(readers []io.Reader) (*Config, error) {
525 if len(readers) == 0 {
526 return &Config{}, nil
527 }
528
529 merged, err := Merge(readers)
530 if err != nil {
531 return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
532 }
533
534 return LoadReader(merged)
535}
536
537func hasVertexCredentials(env env.Env) bool {
538 hasProject := env.Get("VERTEXAI_PROJECT") != ""
539 hasLocation := env.Get("VERTEXAI_LOCATION") != ""
540 return hasProject && hasLocation
541}
542
543func hasAWSCredentials(env env.Env) bool {
544 if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
545 return true
546 }
547
548 if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
549 return true
550 }
551
552 if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
553 return true
554 }
555
556 if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
557 env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
558 return true
559 }
560 return false
561}
562
563func globalConfig() string {
564 xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
565 if xdgConfigHome != "" {
566 return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName))
567 }
568
569 // return the path to the main config directory
570 // for windows, it should be in `%LOCALAPPDATA%/crush/`
571 // for linux and macOS, it should be in `$HOME/.config/crush/`
572 if runtime.GOOS == "windows" {
573 localAppData := os.Getenv("LOCALAPPDATA")
574 if localAppData == "" {
575 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
576 }
577 return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
578 }
579
580 return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName))
581}
582
583// GlobalConfigData returns the path to the main data directory for the application.
584// this config is used when the app overrides configurations instead of updating the global config.
585func GlobalConfigData() string {
586 xdgDataHome := os.Getenv("XDG_DATA_HOME")
587 if xdgDataHome != "" {
588 return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName))
589 }
590
591 // return the path to the main data directory
592 // for windows, it should be in `%LOCALAPPDATA%/crush/`
593 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
594 if runtime.GOOS == "windows" {
595 localAppData := os.Getenv("LOCALAPPDATA")
596 if localAppData == "" {
597 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
598 }
599 return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
600 }
601
602 return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
603}