1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17 "github.com/tidwall/gjson"
18 "github.com/tidwall/sjson"
19)
20
21// RuntimeOverrides holds per-session settings that are never persisted to
22// disk. They are applied on top of the loaded Config and survive only for
23// the lifetime of the process (or workspace).
24type RuntimeOverrides struct {
25 SkipPermissionRequests bool
26}
27
28// ConfigStore is the single entry point for all config access. It owns the
29// pure-data Config, runtime state (working directory, resolver, known
30// providers), and persistence to both global and workspace config files.
31type ConfigStore struct {
32 config *Config
33 workingDir string
34 resolver VariableResolver
35 globalDataPath string // ~/.local/share/crush/crush.json
36 workspacePath string // .crush/crush.json
37 loadedPaths []string // config files that were successfully loaded
38 knownProviders []catwalk.Provider
39 overrides RuntimeOverrides
40}
41
42// Config returns the pure-data config struct (read-only after load).
43func (s *ConfigStore) Config() *Config {
44 return s.config
45}
46
47// WorkingDir returns the current working directory.
48func (s *ConfigStore) WorkingDir() string {
49 return s.workingDir
50}
51
52// Resolver returns the variable resolver.
53func (s *ConfigStore) Resolver() VariableResolver {
54 return s.resolver
55}
56
57// Resolve resolves a variable reference using the configured resolver.
58func (s *ConfigStore) Resolve(key string) (string, error) {
59 if s.resolver == nil {
60 return "", fmt.Errorf("no variable resolver configured")
61 }
62 return s.resolver.ResolveValue(key)
63}
64
65// KnownProviders returns the list of known providers.
66func (s *ConfigStore) KnownProviders() []catwalk.Provider {
67 return s.knownProviders
68}
69
70// SetupAgents configures the coder and task agents on the config.
71func (s *ConfigStore) SetupAgents() {
72 s.config.SetupAgents()
73}
74
75// Overrides returns the runtime overrides for this store.
76func (s *ConfigStore) Overrides() *RuntimeOverrides {
77 return &s.overrides
78}
79
80// LoadedPaths returns the config file paths that were successfully loaded.
81func (s *ConfigStore) LoadedPaths() []string {
82 return slices.Clone(s.loadedPaths)
83}
84
85// configPath returns the file path for the given scope.
86func (s *ConfigStore) configPath(scope Scope) (string, error) {
87 switch scope {
88 case ScopeWorkspace:
89 if s.workspacePath == "" {
90 return "", ErrNoWorkspaceConfig
91 }
92 return s.workspacePath, nil
93 default:
94 return s.globalDataPath, nil
95 }
96}
97
98// HasConfigField checks whether a key exists in the config file for the given
99// scope.
100func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
101 path, err := s.configPath(scope)
102 if err != nil {
103 return false
104 }
105 data, err := os.ReadFile(path)
106 if err != nil {
107 return false
108 }
109 return gjson.Get(string(data), key).Exists()
110}
111
112// SetConfigField sets a key/value pair in the config file for the given scope.
113func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
114 path, err := s.configPath(scope)
115 if err != nil {
116 return fmt.Errorf("%s: %w", key, err)
117 }
118 data, err := os.ReadFile(path)
119 if err != nil {
120 if os.IsNotExist(err) {
121 data = []byte("{}")
122 } else {
123 return fmt.Errorf("failed to read config file: %w", err)
124 }
125 }
126
127 newValue, err := sjson.Set(string(data), key, value)
128 if err != nil {
129 return fmt.Errorf("failed to set config field %s: %w", key, err)
130 }
131 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
132 return fmt.Errorf("failed to create config directory %q: %w", path, err)
133 }
134 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
135 return fmt.Errorf("failed to write config file: %w", err)
136 }
137 return nil
138}
139
140// RemoveConfigField removes a key from the config file for the given scope.
141func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
142 path, err := s.configPath(scope)
143 if err != nil {
144 return fmt.Errorf("%s: %w", key, err)
145 }
146 data, err := os.ReadFile(path)
147 if err != nil {
148 return fmt.Errorf("failed to read config file: %w", err)
149 }
150
151 newValue, err := sjson.Delete(string(data), key)
152 if err != nil {
153 return fmt.Errorf("failed to delete config field %s: %w", key, err)
154 }
155 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
156 return fmt.Errorf("failed to create config directory %q: %w", path, err)
157 }
158 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
159 return fmt.Errorf("failed to write config file: %w", err)
160 }
161 return nil
162}
163
164// UpdatePreferredModel updates the preferred model for the given type and
165// persists it to the config file at the given scope.
166func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
167 s.config.Models[modelType] = model
168 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
169 return fmt.Errorf("failed to update preferred model: %w", err)
170 }
171 if err := s.recordRecentModel(scope, modelType, model); err != nil {
172 return err
173 }
174 return nil
175}
176
177// SetCompactMode sets the compact mode setting and persists it.
178func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
179 if s.config.Options == nil {
180 s.config.Options = &Options{}
181 }
182 s.config.Options.TUI.CompactMode = enabled
183 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
184}
185
186// SetTransparentBackground sets the transparent background setting and persists it.
187func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
188 if s.config.Options == nil {
189 s.config.Options = &Options{}
190 }
191 s.config.Options.TUI.Transparent = &enabled
192 return s.SetConfigField(scope, "options.tui.transparent", enabled)
193}
194
195// SetProviderAPIKey sets the API key for a provider and persists it.
196func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
197 var providerConfig ProviderConfig
198 var exists bool
199 var setKeyOrToken func()
200
201 switch v := apiKey.(type) {
202 case string:
203 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
204 return fmt.Errorf("failed to save api key to config file: %w", err)
205 }
206 setKeyOrToken = func() { providerConfig.APIKey = v }
207 case *oauth.Token:
208 if err := cmp.Or(
209 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
210 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
211 ); err != nil {
212 return err
213 }
214 setKeyOrToken = func() {
215 providerConfig.APIKey = v.AccessToken
216 providerConfig.OAuthToken = v
217 switch providerID {
218 case string(catwalk.InferenceProviderCopilot):
219 providerConfig.SetupGitHubCopilot()
220 }
221 }
222 }
223
224 providerConfig, exists = s.config.Providers.Get(providerID)
225 if exists {
226 setKeyOrToken()
227 s.config.Providers.Set(providerID, providerConfig)
228 return nil
229 }
230
231 var foundProvider *catwalk.Provider
232 for _, p := range s.knownProviders {
233 if string(p.ID) == providerID {
234 foundProvider = &p
235 break
236 }
237 }
238
239 if foundProvider != nil {
240 providerConfig = ProviderConfig{
241 ID: providerID,
242 Name: foundProvider.Name,
243 BaseURL: foundProvider.APIEndpoint,
244 Type: foundProvider.Type,
245 Disable: false,
246 ExtraHeaders: make(map[string]string),
247 ExtraParams: make(map[string]string),
248 Models: foundProvider.Models,
249 }
250 setKeyOrToken()
251 } else {
252 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
253 }
254 s.config.Providers.Set(providerID, providerConfig)
255 return nil
256}
257
258// RefreshOAuthToken refreshes the OAuth token for the given provider.
259func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
260 providerConfig, exists := s.config.Providers.Get(providerID)
261 if !exists {
262 return fmt.Errorf("provider %s not found", providerID)
263 }
264
265 if providerConfig.OAuthToken == nil {
266 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
267 }
268
269 var newToken *oauth.Token
270 var refreshErr error
271 switch providerID {
272 case string(catwalk.InferenceProviderCopilot):
273 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
274 case hyperp.Name:
275 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
276 default:
277 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
278 }
279 if refreshErr != nil {
280 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
281 }
282
283 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
284 providerConfig.OAuthToken = newToken
285 providerConfig.APIKey = newToken.AccessToken
286
287 switch providerID {
288 case string(catwalk.InferenceProviderCopilot):
289 providerConfig.SetupGitHubCopilot()
290 }
291
292 s.config.Providers.Set(providerID, providerConfig)
293
294 if err := cmp.Or(
295 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
296 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
297 ); err != nil {
298 return fmt.Errorf("failed to persist refreshed token: %w", err)
299 }
300
301 return nil
302}
303
304// recordRecentModel records a model in the recent models list.
305func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
306 if model.Provider == "" || model.Model == "" {
307 return nil
308 }
309
310 if s.config.RecentModels == nil {
311 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
312 }
313
314 eq := func(a, b SelectedModel) bool {
315 return a.Provider == b.Provider && a.Model == b.Model
316 }
317
318 entry := SelectedModel{
319 Provider: model.Provider,
320 Model: model.Model,
321 }
322
323 current := s.config.RecentModels[modelType]
324 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
325 return eq(existing, entry)
326 })
327
328 updated := append([]SelectedModel{entry}, withoutCurrent...)
329 if len(updated) > maxRecentModelsPerType {
330 updated = updated[:maxRecentModelsPerType]
331 }
332
333 if slices.EqualFunc(current, updated, eq) {
334 return nil
335 }
336
337 s.config.RecentModels[modelType] = updated
338
339 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
340 return fmt.Errorf("failed to persist recent models: %w", err)
341 }
342
343 return nil
344}
345
346// NewTestStore creates a ConfigStore for testing purposes.
347func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore {
348 return &ConfigStore{
349 config: cfg,
350 loadedPaths: loadedPaths,
351 }
352}
353
354// ImportCopilot attempts to import a GitHub Copilot token from disk.
355func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
356 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
357 return nil, false
358 }
359
360 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
361 if !hasDiskToken {
362 return nil, false
363 }
364
365 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
366 token, err := copilot.RefreshToken(context.TODO(), diskToken)
367 if err != nil {
368 slog.Error("Unable to import GitHub Copilot token", "error", err)
369 return nil, false
370 }
371
372 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
373 return token, false
374 }
375
376 if err := cmp.Or(
377 s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
378 s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
379 ); err != nil {
380 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
381 }
382
383 slog.Info("GitHub Copilot successfully imported")
384 return token, true
385}