1package config
2
3import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "os"
8 "path/filepath"
9 "runtime"
10 "slices"
11 "strings"
12
13 "github.com/charmbracelet/crush/internal/fur/client"
14 "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/pkg/env"
16 "github.com/charmbracelet/crush/pkg/log"
17)
18
19// LoadReader config via io.Reader.
20func LoadReader(fd io.Reader) (*Config, error) {
21 data, err := io.ReadAll(fd)
22 if err != nil {
23 return nil, err
24 }
25
26 var config Config
27 err = json.Unmarshal(data, &config)
28 if err != nil {
29 return nil, err
30 }
31 return &config, err
32}
33
34// Load loads the configuration from the default paths.
35func Load(workingDir string, debug bool) (*Config, error) {
36 // uses default config paths
37 configPaths := []string{
38 globalConfig(),
39 globalConfigData(),
40 filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
41 filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
42 }
43 cfg, err := loadFromConfigPaths(configPaths)
44
45 if debug {
46 cfg.Options.Debug = true
47 }
48
49 // Init logs
50 log.Init(
51 filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)),
52 cfg.Options.Debug,
53 )
54
55 if err != nil {
56 return nil, fmt.Errorf("failed to load config: %w", err)
57 }
58 // TODO: maybe add a validation step here right after loading
59 // e.x validate the models
60 // e.x validate provider config
61
62 cfg.setDefaults(workingDir)
63
64 // Load known providers, this loads the config from fur
65 providers, err := LoadProviders(client.New())
66 if err != nil {
67 return nil, fmt.Errorf("failed to load providers: %w", err)
68 }
69
70 env := env.New()
71 // Configure providers
72 valueResolver := NewShellVariableResolver(env)
73 if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
74 return nil, fmt.Errorf("failed to configure providers: %w", err)
75 }
76
77 return cfg, nil
78}
79
80func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
81 for _, p := range knownProviders {
82
83 config, ok := cfg.Providers[string(p.ID)]
84 // if the user configured a known provider we need to allow it to override a couple of parameters
85 if ok {
86 if config.BaseURL != "" {
87 p.APIEndpoint = config.BaseURL
88 }
89 if config.APIKey != "" {
90 p.APIKey = config.APIKey
91 }
92 if len(config.Models) > 0 {
93 models := []provider.Model{}
94 seen := make(map[string]bool)
95
96 for _, model := range config.Models {
97 if seen[model.ID] {
98 continue
99 }
100 seen[model.ID] = true
101 models = append(models, model)
102 }
103 for _, model := range p.Models {
104 if seen[model.ID] {
105 continue
106 }
107 seen[model.ID] = true
108 models = append(models, model)
109 }
110
111 p.Models = models
112 }
113 }
114 prepared := ProviderConfig{
115 BaseURL: p.APIEndpoint,
116 APIKey: p.APIKey,
117 Type: p.Type,
118 Disable: config.Disable,
119 ExtraHeaders: config.ExtraHeaders,
120 ExtraParams: make(map[string]string),
121 Models: p.Models,
122 }
123
124 switch p.ID {
125 // Handle specific providers that require additional configuration
126 case provider.InferenceProviderVertexAI:
127 if !hasVertexCredentials(env) {
128 continue
129 }
130 prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
131 prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
132 case provider.InferenceProviderBedrock:
133 if !hasAWSCredentials(env) {
134 continue
135 }
136 for _, model := range p.Models {
137 if !strings.HasPrefix(model.ID, "anthropic.") {
138 return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
139 }
140 }
141 default:
142 // if the provider api or endpoint are missing we skip them
143 v, err := resolver.ResolveValue(p.APIKey)
144 if v == "" || err != nil {
145 continue
146 }
147 v, err = resolver.ResolveValue(p.APIEndpoint)
148 if v == "" || err != nil {
149 continue
150 }
151 }
152 cfg.Providers[string(p.ID)] = prepared
153 }
154 return nil
155}
156
157func hasVertexCredentials(env env.Env) bool {
158 useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
159 hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
160 hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
161 return useVertex && hasProject && hasLocation
162}
163
164func hasAWSCredentials(env env.Env) bool {
165 if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
166 return true
167 }
168
169 if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
170 return true
171 }
172
173 if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
174 return true
175 }
176
177 if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
178 env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
179 return true
180 }
181
182 return false
183}
184
185func (cfg *Config) setDefaults(workingDir string) {
186 cfg.workingDir = workingDir
187 if cfg.Options == nil {
188 cfg.Options = &Options{}
189 }
190 if cfg.Options.TUI == nil {
191 cfg.Options.TUI = &TUIOptions{}
192 }
193 if cfg.Options.ContextPaths == nil {
194 cfg.Options.ContextPaths = []string{}
195 }
196 if cfg.Options.DataDirectory == "" {
197 cfg.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
198 }
199 if cfg.Providers == nil {
200 cfg.Providers = make(map[string]ProviderConfig)
201 }
202 if cfg.Models == nil {
203 cfg.Models = make(map[string]SelectedModel)
204 }
205 if cfg.MCP == nil {
206 cfg.MCP = make(map[string]MCPConfig)
207 }
208 if cfg.LSP == nil {
209 cfg.LSP = make(map[string]LSPConfig)
210 }
211
212 // Add the default context paths if they are not already present
213 cfg.Options.ContextPaths = append(defaultContextPaths, cfg.Options.ContextPaths...)
214 slices.Sort(cfg.Options.ContextPaths)
215 cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
216}
217
218func loadFromConfigPaths(configPaths []string) (*Config, error) {
219 var configs []io.Reader
220
221 for _, path := range configPaths {
222 fd, err := os.Open(path)
223 if err != nil {
224 if os.IsNotExist(err) {
225 continue
226 }
227 return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
228 }
229 defer fd.Close()
230
231 configs = append(configs, fd)
232 }
233
234 return loadFromReaders(configs)
235}
236
237func loadFromReaders(readers []io.Reader) (*Config, error) {
238 if len(readers) == 0 {
239 return nil, fmt.Errorf("no configuration readers provided")
240 }
241
242 merged, err := Merge(readers)
243 if err != nil {
244 return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
245 }
246
247 return LoadReader(merged)
248}
249
250func globalConfig() string {
251 xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
252 if xdgConfigHome != "" {
253 return filepath.Join(xdgConfigHome, "crush")
254 }
255
256 // return the path to the main config directory
257 // for windows, it should be in `%LOCALAPPDATA%/crush/`
258 // for linux and macOS, it should be in `$HOME/.config/crush/`
259 if runtime.GOOS == "windows" {
260 localAppData := os.Getenv("LOCALAPPDATA")
261 if localAppData == "" {
262 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
263 }
264 return filepath.Join(localAppData, appName)
265 }
266
267 return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName))
268}
269
270// globalConfigData returns the path to the main data directory for the application.
271// this config is used when the app overrides configurations instead of updating the global config.
272func globalConfigData() string {
273 xdgDataHome := os.Getenv("XDG_DATA_HOME")
274 if xdgDataHome != "" {
275 return filepath.Join(xdgDataHome, appName)
276 }
277
278 // return the path to the main data directory
279 // for windows, it should be in `%LOCALAPPDATA%/crush/`
280 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
281 if runtime.GOOS == "windows" {
282 localAppData := os.Getenv("LOCALAPPDATA")
283 if localAppData == "" {
284 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
285 }
286 return filepath.Join(localAppData, appName)
287 }
288
289 return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
290}