1package config
2
3import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "os"
8 "path/filepath"
9 "runtime"
10 "slices"
11 "strings"
12
13 "github.com/charmbracelet/crush/internal/fur/client"
14 "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/pkg/env"
16)
17
18// LoadReader config via io.Reader.
19func LoadReader(fd io.Reader) (*Config, error) {
20 data, err := io.ReadAll(fd)
21 if err != nil {
22 return nil, err
23 }
24
25 var config Config
26 err = json.Unmarshal(data, &config)
27 if err != nil {
28 return nil, err
29 }
30 return &config, err
31}
32
33// Load loads the configuration from the default paths.
34func Load(workingDir string, env env.Env) (*Config, error) {
35 // uses default config paths
36 configPaths := []string{
37 globalConfig(),
38 globalConfigData(),
39 filepath.Join(workingDir, fmt.Sprintf("%s.json", appName)),
40 filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
41 }
42 cfg, err := loadFromConfigPaths(configPaths)
43 if err != nil {
44 return nil, fmt.Errorf("failed to load config: %w", err)
45 }
46 // TODO: maybe add a validation step here right after loading
47 // e.x validate the models
48 // e.x validate provider config
49
50 setDefaults(workingDir, cfg)
51
52 // Load known providers, this loads the config from fur
53 providers, err := LoadProviders(client.New())
54 if err != nil {
55 return nil, fmt.Errorf("failed to load providers: %w", err)
56 }
57
58 // Configure providers
59 valueResolver := NewShellVariableResolver(env)
60 if err := configureProviders(cfg, env, valueResolver, providers); err != nil {
61 return nil, fmt.Errorf("failed to configure providers: %w", err)
62 }
63
64 return cfg, nil
65}
66
67func configureProviders(cfg *Config, env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
68 for _, p := range knownProviders {
69
70 config, ok := cfg.Providers[string(p.ID)]
71 // if the user configured a known provider we need to allow it to override a couple of parameters
72 if ok {
73 if config.BaseURL != "" {
74 p.APIEndpoint = config.BaseURL
75 }
76 if config.APIKey != "" {
77 p.APIKey = config.APIKey
78 }
79 if len(config.Models) > 0 {
80 models := []provider.Model{}
81 seen := make(map[string]bool)
82
83 for _, model := range config.Models {
84 if seen[model.ID] {
85 continue
86 }
87 seen[model.ID] = true
88 models = append(models, model)
89 }
90 for _, model := range p.Models {
91 if seen[model.ID] {
92 continue
93 }
94 seen[model.ID] = true
95 models = append(models, model)
96 }
97
98 p.Models = models
99 }
100 }
101 prepared := ProviderConfig{
102 BaseURL: p.APIEndpoint,
103 APIKey: p.APIKey,
104 Type: p.Type,
105 Disable: config.Disable,
106 ExtraHeaders: config.ExtraHeaders,
107 ExtraParams: make(map[string]string),
108 Models: p.Models,
109 }
110
111 switch p.ID {
112 // Handle specific providers that require additional configuration
113 case provider.InferenceProviderVertexAI:
114 if !hasVertexCredentials(env) {
115 continue
116 }
117 prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
118 prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
119 case provider.InferenceProviderBedrock:
120 if !hasAWSCredentials(env) {
121 continue
122 }
123 for _, model := range p.Models {
124 if !strings.HasPrefix(model.ID, "anthropic.") {
125 return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
126 }
127 }
128 default:
129 // if the provider api or endpoint are missing we skip them
130 v, err := resolver.ResolveValue(p.APIKey)
131 if v == "" || err != nil {
132 continue
133 }
134 v, err = resolver.ResolveValue(p.APIEndpoint)
135 if v == "" || err != nil {
136 continue
137 }
138 }
139 cfg.Providers[string(p.ID)] = prepared
140 }
141 return nil
142}
143
144func hasVertexCredentials(env env.Env) bool {
145 useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
146 hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
147 hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
148 return useVertex && hasProject && hasLocation
149}
150
151func hasAWSCredentials(env env.Env) bool {
152 if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
153 return true
154 }
155
156 if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
157 return true
158 }
159
160 if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
161 return true
162 }
163
164 if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
165 env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
166 return true
167 }
168
169 return false
170}
171
172func setDefaults(workingDir string, cfg *Config) {
173 cfg.workingDir = workingDir
174 if cfg.Options == nil {
175 cfg.Options = &Options{}
176 }
177 if cfg.Options.TUI == nil {
178 cfg.Options.TUI = &TUIOptions{}
179 }
180 if cfg.Options.ContextPaths == nil {
181 cfg.Options.ContextPaths = []string{}
182 }
183 if cfg.Options.DataDirectory == "" {
184 cfg.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
185 }
186 if cfg.Providers == nil {
187 cfg.Providers = make(map[string]ProviderConfig)
188 }
189 if cfg.Models == nil {
190 cfg.Models = make(map[string]SelectedModel)
191 }
192 if cfg.MCP == nil {
193 cfg.MCP = make(map[string]MCP)
194 }
195 if cfg.LSP == nil {
196 cfg.LSP = make(map[string]LSPConfig)
197 }
198
199 // Add the default context paths if they are not already present
200 cfg.Options.ContextPaths = append(defaultContextPaths, cfg.Options.ContextPaths...)
201 slices.Sort(cfg.Options.ContextPaths)
202 cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
203}
204
205func loadFromConfigPaths(configPaths []string) (*Config, error) {
206 var configs []io.Reader
207
208 for _, path := range configPaths {
209 fd, err := os.Open(path)
210 if err != nil {
211 if os.IsNotExist(err) {
212 continue
213 }
214 return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
215 }
216 defer fd.Close()
217
218 configs = append(configs, fd)
219 }
220
221 return loadFromReaders(configs)
222}
223
224func loadFromReaders(readers []io.Reader) (*Config, error) {
225 if len(readers) == 0 {
226 return nil, fmt.Errorf("no configuration readers provided")
227 }
228
229 merged, err := Merge(readers)
230 if err != nil {
231 return nil, fmt.Errorf("failed to merge configuration readers: %w", err)
232 }
233
234 return LoadReader(merged)
235}
236
237func globalConfig() string {
238 xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
239 if xdgConfigHome != "" {
240 return filepath.Join(xdgConfigHome, "crush")
241 }
242
243 // return the path to the main config directory
244 // for windows, it should be in `%LOCALAPPDATA%/crush/`
245 // for linux and macOS, it should be in `$HOME/.config/crush/`
246 if runtime.GOOS == "windows" {
247 localAppData := os.Getenv("LOCALAPPDATA")
248 if localAppData == "" {
249 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
250 }
251 return filepath.Join(localAppData, appName)
252 }
253
254 return filepath.Join(os.Getenv("HOME"), ".config", appName, fmt.Sprintf("%s.json", appName))
255}
256
257// globalConfigData returns the path to the main data directory for the application.
258// this config is used when the app overrides configurations instead of updating the global config.
259func globalConfigData() string {
260 xdgDataHome := os.Getenv("XDG_DATA_HOME")
261 if xdgDataHome != "" {
262 return filepath.Join(xdgDataHome, appName)
263 }
264
265 // return the path to the main data directory
266 // for windows, it should be in `%LOCALAPPDATA%/crush/`
267 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
268 if runtime.GOOS == "windows" {
269 localAppData := os.Getenv("LOCALAPPDATA")
270 if localAppData == "" {
271 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
272 }
273 return filepath.Join(localAppData, appName)
274 }
275
276 return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
277}