From 678c8f78487777e5fa50c7e401112fda79d8b29b Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 6 Feb 2026 12:39:02 +0100 Subject: [PATCH] refactor: move mutation methods from Config to Service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move SetCompactMode, UpdatePreferredModel, SetProviderAPIKey, RefreshOAuthToken, ImportCopilot, SetConfigField, RemoveConfigField, HasConfigField, and recordRecentModel to Service. Add ConfigService() accessor to App and Common for callers that need mutations. Config retains only unexported setConfigField/removeConfigField for the internal load flow. 🐾 Generated with Crush Assisted-by: Claude Opus 4.6 via Crush --- internal/agent/coordinator.go | 7 +- internal/app/app.go | 16 +- internal/cmd/login.go | 8 +- internal/cmd/root.go | 2 +- internal/config/config.go | 191 +------------------ internal/config/copilot.go | 48 ----- internal/config/load.go | 12 +- internal/config/recent_models_test.go | 97 +++++----- internal/config/service.go | 254 +++++++++++++++++++++++++- internal/ui/common/common.go | 5 + internal/ui/dialog/api_key_input.go | 4 +- internal/ui/dialog/models.go | 2 +- internal/ui/dialog/oauth.go | 4 +- internal/ui/model/ui.go | 12 +- 14 files changed, 348 insertions(+), 314 deletions(-) delete mode 100644 internal/config/copilot.go diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index ad57b20c4470aa0180120798034db1bdb1de601a..97963fe9a728d0febb8da3ad0439fafba86babbe 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -62,6 +62,7 @@ type Coordinator interface { type coordinator struct { cfg *config.Config + cfgSvc *config.Service sessions session.Service messages message.Service permissions permission.Service @@ -77,7 +78,7 @@ type coordinator struct { func NewCoordinator( ctx context.Context, - cfg *config.Config, + cfgSvc *config.Service, sessions session.Service, messages message.Service, permissions permission.Service, @@ -85,8 +86,10 @@ func NewCoordinator( filetracker filetracker.Service, lspClients *csync.Map[string, *lsp.Client], ) (Coordinator, error) { + cfg := cfgSvc.Config() c := &coordinator{ cfg: cfg, + cfgSvc: cfgSvc, sessions: sessions, messages: messages, permissions: permissions, @@ -891,7 +894,7 @@ func (c *coordinator) isUnauthorized(err error) bool { } func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error { - if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { + if err := c.cfgSvc.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err) return err } diff --git a/internal/app/app.go b/internal/app/app.go index f0cabfa534a58401280fb5e9b973aa6f5a9d91c9..399e8a1abaf1ff59d8136e19daa64a50bfc1b81a 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -60,7 +60,8 @@ type App struct { LSPClients *csync.Map[string, *lsp.Client] - config *config.Config + configService *config.Service + config *config.Config serviceEventsWG *sync.WaitGroup eventsCtx context.Context @@ -73,7 +74,8 @@ type App struct { } // New initializes a new application instance. -func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { +func New(ctx context.Context, conn *sql.DB, cfgSvc *config.Service) (*App, error) { + cfg := cfgSvc.Config() q := db.New(conn) sessions := session.NewService(q, conn) messages := message.NewService(q) @@ -94,7 +96,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { globalCtx: ctx, - config: cfg, + configService: cfgSvc, + config: cfg, events: make(chan tea.Msg, 100), serviceEventsWG: &sync.WaitGroup{}, @@ -125,6 +128,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { return app, nil } +// ConfigService returns the config service. +func (app *App) ConfigService() *config.Service { + return app.configService +} + // Config returns the application configuration. func (app *App) Config() *config.Config { return app.config @@ -462,7 +470,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error { var err error app.AgentCoordinator, err = agent.NewCoordinator( ctx, - app.config, + app.configService, app.Sessions, app.Messages, app.Permissions, diff --git a/internal/cmd/login.go b/internal/cmd/login.go index bdad4547d6f583b5ae7e5a97bbbbd88a1421e6ee..9da3e5775ddd9888f54231180b68667c8d8bbce4 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -52,16 +52,16 @@ crush login copilot } switch provider { case "hyper": - return loginHyper(app.Config()) + return loginHyper(app.ConfigService()) case "copilot", "github", "github-copilot": - return loginCopilot(app.Config()) + return loginCopilot(app.ConfigService()) default: return fmt.Errorf("unknown platform: %s", args[0]) } }, } -func loginHyper(cfg *config.Config) error { +func loginHyper(cfg *config.Service) error { if !hyperp.Enabled() { return fmt.Errorf("hyper not enabled") } @@ -123,7 +123,7 @@ func loginHyper(cfg *config.Config) error { return nil } -func loginCopilot(cfg *config.Config) error { +func loginCopilot(cfg *config.Service) error { ctx := getLoginContext() if cfg.HasConfigField("providers.copilot.oauth") { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index f88f8ed6fce3cad33080a6585399cd08ab93e27f..02c05db511f6a28b4b7b907c2fbc1edcabede549 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -222,7 +222,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - appInstance, err := app.New(ctx, conn, cfg) + appInstance, err := app.New(ctx, conn, svc) if err != nil { slog.Error("Failed to create app instance", "error", err) return nil, err diff --git a/internal/config/config.go b/internal/config/config.go index 860143c9f805952409a1cb8572c5ac1c629d81a2..0dc07196f08ffffeffd96c532a7bded73409a35f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,6 @@ package config import ( - "cmp" "context" "fmt" "log/slog" @@ -13,12 +12,10 @@ import ( "time" "charm.land/catwalk/pkg/catwalk" - hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" - "github.com/charmbracelet/crush/internal/oauth/hyper" "github.com/invopop/jsonschema" ) @@ -457,14 +454,6 @@ func (c *Config) SmallModel() *catwalk.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) SetCompactMode(enabled bool) error { - if c.Options == nil { - c.Options = &Options{} - } - c.Options.TUI.CompactMode = enabled - return c.SetConfigField("options.tui.compact_mode", enabled) -} - func (c *Config) Resolve(key string) (string, error) { if c.resolver == nil { return "", fmt.Errorf("no variable resolver configured") @@ -472,187 +461,19 @@ func (c *Config) Resolve(key string) (string, error) { return c.resolver.ResolveValue(key) } -func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { - c.Models[modelType] = model - if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { - return fmt.Errorf("failed to update preferred model: %w", err) - } - if err := c.recordRecentModel(modelType, model); err != nil { - return err - } - return nil -} - -func (c *Config) configStore() Store { - if c.store == nil { - c.store = NewFileStore(c.dataConfigDir) - } - return c.store -} - -func (c *Config) HasConfigField(key string) bool { - return HasField(c.configStore(), key) -} - -func (c *Config) SetConfigField(key string, value any) error { +func (c *Config) setConfigField(key string, value any) error { return SetField(c.configStore(), key, value) } -func (c *Config) RemoveConfigField(key string) error { +func (c *Config) removeConfigField(key string) error { return RemoveField(c.configStore(), key) } -// RefreshOAuthToken refreshes the OAuth token for the given provider. -func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error { - providerConfig, exists := c.Providers.Get(providerID) - if !exists { - return fmt.Errorf("provider %s not found", providerID) - } - - if providerConfig.OAuthToken == nil { - return fmt.Errorf("provider %s does not have an OAuth token", providerID) - } - - var newToken *oauth.Token - var refreshErr error - switch providerID { - case string(catwalk.InferenceProviderCopilot): - newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) - case hyperp.Name: - newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) - default: - return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) - } - if refreshErr != nil { - return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) - } - - slog.Info("Successfully refreshed OAuth token", "provider", providerID) - providerConfig.OAuthToken = newToken - providerConfig.APIKey = newToken.AccessToken - - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - - c.Providers.Set(providerID, providerConfig) - - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken), - ); err != nil { - return fmt.Errorf("failed to persist refreshed token: %w", err) - } - - return nil -} - -func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { - var providerConfig ProviderConfig - var exists bool - var setKeyOrToken func() - - switch v := apiKey.(type) { - case string: - if err := c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { - return fmt.Errorf("failed to save api key to config file: %w", err) - } - setKeyOrToken = func() { providerConfig.APIKey = v } - case *oauth.Token: - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v), - ); err != nil { - return err - } - setKeyOrToken = func() { - providerConfig.APIKey = v.AccessToken - providerConfig.OAuthToken = v - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - } - } - - providerConfig, exists = c.Providers.Get(providerID) - if exists { - setKeyOrToken() - c.Providers.Set(providerID, providerConfig) - return nil - } - - var foundProvider *catwalk.Provider - for _, p := range c.knownProviders { - if string(p.ID) == providerID { - foundProvider = &p - break - } - } - - if foundProvider != nil { - // Create new provider config based on known provider - providerConfig = ProviderConfig{ - ID: providerID, - Name: foundProvider.Name, - BaseURL: foundProvider.APIEndpoint, - Type: foundProvider.Type, - Disable: false, - ExtraHeaders: make(map[string]string), - ExtraParams: make(map[string]string), - Models: foundProvider.Models, - } - setKeyOrToken() - } else { - return fmt.Errorf("provider with ID %s not found in known providers", providerID) - } - // Store the updated provider config - c.Providers.Set(providerID, providerConfig) - return nil -} - -const maxRecentModelsPerType = 5 - -func (c *Config) recordRecentModel(modelType SelectedModelType, model SelectedModel) error { - if model.Provider == "" || model.Model == "" { - return nil - } - - if c.RecentModels == nil { - c.RecentModels = make(map[SelectedModelType][]SelectedModel) - } - - eq := func(a, b SelectedModel) bool { - return a.Provider == b.Provider && a.Model == b.Model - } - - entry := SelectedModel{ - Provider: model.Provider, - Model: model.Model, - } - - current := c.RecentModels[modelType] - withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { - return eq(existing, entry) - }) - - updated := append([]SelectedModel{entry}, withoutCurrent...) - if len(updated) > maxRecentModelsPerType { - updated = updated[:maxRecentModelsPerType] - } - - if slices.EqualFunc(current, updated, eq) { - return nil - } - - c.RecentModels[modelType] = updated - - if err := c.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { - return fmt.Errorf("failed to persist recent models: %w", err) +func (c *Config) configStore() Store { + if c.store == nil { + c.store = NewFileStore(c.dataConfigDir) } - - return nil + return c.store } func allToolNames() []string { diff --git a/internal/config/copilot.go b/internal/config/copilot.go deleted file mode 100644 index d72e7d5048ba4d31c88d7f7152a6b3a9510960a2..0000000000000000000000000000000000000000 --- a/internal/config/copilot.go +++ /dev/null @@ -1,48 +0,0 @@ -package config - -import ( - "cmp" - "context" - "log/slog" - "testing" - - "charm.land/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/oauth" - "github.com/charmbracelet/crush/internal/oauth/copilot" -) - -func (c *Config) ImportCopilot() (*oauth.Token, bool) { - if testing.Testing() { - return nil, false - } - - if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") { - return nil, false - } - - diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() - if !hasDiskToken { - return nil, false - } - - slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") - token, err := copilot.RefreshToken(context.TODO(), diskToken) - if err != nil { - slog.Error("Unable to import GitHub Copilot token", "error", err) - return nil, false - } - - if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil { - return token, false - } - - if err := cmp.Or( - c.SetConfigField("providers.copilot.api_key", token.AccessToken), - c.SetConfigField("providers.copilot.oauth", token), - ); err != nil { - slog.Error("Unable to save GitHub Copilot token to disk", "error", err) - } - - slog.Info("GitHub Copilot successfully imported") - return token, true -} diff --git a/internal/config/load.go b/internal/config/load.go index 02095b5839822e05eef31ff1d9917418a66898a8..84530dbff179689102ee5e4e1863714674485e63 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -220,7 +220,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch { case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: // Claude Code subscription is not supported anymore. Remove to show onboarding. - c.RemoveConfigField("providers.anthropic") + c.removeConfigField("providers.anthropic") c.Providers.Del(string(p.ID)) continue case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil: @@ -558,9 +558,8 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro model := c.GetModel(large.Provider, large.Model) if model == nil { large = defaultLarge - // override the model type to large - err := c.UpdatePreferredModel(SelectedModelTypeLarge, large) - if err != nil { + c.Models[SelectedModelTypeLarge] = large + if err := c.setConfigField(fmt.Sprintf("models.%s", SelectedModelTypeLarge), large); err != nil { return fmt.Errorf("failed to update preferred large model: %w", err) } } else { @@ -602,9 +601,8 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro model := c.GetModel(small.Provider, small.Model) if model == nil { small = defaultSmall - // override the model type to small - err := c.UpdatePreferredModel(SelectedModelTypeSmall, small) - if err != nil { + c.Models[SelectedModelTypeSmall] = small + if err := c.setConfigField(fmt.Sprintf("models.%s", SelectedModelTypeSmall), small); err != nil { return fmt.Errorf("failed to update preferred small model: %w", err) } } else { diff --git a/internal/config/recent_models_test.go b/internal/config/recent_models_test.go index 739ddc0031a65cab261723772c3f38658dcd1561..ea7ded48601a2696f895bf0d2f01936f9897e7b2 100644 --- a/internal/config/recent_models_test.go +++ b/internal/config/recent_models_test.go @@ -31,15 +31,27 @@ func readRecentModels(t *testing.T, path string) map[string]any { return rm } -func TestRecordRecentModel_AddsAndPersists(t *testing.T) { - t.Parallel() - +func newTestService(t *testing.T) (*Service, string) { + t.Helper() dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + storePath := filepath.Join(dir, "config.json") + svc := &Service{ + cfg: cfg, + store: NewFileStore(storePath), + workingDir: dir, + } + return svc, storePath +} + +func TestRecordRecentModel_AddsAndPersists(t *testing.T) { + t.Parallel() - err := cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) + svc, storePath := newTestService(t) + cfg := svc.cfg + + err := svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) require.NoError(t, err) // in-memory state @@ -48,7 +60,7 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) { require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model) // persisted state - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 1) @@ -61,16 +73,14 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) { func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, _ := newTestService(t) + cfg := svc.cfg // Add two entries - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) // Re-add first; should move to front and not duplicate - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) got := cfg.RecentModels[SelectedModelTypeLarge] require.Len(t, got, 2) @@ -81,10 +91,8 @@ func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) { func TestRecordRecentModel_TrimsToMax(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg // Insert 6 unique models; max is 5 entries := []SelectedModel{ @@ -96,7 +104,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { {Provider: "p6", Model: "m6"}, } for _, e := range entries { - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, e)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, e)) } // in-memory state @@ -110,7 +118,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4]) // persisted state: verify trimmed to 5 and newest-first order - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 5) @@ -126,15 +134,13 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg // Missing provider - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) // Missing model - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) _, ok := cfg.RecentModels[SelectedModelTypeLarge] // Map may be initialized, but should have no entries @@ -142,8 +148,8 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0) } // No file should be written (stat via fs.FS) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(storePath) + fileName := filepath.Base(storePath) _, err := fs.Stat(os.DirFS(baseDir), fileName) require.True(t, os.IsNotExist(err)) } @@ -151,16 +157,13 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) entry := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry)) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(storePath) + fileName := filepath.Base(storePath) before, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -170,7 +173,7 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { beforeMod := stBefore.ModTime() // Re-record same entry should be a no-op (no write) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry)) after, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -185,20 +188,18 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg sel := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.UpdatePreferredModel(SelectedModelTypeSmall, sel)) + require.NoError(t, svc.UpdatePreferredModel(SelectedModelTypeSmall, sel)) // in-memory require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall]) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) // persisted (read via fs.FS) - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) small, ok := rm[string(SelectedModelTypeSmall)].([]any) require.True(t, ok) require.Len(t, small, 1) @@ -207,17 +208,15 @@ func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) { func TestRecordRecentModel_TypeIsolation(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg // Add models to both large and small types largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"} smallModel := SelectedModel{Provider: "anthropic", Model: "claude"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, largeModel)) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeSmall, smallModel)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, largeModel)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeSmall, smallModel)) // in-memory: verify types maintain separate histories require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1) @@ -227,14 +226,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) { // Add another to large, verify small unchanged anotherLarge := SelectedModel{Provider: "google", Model: "gemini"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, anotherLarge)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, anotherLarge)) require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0]) // persisted state: verify both types exist with correct lengths and contents - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) diff --git a/internal/config/service.go b/internal/config/service.go index 5fd98bcb91b7586f44773fdde8b2da498db7f86b..2d595b5d919315d0bc2904566cf73d8717cb5bce 100644 --- a/internal/config/service.go +++ b/internal/config/service.go @@ -1,6 +1,19 @@ package config -import "charm.land/catwalk/pkg/catwalk" +import ( + "cmp" + "context" + "fmt" + "log/slog" + "slices" + "testing" + + "charm.land/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/charmbracelet/crush/internal/oauth/hyper" +) // Service is the central access point for configuration. It wraps the // raw Config data and owns all internal state that was previously held @@ -20,3 +33,242 @@ type Service struct { func (s *Service) Config() *Config { return s.cfg } + +// HasConfigField returns true if the given dotted key path exists in +// the persisted config data. +func (s *Service) HasConfigField(key string) bool { + return HasField(s.store, key) +} + +// SetConfigField sets a value at the given dotted key path and +// persists it. +func (s *Service) SetConfigField(key string, value any) error { + return SetField(s.store, key, value) +} + +// RemoveConfigField deletes a value at the given dotted key path and +// persists it. +func (s *Service) RemoveConfigField(key string) error { + return RemoveField(s.store, key) +} + +// SetCompactMode toggles compact mode and persists the change. +func (s *Service) SetCompactMode(enabled bool) error { + cfg := s.cfg + if cfg.Options == nil { + cfg.Options = &Options{} + } + if cfg.Options.TUI == nil { + cfg.Options.TUI = &TUIOptions{} + } + cfg.Options.TUI.CompactMode = enabled + return s.SetConfigField("options.tui.compact_mode", enabled) +} + +// UpdatePreferredModel updates the selected model for the given type +// and persists the change, also recording it in the recent models +// list. +func (s *Service) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { + s.cfg.Models[modelType] = model + if err := s.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { + return fmt.Errorf("failed to update preferred model: %w", err) + } + if err := s.recordRecentModel(modelType, model); err != nil { + return err + } + return nil +} + +const maxRecentModelsPerType = 5 + +func (s *Service) recordRecentModel(modelType SelectedModelType, model SelectedModel) error { + if model.Provider == "" || model.Model == "" { + return nil + } + + cfg := s.cfg + if cfg.RecentModels == nil { + cfg.RecentModels = make(map[SelectedModelType][]SelectedModel) + } + + eq := func(a, b SelectedModel) bool { + return a.Provider == b.Provider && a.Model == b.Model + } + + entry := SelectedModel{ + Provider: model.Provider, + Model: model.Model, + } + + current := cfg.RecentModels[modelType] + withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { + return eq(existing, entry) + }) + + updated := append([]SelectedModel{entry}, withoutCurrent...) + if len(updated) > maxRecentModelsPerType { + updated = updated[:maxRecentModelsPerType] + } + + if slices.EqualFunc(current, updated, eq) { + return nil + } + + cfg.RecentModels[modelType] = updated + + if err := s.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { + return fmt.Errorf("failed to persist recent models: %w", err) + } + + return nil +} + +// RefreshOAuthToken refreshes the OAuth token for the given provider. +func (s *Service) RefreshOAuthToken(ctx context.Context, providerID string) error { + cfg := s.cfg + providerConfig, exists := cfg.Providers.Get(providerID) + if !exists { + return fmt.Errorf("provider %s not found", providerID) + } + + if providerConfig.OAuthToken == nil { + return fmt.Errorf("provider %s does not have an OAuth token", providerID) + } + + var newToken *oauth.Token + var refreshErr error + switch providerID { + case string(catwalk.InferenceProviderCopilot): + newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + case hyperp.Name: + newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + default: + return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) + } + if refreshErr != nil { + return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) + } + + slog.Info("Successfully refreshed OAuth token", "provider", providerID) + providerConfig.OAuthToken = newToken + providerConfig.APIKey = newToken.AccessToken + + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + + cfg.Providers.Set(providerID, providerConfig) + + if err := cmp.Or( + s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), + s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken), + ); err != nil { + return fmt.Errorf("failed to persist refreshed token: %w", err) + } + + return nil +} + +// SetProviderAPIKey sets the API key (string or *oauth.Token) for a +// provider and persists the change. +func (s *Service) SetProviderAPIKey(providerID string, apiKey any) error { + cfg := s.cfg + var providerConfig ProviderConfig + var exists bool + var setKeyOrToken func() + + switch v := apiKey.(type) { + case string: + if err := s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { + return fmt.Errorf("failed to save api key to config file: %w", err) + } + setKeyOrToken = func() { providerConfig.APIKey = v } + case *oauth.Token: + if err := cmp.Or( + s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), + s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v), + ); err != nil { + return err + } + setKeyOrToken = func() { + providerConfig.APIKey = v.AccessToken + providerConfig.OAuthToken = v + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + } + } + + providerConfig, exists = cfg.Providers.Get(providerID) + if exists { + setKeyOrToken() + cfg.Providers.Set(providerID, providerConfig) + return nil + } + + var foundProvider *catwalk.Provider + for _, p := range s.knownProviders { + if string(p.ID) == providerID { + foundProvider = &p + break + } + } + + if foundProvider != nil { + providerConfig = ProviderConfig{ + ID: providerID, + Name: foundProvider.Name, + BaseURL: foundProvider.APIEndpoint, + Type: foundProvider.Type, + Disable: false, + ExtraHeaders: make(map[string]string), + ExtraParams: make(map[string]string), + Models: foundProvider.Models, + } + setKeyOrToken() + } else { + return fmt.Errorf("provider with ID %s not found in known providers", providerID) + } + cfg.Providers.Set(providerID, providerConfig) + return nil +} + +// ImportCopilot imports an existing GitHub Copilot token from disk if +// available and not already configured. +func (s *Service) ImportCopilot() (*oauth.Token, bool) { + if testing.Testing() { + return nil, false + } + + if s.HasConfigField("providers.copilot.api_key") || s.HasConfigField("providers.copilot.oauth") { + return nil, false + } + + diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() + if !hasDiskToken { + return nil, false + } + + slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") + token, err := copilot.RefreshToken(context.TODO(), diskToken) + if err != nil { + slog.Error("Unable to import GitHub Copilot token", "error", err) + return nil, false + } + + if err := s.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil { + return token, false + } + + if err := cmp.Or( + s.SetConfigField("providers.copilot.api_key", token.AccessToken), + s.SetConfigField("providers.copilot.oauth", token), + ); err != nil { + slog.Error("Unable to save GitHub Copilot token to disk", "error", err) + } + + slog.Info("GitHub Copilot successfully imported") + return token, true +} diff --git a/internal/ui/common/common.go b/internal/ui/common/common.go index 6e7c632474389aa5455295e4132818941bc18244..281bc7f8abcff0726cdc166c190e4ccb54c84d19 100644 --- a/internal/ui/common/common.go +++ b/internal/ui/common/common.go @@ -31,6 +31,11 @@ func (c *Common) Config() *config.Config { return c.App.Config() } +// ConfigService returns the config service associated with this [Common] instance. +func (c *Common) ConfigService() *config.Service { + return c.App.ConfigService() +} + // DefaultCommon returns the default common UI configurations. func DefaultCommon(app *app.App) *Common { s := styles.DefaultStyles() diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go index 9677763b2f4f2436376f5bf16ab58aed79140c68..a9df8bb39938481a7743ad243a54cabfee25c4e4 100644 --- a/internal/ui/dialog/api_key_input.go +++ b/internal/ui/dialog/api_key_input.go @@ -312,9 +312,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { } func (m *APIKeyInput) saveKeyAndContinue() Action { - cfg := m.com.Config() - - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value()) + err := m.com.ConfigService().SetProviderAPIKey(string(m.provider.ID), m.input.Value()) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 2f729e19995790fc1bb57fbea4b80191195df8da..e18b849bf78128be1dab3f62d65f04c103f37547 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -482,7 +482,7 @@ func (m *Models) setProviderItems() error { if len(validRecentItems) != len(recentItems) { // FIXME: Does this need to be here? Is it mutating the config during a read? - if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { + if err := m.com.ConfigService().SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { return fmt.Errorf("failed to update recent models: %w", err) } } diff --git a/internal/ui/dialog/oauth.go b/internal/ui/dialog/oauth.go index 93d5fe052db11d036d29d7790810807d5630bb57..ae18b04960418c6e8d7790f62aa7a2b7f24bd37a 100644 --- a/internal/ui/dialog/oauth.go +++ b/internal/ui/dialog/oauth.go @@ -373,9 +373,7 @@ func (d *OAuth) copyCodeAndOpenURL() tea.Cmd { } func (m *OAuth) saveKeyAndContinue() Action { - cfg := m.com.Config() - - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token) + err := m.com.ConfigService().SetProviderAPIKey(string(m.provider.ID), m.token) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 7f4f01c5bdc2e7240716cc5c41a27892a4bcedde..3bb552555b457e19a71f935a9ad7ea985f667857 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1208,7 +1208,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.Think = !currentModel.Think - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { return util.ReportError(err)() } m.com.App.UpdateAgentModel(context.TODO()) @@ -1249,7 +1249,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { // Attempt to import GitHub Copilot tokens from VSCode if available. if isCopilot && !isConfigured() { - m.com.Config().ImportCopilot() + m.com.ConfigService().ImportCopilot() } if !isConfigured() { @@ -1260,12 +1260,12 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } - if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { cmds = append(cmds, util.ReportError(err)) } else if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok { // Ensure small model is set is unset. smallModel := m.com.App.GetDefaultSmallModel(providerID) - if err := cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil { cmds = append(cmds, util.ReportError(err)) } } @@ -1311,7 +1311,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.ReasoningEffort = msg.Effort - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { cmds = append(cmds, util.ReportError(err)) break } @@ -2157,7 +2157,7 @@ func (m *UI) FullHelp() [][]key.Binding { func (m *UI) toggleCompactMode() tea.Cmd { m.forceCompactMode = !m.forceCompactMode - err := m.com.Config().SetCompactMode(m.forceCompactMode) + err := m.com.ConfigService().SetCompactMode(m.forceCompactMode) if err != nil { return util.ReportError(err) }