store.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"slices"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/env"
 15	"github.com/charmbracelet/crush/internal/oauth"
 16	"github.com/charmbracelet/crush/internal/oauth/copilot"
 17	"github.com/charmbracelet/crush/internal/oauth/hyper"
 18	"github.com/tidwall/gjson"
 19	"github.com/tidwall/sjson"
 20)
 21
 22// fileSnapshot captures metadata about a config file at a point in time.
 23type fileSnapshot struct {
 24	Path    string
 25	Exists  bool
 26	Size    int64
 27	ModTime int64 // UnixNano
 28}
 29
 30// RuntimeOverrides holds per-session settings that are never persisted to
 31// disk. They are applied on top of the loaded Config and survive only for
 32// the lifetime of the process (or workspace).
 33type RuntimeOverrides struct {
 34	SkipPermissionRequests bool
 35}
 36
 37// ConfigStore is the single entry point for all config access. It owns the
 38// pure-data Config, runtime state (working directory, resolver, known
 39// providers), and persistence to both global and workspace config files.
 40type ConfigStore struct {
 41	config             *Config
 42	workingDir         string
 43	resolver           VariableResolver
 44	globalDataPath     string   // ~/.local/share/crush/crush.json
 45	workspacePath      string   // .crush/crush.json
 46	loadedPaths        []string // config files that were successfully loaded
 47	knownProviders     []catwalk.Provider
 48	overrides          RuntimeOverrides
 49	trackedConfigPaths []string                // unique, normalized config file paths
 50	snapshots          map[string]fileSnapshot // path -> snapshot at last capture
 51}
 52
 53// Config returns the pure-data config struct (read-only after load).
 54func (s *ConfigStore) Config() *Config {
 55	return s.config
 56}
 57
 58// WorkingDir returns the current working directory.
 59func (s *ConfigStore) WorkingDir() string {
 60	return s.workingDir
 61}
 62
 63// Resolver returns the variable resolver.
 64func (s *ConfigStore) Resolver() VariableResolver {
 65	return s.resolver
 66}
 67
 68// Resolve resolves a variable reference using the configured resolver.
 69func (s *ConfigStore) Resolve(key string) (string, error) {
 70	if s.resolver == nil {
 71		return "", fmt.Errorf("no variable resolver configured")
 72	}
 73	return s.resolver.ResolveValue(key)
 74}
 75
 76// KnownProviders returns the list of known providers.
 77func (s *ConfigStore) KnownProviders() []catwalk.Provider {
 78	return s.knownProviders
 79}
 80
 81// SetupAgents configures the coder and task agents on the config.
 82func (s *ConfigStore) SetupAgents() {
 83	s.config.SetupAgents()
 84}
 85
 86// Overrides returns the runtime overrides for this store.
 87func (s *ConfigStore) Overrides() *RuntimeOverrides {
 88	return &s.overrides
 89}
 90
 91// LoadedPaths returns the config file paths that were successfully loaded.
 92func (s *ConfigStore) LoadedPaths() []string {
 93	return slices.Clone(s.loadedPaths)
 94}
 95
 96// configPath returns the file path for the given scope.
 97func (s *ConfigStore) configPath(scope Scope) (string, error) {
 98	switch scope {
 99	case ScopeWorkspace:
100		if s.workspacePath == "" {
101			return "", ErrNoWorkspaceConfig
102		}
103		return s.workspacePath, nil
104	default:
105		return s.globalDataPath, nil
106	}
107}
108
109// HasConfigField checks whether a key exists in the config file for the given
110// scope.
111func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
112	path, err := s.configPath(scope)
113	if err != nil {
114		return false
115	}
116	data, err := os.ReadFile(path)
117	if err != nil {
118		return false
119	}
120	return gjson.Get(string(data), key).Exists()
121}
122
123// SetConfigField sets a key/value pair in the config file for the given scope.
124func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
125	path, err := s.configPath(scope)
126	if err != nil {
127		return fmt.Errorf("%s: %w", key, err)
128	}
129	data, err := os.ReadFile(path)
130	if err != nil {
131		if os.IsNotExist(err) {
132			data = []byte("{}")
133		} else {
134			return fmt.Errorf("failed to read config file: %w", err)
135		}
136	}
137
138	newValue, err := sjson.Set(string(data), key, value)
139	if err != nil {
140		return fmt.Errorf("failed to set config field %s: %w", key, err)
141	}
142	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
143		return fmt.Errorf("failed to create config directory %q: %w", path, err)
144	}
145	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
146		return fmt.Errorf("failed to write config file: %w", err)
147	}
148	return nil
149}
150
151// RemoveConfigField removes a key from the config file for the given scope.
152func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
153	path, err := s.configPath(scope)
154	if err != nil {
155		return fmt.Errorf("%s: %w", key, err)
156	}
157	data, err := os.ReadFile(path)
158	if err != nil {
159		return fmt.Errorf("failed to read config file: %w", err)
160	}
161
162	newValue, err := sjson.Delete(string(data), key)
163	if err != nil {
164		return fmt.Errorf("failed to delete config field %s: %w", key, err)
165	}
166	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
167		return fmt.Errorf("failed to create config directory %q: %w", path, err)
168	}
169	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
170		return fmt.Errorf("failed to write config file: %w", err)
171	}
172	return nil
173}
174
175// UpdatePreferredModel updates the preferred model for the given type and
176// persists it to the config file at the given scope.
177func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
178	s.config.Models[modelType] = model
179	if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
180		return fmt.Errorf("failed to update preferred model: %w", err)
181	}
182	if err := s.recordRecentModel(scope, modelType, model); err != nil {
183		return err
184	}
185	return nil
186}
187
188// SetCompactMode sets the compact mode setting and persists it.
189func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
190	if s.config.Options == nil {
191		s.config.Options = &Options{}
192	}
193	s.config.Options.TUI.CompactMode = enabled
194	return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
195}
196
197// SetTransparentBackground sets the transparent background setting and persists it.
198func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
199	if s.config.Options == nil {
200		s.config.Options = &Options{}
201	}
202	s.config.Options.TUI.Transparent = &enabled
203	return s.SetConfigField(scope, "options.tui.transparent", enabled)
204}
205
206// SetProviderAPIKey sets the API key for a provider and persists it.
207func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
208	var providerConfig ProviderConfig
209	var exists bool
210	var setKeyOrToken func()
211
212	switch v := apiKey.(type) {
213	case string:
214		if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
215			return fmt.Errorf("failed to save api key to config file: %w", err)
216		}
217		setKeyOrToken = func() { providerConfig.APIKey = v }
218	case *oauth.Token:
219		if err := cmp.Or(
220			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
221			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
222		); err != nil {
223			return err
224		}
225		setKeyOrToken = func() {
226			providerConfig.APIKey = v.AccessToken
227			providerConfig.OAuthToken = v
228			switch providerID {
229			case string(catwalk.InferenceProviderCopilot):
230				providerConfig.SetupGitHubCopilot()
231			}
232		}
233	}
234
235	providerConfig, exists = s.config.Providers.Get(providerID)
236	if exists {
237		setKeyOrToken()
238		s.config.Providers.Set(providerID, providerConfig)
239		return nil
240	}
241
242	var foundProvider *catwalk.Provider
243	for _, p := range s.knownProviders {
244		if string(p.ID) == providerID {
245			foundProvider = &p
246			break
247		}
248	}
249
250	if foundProvider != nil {
251		providerConfig = ProviderConfig{
252			ID:           providerID,
253			Name:         foundProvider.Name,
254			BaseURL:      foundProvider.APIEndpoint,
255			Type:         foundProvider.Type,
256			Disable:      false,
257			ExtraHeaders: make(map[string]string),
258			ExtraParams:  make(map[string]string),
259			Models:       foundProvider.Models,
260		}
261		setKeyOrToken()
262	} else {
263		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
264	}
265	s.config.Providers.Set(providerID, providerConfig)
266	return nil
267}
268
269// RefreshOAuthToken refreshes the OAuth token for the given provider.
270func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
271	providerConfig, exists := s.config.Providers.Get(providerID)
272	if !exists {
273		return fmt.Errorf("provider %s not found", providerID)
274	}
275
276	if providerConfig.OAuthToken == nil {
277		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
278	}
279
280	var newToken *oauth.Token
281	var refreshErr error
282	switch providerID {
283	case string(catwalk.InferenceProviderCopilot):
284		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
285	case hyperp.Name:
286		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
287	default:
288		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
289	}
290	if refreshErr != nil {
291		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
292	}
293
294	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
295	providerConfig.OAuthToken = newToken
296	providerConfig.APIKey = newToken.AccessToken
297
298	switch providerID {
299	case string(catwalk.InferenceProviderCopilot):
300		providerConfig.SetupGitHubCopilot()
301	}
302
303	s.config.Providers.Set(providerID, providerConfig)
304
305	if err := cmp.Or(
306		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
307		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
308	); err != nil {
309		return fmt.Errorf("failed to persist refreshed token: %w", err)
310	}
311
312	return nil
313}
314
315// recordRecentModel records a model in the recent models list.
316func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
317	if model.Provider == "" || model.Model == "" {
318		return nil
319	}
320
321	if s.config.RecentModels == nil {
322		s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
323	}
324
325	eq := func(a, b SelectedModel) bool {
326		return a.Provider == b.Provider && a.Model == b.Model
327	}
328
329	entry := SelectedModel{
330		Provider: model.Provider,
331		Model:    model.Model,
332	}
333
334	current := s.config.RecentModels[modelType]
335	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
336		return eq(existing, entry)
337	})
338
339	updated := append([]SelectedModel{entry}, withoutCurrent...)
340	if len(updated) > maxRecentModelsPerType {
341		updated = updated[:maxRecentModelsPerType]
342	}
343
344	if slices.EqualFunc(current, updated, eq) {
345		return nil
346	}
347
348	s.config.RecentModels[modelType] = updated
349
350	if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
351		return fmt.Errorf("failed to persist recent models: %w", err)
352	}
353
354	return nil
355}
356
357// NewTestStore creates a ConfigStore for testing purposes.
358func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore {
359	return &ConfigStore{
360		config:      cfg,
361		loadedPaths: loadedPaths,
362	}
363}
364
365// ImportCopilot attempts to import a GitHub Copilot token from disk.
366func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
367	if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
368		return nil, false
369	}
370
371	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
372	if !hasDiskToken {
373		return nil, false
374	}
375
376	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
377	token, err := copilot.RefreshToken(context.TODO(), diskToken)
378	if err != nil {
379		slog.Error("Unable to import GitHub Copilot token", "error", err)
380		return nil, false
381	}
382
383	if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
384		return token, false
385	}
386
387	if err := cmp.Or(
388		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
389		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
390	); err != nil {
391		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
392	}
393
394	slog.Info("GitHub Copilot successfully imported")
395	return token, true
396}
397
398// StalenessResult contains the result of a staleness check.
399type StalenessResult struct {
400	Dirty   bool
401	Changed []string
402	Missing []string
403	Errors  map[string]error // stat errors by path
404}
405
406// ConfigStaleness checks whether any tracked config files have changed on disk
407// since the last snapshot. Returns dirty=true if any files changed or went
408// missing, along with sorted lists of affected paths. Stat errors are
409// captured in Errors map but still treated as non-existence for dirty detection.
410func (s *ConfigStore) ConfigStaleness() StalenessResult {
411	var result StalenessResult
412	result.Errors = make(map[string]error)
413
414	for _, path := range s.trackedConfigPaths {
415		snapshot, hadSnapshot := s.snapshots[path]
416
417		info, err := os.Stat(path)
418		exists := err == nil && !info.IsDir()
419
420		if err != nil && !os.IsNotExist(err) {
421			// Capture permission/IO errors separately from non-existence
422			result.Errors[path] = err
423			result.Dirty = true
424		}
425
426		if !exists {
427			if hadSnapshot && snapshot.Exists {
428				// File existed before but now missing
429				result.Missing = append(result.Missing, path)
430				result.Dirty = true
431			}
432			continue
433		}
434
435		// File exists now
436		if !hadSnapshot || !snapshot.Exists {
437			// File didn't exist before but does now
438			result.Changed = append(result.Changed, path)
439			result.Dirty = true
440			continue
441		}
442
443		// Check for content or metadata changes
444		if snapshot.Size != info.Size() || snapshot.ModTime != info.ModTime().UnixNano() {
445			result.Changed = append(result.Changed, path)
446			result.Dirty = true
447		}
448	}
449
450	// Sort for deterministic output
451	slices.Sort(result.Changed)
452	slices.Sort(result.Missing)
453
454	return result
455}
456
457// RefreshStalenessSnapshot captures fresh snapshots of all tracked config files.
458// Call this after reloading config to clear dirty state.
459func (s *ConfigStore) RefreshStalenessSnapshot() error {
460	if s.snapshots == nil {
461		s.snapshots = make(map[string]fileSnapshot)
462	}
463
464	for _, path := range s.trackedConfigPaths {
465		info, err := os.Stat(path)
466		exists := err == nil && !info.IsDir()
467
468		snapshot := fileSnapshot{
469			Path:   path,
470			Exists: exists,
471		}
472
473		if exists {
474			snapshot.Size = info.Size()
475			snapshot.ModTime = info.ModTime().UnixNano()
476		}
477
478		s.snapshots[path] = snapshot
479	}
480
481	return nil
482}
483
484// CaptureStalenessSnapshot captures snapshots for the given paths, building the
485// tracked config paths list. Paths are deduplicated and normalized.
486func (s *ConfigStore) CaptureStalenessSnapshot(paths []string) {
487	// Build unique set of normalized paths
488	seen := make(map[string]struct{})
489	for _, p := range paths {
490		if p == "" {
491			continue
492		}
493		// Normalize path
494		abs, err := filepath.Abs(p)
495		if err != nil {
496			abs = p
497		}
498		seen[abs] = struct{}{}
499	}
500
501	// Also track workspace and global config paths if set
502	if s.workspacePath != "" {
503		abs, err := filepath.Abs(s.workspacePath)
504		if err == nil {
505			seen[abs] = struct{}{}
506		}
507	}
508	if s.globalDataPath != "" {
509		abs, err := filepath.Abs(s.globalDataPath)
510		if err == nil {
511			seen[abs] = struct{}{}
512		}
513	}
514
515	// Build sorted list for deterministic ordering
516	s.trackedConfigPaths = make([]string, 0, len(seen))
517	for p := range seen {
518		s.trackedConfigPaths = append(s.trackedConfigPaths, p)
519	}
520	slices.Sort(s.trackedConfigPaths)
521
522	// Capture initial snapshots
523	s.RefreshStalenessSnapshot()
524}
525
526// captureStalenessSnapshot is an alias for CaptureStalenessSnapshot for internal use.
527func (s *ConfigStore) captureStalenessSnapshot(paths []string) {
528	s.CaptureStalenessSnapshot(paths)
529}
530
531// ReloadFromDisk re-runs the config load/merge flow and updates the in-memory
532// config safely. It rebuilds the staleness snapshot after successful reload.
533func (s *ConfigStore) ReloadFromDisk(ctx context.Context) error {
534	if s.workingDir == "" {
535		return fmt.Errorf("cannot reload: working directory not set")
536	}
537
538	configPaths := lookupConfigs(s.workingDir)
539	cfg, loadedPaths, err := loadFromConfigPaths(configPaths)
540	if err != nil {
541		return fmt.Errorf("failed to reload config: %w", err)
542	}
543
544	// Apply defaults (using existing data directory if set)
545	var dataDir string
546	if s.config != nil && s.config.Options != nil {
547		dataDir = s.config.Options.DataDirectory
548	}
549	cfg.setDefaults(s.workingDir, dataDir)
550
551	// Merge workspace config if present
552	workspacePath := filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName))
553	if wsData, err := os.ReadFile(workspacePath); err == nil && len(wsData) > 0 {
554		merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData))
555		if mergeErr == nil {
556			dataDir := cfg.Options.DataDirectory
557			*cfg = *merged
558			cfg.setDefaults(s.workingDir, dataDir)
559			loadedPaths = append(loadedPaths, workspacePath)
560		}
561	}
562
563	// Preserve runtime overrides
564	overrides := s.overrides
565
566	// Reconfigure providers
567	env := env.New()
568	resolver := NewShellVariableResolver(env)
569	providers, err := Providers(cfg)
570	if err != nil {
571		return fmt.Errorf("failed to load providers during reload: %w", err)
572	}
573
574	if err := cfg.configureProviders(s, env, resolver, providers); err != nil {
575		return fmt.Errorf("failed to configure providers during reload: %w", err)
576	}
577
578	// Update store state BEFORE running model/agent setup (so they see new config)
579	s.config = cfg
580	s.loadedPaths = loadedPaths
581	s.resolver = resolver
582	s.knownProviders = providers
583	s.overrides = overrides
584	s.workspacePath = workspacePath
585
586	// Mirror startup flow: setup models and agents against NEW config
587	if !cfg.IsConfigured() {
588		slog.Warn("No providers configured after reload")
589	} else {
590		if err := configureSelectedModels(s, providers); err != nil {
591			return fmt.Errorf("failed to configure selected models during reload: %w", err)
592		}
593		s.SetupAgents()
594	}
595
596	// Rebuild staleness tracking
597	s.captureStalenessSnapshot(loadedPaths)
598
599	return nil
600}