1package config
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11 "sync"
12 "time"
13
14 "charm.land/catwalk/pkg/catwalk"
15 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
16 "github.com/charmbracelet/crush/internal/env"
17 "github.com/charmbracelet/crush/internal/lock"
18 "github.com/charmbracelet/crush/internal/oauth"
19 "github.com/charmbracelet/crush/internal/oauth/copilot"
20 "github.com/charmbracelet/crush/internal/oauth/hyper"
21 "github.com/tidwall/gjson"
22 "github.com/tidwall/sjson"
23)
24
25// configLockDeadline bounds how long lockConfig waits for the
26// cross-process flock before giving up. A few seconds is plenty for
27// honest contention; longer suggests something is wedged.
28const configLockDeadline = 5 * time.Second
29
30// fileSnapshot captures metadata about a config file at a point in time.
31type fileSnapshot struct {
32 Path string
33 Exists bool
34 Size int64
35 ModTime int64 // UnixNano
36}
37
38// RuntimeOverrides holds per-session settings that are never persisted to
39// disk. They are applied on top of the loaded Config and survive only for
40// the lifetime of the process (or workspace).
41type RuntimeOverrides struct {
42 SkipPermissionRequests bool
43}
44
45// ConfigStore is the single entry point for all config access. It owns the
46// pure-data Config, runtime state (working directory, resolver, known
47// providers), and persistence to both global and workspace config files.
48//
49// mu serialises all config file mutations (SetConfigFields,
50// RemoveConfigField, RefreshOAuthToken) to prevent both in-process
51// goroutine races and, together with the shared lock.File, cross-process
52// races on the config file.
53type ConfigStore struct {
54 config *Config
55 workingDir string
56 resolver VariableResolver
57 globalDataPath string // ~/.local/share/crush/crush.json
58 workspacePath string // .crush/crush.json
59 loadedPaths []string // config files that were successfully loaded
60 knownProviders []catwalk.Provider
61 overrides RuntimeOverrides
62 trackedConfigPaths []string // unique, normalized config file paths
63 snapshots map[string]fileSnapshot // path -> snapshot at last capture
64 autoReloadDisabled bool // set during load/reload to prevent re-entrancy
65 reloadInProgress bool // set during reload to avoid disk writes mid-reload
66
67 mu sync.Mutex
68}
69
70// Config returns the pure-data config struct (read-only after load).
71func (s *ConfigStore) Config() *Config {
72 return s.config
73}
74
75// WorkingDir returns the current working directory.
76func (s *ConfigStore) WorkingDir() string {
77 return s.workingDir
78}
79
80// Resolver returns the variable resolver.
81func (s *ConfigStore) Resolver() VariableResolver {
82 return s.resolver
83}
84
85// Resolve resolves a variable reference using the configured resolver.
86func (s *ConfigStore) Resolve(key string) (string, error) {
87 if s.resolver == nil {
88 return "", fmt.Errorf("no variable resolver configured")
89 }
90 return s.resolver.ResolveValue(key)
91}
92
93// KnownProviders returns the list of known providers.
94func (s *ConfigStore) KnownProviders() []catwalk.Provider {
95 return s.knownProviders
96}
97
98// SetupAgents configures the coder and task agents on the config.
99func (s *ConfigStore) SetupAgents() {
100 s.config.SetupAgents()
101}
102
103// Overrides returns the runtime overrides for this store.
104func (s *ConfigStore) Overrides() *RuntimeOverrides {
105 return &s.overrides
106}
107
108// LoadedPaths returns the config file paths that were successfully loaded.
109func (s *ConfigStore) LoadedPaths() []string {
110 return slices.Clone(s.loadedPaths)
111}
112
113// lockConfig acquires both the in-process mutex and a cross-process flock
114// on the config file for the given scope. Callers that need to do I/O
115// between reading and writing (e.g. an HTTP token exchange) must use
116// lockConfig explicitly rather than atomicWrite.
117//
118// The returned release function drops both locks. Callers must call it
119// as soon as the file access is complete — no I/O should be performed
120// while the lock is held.
121func (s *ConfigStore) lockConfig(scope Scope) (func(), error) {
122 s.mu.Lock()
123 path, err := s.configPath(scope)
124 if err != nil {
125 s.mu.Unlock()
126 return nil, err
127 }
128 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
129 s.mu.Unlock()
130 return nil, fmt.Errorf("create config directory: %w", err)
131 }
132 ctx, cancel := context.WithTimeout(context.Background(), configLockDeadline)
133 defer cancel()
134 release, err := lock.File(ctx, path+".lock")
135 if err != nil {
136 s.mu.Unlock()
137 return nil, fmt.Errorf("acquire config lock: %w", err)
138 }
139 return func() {
140 release()
141 s.mu.Unlock()
142 }, nil
143}
144
145// atomicWrite handles the lock-read-transform-write-unlock cycle for
146// config file mutations. The fn callback receives the current file
147// contents (raw bytes, or {} if the file is missing) and must return the
148// new contents. fn must be pure — no I/O, no network calls.
149func (s *ConfigStore) atomicWrite(scope Scope, fn func(current []byte) ([]byte, error)) error {
150 unlock, err := s.lockConfig(scope)
151 if err != nil {
152 return err
153 }
154 defer unlock()
155
156 path, err := s.configPath(scope)
157 if err != nil {
158 return err
159 }
160
161 data, err := os.ReadFile(path)
162 if err != nil {
163 if os.IsNotExist(err) {
164 data = []byte("{}")
165 } else {
166 return fmt.Errorf("read config file: %w", err)
167 }
168 }
169
170 newData, err := fn(data)
171 if err != nil {
172 return err
173 }
174
175 return atomicWriteFile(path, newData, 0o600)
176}
177
178// configPath returns the file path for the given scope.
179func (s *ConfigStore) configPath(scope Scope) (string, error) {
180 switch scope {
181 case ScopeWorkspace:
182 if s.workspacePath == "" {
183 return "", ErrNoWorkspaceConfig
184 }
185 return s.workspacePath, nil
186 default:
187 return s.globalDataPath, nil
188 }
189}
190
191// HasConfigField checks whether a key exists in the config file for the given
192// scope.
193func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
194 path, err := s.configPath(scope)
195 if err != nil {
196 return false
197 }
198 data, err := os.ReadFile(path)
199 if err != nil {
200 return false
201 }
202 return gjson.Get(string(data), key).Exists()
203}
204
205// SetConfigField sets a key/value pair in the config file for the given scope.
206// After a successful write, it automatically reloads config to keep in-memory
207// state fresh.
208func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
209 return s.SetConfigFields(scope, map[string]any{key: value})
210}
211
212// SetConfigFields sets multiple key/value pairs in the config file for the given
213// scope in a single write. After a successful write, it automatically reloads
214// config to keep in-memory state fresh. This is preferred over multiple
215// SetConfigField calls when writing several fields atomically to avoid
216// intermediate reloads with partial state.
217//
218// The write is protected by an in-process mutex and a cross-process flock
219// to prevent races between concurrent writers in different processes.
220func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
221 // Sort keys for deterministic output regardless of map iteration
222 // order. This also ensures consistent results when callers pass
223 // overlapping JSONPath keys (e.g. "a" and "a.b").
224 keys := make([]string, 0, len(kv))
225 for k := range kv {
226 keys = append(keys, k)
227 }
228 slices.Sort(keys)
229
230 err := s.atomicWrite(scope, func(data []byte) ([]byte, error) {
231 v := string(data)
232 for _, key := range keys {
233 var sErr error
234 v, sErr = sjson.Set(v, key, kv[key])
235 if sErr != nil {
236 return nil, fmt.Errorf("failed to set config field %s: %w", key, sErr)
237 }
238 }
239 return []byte(v), nil
240 })
241 if err != nil {
242 return err
243 }
244
245 // Auto-reload to keep in-memory state fresh after config edits.
246 // We use context.Background() since this is an internal operation that
247 // shouldn't be cancelled by user context.
248 if err := s.autoReload(context.Background()); err != nil {
249 // Log warning but don't fail the write - disk is already updated.
250 slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
251 }
252
253 return nil
254}
255
256// RemoveConfigField removes a key from the config file for the given scope.
257// After a successful write, it automatically reloads config to keep in-memory
258// state fresh.
259//
260// The write is protected by an in-process mutex and a cross-process flock.
261func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
262 err := s.atomicWrite(scope, func(data []byte) ([]byte, error) {
263 v, sErr := sjson.Delete(string(data), key)
264 if sErr != nil {
265 return nil, fmt.Errorf("failed to delete config field %s: %w", key, sErr)
266 }
267 return []byte(v), nil
268 })
269 if err != nil {
270 return err
271 }
272
273 if err := s.autoReload(context.Background()); err != nil {
274 slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
275 }
276
277 return nil
278}
279
280// UpdatePreferredModel updates the preferred model for the given type and
281// persists it to the config file at the given scope.
282func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
283 s.config.Models[modelType] = model
284 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
285 return fmt.Errorf("failed to update preferred model: %w", err)
286 }
287 if err := s.recordRecentModel(scope, modelType, model); err != nil {
288 return err
289 }
290 return nil
291}
292
293// SetCompactMode sets the compact mode setting and persists it.
294func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
295 if s.config.Options == nil {
296 s.config.Options = &Options{}
297 }
298 s.config.Options.TUI.CompactMode = enabled
299 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
300}
301
302// SetTransparentBackground sets the transparent background setting and persists it.
303func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
304 if s.config.Options == nil {
305 s.config.Options = &Options{}
306 }
307 s.config.Options.TUI.Transparent = &enabled
308 return s.SetConfigField(scope, "options.tui.transparent", enabled)
309}
310
311// SetProviderAPIKey sets the API key for a provider and persists it.
312func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
313 var providerConfig ProviderConfig
314 var exists bool
315 var setKeyOrToken func()
316
317 switch v := apiKey.(type) {
318 case string:
319 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
320 return fmt.Errorf("failed to save api key to config file: %w", err)
321 }
322 setKeyOrToken = func() { providerConfig.APIKey = v }
323 case *oauth.Token:
324 if err := s.SetConfigFields(scope, map[string]any{
325 fmt.Sprintf("providers.%s.api_key", providerID): v.AccessToken,
326 fmt.Sprintf("providers.%s.oauth", providerID): v,
327 }); err != nil {
328 return err
329 }
330 setKeyOrToken = func() {
331 providerConfig.APIKey = v.AccessToken
332 providerConfig.OAuthToken = v
333 switch providerID {
334 case string(catwalk.InferenceProviderCopilot):
335 providerConfig.SetupGitHubCopilot()
336 }
337 }
338 }
339
340 providerConfig, exists = s.config.Providers.Get(providerID)
341 if exists {
342 setKeyOrToken()
343 s.config.Providers.Set(providerID, providerConfig)
344 return nil
345 }
346
347 var foundProvider *catwalk.Provider
348 for _, p := range s.knownProviders {
349 if string(p.ID) == providerID {
350 foundProvider = &p
351 break
352 }
353 }
354
355 if foundProvider != nil {
356 providerConfig = ProviderConfig{
357 ID: providerID,
358 Name: foundProvider.Name,
359 BaseURL: foundProvider.APIEndpoint,
360 Type: foundProvider.Type,
361 Disable: false,
362 ExtraHeaders: make(map[string]string),
363 ExtraParams: make(map[string]string),
364 Models: foundProvider.Models,
365 }
366 setKeyOrToken()
367 } else {
368 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
369 }
370 s.config.Providers.Set(providerID, providerConfig)
371 return nil
372}
373
374// RefreshOAuthToken refreshes the OAuth token for the given provider.
375//
376// It uses two-phase locking: the pre-check (reading the config file to
377// see if another process already refreshed) happens under the config
378// lock, then the HTTP exchange runs without any lock held, and finally
379// the result is persisted via SetConfigFields (which acquires the lock
380// internally). If the exchange fails — e.g. because another process
381// already rotated the refresh token — the disk is re-checked under lock
382// to recover the other process's token.
383func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
384 providerConfig, exists := s.config.Providers.Get(providerID)
385 if !exists {
386 return fmt.Errorf("provider %s not found", providerID)
387 }
388
389 if providerConfig.OAuthToken == nil {
390 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
391 }
392
393 // Phase 1: Pre-check under lock — did another process already refresh?
394 release, lockErr := s.lockConfig(scope)
395 if lockErr != nil {
396 slog.Warn("Failed to lock config for pre-check, proceeding anyway", "provider", providerID, "error", lockErr)
397 } else {
398 diskToken, err := s.loadTokenFromDisk(scope, providerID)
399 release()
400 if err != nil {
401 slog.Warn("Failed to read token from config file", "provider", providerID, "error", err)
402 } else if diskToken != nil && !diskToken.IsExpired() && diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
403 slog.Info("Using token refreshed by another session", "provider", providerID)
404 return s.applyToken(providerConfig, diskToken, providerID)
405 }
406 }
407
408 // Phase 2: HTTP exchange — no lock held.
409 var refreshedToken *oauth.Token
410 var refreshErr error
411 switch providerID {
412 case string(catwalk.InferenceProviderCopilot):
413 refreshedToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
414 case hyperp.Name:
415 refreshedToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
416 default:
417 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
418 }
419 if refreshErr != nil {
420 // Phase 3: Fallback — re-check disk under lock. The exchange may
421 // have failed because another process already rotated the refresh
422 // token.
423 if release, lockErr := s.lockConfig(scope); lockErr == nil {
424 diskToken, diskErr := s.loadTokenFromDisk(scope, providerID)
425 release()
426 if diskErr == nil &&
427 diskToken != nil &&
428 !diskToken.IsExpired() &&
429 diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
430 slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID)
431 return s.applyToken(providerConfig, diskToken, providerID)
432 }
433 }
434 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
435 }
436
437 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
438 providerConfig.OAuthToken = refreshedToken
439 providerConfig.APIKey = refreshedToken.AccessToken
440
441 switch providerID {
442 case string(catwalk.InferenceProviderCopilot):
443 providerConfig.SetupGitHubCopilot()
444 }
445
446 s.config.Providers.Set(providerID, providerConfig)
447
448 if err := s.SetConfigFields(scope, map[string]any{
449 fmt.Sprintf("providers.%s.api_key", providerID): refreshedToken.AccessToken,
450 fmt.Sprintf("providers.%s.oauth", providerID): refreshedToken,
451 }); err != nil {
452 return fmt.Errorf("failed to persist refreshed token: %w", err)
453 }
454
455 return nil
456}
457
458// applyToken updates the in-memory provider config with the given token.
459func (s *ConfigStore) applyToken(providerConfig ProviderConfig, token *oauth.Token, providerID string) error {
460 providerConfig.OAuthToken = token
461 providerConfig.APIKey = token.AccessToken
462 if providerID == string(catwalk.InferenceProviderCopilot) {
463 providerConfig.SetupGitHubCopilot()
464 }
465 s.config.Providers.Set(providerID, providerConfig)
466 return nil
467}
468
469// loadTokenFromDisk reads the OAuth token for the given provider from the
470// config file on disk. Returns nil if the token is not found or matches the
471// current in-memory token.
472func (s *ConfigStore) loadTokenFromDisk(scope Scope, providerID string) (*oauth.Token, error) {
473 path, err := s.configPath(scope)
474 if err != nil {
475 return nil, err
476 }
477
478 data, err := os.ReadFile(path)
479 if err != nil {
480 if os.IsNotExist(err) {
481 return nil, nil
482 }
483 return nil, err
484 }
485
486 oauthKey := fmt.Sprintf("providers.%s.oauth", providerID)
487 oauthResult := gjson.Get(string(data), oauthKey)
488 if !oauthResult.Exists() {
489 return nil, nil
490 }
491
492 var token oauth.Token
493 if err := json.Unmarshal([]byte(oauthResult.Raw), &token); err != nil {
494 return nil, err
495 }
496
497 if token.AccessToken == "" {
498 return nil, nil
499 }
500
501 return &token, nil
502}
503
504// recordRecentModel records a model in the recent models list.
505func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
506 if model.Provider == "" || model.Model == "" {
507 return nil
508 }
509
510 if s.config.RecentModels == nil {
511 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
512 }
513
514 eq := func(a, b SelectedModel) bool {
515 return a.Provider == b.Provider && a.Model == b.Model
516 }
517
518 entry := SelectedModel{
519 Provider: model.Provider,
520 Model: model.Model,
521 }
522
523 current := s.config.RecentModels[modelType]
524 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
525 return eq(existing, entry)
526 })
527
528 updated := append([]SelectedModel{entry}, withoutCurrent...)
529 if len(updated) > maxRecentModelsPerType {
530 updated = updated[:maxRecentModelsPerType]
531 }
532
533 if slices.EqualFunc(current, updated, eq) {
534 return nil
535 }
536
537 s.config.RecentModels[modelType] = updated
538
539 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
540 return fmt.Errorf("failed to persist recent models: %w", err)
541 }
542
543 return nil
544}
545
546// NewTestStore creates a ConfigStore for testing purposes.
547func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore {
548 return &ConfigStore{
549 config: cfg,
550 loadedPaths: loadedPaths,
551 }
552}
553
554// ImportCopilot attempts to import a GitHub Copilot token from disk.
555func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
556 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
557 return nil, false
558 }
559
560 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
561 if !hasDiskToken {
562 return nil, false
563 }
564
565 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
566 token, err := copilot.RefreshToken(context.TODO(), diskToken)
567 if err != nil {
568 slog.Error("Unable to import GitHub Copilot token", "error", err)
569 return nil, false
570 }
571
572 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
573 return token, false
574 }
575
576 if err := s.SetConfigFields(ScopeGlobal, map[string]any{
577 "providers.copilot.api_key": token.AccessToken,
578 "providers.copilot.oauth": token,
579 }); err != nil {
580 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
581 }
582
583 slog.Info("GitHub Copilot successfully imported")
584 return token, true
585}
586
587// StalenessResult contains the result of a staleness check.
588type StalenessResult struct {
589 Dirty bool
590 Changed []string
591 Missing []string
592 Errors map[string]error // stat errors by path
593}
594
595// ConfigStaleness checks whether any tracked config files have changed on disk
596// since the last snapshot. Returns dirty=true if any files changed or went
597// missing, along with sorted lists of affected paths. Stat errors are
598// captured in Errors map but still treated as non-existence for dirty detection.
599func (s *ConfigStore) ConfigStaleness() StalenessResult {
600 var result StalenessResult
601 result.Errors = make(map[string]error)
602
603 for _, path := range s.trackedConfigPaths {
604 snapshot, hadSnapshot := s.snapshots[path]
605
606 info, err := os.Stat(path)
607 exists := err == nil && !info.IsDir()
608
609 if err != nil && !os.IsNotExist(err) {
610 // Capture permission/IO errors separately from non-existence
611 result.Errors[path] = err
612 result.Dirty = true
613 }
614
615 if !exists {
616 if hadSnapshot && snapshot.Exists {
617 // File existed before but now missing
618 result.Missing = append(result.Missing, path)
619 result.Dirty = true
620 }
621 continue
622 }
623
624 // File exists now
625 if !hadSnapshot || !snapshot.Exists {
626 // File didn't exist before but does now
627 result.Changed = append(result.Changed, path)
628 result.Dirty = true
629 continue
630 }
631
632 // Check for content or metadata changes
633 if snapshot.Size != info.Size() || snapshot.ModTime != info.ModTime().UnixNano() {
634 result.Changed = append(result.Changed, path)
635 result.Dirty = true
636 }
637 }
638
639 // Sort for deterministic output
640 slices.Sort(result.Changed)
641 slices.Sort(result.Missing)
642
643 return result
644}
645
646// RefreshStalenessSnapshot captures fresh snapshots of all tracked config files.
647// Call this after reloading config to clear dirty state.
648func (s *ConfigStore) RefreshStalenessSnapshot() error {
649 if s.snapshots == nil {
650 s.snapshots = make(map[string]fileSnapshot)
651 }
652
653 for _, path := range s.trackedConfigPaths {
654 info, err := os.Stat(path)
655 exists := err == nil && !info.IsDir()
656
657 snapshot := fileSnapshot{
658 Path: path,
659 Exists: exists,
660 }
661
662 if exists {
663 snapshot.Size = info.Size()
664 snapshot.ModTime = info.ModTime().UnixNano()
665 }
666
667 s.snapshots[path] = snapshot
668 }
669
670 return nil
671}
672
673// CaptureStalenessSnapshot captures snapshots for the given paths, building the
674// tracked config paths list. Paths are deduplicated and normalized.
675func (s *ConfigStore) CaptureStalenessSnapshot(paths []string) {
676 // Build unique set of normalized paths
677 seen := make(map[string]struct{})
678 for _, p := range paths {
679 if p == "" {
680 continue
681 }
682 // Normalize path
683 abs, err := filepath.Abs(p)
684 if err != nil {
685 abs = p
686 }
687 seen[abs] = struct{}{}
688 }
689
690 // Also track workspace and global config paths if set
691 if s.workspacePath != "" {
692 abs, err := filepath.Abs(s.workspacePath)
693 if err == nil {
694 seen[abs] = struct{}{}
695 }
696 }
697 if s.globalDataPath != "" {
698 abs, err := filepath.Abs(s.globalDataPath)
699 if err == nil {
700 seen[abs] = struct{}{}
701 }
702 }
703
704 // Build sorted list for deterministic ordering
705 s.trackedConfigPaths = make([]string, 0, len(seen))
706 for p := range seen {
707 s.trackedConfigPaths = append(s.trackedConfigPaths, p)
708 }
709 slices.Sort(s.trackedConfigPaths)
710
711 // Capture initial snapshots
712 s.RefreshStalenessSnapshot()
713}
714
715// captureStalenessSnapshot is an alias for CaptureStalenessSnapshot for internal use.
716func (s *ConfigStore) captureStalenessSnapshot(paths []string) {
717 s.CaptureStalenessSnapshot(paths)
718}
719
720// ReloadFromDisk re-runs the config load/merge flow and updates the in-memory
721// config atomically. It rebuilds the staleness snapshot after successful reload.
722// On failure, the store state is rolled back to its previous state.
723func (s *ConfigStore) ReloadFromDisk(ctx context.Context) error {
724 if s.workingDir == "" {
725 return fmt.Errorf("cannot reload: working directory not set")
726 }
727
728 // Disable auto-reload during reload to prevent nested/re-entrant calls.
729 s.autoReloadDisabled = true
730 s.reloadInProgress = true
731 defer func() {
732 s.autoReloadDisabled = false
733 s.reloadInProgress = false
734 }()
735
736 // Migrate deprecated disable_notifications before reloading config.
737 migrateDisableNotifications()
738
739 configPaths := lookupConfigs(s.workingDir)
740 cfg, loadedPaths, err := loadFromConfigPaths(configPaths)
741 if err != nil {
742 return fmt.Errorf("failed to reload config: %w", err)
743 }
744
745 // Apply defaults (using existing data directory if set)
746 var dataDir string
747 if s.config != nil && s.config.Options != nil {
748 dataDir = s.config.Options.DataDirectory
749 }
750 cfg.setDefaults(s.workingDir, dataDir)
751
752 // Merge workspace config if present
753 workspacePath := filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName))
754 if wsData, err := os.ReadFile(workspacePath); err == nil && len(wsData) > 0 {
755 if !json.Valid(wsData) {
756 return fmt.Errorf("invalid JSON in config file %s", workspacePath)
757 }
758 merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData))
759 if mergeErr == nil {
760 dataDir := cfg.Options.DataDirectory
761 *cfg = *merged
762 cfg.setDefaults(s.workingDir, dataDir)
763 loadedPaths = append(loadedPaths, workspacePath)
764 }
765 }
766
767 // Validate hooks after all config merging is complete so matcher
768 // regexes are recompiled on the reloaded config (mirrors Load).
769 if err := cfg.ValidateHooks(); err != nil {
770 return fmt.Errorf("invalid hook configuration on reload: %w", err)
771 }
772
773 // Preserve runtime overrides
774 overrides := s.overrides
775
776 // Reconfigure providers
777 env := env.New()
778 resolver := NewShellVariableResolver(env)
779 providers, err := Providers(cfg)
780 if err != nil {
781 return fmt.Errorf("failed to load providers during reload: %w", err)
782 }
783
784 if err := cfg.configureProviders(s, env, resolver, providers); err != nil {
785 return fmt.Errorf("failed to configure providers during reload: %w", err)
786 }
787
788 // Save current state for potential rollback
789 oldConfig := s.config
790 oldLoadedPaths := s.loadedPaths
791 oldResolver := s.resolver
792 oldKnownProviders := s.knownProviders
793 oldOverrides := s.overrides
794 oldWorkspacePath := s.workspacePath
795
796 // Update store state BEFORE running model/agent setup (so they see new config)
797 s.config = cfg
798 s.loadedPaths = loadedPaths
799 s.resolver = resolver
800 s.knownProviders = providers
801 s.overrides = overrides
802 s.workspacePath = workspacePath
803
804 // Mirror startup flow: setup models and agents against NEW config
805 var setupErr error
806 if !cfg.IsConfigured() {
807 slog.Warn("No providers configured after reload")
808 } else {
809 if err := configureSelectedModels(s, providers, false); err != nil {
810 setupErr = fmt.Errorf("failed to configure selected models during reload: %w", err)
811 } else {
812 s.SetupAgents()
813 }
814 }
815
816 // Rollback on setup failure
817 if setupErr != nil {
818 s.config = oldConfig
819 s.loadedPaths = oldLoadedPaths
820 s.resolver = oldResolver
821 s.knownProviders = oldKnownProviders
822 s.overrides = oldOverrides
823 s.workspacePath = oldWorkspacePath
824 return setupErr
825 }
826
827 // Rebuild staleness tracking
828 s.captureStalenessSnapshot(loadedPaths)
829
830 return nil
831}
832
833// autoReload conditionally reloads config from disk after writes.
834// It returns nil (no error) for expected skip cases: when auto-reload is
835// disabled during load/reload flows, or when working directory is not set
836// (e.g., during testing). Only actual reload failures return an error.
837func (s *ConfigStore) autoReload(ctx context.Context) error {
838 if s.autoReloadDisabled {
839 return nil // Expected skip: already in load/reload flow
840 }
841 if s.workingDir == "" {
842 return nil // Expected skip: working directory not set
843 }
844 return s.ReloadFromDisk(ctx)
845}