service.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10	"testing"
 11
 12	"charm.land/catwalk/pkg/catwalk"
 13	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 14	"github.com/charmbracelet/crush/internal/oauth"
 15	"github.com/charmbracelet/crush/internal/oauth/copilot"
 16	"github.com/charmbracelet/crush/internal/oauth/hyper"
 17)
 18
 19// Service is the central access point for configuration. It wraps the
 20// raw Config data and owns all internal state that was previously held
 21// as unexported fields on Config (resolver, store, known providers,
 22// working directory).
 23type Service struct {
 24	mu             sync.RWMutex
 25	cfg            *Config
 26	store          Store
 27	resolver       VariableResolver
 28	workingDir     string
 29	knownProviders []catwalk.Provider
 30	agents         map[string]Agent
 31}
 32
 33// WorkingDir returns the working directory.
 34func (s *Service) WorkingDir() string {
 35	return s.workingDir
 36}
 37
 38// EnabledProviders returns all non-disabled provider configs.
 39func (s *Service) EnabledProviders() []ProviderConfig {
 40	s.mu.RLock()
 41	defer s.mu.RUnlock()
 42	return s.cfg.EnabledProviders()
 43}
 44
 45// IsConfigured returns true if at least one provider is enabled.
 46func (s *Service) IsConfigured() bool {
 47	s.mu.RLock()
 48	defer s.mu.RUnlock()
 49	return s.cfg.IsConfigured()
 50}
 51
 52// GetModel returns the catwalk model for the given provider and model
 53// ID, or nil if not found.
 54func (s *Service) GetModel(provider, model string) *catwalk.Model {
 55	s.mu.RLock()
 56	defer s.mu.RUnlock()
 57	return s.cfg.GetModel(provider, model)
 58}
 59
 60// GetProviderForModel returns the provider config for the given model
 61// type, or nil.
 62func (s *Service) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
 63	s.mu.RLock()
 64	defer s.mu.RUnlock()
 65	return s.cfg.GetProviderForModel(modelType)
 66}
 67
 68// GetModelByType returns the catwalk model for the given model type,
 69// or nil.
 70func (s *Service) GetModelByType(modelType SelectedModelType) *catwalk.Model {
 71	s.mu.RLock()
 72	defer s.mu.RUnlock()
 73	return s.cfg.GetModelByType(modelType)
 74}
 75
 76// LargeModel returns the catwalk model for the large model type.
 77func (s *Service) LargeModel() *catwalk.Model {
 78	s.mu.RLock()
 79	defer s.mu.RUnlock()
 80	return s.cfg.LargeModel()
 81}
 82
 83// SmallModel returns the catwalk model for the small model type.
 84func (s *Service) SmallModel() *catwalk.Model {
 85	s.mu.RLock()
 86	defer s.mu.RUnlock()
 87	return s.cfg.SmallModel()
 88}
 89
 90// Resolve resolves a variable value using the configured resolver.
 91func (s *Service) Resolve(key string) (string, error) {
 92	if s.resolver == nil {
 93		return "", fmt.Errorf("no variable resolver configured")
 94	}
 95	return s.resolver.ResolveValue(key)
 96}
 97
 98// Resolver returns the variable resolver.
 99func (s *Service) Resolver() VariableResolver {
100	return s.resolver
101}
102
103// SetupAgents builds the agent configurations from the current config
104// options.
105func (s *Service) SetupAgents() {
106	s.mu.Lock()
107	defer s.mu.Unlock()
108	allowedTools := resolveAllowedTools(allToolNames(), s.cfg.Options.DisabledTools)
109
110	s.agents = map[string]Agent{
111		AgentCoder: {
112			ID:           AgentCoder,
113			Name:         "Coder",
114			Description:  "An agent that helps with executing coding tasks.",
115			Model:        SelectedModelTypeLarge,
116			ContextPaths: s.cfg.Options.ContextPaths,
117			AllowedTools: allowedTools,
118		},
119
120		AgentTask: {
121			ID:           AgentCoder,
122			Name:         "Task",
123			Description:  "An agent that helps with searching for context and finding implementation details.",
124			Model:        SelectedModelTypeLarge,
125			ContextPaths: s.cfg.Options.ContextPaths,
126			AllowedTools: resolveReadOnlyTools(allowedTools),
127			AllowedMCP:   map[string][]string{},
128		},
129	}
130}
131
132// Agents returns the agent configuration map.
133func (s *Service) Agents() map[string]Agent {
134	s.mu.RLock()
135	defer s.mu.RUnlock()
136	return s.agents
137}
138
139// Agent returns the agent configuration for the given name and
140// whether it exists.
141func (s *Service) Agent(name string) (Agent, bool) {
142	s.mu.RLock()
143	defer s.mu.RUnlock()
144	a, ok := s.agents[name]
145	return a, ok
146}
147
148// DataDirectory returns the data directory path.
149func (s *Service) DataDirectory() string {
150	return s.cfg.Options.DataDirectory
151}
152
153// Debug returns whether debug mode is enabled.
154func (s *Service) Debug() bool {
155	return s.cfg.Options.Debug
156}
157
158// DebugLSP returns whether LSP debug mode is enabled.
159func (s *Service) DebugLSP() bool {
160	return s.cfg.Options.DebugLSP
161}
162
163// DisableAutoSummarize returns whether auto-summarization is
164// disabled.
165func (s *Service) DisableAutoSummarize() bool {
166	return s.cfg.Options.DisableAutoSummarize
167}
168
169// Attribution returns the attribution settings.
170func (s *Service) Attribution() *Attribution {
171	s.mu.RLock()
172	defer s.mu.RUnlock()
173	return s.cfg.Options.Attribution
174}
175
176// ContextPaths returns the configured context paths.
177func (s *Service) ContextPaths() []string {
178	return s.cfg.Options.ContextPaths
179}
180
181// SkillsPaths returns the configured skills paths.
182func (s *Service) SkillsPaths() []string {
183	s.mu.RLock()
184	defer s.mu.RUnlock()
185	return s.cfg.Options.SkillsPaths
186}
187
188// Progress returns the progress setting pointer.
189func (s *Service) Progress() *bool {
190	return s.cfg.Options.Progress
191}
192
193// DisableMetrics returns whether metrics are disabled.
194func (s *Service) DisableMetrics() bool {
195	return s.cfg.Options.DisableMetrics
196}
197
198// SelectedModel returns the selected model for the given type and
199// whether it exists.
200func (s *Service) SelectedModel(modelType SelectedModelType) (SelectedModel, bool) {
201	s.mu.RLock()
202	defer s.mu.RUnlock()
203	m, ok := s.cfg.Models[modelType]
204	return m, ok
205}
206
207// Provider returns the provider config for the given ID and whether
208// it exists.
209func (s *Service) Provider(id string) (ProviderConfig, bool) {
210	s.mu.RLock()
211	defer s.mu.RUnlock()
212	p, ok := s.cfg.Providers[id]
213	return p, ok
214}
215
216// SetProvider sets the provider config for the given ID.
217func (s *Service) SetProvider(id string, p ProviderConfig) {
218	s.mu.Lock()
219	defer s.mu.Unlock()
220	s.cfg.Providers[id] = p
221}
222
223// Providers returns all provider configs.
224func (s *Service) AllProviders() map[string]ProviderConfig {
225	s.mu.RLock()
226	defer s.mu.RUnlock()
227	return s.cfg.Providers
228}
229
230// MCP returns the MCP configurations.
231func (s *Service) MCP() MCPs {
232	return s.cfg.MCP
233}
234
235// LSP returns the LSP configurations.
236func (s *Service) LSP() LSPs {
237	s.mu.RLock()
238	defer s.mu.RUnlock()
239	return s.cfg.LSP
240}
241
242// Permissions returns the permissions configuration.
243func (s *Service) Permissions() *Permissions {
244	s.mu.RLock()
245	defer s.mu.RUnlock()
246	return s.cfg.Permissions
247}
248
249// SetAttribution sets the attribution settings.
250func (s *Service) SetAttribution(a *Attribution) {
251	s.mu.Lock()
252	defer s.mu.Unlock()
253	s.cfg.Options.Attribution = a
254}
255
256// SetSkillsPaths sets the skills paths.
257func (s *Service) SetSkillsPaths(paths []string) {
258	s.mu.Lock()
259	defer s.mu.Unlock()
260	s.cfg.Options.SkillsPaths = paths
261}
262
263// SetLSP sets the LSP configurations.
264func (s *Service) SetLSP(lsp LSPs) {
265	s.mu.Lock()
266	defer s.mu.Unlock()
267	s.cfg.LSP = lsp
268}
269
270// SetPermissions sets the permissions configuration.
271func (s *Service) SetPermissions(p *Permissions) {
272	s.mu.Lock()
273	defer s.mu.Unlock()
274	s.cfg.Permissions = p
275}
276
277// OverrideModel overrides the in-memory model for the given type
278// without persisting. Used for non-interactive model overrides.
279func (s *Service) OverrideModel(modelType SelectedModelType, model SelectedModel) {
280	s.mu.Lock()
281	defer s.mu.Unlock()
282	s.cfg.Models[modelType] = model
283}
284
285// ToolLsConfig returns the ls tool configuration.
286func (s *Service) ToolLsConfig() ToolLs {
287	return s.cfg.Tools.Ls
288}
289
290// CompactMode returns whether compact mode is enabled.
291func (s *Service) CompactMode() bool {
292	s.mu.RLock()
293	defer s.mu.RUnlock()
294	if s.cfg.Options.TUI == nil {
295		return false
296	}
297	return s.cfg.Options.TUI.CompactMode
298}
299
300// DiffMode returns the diff mode setting.
301func (s *Service) DiffMode() string {
302	if s.cfg.Options.TUI == nil {
303		return ""
304	}
305	return s.cfg.Options.TUI.DiffMode
306}
307
308// CompletionLimits returns the completion depth and items limits.
309func (s *Service) CompletionLimits() (depth, items int) {
310	if s.cfg.Options.TUI == nil {
311		return 0, 0
312	}
313	return s.cfg.Options.TUI.Completions.Limits()
314}
315
316// DisableDefaultProviders returns whether default providers are
317// disabled.
318func (s *Service) DisableDefaultProviders() bool {
319	return s.cfg.Options.DisableDefaultProviders
320}
321
322// DisableProviderAutoUpdate returns whether provider auto-update is
323// disabled.
324func (s *Service) DisableProviderAutoUpdate() bool {
325	return s.cfg.Options.DisableProviderAutoUpdate
326}
327
328// InitializeAs returns the initialization file name.
329func (s *Service) InitializeAs() string {
330	return s.cfg.Options.InitializeAs
331}
332
333// AutoLSP returns the auto-LSP setting pointer.
334func (s *Service) AutoLSP() *bool {
335	return s.cfg.Options.AutoLSP
336}
337
338// RecentModels returns recent models for the given type.
339func (s *Service) RecentModels(modelType SelectedModelType) []SelectedModel {
340	s.mu.RLock()
341	defer s.mu.RUnlock()
342	return s.cfg.RecentModels[modelType]
343}
344
345// Options returns the full options struct. This is a temporary
346// accessor for callers that need multiple option fields.
347func (s *Service) Options() *Options {
348	return s.cfg.Options
349}
350
351// HasConfigField returns true if the given dotted key path exists in
352// the persisted config data.
353func (s *Service) HasConfigField(key string) bool {
354	return HasField(s.store, key)
355}
356
357// SetConfigField sets a value at the given dotted key path and
358// persists it.
359func (s *Service) SetConfigField(key string, value any) error {
360	return SetField(s.store, key, value)
361}
362
363// RemoveConfigField deletes a value at the given dotted key path and
364// persists it.
365func (s *Service) RemoveConfigField(key string) error {
366	return RemoveField(s.store, key)
367}
368
369// SetCompactMode toggles compact mode and persists the change.
370func (s *Service) SetCompactMode(enabled bool) error {
371	s.mu.Lock()
372	defer s.mu.Unlock()
373	cfg := s.cfg
374	if cfg.Options == nil {
375		cfg.Options = &Options{}
376	}
377	if cfg.Options.TUI == nil {
378		cfg.Options.TUI = &TUIOptions{}
379	}
380	cfg.Options.TUI.CompactMode = enabled
381	return s.SetConfigField("options.tui.compact_mode", enabled)
382}
383
384// UpdatePreferredModel updates the selected model for the given type
385// and persists the change, also recording it in the recent models
386// list.
387func (s *Service) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
388	s.mu.Lock()
389	defer s.mu.Unlock()
390	s.cfg.Models[modelType] = model
391	if err := s.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
392		return fmt.Errorf("failed to update preferred model: %w", err)
393	}
394	if err := s.recordRecentModel(modelType, model); err != nil {
395		return err
396	}
397	return nil
398}
399
400const maxRecentModelsPerType = 5
401
402func (s *Service) recordRecentModel(modelType SelectedModelType, model SelectedModel) error {
403	if model.Provider == "" || model.Model == "" {
404		return nil
405	}
406
407	cfg := s.cfg
408	if cfg.RecentModels == nil {
409		cfg.RecentModels = make(map[SelectedModelType][]SelectedModel)
410	}
411
412	eq := func(a, b SelectedModel) bool {
413		return a.Provider == b.Provider && a.Model == b.Model
414	}
415
416	entry := SelectedModel{
417		Provider: model.Provider,
418		Model:    model.Model,
419	}
420
421	current := cfg.RecentModels[modelType]
422	withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
423		return eq(existing, entry)
424	})
425
426	updated := append([]SelectedModel{entry}, withoutCurrent...)
427	if len(updated) > maxRecentModelsPerType {
428		updated = updated[:maxRecentModelsPerType]
429	}
430
431	if slices.EqualFunc(current, updated, eq) {
432		return nil
433	}
434
435	cfg.RecentModels[modelType] = updated
436
437	if err := s.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
438		return fmt.Errorf("failed to persist recent models: %w", err)
439	}
440
441	return nil
442}
443
444// RefreshOAuthToken refreshes the OAuth token for the given provider.
445func (s *Service) RefreshOAuthToken(ctx context.Context, providerID string) error {
446	s.mu.Lock()
447	defer s.mu.Unlock()
448	cfg := s.cfg
449	providerConfig, exists := cfg.Providers[providerID]
450	if !exists {
451		return fmt.Errorf("provider %s not found", providerID)
452	}
453
454	if providerConfig.OAuthToken == nil {
455		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
456	}
457
458	var newToken *oauth.Token
459	var refreshErr error
460	switch providerID {
461	case string(catwalk.InferenceProviderCopilot):
462		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
463	case hyperp.Name:
464		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
465	default:
466		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
467	}
468	if refreshErr != nil {
469		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
470	}
471
472	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
473	providerConfig.OAuthToken = newToken
474	providerConfig.APIKey = newToken.AccessToken
475
476	switch providerID {
477	case string(catwalk.InferenceProviderCopilot):
478		providerConfig.SetupGitHubCopilot()
479	}
480
481	cfg.Providers[providerID] = providerConfig
482
483	if err := cmp.Or(
484		s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
485		s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken),
486	); err != nil {
487		return fmt.Errorf("failed to persist refreshed token: %w", err)
488	}
489
490	return nil
491}
492
493// SetProviderAPIKey sets the API key (string or *oauth.Token) for a
494// provider and persists the change.
495func (s *Service) SetProviderAPIKey(providerID string, apiKey any) error {
496	s.mu.Lock()
497	defer s.mu.Unlock()
498	cfg := s.cfg
499	var providerConfig ProviderConfig
500	var exists bool
501	var setKeyOrToken func()
502
503	switch v := apiKey.(type) {
504	case string:
505		if err := s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
506			return fmt.Errorf("failed to save api key to config file: %w", err)
507		}
508		setKeyOrToken = func() { providerConfig.APIKey = v }
509	case *oauth.Token:
510		if err := cmp.Or(
511			s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
512			s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v),
513		); err != nil {
514			return err
515		}
516		setKeyOrToken = func() {
517			providerConfig.APIKey = v.AccessToken
518			providerConfig.OAuthToken = v
519			switch providerID {
520			case string(catwalk.InferenceProviderCopilot):
521				providerConfig.SetupGitHubCopilot()
522			}
523		}
524	}
525
526	providerConfig, exists = cfg.Providers[providerID]
527	if exists {
528		setKeyOrToken()
529		cfg.Providers[providerID] = providerConfig
530		return nil
531	}
532
533	var foundProvider *catwalk.Provider
534	for _, p := range s.knownProviders {
535		if string(p.ID) == providerID {
536			foundProvider = &p
537			break
538		}
539	}
540
541	if foundProvider != nil {
542		providerConfig = ProviderConfig{
543			ID:           providerID,
544			Name:         foundProvider.Name,
545			BaseURL:      foundProvider.APIEndpoint,
546			Type:         foundProvider.Type,
547			Disable:      false,
548			ExtraHeaders: make(map[string]string),
549			ExtraParams:  make(map[string]string),
550			Models:       foundProvider.Models,
551		}
552		setKeyOrToken()
553	} else {
554		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
555	}
556	cfg.Providers[providerID] = providerConfig
557	return nil
558}
559
560// ImportCopilot imports an existing GitHub Copilot token from disk if
561// available and not already configured.
562func (s *Service) ImportCopilot() (*oauth.Token, bool) {
563	s.mu.Lock()
564	defer s.mu.Unlock()
565	if testing.Testing() {
566		return nil, false
567	}
568
569	if s.HasConfigField("providers.copilot.api_key") || s.HasConfigField("providers.copilot.oauth") {
570		return nil, false
571	}
572
573	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
574	if !hasDiskToken {
575		return nil, false
576	}
577
578	slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
579	token, err := copilot.RefreshToken(context.TODO(), diskToken)
580	if err != nil {
581		slog.Error("Unable to import GitHub Copilot token", "error", err)
582		return nil, false
583	}
584
585	if err := s.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil {
586		return token, false
587	}
588
589	if err := cmp.Or(
590		s.SetConfigField("providers.copilot.api_key", token.AccessToken),
591		s.SetConfigField("providers.copilot.oauth", token),
592	); err != nil {
593		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
594	}
595
596	slog.Info("GitHub Copilot successfully imported")
597	return token, true
598}