store.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"slices"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17	"github.com/tidwall/gjson"
 18	"github.com/tidwall/sjson"
 19)
 20
 21// ConfigStore is the single entry point for all config access. It owns the
 22// pure-data Config, runtime state (working directory, resolver, known
 23// providers), and persistence to both global and workspace config files.
 24type ConfigStore struct {
 25	config         *Config
 26	workingDir     string
 27	resolver       VariableResolver
 28	globalDataPath string // ~/.local/share/crush/crush.json
 29	workspacePath  string // .crush/crush.json
 30	knownProviders []catwalk.Provider
 31}
 32
 33// Config returns the pure-data config struct (read-only after load).
 34func (s *ConfigStore) Config() *Config {
 35	return s.config
 36}
 37
 38// WorkingDir returns the current working directory.
 39func (s *ConfigStore) WorkingDir() string {
 40	return s.workingDir
 41}
 42
 43// Resolver returns the variable resolver.
 44func (s *ConfigStore) Resolver() VariableResolver {
 45	return s.resolver
 46}
 47
 48// Resolve resolves a variable reference using the configured resolver.
 49func (s *ConfigStore) Resolve(key string) (string, error) {
 50	if s.resolver == nil {
 51		return "", fmt.Errorf("no variable resolver configured")
 52	}
 53	return s.resolver.ResolveValue(key)
 54}
 55
 56// KnownProviders returns the list of known providers.
 57func (s *ConfigStore) KnownProviders() []catwalk.Provider {
 58	return s.knownProviders
 59}
 60
 61// SetupAgents configures the coder and task agents on the config.
 62func (s *ConfigStore) SetupAgents() {
 63	s.config.SetupAgents()
 64}
 65
 66// configPath returns the file path for the given scope.
 67func (s *ConfigStore) configPath(scope Scope) string {
 68	switch scope {
 69	case ScopeWorkspace:
 70		return s.workspacePath
 71	default:
 72		return s.globalDataPath
 73	}
 74}
 75
 76// HasConfigField checks whether a key exists in the config file for the given
 77// scope.
 78func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
 79	data, err := os.ReadFile(s.configPath(scope))
 80	if err != nil {
 81		return false
 82	}
 83	return gjson.Get(string(data), key).Exists()
 84}
 85
 86// SetConfigField sets a key/value pair in the config file for the given scope.
 87func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
 88	path := s.configPath(scope)
 89	data, err := os.ReadFile(path)
 90	if err != nil {
 91		if os.IsNotExist(err) {
 92			data = []byte("{}")
 93		} else {
 94			return fmt.Errorf("failed to read config file: %w", err)
 95		}
 96	}
 97
 98	newValue, err := sjson.Set(string(data), key, value)
 99	if err != nil {
100		return fmt.Errorf("failed to set config field %s: %w", key, err)
101	}
102	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
103		return fmt.Errorf("failed to create config directory %q: %w", path, err)
104	}
105	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
106		return fmt.Errorf("failed to write config file: %w", err)
107	}
108	return nil
109}
110
111// RemoveConfigField removes a key from the config file for the given scope.
112func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
113	path := s.configPath(scope)
114	data, err := os.ReadFile(path)
115	if err != nil {
116		return fmt.Errorf("failed to read config file: %w", err)
117	}
118
119	newValue, err := sjson.Delete(string(data), key)
120	if err != nil {
121		return fmt.Errorf("failed to delete config field %s: %w", key, err)
122	}
123	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
124		return fmt.Errorf("failed to create config directory %q: %w", path, err)
125	}
126	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
127		return fmt.Errorf("failed to write config file: %w", err)
128	}
129	return nil
130}
131
132// UpdatePreferredModel updates the preferred model for the given type and
133// persists it to the config file at the given scope.
134func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
135	s.config.Models[modelType] = model
136	if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
137		return fmt.Errorf("failed to update preferred model: %w", err)
138	}
139	if err := s.recordRecentModel(scope, modelType, model); err != nil {
140		return err
141	}
142	return nil
143}
144
145// SetCompactMode sets the compact mode setting and persists it.
146func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
147	if s.config.Options == nil {
148		s.config.Options = &Options{}
149	}
150	s.config.Options.TUI.CompactMode = enabled
151	return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
152}
153
154// SetTransparentBackground sets the transparent background setting and persists it.
155func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
156	if s.config.Options == nil {
157		s.config.Options = &Options{}
158	}
159	s.config.Options.TUI.Transparent = &enabled
160	return s.SetConfigField(scope, "options.tui.transparent", enabled)
161}
162
163// SetProviderAPIKey sets the API key for a provider and persists it.
164func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
165	var providerConfig ProviderConfig
166	var exists bool
167	var setKeyOrToken func()
168
169	switch v := apiKey.(type) {
170	case string:
171		if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
172			return fmt.Errorf("failed to save api key to config file: %w", err)
173		}
174		setKeyOrToken = func() { providerConfig.APIKey = v }
175	case *oauth.Token:
176		if err := cmp.Or(
177			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
178			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
179		); err != nil {
180			return err
181		}
182		setKeyOrToken = func() {
183			providerConfig.APIKey = v.AccessToken
184			providerConfig.OAuthToken = v
185			switch providerID {
186			case string(catwalk.InferenceProviderCopilot):
187				providerConfig.SetupGitHubCopilot()
188			}
189		}
190	}
191
192	providerConfig, exists = s.config.Providers.Get(providerID)
193	if exists {
194		setKeyOrToken()
195		s.config.Providers.Set(providerID, providerConfig)
196		return nil
197	}
198
199	var foundProvider *catwalk.Provider
200	for _, p := range s.knownProviders {
201		if string(p.ID) == providerID {
202			foundProvider = &p
203			break
204		}
205	}
206
207	if foundProvider != nil {
208		providerConfig = ProviderConfig{
209			ID:           providerID,
210			Name:         foundProvider.Name,
211			BaseURL:      foundProvider.APIEndpoint,
212			Type:         foundProvider.Type,
213			Disable:      false,
214			ExtraHeaders: make(map[string]string),
215			ExtraParams:  make(map[string]string),
216			Models:       foundProvider.Models,
217		}
218		setKeyOrToken()
219	} else {
220		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
221	}
222	s.config.Providers.Set(providerID, providerConfig)
223	return nil
224}
225
226// RefreshOAuthToken refreshes the OAuth token for the given provider.
227func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
228	providerConfig, exists := s.config.Providers.Get(providerID)
229	if !exists {
230		return fmt.Errorf("provider %s not found", providerID)
231	}
232
233	if providerConfig.OAuthToken == nil {
234		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
235	}
236
237	var newToken *oauth.Token
238	var refreshErr error
239	switch providerID {
240	case string(catwalk.InferenceProviderCopilot):
241		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
242	case hyperp.Name:
243		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
244	default:
245		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
246	}
247	if refreshErr != nil {
248		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
249	}
250
251	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
252	providerConfig.OAuthToken = newToken
253	providerConfig.APIKey = newToken.AccessToken
254
255	switch providerID {
256	case string(catwalk.InferenceProviderCopilot):
257		providerConfig.SetupGitHubCopilot()
258	}
259
260	s.config.Providers.Set(providerID, providerConfig)
261
262	if err := cmp.Or(
263		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
264		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
265	); err != nil {
266		return fmt.Errorf("failed to persist refreshed token: %w", err)
267	}
268
269	return nil
270}
271
272// recordRecentModel records a model in the recent models list.
273func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
274	if model.Provider == "" || model.Model == "" {
275		return nil
276	}
277
278	if s.config.RecentModels == nil {
279		s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
280	}
281
282	eq := func(a, b SelectedModel) bool {
283		return a.Provider == b.Provider && a.Model == b.Model
284	}
285
286	entry := SelectedModel{
287		Provider: model.Provider,
288		Model:    model.Model,
289	}
290
291	current := s.config.RecentModels[modelType]
292	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
293		return eq(existing, entry)
294	})
295
296	updated := append([]SelectedModel{entry}, withoutCurrent...)
297	if len(updated) > maxRecentModelsPerType {
298		updated = updated[:maxRecentModelsPerType]
299	}
300
301	if slices.EqualFunc(current, updated, eq) {
302		return nil
303	}
304
305	s.config.RecentModels[modelType] = updated
306
307	if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
308		return fmt.Errorf("failed to persist recent models: %w", err)
309	}
310
311	return nil
312}
313
314// ImportCopilot attempts to import a GitHub Copilot token from disk.
315func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
316	if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
317		return nil, false
318	}
319
320	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
321	if !hasDiskToken {
322		return nil, false
323	}
324
325	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
326	token, err := copilot.RefreshToken(context.TODO(), diskToken)
327	if err != nil {
328		slog.Error("Unable to import GitHub Copilot token", "error", err)
329		return nil, false
330	}
331
332	if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
333		return token, false
334	}
335
336	if err := cmp.Or(
337		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
338		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
339	); err != nil {
340		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
341	}
342
343	slog.Info("GitHub Copilot successfully imported")
344	return token, true
345}