diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index ad57b20c4470aa0180120798034db1bdb1de601a..97963fe9a728d0febb8da3ad0439fafba86babbe 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -62,6 +62,7 @@ type Coordinator interface { type coordinator struct { cfg *config.Config + cfgSvc *config.Service sessions session.Service messages message.Service permissions permission.Service @@ -77,7 +78,7 @@ type coordinator struct { func NewCoordinator( ctx context.Context, - cfg *config.Config, + cfgSvc *config.Service, sessions session.Service, messages message.Service, permissions permission.Service, @@ -85,8 +86,10 @@ func NewCoordinator( filetracker filetracker.Service, lspClients *csync.Map[string, *lsp.Client], ) (Coordinator, error) { + cfg := cfgSvc.Config() c := &coordinator{ cfg: cfg, + cfgSvc: cfgSvc, sessions: sessions, messages: messages, permissions: permissions, @@ -891,7 +894,7 @@ func (c *coordinator) isUnauthorized(err error) bool { } func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error { - if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { + if err := c.cfgSvc.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err) return err } diff --git a/internal/app/app.go b/internal/app/app.go index f0cabfa534a58401280fb5e9b973aa6f5a9d91c9..399e8a1abaf1ff59d8136e19daa64a50bfc1b81a 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -60,7 +60,8 @@ type App struct { LSPClients *csync.Map[string, *lsp.Client] - config *config.Config + configService *config.Service + config *config.Config serviceEventsWG *sync.WaitGroup eventsCtx context.Context @@ -73,7 +74,8 @@ type App struct { } // New initializes a new application instance. -func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { +func New(ctx context.Context, conn *sql.DB, cfgSvc *config.Service) (*App, error) { + cfg := cfgSvc.Config() q := db.New(conn) sessions := session.NewService(q, conn) messages := message.NewService(q) @@ -94,7 +96,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { globalCtx: ctx, - config: cfg, + configService: cfgSvc, + config: cfg, events: make(chan tea.Msg, 100), serviceEventsWG: &sync.WaitGroup{}, @@ -125,6 +128,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { return app, nil } +// ConfigService returns the config service. +func (app *App) ConfigService() *config.Service { + return app.configService +} + // Config returns the application configuration. func (app *App) Config() *config.Config { return app.config @@ -462,7 +470,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error { var err error app.AgentCoordinator, err = agent.NewCoordinator( ctx, - app.config, + app.configService, app.Sessions, app.Messages, app.Permissions, diff --git a/internal/cmd/login.go b/internal/cmd/login.go index bdad4547d6f583b5ae7e5a97bbbbd88a1421e6ee..9da3e5775ddd9888f54231180b68667c8d8bbce4 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -52,16 +52,16 @@ crush login copilot } switch provider { case "hyper": - return loginHyper(app.Config()) + return loginHyper(app.ConfigService()) case "copilot", "github", "github-copilot": - return loginCopilot(app.Config()) + return loginCopilot(app.ConfigService()) default: return fmt.Errorf("unknown platform: %s", args[0]) } }, } -func loginHyper(cfg *config.Config) error { +func loginHyper(cfg *config.Service) error { if !hyperp.Enabled() { return fmt.Errorf("hyper not enabled") } @@ -123,7 +123,7 @@ func loginHyper(cfg *config.Config) error { return nil } -func loginCopilot(cfg *config.Config) error { +func loginCopilot(cfg *config.Service) error { ctx := getLoginContext() if cfg.HasConfigField("providers.copilot.oauth") { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index f88f8ed6fce3cad33080a6585399cd08ab93e27f..02c05db511f6a28b4b7b907c2fbc1edcabede549 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -222,7 +222,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - appInstance, err := app.New(ctx, conn, cfg) + appInstance, err := app.New(ctx, conn, svc) if err != nil { slog.Error("Failed to create app instance", "error", err) return nil, err diff --git a/internal/config/config.go b/internal/config/config.go index 860143c9f805952409a1cb8572c5ac1c629d81a2..0dc07196f08ffffeffd96c532a7bded73409a35f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,6 @@ package config import ( - "cmp" "context" "fmt" "log/slog" @@ -13,12 +12,10 @@ import ( "time" "charm.land/catwalk/pkg/catwalk" - hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" - "github.com/charmbracelet/crush/internal/oauth/hyper" "github.com/invopop/jsonschema" ) @@ -457,14 +454,6 @@ func (c *Config) SmallModel() *catwalk.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) SetCompactMode(enabled bool) error { - if c.Options == nil { - c.Options = &Options{} - } - c.Options.TUI.CompactMode = enabled - return c.SetConfigField("options.tui.compact_mode", enabled) -} - func (c *Config) Resolve(key string) (string, error) { if c.resolver == nil { return "", fmt.Errorf("no variable resolver configured") @@ -472,187 +461,19 @@ func (c *Config) Resolve(key string) (string, error) { return c.resolver.ResolveValue(key) } -func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { - c.Models[modelType] = model - if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { - return fmt.Errorf("failed to update preferred model: %w", err) - } - if err := c.recordRecentModel(modelType, model); err != nil { - return err - } - return nil -} - -func (c *Config) configStore() Store { - if c.store == nil { - c.store = NewFileStore(c.dataConfigDir) - } - return c.store -} - -func (c *Config) HasConfigField(key string) bool { - return HasField(c.configStore(), key) -} - -func (c *Config) SetConfigField(key string, value any) error { +func (c *Config) setConfigField(key string, value any) error { return SetField(c.configStore(), key, value) } -func (c *Config) RemoveConfigField(key string) error { +func (c *Config) removeConfigField(key string) error { return RemoveField(c.configStore(), key) } -// RefreshOAuthToken refreshes the OAuth token for the given provider. -func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error { - providerConfig, exists := c.Providers.Get(providerID) - if !exists { - return fmt.Errorf("provider %s not found", providerID) - } - - if providerConfig.OAuthToken == nil { - return fmt.Errorf("provider %s does not have an OAuth token", providerID) - } - - var newToken *oauth.Token - var refreshErr error - switch providerID { - case string(catwalk.InferenceProviderCopilot): - newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) - case hyperp.Name: - newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) - default: - return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) - } - if refreshErr != nil { - return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) - } - - slog.Info("Successfully refreshed OAuth token", "provider", providerID) - providerConfig.OAuthToken = newToken - providerConfig.APIKey = newToken.AccessToken - - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - - c.Providers.Set(providerID, providerConfig) - - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken), - ); err != nil { - return fmt.Errorf("failed to persist refreshed token: %w", err) - } - - return nil -} - -func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { - var providerConfig ProviderConfig - var exists bool - var setKeyOrToken func() - - switch v := apiKey.(type) { - case string: - if err := c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { - return fmt.Errorf("failed to save api key to config file: %w", err) - } - setKeyOrToken = func() { providerConfig.APIKey = v } - case *oauth.Token: - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v), - ); err != nil { - return err - } - setKeyOrToken = func() { - providerConfig.APIKey = v.AccessToken - providerConfig.OAuthToken = v - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - } - } - - providerConfig, exists = c.Providers.Get(providerID) - if exists { - setKeyOrToken() - c.Providers.Set(providerID, providerConfig) - return nil - } - - var foundProvider *catwalk.Provider - for _, p := range c.knownProviders { - if string(p.ID) == providerID { - foundProvider = &p - break - } - } - - if foundProvider != nil { - // Create new provider config based on known provider - providerConfig = ProviderConfig{ - ID: providerID, - Name: foundProvider.Name, - BaseURL: foundProvider.APIEndpoint, - Type: foundProvider.Type, - Disable: false, - ExtraHeaders: make(map[string]string), - ExtraParams: make(map[string]string), - Models: foundProvider.Models, - } - setKeyOrToken() - } else { - return fmt.Errorf("provider with ID %s not found in known providers", providerID) - } - // Store the updated provider config - c.Providers.Set(providerID, providerConfig) - return nil -} - -const maxRecentModelsPerType = 5 - -func (c *Config) recordRecentModel(modelType SelectedModelType, model SelectedModel) error { - if model.Provider == "" || model.Model == "" { - return nil - } - - if c.RecentModels == nil { - c.RecentModels = make(map[SelectedModelType][]SelectedModel) - } - - eq := func(a, b SelectedModel) bool { - return a.Provider == b.Provider && a.Model == b.Model - } - - entry := SelectedModel{ - Provider: model.Provider, - Model: model.Model, - } - - current := c.RecentModels[modelType] - withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { - return eq(existing, entry) - }) - - updated := append([]SelectedModel{entry}, withoutCurrent...) - if len(updated) > maxRecentModelsPerType { - updated = updated[:maxRecentModelsPerType] - } - - if slices.EqualFunc(current, updated, eq) { - return nil - } - - c.RecentModels[modelType] = updated - - if err := c.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { - return fmt.Errorf("failed to persist recent models: %w", err) +func (c *Config) configStore() Store { + if c.store == nil { + c.store = NewFileStore(c.dataConfigDir) } - - return nil + return c.store } func allToolNames() []string { diff --git a/internal/config/copilot.go b/internal/config/copilot.go deleted file mode 100644 index d72e7d5048ba4d31c88d7f7152a6b3a9510960a2..0000000000000000000000000000000000000000 --- a/internal/config/copilot.go +++ /dev/null @@ -1,48 +0,0 @@ -package config - -import ( - "cmp" - "context" - "log/slog" - "testing" - - "charm.land/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/oauth" - "github.com/charmbracelet/crush/internal/oauth/copilot" -) - -func (c *Config) ImportCopilot() (*oauth.Token, bool) { - if testing.Testing() { - return nil, false - } - - if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") { - return nil, false - } - - diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() - if !hasDiskToken { - return nil, false - } - - slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") - token, err := copilot.RefreshToken(context.TODO(), diskToken) - if err != nil { - slog.Error("Unable to import GitHub Copilot token", "error", err) - return nil, false - } - - if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil { - return token, false - } - - if err := cmp.Or( - c.SetConfigField("providers.copilot.api_key", token.AccessToken), - c.SetConfigField("providers.copilot.oauth", token), - ); err != nil { - slog.Error("Unable to save GitHub Copilot token to disk", "error", err) - } - - slog.Info("GitHub Copilot successfully imported") - return token, true -} diff --git a/internal/config/load.go b/internal/config/load.go index 02095b5839822e05eef31ff1d9917418a66898a8..84530dbff179689102ee5e4e1863714674485e63 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -220,7 +220,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch { case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: // Claude Code subscription is not supported anymore. Remove to show onboarding. - c.RemoveConfigField("providers.anthropic") + c.removeConfigField("providers.anthropic") c.Providers.Del(string(p.ID)) continue case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil: @@ -558,9 +558,8 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro model := c.GetModel(large.Provider, large.Model) if model == nil { large = defaultLarge - // override the model type to large - err := c.UpdatePreferredModel(SelectedModelTypeLarge, large) - if err != nil { + c.Models[SelectedModelTypeLarge] = large + if err := c.setConfigField(fmt.Sprintf("models.%s", SelectedModelTypeLarge), large); err != nil { return fmt.Errorf("failed to update preferred large model: %w", err) } } else { @@ -602,9 +601,8 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro model := c.GetModel(small.Provider, small.Model) if model == nil { small = defaultSmall - // override the model type to small - err := c.UpdatePreferredModel(SelectedModelTypeSmall, small) - if err != nil { + c.Models[SelectedModelTypeSmall] = small + if err := c.setConfigField(fmt.Sprintf("models.%s", SelectedModelTypeSmall), small); err != nil { return fmt.Errorf("failed to update preferred small model: %w", err) } } else { diff --git a/internal/config/recent_models_test.go b/internal/config/recent_models_test.go index 739ddc0031a65cab261723772c3f38658dcd1561..ea7ded48601a2696f895bf0d2f01936f9897e7b2 100644 --- a/internal/config/recent_models_test.go +++ b/internal/config/recent_models_test.go @@ -31,15 +31,27 @@ func readRecentModels(t *testing.T, path string) map[string]any { return rm } -func TestRecordRecentModel_AddsAndPersists(t *testing.T) { - t.Parallel() - +func newTestService(t *testing.T) (*Service, string) { + t.Helper() dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + storePath := filepath.Join(dir, "config.json") + svc := &Service{ + cfg: cfg, + store: NewFileStore(storePath), + workingDir: dir, + } + return svc, storePath +} + +func TestRecordRecentModel_AddsAndPersists(t *testing.T) { + t.Parallel() - err := cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) + svc, storePath := newTestService(t) + cfg := svc.cfg + + err := svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) require.NoError(t, err) // in-memory state @@ -48,7 +60,7 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) { require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model) // persisted state - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 1) @@ -61,16 +73,14 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) { func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, _ := newTestService(t) + cfg := svc.cfg // Add two entries - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) // Re-add first; should move to front and not duplicate - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) got := cfg.RecentModels[SelectedModelTypeLarge] require.Len(t, got, 2) @@ -81,10 +91,8 @@ func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) { func TestRecordRecentModel_TrimsToMax(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg // Insert 6 unique models; max is 5 entries := []SelectedModel{ @@ -96,7 +104,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { {Provider: "p6", Model: "m6"}, } for _, e := range entries { - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, e)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, e)) } // in-memory state @@ -110,7 +118,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4]) // persisted state: verify trimmed to 5 and newest-first order - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 5) @@ -126,15 +134,13 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg // Missing provider - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) // Missing model - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) _, ok := cfg.RecentModels[SelectedModelTypeLarge] // Map may be initialized, but should have no entries @@ -142,8 +148,8 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0) } // No file should be written (stat via fs.FS) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(storePath) + fileName := filepath.Base(storePath) _, err := fs.Stat(os.DirFS(baseDir), fileName) require.True(t, os.IsNotExist(err)) } @@ -151,16 +157,13 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) entry := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry)) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(storePath) + fileName := filepath.Base(storePath) before, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -170,7 +173,7 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { beforeMod := stBefore.ModTime() // Re-record same entry should be a no-op (no write) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry)) after, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -185,20 +188,18 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg sel := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.UpdatePreferredModel(SelectedModelTypeSmall, sel)) + require.NoError(t, svc.UpdatePreferredModel(SelectedModelTypeSmall, sel)) // in-memory require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall]) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) // persisted (read via fs.FS) - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) small, ok := rm[string(SelectedModelTypeSmall)].([]any) require.True(t, ok) require.Len(t, small, 1) @@ -207,17 +208,15 @@ func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) { func TestRecordRecentModel_TypeIsolation(t *testing.T) { t.Parallel() - dir := t.TempDir() - cfg := &Config{} - cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + svc, storePath := newTestService(t) + cfg := svc.cfg // Add models to both large and small types largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"} smallModel := SelectedModel{Provider: "anthropic", Model: "claude"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, largeModel)) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeSmall, smallModel)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, largeModel)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeSmall, smallModel)) // in-memory: verify types maintain separate histories require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1) @@ -227,14 +226,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) { // Add another to large, verify small unchanged anotherLarge := SelectedModel{Provider: "google", Model: "gemini"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, anotherLarge)) + require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, anotherLarge)) require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0]) // persisted state: verify both types exist with correct lengths and contents - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, storePath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) diff --git a/internal/config/service.go b/internal/config/service.go index 5fd98bcb91b7586f44773fdde8b2da498db7f86b..2d595b5d919315d0bc2904566cf73d8717cb5bce 100644 --- a/internal/config/service.go +++ b/internal/config/service.go @@ -1,6 +1,19 @@ package config -import "charm.land/catwalk/pkg/catwalk" +import ( + "cmp" + "context" + "fmt" + "log/slog" + "slices" + "testing" + + "charm.land/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/charmbracelet/crush/internal/oauth/hyper" +) // Service is the central access point for configuration. It wraps the // raw Config data and owns all internal state that was previously held @@ -20,3 +33,242 @@ type Service struct { func (s *Service) Config() *Config { return s.cfg } + +// HasConfigField returns true if the given dotted key path exists in +// the persisted config data. +func (s *Service) HasConfigField(key string) bool { + return HasField(s.store, key) +} + +// SetConfigField sets a value at the given dotted key path and +// persists it. +func (s *Service) SetConfigField(key string, value any) error { + return SetField(s.store, key, value) +} + +// RemoveConfigField deletes a value at the given dotted key path and +// persists it. +func (s *Service) RemoveConfigField(key string) error { + return RemoveField(s.store, key) +} + +// SetCompactMode toggles compact mode and persists the change. +func (s *Service) SetCompactMode(enabled bool) error { + cfg := s.cfg + if cfg.Options == nil { + cfg.Options = &Options{} + } + if cfg.Options.TUI == nil { + cfg.Options.TUI = &TUIOptions{} + } + cfg.Options.TUI.CompactMode = enabled + return s.SetConfigField("options.tui.compact_mode", enabled) +} + +// UpdatePreferredModel updates the selected model for the given type +// and persists the change, also recording it in the recent models +// list. +func (s *Service) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { + s.cfg.Models[modelType] = model + if err := s.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { + return fmt.Errorf("failed to update preferred model: %w", err) + } + if err := s.recordRecentModel(modelType, model); err != nil { + return err + } + return nil +} + +const maxRecentModelsPerType = 5 + +func (s *Service) recordRecentModel(modelType SelectedModelType, model SelectedModel) error { + if model.Provider == "" || model.Model == "" { + return nil + } + + cfg := s.cfg + if cfg.RecentModels == nil { + cfg.RecentModels = make(map[SelectedModelType][]SelectedModel) + } + + eq := func(a, b SelectedModel) bool { + return a.Provider == b.Provider && a.Model == b.Model + } + + entry := SelectedModel{ + Provider: model.Provider, + Model: model.Model, + } + + current := cfg.RecentModels[modelType] + withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { + return eq(existing, entry) + }) + + updated := append([]SelectedModel{entry}, withoutCurrent...) + if len(updated) > maxRecentModelsPerType { + updated = updated[:maxRecentModelsPerType] + } + + if slices.EqualFunc(current, updated, eq) { + return nil + } + + cfg.RecentModels[modelType] = updated + + if err := s.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { + return fmt.Errorf("failed to persist recent models: %w", err) + } + + return nil +} + +// RefreshOAuthToken refreshes the OAuth token for the given provider. +func (s *Service) RefreshOAuthToken(ctx context.Context, providerID string) error { + cfg := s.cfg + providerConfig, exists := cfg.Providers.Get(providerID) + if !exists { + return fmt.Errorf("provider %s not found", providerID) + } + + if providerConfig.OAuthToken == nil { + return fmt.Errorf("provider %s does not have an OAuth token", providerID) + } + + var newToken *oauth.Token + var refreshErr error + switch providerID { + case string(catwalk.InferenceProviderCopilot): + newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + case hyperp.Name: + newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + default: + return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) + } + if refreshErr != nil { + return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) + } + + slog.Info("Successfully refreshed OAuth token", "provider", providerID) + providerConfig.OAuthToken = newToken + providerConfig.APIKey = newToken.AccessToken + + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + + cfg.Providers.Set(providerID, providerConfig) + + if err := cmp.Or( + s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), + s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken), + ); err != nil { + return fmt.Errorf("failed to persist refreshed token: %w", err) + } + + return nil +} + +// SetProviderAPIKey sets the API key (string or *oauth.Token) for a +// provider and persists the change. +func (s *Service) SetProviderAPIKey(providerID string, apiKey any) error { + cfg := s.cfg + var providerConfig ProviderConfig + var exists bool + var setKeyOrToken func() + + switch v := apiKey.(type) { + case string: + if err := s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { + return fmt.Errorf("failed to save api key to config file: %w", err) + } + setKeyOrToken = func() { providerConfig.APIKey = v } + case *oauth.Token: + if err := cmp.Or( + s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), + s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v), + ); err != nil { + return err + } + setKeyOrToken = func() { + providerConfig.APIKey = v.AccessToken + providerConfig.OAuthToken = v + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + } + } + + providerConfig, exists = cfg.Providers.Get(providerID) + if exists { + setKeyOrToken() + cfg.Providers.Set(providerID, providerConfig) + return nil + } + + var foundProvider *catwalk.Provider + for _, p := range s.knownProviders { + if string(p.ID) == providerID { + foundProvider = &p + break + } + } + + if foundProvider != nil { + providerConfig = ProviderConfig{ + ID: providerID, + Name: foundProvider.Name, + BaseURL: foundProvider.APIEndpoint, + Type: foundProvider.Type, + Disable: false, + ExtraHeaders: make(map[string]string), + ExtraParams: make(map[string]string), + Models: foundProvider.Models, + } + setKeyOrToken() + } else { + return fmt.Errorf("provider with ID %s not found in known providers", providerID) + } + cfg.Providers.Set(providerID, providerConfig) + return nil +} + +// ImportCopilot imports an existing GitHub Copilot token from disk if +// available and not already configured. +func (s *Service) ImportCopilot() (*oauth.Token, bool) { + if testing.Testing() { + return nil, false + } + + if s.HasConfigField("providers.copilot.api_key") || s.HasConfigField("providers.copilot.oauth") { + return nil, false + } + + diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() + if !hasDiskToken { + return nil, false + } + + slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") + token, err := copilot.RefreshToken(context.TODO(), diskToken) + if err != nil { + slog.Error("Unable to import GitHub Copilot token", "error", err) + return nil, false + } + + if err := s.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil { + return token, false + } + + if err := cmp.Or( + s.SetConfigField("providers.copilot.api_key", token.AccessToken), + s.SetConfigField("providers.copilot.oauth", token), + ); err != nil { + slog.Error("Unable to save GitHub Copilot token to disk", "error", err) + } + + slog.Info("GitHub Copilot successfully imported") + return token, true +} diff --git a/internal/ui/common/common.go b/internal/ui/common/common.go index 6e7c632474389aa5455295e4132818941bc18244..281bc7f8abcff0726cdc166c190e4ccb54c84d19 100644 --- a/internal/ui/common/common.go +++ b/internal/ui/common/common.go @@ -31,6 +31,11 @@ func (c *Common) Config() *config.Config { return c.App.Config() } +// ConfigService returns the config service associated with this [Common] instance. +func (c *Common) ConfigService() *config.Service { + return c.App.ConfigService() +} + // DefaultCommon returns the default common UI configurations. func DefaultCommon(app *app.App) *Common { s := styles.DefaultStyles() diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go index 9677763b2f4f2436376f5bf16ab58aed79140c68..a9df8bb39938481a7743ad243a54cabfee25c4e4 100644 --- a/internal/ui/dialog/api_key_input.go +++ b/internal/ui/dialog/api_key_input.go @@ -312,9 +312,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { } func (m *APIKeyInput) saveKeyAndContinue() Action { - cfg := m.com.Config() - - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value()) + err := m.com.ConfigService().SetProviderAPIKey(string(m.provider.ID), m.input.Value()) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 2f729e19995790fc1bb57fbea4b80191195df8da..e18b849bf78128be1dab3f62d65f04c103f37547 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -482,7 +482,7 @@ func (m *Models) setProviderItems() error { if len(validRecentItems) != len(recentItems) { // FIXME: Does this need to be here? Is it mutating the config during a read? - if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { + if err := m.com.ConfigService().SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { return fmt.Errorf("failed to update recent models: %w", err) } } diff --git a/internal/ui/dialog/oauth.go b/internal/ui/dialog/oauth.go index 93d5fe052db11d036d29d7790810807d5630bb57..ae18b04960418c6e8d7790f62aa7a2b7f24bd37a 100644 --- a/internal/ui/dialog/oauth.go +++ b/internal/ui/dialog/oauth.go @@ -373,9 +373,7 @@ func (d *OAuth) copyCodeAndOpenURL() tea.Cmd { } func (m *OAuth) saveKeyAndContinue() Action { - cfg := m.com.Config() - - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token) + err := m.com.ConfigService().SetProviderAPIKey(string(m.provider.ID), m.token) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 7f4f01c5bdc2e7240716cc5c41a27892a4bcedde..3bb552555b457e19a71f935a9ad7ea985f667857 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1208,7 +1208,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.Think = !currentModel.Think - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { return util.ReportError(err)() } m.com.App.UpdateAgentModel(context.TODO()) @@ -1249,7 +1249,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { // Attempt to import GitHub Copilot tokens from VSCode if available. if isCopilot && !isConfigured() { - m.com.Config().ImportCopilot() + m.com.ConfigService().ImportCopilot() } if !isConfigured() { @@ -1260,12 +1260,12 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } - if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { cmds = append(cmds, util.ReportError(err)) } else if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok { // Ensure small model is set is unset. smallModel := m.com.App.GetDefaultSmallModel(providerID) - if err := cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil { cmds = append(cmds, util.ReportError(err)) } } @@ -1311,7 +1311,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.ReasoningEffort = msg.Effort - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.ConfigService().UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { cmds = append(cmds, util.ReportError(err)) break } @@ -2157,7 +2157,7 @@ func (m *UI) FullHelp() [][]key.Binding { func (m *UI) toggleCompactMode() tea.Cmd { m.forceCompactMode = !m.forceCompactMode - err := m.com.Config().SetCompactMode(m.forceCompactMode) + err := m.com.ConfigService().SetCompactMode(m.forceCompactMode) if err != nil { return util.ReportError(err) }