load.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"io"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"slices"
 11	"strings"
 12
 13	"github.com/charmbracelet/crush/internal/fur/client"
 14	"github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/pkg/env"
 16	"github.com/charmbracelet/crush/pkg/log"
 17)
 18
 19// LoadReader config via io.Reader.
 20func LoadReader(fd io.Reader) (*Config, error) {
 21	data, err := io.ReadAll(fd)
 22	if err != nil {
 23		return nil, err
 24	}
 25
 26	var config Config
 27	err = json.Unmarshal(data, &config)
 28	if err != nil {
 29		return nil, err
 30	}
 31	return &config, err
 32}
 33
 34// Load loads the configuration from the default paths.
 35func Load(workingDir string, debug bool) (*Config, error) {
 36	// uses default config paths
 37	configPaths := []string{
 38		globalConfig(),
 39		globalConfigData(),
 40		filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
 41		filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
 42	}
 43	cfg, err := loadFromConfigPaths(configPaths)
 44
 45	if debug {
 46		cfg.Options.Debug = true
 47	}
 48
 49	// Init logs
 50	log.Init(
 51		filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)),
 52		cfg.Options.Debug,
 53	)
 54
 55	if err != nil {
 56		return nil, fmt.Errorf("failed to load config: %w", err)
 57	}
 58	// TODO: maybe add a validation step here right after loading
 59	// e.x validate the models
 60	// e.x validate provider config
 61
 62	cfg.setDefaults(workingDir)
 63
 64	// Load known providers, this loads the config from fur
 65	providers, err := LoadProviders(client.New())
 66	if err != nil {
 67		return nil, fmt.Errorf("failed to load providers: %w", err)
 68	}
 69
 70	env := env.New()
 71	// Configure providers
 72	valueResolver := NewShellVariableResolver(env)
 73	if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
 74		return nil, fmt.Errorf("failed to configure providers: %w", err)
 75	}
 76
 77	return cfg, nil
 78}
 79
 80func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
 81	for _, p := range knownProviders {
 82
 83		config, ok := cfg.Providers[string(p.ID)]
 84		// if the user configured a known provider we need to allow it to override a couple of parameters
 85		if ok {
 86			if config.BaseURL != "" {
 87				p.APIEndpoint = config.BaseURL
 88			}
 89			if config.APIKey != "" {
 90				p.APIKey = config.APIKey
 91			}
 92			if len(config.Models) > 0 {
 93				models := []provider.Model{}
 94				seen := make(map[string]bool)
 95
 96				for _, model := range config.Models {
 97					if seen[model.ID] {
 98						continue
 99					}
100					seen[model.ID] = true
101					models = append(models, model)
102				}
103				for _, model := range p.Models {
104					if seen[model.ID] {
105						continue
106					}
107					seen[model.ID] = true
108					models = append(models, model)
109				}
110
111				p.Models = models
112			}
113		}
114		prepared := ProviderConfig{
115			BaseURL:      p.APIEndpoint,
116			APIKey:       p.APIKey,
117			Type:         p.Type,
118			Disable:      config.Disable,
119			ExtraHeaders: config.ExtraHeaders,
120			ExtraParams:  make(map[string]string),
121			Models:       p.Models,
122		}
123
124		switch p.ID {
125		// Handle specific providers that require additional configuration
126		case provider.InferenceProviderVertexAI:
127			if !hasVertexCredentials(env) {
128				continue
129			}
130			prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
131			prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
132		case provider.InferenceProviderBedrock:
133			if !hasAWSCredentials(env) {
134				continue
135			}
136			for _, model := range p.Models {
137				if !strings.HasPrefix(model.ID, "anthropic.") {
138					return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
139				}
140			}
141		default:
142			// if the provider api or endpoint are missing we skip them
143			v, err := resolver.ResolveValue(p.APIKey)
144			if v == "" || err != nil {
145				continue
146			}
147			v, err = resolver.ResolveValue(p.APIEndpoint)
148			if v == "" || err != nil {
149				continue
150			}
151		}
152		cfg.Providers[string(p.ID)] = prepared
153	}
154	return nil
155}
156
157func hasVertexCredentials(env env.Env) bool {
158	useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
159	hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
160	hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
161	return useVertex && hasProject && hasLocation
162}
163
164func hasAWSCredentials(env env.Env) bool {
165	if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
166		return true
167	}
168
169	if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
170		return true
171	}
172
173	if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
174		return true
175	}
176
177	if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
178		env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
179		return true
180	}
181
182	return false
183}
184
185func (cfg *Config) setDefaults(workingDir string) {
186	cfg.workingDir = workingDir
187	if cfg.Options == nil {
188		cfg.Options = &Options{}
189	}
190	if cfg.Options.TUI == nil {
191		cfg.Options.TUI = &TUIOptions{}
192	}
193	if cfg.Options.ContextPaths == nil {
194		cfg.Options.ContextPaths = []string{}
195	}
196	if cfg.Options.DataDirectory == "" {
197		cfg.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
198	}
199	if cfg.Providers == nil {
200		cfg.Providers = make(map[string]ProviderConfig)
201	}
202	if cfg.Models == nil {
203		cfg.Models = make(map[string]SelectedModel)
204	}
205	if cfg.MCP == nil {
206		cfg.MCP = make(map[string]MCPConfig)
207	}
208	if cfg.LSP == nil {
209		cfg.LSP = make(map[string]LSPConfig)
210	}
211
212	// Add the default context paths if they are not already present
213	cfg.Options.ContextPaths = append(defaultContextPaths, cfg.Options.ContextPaths...)
214	slices.Sort(cfg.Options.ContextPaths)
215	cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
216}
217
218func loadFromConfigPaths(configPaths []string) (*Config, error) {
219	var configs []io.Reader
220
221	for _, path := range configPaths {
222		fd, err := os.Open(path)
223		if err != nil {
224			if os.IsNotExist(err) {
225				continue
226			}
227			return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
228		}
229		defer fd.Close()
230
231		configs = append(configs, fd)
232	}
233
234	return loadFromReaders(configs)
235}
236
237func loadFromReaders(readers []io.Reader) (*Config, error) {
238	if len(readers) == 0 {
239		return nil, fmt.Errorf("no configuration readers provided")
240	}
241
242	merged, err := Merge(readers)
243	if err != nil {
244		return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
245	}
246
247	return LoadReader(merged)
248}
249
250func globalConfig() string {
251	xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
252	if xdgConfigHome != "" {
253		return filepath.Join(xdgConfigHome, "crush")
254	}
255
256	// return the path to the main config directory
257	// for windows, it should be in `%LOCALAPPDATA%/crush/`
258	// for linux and macOS, it should be in `$HOME/.config/crush/`
259	if runtime.GOOS == "windows" {
260		localAppData := os.Getenv("LOCALAPPDATA")
261		if localAppData == "" {
262			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
263		}
264		return filepath.Join(localAppData, appName)
265	}
266
267	return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName))
268}
269
270// globalConfigData returns the path to the main data directory for the application.
271// this config is used when the app overrides configurations instead of updating the global config.
272func globalConfigData() string {
273	xdgDataHome := os.Getenv("XDG_DATA_HOME")
274	if xdgDataHome != "" {
275		return filepath.Join(xdgDataHome, appName)
276	}
277
278	// return the path to the main data directory
279	// for windows, it should be in `%LOCALAPPDATA%/crush/`
280	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
281	if runtime.GOOS == "windows" {
282		localAppData := os.Getenv("LOCALAPPDATA")
283		if localAppData == "" {
284			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
285		}
286		return filepath.Join(localAppData, appName)
287	}
288
289	return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
290}