store.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"slices"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/env"
 15	"github.com/charmbracelet/crush/internal/oauth"
 16	"github.com/charmbracelet/crush/internal/oauth/copilot"
 17	"github.com/charmbracelet/crush/internal/oauth/hyper"
 18	"github.com/tidwall/gjson"
 19	"github.com/tidwall/sjson"
 20)
 21
 22// fileSnapshot captures metadata about a config file at a point in time.
 23type fileSnapshot struct {
 24	Path    string
 25	Exists  bool
 26	Size    int64
 27	ModTime int64 // UnixNano
 28}
 29
 30// RuntimeOverrides holds per-session settings that are never persisted to
 31// disk. They are applied on top of the loaded Config and survive only for
 32// the lifetime of the process (or workspace).
 33type RuntimeOverrides struct {
 34	SkipPermissionRequests bool
 35}
 36
 37// ConfigStore is the single entry point for all config access. It owns the
 38// pure-data Config, runtime state (working directory, resolver, known
 39// providers), and persistence to both global and workspace config files.
 40type ConfigStore struct {
 41	config             *Config
 42	workingDir         string
 43	resolver           VariableResolver
 44	globalDataPath     string   // ~/.local/share/crush/crush.json
 45	workspacePath      string   // .crush/crush.json
 46	loadedPaths        []string // config files that were successfully loaded
 47	knownProviders     []catwalk.Provider
 48	overrides          RuntimeOverrides
 49	trackedConfigPaths []string                // unique, normalized config file paths
 50	snapshots          map[string]fileSnapshot // path -> snapshot at last capture
 51	autoReloadDisabled bool                    // set during load/reload to prevent re-entrancy
 52	reloadInProgress   bool                    // set during reload to avoid disk writes mid-reload
 53}
 54
 55// Config returns the pure-data config struct (read-only after load).
 56func (s *ConfigStore) Config() *Config {
 57	return s.config
 58}
 59
 60// WorkingDir returns the current working directory.
 61func (s *ConfigStore) WorkingDir() string {
 62	return s.workingDir
 63}
 64
 65// Resolver returns the variable resolver.
 66func (s *ConfigStore) Resolver() VariableResolver {
 67	return s.resolver
 68}
 69
 70// Resolve resolves a variable reference using the configured resolver.
 71func (s *ConfigStore) Resolve(key string) (string, error) {
 72	if s.resolver == nil {
 73		return "", fmt.Errorf("no variable resolver configured")
 74	}
 75	return s.resolver.ResolveValue(key)
 76}
 77
 78// KnownProviders returns the list of known providers.
 79func (s *ConfigStore) KnownProviders() []catwalk.Provider {
 80	return s.knownProviders
 81}
 82
 83// SetupAgents configures the coder and task agents on the config.
 84func (s *ConfigStore) SetupAgents() {
 85	s.config.SetupAgents()
 86}
 87
 88// Overrides returns the runtime overrides for this store.
 89func (s *ConfigStore) Overrides() *RuntimeOverrides {
 90	return &s.overrides
 91}
 92
 93// LoadedPaths returns the config file paths that were successfully loaded.
 94func (s *ConfigStore) LoadedPaths() []string {
 95	return slices.Clone(s.loadedPaths)
 96}
 97
 98// configPath returns the file path for the given scope.
 99func (s *ConfigStore) configPath(scope Scope) (string, error) {
100	switch scope {
101	case ScopeWorkspace:
102		if s.workspacePath == "" {
103			return "", ErrNoWorkspaceConfig
104		}
105		return s.workspacePath, nil
106	default:
107		return s.globalDataPath, nil
108	}
109}
110
111// HasConfigField checks whether a key exists in the config file for the given
112// scope.
113func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
114	path, err := s.configPath(scope)
115	if err != nil {
116		return false
117	}
118	data, err := os.ReadFile(path)
119	if err != nil {
120		return false
121	}
122	return gjson.Get(string(data), key).Exists()
123}
124
125// SetConfigField sets a key/value pair in the config file for the given scope.
126// After a successful write, it automatically reloads config to keep in-memory
127// state fresh.
128func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
129	path, err := s.configPath(scope)
130	if err != nil {
131		return fmt.Errorf("%s: %w", key, err)
132	}
133	data, err := os.ReadFile(path)
134	if err != nil {
135		if os.IsNotExist(err) {
136			data = []byte("{}")
137		} else {
138			return fmt.Errorf("failed to read config file: %w", err)
139		}
140	}
141
142	newValue, err := sjson.Set(string(data), key, value)
143	if err != nil {
144		return fmt.Errorf("failed to set config field %s: %w", key, err)
145	}
146	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
147		return fmt.Errorf("failed to create config directory %q: %w", path, err)
148	}
149	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
150		return fmt.Errorf("failed to write config file: %w", err)
151	}
152
153	// Auto-reload to keep in-memory state fresh after config edits.
154	// We use context.Background() since this is an internal operation that
155	// shouldn't be cancelled by user context.
156	if err := s.autoReload(context.Background()); err != nil {
157		// Log warning but don't fail the write - disk is already updated.
158		slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
159	}
160
161	return nil
162}
163
164// RemoveConfigField removes a key from the config file for the given scope.
165// After a successful write, it automatically reloads config to keep in-memory
166// state fresh.
167func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
168	path, err := s.configPath(scope)
169	if err != nil {
170		return fmt.Errorf("%s: %w", key, err)
171	}
172	data, err := os.ReadFile(path)
173	if err != nil {
174		return fmt.Errorf("failed to read config file: %w", err)
175	}
176
177	newValue, err := sjson.Delete(string(data), key)
178	if err != nil {
179		return fmt.Errorf("failed to delete config field %s: %w", key, err)
180	}
181	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
182		return fmt.Errorf("failed to create config directory %q: %w", path, err)
183	}
184	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
185		return fmt.Errorf("failed to write config file: %w", err)
186	}
187
188	// Auto-reload to keep in-memory state fresh after config edits.
189	if err := s.autoReload(context.Background()); err != nil {
190		slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
191	}
192
193	return nil
194}
195
196// UpdatePreferredModel updates the preferred model for the given type and
197// persists it to the config file at the given scope.
198func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
199	s.config.Models[modelType] = model
200	if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
201		return fmt.Errorf("failed to update preferred model: %w", err)
202	}
203	if err := s.recordRecentModel(scope, modelType, model); err != nil {
204		return err
205	}
206	return nil
207}
208
209// SetCompactMode sets the compact mode setting and persists it.
210func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
211	if s.config.Options == nil {
212		s.config.Options = &Options{}
213	}
214	s.config.Options.TUI.CompactMode = enabled
215	return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
216}
217
218// SetTransparentBackground sets the transparent background setting and persists it.
219func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
220	if s.config.Options == nil {
221		s.config.Options = &Options{}
222	}
223	s.config.Options.TUI.Transparent = &enabled
224	return s.SetConfigField(scope, "options.tui.transparent", enabled)
225}
226
227// SetProviderAPIKey sets the API key for a provider and persists it.
228func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
229	var providerConfig ProviderConfig
230	var exists bool
231	var setKeyOrToken func()
232
233	switch v := apiKey.(type) {
234	case string:
235		if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
236			return fmt.Errorf("failed to save api key to config file: %w", err)
237		}
238		setKeyOrToken = func() { providerConfig.APIKey = v }
239	case *oauth.Token:
240		if err := cmp.Or(
241			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
242			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
243		); err != nil {
244			return err
245		}
246		setKeyOrToken = func() {
247			providerConfig.APIKey = v.AccessToken
248			providerConfig.OAuthToken = v
249			switch providerID {
250			case string(catwalk.InferenceProviderCopilot):
251				providerConfig.SetupGitHubCopilot()
252			}
253		}
254	}
255
256	providerConfig, exists = s.config.Providers.Get(providerID)
257	if exists {
258		setKeyOrToken()
259		s.config.Providers.Set(providerID, providerConfig)
260		return nil
261	}
262
263	var foundProvider *catwalk.Provider
264	for _, p := range s.knownProviders {
265		if string(p.ID) == providerID {
266			foundProvider = &p
267			break
268		}
269	}
270
271	if foundProvider != nil {
272		providerConfig = ProviderConfig{
273			ID:           providerID,
274			Name:         foundProvider.Name,
275			BaseURL:      foundProvider.APIEndpoint,
276			Type:         foundProvider.Type,
277			Disable:      false,
278			ExtraHeaders: make(map[string]string),
279			ExtraParams:  make(map[string]string),
280			Models:       foundProvider.Models,
281		}
282		setKeyOrToken()
283	} else {
284		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
285	}
286	s.config.Providers.Set(providerID, providerConfig)
287	return nil
288}
289
290// RefreshOAuthToken refreshes the OAuth token for the given provider.
291func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
292	providerConfig, exists := s.config.Providers.Get(providerID)
293	if !exists {
294		return fmt.Errorf("provider %s not found", providerID)
295	}
296
297	if providerConfig.OAuthToken == nil {
298		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
299	}
300
301	var newToken *oauth.Token
302	var refreshErr error
303	switch providerID {
304	case string(catwalk.InferenceProviderCopilot):
305		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
306	case hyperp.Name:
307		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
308	default:
309		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
310	}
311	if refreshErr != nil {
312		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
313	}
314
315	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
316	providerConfig.OAuthToken = newToken
317	providerConfig.APIKey = newToken.AccessToken
318
319	switch providerID {
320	case string(catwalk.InferenceProviderCopilot):
321		providerConfig.SetupGitHubCopilot()
322	}
323
324	s.config.Providers.Set(providerID, providerConfig)
325
326	if err := cmp.Or(
327		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
328		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
329	); err != nil {
330		return fmt.Errorf("failed to persist refreshed token: %w", err)
331	}
332
333	return nil
334}
335
336// recordRecentModel records a model in the recent models list.
337func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
338	if model.Provider == "" || model.Model == "" {
339		return nil
340	}
341
342	if s.config.RecentModels == nil {
343		s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
344	}
345
346	eq := func(a, b SelectedModel) bool {
347		return a.Provider == b.Provider && a.Model == b.Model
348	}
349
350	entry := SelectedModel{
351		Provider: model.Provider,
352		Model:    model.Model,
353	}
354
355	current := s.config.RecentModels[modelType]
356	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
357		return eq(existing, entry)
358	})
359
360	updated := append([]SelectedModel{entry}, withoutCurrent...)
361	if len(updated) > maxRecentModelsPerType {
362		updated = updated[:maxRecentModelsPerType]
363	}
364
365	if slices.EqualFunc(current, updated, eq) {
366		return nil
367	}
368
369	s.config.RecentModels[modelType] = updated
370
371	if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
372		return fmt.Errorf("failed to persist recent models: %w", err)
373	}
374
375	return nil
376}
377
378// NewTestStore creates a ConfigStore for testing purposes.
379func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore {
380	return &ConfigStore{
381		config:      cfg,
382		loadedPaths: loadedPaths,
383	}
384}
385
386// ImportCopilot attempts to import a GitHub Copilot token from disk.
387func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
388	if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
389		return nil, false
390	}
391
392	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
393	if !hasDiskToken {
394		return nil, false
395	}
396
397	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
398	token, err := copilot.RefreshToken(context.TODO(), diskToken)
399	if err != nil {
400		slog.Error("Unable to import GitHub Copilot token", "error", err)
401		return nil, false
402	}
403
404	if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
405		return token, false
406	}
407
408	if err := cmp.Or(
409		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
410		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
411	); err != nil {
412		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
413	}
414
415	slog.Info("GitHub Copilot successfully imported")
416	return token, true
417}
418
419// StalenessResult contains the result of a staleness check.
420type StalenessResult struct {
421	Dirty   bool
422	Changed []string
423	Missing []string
424	Errors  map[string]error // stat errors by path
425}
426
427// ConfigStaleness checks whether any tracked config files have changed on disk
428// since the last snapshot. Returns dirty=true if any files changed or went
429// missing, along with sorted lists of affected paths. Stat errors are
430// captured in Errors map but still treated as non-existence for dirty detection.
431func (s *ConfigStore) ConfigStaleness() StalenessResult {
432	var result StalenessResult
433	result.Errors = make(map[string]error)
434
435	for _, path := range s.trackedConfigPaths {
436		snapshot, hadSnapshot := s.snapshots[path]
437
438		info, err := os.Stat(path)
439		exists := err == nil && !info.IsDir()
440
441		if err != nil && !os.IsNotExist(err) {
442			// Capture permission/IO errors separately from non-existence
443			result.Errors[path] = err
444			result.Dirty = true
445		}
446
447		if !exists {
448			if hadSnapshot && snapshot.Exists {
449				// File existed before but now missing
450				result.Missing = append(result.Missing, path)
451				result.Dirty = true
452			}
453			continue
454		}
455
456		// File exists now
457		if !hadSnapshot || !snapshot.Exists {
458			// File didn't exist before but does now
459			result.Changed = append(result.Changed, path)
460			result.Dirty = true
461			continue
462		}
463
464		// Check for content or metadata changes
465		if snapshot.Size != info.Size() || snapshot.ModTime != info.ModTime().UnixNano() {
466			result.Changed = append(result.Changed, path)
467			result.Dirty = true
468		}
469	}
470
471	// Sort for deterministic output
472	slices.Sort(result.Changed)
473	slices.Sort(result.Missing)
474
475	return result
476}
477
478// RefreshStalenessSnapshot captures fresh snapshots of all tracked config files.
479// Call this after reloading config to clear dirty state.
480func (s *ConfigStore) RefreshStalenessSnapshot() error {
481	if s.snapshots == nil {
482		s.snapshots = make(map[string]fileSnapshot)
483	}
484
485	for _, path := range s.trackedConfigPaths {
486		info, err := os.Stat(path)
487		exists := err == nil && !info.IsDir()
488
489		snapshot := fileSnapshot{
490			Path:   path,
491			Exists: exists,
492		}
493
494		if exists {
495			snapshot.Size = info.Size()
496			snapshot.ModTime = info.ModTime().UnixNano()
497		}
498
499		s.snapshots[path] = snapshot
500	}
501
502	return nil
503}
504
505// CaptureStalenessSnapshot captures snapshots for the given paths, building the
506// tracked config paths list. Paths are deduplicated and normalized.
507func (s *ConfigStore) CaptureStalenessSnapshot(paths []string) {
508	// Build unique set of normalized paths
509	seen := make(map[string]struct{})
510	for _, p := range paths {
511		if p == "" {
512			continue
513		}
514		// Normalize path
515		abs, err := filepath.Abs(p)
516		if err != nil {
517			abs = p
518		}
519		seen[abs] = struct{}{}
520	}
521
522	// Also track workspace and global config paths if set
523	if s.workspacePath != "" {
524		abs, err := filepath.Abs(s.workspacePath)
525		if err == nil {
526			seen[abs] = struct{}{}
527		}
528	}
529	if s.globalDataPath != "" {
530		abs, err := filepath.Abs(s.globalDataPath)
531		if err == nil {
532			seen[abs] = struct{}{}
533		}
534	}
535
536	// Build sorted list for deterministic ordering
537	s.trackedConfigPaths = make([]string, 0, len(seen))
538	for p := range seen {
539		s.trackedConfigPaths = append(s.trackedConfigPaths, p)
540	}
541	slices.Sort(s.trackedConfigPaths)
542
543	// Capture initial snapshots
544	s.RefreshStalenessSnapshot()
545}
546
547// captureStalenessSnapshot is an alias for CaptureStalenessSnapshot for internal use.
548func (s *ConfigStore) captureStalenessSnapshot(paths []string) {
549	s.CaptureStalenessSnapshot(paths)
550}
551
552// ReloadFromDisk re-runs the config load/merge flow and updates the in-memory
553// config atomically. It rebuilds the staleness snapshot after successful reload.
554// On failure, the store state is rolled back to its previous state.
555func (s *ConfigStore) ReloadFromDisk(ctx context.Context) error {
556	if s.workingDir == "" {
557		return fmt.Errorf("cannot reload: working directory not set")
558	}
559
560	// Disable auto-reload during reload to prevent nested/re-entrant calls.
561	s.autoReloadDisabled = true
562	s.reloadInProgress = true
563	defer func() {
564		s.autoReloadDisabled = false
565		s.reloadInProgress = false
566	}()
567
568	configPaths := lookupConfigs(s.workingDir)
569	cfg, loadedPaths, err := loadFromConfigPaths(configPaths)
570	if err != nil {
571		return fmt.Errorf("failed to reload config: %w", err)
572	}
573
574	// Apply defaults (using existing data directory if set)
575	var dataDir string
576	if s.config != nil && s.config.Options != nil {
577		dataDir = s.config.Options.DataDirectory
578	}
579	cfg.setDefaults(s.workingDir, dataDir)
580
581	// Merge workspace config if present
582	workspacePath := filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName))
583	if wsData, err := os.ReadFile(workspacePath); err == nil && len(wsData) > 0 {
584		merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData))
585		if mergeErr == nil {
586			dataDir := cfg.Options.DataDirectory
587			*cfg = *merged
588			cfg.setDefaults(s.workingDir, dataDir)
589			loadedPaths = append(loadedPaths, workspacePath)
590		}
591	}
592
593	// Preserve runtime overrides
594	overrides := s.overrides
595
596	// Reconfigure providers
597	env := env.New()
598	resolver := NewShellVariableResolver(env)
599	providers, err := Providers(cfg)
600	if err != nil {
601		return fmt.Errorf("failed to load providers during reload: %w", err)
602	}
603
604	if err := cfg.configureProviders(s, env, resolver, providers); err != nil {
605		return fmt.Errorf("failed to configure providers during reload: %w", err)
606	}
607
608	// Save current state for potential rollback
609	oldConfig := s.config
610	oldLoadedPaths := s.loadedPaths
611	oldResolver := s.resolver
612	oldKnownProviders := s.knownProviders
613	oldOverrides := s.overrides
614	oldWorkspacePath := s.workspacePath
615
616	// Update store state BEFORE running model/agent setup (so they see new config)
617	s.config = cfg
618	s.loadedPaths = loadedPaths
619	s.resolver = resolver
620	s.knownProviders = providers
621	s.overrides = overrides
622	s.workspacePath = workspacePath
623
624	// Mirror startup flow: setup models and agents against NEW config
625	var setupErr error
626	if !cfg.IsConfigured() {
627		slog.Warn("No providers configured after reload")
628	} else {
629		if err := configureSelectedModels(s, providers, false); err != nil {
630			setupErr = fmt.Errorf("failed to configure selected models during reload: %w", err)
631		} else {
632			s.SetupAgents()
633		}
634	}
635
636	// Rollback on setup failure
637	if setupErr != nil {
638		s.config = oldConfig
639		s.loadedPaths = oldLoadedPaths
640		s.resolver = oldResolver
641		s.knownProviders = oldKnownProviders
642		s.overrides = oldOverrides
643		s.workspacePath = oldWorkspacePath
644		return setupErr
645	}
646
647	// Rebuild staleness tracking
648	s.captureStalenessSnapshot(loadedPaths)
649
650	return nil
651}
652
653// autoReload conditionally reloads config from disk after writes.
654// It returns nil (no error) for expected skip cases: when auto-reload is
655// disabled during load/reload flows, or when working directory is not set
656// (e.g., during testing). Only actual reload failures return an error.
657func (s *ConfigStore) autoReload(ctx context.Context) error {
658	if s.autoReloadDisabled {
659		return nil // Expected skip: already in load/reload flow
660	}
661	if s.workingDir == "" {
662		return nil // Expected skip: working directory not set
663	}
664	return s.ReloadFromDisk(ctx)
665}