load.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"io"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"slices"
 11	"strings"
 12
 13	"github.com/charmbracelet/crush/internal/env"
 14	"github.com/charmbracelet/crush/internal/fur/client"
 15	"github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"golang.org/x/exp/slog"
 18)
 19
 20// LoadReader config via io.Reader.
 21func LoadReader(fd io.Reader) (*Config, error) {
 22	data, err := io.ReadAll(fd)
 23	if err != nil {
 24		return nil, err
 25	}
 26
 27	var config Config
 28	err = json.Unmarshal(data, &config)
 29	if err != nil {
 30		return nil, err
 31	}
 32	return &config, err
 33}
 34
 35// Load loads the configuration from the default paths.
 36func Load(workingDir string, debug bool) (*Config, error) {
 37	// uses default config paths
 38	configPaths := []string{
 39		globalConfig(),
 40		GlobalConfigData(),
 41		filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
 42		filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
 43	}
 44	cfg, err := loadFromConfigPaths(configPaths)
 45	if err != nil {
 46		return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
 47	}
 48
 49	cfg.dataConfigDir = GlobalConfigData()
 50
 51	cfg.setDefaults(workingDir)
 52
 53	if debug {
 54		cfg.Options.Debug = true
 55	}
 56
 57	// Setup logs
 58	log.Setup(
 59		filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)),
 60		cfg.Options.Debug,
 61	)
 62
 63	// Load known providers, this loads the config from fur
 64	providers, err := LoadProviders(client.New())
 65	if err != nil || len(providers) == 0 {
 66		return nil, fmt.Errorf("failed to load providers: %w", err)
 67	}
 68
 69	env := env.New()
 70	// Configure providers
 71	valueResolver := NewShellVariableResolver(env)
 72	cfg.resolver = valueResolver
 73	if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
 74		return nil, fmt.Errorf("failed to configure providers: %w", err)
 75	}
 76
 77	if !cfg.IsConfigured() {
 78		slog.Warn("No providers configured")
 79		return cfg, nil
 80	}
 81
 82	if err := cfg.configureSelectedModels(providers); err != nil {
 83		return nil, fmt.Errorf("failed to configure selected models: %w", err)
 84	}
 85
 86	// TODO: remove the agents concept from the config
 87	agents := map[string]Agent{
 88		"coder": {
 89			ID:           "coder",
 90			Name:         "Coder",
 91			Description:  "An agent that helps with executing coding tasks.",
 92			Model:        SelectedModelTypeLarge,
 93			ContextPaths: cfg.Options.ContextPaths,
 94			// All tools allowed
 95		},
 96		"task": {
 97			ID:           "task",
 98			Name:         "Task",
 99			Description:  "An agent that helps with searching for context and finding implementation details.",
100			Model:        SelectedModelTypeLarge,
101			ContextPaths: cfg.Options.ContextPaths,
102			AllowedTools: []string{
103				"glob",
104				"grep",
105				"ls",
106				"sourcegraph",
107				"view",
108			},
109			// NO MCPs or LSPs by default
110			AllowedMCP: map[string][]string{},
111			AllowedLSP: []string{},
112		},
113	}
114	cfg.Agents = agents
115
116	return cfg, nil
117}
118
119func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
120	knownProviderNames := make(map[string]bool)
121	for _, p := range knownProviders {
122		knownProviderNames[string(p.ID)] = true
123		config, configExists := cfg.Providers[string(p.ID)]
124		// if the user configured a known provider we need to allow it to override a couple of parameters
125		if configExists {
126			if config.Disable {
127				slog.Debug("Skipping provider due to disable flag", "provider", p.ID)
128				delete(cfg.Providers, string(p.ID))
129				continue
130			}
131			if config.BaseURL != "" {
132				p.APIEndpoint = config.BaseURL
133			}
134			if config.APIKey != "" {
135				p.APIKey = config.APIKey
136			}
137			if len(config.Models) > 0 {
138				models := []provider.Model{}
139				seen := make(map[string]bool)
140
141				for _, model := range config.Models {
142					if seen[model.ID] {
143						continue
144					}
145					seen[model.ID] = true
146					if model.Model == "" {
147						model.Model = model.ID
148					}
149					models = append(models, model)
150				}
151				for _, model := range p.Models {
152					if seen[model.ID] {
153						continue
154					}
155					seen[model.ID] = true
156					if model.Model == "" {
157						model.Model = model.ID
158					}
159					models = append(models, model)
160				}
161
162				p.Models = models
163			}
164		}
165		prepared := ProviderConfig{
166			ID:           string(p.ID),
167			Name:         p.Name,
168			BaseURL:      p.APIEndpoint,
169			APIKey:       p.APIKey,
170			Type:         p.Type,
171			Disable:      config.Disable,
172			ExtraHeaders: config.ExtraHeaders,
173			ExtraParams:  make(map[string]string),
174			Models:       p.Models,
175		}
176
177		switch p.ID {
178		// Handle specific providers that require additional configuration
179		case provider.InferenceProviderVertexAI:
180			if !hasVertexCredentials(env) {
181				if configExists {
182					slog.Warn("Skipping Vertex AI provider due to missing credentials")
183					delete(cfg.Providers, string(p.ID))
184				}
185				continue
186			}
187			prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
188			prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
189		case provider.InferenceProviderBedrock:
190			if !hasAWSCredentials(env) {
191				if configExists {
192					slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
193					delete(cfg.Providers, string(p.ID))
194				}
195				continue
196			}
197			for _, model := range p.Models {
198				if !strings.HasPrefix(model.ID, "anthropic.") {
199					return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
200				}
201			}
202		default:
203			// if the provider api or endpoint are missing we skip them
204			v, err := resolver.ResolveValue(p.APIKey)
205			if v == "" || err != nil {
206				if configExists {
207					slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
208					delete(cfg.Providers, string(p.ID))
209				}
210				continue
211			}
212		}
213		cfg.Providers[string(p.ID)] = prepared
214	}
215
216	// validate the custom providers
217	for id, providerConfig := range cfg.Providers {
218		if knownProviderNames[id] {
219			continue
220		}
221
222		// Make sure the provider ID is set
223		providerConfig.ID = id
224		if providerConfig.Name == "" {
225			providerConfig.Name = id // Use ID as name if not set
226		}
227		// default to OpenAI if not set
228		if providerConfig.Type == "" {
229			providerConfig.Type = provider.TypeOpenAI
230		}
231
232		if providerConfig.Disable {
233			slog.Debug("Skipping custom provider due to disable flag", "provider", id)
234			delete(cfg.Providers, id)
235			continue
236		}
237		if providerConfig.APIKey == "" {
238			slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
239		}
240		if providerConfig.BaseURL == "" {
241			slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
242			delete(cfg.Providers, id)
243			continue
244		}
245		if len(providerConfig.Models) == 0 {
246			slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
247			delete(cfg.Providers, id)
248			continue
249		}
250		if providerConfig.Type != provider.TypeOpenAI {
251			slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
252			delete(cfg.Providers, id)
253			continue
254		}
255
256		apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
257		if apiKey == "" || err != nil {
258			slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
259		}
260		baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
261		if baseURL == "" || err != nil {
262			slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
263			delete(cfg.Providers, id)
264			continue
265		}
266
267		cfg.Providers[id] = providerConfig
268	}
269	return nil
270}
271
272func (cfg *Config) setDefaults(workingDir string) {
273	cfg.workingDir = workingDir
274	if cfg.Options == nil {
275		cfg.Options = &Options{}
276	}
277	if cfg.Options.TUI == nil {
278		cfg.Options.TUI = &TUIOptions{}
279	}
280	if cfg.Options.ContextPaths == nil {
281		cfg.Options.ContextPaths = []string{}
282	}
283	if cfg.Options.DataDirectory == "" {
284		cfg.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
285	}
286	if cfg.Providers == nil {
287		cfg.Providers = make(map[string]ProviderConfig)
288	}
289	if cfg.Models == nil {
290		cfg.Models = make(map[SelectedModelType]SelectedModel)
291	}
292	if cfg.MCP == nil {
293		cfg.MCP = make(map[string]MCPConfig)
294	}
295	if cfg.LSP == nil {
296		cfg.LSP = make(map[string]LSPConfig)
297	}
298
299	// Add the default context paths if they are not already present
300	cfg.Options.ContextPaths = append(defaultContextPaths, cfg.Options.ContextPaths...)
301	slices.Sort(cfg.Options.ContextPaths)
302	cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
303}
304
305func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
306	if len(knownProviders) == 0 && len(cfg.Providers) == 0 {
307		err = fmt.Errorf("no providers configured, please configure at least one provider")
308		return
309	}
310
311	// Use the first provider enabled based on the known providers order
312	// if no provider found that is known use the first provider configured
313	for _, p := range knownProviders {
314		providerConfig, ok := cfg.Providers[string(p.ID)]
315		if !ok || providerConfig.Disable {
316			continue
317		}
318		defaultLargeModel := cfg.GetModel(string(p.ID), p.DefaultLargeModelID)
319		if defaultLargeModel == nil {
320			err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
321			return
322		}
323		largeModel = SelectedModel{
324			Provider:        string(p.ID),
325			Model:           defaultLargeModel.ID,
326			MaxTokens:       defaultLargeModel.DefaultMaxTokens,
327			ReasoningEffort: defaultLargeModel.DefaultReasoningEffort,
328		}
329
330		defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID)
331		if defaultSmallModel == nil {
332			err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
333		}
334		smallModel = SelectedModel{
335			Provider:        string(p.ID),
336			Model:           defaultSmallModel.ID,
337			MaxTokens:       defaultSmallModel.DefaultMaxTokens,
338			ReasoningEffort: defaultSmallModel.DefaultReasoningEffort,
339		}
340		return
341	}
342
343	enabledProviders := cfg.EnabledProviders()
344	slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
345		return strings.Compare(a.ID, b.ID)
346	})
347
348	if len(enabledProviders) == 0 {
349		err = fmt.Errorf("no providers configured, please configure at least one provider")
350		return
351	}
352
353	providerConfig := enabledProviders[0]
354	if len(providerConfig.Models) == 0 {
355		err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
356		return
357	}
358	defaultLargeModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
359	largeModel = SelectedModel{
360		Provider:  providerConfig.ID,
361		Model:     defaultLargeModel.ID,
362		MaxTokens: defaultLargeModel.DefaultMaxTokens,
363	}
364	defaultSmallModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
365	smallModel = SelectedModel{
366		Provider:  providerConfig.ID,
367		Model:     defaultSmallModel.ID,
368		MaxTokens: defaultSmallModel.DefaultMaxTokens,
369	}
370	return
371}
372
373func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) error {
374	defaultLarge, defaultSmall, err := cfg.defaultModelSelection(knownProviders)
375	if err != nil {
376		return fmt.Errorf("failed to select default models: %w", err)
377	}
378	large, small := defaultLarge, defaultSmall
379
380	largeModelSelected, largeModelConfigured := cfg.Models[SelectedModelTypeLarge]
381	if largeModelConfigured {
382		if largeModelSelected.Model != "" {
383			large.Model = largeModelSelected.Model
384		}
385		if largeModelSelected.Provider != "" {
386			large.Provider = largeModelSelected.Provider
387		}
388		model := cfg.GetModel(large.Provider, large.Model)
389		slog.Info("Configuring selected large model", "provider", large.Provider, "model", large.Model)
390		slog.Info("MOdel configured", "model", model)
391		if model == nil {
392			large = defaultLarge
393			// override the model type to large
394			err := cfg.UpdatePreferredModel(SelectedModelTypeLarge, large)
395			if err != nil {
396				return fmt.Errorf("failed to update preferred large model: %w", err)
397			}
398		} else {
399			if largeModelSelected.MaxTokens > 0 {
400				large.MaxTokens = largeModelSelected.MaxTokens
401			} else {
402				large.MaxTokens = model.DefaultMaxTokens
403			}
404			if largeModelSelected.ReasoningEffort != "" {
405				large.ReasoningEffort = largeModelSelected.ReasoningEffort
406			}
407			large.Think = largeModelSelected.Think
408		}
409	}
410	smallModelSelected, smallModelConfigured := cfg.Models[SelectedModelTypeSmall]
411	if smallModelConfigured {
412		if smallModelSelected.Model != "" {
413			small.Model = smallModelSelected.Model
414		}
415		if smallModelSelected.Provider != "" {
416			small.Provider = smallModelSelected.Provider
417		}
418
419		model := cfg.GetModel(small.Provider, small.Model)
420		if model == nil {
421			small = defaultSmall
422			// override the model type to small
423			err := cfg.UpdatePreferredModel(SelectedModelTypeSmall, small)
424			if err != nil {
425				return fmt.Errorf("failed to update preferred small model: %w", err)
426			}
427		} else {
428			if smallModelSelected.MaxTokens > 0 {
429				small.MaxTokens = smallModelSelected.MaxTokens
430			} else {
431				small.MaxTokens = model.DefaultMaxTokens
432			}
433			small.ReasoningEffort = smallModelSelected.ReasoningEffort
434			small.Think = smallModelSelected.Think
435		}
436	}
437	cfg.Models[SelectedModelTypeLarge] = large
438	cfg.Models[SelectedModelTypeSmall] = small
439	return nil
440}
441
442func loadFromConfigPaths(configPaths []string) (*Config, error) {
443	var configs []io.Reader
444
445	for _, path := range configPaths {
446		fd, err := os.Open(path)
447		if err != nil {
448			if os.IsNotExist(err) {
449				continue
450			}
451			return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
452		}
453		defer fd.Close()
454
455		configs = append(configs, fd)
456	}
457
458	return loadFromReaders(configs)
459}
460
461func loadFromReaders(readers []io.Reader) (*Config, error) {
462	if len(readers) == 0 {
463		return &Config{}, nil
464	}
465
466	merged, err := Merge(readers)
467	if err != nil {
468		return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
469	}
470
471	return LoadReader(merged)
472}
473
474func hasVertexCredentials(env env.Env) bool {
475	useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
476	hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
477	hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
478	return useVertex && hasProject && hasLocation
479}
480
481func hasAWSCredentials(env env.Env) bool {
482	if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
483		return true
484	}
485
486	if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
487		return true
488	}
489
490	if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
491		return true
492	}
493
494	if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
495		env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
496		return true
497	}
498
499	return false
500}
501
502func globalConfig() string {
503	xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
504	if xdgConfigHome != "" {
505		return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName))
506	}
507
508	// return the path to the main config directory
509	// for windows, it should be in `%LOCALAPPDATA%/crush/`
510	// for linux and macOS, it should be in `$HOME/.config/crush/`
511	if runtime.GOOS == "windows" {
512		localAppData := os.Getenv("LOCALAPPDATA")
513		if localAppData == "" {
514			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
515		}
516		return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
517	}
518
519	return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName))
520}
521
522// GlobalConfigData returns the path to the main data directory for the application.
523// this config is used when the app overrides configurations instead of updating the global config.
524func GlobalConfigData() string {
525	xdgDataHome := os.Getenv("XDG_DATA_HOME")
526	if xdgDataHome != "" {
527		return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName))
528	}
529
530	// return the path to the main data directory
531	// for windows, it should be in `%LOCALAPPDATA%/crush/`
532	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
533	if runtime.GOOS == "windows" {
534		localAppData := os.Getenv("LOCALAPPDATA")
535		if localAppData == "" {
536			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
537		}
538		return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
539	}
540
541	return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
542}
543
544func HomeDir() string {
545	homeDir := os.Getenv("HOME")
546	if homeDir == "" {
547		homeDir = os.Getenv("USERPROFILE") // For Windows compatibility
548	}
549	if homeDir == "" {
550		homeDir = os.Getenv("HOMEPATH") // Fallback for some environments
551	}
552	return homeDir
553}