load.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"io"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"slices"
 11	"strings"
 12
 13	"github.com/charmbracelet/crush/internal/fur/client"
 14	"github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/pkg/env"
 16)
 17
 18// LoadReader config via io.Reader.
 19func LoadReader(fd io.Reader) (*Config, error) {
 20	data, err := io.ReadAll(fd)
 21	if err != nil {
 22		return nil, err
 23	}
 24
 25	var config Config
 26	err = json.Unmarshal(data, &config)
 27	if err != nil {
 28		return nil, err
 29	}
 30	return &config, err
 31}
 32
 33// Load loads the configuration from the default paths.
 34func Load(workingDir string, env env.Env) (*Config, error) {
 35	// uses default config paths
 36	configPaths := []string{
 37		globalConfig(),
 38		globalConfigData(),
 39		filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
 40		filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
 41	}
 42	cfg, err := loadFromConfigPaths(configPaths)
 43	if err != nil {
 44		return nil, fmt.Errorf("failed to load config: %w", err)
 45	}
 46	// TODO: maybe add a validation step here right after loading
 47	// e.x validate the models
 48	// e.x validate provider config
 49
 50	setDefaults(workingDir, cfg)
 51
 52	// Load known providers, this loads the config from fur
 53	providers, err := LoadProviders(client.New())
 54	if err != nil {
 55		return nil, fmt.Errorf("failed to load providers: %w", err)
 56	}
 57
 58	// Configure providers
 59	valueResolver := NewShellVariableResolver(env)
 60	if err := configureProviders(cfg, env, valueResolver, providers); err != nil {
 61		return nil, fmt.Errorf("failed to configure providers: %w", err)
 62	}
 63
 64	return cfg, nil
 65}
 66
 67func configureProviders(cfg *Config, env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
 68	for _, p := range knownProviders {
 69
 70		config, ok := cfg.Providers[string(p.ID)]
 71		// if the user configured a known provider we need to allow it to override a couple of parameters
 72		if ok {
 73			if config.BaseURL != "" {
 74				p.APIEndpoint = config.BaseURL
 75			}
 76			if config.APIKey != "" {
 77				p.APIKey = config.APIKey
 78			}
 79			if len(config.Models) > 0 {
 80				models := []provider.Model{}
 81				seen := make(map[string]bool)
 82
 83				for _, model := range config.Models {
 84					if seen[model.ID] {
 85						continue
 86					}
 87					seen[model.ID] = true
 88					models = append(models, model)
 89				}
 90				for _, model := range p.Models {
 91					if seen[model.ID] {
 92						continue
 93					}
 94					seen[model.ID] = true
 95					models = append(models, model)
 96				}
 97
 98				p.Models = models
 99			}
100		}
101		prepared := ProviderConfig{
102			BaseURL:      p.APIEndpoint,
103			APIKey:       p.APIKey,
104			Type:         p.Type,
105			Disable:      config.Disable,
106			ExtraHeaders: config.ExtraHeaders,
107			ExtraParams:  make(map[string]string),
108			Models:       p.Models,
109		}
110
111		switch p.ID {
112		// Handle specific providers that require additional configuration
113		case provider.InferenceProviderVertexAI:
114			if !hasVertexCredentials(env) {
115				continue
116			}
117			prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
118			prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
119		case provider.InferenceProviderBedrock:
120			if !hasAWSCredentials(env) {
121				continue
122			}
123			for _, model := range p.Models {
124				if !strings.HasPrefix(model.ID, "anthropic.") {
125					return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
126				}
127			}
128		default:
129			// if the provider api or endpoint are missing we skip them
130			v, err := resolver.ResolveValue(p.APIKey)
131			if v == "" || err != nil {
132				continue
133			}
134			v, err = resolver.ResolveValue(p.APIEndpoint)
135			if v == "" || err != nil {
136				continue
137			}
138		}
139		cfg.Providers[string(p.ID)] = prepared
140	}
141	return nil
142}
143
144func hasVertexCredentials(env env.Env) bool {
145	useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
146	hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
147	hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
148	return useVertex && hasProject && hasLocation
149}
150
151func hasAWSCredentials(env env.Env) bool {
152	if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
153		return true
154	}
155
156	if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
157		return true
158	}
159
160	if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
161		return true
162	}
163
164	if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
165		env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
166		return true
167	}
168
169	return false
170}
171
172func setDefaults(workingDir string, cfg *Config) {
173	cfg.workingDir = workingDir
174	if cfg.Options == nil {
175		cfg.Options = &Options{}
176	}
177	if cfg.Options.TUI == nil {
178		cfg.Options.TUI = &TUIOptions{}
179	}
180	if cfg.Options.ContextPaths == nil {
181		cfg.Options.ContextPaths = []string{}
182	}
183	if cfg.Options.DataDirectory == "" {
184		cfg.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
185	}
186	if cfg.Providers == nil {
187		cfg.Providers = make(map[string]ProviderConfig)
188	}
189	if cfg.Models == nil {
190		cfg.Models = make(map[string]SelectedModel)
191	}
192	if cfg.MCP == nil {
193		cfg.MCP = make(map[string]MCP)
194	}
195	if cfg.LSP == nil {
196		cfg.LSP = make(map[string]LSPConfig)
197	}
198
199	// Add the default context paths if they are not already present
200	cfg.Options.ContextPaths = append(defaultContextPaths, cfg.Options.ContextPaths...)
201	slices.Sort(cfg.Options.ContextPaths)
202	cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
203}
204
205func loadFromConfigPaths(configPaths []string) (*Config, error) {
206	var configs []io.Reader
207
208	for _, path := range configPaths {
209		fd, err := os.Open(path)
210		if err != nil {
211			if os.IsNotExist(err) {
212				continue
213			}
214			return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
215		}
216		defer fd.Close()
217
218		configs = append(configs, fd)
219	}
220
221	return loadFromReaders(configs)
222}
223
224func loadFromReaders(readers []io.Reader) (*Config, error) {
225	if len(readers) == 0 {
226		return nil, fmt.Errorf("no configuration readers provided")
227	}
228
229	merged, err := Merge(readers)
230	if err != nil {
231		return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
232	}
233
234	return LoadReader(merged)
235}
236
237func globalConfig() string {
238	xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
239	if xdgConfigHome != "" {
240		return filepath.Join(xdgConfigHome, "crush")
241	}
242
243	// return the path to the main config directory
244	// for windows, it should be in `%LOCALAPPDATA%/crush/`
245	// for linux and macOS, it should be in `$HOME/.config/crush/`
246	if runtime.GOOS == "windows" {
247		localAppData := os.Getenv("LOCALAPPDATA")
248		if localAppData == "" {
249			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
250		}
251		return filepath.Join(localAppData, appName)
252	}
253
254	return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName))
255}
256
257// globalConfigData returns the path to the main data directory for the application.
258// this config is used when the app overrides configurations instead of updating the global config.
259func globalConfigData() string {
260	xdgDataHome := os.Getenv("XDG_DATA_HOME")
261	if xdgDataHome != "" {
262		return filepath.Join(xdgDataHome, appName)
263	}
264
265	// return the path to the main data directory
266	// for windows, it should be in `%LOCALAPPDATA%/crush/`
267	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
268	if runtime.GOOS == "windows" {
269		localAppData := os.Getenv("LOCALAPPDATA")
270		if localAppData == "" {
271			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
272		}
273		return filepath.Join(localAppData, appName)
274	}
275
276	return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
277}