1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/env"
15 "github.com/charmbracelet/crush/internal/oauth"
16 "github.com/charmbracelet/crush/internal/oauth/copilot"
17 "github.com/charmbracelet/crush/internal/oauth/hyper"
18 "github.com/tidwall/gjson"
19 "github.com/tidwall/sjson"
20)
21
22// fileSnapshot captures metadata about a config file at a point in time.
23type fileSnapshot struct {
24 Path string
25 Exists bool
26 Size int64
27 ModTime int64 // UnixNano
28}
29
30// RuntimeOverrides holds per-session settings that are never persisted to
31// disk. They are applied on top of the loaded Config and survive only for
32// the lifetime of the process (or workspace).
33type RuntimeOverrides struct {
34 SkipPermissionRequests bool
35}
36
37// ConfigStore is the single entry point for all config access. It owns the
38// pure-data Config, runtime state (working directory, resolver, known
39// providers), and persistence to both global and workspace config files.
40type ConfigStore struct {
41 config *Config
42 workingDir string
43 resolver VariableResolver
44 globalDataPath string // ~/.local/share/crush/crush.json
45 workspacePath string // .crush/crush.json
46 loadedPaths []string // config files that were successfully loaded
47 knownProviders []catwalk.Provider
48 overrides RuntimeOverrides
49 trackedConfigPaths []string // unique, normalized config file paths
50 snapshots map[string]fileSnapshot // path -> snapshot at last capture
51}
52
53// Config returns the pure-data config struct (read-only after load).
54func (s *ConfigStore) Config() *Config {
55 return s.config
56}
57
58// WorkingDir returns the current working directory.
59func (s *ConfigStore) WorkingDir() string {
60 return s.workingDir
61}
62
63// Resolver returns the variable resolver.
64func (s *ConfigStore) Resolver() VariableResolver {
65 return s.resolver
66}
67
68// Resolve resolves a variable reference using the configured resolver.
69func (s *ConfigStore) Resolve(key string) (string, error) {
70 if s.resolver == nil {
71 return "", fmt.Errorf("no variable resolver configured")
72 }
73 return s.resolver.ResolveValue(key)
74}
75
76// KnownProviders returns the list of known providers.
77func (s *ConfigStore) KnownProviders() []catwalk.Provider {
78 return s.knownProviders
79}
80
81// SetupAgents configures the coder and task agents on the config.
82func (s *ConfigStore) SetupAgents() {
83 s.config.SetupAgents()
84}
85
86// Overrides returns the runtime overrides for this store.
87func (s *ConfigStore) Overrides() *RuntimeOverrides {
88 return &s.overrides
89}
90
91// LoadedPaths returns the config file paths that were successfully loaded.
92func (s *ConfigStore) LoadedPaths() []string {
93 return slices.Clone(s.loadedPaths)
94}
95
96// configPath returns the file path for the given scope.
97func (s *ConfigStore) configPath(scope Scope) (string, error) {
98 switch scope {
99 case ScopeWorkspace:
100 if s.workspacePath == "" {
101 return "", ErrNoWorkspaceConfig
102 }
103 return s.workspacePath, nil
104 default:
105 return s.globalDataPath, nil
106 }
107}
108
109// HasConfigField checks whether a key exists in the config file for the given
110// scope.
111func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
112 path, err := s.configPath(scope)
113 if err != nil {
114 return false
115 }
116 data, err := os.ReadFile(path)
117 if err != nil {
118 return false
119 }
120 return gjson.Get(string(data), key).Exists()
121}
122
123// SetConfigField sets a key/value pair in the config file for the given scope.
124func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
125 path, err := s.configPath(scope)
126 if err != nil {
127 return fmt.Errorf("%s: %w", key, err)
128 }
129 data, err := os.ReadFile(path)
130 if err != nil {
131 if os.IsNotExist(err) {
132 data = []byte("{}")
133 } else {
134 return fmt.Errorf("failed to read config file: %w", err)
135 }
136 }
137
138 newValue, err := sjson.Set(string(data), key, value)
139 if err != nil {
140 return fmt.Errorf("failed to set config field %s: %w", key, err)
141 }
142 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
143 return fmt.Errorf("failed to create config directory %q: %w", path, err)
144 }
145 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
146 return fmt.Errorf("failed to write config file: %w", err)
147 }
148 return nil
149}
150
151// RemoveConfigField removes a key from the config file for the given scope.
152func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
153 path, err := s.configPath(scope)
154 if err != nil {
155 return fmt.Errorf("%s: %w", key, err)
156 }
157 data, err := os.ReadFile(path)
158 if err != nil {
159 return fmt.Errorf("failed to read config file: %w", err)
160 }
161
162 newValue, err := sjson.Delete(string(data), key)
163 if err != nil {
164 return fmt.Errorf("failed to delete config field %s: %w", key, err)
165 }
166 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
167 return fmt.Errorf("failed to create config directory %q: %w", path, err)
168 }
169 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
170 return fmt.Errorf("failed to write config file: %w", err)
171 }
172 return nil
173}
174
175// UpdatePreferredModel updates the preferred model for the given type and
176// persists it to the config file at the given scope.
177func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
178 s.config.Models[modelType] = model
179 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
180 return fmt.Errorf("failed to update preferred model: %w", err)
181 }
182 if err := s.recordRecentModel(scope, modelType, model); err != nil {
183 return err
184 }
185 return nil
186}
187
188// SetCompactMode sets the compact mode setting and persists it.
189func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
190 if s.config.Options == nil {
191 s.config.Options = &Options{}
192 }
193 s.config.Options.TUI.CompactMode = enabled
194 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
195}
196
197// SetTransparentBackground sets the transparent background setting and persists it.
198func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
199 if s.config.Options == nil {
200 s.config.Options = &Options{}
201 }
202 s.config.Options.TUI.Transparent = &enabled
203 return s.SetConfigField(scope, "options.tui.transparent", enabled)
204}
205
206// SetProviderAPIKey sets the API key for a provider and persists it.
207func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
208 var providerConfig ProviderConfig
209 var exists bool
210 var setKeyOrToken func()
211
212 switch v := apiKey.(type) {
213 case string:
214 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
215 return fmt.Errorf("failed to save api key to config file: %w", err)
216 }
217 setKeyOrToken = func() { providerConfig.APIKey = v }
218 case *oauth.Token:
219 if err := cmp.Or(
220 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
221 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
222 ); err != nil {
223 return err
224 }
225 setKeyOrToken = func() {
226 providerConfig.APIKey = v.AccessToken
227 providerConfig.OAuthToken = v
228 switch providerID {
229 case string(catwalk.InferenceProviderCopilot):
230 providerConfig.SetupGitHubCopilot()
231 }
232 }
233 }
234
235 providerConfig, exists = s.config.Providers.Get(providerID)
236 if exists {
237 setKeyOrToken()
238 s.config.Providers.Set(providerID, providerConfig)
239 return nil
240 }
241
242 var foundProvider *catwalk.Provider
243 for _, p := range s.knownProviders {
244 if string(p.ID) == providerID {
245 foundProvider = &p
246 break
247 }
248 }
249
250 if foundProvider != nil {
251 providerConfig = ProviderConfig{
252 ID: providerID,
253 Name: foundProvider.Name,
254 BaseURL: foundProvider.APIEndpoint,
255 Type: foundProvider.Type,
256 Disable: false,
257 ExtraHeaders: make(map[string]string),
258 ExtraParams: make(map[string]string),
259 Models: foundProvider.Models,
260 }
261 setKeyOrToken()
262 } else {
263 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
264 }
265 s.config.Providers.Set(providerID, providerConfig)
266 return nil
267}
268
269// RefreshOAuthToken refreshes the OAuth token for the given provider.
270func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
271 providerConfig, exists := s.config.Providers.Get(providerID)
272 if !exists {
273 return fmt.Errorf("provider %s not found", providerID)
274 }
275
276 if providerConfig.OAuthToken == nil {
277 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
278 }
279
280 var newToken *oauth.Token
281 var refreshErr error
282 switch providerID {
283 case string(catwalk.InferenceProviderCopilot):
284 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
285 case hyperp.Name:
286 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
287 default:
288 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
289 }
290 if refreshErr != nil {
291 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
292 }
293
294 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
295 providerConfig.OAuthToken = newToken
296 providerConfig.APIKey = newToken.AccessToken
297
298 switch providerID {
299 case string(catwalk.InferenceProviderCopilot):
300 providerConfig.SetupGitHubCopilot()
301 }
302
303 s.config.Providers.Set(providerID, providerConfig)
304
305 if err := cmp.Or(
306 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
307 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
308 ); err != nil {
309 return fmt.Errorf("failed to persist refreshed token: %w", err)
310 }
311
312 return nil
313}
314
315// recordRecentModel records a model in the recent models list.
316func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
317 if model.Provider == "" || model.Model == "" {
318 return nil
319 }
320
321 if s.config.RecentModels == nil {
322 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
323 }
324
325 eq := func(a, b SelectedModel) bool {
326 return a.Provider == b.Provider && a.Model == b.Model
327 }
328
329 entry := SelectedModel{
330 Provider: model.Provider,
331 Model: model.Model,
332 }
333
334 current := s.config.RecentModels[modelType]
335 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
336 return eq(existing, entry)
337 })
338
339 updated := append([]SelectedModel{entry}, withoutCurrent...)
340 if len(updated) > maxRecentModelsPerType {
341 updated = updated[:maxRecentModelsPerType]
342 }
343
344 if slices.EqualFunc(current, updated, eq) {
345 return nil
346 }
347
348 s.config.RecentModels[modelType] = updated
349
350 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
351 return fmt.Errorf("failed to persist recent models: %w", err)
352 }
353
354 return nil
355}
356
357// NewTestStore creates a ConfigStore for testing purposes.
358func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore {
359 return &ConfigStore{
360 config: cfg,
361 loadedPaths: loadedPaths,
362 }
363}
364
365// ImportCopilot attempts to import a GitHub Copilot token from disk.
366func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
367 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
368 return nil, false
369 }
370
371 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
372 if !hasDiskToken {
373 return nil, false
374 }
375
376 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
377 token, err := copilot.RefreshToken(context.TODO(), diskToken)
378 if err != nil {
379 slog.Error("Unable to import GitHub Copilot token", "error", err)
380 return nil, false
381 }
382
383 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
384 return token, false
385 }
386
387 if err := cmp.Or(
388 s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
389 s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
390 ); err != nil {
391 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
392 }
393
394 slog.Info("GitHub Copilot successfully imported")
395 return token, true
396}
397
398// StalenessResult contains the result of a staleness check.
399type StalenessResult struct {
400 Dirty bool
401 Changed []string
402 Missing []string
403 Errors map[string]error // stat errors by path
404}
405
406// ConfigStaleness checks whether any tracked config files have changed on disk
407// since the last snapshot. Returns dirty=true if any files changed or went
408// missing, along with sorted lists of affected paths. Stat errors are
409// captured in Errors map but still treated as non-existence for dirty detection.
410func (s *ConfigStore) ConfigStaleness() StalenessResult {
411 var result StalenessResult
412 result.Errors = make(map[string]error)
413
414 for _, path := range s.trackedConfigPaths {
415 snapshot, hadSnapshot := s.snapshots[path]
416
417 info, err := os.Stat(path)
418 exists := err == nil && !info.IsDir()
419
420 if err != nil && !os.IsNotExist(err) {
421 // Capture permission/IO errors separately from non-existence
422 result.Errors[path] = err
423 result.Dirty = true
424 }
425
426 if !exists {
427 if hadSnapshot && snapshot.Exists {
428 // File existed before but now missing
429 result.Missing = append(result.Missing, path)
430 result.Dirty = true
431 }
432 continue
433 }
434
435 // File exists now
436 if !hadSnapshot || !snapshot.Exists {
437 // File didn't exist before but does now
438 result.Changed = append(result.Changed, path)
439 result.Dirty = true
440 continue
441 }
442
443 // Check for content or metadata changes
444 if snapshot.Size != info.Size() || snapshot.ModTime != info.ModTime().UnixNano() {
445 result.Changed = append(result.Changed, path)
446 result.Dirty = true
447 }
448 }
449
450 // Sort for deterministic output
451 slices.Sort(result.Changed)
452 slices.Sort(result.Missing)
453
454 return result
455}
456
457// RefreshStalenessSnapshot captures fresh snapshots of all tracked config files.
458// Call this after reloading config to clear dirty state.
459func (s *ConfigStore) RefreshStalenessSnapshot() error {
460 if s.snapshots == nil {
461 s.snapshots = make(map[string]fileSnapshot)
462 }
463
464 for _, path := range s.trackedConfigPaths {
465 info, err := os.Stat(path)
466 exists := err == nil && !info.IsDir()
467
468 snapshot := fileSnapshot{
469 Path: path,
470 Exists: exists,
471 }
472
473 if exists {
474 snapshot.Size = info.Size()
475 snapshot.ModTime = info.ModTime().UnixNano()
476 }
477
478 s.snapshots[path] = snapshot
479 }
480
481 return nil
482}
483
484// CaptureStalenessSnapshot captures snapshots for the given paths, building the
485// tracked config paths list. Paths are deduplicated and normalized.
486func (s *ConfigStore) CaptureStalenessSnapshot(paths []string) {
487 // Build unique set of normalized paths
488 seen := make(map[string]struct{})
489 for _, p := range paths {
490 if p == "" {
491 continue
492 }
493 // Normalize path
494 abs, err := filepath.Abs(p)
495 if err != nil {
496 abs = p
497 }
498 seen[abs] = struct{}{}
499 }
500
501 // Also track workspace and global config paths if set
502 if s.workspacePath != "" {
503 abs, err := filepath.Abs(s.workspacePath)
504 if err == nil {
505 seen[abs] = struct{}{}
506 }
507 }
508 if s.globalDataPath != "" {
509 abs, err := filepath.Abs(s.globalDataPath)
510 if err == nil {
511 seen[abs] = struct{}{}
512 }
513 }
514
515 // Build sorted list for deterministic ordering
516 s.trackedConfigPaths = make([]string, 0, len(seen))
517 for p := range seen {
518 s.trackedConfigPaths = append(s.trackedConfigPaths, p)
519 }
520 slices.Sort(s.trackedConfigPaths)
521
522 // Capture initial snapshots
523 s.RefreshStalenessSnapshot()
524}
525
526// captureStalenessSnapshot is an alias for CaptureStalenessSnapshot for internal use.
527func (s *ConfigStore) captureStalenessSnapshot(paths []string) {
528 s.CaptureStalenessSnapshot(paths)
529}
530
531// ReloadFromDisk re-runs the config load/merge flow and updates the in-memory
532// config safely. It rebuilds the staleness snapshot after successful reload.
533func (s *ConfigStore) ReloadFromDisk(ctx context.Context) error {
534 if s.workingDir == "" {
535 return fmt.Errorf("cannot reload: working directory not set")
536 }
537
538 configPaths := lookupConfigs(s.workingDir)
539 cfg, loadedPaths, err := loadFromConfigPaths(configPaths)
540 if err != nil {
541 return fmt.Errorf("failed to reload config: %w", err)
542 }
543
544 // Apply defaults (using existing data directory if set)
545 var dataDir string
546 if s.config != nil && s.config.Options != nil {
547 dataDir = s.config.Options.DataDirectory
548 }
549 cfg.setDefaults(s.workingDir, dataDir)
550
551 // Merge workspace config if present
552 workspacePath := filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName))
553 if wsData, err := os.ReadFile(workspacePath); err == nil && len(wsData) > 0 {
554 merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData))
555 if mergeErr == nil {
556 dataDir := cfg.Options.DataDirectory
557 *cfg = *merged
558 cfg.setDefaults(s.workingDir, dataDir)
559 loadedPaths = append(loadedPaths, workspacePath)
560 }
561 }
562
563 // Preserve runtime overrides
564 overrides := s.overrides
565
566 // Reconfigure providers
567 env := env.New()
568 resolver := NewShellVariableResolver(env)
569 providers, err := Providers(cfg)
570 if err != nil {
571 return fmt.Errorf("failed to load providers during reload: %w", err)
572 }
573
574 if err := cfg.configureProviders(s, env, resolver, providers); err != nil {
575 return fmt.Errorf("failed to configure providers during reload: %w", err)
576 }
577
578 // Update store state BEFORE running model/agent setup (so they see new config)
579 s.config = cfg
580 s.loadedPaths = loadedPaths
581 s.resolver = resolver
582 s.knownProviders = providers
583 s.overrides = overrides
584 s.workspacePath = workspacePath
585
586 // Mirror startup flow: setup models and agents against NEW config
587 if !cfg.IsConfigured() {
588 slog.Warn("No providers configured after reload")
589 } else {
590 if err := configureSelectedModels(s, providers); err != nil {
591 return fmt.Errorf("failed to configure selected models during reload: %w", err)
592 }
593 s.SetupAgents()
594 }
595
596 // Rebuild staleness tracking
597 s.captureStalenessSnapshot(loadedPaths)
598
599 return nil
600}