store.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"slices"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17	"github.com/tidwall/gjson"
 18	"github.com/tidwall/sjson"
 19)
 20
 21// RuntimeOverrides holds per-session settings that are never persisted to
 22// disk. They are applied on top of the loaded Config and survive only for
 23// the lifetime of the process (or workspace).
 24type RuntimeOverrides struct {
 25	SkipPermissionRequests bool
 26}
 27
 28// ConfigStore is the single entry point for all config access. It owns the
 29// pure-data Config, runtime state (working directory, resolver, known
 30// providers), and persistence to both global and workspace config files.
 31type ConfigStore struct {
 32	config         *Config
 33	workingDir     string
 34	resolver       VariableResolver
 35	globalDataPath string // ~/.local/share/crush/crush.json
 36	workspacePath  string // .crush/crush.json
 37	knownProviders []catwalk.Provider
 38	overrides      RuntimeOverrides
 39}
 40
 41// Config returns the pure-data config struct (read-only after load).
 42func (s *ConfigStore) Config() *Config {
 43	return s.config
 44}
 45
 46// WorkingDir returns the current working directory.
 47func (s *ConfigStore) WorkingDir() string {
 48	return s.workingDir
 49}
 50
 51// Resolver returns the variable resolver.
 52func (s *ConfigStore) Resolver() VariableResolver {
 53	return s.resolver
 54}
 55
 56// Resolve resolves a variable reference using the configured resolver.
 57func (s *ConfigStore) Resolve(key string) (string, error) {
 58	if s.resolver == nil {
 59		return "", fmt.Errorf("no variable resolver configured")
 60	}
 61	return s.resolver.ResolveValue(key)
 62}
 63
 64// KnownProviders returns the list of known providers.
 65func (s *ConfigStore) KnownProviders() []catwalk.Provider {
 66	return s.knownProviders
 67}
 68
 69// SetupAgents configures the coder and task agents on the config.
 70func (s *ConfigStore) SetupAgents() {
 71	s.config.SetupAgents()
 72}
 73
 74// Overrides returns the runtime overrides for this store.
 75func (s *ConfigStore) Overrides() *RuntimeOverrides {
 76	return &s.overrides
 77}
 78
 79// configPath returns the file path for the given scope.
 80func (s *ConfigStore) configPath(scope Scope) (string, error) {
 81	switch scope {
 82	case ScopeWorkspace:
 83		if s.workspacePath == "" {
 84			return "", ErrNoWorkspaceConfig
 85		}
 86		return s.workspacePath, nil
 87	default:
 88		return s.globalDataPath, nil
 89	}
 90}
 91
 92// HasConfigField checks whether a key exists in the config file for the given
 93// scope.
 94func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
 95	path, err := s.configPath(scope)
 96	if err != nil {
 97		return false
 98	}
 99	data, err := os.ReadFile(path)
100	if err != nil {
101		return false
102	}
103	return gjson.Get(string(data), key).Exists()
104}
105
106// SetConfigField sets a key/value pair in the config file for the given scope.
107func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
108	path, err := s.configPath(scope)
109	if err != nil {
110		return fmt.Errorf("%s: %w", key, err)
111	}
112	data, err := os.ReadFile(path)
113	if err != nil {
114		if os.IsNotExist(err) {
115			data = []byte("{}")
116		} else {
117			return fmt.Errorf("failed to read config file: %w", err)
118		}
119	}
120
121	newValue, err := sjson.Set(string(data), key, value)
122	if err != nil {
123		return fmt.Errorf("failed to set config field %s: %w", key, err)
124	}
125	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
126		return fmt.Errorf("failed to create config directory %q: %w", path, err)
127	}
128	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
129		return fmt.Errorf("failed to write config file: %w", err)
130	}
131	return nil
132}
133
134// RemoveConfigField removes a key from the config file for the given scope.
135func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
136	path, err := s.configPath(scope)
137	if err != nil {
138		return fmt.Errorf("%s: %w", key, err)
139	}
140	data, err := os.ReadFile(path)
141	if err != nil {
142		return fmt.Errorf("failed to read config file: %w", err)
143	}
144
145	newValue, err := sjson.Delete(string(data), key)
146	if err != nil {
147		return fmt.Errorf("failed to delete config field %s: %w", key, err)
148	}
149	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
150		return fmt.Errorf("failed to create config directory %q: %w", path, err)
151	}
152	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
153		return fmt.Errorf("failed to write config file: %w", err)
154	}
155	return nil
156}
157
158// UpdatePreferredModel updates the preferred model for the given type and
159// persists it to the config file at the given scope.
160func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
161	s.config.Models[modelType] = model
162	if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
163		return fmt.Errorf("failed to update preferred model: %w", err)
164	}
165	if err := s.recordRecentModel(scope, modelType, model); err != nil {
166		return err
167	}
168	return nil
169}
170
171// SetCompactMode sets the compact mode setting and persists it.
172func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
173	if s.config.Options == nil {
174		s.config.Options = &Options{}
175	}
176	s.config.Options.TUI.CompactMode = enabled
177	return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
178}
179
180// SetProviderAPIKey sets the API key for a provider and persists it.
181func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
182	var providerConfig ProviderConfig
183	var exists bool
184	var setKeyOrToken func()
185
186	switch v := apiKey.(type) {
187	case string:
188		if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
189			return fmt.Errorf("failed to save api key to config file: %w", err)
190		}
191		setKeyOrToken = func() { providerConfig.APIKey = v }
192	case *oauth.Token:
193		if err := cmp.Or(
194			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
195			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
196		); err != nil {
197			return err
198		}
199		setKeyOrToken = func() {
200			providerConfig.APIKey = v.AccessToken
201			providerConfig.OAuthToken = v
202			switch providerID {
203			case string(catwalk.InferenceProviderCopilot):
204				providerConfig.SetupGitHubCopilot()
205			}
206		}
207	}
208
209	providerConfig, exists = s.config.Providers.Get(providerID)
210	if exists {
211		setKeyOrToken()
212		s.config.Providers.Set(providerID, providerConfig)
213		return nil
214	}
215
216	var foundProvider *catwalk.Provider
217	for _, p := range s.knownProviders {
218		if string(p.ID) == providerID {
219			foundProvider = &p
220			break
221		}
222	}
223
224	if foundProvider != nil {
225		providerConfig = ProviderConfig{
226			ID:           providerID,
227			Name:         foundProvider.Name,
228			BaseURL:      foundProvider.APIEndpoint,
229			Type:         foundProvider.Type,
230			Disable:      false,
231			ExtraHeaders: make(map[string]string),
232			ExtraParams:  make(map[string]string),
233			Models:       foundProvider.Models,
234		}
235		setKeyOrToken()
236	} else {
237		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
238	}
239	s.config.Providers.Set(providerID, providerConfig)
240	return nil
241}
242
243// RefreshOAuthToken refreshes the OAuth token for the given provider.
244func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
245	providerConfig, exists := s.config.Providers.Get(providerID)
246	if !exists {
247		return fmt.Errorf("provider %s not found", providerID)
248	}
249
250	if providerConfig.OAuthToken == nil {
251		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
252	}
253
254	var newToken *oauth.Token
255	var refreshErr error
256	switch providerID {
257	case string(catwalk.InferenceProviderCopilot):
258		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
259	case hyperp.Name:
260		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
261	default:
262		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
263	}
264	if refreshErr != nil {
265		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
266	}
267
268	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
269	providerConfig.OAuthToken = newToken
270	providerConfig.APIKey = newToken.AccessToken
271
272	switch providerID {
273	case string(catwalk.InferenceProviderCopilot):
274		providerConfig.SetupGitHubCopilot()
275	}
276
277	s.config.Providers.Set(providerID, providerConfig)
278
279	if err := cmp.Or(
280		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
281		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
282	); err != nil {
283		return fmt.Errorf("failed to persist refreshed token: %w", err)
284	}
285
286	return nil
287}
288
289// recordRecentModel records a model in the recent models list.
290func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
291	if model.Provider == "" || model.Model == "" {
292		return nil
293	}
294
295	if s.config.RecentModels == nil {
296		s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
297	}
298
299	eq := func(a, b SelectedModel) bool {
300		return a.Provider == b.Provider && a.Model == b.Model
301	}
302
303	entry := SelectedModel{
304		Provider: model.Provider,
305		Model:    model.Model,
306	}
307
308	current := s.config.RecentModels[modelType]
309	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
310		return eq(existing, entry)
311	})
312
313	updated := append([]SelectedModel{entry}, withoutCurrent...)
314	if len(updated) > maxRecentModelsPerType {
315		updated = updated[:maxRecentModelsPerType]
316	}
317
318	if slices.EqualFunc(current, updated, eq) {
319		return nil
320	}
321
322	s.config.RecentModels[modelType] = updated
323
324	if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
325		return fmt.Errorf("failed to persist recent models: %w", err)
326	}
327
328	return nil
329}
330
331// ImportCopilot attempts to import a GitHub Copilot token from disk.
332func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
333	if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
334		return nil, false
335	}
336
337	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
338	if !hasDiskToken {
339		return nil, false
340	}
341
342	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
343	token, err := copilot.RefreshToken(context.TODO(), diskToken)
344	if err != nil {
345		slog.Error("Unable to import GitHub Copilot token", "error", err)
346		return nil, false
347	}
348
349	if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
350		return token, false
351	}
352
353	if err := cmp.Or(
354		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
355		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
356	); err != nil {
357		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
358	}
359
360	slog.Info("GitHub Copilot successfully imported")
361	return token, true
362}