1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17 "github.com/tidwall/gjson"
18 "github.com/tidwall/sjson"
19)
20
21// ConfigStore is the single entry point for all config access. It owns the
22// pure-data Config, runtime state (working directory, resolver, known
23// providers), and persistence to both global and workspace config files.
24type ConfigStore struct {
25 config *Config
26 workingDir string
27 resolver VariableResolver
28 globalDataPath string // ~/.local/share/crush/crush.json
29 workspacePath string // .crush/crush.json
30 knownProviders []catwalk.Provider
31}
32
33// Config returns the pure-data config struct (read-only after load).
34func (s *ConfigStore) Config() *Config {
35 return s.config
36}
37
38// WorkingDir returns the current working directory.
39func (s *ConfigStore) WorkingDir() string {
40 return s.workingDir
41}
42
43// Resolver returns the variable resolver.
44func (s *ConfigStore) Resolver() VariableResolver {
45 return s.resolver
46}
47
48// Resolve resolves a variable reference using the configured resolver.
49func (s *ConfigStore) Resolve(key string) (string, error) {
50 if s.resolver == nil {
51 return "", fmt.Errorf("no variable resolver configured")
52 }
53 return s.resolver.ResolveValue(key)
54}
55
56// KnownProviders returns the list of known providers.
57func (s *ConfigStore) KnownProviders() []catwalk.Provider {
58 return s.knownProviders
59}
60
61// SetupAgents configures the coder and task agents on the config.
62func (s *ConfigStore) SetupAgents() {
63 s.config.SetupAgents()
64}
65
66// configPath returns the file path for the given scope.
67func (s *ConfigStore) configPath(scope Scope) string {
68 switch scope {
69 case ScopeWorkspace:
70 return s.workspacePath
71 default:
72 return s.globalDataPath
73 }
74}
75
76// HasConfigField checks whether a key exists in the config file for the given
77// scope.
78func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
79 data, err := os.ReadFile(s.configPath(scope))
80 if err != nil {
81 return false
82 }
83 return gjson.Get(string(data), key).Exists()
84}
85
86// SetConfigField sets a key/value pair in the config file for the given scope.
87func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
88 path := s.configPath(scope)
89 data, err := os.ReadFile(path)
90 if err != nil {
91 if os.IsNotExist(err) {
92 data = []byte("{}")
93 } else {
94 return fmt.Errorf("failed to read config file: %w", err)
95 }
96 }
97
98 newValue, err := sjson.Set(string(data), key, value)
99 if err != nil {
100 return fmt.Errorf("failed to set config field %s: %w", key, err)
101 }
102 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
103 return fmt.Errorf("failed to create config directory %q: %w", path, err)
104 }
105 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
106 return fmt.Errorf("failed to write config file: %w", err)
107 }
108 return nil
109}
110
111// RemoveConfigField removes a key from the config file for the given scope.
112func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
113 path := s.configPath(scope)
114 data, err := os.ReadFile(path)
115 if err != nil {
116 return fmt.Errorf("failed to read config file: %w", err)
117 }
118
119 newValue, err := sjson.Delete(string(data), key)
120 if err != nil {
121 return fmt.Errorf("failed to delete config field %s: %w", key, err)
122 }
123 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
124 return fmt.Errorf("failed to create config directory %q: %w", path, err)
125 }
126 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
127 return fmt.Errorf("failed to write config file: %w", err)
128 }
129 return nil
130}
131
132// UpdatePreferredModel updates the preferred model for the given type and
133// persists it to the config file at the given scope.
134func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
135 s.config.Models[modelType] = model
136 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
137 return fmt.Errorf("failed to update preferred model: %w", err)
138 }
139 if err := s.recordRecentModel(scope, modelType, model); err != nil {
140 return err
141 }
142 return nil
143}
144
145// SetCompactMode sets the compact mode setting and persists it.
146func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
147 if s.config.Options == nil {
148 s.config.Options = &Options{}
149 }
150 s.config.Options.TUI.CompactMode = enabled
151 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
152}
153
154// SetTransparentBackground sets the transparent background setting and persists it.
155func (s *ConfigStore) SetTransparentBackground(scope Scope, enabled bool) error {
156 if s.config.Options == nil {
157 s.config.Options = &Options{}
158 }
159 s.config.Options.TUI.Transparent = &enabled
160 return s.SetConfigField(scope, "options.tui.transparent", enabled)
161}
162
163// SetProviderAPIKey sets the API key for a provider and persists it.
164func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
165 var providerConfig ProviderConfig
166 var exists bool
167 var setKeyOrToken func()
168
169 switch v := apiKey.(type) {
170 case string:
171 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
172 return fmt.Errorf("failed to save api key to config file: %w", err)
173 }
174 setKeyOrToken = func() { providerConfig.APIKey = v }
175 case *oauth.Token:
176 if err := cmp.Or(
177 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
178 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
179 ); err != nil {
180 return err
181 }
182 setKeyOrToken = func() {
183 providerConfig.APIKey = v.AccessToken
184 providerConfig.OAuthToken = v
185 switch providerID {
186 case string(catwalk.InferenceProviderCopilot):
187 providerConfig.SetupGitHubCopilot()
188 }
189 }
190 }
191
192 providerConfig, exists = s.config.Providers.Get(providerID)
193 if exists {
194 setKeyOrToken()
195 s.config.Providers.Set(providerID, providerConfig)
196 return nil
197 }
198
199 var foundProvider *catwalk.Provider
200 for _, p := range s.knownProviders {
201 if string(p.ID) == providerID {
202 foundProvider = &p
203 break
204 }
205 }
206
207 if foundProvider != nil {
208 providerConfig = ProviderConfig{
209 ID: providerID,
210 Name: foundProvider.Name,
211 BaseURL: foundProvider.APIEndpoint,
212 Type: foundProvider.Type,
213 Disable: false,
214 ExtraHeaders: make(map[string]string),
215 ExtraParams: make(map[string]string),
216 Models: foundProvider.Models,
217 }
218 setKeyOrToken()
219 } else {
220 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
221 }
222 s.config.Providers.Set(providerID, providerConfig)
223 return nil
224}
225
226// RefreshOAuthToken refreshes the OAuth token for the given provider.
227func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
228 providerConfig, exists := s.config.Providers.Get(providerID)
229 if !exists {
230 return fmt.Errorf("provider %s not found", providerID)
231 }
232
233 if providerConfig.OAuthToken == nil {
234 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
235 }
236
237 var newToken *oauth.Token
238 var refreshErr error
239 switch providerID {
240 case string(catwalk.InferenceProviderCopilot):
241 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
242 case hyperp.Name:
243 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
244 default:
245 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
246 }
247 if refreshErr != nil {
248 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
249 }
250
251 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
252 providerConfig.OAuthToken = newToken
253 providerConfig.APIKey = newToken.AccessToken
254
255 switch providerID {
256 case string(catwalk.InferenceProviderCopilot):
257 providerConfig.SetupGitHubCopilot()
258 }
259
260 s.config.Providers.Set(providerID, providerConfig)
261
262 if err := cmp.Or(
263 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
264 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
265 ); err != nil {
266 return fmt.Errorf("failed to persist refreshed token: %w", err)
267 }
268
269 return nil
270}
271
272// recordRecentModel records a model in the recent models list.
273func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
274 if model.Provider == "" || model.Model == "" {
275 return nil
276 }
277
278 if s.config.RecentModels == nil {
279 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
280 }
281
282 eq := func(a, b SelectedModel) bool {
283 return a.Provider == b.Provider && a.Model == b.Model
284 }
285
286 entry := SelectedModel{
287 Provider: model.Provider,
288 Model: model.Model,
289 }
290
291 current := s.config.RecentModels[modelType]
292 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
293 return eq(existing, entry)
294 })
295
296 updated := append([]SelectedModel{entry}, withoutCurrent...)
297 if len(updated) > maxRecentModelsPerType {
298 updated = updated[:maxRecentModelsPerType]
299 }
300
301 if slices.EqualFunc(current, updated, eq) {
302 return nil
303 }
304
305 s.config.RecentModels[modelType] = updated
306
307 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
308 return fmt.Errorf("failed to persist recent models: %w", err)
309 }
310
311 return nil
312}
313
314// ImportCopilot attempts to import a GitHub Copilot token from disk.
315func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
316 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
317 return nil, false
318 }
319
320 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
321 if !hasDiskToken {
322 return nil, false
323 }
324
325 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
326 token, err := copilot.RefreshToken(context.TODO(), diskToken)
327 if err != nil {
328 slog.Error("Unable to import GitHub Copilot token", "error", err)
329 return nil, false
330 }
331
332 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
333 return token, false
334 }
335
336 if err := cmp.Or(
337 s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
338 s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
339 ); err != nil {
340 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
341 }
342
343 slog.Info("GitHub Copilot successfully imported")
344 return token, true
345}