load.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"io"
  7	"log/slog"
  8	"maps"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12	"slices"
 13	"strings"
 14
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/env"
 18	"github.com/charmbracelet/crush/internal/fsext"
 19	"github.com/charmbracelet/crush/internal/home"
 20	"github.com/charmbracelet/crush/internal/log"
 21)
 22
 23const defaultCatwalkURL = "https://catwalk.charm.sh"
 24
 25// LoadReader config via io.Reader.
 26func LoadReader(fd io.Reader) (*Config, error) {
 27	data, err := io.ReadAll(fd)
 28	if err != nil {
 29		return nil, err
 30	}
 31
 32	var config Config
 33	err = json.Unmarshal(data, &config)
 34	if err != nil {
 35		return nil, err
 36	}
 37	return &config, err
 38}
 39
 40// Load loads the configuration from the default paths.
 41func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 42	// uses default config paths
 43	configPaths := []string{
 44		globalConfig(),
 45		GlobalConfigData(),
 46		filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
 47		filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
 48	}
 49	cfg, err := loadFromConfigPaths(configPaths)
 50	if err != nil {
 51		return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
 52	}
 53
 54	cfg.dataConfigDir = GlobalConfigData()
 55
 56	cfg.setDefaults(workingDir, dataDir)
 57
 58	if debug {
 59		cfg.Options.Debug = true
 60	}
 61
 62	// Setup logs
 63	log.Setup(
 64		filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)),
 65		cfg.Options.Debug,
 66	)
 67
 68	// Load known providers, this loads the config from catwalk
 69	providers, err := Providers()
 70	if err != nil || len(providers) == 0 {
 71		return nil, fmt.Errorf("failed to load providers: %w", err)
 72	}
 73	cfg.knownProviders = providers
 74
 75	env := env.New()
 76	// Configure providers
 77	valueResolver := NewShellVariableResolver(env)
 78	cfg.resolver = valueResolver
 79	if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
 80		return nil, fmt.Errorf("failed to configure providers: %w", err)
 81	}
 82
 83	if !cfg.IsConfigured() {
 84		slog.Warn("No providers configured")
 85		return cfg, nil
 86	}
 87
 88	if err := cfg.configureSelectedModels(providers); err != nil {
 89		return nil, fmt.Errorf("failed to configure selected models: %w", err)
 90	}
 91	cfg.SetupAgents()
 92	return cfg, nil
 93}
 94
 95func PushPopCrushEnv() func() {
 96	found := []string{}
 97	for _, ev := range os.Environ() {
 98		if strings.HasPrefix(ev, "CRUSH_") {
 99			pair := strings.SplitN(ev, "=", 2)
100			if len(pair) != 2 {
101				continue
102			}
103			found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
104		}
105	}
106	backups := make(map[string]string)
107	for _, ev := range found {
108		backups[ev] = os.Getenv(ev)
109	}
110
111	for _, ev := range found {
112		os.Setenv(ev, os.Getenv("CRUSH_"+ev))
113	}
114
115	restore := func() {
116		for k, v := range backups {
117			os.Setenv(k, v)
118		}
119	}
120	return restore
121}
122
123func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
124	knownProviderNames := make(map[string]bool)
125	restore := PushPopCrushEnv()
126	defer restore()
127	for _, p := range knownProviders {
128		knownProviderNames[string(p.ID)] = true
129		config, configExists := c.Providers.Get(string(p.ID))
130		// if the user configured a known provider we need to allow it to override a couple of parameters
131		if configExists {
132			if config.Disable {
133				slog.Debug("Skipping provider due to disable flag", "provider", p.ID)
134				c.Providers.Del(string(p.ID))
135				continue
136			}
137			if config.BaseURL != "" {
138				p.APIEndpoint = config.BaseURL
139			}
140			if config.APIKey != "" {
141				p.APIKey = config.APIKey
142			}
143			if len(config.Models) > 0 {
144				models := []catwalk.Model{}
145				seen := make(map[string]bool)
146
147				for _, model := range config.Models {
148					if seen[model.ID] {
149						continue
150					}
151					seen[model.ID] = true
152					if model.Name == "" {
153						model.Name = model.ID
154					}
155					models = append(models, model)
156				}
157				for _, model := range p.Models {
158					if seen[model.ID] {
159						continue
160					}
161					seen[model.ID] = true
162					if model.Name == "" {
163						model.Name = model.ID
164					}
165					models = append(models, model)
166				}
167
168				p.Models = models
169			}
170		}
171
172		headers := map[string]string{}
173		if len(p.DefaultHeaders) > 0 {
174			maps.Copy(headers, p.DefaultHeaders)
175		}
176		if len(config.ExtraHeaders) > 0 {
177			maps.Copy(headers, config.ExtraHeaders)
178		}
179		prepared := ProviderConfig{
180			ID:                 string(p.ID),
181			Name:               p.Name,
182			BaseURL:            p.APIEndpoint,
183			APIKey:             p.APIKey,
184			Type:               p.Type,
185			Disable:            config.Disable,
186			SystemPromptPrefix: config.SystemPromptPrefix,
187			ExtraHeaders:       headers,
188			ExtraBody:          config.ExtraBody,
189			ExtraParams:        make(map[string]string),
190			Models:             p.Models,
191		}
192
193		switch p.ID {
194		// Handle specific providers that require additional configuration
195		case catwalk.InferenceProviderVertexAI:
196			if !hasVertexCredentials(env) {
197				if configExists {
198					slog.Warn("Skipping Vertex AI provider due to missing credentials")
199					c.Providers.Del(string(p.ID))
200				}
201				continue
202			}
203			prepared.ExtraParams["project"] = env.Get("VERTEXAI_PROJECT")
204			prepared.ExtraParams["location"] = env.Get("VERTEXAI_LOCATION")
205		case catwalk.InferenceProviderAzure:
206			endpoint, err := resolver.ResolveValue(p.APIEndpoint)
207			if err != nil || endpoint == "" {
208				if configExists {
209					slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
210					c.Providers.Del(string(p.ID))
211				}
212				continue
213			}
214			prepared.BaseURL = endpoint
215			prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION")
216		case catwalk.InferenceProviderBedrock:
217			if !hasAWSCredentials(env) {
218				if configExists {
219					slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
220					c.Providers.Del(string(p.ID))
221				}
222				continue
223			}
224			prepared.ExtraParams["region"] = env.Get("AWS_REGION")
225			if prepared.ExtraParams["region"] == "" {
226				prepared.ExtraParams["region"] = env.Get("AWS_DEFAULT_REGION")
227			}
228			for _, model := range p.Models {
229				if !strings.HasPrefix(model.ID, "anthropic.") {
230					return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
231				}
232			}
233		default:
234			// if the provider api or endpoint are missing we skip them
235			v, err := resolver.ResolveValue(p.APIKey)
236			if v == "" || err != nil {
237				if configExists {
238					slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
239					c.Providers.Del(string(p.ID))
240				}
241				continue
242			}
243		}
244		c.Providers.Set(string(p.ID), prepared)
245	}
246
247	// validate the custom providers
248	for id, providerConfig := range c.Providers.Seq2() {
249		if knownProviderNames[id] {
250			continue
251		}
252
253		// Make sure the provider ID is set
254		providerConfig.ID = id
255		if providerConfig.Name == "" {
256			providerConfig.Name = id // Use ID as name if not set
257		}
258		// default to OpenAI if not set
259		if providerConfig.Type == "" {
260			providerConfig.Type = catwalk.TypeOpenAI
261		}
262
263		if providerConfig.Disable {
264			slog.Debug("Skipping custom provider due to disable flag", "provider", id)
265			c.Providers.Del(id)
266			continue
267		}
268		if providerConfig.APIKey == "" {
269			slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
270		}
271		if providerConfig.BaseURL == "" {
272			slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
273			c.Providers.Del(id)
274			continue
275		}
276		if len(providerConfig.Models) == 0 {
277			slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
278			c.Providers.Del(id)
279			continue
280		}
281		if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic {
282			slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
283			c.Providers.Del(id)
284			continue
285		}
286
287		apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
288		if apiKey == "" || err != nil {
289			slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
290		}
291		baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
292		if baseURL == "" || err != nil {
293			slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
294			c.Providers.Del(id)
295			continue
296		}
297
298		c.Providers.Set(id, providerConfig)
299	}
300	return nil
301}
302
303func (c *Config) setDefaults(workingDir, dataDir string) {
304	c.workingDir = workingDir
305	if c.Options == nil {
306		c.Options = &Options{}
307	}
308	if c.Options.TUI == nil {
309		c.Options.TUI = &TUIOptions{}
310	}
311	if c.Options.ContextPaths == nil {
312		c.Options.ContextPaths = []string{}
313	}
314	if dataDir != "" {
315		c.Options.DataDirectory = dataDir
316	} else if c.Options.DataDirectory == "" {
317		if path, ok := fsext.SearchParent(workingDir, defaultDataDirectory); ok {
318			c.Options.DataDirectory = path
319		} else {
320			c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
321		}
322	}
323	if c.Providers == nil {
324		c.Providers = csync.NewMap[string, ProviderConfig]()
325	}
326	if c.Models == nil {
327		c.Models = make(map[SelectedModelType]SelectedModel)
328	}
329	if c.MCP == nil {
330		c.MCP = make(map[string]MCPConfig)
331	}
332	if c.LSP == nil {
333		c.LSP = make(map[string]LSPConfig)
334	}
335
336	// Apply default file types for known LSP servers if not specified
337	applyDefaultLSPFileTypes(c.LSP)
338
339	// Add the default context paths if they are not already present
340	c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
341	slices.Sort(c.Options.ContextPaths)
342	c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
343}
344
345var defaultLSPFileTypes = map[string][]string{
346	"gopls":                      {"go", "mod", "sum", "work"},
347	"typescript-language-server": {"ts", "tsx", "js", "jsx", "mjs", "cjs"},
348	"vtsls":                      {"ts", "tsx", "js", "jsx", "mjs", "cjs"},
349	"bash-language-server":       {"sh", "bash", "zsh", "ksh"},
350	"rust-analyzer":              {"rs"},
351	"pyright":                    {"py", "pyi"},
352	"pylsp":                      {"py", "pyi"},
353	"clangd":                     {"c", "cpp", "cc", "cxx", "h", "hpp"},
354	"jdtls":                      {"java"},
355	"vscode-html-languageserver": {"html", "htm"},
356	"vscode-css-languageserver":  {"css", "scss", "sass", "less"},
357	"vscode-json-languageserver": {"json", "jsonc"},
358	"yaml-language-server":       {"yaml", "yml"},
359	"lua-language-server":        {"lua"},
360	"solargraph":                 {"rb"},
361	"elixir-ls":                  {"ex", "exs"},
362	"zls":                        {"zig"},
363}
364
365// applyDefaultLSPFileTypes sets default file types for known LSP servers
366func applyDefaultLSPFileTypes(lspConfigs map[string]LSPConfig) {
367	for name, config := range lspConfigs {
368		if len(config.FileTypes) != 0 {
369			continue
370		}
371		bin := strings.ToLower(filepath.Base(config.Command))
372		config.FileTypes = defaultLSPFileTypes[bin]
373		lspConfigs[name] = config
374	}
375}
376
377func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
378	if len(knownProviders) == 0 && c.Providers.Len() == 0 {
379		err = fmt.Errorf("no providers configured, please configure at least one provider")
380		return
381	}
382
383	// Use the first provider enabled based on the known providers order
384	// if no provider found that is known use the first provider configured
385	for _, p := range knownProviders {
386		providerConfig, ok := c.Providers.Get(string(p.ID))
387		if !ok || providerConfig.Disable {
388			continue
389		}
390		defaultLargeModel := c.GetModel(string(p.ID), p.DefaultLargeModelID)
391		if defaultLargeModel == nil {
392			err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
393			return
394		}
395		largeModel = SelectedModel{
396			Provider:        string(p.ID),
397			Model:           defaultLargeModel.ID,
398			MaxTokens:       defaultLargeModel.DefaultMaxTokens,
399			ReasoningEffort: defaultLargeModel.DefaultReasoningEffort,
400		}
401
402		defaultSmallModel := c.GetModel(string(p.ID), p.DefaultSmallModelID)
403		if defaultSmallModel == nil {
404			err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
405			return
406		}
407		smallModel = SelectedModel{
408			Provider:        string(p.ID),
409			Model:           defaultSmallModel.ID,
410			MaxTokens:       defaultSmallModel.DefaultMaxTokens,
411			ReasoningEffort: defaultSmallModel.DefaultReasoningEffort,
412		}
413		return
414	}
415
416	enabledProviders := c.EnabledProviders()
417	slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
418		return strings.Compare(a.ID, b.ID)
419	})
420
421	if len(enabledProviders) == 0 {
422		err = fmt.Errorf("no providers configured, please configure at least one provider")
423		return
424	}
425
426	providerConfig := enabledProviders[0]
427	if len(providerConfig.Models) == 0 {
428		err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
429		return
430	}
431	defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
432	largeModel = SelectedModel{
433		Provider:  providerConfig.ID,
434		Model:     defaultLargeModel.ID,
435		MaxTokens: defaultLargeModel.DefaultMaxTokens,
436	}
437	defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
438	smallModel = SelectedModel{
439		Provider:  providerConfig.ID,
440		Model:     defaultSmallModel.ID,
441		MaxTokens: defaultSmallModel.DefaultMaxTokens,
442	}
443	return
444}
445
446func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
447	defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
448	if err != nil {
449		return fmt.Errorf("failed to select default models: %w", err)
450	}
451	large, small := defaultLarge, defaultSmall
452
453	largeModelSelected, largeModelConfigured := c.Models[SelectedModelTypeLarge]
454	if largeModelConfigured {
455		if largeModelSelected.Model != "" {
456			large.Model = largeModelSelected.Model
457		}
458		if largeModelSelected.Provider != "" {
459			large.Provider = largeModelSelected.Provider
460		}
461		model := c.GetModel(large.Provider, large.Model)
462		if model == nil {
463			large = defaultLarge
464			// override the model type to large
465			err := c.UpdatePreferredModel(SelectedModelTypeLarge, large)
466			if err != nil {
467				return fmt.Errorf("failed to update preferred large model: %w", err)
468			}
469		} else {
470			if largeModelSelected.MaxTokens > 0 {
471				large.MaxTokens = largeModelSelected.MaxTokens
472			} else {
473				large.MaxTokens = model.DefaultMaxTokens
474			}
475			if largeModelSelected.ReasoningEffort != "" {
476				large.ReasoningEffort = largeModelSelected.ReasoningEffort
477			}
478			large.Think = largeModelSelected.Think
479		}
480	}
481	smallModelSelected, smallModelConfigured := c.Models[SelectedModelTypeSmall]
482	if smallModelConfigured {
483		if smallModelSelected.Model != "" {
484			small.Model = smallModelSelected.Model
485		}
486		if smallModelSelected.Provider != "" {
487			small.Provider = smallModelSelected.Provider
488		}
489
490		model := c.GetModel(small.Provider, small.Model)
491		if model == nil {
492			small = defaultSmall
493			// override the model type to small
494			err := c.UpdatePreferredModel(SelectedModelTypeSmall, small)
495			if err != nil {
496				return fmt.Errorf("failed to update preferred small model: %w", err)
497			}
498		} else {
499			if smallModelSelected.MaxTokens > 0 {
500				small.MaxTokens = smallModelSelected.MaxTokens
501			} else {
502				small.MaxTokens = model.DefaultMaxTokens
503			}
504			small.ReasoningEffort = smallModelSelected.ReasoningEffort
505			small.Think = smallModelSelected.Think
506		}
507	}
508	c.Models[SelectedModelTypeLarge] = large
509	c.Models[SelectedModelTypeSmall] = small
510	return nil
511}
512
513func loadFromConfigPaths(configPaths []string) (*Config, error) {
514	var configs []io.Reader
515
516	for _, path := range configPaths {
517		fd, err := os.Open(path)
518		if err != nil {
519			if os.IsNotExist(err) {
520				continue
521			}
522			return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
523		}
524		defer fd.Close()
525
526		configs = append(configs, fd)
527	}
528
529	return loadFromReaders(configs)
530}
531
532func loadFromReaders(readers []io.Reader) (*Config, error) {
533	if len(readers) == 0 {
534		return &Config{}, nil
535	}
536
537	merged, err := Merge(readers)
538	if err != nil {
539		return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
540	}
541
542	return LoadReader(merged)
543}
544
545func hasVertexCredentials(env env.Env) bool {
546	hasProject := env.Get("VERTEXAI_PROJECT") != ""
547	hasLocation := env.Get("VERTEXAI_LOCATION") != ""
548	return hasProject && hasLocation
549}
550
551func hasAWSCredentials(env env.Env) bool {
552	if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
553		return true
554	}
555
556	if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
557		return true
558	}
559
560	if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
561		return true
562	}
563
564	if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
565		env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
566		return true
567	}
568	return false
569}
570
571func globalConfig() string {
572	xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
573	if xdgConfigHome != "" {
574		return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName))
575	}
576
577	// return the path to the main config directory
578	// for windows, it should be in `%LOCALAPPDATA%/crush/`
579	// for linux and macOS, it should be in `$HOME/.config/crush/`
580	if runtime.GOOS == "windows" {
581		localAppData := os.Getenv("LOCALAPPDATA")
582		if localAppData == "" {
583			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
584		}
585		return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
586	}
587
588	return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName))
589}
590
591// GlobalConfigData returns the path to the main data directory for the application.
592// this config is used when the app overrides configurations instead of updating the global config.
593func GlobalConfigData() string {
594	xdgDataHome := os.Getenv("XDG_DATA_HOME")
595	if xdgDataHome != "" {
596		return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName))
597	}
598
599	// return the path to the main data directory
600	// for windows, it should be in `%LOCALAPPDATA%/crush/`
601	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
602	if runtime.GOOS == "windows" {
603		localAppData := os.Getenv("LOCALAPPDATA")
604		if localAppData == "" {
605			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
606		}
607		return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
608	}
609
610	return filepath.Join(home.Dir(), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
611}