1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "slices"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17 "github.com/tidwall/gjson"
18 "github.com/tidwall/sjson"
19)
20
21// ConfigStore is the single entry point for all config access. It owns the
22// pure-data Config, runtime state (working directory, resolver, known
23// providers), and persistence to both global and workspace config files.
24type ConfigStore struct {
25 config *Config
26 workingDir string
27 resolver VariableResolver
28 globalDataPath string // ~/.local/share/crush/crush.json
29 workspacePath string // .crush/crush.json
30 knownProviders []catwalk.Provider
31}
32
33// Config returns the pure-data config struct (read-only after load).
34func (s *ConfigStore) Config() *Config {
35 return s.config
36}
37
38// WorkingDir returns the current working directory.
39func (s *ConfigStore) WorkingDir() string {
40 return s.workingDir
41}
42
43// Resolver returns the variable resolver.
44func (s *ConfigStore) Resolver() VariableResolver {
45 return s.resolver
46}
47
48// Resolve resolves a variable reference using the configured resolver.
49func (s *ConfigStore) Resolve(key string) (string, error) {
50 if s.resolver == nil {
51 return "", fmt.Errorf("no variable resolver configured")
52 }
53 return s.resolver.ResolveValue(key)
54}
55
56// KnownProviders returns the list of known providers.
57func (s *ConfigStore) KnownProviders() []catwalk.Provider {
58 return s.knownProviders
59}
60
61// SetupAgents configures the coder and task agents on the config.
62func (s *ConfigStore) SetupAgents() {
63 s.config.SetupAgents()
64}
65
66// configPath returns the file path for the given scope.
67func (s *ConfigStore) configPath(scope Scope) string {
68 switch scope {
69 case ScopeWorkspace:
70 return s.workspacePath
71 default:
72 return s.globalDataPath
73 }
74}
75
76// HasConfigField checks whether a key exists in the config file for the given
77// scope.
78func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
79 data, err := os.ReadFile(s.configPath(scope))
80 if err != nil {
81 return false
82 }
83 return gjson.Get(string(data), key).Exists()
84}
85
86// SetConfigField sets a key/value pair in the config file for the given scope.
87func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
88 path := s.configPath(scope)
89 data, err := os.ReadFile(path)
90 if err != nil {
91 if os.IsNotExist(err) {
92 data = []byte("{}")
93 } else {
94 return fmt.Errorf("failed to read config file: %w", err)
95 }
96 }
97
98 newValue, err := sjson.Set(string(data), key, value)
99 if err != nil {
100 return fmt.Errorf("failed to set config field %s: %w", key, err)
101 }
102 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
103 return fmt.Errorf("failed to create config directory %q: %w", path, err)
104 }
105 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
106 return fmt.Errorf("failed to write config file: %w", err)
107 }
108 return nil
109}
110
111// RemoveConfigField removes a key from the config file for the given scope.
112func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
113 path := s.configPath(scope)
114 data, err := os.ReadFile(path)
115 if err != nil {
116 return fmt.Errorf("failed to read config file: %w", err)
117 }
118
119 newValue, err := sjson.Delete(string(data), key)
120 if err != nil {
121 return fmt.Errorf("failed to delete config field %s: %w", key, err)
122 }
123 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
124 return fmt.Errorf("failed to create config directory %q: %w", path, err)
125 }
126 if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
127 return fmt.Errorf("failed to write config file: %w", err)
128 }
129 return nil
130}
131
132// UpdatePreferredModel updates the preferred model for the given type and
133// persists it to the config file at the given scope.
134func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
135 s.config.Models[modelType] = model
136 if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
137 return fmt.Errorf("failed to update preferred model: %w", err)
138 }
139 if err := s.recordRecentModel(scope, modelType, model); err != nil {
140 return err
141 }
142 return nil
143}
144
145// SetCompactMode sets the compact mode setting and persists it.
146func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
147 if s.config.Options == nil {
148 s.config.Options = &Options{}
149 }
150 s.config.Options.TUI.CompactMode = enabled
151 return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
152}
153
154// SetProviderAPIKey sets the API key for a provider and persists it.
155func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
156 var providerConfig ProviderConfig
157 var exists bool
158 var setKeyOrToken func()
159
160 switch v := apiKey.(type) {
161 case string:
162 if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
163 return fmt.Errorf("failed to save api key to config file: %w", err)
164 }
165 setKeyOrToken = func() { providerConfig.APIKey = v }
166 case *oauth.Token:
167 if err := cmp.Or(
168 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
169 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
170 ); err != nil {
171 return err
172 }
173 setKeyOrToken = func() {
174 providerConfig.APIKey = v.AccessToken
175 providerConfig.OAuthToken = v
176 switch providerID {
177 case string(catwalk.InferenceProviderCopilot):
178 providerConfig.SetupGitHubCopilot()
179 }
180 }
181 }
182
183 providerConfig, exists = s.config.Providers.Get(providerID)
184 if exists {
185 setKeyOrToken()
186 s.config.Providers.Set(providerID, providerConfig)
187 return nil
188 }
189
190 var foundProvider *catwalk.Provider
191 for _, p := range s.knownProviders {
192 if string(p.ID) == providerID {
193 foundProvider = &p
194 break
195 }
196 }
197
198 if foundProvider != nil {
199 providerConfig = ProviderConfig{
200 ID: providerID,
201 Name: foundProvider.Name,
202 BaseURL: foundProvider.APIEndpoint,
203 Type: foundProvider.Type,
204 Disable: false,
205 ExtraHeaders: make(map[string]string),
206 ExtraParams: make(map[string]string),
207 Models: foundProvider.Models,
208 }
209 setKeyOrToken()
210 } else {
211 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
212 }
213 s.config.Providers.Set(providerID, providerConfig)
214 return nil
215}
216
217// RefreshOAuthToken refreshes the OAuth token for the given provider.
218func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
219 providerConfig, exists := s.config.Providers.Get(providerID)
220 if !exists {
221 return fmt.Errorf("provider %s not found", providerID)
222 }
223
224 if providerConfig.OAuthToken == nil {
225 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
226 }
227
228 var newToken *oauth.Token
229 var refreshErr error
230 switch providerID {
231 case string(catwalk.InferenceProviderCopilot):
232 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
233 case hyperp.Name:
234 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
235 default:
236 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
237 }
238 if refreshErr != nil {
239 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
240 }
241
242 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
243 providerConfig.OAuthToken = newToken
244 providerConfig.APIKey = newToken.AccessToken
245
246 switch providerID {
247 case string(catwalk.InferenceProviderCopilot):
248 providerConfig.SetupGitHubCopilot()
249 }
250
251 s.config.Providers.Set(providerID, providerConfig)
252
253 if err := cmp.Or(
254 s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
255 s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
256 ); err != nil {
257 return fmt.Errorf("failed to persist refreshed token: %w", err)
258 }
259
260 return nil
261}
262
263// recordRecentModel records a model in the recent models list.
264func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
265 if model.Provider == "" || model.Model == "" {
266 return nil
267 }
268
269 if s.config.RecentModels == nil {
270 s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
271 }
272
273 eq := func(a, b SelectedModel) bool {
274 return a.Provider == b.Provider && a.Model == b.Model
275 }
276
277 entry := SelectedModel{
278 Provider: model.Provider,
279 Model: model.Model,
280 }
281
282 current := s.config.RecentModels[modelType]
283 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
284 return eq(existing, entry)
285 })
286
287 updated := append([]SelectedModel{entry}, withoutCurrent...)
288 if len(updated) > maxRecentModelsPerType {
289 updated = updated[:maxRecentModelsPerType]
290 }
291
292 if slices.EqualFunc(current, updated, eq) {
293 return nil
294 }
295
296 s.config.RecentModels[modelType] = updated
297
298 if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
299 return fmt.Errorf("failed to persist recent models: %w", err)
300 }
301
302 return nil
303}
304
305// ImportCopilot attempts to import a GitHub Copilot token from disk.
306func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
307 if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
308 return nil, false
309 }
310
311 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
312 if !hasDiskToken {
313 return nil, false
314 }
315
316 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
317 token, err := copilot.RefreshToken(context.TODO(), diskToken)
318 if err != nil {
319 slog.Error("Unable to import GitHub Copilot token", "error", err)
320 return nil, false
321 }
322
323 if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
324 return token, false
325 }
326
327 if err := cmp.Or(
328 s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
329 s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
330 ); err != nil {
331 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
332 }
333
334 slog.Info("GitHub Copilot successfully imported")
335 return token, true
336}