1package config
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "slices"
9 "sync"
10 "testing"
11
12 "charm.land/catwalk/pkg/catwalk"
13 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
14 "github.com/charmbracelet/crush/internal/oauth"
15 "github.com/charmbracelet/crush/internal/oauth/copilot"
16 "github.com/charmbracelet/crush/internal/oauth/hyper"
17)
18
19// Service is the central access point for configuration. It wraps the
20// raw Config data and owns all internal state that was previously held
21// as unexported fields on Config (resolver, store, known providers,
22// working directory).
23type Service struct {
24 mu sync.RWMutex
25 cfg *Config
26 store Store
27 resolver VariableResolver
28 workingDir string
29 knownProviders []catwalk.Provider
30 agents map[string]Agent
31}
32
33// WorkingDir returns the working directory.
34func (s *Service) WorkingDir() string {
35 return s.workingDir
36}
37
38// EnabledProviders returns all non-disabled provider configs.
39func (s *Service) EnabledProviders() []ProviderConfig {
40 s.mu.RLock()
41 defer s.mu.RUnlock()
42 return s.cfg.EnabledProviders()
43}
44
45// IsConfigured returns true if at least one provider is enabled.
46func (s *Service) IsConfigured() bool {
47 s.mu.RLock()
48 defer s.mu.RUnlock()
49 return s.cfg.IsConfigured()
50}
51
52// GetModel returns the catwalk model for the given provider and model
53// ID, or nil if not found.
54func (s *Service) GetModel(provider, model string) *catwalk.Model {
55 s.mu.RLock()
56 defer s.mu.RUnlock()
57 return s.cfg.GetModel(provider, model)
58}
59
60// GetProviderForModel returns the provider config for the given model
61// type, or nil.
62func (s *Service) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
63 s.mu.RLock()
64 defer s.mu.RUnlock()
65 return s.cfg.GetProviderForModel(modelType)
66}
67
68// GetModelByType returns the catwalk model for the given model type,
69// or nil.
70func (s *Service) GetModelByType(modelType SelectedModelType) *catwalk.Model {
71 s.mu.RLock()
72 defer s.mu.RUnlock()
73 return s.cfg.GetModelByType(modelType)
74}
75
76// LargeModel returns the catwalk model for the large model type.
77func (s *Service) LargeModel() *catwalk.Model {
78 s.mu.RLock()
79 defer s.mu.RUnlock()
80 return s.cfg.LargeModel()
81}
82
83// SmallModel returns the catwalk model for the small model type.
84func (s *Service) SmallModel() *catwalk.Model {
85 s.mu.RLock()
86 defer s.mu.RUnlock()
87 return s.cfg.SmallModel()
88}
89
90// Resolve resolves a variable value using the configured resolver.
91func (s *Service) Resolve(key string) (string, error) {
92 if s.resolver == nil {
93 return "", fmt.Errorf("no variable resolver configured")
94 }
95 return s.resolver.ResolveValue(key)
96}
97
98// Resolver returns the variable resolver.
99func (s *Service) Resolver() VariableResolver {
100 return s.resolver
101}
102
103// SetupAgents builds the agent configurations from the current config
104// options.
105func (s *Service) SetupAgents() {
106 s.mu.Lock()
107 defer s.mu.Unlock()
108 allowedTools := resolveAllowedTools(allToolNames(), s.cfg.Options.DisabledTools)
109
110 s.agents = map[string]Agent{
111 AgentCoder: {
112 ID: AgentCoder,
113 Name: "Coder",
114 Description: "An agent that helps with executing coding tasks.",
115 Model: SelectedModelTypeLarge,
116 ContextPaths: s.cfg.Options.ContextPaths,
117 AllowedTools: allowedTools,
118 },
119
120 AgentTask: {
121 ID: AgentCoder,
122 Name: "Task",
123 Description: "An agent that helps with searching for context and finding implementation details.",
124 Model: SelectedModelTypeLarge,
125 ContextPaths: s.cfg.Options.ContextPaths,
126 AllowedTools: resolveReadOnlyTools(allowedTools),
127 AllowedMCP: map[string][]string{},
128 },
129 }
130}
131
132// Agents returns the agent configuration map.
133func (s *Service) Agents() map[string]Agent {
134 s.mu.RLock()
135 defer s.mu.RUnlock()
136 return s.agents
137}
138
139// Agent returns the agent configuration for the given name and
140// whether it exists.
141func (s *Service) Agent(name string) (Agent, bool) {
142 s.mu.RLock()
143 defer s.mu.RUnlock()
144 a, ok := s.agents[name]
145 return a, ok
146}
147
148// DataDirectory returns the data directory path.
149func (s *Service) DataDirectory() string {
150 return s.cfg.Options.DataDirectory
151}
152
153// Debug returns whether debug mode is enabled.
154func (s *Service) Debug() bool {
155 return s.cfg.Options.Debug
156}
157
158// DebugLSP returns whether LSP debug mode is enabled.
159func (s *Service) DebugLSP() bool {
160 return s.cfg.Options.DebugLSP
161}
162
163// DisableAutoSummarize returns whether auto-summarization is
164// disabled.
165func (s *Service) DisableAutoSummarize() bool {
166 return s.cfg.Options.DisableAutoSummarize
167}
168
169// Attribution returns the attribution settings.
170func (s *Service) Attribution() *Attribution {
171 s.mu.RLock()
172 defer s.mu.RUnlock()
173 return s.cfg.Options.Attribution
174}
175
176// ContextPaths returns the configured context paths.
177func (s *Service) ContextPaths() []string {
178 return s.cfg.Options.ContextPaths
179}
180
181// SkillsPaths returns the configured skills paths.
182func (s *Service) SkillsPaths() []string {
183 s.mu.RLock()
184 defer s.mu.RUnlock()
185 return s.cfg.Options.SkillsPaths
186}
187
188// Progress returns the progress setting pointer.
189func (s *Service) Progress() *bool {
190 return s.cfg.Options.Progress
191}
192
193// DisableMetrics returns whether metrics are disabled.
194func (s *Service) DisableMetrics() bool {
195 return s.cfg.Options.DisableMetrics
196}
197
198// SelectedModel returns the selected model for the given type and
199// whether it exists.
200func (s *Service) SelectedModel(modelType SelectedModelType) (SelectedModel, bool) {
201 s.mu.RLock()
202 defer s.mu.RUnlock()
203 m, ok := s.cfg.Models[modelType]
204 return m, ok
205}
206
207// Provider returns the provider config for the given ID and whether
208// it exists.
209func (s *Service) Provider(id string) (ProviderConfig, bool) {
210 s.mu.RLock()
211 defer s.mu.RUnlock()
212 p, ok := s.cfg.Providers[id]
213 return p, ok
214}
215
216// SetProvider sets the provider config for the given ID.
217func (s *Service) SetProvider(id string, p ProviderConfig) {
218 s.mu.Lock()
219 defer s.mu.Unlock()
220 s.cfg.Providers[id] = p
221}
222
223// Providers returns all provider configs.
224func (s *Service) AllProviders() map[string]ProviderConfig {
225 s.mu.RLock()
226 defer s.mu.RUnlock()
227 return s.cfg.Providers
228}
229
230// MCP returns the MCP configurations.
231func (s *Service) MCP() MCPs {
232 return s.cfg.MCP
233}
234
235// LSP returns the LSP configurations.
236func (s *Service) LSP() LSPs {
237 s.mu.RLock()
238 defer s.mu.RUnlock()
239 return s.cfg.LSP
240}
241
242// Permissions returns the permissions configuration.
243func (s *Service) Permissions() *Permissions {
244 s.mu.RLock()
245 defer s.mu.RUnlock()
246 return s.cfg.Permissions
247}
248
249// SetAttribution sets the attribution settings.
250func (s *Service) SetAttribution(a *Attribution) {
251 s.mu.Lock()
252 defer s.mu.Unlock()
253 s.cfg.Options.Attribution = a
254}
255
256// SetSkillsPaths sets the skills paths.
257func (s *Service) SetSkillsPaths(paths []string) {
258 s.mu.Lock()
259 defer s.mu.Unlock()
260 s.cfg.Options.SkillsPaths = paths
261}
262
263// SetLSP sets the LSP configurations.
264func (s *Service) SetLSP(lsp LSPs) {
265 s.mu.Lock()
266 defer s.mu.Unlock()
267 s.cfg.LSP = lsp
268}
269
270// SetPermissions sets the permissions configuration.
271func (s *Service) SetPermissions(p *Permissions) {
272 s.mu.Lock()
273 defer s.mu.Unlock()
274 s.cfg.Permissions = p
275}
276
277// OverrideModel overrides the in-memory model for the given type
278// without persisting. Used for non-interactive model overrides.
279func (s *Service) OverrideModel(modelType SelectedModelType, model SelectedModel) {
280 s.mu.Lock()
281 defer s.mu.Unlock()
282 s.cfg.Models[modelType] = model
283}
284
285// ToolLsConfig returns the ls tool configuration.
286func (s *Service) ToolLsConfig() ToolLs {
287 return s.cfg.Tools.Ls
288}
289
290// CompactMode returns whether compact mode is enabled.
291func (s *Service) CompactMode() bool {
292 s.mu.RLock()
293 defer s.mu.RUnlock()
294 if s.cfg.Options.TUI == nil {
295 return false
296 }
297 return s.cfg.Options.TUI.CompactMode
298}
299
300// DiffMode returns the diff mode setting.
301func (s *Service) DiffMode() string {
302 if s.cfg.Options.TUI == nil {
303 return ""
304 }
305 return s.cfg.Options.TUI.DiffMode
306}
307
308// CompletionLimits returns the completion depth and items limits.
309func (s *Service) CompletionLimits() (depth, items int) {
310 if s.cfg.Options.TUI == nil {
311 return 0, 0
312 }
313 return s.cfg.Options.TUI.Completions.Limits()
314}
315
316// DisableDefaultProviders returns whether default providers are
317// disabled.
318func (s *Service) DisableDefaultProviders() bool {
319 return s.cfg.Options.DisableDefaultProviders
320}
321
322// DisableProviderAutoUpdate returns whether provider auto-update is
323// disabled.
324func (s *Service) DisableProviderAutoUpdate() bool {
325 return s.cfg.Options.DisableProviderAutoUpdate
326}
327
328// InitializeAs returns the initialization file name.
329func (s *Service) InitializeAs() string {
330 return s.cfg.Options.InitializeAs
331}
332
333// AutoLSP returns the auto-LSP setting pointer.
334func (s *Service) AutoLSP() *bool {
335 return s.cfg.Options.AutoLSP
336}
337
338// RecentModels returns recent models for the given type.
339func (s *Service) RecentModels(modelType SelectedModelType) []SelectedModel {
340 s.mu.RLock()
341 defer s.mu.RUnlock()
342 return s.cfg.RecentModels[modelType]
343}
344
345// Options returns the full options struct. This is a temporary
346// accessor for callers that need multiple option fields.
347func (s *Service) Options() *Options {
348 return s.cfg.Options
349}
350
351// HasConfigField returns true if the given dotted key path exists in
352// the persisted config data.
353func (s *Service) HasConfigField(key string) bool {
354 return HasField(s.store, key)
355}
356
357// SetConfigField sets a value at the given dotted key path and
358// persists it.
359func (s *Service) SetConfigField(key string, value any) error {
360 return SetField(s.store, key, value)
361}
362
363// RemoveConfigField deletes a value at the given dotted key path and
364// persists it.
365func (s *Service) RemoveConfigField(key string) error {
366 return RemoveField(s.store, key)
367}
368
369// SetCompactMode toggles compact mode and persists the change.
370func (s *Service) SetCompactMode(enabled bool) error {
371 s.mu.Lock()
372 defer s.mu.Unlock()
373 cfg := s.cfg
374 if cfg.Options == nil {
375 cfg.Options = &Options{}
376 }
377 if cfg.Options.TUI == nil {
378 cfg.Options.TUI = &TUIOptions{}
379 }
380 cfg.Options.TUI.CompactMode = enabled
381 return s.SetConfigField("options.tui.compact_mode", enabled)
382}
383
384// UpdatePreferredModel updates the selected model for the given type
385// and persists the change, also recording it in the recent models
386// list.
387func (s *Service) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
388 s.mu.Lock()
389 defer s.mu.Unlock()
390 s.cfg.Models[modelType] = model
391 if err := s.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
392 return fmt.Errorf("failed to update preferred model: %w", err)
393 }
394 if err := s.recordRecentModel(modelType, model); err != nil {
395 return err
396 }
397 return nil
398}
399
400const maxRecentModelsPerType = 5
401
402func (s *Service) recordRecentModel(modelType SelectedModelType, model SelectedModel) error {
403 if model.Provider == "" || model.Model == "" {
404 return nil
405 }
406
407 cfg := s.cfg
408 if cfg.RecentModels == nil {
409 cfg.RecentModels = make(map[SelectedModelType][]SelectedModel)
410 }
411
412 eq := func(a, b SelectedModel) bool {
413 return a.Provider == b.Provider && a.Model == b.Model
414 }
415
416 entry := SelectedModel{
417 Provider: model.Provider,
418 Model: model.Model,
419 }
420
421 current := cfg.RecentModels[modelType]
422 withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
423 return eq(existing, entry)
424 })
425
426 updated := append([]SelectedModel{entry}, withoutCurrent...)
427 if len(updated) > maxRecentModelsPerType {
428 updated = updated[:maxRecentModelsPerType]
429 }
430
431 if slices.EqualFunc(current, updated, eq) {
432 return nil
433 }
434
435 cfg.RecentModels[modelType] = updated
436
437 if err := s.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
438 return fmt.Errorf("failed to persist recent models: %w", err)
439 }
440
441 return nil
442}
443
444// RefreshOAuthToken refreshes the OAuth token for the given provider.
445func (s *Service) RefreshOAuthToken(ctx context.Context, providerID string) error {
446 s.mu.Lock()
447 defer s.mu.Unlock()
448 cfg := s.cfg
449 providerConfig, exists := cfg.Providers[providerID]
450 if !exists {
451 return fmt.Errorf("provider %s not found", providerID)
452 }
453
454 if providerConfig.OAuthToken == nil {
455 return fmt.Errorf("provider %s does not have an OAuth token", providerID)
456 }
457
458 var newToken *oauth.Token
459 var refreshErr error
460 switch providerID {
461 case string(catwalk.InferenceProviderCopilot):
462 newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
463 case hyperp.Name:
464 newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
465 default:
466 return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
467 }
468 if refreshErr != nil {
469 return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
470 }
471
472 slog.Info("Successfully refreshed OAuth token", "provider", providerID)
473 providerConfig.OAuthToken = newToken
474 providerConfig.APIKey = newToken.AccessToken
475
476 switch providerID {
477 case string(catwalk.InferenceProviderCopilot):
478 providerConfig.SetupGitHubCopilot()
479 }
480
481 cfg.Providers[providerID] = providerConfig
482
483 if err := cmp.Or(
484 s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
485 s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken),
486 ); err != nil {
487 return fmt.Errorf("failed to persist refreshed token: %w", err)
488 }
489
490 return nil
491}
492
493// SetProviderAPIKey sets the API key (string or *oauth.Token) for a
494// provider and persists the change.
495func (s *Service) SetProviderAPIKey(providerID string, apiKey any) error {
496 s.mu.Lock()
497 defer s.mu.Unlock()
498 cfg := s.cfg
499 var providerConfig ProviderConfig
500 var exists bool
501 var setKeyOrToken func()
502
503 switch v := apiKey.(type) {
504 case string:
505 if err := s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
506 return fmt.Errorf("failed to save api key to config file: %w", err)
507 }
508 setKeyOrToken = func() { providerConfig.APIKey = v }
509 case *oauth.Token:
510 if err := cmp.Or(
511 s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
512 s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v),
513 ); err != nil {
514 return err
515 }
516 setKeyOrToken = func() {
517 providerConfig.APIKey = v.AccessToken
518 providerConfig.OAuthToken = v
519 switch providerID {
520 case string(catwalk.InferenceProviderCopilot):
521 providerConfig.SetupGitHubCopilot()
522 }
523 }
524 }
525
526 providerConfig, exists = cfg.Providers[providerID]
527 if exists {
528 setKeyOrToken()
529 cfg.Providers[providerID] = providerConfig
530 return nil
531 }
532
533 var foundProvider *catwalk.Provider
534 for _, p := range s.knownProviders {
535 if string(p.ID) == providerID {
536 foundProvider = &p
537 break
538 }
539 }
540
541 if foundProvider != nil {
542 providerConfig = ProviderConfig{
543 ID: providerID,
544 Name: foundProvider.Name,
545 BaseURL: foundProvider.APIEndpoint,
546 Type: foundProvider.Type,
547 Disable: false,
548 ExtraHeaders: make(map[string]string),
549 ExtraParams: make(map[string]string),
550 Models: foundProvider.Models,
551 }
552 setKeyOrToken()
553 } else {
554 return fmt.Errorf("provider with ID %s not found in known providers", providerID)
555 }
556 cfg.Providers[providerID] = providerConfig
557 return nil
558}
559
560// ImportCopilot imports an existing GitHub Copilot token from disk if
561// available and not already configured.
562func (s *Service) ImportCopilot() (*oauth.Token, bool) {
563 s.mu.Lock()
564 defer s.mu.Unlock()
565 if testing.Testing() {
566 return nil, false
567 }
568
569 if s.HasConfigField("providers.copilot.api_key") || s.HasConfigField("providers.copilot.oauth") {
570 return nil, false
571 }
572
573 diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
574 if !hasDiskToken {
575 return nil, false
576 }
577
578 slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
579 token, err := copilot.RefreshToken(context.TODO(), diskToken)
580 if err != nil {
581 slog.Error("Unable to import GitHub Copilot token", "error", err)
582 return nil, false
583 }
584
585 if err := s.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil {
586 return token, false
587 }
588
589 if err := cmp.Or(
590 s.SetConfigField("providers.copilot.api_key", token.AccessToken),
591 s.SetConfigField("providers.copilot.oauth", token),
592 ); err != nil {
593 slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
594 }
595
596 slog.Info("GitHub Copilot successfully imported")
597 return token, true
598}