1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17 "github.com/tidwall/gjson"
18 "github.com/tidwall/sjson"
19)
20
21// RuntimeOverrides holds per-session settings that are never persisted to
22// disk. They are applied on top of the loaded Config and survive only for
23// the lifetime of the process (or workspace).
24type RuntimeOverrides struct {
25 SkipPermissionRequests bool
26}
27
28// ConfigStore is the single entry point for all config access. It owns the
29// pure-data Config, runtime state (working directory, resolver, known
30// providers), and persistence to both global and workspace config files.
31type ConfigStore struct {
32 config *Config
33 workingDir string
34 resolver VariableResolver
35 globalDataPath string // ~/.local/share/crush/crush.json
36 workspacePath string // .crush/crush.json
37 knownProviders []catwalk.Provider
38 overrides RuntimeOverrides
39}
40
41// Config returns the pure-data config struct (read-only after load).
42func (s *ConfigStore) Config() *Config {
43 return s.config
44}
45
46// WorkingDir returns the current working directory.
47func (s *ConfigStore) WorkingDir() string {
48 return s.workingDir
49}
50
51// Resolver returns the variable resolver.
52func (s *ConfigStore) Resolver() VariableResolver {
53 return s.resolver
54}
55
56// Resolve resolves a variable reference using the configured resolver.
57func (s *ConfigStore) Resolve(key string) (string, error) {
58 if s.resolver == nil {
59 return "", fmt.Errorf("no variable resolver configured")
60 }
61 return s.resolver.ResolveValue(key)
62}
63
64// KnownProviders returns the list of known providers.
65func (s *ConfigStore) KnownProviders() []catwalk.Provider {
66 return s.knownProviders
67}
68
69// SetupAgents configures the coder and task agents on the config.
70func (s *ConfigStore) SetupAgents() {
71 s.config.SetupAgents()
72}
73
74// Overrides returns the runtime overrides for this store.
75func (s *ConfigStore) Overrides() *RuntimeOverrides {
76 return &s.overrides
77}
78
79// configPath returns the file path for the given scope.
80func (s *ConfigStore) configPath(scope Scope) (string, error) {
81 switch scope {
82 case ScopeWorkspace:
83 if s.workspacePath == "" {
84 return "", ErrNoWorkspaceConfig
85 }
86 return s.workspacePath, nil
87 default:
88 return s.globalDataPath, nil
89 }
90}
91
92// HasConfigField checks whether a key exists in the config file for the given
93// scope.
94func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
95 path, err := s.configPath(scope)
96 if err != nil {
97 return false
98 }
99 data, err := os.ReadFile(path)
100 if err != nil {
101 return false
102 }
103 return gjson.Get(string(data), key).Exists()
104}
105
106// SetConfigField sets a key/value pair in the config file for the given scope.
107func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
108 path, err := s.configPath(scope)
109 if err != nil {
110 return fmt.Errorf("%s: %w", key, err)
111 }
112 data, err := os.ReadFile(path)
113 if err != nil {
114 if os.IsNotExist(err) {
115 data = []byte("{}")
116 } else {
117 return fmt.Errorf("failed to read config file: %w", err)
118 }
119 }
120
121 newValue, err := sjson.Set(string(data), key, value)
122 if err != nil {
123 return fmt.Errorf("failed to set config field %s: %w", key, err)
124 }
125 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
126 return fmt.Errorf("failed to create config directory %q: %w", path, err)
127 }
128 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
129 return fmt.Errorf("failed to write config file: %w", err)
130 }
131 return nil
132}
133
134// RemoveConfigField removes a key from the config file for the given scope.
135func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
136 path, err := s.configPath(scope)
137 if err != nil {
138 return fmt.Errorf("%s: %w", key, err)
139 }
140 data, err := os.ReadFile(path)
141 if err != nil {
142 return fmt.Errorf("failed to read config file: %w", err)
143 }
144
145 newValue, err := sjson.Delete(string(data), key)
146 if err != nil {
147 return fmt.Errorf("failed to delete config field %s: %w", key, err)
148 }
149 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
150 return fmt.Errorf("failed to create config directory %q: %w", path, err)
151 }
152 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
153 return fmt.Errorf("failed to write config file: %w", err)
154 }
155 return nil
156}
157
158// UpdatePreferredModel updates the preferred model for the given type and
159// persists it to the config file at the given scope.
160func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
161 s.config.Models[modelType] = model
162 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
163 return fmt.Errorf("failed to update preferred model: %w", err)
164 }
165 if err := s.recordRecentModel(scope, modelType, model); err != nil {
166 return err
167 }
168 return nil
169}
170
171// SetCompactMode sets the compact mode setting and persists it.
172func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
173 if s.config.Options == nil {
174 s.config.Options = &Options{}
175 }
176 s.config.Options.TUI.CompactMode = enabled
177 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
178}
179
180// SetProviderAPIKey sets the API key for a provider and persists it.
181func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
182 var providerConfig ProviderConfig
183 var exists bool
184 var setKeyOrToken func()
185
186 switch v := apiKey.(type) {
187 case string:
188 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
189 return fmt.Errorf("failed to save api key to config file: %w", err)
190 }
191 setKeyOrToken = func() { providerConfig.APIKey = v }
192 case *oauth.Token:
193 if err := cmp.Or(
194 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
195 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
196 ); err != nil {
197 return err
198 }
199 setKeyOrToken = func() {
200 providerConfig.APIKey = v.AccessToken
201 providerConfig.OAuthToken = v
202 switch providerID {
203 case string(catwalk.InferenceProviderCopilot):
204 providerConfig.SetupGitHubCopilot()
205 }
206 }
207 }
208
209 providerConfig, exists = s.config.Providers.Get(providerID)
210 if exists {
211 setKeyOrToken()
212 s.config.Providers.Set(providerID, providerConfig)
213 return nil
214 }
215
216 var foundProvider *catwalk.Provider
217 for _, p := range s.knownProviders {
218 if string(p.ID) == providerID {
219 foundProvider = &p
220 break
221 }
222 }
223
224 if foundProvider != nil {
225 providerConfig = ProviderConfig{
226 ID: providerID,
227 Name: foundProvider.Name,
228 BaseURL: foundProvider.APIEndpoint,
229 Type: foundProvider.Type,
230 Disable: false,
231 ExtraHeaders: make(map[string]string),
232 ExtraParams: make(map[string]string),
233 Models: foundProvider.Models,
234 }
235 setKeyOrToken()
236 } else {
237 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
238 }
239 s.config.Providers.Set(providerID, providerConfig)
240 return nil
241}
242
243// RefreshOAuthToken refreshes the OAuth token for the given provider.
244func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
245 providerConfig, exists := s.config.Providers.Get(providerID)
246 if !exists {
247 return fmt.Errorf("provider %s not found", providerID)
248 }
249
250 if providerConfig.OAuthToken == nil {
251 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
252 }
253
254 var newToken *oauth.Token
255 var refreshErr error
256 switch providerID {
257 case string(catwalk.InferenceProviderCopilot):
258 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
259 case hyperp.Name:
260 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
261 default:
262 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
263 }
264 if refreshErr != nil {
265 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
266 }
267
268 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
269 providerConfig.OAuthToken = newToken
270 providerConfig.APIKey = newToken.AccessToken
271
272 switch providerID {
273 case string(catwalk.InferenceProviderCopilot):
274 providerConfig.SetupGitHubCopilot()
275 }
276
277 s.config.Providers.Set(providerID, providerConfig)
278
279 if err := cmp.Or(
280 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
281 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
282 ); err != nil {
283 return fmt.Errorf("failed to persist refreshed token: %w", err)
284 }
285
286 return nil
287}
288
289// recordRecentModel records a model in the recent models list.
290func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
291 if model.Provider == "" || model.Model == "" {
292 return nil
293 }
294
295 if s.config.RecentModels == nil {
296 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
297 }
298
299 eq := func(a, b SelectedModel) bool {
300 return a.Provider == b.Provider && a.Model == b.Model
301 }
302
303 entry := SelectedModel{
304 Provider: model.Provider,
305 Model: model.Model,
306 }
307
308 current := s.config.RecentModels[modelType]
309 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
310 return eq(existing, entry)
311 })
312
313 updated := append([]SelectedModel{entry}, withoutCurrent...)
314 if len(updated) > maxRecentModelsPerType {
315 updated = updated[:maxRecentModelsPerType]
316 }
317
318 if slices.EqualFunc(current, updated, eq) {
319 return nil
320 }
321
322 s.config.RecentModels[modelType] = updated
323
324 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
325 return fmt.Errorf("failed to persist recent models: %w", err)
326 }
327
328 return nil
329}
330
331// ImportCopilot attempts to import a GitHub Copilot token from disk.
332func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
333 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
334 return nil, false
335 }
336
337 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
338 if !hasDiskToken {
339 return nil, false
340 }
341
342 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
343 token, err := copilot.RefreshToken(context.TODO(), diskToken)
344 if err != nil {
345 slog.Error("Unable to import GitHub Copilot token", "error", err)
346 return nil, false
347 }
348
349 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
350 return token, false
351 }
352
353 if err := cmp.Or(
354 s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
355 s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
356 ); err != nil {
357 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
358 }
359
360 slog.Info("GitHub Copilot successfully imported")
361 return token, true
362}