1package configv2
2
3import (
4 "encoding/json"
5 "errors"
6 "maps"
7 "os"
8 "path/filepath"
9 "slices"
10 "strings"
11 "sync"
12
13 "github.com/charmbracelet/crush/internal/logging"
14 "github.com/charmbracelet/fur/pkg/provider"
15)
16
17const (
18 defaultDataDirectory = ".crush"
19 defaultLogLevel = "info"
20 appName = "crush"
21
22 MaxTokensFallbackDefault = 4096
23)
24
25type Model struct {
26 ID string `json:"id"`
27 Name string `json:"model"`
28 CostPer1MIn float64 `json:"cost_per_1m_in"`
29 CostPer1MOut float64 `json:"cost_per_1m_out"`
30 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
31 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
32 ContextWindow int64 `json:"context_window"`
33 DefaultMaxTokens int64 `json:"default_max_tokens"`
34 CanReason bool `json:"can_reason"`
35 ReasoningEffort string `json:"reasoning_effort"`
36 SupportsImages bool `json:"supports_attachments"`
37}
38
39type VertexAIOptions struct {
40 APIKey string `json:"api_key,omitempty"`
41 Project string `json:"project,omitempty"`
42 Location string `json:"location,omitempty"`
43}
44
45type ProviderConfig struct {
46 BaseURL string `json:"base_url,omitempty"`
47 ProviderType provider.Type `json:"provider_type"`
48 APIKey string `json:"api_key,omitempty"`
49 Disabled bool `json:"disabled"`
50 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
51 // used for e.x for vertex to set the project
52 ExtraParams map[string]string `json:"extra_params,omitempty"`
53
54 DefaultModel string `json:"default_model"`
55}
56
57type Agent struct {
58 Name string `json:"name"`
59 // This is the id of the system prompt used by the agent
60 // TODO: still needs to be implemented
61 PromptID string `json:"prompt_id"`
62 Disabled bool `json:"disabled"`
63
64 Provider provider.InferenceProvider `json:"provider"`
65 Model Model `json:"model"`
66
67 // The available tools for the agent
68 // if this is empty, all tools are available
69 AllowedTools []string `json:"allowed_tools"`
70
71 // this tells us which MCPs are available for this agent
72 // if this is empty all mcps are available
73 // the string array is the list of tools from the MCP the agent has available
74 // if the string array is empty, all tools from the MCP are available
75 MCP map[string][]string `json:"mcp"`
76
77 // The list of LSPs that this agent can use
78 // if this is empty, all LSPs are available
79 LSP []string `json:"lsp"`
80
81 // Overrides the context paths for this agent
82 ContextPaths []string `json:"context_paths"`
83}
84
85type MCPType string
86
87const (
88 MCPStdio MCPType = "stdio"
89 MCPSse MCPType = "sse"
90)
91
92type MCP struct {
93 Command string `json:"command"`
94 Env []string `json:"env"`
95 Args []string `json:"args"`
96 Type MCPType `json:"type"`
97 URL string `json:"url"`
98 Headers map[string]string `json:"headers"`
99}
100
101type LSPConfig struct {
102 Disabled bool `json:"enabled"`
103 Command string `json:"command"`
104 Args []string `json:"args"`
105 Options any `json:"options"`
106}
107
108type TUIOptions struct {
109 CompactMode bool `json:"compact_mode"`
110 // Here we can add themes later or any TUI related options
111}
112
113type Options struct {
114 ContextPaths []string `json:"context_paths"`
115 TUI TUIOptions `json:"tui"`
116 Debug bool `json:"debug"`
117 DebugLSP bool `json:"debug_lsp"`
118 DisableAutoSummarize bool `json:"disable_auto_summarize"`
119 // Relative to the cwd
120 DataDirectory string `json:"data_directory"`
121}
122
123type Config struct {
124 // List of configured providers
125 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
126
127 // List of configured agents
128 Agents map[string]Agent `json:"agents,omitempty"`
129
130 // List of configured MCPs
131 MCP map[string]MCP `json:"mcp,omitempty"`
132
133 // List of configured LSPs
134 LSP map[string]LSPConfig `json:"lsp,omitempty"`
135
136 // Miscellaneous options
137 Options Options `json:"options"`
138
139 // Used to add models that are not already in the repository
140 Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"`
141}
142
143var (
144 instance *Config // The single instance of the Singleton
145 cwd string
146 once sync.Once // Ensures the initialization happens only once
147)
148
149func loadConfig(cwd string) (*Config, error) {
150 // First read the global config file
151 cfgPath := ConfigPath()
152
153 cfg := defaultConfigBasedOnEnv()
154
155 var globalCfg *Config
156 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
157 // some other error occurred while checking the file
158 return nil, err
159 } else if err == nil {
160 // config file exists, read it
161 file, err := os.ReadFile(cfgPath)
162 if err != nil {
163 return nil, err
164 }
165 globalCfg = &Config{}
166 if err := json.Unmarshal(file, globalCfg); err != nil {
167 return nil, err
168 }
169 } else {
170 // config file does not exist, create a new one
171 globalCfg = &Config{}
172 }
173
174 var localConfig *Config
175 // Global config loaded, now read the local config file
176 localConfigPath := filepath.Join(cwd, "crush.json")
177 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
178 // some other error occurred while checking the file
179 return nil, err
180 } else if err == nil {
181 // local config file exists, read it
182 file, err := os.ReadFile(localConfigPath)
183 if err != nil {
184 return nil, err
185 }
186 localConfig = &Config{}
187 if err := json.Unmarshal(file, localConfig); err != nil {
188 return nil, err
189 }
190 }
191
192 // merge options
193 cfg.Options = mergeOptions(cfg.Options, globalCfg.Options)
194 cfg.Options = mergeOptions(cfg.Options, localConfig.Options)
195
196 mergeProviderConfigs(cfg, globalCfg, localConfig)
197 return cfg, nil
198}
199
200func InitConfig(workingDir string) *Config {
201 once.Do(func() {
202 cwd = workingDir
203 cfg, err := loadConfig(cwd)
204 if err != nil {
205 // TODO: Handle this better
206 panic("Failed to load config: " + err.Error())
207 }
208 instance = cfg
209 })
210
211 return instance
212}
213
214func GetConfig() *Config {
215 if instance == nil {
216 // TODO: Handle this better
217 panic("Config not initialized. Call InitConfig first.")
218 }
219 return instance
220}
221
222func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
223 if other.APIKey != "" {
224 base.APIKey = other.APIKey
225 }
226 // Only change these options if the provider is not a known provider
227 if !slices.Contains(provider.KnownProviders(), p) {
228 if other.BaseURL != "" {
229 base.BaseURL = other.BaseURL
230 }
231 if other.ProviderType != "" {
232 base.ProviderType = other.ProviderType
233 }
234 if len(base.ExtraHeaders) > 0 {
235 if base.ExtraHeaders == nil {
236 base.ExtraHeaders = make(map[string]string)
237 }
238 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
239 }
240 if len(other.ExtraParams) > 0 {
241 if base.ExtraParams == nil {
242 base.ExtraParams = make(map[string]string)
243 }
244 maps.Copy(base.ExtraParams, other.ExtraParams)
245 }
246 }
247
248 if other.Disabled {
249 base.Disabled = other.Disabled
250 }
251
252 return base
253}
254
255func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
256 if !slices.Contains(provider.KnownProviders(), p) {
257 if providerConfig.ProviderType != provider.TypeOpenAI {
258 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
259 }
260 if providerConfig.BaseURL == "" {
261 return errors.New("base URL must be set for custom providers")
262 }
263 if providerConfig.APIKey == "" {
264 return errors.New("API key must be set for custom providers")
265 }
266 }
267 return nil
268}
269
270func mergeOptions(base, other Options) Options {
271 result := base
272
273 if len(other.ContextPaths) > 0 {
274 base.ContextPaths = append(base.ContextPaths, other.ContextPaths...)
275 }
276
277 if other.TUI.CompactMode {
278 result.TUI.CompactMode = other.TUI.CompactMode
279 }
280
281 if other.Debug {
282 result.Debug = other.Debug
283 }
284
285 if other.DebugLSP {
286 result.DebugLSP = other.DebugLSP
287 }
288
289 if other.DisableAutoSummarize {
290 result.DisableAutoSummarize = other.DisableAutoSummarize
291 }
292
293 if other.DataDirectory != "" {
294 result.DataDirectory = other.DataDirectory
295 }
296
297 return result
298}
299
300func mergeProviderConfigs(base, global, local *Config) {
301 if global != nil {
302 for providerName, globalProvider := range global.Providers {
303 if _, ok := base.Providers[providerName]; !ok {
304 base.Providers[providerName] = globalProvider
305 } else {
306 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
307 }
308 }
309 }
310 if local != nil {
311 for providerName, localProvider := range local.Providers {
312 if _, ok := base.Providers[providerName]; !ok {
313 base.Providers[providerName] = localProvider
314 } else {
315 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider)
316 }
317 }
318 }
319
320 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
321 for providerName, providerConfig := range base.Providers {
322 err := validateProvider(providerName, providerConfig)
323 if err != nil {
324 logging.Warn("Skipping provider", "name", providerName, "error", err)
325 }
326 finalProviders[providerName] = providerConfig
327 }
328 base.Providers = finalProviders
329}
330
331func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig {
332 switch providerName {
333 case provider.InferenceProviderAnthropic:
334 return ProviderConfig{
335 ProviderType: provider.TypeAnthropic,
336 }
337 case provider.InferenceProviderOpenAI:
338 return ProviderConfig{
339 ProviderType: provider.TypeOpenAI,
340 }
341 case provider.InferenceProviderGemini:
342 return ProviderConfig{
343 ProviderType: provider.TypeGemini,
344 }
345 case provider.InferenceProviderBedrock:
346 return ProviderConfig{
347 ProviderType: provider.TypeBedrock,
348 }
349 case provider.InferenceProviderAzure:
350 return ProviderConfig{
351 ProviderType: provider.TypeAzure,
352 }
353 case provider.InferenceProviderOpenRouter:
354 return ProviderConfig{
355 ProviderType: provider.TypeOpenAI,
356 BaseURL: "https://openrouter.ai/api/v1",
357 ExtraHeaders: map[string]string{
358 "HTTP-Referer": "crush.charm.land",
359 "X-Title": "Crush",
360 },
361 }
362 case provider.InferenceProviderXAI:
363 return ProviderConfig{
364 ProviderType: provider.TypeXAI,
365 BaseURL: "https://api.x.ai/v1",
366 }
367 case provider.InferenceProviderVertexAI:
368 return ProviderConfig{
369 ProviderType: provider.TypeVertexAI,
370 }
371 default:
372 return ProviderConfig{
373 ProviderType: provider.TypeOpenAI,
374 }
375 }
376}
377
378func defaultConfigBasedOnEnv() *Config {
379 cfg := &Config{
380 Options: Options{
381 DataDirectory: defaultDataDirectory,
382 },
383 Providers: make(map[provider.InferenceProvider]ProviderConfig),
384 }
385
386 providers := Providers()
387
388 for _, p := range providers {
389 if strings.HasPrefix(p.APIKey, "$") {
390 envVar := strings.TrimPrefix(p.APIKey, "$")
391 if apiKey := os.Getenv(envVar); apiKey != "" {
392 providerConfig := providerDefaultConfig(p.ID)
393 providerConfig.APIKey = apiKey
394 providerConfig.DefaultModel = p.DefaultModelID
395 cfg.Providers[p.ID] = providerConfig
396 }
397 }
398 }
399 // TODO: support local models
400
401 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
402 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
403 providerConfig.ExtraParams = map[string]string{
404 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
405 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
406 }
407 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
408 }
409
410 if hasAWSCredentials() {
411 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
412 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
413 }
414 return cfg
415}
416
417func hasAWSCredentials() bool {
418 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
419 return true
420 }
421
422 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
423 return true
424 }
425
426 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
427 return true
428 }
429
430 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
431 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
432 return true
433 }
434
435 return false
436}
437
438func WorkingDirectory() string {
439 return cwd
440}