store.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"slices"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17	"github.com/tidwall/gjson"
 18	"github.com/tidwall/sjson"
 19)
 20
 21// ConfigStore is the single entry point for all config access. It owns the
 22// pure-data Config, runtime state (working directory, resolver, known
 23// providers), and persistence to both global and workspace config files.
 24type ConfigStore struct {
 25	config         *Config
 26	workingDir     string
 27	resolver       VariableResolver
 28	globalDataPath string // ~/.local/share/crush/crush.json
 29	workspacePath  string // .crush/crush.json
 30	knownProviders []catwalk.Provider
 31}
 32
 33// Config returns the pure-data config struct (read-only after load).
 34func (s *ConfigStore) Config() *Config {
 35	return s.config
 36}
 37
 38// WorkingDir returns the current working directory.
 39func (s *ConfigStore) WorkingDir() string {
 40	return s.workingDir
 41}
 42
 43// Resolver returns the variable resolver.
 44func (s *ConfigStore) Resolver() VariableResolver {
 45	return s.resolver
 46}
 47
 48// Resolve resolves a variable reference using the configured resolver.
 49func (s *ConfigStore) Resolve(key string) (string, error) {
 50	if s.resolver == nil {
 51		return "", fmt.Errorf("no variable resolver configured")
 52	}
 53	return s.resolver.ResolveValue(key)
 54}
 55
 56// KnownProviders returns the list of known providers.
 57func (s *ConfigStore) KnownProviders() []catwalk.Provider {
 58	return s.knownProviders
 59}
 60
 61// SetupAgents configures the coder and task agents on the config.
 62func (s *ConfigStore) SetupAgents() {
 63	s.config.SetupAgents()
 64}
 65
 66// configPath returns the file path for the given scope.
 67func (s *ConfigStore) configPath(scope Scope) string {
 68	switch scope {
 69	case ScopeWorkspace:
 70		return s.workspacePath
 71	default:
 72		return s.globalDataPath
 73	}
 74}
 75
 76// HasConfigField checks whether a key exists in the config file for the given
 77// scope.
 78func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
 79	data, err := os.ReadFile(s.configPath(scope))
 80	if err != nil {
 81		return false
 82	}
 83	return gjson.Get(string(data), key).Exists()
 84}
 85
 86// SetConfigField sets a key/value pair in the config file for the given scope.
 87func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
 88	path := s.configPath(scope)
 89	data, err := os.ReadFile(path)
 90	if err != nil {
 91		if os.IsNotExist(err) {
 92			data = []byte("{}")
 93		} else {
 94			return fmt.Errorf("failed to read config file: %w", err)
 95		}
 96	}
 97
 98	newValue, err := sjson.Set(string(data), key, value)
 99	if err != nil {
100		return fmt.Errorf("failed to set config field %s: %w", key, err)
101	}
102	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
103		return fmt.Errorf("failed to create config directory %q: %w", path, err)
104	}
105	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
106		return fmt.Errorf("failed to write config file: %w", err)
107	}
108	return nil
109}
110
111// RemoveConfigField removes a key from the config file for the given scope.
112func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
113	path := s.configPath(scope)
114	data, err := os.ReadFile(path)
115	if err != nil {
116		return fmt.Errorf("failed to read config file: %w", err)
117	}
118
119	newValue, err := sjson.Delete(string(data), key)
120	if err != nil {
121		return fmt.Errorf("failed to delete config field %s: %w", key, err)
122	}
123	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
124		return fmt.Errorf("failed to create config directory %q: %w", path, err)
125	}
126	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
127		return fmt.Errorf("failed to write config file: %w", err)
128	}
129	return nil
130}
131
132// UpdatePreferredModel updates the preferred model for the given type and
133// persists it to the config file at the given scope.
134func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
135	s.config.Models[modelType] = model
136	if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
137		return fmt.Errorf("failed to update preferred model: %w", err)
138	}
139	if err := s.recordRecentModel(scope, modelType, model); err != nil {
140		return err
141	}
142	return nil
143}
144
145// SetCompactMode sets the compact mode setting and persists it.
146func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
147	if s.config.Options == nil {
148		s.config.Options = &Options{}
149	}
150	s.config.Options.TUI.CompactMode = enabled
151	return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
152}
153
154// SetProviderAPIKey sets the API key for a provider and persists it.
155func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
156	var providerConfig ProviderConfig
157	var exists bool
158	var setKeyOrToken func()
159
160	switch v := apiKey.(type) {
161	case string:
162		if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
163			return fmt.Errorf("failed to save api key to config file: %w", err)
164		}
165		setKeyOrToken = func() { providerConfig.APIKey = v }
166	case *oauth.Token:
167		if err := cmp.Or(
168			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
169			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
170		); err != nil {
171			return err
172		}
173		setKeyOrToken = func() {
174			providerConfig.APIKey = v.AccessToken
175			providerConfig.OAuthToken = v
176			switch providerID {
177			case string(catwalk.InferenceProviderCopilot):
178				providerConfig.SetupGitHubCopilot()
179			}
180		}
181	}
182
183	providerConfig, exists = s.config.Providers.Get(providerID)
184	if exists {
185		setKeyOrToken()
186		s.config.Providers.Set(providerID, providerConfig)
187		return nil
188	}
189
190	var foundProvider *catwalk.Provider
191	for _, p := range s.knownProviders {
192		if string(p.ID) == providerID {
193			foundProvider = &p
194			break
195		}
196	}
197
198	if foundProvider != nil {
199		providerConfig = ProviderConfig{
200			ID:           providerID,
201			Name:         foundProvider.Name,
202			BaseURL:      foundProvider.APIEndpoint,
203			Type:         foundProvider.Type,
204			Disable:      false,
205			ExtraHeaders: make(map[string]string),
206			ExtraParams:  make(map[string]string),
207			Models:       foundProvider.Models,
208		}
209		setKeyOrToken()
210	} else {
211		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
212	}
213	s.config.Providers.Set(providerID, providerConfig)
214	return nil
215}
216
217// RefreshOAuthToken refreshes the OAuth token for the given provider.
218func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
219	providerConfig, exists := s.config.Providers.Get(providerID)
220	if !exists {
221		return fmt.Errorf("provider %s not found", providerID)
222	}
223
224	if providerConfig.OAuthToken == nil {
225		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
226	}
227
228	var newToken *oauth.Token
229	var refreshErr error
230	switch providerID {
231	case string(catwalk.InferenceProviderCopilot):
232		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
233	case hyperp.Name:
234		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
235	default:
236		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
237	}
238	if refreshErr != nil {
239		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
240	}
241
242	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
243	providerConfig.OAuthToken = newToken
244	providerConfig.APIKey = newToken.AccessToken
245
246	switch providerID {
247	case string(catwalk.InferenceProviderCopilot):
248		providerConfig.SetupGitHubCopilot()
249	}
250
251	s.config.Providers.Set(providerID, providerConfig)
252
253	if err := cmp.Or(
254		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
255		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
256	); err != nil {
257		return fmt.Errorf("failed to persist refreshed token: %w", err)
258	}
259
260	return nil
261}
262
263// recordRecentModel records a model in the recent models list.
264func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
265	if model.Provider == "" || model.Model == "" {
266		return nil
267	}
268
269	if s.config.RecentModels == nil {
270		s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
271	}
272
273	eq := func(a, b SelectedModel) bool {
274		return a.Provider == b.Provider && a.Model == b.Model
275	}
276
277	entry := SelectedModel{
278		Provider: model.Provider,
279		Model:    model.Model,
280	}
281
282	current := s.config.RecentModels[modelType]
283	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
284		return eq(existing, entry)
285	})
286
287	updated := append([]SelectedModel{entry}, withoutCurrent...)
288	if len(updated) > maxRecentModelsPerType {
289		updated = updated[:maxRecentModelsPerType]
290	}
291
292	if slices.EqualFunc(current, updated, eq) {
293		return nil
294	}
295
296	s.config.RecentModels[modelType] = updated
297
298	if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
299		return fmt.Errorf("failed to persist recent models: %w", err)
300	}
301
302	return nil
303}
304
305// ImportCopilot attempts to import a GitHub Copilot token from disk.
306func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
307	if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
308		return nil, false
309	}
310
311	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
312	if !hasDiskToken {
313		return nil, false
314	}
315
316	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
317	token, err := copilot.RefreshToken(context.TODO(), diskToken)
318	if err != nil {
319		slog.Error("Unable to import GitHub Copilot token", "error", err)
320		return nil, false
321	}
322
323	if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
324		return token, false
325	}
326
327	if err := cmp.Or(
328		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
329		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
330	); err != nil {
331		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
332	}
333
334	slog.Info("GitHub Copilot successfully imported")
335	return token, true
336}