store.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"slices"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17	"github.com/tidwall/gjson"
 18	"github.com/tidwall/sjson"
 19)
 20
 21// RuntimeOverrides holds per-session settings that are never persisted to
 22// disk. They are applied on top of the loaded Config and survive only for
 23// the lifetime of the process (or workspace).
 24type RuntimeOverrides struct {
 25	SkipPermissionRequests bool
 26}
 27
 28// ConfigStore is the single entry point for all config access. It owns the
 29// pure-data Config, runtime state (working directory, resolver, known
 30// providers), and persistence to both global and workspace config files.
 31type ConfigStore struct {
 32	config         *Config
 33	workingDir     string
 34	resolver       VariableResolver
 35	globalDataPath string   // ~/.local/share/crush/crush.json
 36	workspacePath  string   // .crush/crush.json
 37	loadedPaths    []string // config files that were successfully loaded
 38	knownProviders []catwalk.Provider
 39	overrides      RuntimeOverrides
 40}
 41
 42// Config returns the pure-data config struct (read-only after load).
 43func (s *ConfigStore) Config() *Config {
 44	return s.config
 45}
 46
 47// WorkingDir returns the current working directory.
 48func (s *ConfigStore) WorkingDir() string {
 49	return s.workingDir
 50}
 51
 52// Resolver returns the variable resolver.
 53func (s *ConfigStore) Resolver() VariableResolver {
 54	return s.resolver
 55}
 56
 57// Resolve resolves a variable reference using the configured resolver.
 58func (s *ConfigStore) Resolve(key string) (string, error) {
 59	if s.resolver == nil {
 60		return "", fmt.Errorf("no variable resolver configured")
 61	}
 62	return s.resolver.ResolveValue(key)
 63}
 64
 65// KnownProviders returns the list of known providers.
 66func (s *ConfigStore) KnownProviders() []catwalk.Provider {
 67	return s.knownProviders
 68}
 69
 70// SetupAgents configures the coder and task agents on the config.
 71func (s *ConfigStore) SetupAgents() {
 72	s.config.SetupAgents()
 73}
 74
 75// Overrides returns the runtime overrides for this store.
 76func (s *ConfigStore) Overrides() *RuntimeOverrides {
 77	return &s.overrides
 78}
 79
 80// LoadedPaths returns the config file paths that were successfully loaded.
 81func (s *ConfigStore) LoadedPaths() []string {
 82	return slices.Clone(s.loadedPaths)
 83}
 84
 85// configPath returns the file path for the given scope.
 86func (s *ConfigStore) configPath(scope Scope) (string, error) {
 87	switch scope {
 88	case ScopeWorkspace:
 89		if s.workspacePath == "" {
 90			return "", ErrNoWorkspaceConfig
 91		}
 92		return s.workspacePath, nil
 93	default:
 94		return s.globalDataPath, nil
 95	}
 96}
 97
 98// HasConfigField checks whether a key exists in the config file for the given
 99// scope.
100func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
101	path, err := s.configPath(scope)
102	if err != nil {
103		return false
104	}
105	data, err := os.ReadFile(path)
106	if err != nil {
107		return false
108	}
109	return gjson.Get(string(data), key).Exists()
110}
111
112// SetConfigField sets a key/value pair in the config file for the given scope.
113func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
114	path, err := s.configPath(scope)
115	if err != nil {
116		return fmt.Errorf("%s: %w", key, err)
117	}
118	data, err := os.ReadFile(path)
119	if err != nil {
120		if os.IsNotExist(err) {
121			data = []byte("{}")
122		} else {
123			return fmt.Errorf("failed to read config file: %w", err)
124		}
125	}
126
127	newValue, err := sjson.Set(string(data), key, value)
128	if err != nil {
129		return fmt.Errorf("failed to set config field %s: %w", key, err)
130	}
131	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
132		return fmt.Errorf("failed to create config directory %q: %w", path, err)
133	}
134	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
135		return fmt.Errorf("failed to write config file: %w", err)
136	}
137	return nil
138}
139
140// RemoveConfigField removes a key from the config file for the given scope.
141func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
142	path, err := s.configPath(scope)
143	if err != nil {
144		return fmt.Errorf("%s: %w", key, err)
145	}
146	data, err := os.ReadFile(path)
147	if err != nil {
148		return fmt.Errorf("failed to read config file: %w", err)
149	}
150
151	newValue, err := sjson.Delete(string(data), key)
152	if err != nil {
153		return fmt.Errorf("failed to delete config field %s: %w", key, err)
154	}
155	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
156		return fmt.Errorf("failed to create config directory %q: %w", path, err)
157	}
158	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
159		return fmt.Errorf("failed to write config file: %w", err)
160	}
161	return nil
162}
163
164// UpdatePreferredModel updates the preferred model for the given type and
165// persists it to the config file at the given scope.
166func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
167	s.config.Models[modelType] = model
168	if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
169		return fmt.Errorf("failed to update preferred model: %w", err)
170	}
171	if err := s.recordRecentModel(scope, modelType, model); err != nil {
172		return err
173	}
174	return nil
175}
176
177// SetCompactMode sets the compact mode setting and persists it.
178func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
179	if s.config.Options == nil {
180		s.config.Options = &Options{}
181	}
182	s.config.Options.TUI.CompactMode = enabled
183	return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
184}
185
186// SetTransparentBackground sets the transparent background setting and persists it.
187func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
188	if s.config.Options == nil {
189		s.config.Options = &Options{}
190	}
191	s.config.Options.TUI.Transparent = &enabled
192	return s.SetConfigField(scope, "options.tui.transparent", enabled)
193}
194
195// SetProviderAPIKey sets the API key for a provider and persists it.
196func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
197	var providerConfig ProviderConfig
198	var exists bool
199	var setKeyOrToken func()
200
201	switch v := apiKey.(type) {
202	case string:
203		if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
204			return fmt.Errorf("failed to save api key to config file: %w", err)
205		}
206		setKeyOrToken = func() { providerConfig.APIKey = v }
207	case *oauth.Token:
208		if err := cmp.Or(
209			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
210			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
211		); err != nil {
212			return err
213		}
214		setKeyOrToken = func() {
215			providerConfig.APIKey = v.AccessToken
216			providerConfig.OAuthToken = v
217			switch providerID {
218			case string(catwalk.InferenceProviderCopilot):
219				providerConfig.SetupGitHubCopilot()
220			}
221		}
222	}
223
224	providerConfig, exists = s.config.Providers.Get(providerID)
225	if exists {
226		setKeyOrToken()
227		s.config.Providers.Set(providerID, providerConfig)
228		return nil
229	}
230
231	var foundProvider *catwalk.Provider
232	for _, p := range s.knownProviders {
233		if string(p.ID) == providerID {
234			foundProvider = &p
235			break
236		}
237	}
238
239	if foundProvider != nil {
240		providerConfig = ProviderConfig{
241			ID:           providerID,
242			Name:         foundProvider.Name,
243			BaseURL:      foundProvider.APIEndpoint,
244			Type:         foundProvider.Type,
245			Disable:      false,
246			ExtraHeaders: make(map[string]string),
247			ExtraParams:  make(map[string]string),
248			Models:       foundProvider.Models,
249		}
250		setKeyOrToken()
251	} else {
252		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
253	}
254	s.config.Providers.Set(providerID, providerConfig)
255	return nil
256}
257
258// RefreshOAuthToken refreshes the OAuth token for the given provider.
259func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
260	providerConfig, exists := s.config.Providers.Get(providerID)
261	if !exists {
262		return fmt.Errorf("provider %s not found", providerID)
263	}
264
265	if providerConfig.OAuthToken == nil {
266		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
267	}
268
269	var newToken *oauth.Token
270	var refreshErr error
271	switch providerID {
272	case string(catwalk.InferenceProviderCopilot):
273		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
274	case hyperp.Name:
275		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
276	default:
277		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
278	}
279	if refreshErr != nil {
280		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
281	}
282
283	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
284	providerConfig.OAuthToken = newToken
285	providerConfig.APIKey = newToken.AccessToken
286
287	switch providerID {
288	case string(catwalk.InferenceProviderCopilot):
289		providerConfig.SetupGitHubCopilot()
290	}
291
292	s.config.Providers.Set(providerID, providerConfig)
293
294	if err := cmp.Or(
295		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
296		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
297	); err != nil {
298		return fmt.Errorf("failed to persist refreshed token: %w", err)
299	}
300
301	return nil
302}
303
304// recordRecentModel records a model in the recent models list.
305func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
306	if model.Provider == "" || model.Model == "" {
307		return nil
308	}
309
310	if s.config.RecentModels == nil {
311		s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
312	}
313
314	eq := func(a, b SelectedModel) bool {
315		return a.Provider == b.Provider && a.Model == b.Model
316	}
317
318	entry := SelectedModel{
319		Provider: model.Provider,
320		Model:    model.Model,
321	}
322
323	current := s.config.RecentModels[modelType]
324	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
325		return eq(existing, entry)
326	})
327
328	updated := append([]SelectedModel{entry}, withoutCurrent...)
329	if len(updated) > maxRecentModelsPerType {
330		updated = updated[:maxRecentModelsPerType]
331	}
332
333	if slices.EqualFunc(current, updated, eq) {
334		return nil
335	}
336
337	s.config.RecentModels[modelType] = updated
338
339	if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
340		return fmt.Errorf("failed to persist recent models: %w", err)
341	}
342
343	return nil
344}
345
346// NewTestStore creates a ConfigStore for testing purposes.
347func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore {
348	return &ConfigStore{
349		config:      cfg,
350		loadedPaths: loadedPaths,
351	}
352}
353
354// ImportCopilot attempts to import a GitHub Copilot token from disk.
355func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
356	if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
357		return nil, false
358	}
359
360	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
361	if !hasDiskToken {
362		return nil, false
363	}
364
365	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
366	token, err := copilot.RefreshToken(context.TODO(), diskToken)
367	if err != nil {
368		slog.Error("Unable to import GitHub Copilot token", "error", err)
369		return nil, false
370	}
371
372	if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
373		return token, false
374	}
375
376	if err := cmp.Or(
377		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
378		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
379	); err != nil {
380		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
381	}
382
383	slog.Info("GitHub Copilot successfully imported")
384	return token, true
385}