1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/env"
15 "github.com/charmbracelet/crush/internal/oauth"
16 "github.com/charmbracelet/crush/internal/oauth/copilot"
17 "github.com/charmbracelet/crush/internal/oauth/hyper"
18 "github.com/tidwall/gjson"
19 "github.com/tidwall/sjson"
20)
21
22// fileSnapshot captures metadata about a config file at a point in time.
23type fileSnapshot struct {
24 Path string
25 Exists bool
26 Size int64
27 ModTime int64 // UnixNano
28}
29
30// RuntimeOverrides holds per-session settings that are never persisted to
31// disk. They are applied on top of the loaded Config and survive only for
32// the lifetime of the process (or workspace).
33type RuntimeOverrides struct {
34 SkipPermissionRequests bool
35}
36
37// ConfigStore is the single entry point for all config access. It owns the
38// pure-data Config, runtime state (working directory, resolver, known
39// providers), and persistence to both global and workspace config files.
40type ConfigStore struct {
41 config *Config
42 workingDir string
43 resolver VariableResolver
44 globalDataPath string // ~/.local/share/crush/crush.json
45 workspacePath string // .crush/crush.json
46 loadedPaths []string // config files that were successfully loaded
47 knownProviders []catwalk.Provider
48 overrides RuntimeOverrides
49 trackedConfigPaths []string // unique, normalized config file paths
50 snapshots map[string]fileSnapshot // path -> snapshot at last capture
51 autoReloadDisabled bool // set during load/reload to prevent re-entrancy
52 reloadInProgress bool // set during reload to avoid disk writes mid-reload
53}
54
55// Config returns the pure-data config struct (read-only after load).
56func (s *ConfigStore) Config() *Config {
57 return s.config
58}
59
60// WorkingDir returns the current working directory.
61func (s *ConfigStore) WorkingDir() string {
62 return s.workingDir
63}
64
65// Resolver returns the variable resolver.
66func (s *ConfigStore) Resolver() VariableResolver {
67 return s.resolver
68}
69
70// Resolve resolves a variable reference using the configured resolver.
71func (s *ConfigStore) Resolve(key string) (string, error) {
72 if s.resolver == nil {
73 return "", fmt.Errorf("no variable resolver configured")
74 }
75 return s.resolver.ResolveValue(key)
76}
77
78// KnownProviders returns the list of known providers.
79func (s *ConfigStore) KnownProviders() []catwalk.Provider {
80 return s.knownProviders
81}
82
83// SetupAgents configures the coder and task agents on the config.
84func (s *ConfigStore) SetupAgents() {
85 s.config.SetupAgents()
86}
87
88// Overrides returns the runtime overrides for this store.
89func (s *ConfigStore) Overrides() *RuntimeOverrides {
90 return &s.overrides
91}
92
93// LoadedPaths returns the config file paths that were successfully loaded.
94func (s *ConfigStore) LoadedPaths() []string {
95 return slices.Clone(s.loadedPaths)
96}
97
98// configPath returns the file path for the given scope.
99func (s *ConfigStore) configPath(scope Scope) (string, error) {
100 switch scope {
101 case ScopeWorkspace:
102 if s.workspacePath == "" {
103 return "", ErrNoWorkspaceConfig
104 }
105 return s.workspacePath, nil
106 default:
107 return s.globalDataPath, nil
108 }
109}
110
111// HasConfigField checks whether a key exists in the config file for the given
112// scope.
113func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
114 path, err := s.configPath(scope)
115 if err != nil {
116 return false
117 }
118 data, err := os.ReadFile(path)
119 if err != nil {
120 return false
121 }
122 return gjson.Get(string(data), key).Exists()
123}
124
125// SetConfigField sets a key/value pair in the config file for the given scope.
126// After a successful write, it automatically reloads config to keep in-memory
127// state fresh.
128func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
129 path, err := s.configPath(scope)
130 if err != nil {
131 return fmt.Errorf("%s: %w", key, err)
132 }
133 data, err := os.ReadFile(path)
134 if err != nil {
135 if os.IsNotExist(err) {
136 data = []byte("{}")
137 } else {
138 return fmt.Errorf("failed to read config file: %w", err)
139 }
140 }
141
142 newValue, err := sjson.Set(string(data), key, value)
143 if err != nil {
144 return fmt.Errorf("failed to set config field %s: %w", key, err)
145 }
146 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
147 return fmt.Errorf("failed to create config directory %q: %w", path, err)
148 }
149 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
150 return fmt.Errorf("failed to write config file: %w", err)
151 }
152
153 // Auto-reload to keep in-memory state fresh after config edits.
154 // We use context.Background() since this is an internal operation that
155 // shouldn't be cancelled by user context.
156 if err := s.autoReload(context.Background()); err != nil {
157 // Log warning but don't fail the write - disk is already updated.
158 slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
159 }
160
161 return nil
162}
163
164// RemoveConfigField removes a key from the config file for the given scope.
165// After a successful write, it automatically reloads config to keep in-memory
166// state fresh.
167func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
168 path, err := s.configPath(scope)
169 if err != nil {
170 return fmt.Errorf("%s: %w", key, err)
171 }
172 data, err := os.ReadFile(path)
173 if err != nil {
174 return fmt.Errorf("failed to read config file: %w", err)
175 }
176
177 newValue, err := sjson.Delete(string(data), key)
178 if err != nil {
179 return fmt.Errorf("failed to delete config field %s: %w", key, err)
180 }
181 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
182 return fmt.Errorf("failed to create config directory %q: %w", path, err)
183 }
184 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
185 return fmt.Errorf("failed to write config file: %w", err)
186 }
187
188 // Auto-reload to keep in-memory state fresh after config edits.
189 if err := s.autoReload(context.Background()); err != nil {
190 slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
191 }
192
193 return nil
194}
195
196// UpdatePreferredModel updates the preferred model for the given type and
197// persists it to the config file at the given scope.
198func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
199 s.config.Models[modelType] = model
200 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
201 return fmt.Errorf("failed to update preferred model: %w", err)
202 }
203 if err := s.recordRecentModel(scope, modelType, model); err != nil {
204 return err
205 }
206 return nil
207}
208
209// SetCompactMode sets the compact mode setting and persists it.
210func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
211 if s.config.Options == nil {
212 s.config.Options = &Options{}
213 }
214 s.config.Options.TUI.CompactMode = enabled
215 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
216}
217
218// SetTransparentBackground sets the transparent background setting and persists it.
219func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
220 if s.config.Options == nil {
221 s.config.Options = &Options{}
222 }
223 s.config.Options.TUI.Transparent = &enabled
224 return s.SetConfigField(scope, "options.tui.transparent", enabled)
225}
226
227// SetProviderAPIKey sets the API key for a provider and persists it.
228func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
229 var providerConfig ProviderConfig
230 var exists bool
231 var setKeyOrToken func()
232
233 switch v := apiKey.(type) {
234 case string:
235 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
236 return fmt.Errorf("failed to save api key to config file: %w", err)
237 }
238 setKeyOrToken = func() { providerConfig.APIKey = v }
239 case *oauth.Token:
240 if err := cmp.Or(
241 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
242 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
243 ); err != nil {
244 return err
245 }
246 setKeyOrToken = func() {
247 providerConfig.APIKey = v.AccessToken
248 providerConfig.OAuthToken = v
249 switch providerID {
250 case string(catwalk.InferenceProviderCopilot):
251 providerConfig.SetupGitHubCopilot()
252 }
253 }
254 }
255
256 providerConfig, exists = s.config.Providers.Get(providerID)
257 if exists {
258 setKeyOrToken()
259 s.config.Providers.Set(providerID, providerConfig)
260 return nil
261 }
262
263 var foundProvider *catwalk.Provider
264 for _, p := range s.knownProviders {
265 if string(p.ID) == providerID {
266 foundProvider = &p
267 break
268 }
269 }
270
271 if foundProvider != nil {
272 providerConfig = ProviderConfig{
273 ID: providerID,
274 Name: foundProvider.Name,
275 BaseURL: foundProvider.APIEndpoint,
276 Type: foundProvider.Type,
277 Disable: false,
278 ExtraHeaders: make(map[string]string),
279 ExtraParams: make(map[string]string),
280 Models: foundProvider.Models,
281 }
282 setKeyOrToken()
283 } else {
284 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
285 }
286 s.config.Providers.Set(providerID, providerConfig)
287 return nil
288}
289
290// RefreshOAuthToken refreshes the OAuth token for the given provider.
291func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
292 providerConfig, exists := s.config.Providers.Get(providerID)
293 if !exists {
294 return fmt.Errorf("provider %s not found", providerID)
295 }
296
297 if providerConfig.OAuthToken == nil {
298 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
299 }
300
301 var newToken *oauth.Token
302 var refreshErr error
303 switch providerID {
304 case string(catwalk.InferenceProviderCopilot):
305 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
306 case hyperp.Name:
307 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
308 default:
309 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
310 }
311 if refreshErr != nil {
312 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
313 }
314
315 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
316 providerConfig.OAuthToken = newToken
317 providerConfig.APIKey = newToken.AccessToken
318
319 switch providerID {
320 case string(catwalk.InferenceProviderCopilot):
321 providerConfig.SetupGitHubCopilot()
322 }
323
324 s.config.Providers.Set(providerID, providerConfig)
325
326 if err := cmp.Or(
327 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
328 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
329 ); err != nil {
330 return fmt.Errorf("failed to persist refreshed token: %w", err)
331 }
332
333 return nil
334}
335
336// recordRecentModel records a model in the recent models list.
337func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
338 if model.Provider == "" || model.Model == "" {
339 return nil
340 }
341
342 if s.config.RecentModels == nil {
343 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
344 }
345
346 eq := func(a, b SelectedModel) bool {
347 return a.Provider == b.Provider && a.Model == b.Model
348 }
349
350 entry := SelectedModel{
351 Provider: model.Provider,
352 Model: model.Model,
353 }
354
355 current := s.config.RecentModels[modelType]
356 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
357 return eq(existing, entry)
358 })
359
360 updated := append([]SelectedModel{entry}, withoutCurrent...)
361 if len(updated) > maxRecentModelsPerType {
362 updated = updated[:maxRecentModelsPerType]
363 }
364
365 if slices.EqualFunc(current, updated, eq) {
366 return nil
367 }
368
369 s.config.RecentModels[modelType] = updated
370
371 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
372 return fmt.Errorf("failed to persist recent models: %w", err)
373 }
374
375 return nil
376}
377
378// NewTestStore creates a ConfigStore for testing purposes.
379func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore {
380 return &ConfigStore{
381 config: cfg,
382 loadedPaths: loadedPaths,
383 }
384}
385
386// ImportCopilot attempts to import a GitHub Copilot token from disk.
387func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
388 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
389 return nil, false
390 }
391
392 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
393 if !hasDiskToken {
394 return nil, false
395 }
396
397 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
398 token, err := copilot.RefreshToken(context.TODO(), diskToken)
399 if err != nil {
400 slog.Error("Unable to import GitHub Copilot token", "error", err)
401 return nil, false
402 }
403
404 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
405 return token, false
406 }
407
408 if err := cmp.Or(
409 s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
410 s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
411 ); err != nil {
412 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
413 }
414
415 slog.Info("GitHub Copilot successfully imported")
416 return token, true
417}
418
419// StalenessResult contains the result of a staleness check.
420type StalenessResult struct {
421 Dirty bool
422 Changed []string
423 Missing []string
424 Errors map[string]error // stat errors by path
425}
426
427// ConfigStaleness checks whether any tracked config files have changed on disk
428// since the last snapshot. Returns dirty=true if any files changed or went
429// missing, along with sorted lists of affected paths. Stat errors are
430// captured in Errors map but still treated as non-existence for dirty detection.
431func (s *ConfigStore) ConfigStaleness() StalenessResult {
432 var result StalenessResult
433 result.Errors = make(map[string]error)
434
435 for _, path := range s.trackedConfigPaths {
436 snapshot, hadSnapshot := s.snapshots[path]
437
438 info, err := os.Stat(path)
439 exists := err == nil && !info.IsDir()
440
441 if err != nil && !os.IsNotExist(err) {
442 // Capture permission/IO errors separately from non-existence
443 result.Errors[path] = err
444 result.Dirty = true
445 }
446
447 if !exists {
448 if hadSnapshot && snapshot.Exists {
449 // File existed before but now missing
450 result.Missing = append(result.Missing, path)
451 result.Dirty = true
452 }
453 continue
454 }
455
456 // File exists now
457 if !hadSnapshot || !snapshot.Exists {
458 // File didn't exist before but does now
459 result.Changed = append(result.Changed, path)
460 result.Dirty = true
461 continue
462 }
463
464 // Check for content or metadata changes
465 if snapshot.Size != info.Size() || snapshot.ModTime != info.ModTime().UnixNano() {
466 result.Changed = append(result.Changed, path)
467 result.Dirty = true
468 }
469 }
470
471 // Sort for deterministic output
472 slices.Sort(result.Changed)
473 slices.Sort(result.Missing)
474
475 return result
476}
477
478// RefreshStalenessSnapshot captures fresh snapshots of all tracked config files.
479// Call this after reloading config to clear dirty state.
480func (s *ConfigStore) RefreshStalenessSnapshot() error {
481 if s.snapshots == nil {
482 s.snapshots = make(map[string]fileSnapshot)
483 }
484
485 for _, path := range s.trackedConfigPaths {
486 info, err := os.Stat(path)
487 exists := err == nil && !info.IsDir()
488
489 snapshot := fileSnapshot{
490 Path: path,
491 Exists: exists,
492 }
493
494 if exists {
495 snapshot.Size = info.Size()
496 snapshot.ModTime = info.ModTime().UnixNano()
497 }
498
499 s.snapshots[path] = snapshot
500 }
501
502 return nil
503}
504
505// CaptureStalenessSnapshot captures snapshots for the given paths, building the
506// tracked config paths list. Paths are deduplicated and normalized.
507func (s *ConfigStore) CaptureStalenessSnapshot(paths []string) {
508 // Build unique set of normalized paths
509 seen := make(map[string]struct{})
510 for _, p := range paths {
511 if p == "" {
512 continue
513 }
514 // Normalize path
515 abs, err := filepath.Abs(p)
516 if err != nil {
517 abs = p
518 }
519 seen[abs] = struct{}{}
520 }
521
522 // Also track workspace and global config paths if set
523 if s.workspacePath != "" {
524 abs, err := filepath.Abs(s.workspacePath)
525 if err == nil {
526 seen[abs] = struct{}{}
527 }
528 }
529 if s.globalDataPath != "" {
530 abs, err := filepath.Abs(s.globalDataPath)
531 if err == nil {
532 seen[abs] = struct{}{}
533 }
534 }
535
536 // Build sorted list for deterministic ordering
537 s.trackedConfigPaths = make([]string, 0, len(seen))
538 for p := range seen {
539 s.trackedConfigPaths = append(s.trackedConfigPaths, p)
540 }
541 slices.Sort(s.trackedConfigPaths)
542
543 // Capture initial snapshots
544 s.RefreshStalenessSnapshot()
545}
546
547// captureStalenessSnapshot is an alias for CaptureStalenessSnapshot for internal use.
548func (s *ConfigStore) captureStalenessSnapshot(paths []string) {
549 s.CaptureStalenessSnapshot(paths)
550}
551
552// ReloadFromDisk re-runs the config load/merge flow and updates the in-memory
553// config atomically. It rebuilds the staleness snapshot after successful reload.
554// On failure, the store state is rolled back to its previous state.
555func (s *ConfigStore) ReloadFromDisk(ctx context.Context) error {
556 if s.workingDir == "" {
557 return fmt.Errorf("cannot reload: working directory not set")
558 }
559
560 // Disable auto-reload during reload to prevent nested/re-entrant calls.
561 s.autoReloadDisabled = true
562 s.reloadInProgress = true
563 defer func() {
564 s.autoReloadDisabled = false
565 s.reloadInProgress = false
566 }()
567
568 configPaths := lookupConfigs(s.workingDir)
569 cfg, loadedPaths, err := loadFromConfigPaths(configPaths)
570 if err != nil {
571 return fmt.Errorf("failed to reload config: %w", err)
572 }
573
574 // Apply defaults (using existing data directory if set)
575 var dataDir string
576 if s.config != nil && s.config.Options != nil {
577 dataDir = s.config.Options.DataDirectory
578 }
579 cfg.setDefaults(s.workingDir, dataDir)
580
581 // Merge workspace config if present
582 workspacePath := filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName))
583 if wsData, err := os.ReadFile(workspacePath); err == nil && len(wsData) > 0 {
584 merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData))
585 if mergeErr == nil {
586 dataDir := cfg.Options.DataDirectory
587 *cfg = *merged
588 cfg.setDefaults(s.workingDir, dataDir)
589 loadedPaths = append(loadedPaths, workspacePath)
590 }
591 }
592
593 // Preserve runtime overrides
594 overrides := s.overrides
595
596 // Reconfigure providers
597 env := env.New()
598 resolver := NewShellVariableResolver(env)
599 providers, err := Providers(cfg)
600 if err != nil {
601 return fmt.Errorf("failed to load providers during reload: %w", err)
602 }
603
604 if err := cfg.configureProviders(s, env, resolver, providers); err != nil {
605 return fmt.Errorf("failed to configure providers during reload: %w", err)
606 }
607
608 // Save current state for potential rollback
609 oldConfig := s.config
610 oldLoadedPaths := s.loadedPaths
611 oldResolver := s.resolver
612 oldKnownProviders := s.knownProviders
613 oldOverrides := s.overrides
614 oldWorkspacePath := s.workspacePath
615
616 // Update store state BEFORE running model/agent setup (so they see new config)
617 s.config = cfg
618 s.loadedPaths = loadedPaths
619 s.resolver = resolver
620 s.knownProviders = providers
621 s.overrides = overrides
622 s.workspacePath = workspacePath
623
624 // Mirror startup flow: setup models and agents against NEW config
625 var setupErr error
626 if !cfg.IsConfigured() {
627 slog.Warn("No providers configured after reload")
628 } else {
629 if err := configureSelectedModels(s, providers, false); err != nil {
630 setupErr = fmt.Errorf("failed to configure selected models during reload: %w", err)
631 } else {
632 s.SetupAgents()
633 }
634 }
635
636 // Rollback on setup failure
637 if setupErr != nil {
638 s.config = oldConfig
639 s.loadedPaths = oldLoadedPaths
640 s.resolver = oldResolver
641 s.knownProviders = oldKnownProviders
642 s.overrides = oldOverrides
643 s.workspacePath = oldWorkspacePath
644 return setupErr
645 }
646
647 // Rebuild staleness tracking
648 s.captureStalenessSnapshot(loadedPaths)
649
650 return nil
651}
652
653// autoReload conditionally reloads config from disk after writes.
654// It returns nil (no error) for expected skip cases: when auto-reload is
655// disabled during load/reload flows, or when working directory is not set
656// (e.g., during testing). Only actual reload failures return an error.
657func (s *ConfigStore) autoReload(ctx context.Context) error {
658 if s.autoReloadDisabled {
659 return nil // Expected skip: already in load/reload flow
660 }
661 if s.workingDir == "" {
662 return nil // Expected skip: working directory not set
663 }
664 return s.ReloadFromDisk(ctx)
665}