store.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"slices"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17	"github.com/tidwall/gjson"
 18	"github.com/tidwall/sjson"
 19)
 20
 21// RuntimeOverrides holds per-session settings that are never persisted to
 22// disk. They are applied on top of the loaded Config and survive only for
 23// the lifetime of the process (or workspace).
 24type RuntimeOverrides struct {
 25	SkipPermissionRequests bool
 26}
 27
 28// ConfigStore is the single entry point for all config access. It owns the
 29// pure-data Config, runtime state (working directory, resolver, known
 30// providers), and persistence to both global and workspace config files.
 31type ConfigStore struct {
 32	config         *Config
 33	workingDir     string
 34	resolver       VariableResolver
 35	globalDataPath string // ~/.local/share/crush/crush.json
 36	workspacePath  string // .crush/crush.json
 37	knownProviders []catwalk.Provider
 38	overrides      RuntimeOverrides
 39}
 40
 41// Config returns the pure-data config struct (read-only after load).
 42func (s *ConfigStore) Config() *Config {
 43	return s.config
 44}
 45
 46// WorkingDir returns the current working directory.
 47func (s *ConfigStore) WorkingDir() string {
 48	return s.workingDir
 49}
 50
 51// Resolver returns the variable resolver.
 52func (s *ConfigStore) Resolver() VariableResolver {
 53	return s.resolver
 54}
 55
 56// Resolve resolves a variable reference using the configured resolver.
 57func (s *ConfigStore) Resolve(key string) (string, error) {
 58	if s.resolver == nil {
 59		return "", fmt.Errorf("no variable resolver configured")
 60	}
 61	return s.resolver.ResolveValue(key)
 62}
 63
 64// KnownProviders returns the list of known providers.
 65func (s *ConfigStore) KnownProviders() []catwalk.Provider {
 66	return s.knownProviders
 67}
 68
 69// SetupAgents configures the coder and task agents on the config.
 70func (s *ConfigStore) SetupAgents() {
 71	s.config.SetupAgents()
 72}
 73
 74// Overrides returns the runtime overrides for this store.
 75func (s *ConfigStore) Overrides() *RuntimeOverrides {
 76	return &s.overrides
 77}
 78
 79// configPath returns the file path for the given scope.
 80func (s *ConfigStore) configPath(scope Scope) (string, error) {
 81	switch scope {
 82	case ScopeWorkspace:
 83		if s.workspacePath == "" {
 84			return "", ErrNoWorkspaceConfig
 85		}
 86		return s.workspacePath, nil
 87	default:
 88		return s.globalDataPath, nil
 89	}
 90}
 91
 92// HasConfigField checks whether a key exists in the config file for the given
 93// scope.
 94func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
 95	path, err := s.configPath(scope)
 96	if err != nil {
 97		return false
 98	}
 99	data, err := os.ReadFile(path)
100	if err != nil {
101		return false
102	}
103	return gjson.Get(string(data), key).Exists()
104}
105
106// SetConfigField sets a key/value pair in the config file for the given scope.
107func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
108	path, err := s.configPath(scope)
109	if err != nil {
110		return fmt.Errorf("%s: %w", key, err)
111	}
112	data, err := os.ReadFile(path)
113	if err != nil {
114		if os.IsNotExist(err) {
115			data = []byte("{}")
116		} else {
117			return fmt.Errorf("failed to read config file: %w", err)
118		}
119	}
120
121	newValue, err := sjson.Set(string(data), key, value)
122	if err != nil {
123		return fmt.Errorf("failed to set config field %s: %w", key, err)
124	}
125	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
126		return fmt.Errorf("failed to create config directory %q: %w", path, err)
127	}
128	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
129		return fmt.Errorf("failed to write config file: %w", err)
130	}
131	return nil
132}
133
134// RemoveConfigField removes a key from the config file for the given scope.
135func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
136	path, err := s.configPath(scope)
137	if err != nil {
138		return fmt.Errorf("%s: %w", key, err)
139	}
140	data, err := os.ReadFile(path)
141	if err != nil {
142		return fmt.Errorf("failed to read config file: %w", err)
143	}
144
145	newValue, err := sjson.Delete(string(data), key)
146	if err != nil {
147		return fmt.Errorf("failed to delete config field %s: %w", key, err)
148	}
149	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
150		return fmt.Errorf("failed to create config directory %q: %w", path, err)
151	}
152	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
153		return fmt.Errorf("failed to write config file: %w", err)
154	}
155	return nil
156}
157
158// UpdatePreferredModel updates the preferred model for the given type and
159// persists it to the config file at the given scope.
160func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
161	s.config.Models[modelType] = model
162	if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
163		return fmt.Errorf("failed to update preferred model: %w", err)
164	}
165	if err := s.recordRecentModel(scope, modelType, model); err != nil {
166		return err
167	}
168	return nil
169}
170
171// SetCompactMode sets the compact mode setting and persists it.
172func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
173	if s.config.Options == nil {
174		s.config.Options = &Options{}
175	}
176	s.config.Options.TUI.CompactMode = enabled
177	return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
178}
179
180// SetTransparentBackground sets the transparent background setting and persists it.
181func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
182	if s.config.Options == nil {
183		s.config.Options = &Options{}
184	}
185	s.config.Options.TUI.Transparent = &enabled
186	return s.SetConfigField(scope, "options.tui.transparent", enabled)
187}
188
189// SetProviderAPIKey sets the API key for a provider and persists it.
190func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
191	var providerConfig ProviderConfig
192	var exists bool
193	var setKeyOrToken func()
194
195	switch v := apiKey.(type) {
196	case string:
197		if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
198			return fmt.Errorf("failed to save api key to config file: %w", err)
199		}
200		setKeyOrToken = func() { providerConfig.APIKey = v }
201	case *oauth.Token:
202		if err := cmp.Or(
203			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
204			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
205		); err != nil {
206			return err
207		}
208		setKeyOrToken = func() {
209			providerConfig.APIKey = v.AccessToken
210			providerConfig.OAuthToken = v
211			switch providerID {
212			case string(catwalk.InferenceProviderCopilot):
213				providerConfig.SetupGitHubCopilot()
214			}
215		}
216	}
217
218	providerConfig, exists = s.config.Providers.Get(providerID)
219	if exists {
220		setKeyOrToken()
221		s.config.Providers.Set(providerID, providerConfig)
222		return nil
223	}
224
225	var foundProvider *catwalk.Provider
226	for _, p := range s.knownProviders {
227		if string(p.ID) == providerID {
228			foundProvider = &p
229			break
230		}
231	}
232
233	if foundProvider != nil {
234		providerConfig = ProviderConfig{
235			ID:           providerID,
236			Name:         foundProvider.Name,
237			BaseURL:      foundProvider.APIEndpoint,
238			Type:         foundProvider.Type,
239			Disable:      false,
240			ExtraHeaders: make(map[string]string),
241			ExtraParams:  make(map[string]string),
242			Models:       foundProvider.Models,
243		}
244		setKeyOrToken()
245	} else {
246		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
247	}
248	s.config.Providers.Set(providerID, providerConfig)
249	return nil
250}
251
252// RefreshOAuthToken refreshes the OAuth token for the given provider.
253func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
254	providerConfig, exists := s.config.Providers.Get(providerID)
255	if !exists {
256		return fmt.Errorf("provider %s not found", providerID)
257	}
258
259	if providerConfig.OAuthToken == nil {
260		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
261	}
262
263	var newToken *oauth.Token
264	var refreshErr error
265	switch providerID {
266	case string(catwalk.InferenceProviderCopilot):
267		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
268	case hyperp.Name:
269		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
270	default:
271		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
272	}
273	if refreshErr != nil {
274		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
275	}
276
277	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
278	providerConfig.OAuthToken = newToken
279	providerConfig.APIKey = newToken.AccessToken
280
281	switch providerID {
282	case string(catwalk.InferenceProviderCopilot):
283		providerConfig.SetupGitHubCopilot()
284	}
285
286	s.config.Providers.Set(providerID, providerConfig)
287
288	if err := cmp.Or(
289		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
290		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
291	); err != nil {
292		return fmt.Errorf("failed to persist refreshed token: %w", err)
293	}
294
295	return nil
296}
297
298// recordRecentModel records a model in the recent models list.
299func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
300	if model.Provider == "" || model.Model == "" {
301		return nil
302	}
303
304	if s.config.RecentModels == nil {
305		s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
306	}
307
308	eq := func(a, b SelectedModel) bool {
309		return a.Provider == b.Provider && a.Model == b.Model
310	}
311
312	entry := SelectedModel{
313		Provider: model.Provider,
314		Model:    model.Model,
315	}
316
317	current := s.config.RecentModels[modelType]
318	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
319		return eq(existing, entry)
320	})
321
322	updated := append([]SelectedModel{entry}, withoutCurrent...)
323	if len(updated) > maxRecentModelsPerType {
324		updated = updated[:maxRecentModelsPerType]
325	}
326
327	if slices.EqualFunc(current, updated, eq) {
328		return nil
329	}
330
331	s.config.RecentModels[modelType] = updated
332
333	if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
334		return fmt.Errorf("failed to persist recent models: %w", err)
335	}
336
337	return nil
338}
339
340// ImportCopilot attempts to import a GitHub Copilot token from disk.
341func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
342	if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
343		return nil, false
344	}
345
346	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
347	if !hasDiskToken {
348		return nil, false
349	}
350
351	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
352	token, err := copilot.RefreshToken(context.TODO(), diskToken)
353	if err != nil {
354		slog.Error("Unable to import GitHub Copilot token", "error", err)
355		return nil, false
356	}
357
358	if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
359		return token, false
360	}
361
362	if err := cmp.Or(
363		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
364		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
365	); err != nil {
366		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
367	}
368
369	slog.Info("GitHub Copilot successfully imported")
370	return token, true
371}