1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17 "github.com/tidwall/gjson"
18 "github.com/tidwall/sjson"
19)
20
21// RuntimeOverrides holds per-session settings that are never persisted to
22// disk. They are applied on top of the loaded Config and survive only for
23// the lifetime of the process (or workspace).
24type RuntimeOverrides struct {
25 SkipPermissionRequests bool
26}
27
28// ConfigStore is the single entry point for all config access. It owns the
29// pure-data Config, runtime state (working directory, resolver, known
30// providers), and persistence to both global and workspace config files.
31type ConfigStore struct {
32 config *Config
33 workingDir string
34 resolver VariableResolver
35 globalDataPath string // ~/.local/share/crush/crush.json
36 workspacePath string // .crush/crush.json
37 knownProviders []catwalk.Provider
38 overrides RuntimeOverrides
39}
40
41// Config returns the pure-data config struct (read-only after load).
42func (s *ConfigStore) Config() *Config {
43 return s.config
44}
45
46// WorkingDir returns the current working directory.
47func (s *ConfigStore) WorkingDir() string {
48 return s.workingDir
49}
50
51// Resolver returns the variable resolver.
52func (s *ConfigStore) Resolver() VariableResolver {
53 return s.resolver
54}
55
56// Resolve resolves a variable reference using the configured resolver.
57func (s *ConfigStore) Resolve(key string) (string, error) {
58 if s.resolver == nil {
59 return "", fmt.Errorf("no variable resolver configured")
60 }
61 return s.resolver.ResolveValue(key)
62}
63
64// KnownProviders returns the list of known providers.
65func (s *ConfigStore) KnownProviders() []catwalk.Provider {
66 return s.knownProviders
67}
68
69// SetupAgents configures the coder and task agents on the config.
70func (s *ConfigStore) SetupAgents() {
71 s.config.SetupAgents()
72}
73
74// Overrides returns the runtime overrides for this store.
75func (s *ConfigStore) Overrides() *RuntimeOverrides {
76 return &s.overrides
77}
78
79// configPath returns the file path for the given scope.
80func (s *ConfigStore) configPath(scope Scope) (string, error) {
81 switch scope {
82 case ScopeWorkspace:
83 if s.workspacePath == "" {
84 return "", ErrNoWorkspaceConfig
85 }
86 return s.workspacePath, nil
87 default:
88 return s.globalDataPath, nil
89 }
90}
91
92// HasConfigField checks whether a key exists in the config file for the given
93// scope.
94func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
95 path, err := s.configPath(scope)
96 if err != nil {
97 return false
98 }
99 data, err := os.ReadFile(path)
100 if err != nil {
101 return false
102 }
103 return gjson.Get(string(data), key).Exists()
104}
105
106// SetConfigField sets a key/value pair in the config file for the given scope.
107func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
108 path, err := s.configPath(scope)
109 if err != nil {
110 return fmt.Errorf("%s: %w", key, err)
111 }
112 data, err := os.ReadFile(path)
113 if err != nil {
114 if os.IsNotExist(err) {
115 data = []byte("{}")
116 } else {
117 return fmt.Errorf("failed to read config file: %w", err)
118 }
119 }
120
121 newValue, err := sjson.Set(string(data), key, value)
122 if err != nil {
123 return fmt.Errorf("failed to set config field %s: %w", key, err)
124 }
125 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
126 return fmt.Errorf("failed to create config directory %q: %w", path, err)
127 }
128 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
129 return fmt.Errorf("failed to write config file: %w", err)
130 }
131 return nil
132}
133
134// RemoveConfigField removes a key from the config file for the given scope.
135func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
136 path, err := s.configPath(scope)
137 if err != nil {
138 return fmt.Errorf("%s: %w", key, err)
139 }
140 data, err := os.ReadFile(path)
141 if err != nil {
142 return fmt.Errorf("failed to read config file: %w", err)
143 }
144
145 newValue, err := sjson.Delete(string(data), key)
146 if err != nil {
147 return fmt.Errorf("failed to delete config field %s: %w", key, err)
148 }
149 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
150 return fmt.Errorf("failed to create config directory %q: %w", path, err)
151 }
152 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
153 return fmt.Errorf("failed to write config file: %w", err)
154 }
155 return nil
156}
157
158// UpdatePreferredModel updates the preferred model for the given type and
159// persists it to the config file at the given scope.
160func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
161 s.config.Models[modelType] = model
162 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
163 return fmt.Errorf("failed to update preferred model: %w", err)
164 }
165 if err := s.recordRecentModel(scope, modelType, model); err != nil {
166 return err
167 }
168 return nil
169}
170
171// SetCompactMode sets the compact mode setting and persists it.
172func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
173 if s.config.Options == nil {
174 s.config.Options = &Options{}
175 }
176 s.config.Options.TUI.CompactMode = enabled
177 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
178}
179
180// SetTransparentBackground sets the transparent background setting and persists it.
181func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
182 if s.config.Options == nil {
183 s.config.Options = &Options{}
184 }
185 s.config.Options.TUI.Transparent = &enabled
186 return s.SetConfigField(scope, "options.tui.transparent", enabled)
187}
188
189// SetProviderAPIKey sets the API key for a provider and persists it.
190func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
191 var providerConfig ProviderConfig
192 var exists bool
193 var setKeyOrToken func()
194
195 switch v := apiKey.(type) {
196 case string:
197 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
198 return fmt.Errorf("failed to save api key to config file: %w", err)
199 }
200 setKeyOrToken = func() { providerConfig.APIKey = v }
201 case *oauth.Token:
202 if err := cmp.Or(
203 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
204 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
205 ); err != nil {
206 return err
207 }
208 setKeyOrToken = func() {
209 providerConfig.APIKey = v.AccessToken
210 providerConfig.OAuthToken = v
211 switch providerID {
212 case string(catwalk.InferenceProviderCopilot):
213 providerConfig.SetupGitHubCopilot()
214 }
215 }
216 }
217
218 providerConfig, exists = s.config.Providers.Get(providerID)
219 if exists {
220 setKeyOrToken()
221 s.config.Providers.Set(providerID, providerConfig)
222 return nil
223 }
224
225 var foundProvider *catwalk.Provider
226 for _, p := range s.knownProviders {
227 if string(p.ID) == providerID {
228 foundProvider = &p
229 break
230 }
231 }
232
233 if foundProvider != nil {
234 providerConfig = ProviderConfig{
235 ID: providerID,
236 Name: foundProvider.Name,
237 BaseURL: foundProvider.APIEndpoint,
238 Type: foundProvider.Type,
239 Disable: false,
240 ExtraHeaders: make(map[string]string),
241 ExtraParams: make(map[string]string),
242 Models: foundProvider.Models,
243 }
244 setKeyOrToken()
245 } else {
246 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
247 }
248 s.config.Providers.Set(providerID, providerConfig)
249 return nil
250}
251
252// RefreshOAuthToken refreshes the OAuth token for the given provider.
253func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
254 providerConfig, exists := s.config.Providers.Get(providerID)
255 if !exists {
256 return fmt.Errorf("provider %s not found", providerID)
257 }
258
259 if providerConfig.OAuthToken == nil {
260 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
261 }
262
263 var newToken *oauth.Token
264 var refreshErr error
265 switch providerID {
266 case string(catwalk.InferenceProviderCopilot):
267 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
268 case hyperp.Name:
269 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
270 default:
271 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
272 }
273 if refreshErr != nil {
274 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
275 }
276
277 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
278 providerConfig.OAuthToken = newToken
279 providerConfig.APIKey = newToken.AccessToken
280
281 switch providerID {
282 case string(catwalk.InferenceProviderCopilot):
283 providerConfig.SetupGitHubCopilot()
284 }
285
286 s.config.Providers.Set(providerID, providerConfig)
287
288 if err := cmp.Or(
289 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
290 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
291 ); err != nil {
292 return fmt.Errorf("failed to persist refreshed token: %w", err)
293 }
294
295 return nil
296}
297
298// recordRecentModel records a model in the recent models list.
299func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
300 if model.Provider == "" || model.Model == "" {
301 return nil
302 }
303
304 if s.config.RecentModels == nil {
305 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
306 }
307
308 eq := func(a, b SelectedModel) bool {
309 return a.Provider == b.Provider && a.Model == b.Model
310 }
311
312 entry := SelectedModel{
313 Provider: model.Provider,
314 Model: model.Model,
315 }
316
317 current := s.config.RecentModels[modelType]
318 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
319 return eq(existing, entry)
320 })
321
322 updated := append([]SelectedModel{entry}, withoutCurrent...)
323 if len(updated) > maxRecentModelsPerType {
324 updated = updated[:maxRecentModelsPerType]
325 }
326
327 if slices.EqualFunc(current, updated, eq) {
328 return nil
329 }
330
331 s.config.RecentModels[modelType] = updated
332
333 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
334 return fmt.Errorf("failed to persist recent models: %w", err)
335 }
336
337 return nil
338}
339
340// ImportCopilot attempts to import a GitHub Copilot token from disk.
341func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
342 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
343 return nil, false
344 }
345
346 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
347 if !hasDiskToken {
348 return nil, false
349 }
350
351 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
352 token, err := copilot.RefreshToken(context.TODO(), diskToken)
353 if err != nil {
354 slog.Error("Unable to import GitHub Copilot token", "error", err)
355 return nil, false
356 }
357
358 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
359 return token, false
360 }
361
362 if err := cmp.Or(
363 s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
364 s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
365 ); err != nil {
366 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
367 }
368
369 slog.Info("GitHub Copilot successfully imported")
370 return token, true
371}