Detailed changes
@@ -62,6 +62,7 @@ type Coordinator interface {
type coordinator struct {
cfg *config.Config
+ cfgSvc *config.Service
sessions session.Service
messages message.Service
permissions permission.Service
@@ -77,7 +78,7 @@ type coordinator struct {
func NewCoordinator(
ctx context.Context,
- cfg *config.Config,
+ cfgSvc *config.Service,
sessions session.Service,
messages message.Service,
permissions permission.Service,
@@ -85,8 +86,10 @@ func NewCoordinator(
filetracker filetracker.Service,
lspClients *csync.Map[string, *lsp.Client],
) (Coordinator, error) {
+ cfg := cfgSvc.Config()
c := &coordinator{
cfg: cfg,
+ cfgSvc: cfgSvc,
sessions: sessions,
messages: messages,
permissions: permissions,
@@ -891,7 +894,7 @@ func (c *coordinator) isUnauthorized(err error) bool {
}
func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
- if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
+ if err := c.cfgSvc.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
return err
}
@@ -60,7 +60,8 @@ type App struct {
LSPClients *csync.Map[string, *lsp.Client]
- config *config.Config
+ configService *config.Service
+ config *config.Config
serviceEventsWG *sync.WaitGroup
eventsCtx context.Context
@@ -73,7 +74,8 @@ type App struct {
}
// New initializes a new application instance.
-func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
+func New(ctx context.Context, conn *sql.DB, cfgSvc *config.Service) (*App, error) {
+ cfg := cfgSvc.Config()
q := db.New(conn)
sessions := session.NewService(q, conn)
messages := message.NewService(q)
@@ -94,7 +96,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
globalCtx: ctx,
- config: cfg,
+ configService: cfgSvc,
+ config: cfg,
events: make(chan tea.Msg, 100),
serviceEventsWG: &sync.WaitGroup{},
@@ -125,6 +128,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
return app, nil
}
+// ConfigService returns the config service.
+func (app *App) ConfigService() *config.Service {
+ return app.configService
+}
+
// Config returns the application configuration.
func (app *App) Config() *config.Config {
return app.config
@@ -462,7 +470,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
var err error
app.AgentCoordinator, err = agent.NewCoordinator(
ctx,
- app.config,
+ app.configService,
app.Sessions,
app.Messages,
app.Permissions,
@@ -52,16 +52,16 @@ crush login copilot
}
switch provider {
case "hyper":
- return loginHyper(app.Config())
+ return loginHyper(app.ConfigService())
case "copilot", "github", "github-copilot":
- return loginCopilot(app.Config())
+ return loginCopilot(app.ConfigService())
default:
return fmt.Errorf("unknown platform: %s", args[0])
}
},
}
-func loginHyper(cfg *config.Config) error {
+func loginHyper(cfg *config.Service) error {
if !hyperp.Enabled() {
return fmt.Errorf("hyper not enabled")
}
@@ -123,7 +123,7 @@ func loginHyper(cfg *config.Config) error {
return nil
}
-func loginCopilot(cfg *config.Config) error {
+func loginCopilot(cfg *config.Service) error {
ctx := getLoginContext()
if cfg.HasConfigField("providers.copilot.oauth") {
@@ -222,7 +222,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
return nil, err
}
- appInstance, err := app.New(ctx, conn, cfg)
+ appInstance, err := app.New(ctx, conn, svc)
if err != nil {
slog.Error("Failed to create app instance", "error", err)
return nil, err
@@ -1,7 +1,6 @@
package config
import (
- "cmp"
"context"
"fmt"
"log/slog"
@@ -13,12 +12,10 @@ import (
"time"
"charm.land/catwalk/pkg/catwalk"
- hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/copilot"
- "github.com/charmbracelet/crush/internal/oauth/hyper"
"github.com/invopop/jsonschema"
)
@@ -457,14 +454,6 @@ func (c *Config) SmallModel() *catwalk.Model {
return c.GetModel(model.Provider, model.Model)
}
-func (c *Config) SetCompactMode(enabled bool) error {
- if c.Options == nil {
- c.Options = &Options{}
- }
- c.Options.TUI.CompactMode = enabled
- return c.SetConfigField("options.tui.compact_mode", enabled)
-}
-
func (c *Config) Resolve(key string) (string, error) {
if c.resolver == nil {
return "", fmt.Errorf("no variable resolver configured")
@@ -472,187 +461,19 @@ func (c *Config) Resolve(key string) (string, error) {
return c.resolver.ResolveValue(key)
}
-func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
- c.Models[modelType] = model
- if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
- return fmt.Errorf("failed to update preferred model: %w", err)
- }
- if err := c.recordRecentModel(modelType, model); err != nil {
- return err
- }
- return nil
-}
-
-func (c *Config) configStore() Store {
- if c.store == nil {
- c.store = NewFileStore(c.dataConfigDir)
- }
- return c.store
-}
-
-func (c *Config) HasConfigField(key string) bool {
- return HasField(c.configStore(), key)
-}
-
-func (c *Config) SetConfigField(key string, value any) error {
+func (c *Config) setConfigField(key string, value any) error {
return SetField(c.configStore(), key, value)
}
-func (c *Config) RemoveConfigField(key string) error {
+func (c *Config) removeConfigField(key string) error {
return RemoveField(c.configStore(), key)
}
-// RefreshOAuthToken refreshes the OAuth token for the given provider.
-func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error {
- providerConfig, exists := c.Providers.Get(providerID)
- if !exists {
- return fmt.Errorf("provider %s not found", providerID)
- }
-
- if providerConfig.OAuthToken == nil {
- return fmt.Errorf("provider %s does not have an OAuth token", providerID)
- }
-
- var newToken *oauth.Token
- var refreshErr error
- switch providerID {
- case string(catwalk.InferenceProviderCopilot):
- newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
- case hyperp.Name:
- newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
- default:
- return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
- }
- if refreshErr != nil {
- return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
- }
-
- slog.Info("Successfully refreshed OAuth token", "provider", providerID)
- providerConfig.OAuthToken = newToken
- providerConfig.APIKey = newToken.AccessToken
-
- switch providerID {
- case string(catwalk.InferenceProviderCopilot):
- providerConfig.SetupGitHubCopilot()
- }
-
- c.Providers.Set(providerID, providerConfig)
-
- if err := cmp.Or(
- c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
- c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken),
- ); err != nil {
- return fmt.Errorf("failed to persist refreshed token: %w", err)
- }
-
- return nil
-}
-
-func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error {
- var providerConfig ProviderConfig
- var exists bool
- var setKeyOrToken func()
-
- switch v := apiKey.(type) {
- case string:
- if err := c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
- return fmt.Errorf("failed to save api key to config file: %w", err)
- }
- setKeyOrToken = func() { providerConfig.APIKey = v }
- case *oauth.Token:
- if err := cmp.Or(
- c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
- c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v),
- ); err != nil {
- return err
- }
- setKeyOrToken = func() {
- providerConfig.APIKey = v.AccessToken
- providerConfig.OAuthToken = v
- switch providerID {
- case string(catwalk.InferenceProviderCopilot):
- providerConfig.SetupGitHubCopilot()
- }
- }
- }
-
- providerConfig, exists = c.Providers.Get(providerID)
- if exists {
- setKeyOrToken()
- c.Providers.Set(providerID, providerConfig)
- return nil
- }
-
- var foundProvider *catwalk.Provider
- for _, p := range c.knownProviders {
- if string(p.ID) == providerID {
- foundProvider = &p
- break
- }
- }
-
- if foundProvider != nil {
- // Create new provider config based on known provider
- providerConfig = ProviderConfig{
- ID: providerID,
- Name: foundProvider.Name,
- BaseURL: foundProvider.APIEndpoint,
- Type: foundProvider.Type,
- Disable: false,
- ExtraHeaders: make(map[string]string),
- ExtraParams: make(map[string]string),
- Models: foundProvider.Models,
- }
- setKeyOrToken()
- } else {
- return fmt.Errorf("provider with ID %s not found in known providers", providerID)
- }
- // Store the updated provider config
- c.Providers.Set(providerID, providerConfig)
- return nil
-}
-
-const maxRecentModelsPerType = 5
-
-func (c *Config) recordRecentModel(modelType SelectedModelType, model SelectedModel) error {
- if model.Provider == "" || model.Model == "" {
- return nil
- }
-
- if c.RecentModels == nil {
- c.RecentModels = make(map[SelectedModelType][]SelectedModel)
- }
-
- eq := func(a, b SelectedModel) bool {
- return a.Provider == b.Provider && a.Model == b.Model
- }
-
- entry := SelectedModel{
- Provider: model.Provider,
- Model: model.Model,
- }
-
- current := c.RecentModels[modelType]
- withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
- return eq(existing, entry)
- })
-
- updated := append([]SelectedModel{entry}, withoutCurrent...)
- if len(updated) > maxRecentModelsPerType {
- updated = updated[:maxRecentModelsPerType]
- }
-
- if slices.EqualFunc(current, updated, eq) {
- return nil
- }
-
- c.RecentModels[modelType] = updated
-
- if err := c.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
- return fmt.Errorf("failed to persist recent models: %w", err)
+func (c *Config) configStore() Store {
+ if c.store == nil {
+ c.store = NewFileStore(c.dataConfigDir)
}
-
- return nil
+ return c.store
}
func allToolNames() []string {
@@ -1,48 +0,0 @@
-package config
-
-import (
- "cmp"
- "context"
- "log/slog"
- "testing"
-
- "charm.land/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/oauth"
- "github.com/charmbracelet/crush/internal/oauth/copilot"
-)
-
-func (c *Config) ImportCopilot() (*oauth.Token, bool) {
- if testing.Testing() {
- return nil, false
- }
-
- if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") {
- return nil, false
- }
-
- diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
- if !hasDiskToken {
- return nil, false
- }
-
- slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
- token, err := copilot.RefreshToken(context.TODO(), diskToken)
- if err != nil {
- slog.Error("Unable to import GitHub Copilot token", "error", err)
- return nil, false
- }
-
- if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil {
- return token, false
- }
-
- if err := cmp.Or(
- c.SetConfigField("providers.copilot.api_key", token.AccessToken),
- c.SetConfigField("providers.copilot.oauth", token),
- ); err != nil {
- slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
- }
-
- slog.Info("GitHub Copilot successfully imported")
- return token, true
-}
@@ -220,7 +220,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
switch {
case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil:
// Claude Code subscription is not supported anymore. Remove to show onboarding.
- c.RemoveConfigField("providers.anthropic")
+ c.removeConfigField("providers.anthropic")
c.Providers.Del(string(p.ID))
continue
case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil:
@@ -558,9 +558,8 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro
model := c.GetModel(large.Provider, large.Model)
if model == nil {
large = defaultLarge
- // override the model type to large
- err := c.UpdatePreferredModel(SelectedModelTypeLarge, large)
- if err != nil {
+ c.Models[SelectedModelTypeLarge] = large
+ if err := c.setConfigField(fmt.Sprintf("models.%s", SelectedModelTypeLarge), large); err != nil {
return fmt.Errorf("failed to update preferred large model: %w", err)
}
} else {
@@ -602,9 +601,8 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro
model := c.GetModel(small.Provider, small.Model)
if model == nil {
small = defaultSmall
- // override the model type to small
- err := c.UpdatePreferredModel(SelectedModelTypeSmall, small)
- if err != nil {
+ c.Models[SelectedModelTypeSmall] = small
+ if err := c.setConfigField(fmt.Sprintf("models.%s", SelectedModelTypeSmall), small); err != nil {
return fmt.Errorf("failed to update preferred small model: %w", err)
}
} else {
@@ -31,15 +31,27 @@ func readRecentModels(t *testing.T, path string) map[string]any {
return rm
}
-func TestRecordRecentModel_AddsAndPersists(t *testing.T) {
- t.Parallel()
-
+func newTestService(t *testing.T) (*Service, string) {
+ t.Helper()
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ storePath := filepath.Join(dir, "config.json")
+ svc := &Service{
+ cfg: cfg,
+ store: NewFileStore(storePath),
+ workingDir: dir,
+ }
+ return svc, storePath
+}
+
+func TestRecordRecentModel_AddsAndPersists(t *testing.T) {
+ t.Parallel()
- err := cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})
+ svc, storePath := newTestService(t)
+ cfg := svc.cfg
+
+ err := svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})
require.NoError(t, err)
// in-memory state
@@ -48,7 +60,7 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) {
require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model)
// persisted state
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, storePath)
large, ok := rm[string(SelectedModelTypeLarge)].([]any)
require.True(t, ok)
require.Len(t, large, 1)
@@ -61,16 +73,14 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) {
func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) {
t.Parallel()
- dir := t.TempDir()
- cfg := &Config{}
- cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ svc, _ := newTestService(t)
+ cfg := svc.cfg
// Add two entries
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"}))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"}))
// Re-add first; should move to front and not duplicate
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
got := cfg.RecentModels[SelectedModelTypeLarge]
require.Len(t, got, 2)
@@ -81,10 +91,8 @@ func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) {
func TestRecordRecentModel_TrimsToMax(t *testing.T) {
t.Parallel()
- dir := t.TempDir()
- cfg := &Config{}
- cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ svc, storePath := newTestService(t)
+ cfg := svc.cfg
// Insert 6 unique models; max is 5
entries := []SelectedModel{
@@ -96,7 +104,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) {
{Provider: "p6", Model: "m6"},
}
for _, e := range entries {
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, e))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, e))
}
// in-memory state
@@ -110,7 +118,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) {
require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4])
// persisted state: verify trimmed to 5 and newest-first order
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, storePath)
large, ok := rm[string(SelectedModelTypeLarge)].([]any)
require.True(t, ok)
require.Len(t, large, 5)
@@ -126,15 +134,13 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) {
func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) {
t.Parallel()
- dir := t.TempDir()
- cfg := &Config{}
- cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ svc, storePath := newTestService(t)
+ cfg := svc.cfg
// Missing provider
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"}))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"}))
// Missing model
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""}))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""}))
_, ok := cfg.RecentModels[SelectedModelTypeLarge]
// Map may be initialized, but should have no entries
@@ -142,8 +148,8 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) {
require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0)
}
// No file should be written (stat via fs.FS)
- baseDir := filepath.Dir(cfg.dataConfigDir)
- fileName := filepath.Base(cfg.dataConfigDir)
+ baseDir := filepath.Dir(storePath)
+ fileName := filepath.Base(storePath)
_, err := fs.Stat(os.DirFS(baseDir), fileName)
require.True(t, os.IsNotExist(err))
}
@@ -151,16 +157,13 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) {
func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) {
t.Parallel()
- dir := t.TempDir()
- cfg := &Config{}
- cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ svc, storePath := newTestService(t)
entry := SelectedModel{Provider: "openai", Model: "gpt-4o"}
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry))
- baseDir := filepath.Dir(cfg.dataConfigDir)
- fileName := filepath.Base(cfg.dataConfigDir)
+ baseDir := filepath.Dir(storePath)
+ fileName := filepath.Base(storePath)
before, err := fs.ReadFile(os.DirFS(baseDir), fileName)
require.NoError(t, err)
@@ -170,7 +173,7 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) {
beforeMod := stBefore.ModTime()
// Re-record same entry should be a no-op (no write)
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, entry))
after, err := fs.ReadFile(os.DirFS(baseDir), fileName)
require.NoError(t, err)
@@ -185,20 +188,18 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) {
func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) {
t.Parallel()
- dir := t.TempDir()
- cfg := &Config{}
- cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ svc, storePath := newTestService(t)
+ cfg := svc.cfg
sel := SelectedModel{Provider: "openai", Model: "gpt-4o"}
- require.NoError(t, cfg.UpdatePreferredModel(SelectedModelTypeSmall, sel))
+ require.NoError(t, svc.UpdatePreferredModel(SelectedModelTypeSmall, sel))
// in-memory
require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall])
require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1)
// persisted (read via fs.FS)
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, storePath)
small, ok := rm[string(SelectedModelTypeSmall)].([]any)
require.True(t, ok)
require.Len(t, small, 1)
@@ -207,17 +208,15 @@ func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) {
func TestRecordRecentModel_TypeIsolation(t *testing.T) {
t.Parallel()
- dir := t.TempDir()
- cfg := &Config{}
- cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ svc, storePath := newTestService(t)
+ cfg := svc.cfg
// Add models to both large and small types
largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"}
smallModel := SelectedModel{Provider: "anthropic", Model: "claude"}
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, largeModel))
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeSmall, smallModel))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, largeModel))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeSmall, smallModel))
// in-memory: verify types maintain separate histories
require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1)
@@ -227,14 +226,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) {
// Add another to large, verify small unchanged
anotherLarge := SelectedModel{Provider: "google", Model: "gemini"}
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, anotherLarge))
+ require.NoError(t, svc.recordRecentModel(SelectedModelTypeLarge, anotherLarge))
require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2)
require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1)
require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0])
// persisted state: verify both types exist with correct lengths and contents
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, storePath)
large, ok := rm[string(SelectedModelTypeLarge)].([]any)
require.True(t, ok)
@@ -1,6 +1,19 @@
package config
-import "charm.land/catwalk/pkg/catwalk"
+import (
+ "cmp"
+ "context"
+ "fmt"
+ "log/slog"
+ "slices"
+ "testing"
+
+ "charm.land/catwalk/pkg/catwalk"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/oauth"
+ "github.com/charmbracelet/crush/internal/oauth/copilot"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
+)
// Service is the central access point for configuration. It wraps the
// raw Config data and owns all internal state that was previously held
@@ -20,3 +33,242 @@ type Service struct {
func (s *Service) Config() *Config {
return s.cfg
}
+
+// HasConfigField returns true if the given dotted key path exists in
+// the persisted config data.
+func (s *Service) HasConfigField(key string) bool {
+ return HasField(s.store, key)
+}
+
+// SetConfigField sets a value at the given dotted key path and
+// persists it.
+func (s *Service) SetConfigField(key string, value any) error {
+ return SetField(s.store, key, value)
+}
+
+// RemoveConfigField deletes a value at the given dotted key path and
+// persists it.
+func (s *Service) RemoveConfigField(key string) error {
+ return RemoveField(s.store, key)
+}
+
+// SetCompactMode toggles compact mode and persists the change.
+func (s *Service) SetCompactMode(enabled bool) error {
+ cfg := s.cfg
+ if cfg.Options == nil {
+ cfg.Options = &Options{}
+ }
+ if cfg.Options.TUI == nil {
+ cfg.Options.TUI = &TUIOptions{}
+ }
+ cfg.Options.TUI.CompactMode = enabled
+ return s.SetConfigField("options.tui.compact_mode", enabled)
+}
+
+// UpdatePreferredModel updates the selected model for the given type
+// and persists the change, also recording it in the recent models
+// list.
+func (s *Service) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
+ s.cfg.Models[modelType] = model
+ if err := s.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
+ return fmt.Errorf("failed to update preferred model: %w", err)
+ }
+ if err := s.recordRecentModel(modelType, model); err != nil {
+ return err
+ }
+ return nil
+}
+
+const maxRecentModelsPerType = 5
+
+func (s *Service) recordRecentModel(modelType SelectedModelType, model SelectedModel) error {
+ if model.Provider == "" || model.Model == "" {
+ return nil
+ }
+
+ cfg := s.cfg
+ if cfg.RecentModels == nil {
+ cfg.RecentModels = make(map[SelectedModelType][]SelectedModel)
+ }
+
+ eq := func(a, b SelectedModel) bool {
+ return a.Provider == b.Provider && a.Model == b.Model
+ }
+
+ entry := SelectedModel{
+ Provider: model.Provider,
+ Model: model.Model,
+ }
+
+ current := cfg.RecentModels[modelType]
+ withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
+ return eq(existing, entry)
+ })
+
+ updated := append([]SelectedModel{entry}, withoutCurrent...)
+ if len(updated) > maxRecentModelsPerType {
+ updated = updated[:maxRecentModelsPerType]
+ }
+
+ if slices.EqualFunc(current, updated, eq) {
+ return nil
+ }
+
+ cfg.RecentModels[modelType] = updated
+
+ if err := s.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
+ return fmt.Errorf("failed to persist recent models: %w", err)
+ }
+
+ return nil
+}
+
+// RefreshOAuthToken refreshes the OAuth token for the given provider.
+func (s *Service) RefreshOAuthToken(ctx context.Context, providerID string) error {
+ cfg := s.cfg
+ providerConfig, exists := cfg.Providers.Get(providerID)
+ if !exists {
+ return fmt.Errorf("provider %s not found", providerID)
+ }
+
+ if providerConfig.OAuthToken == nil {
+ return fmt.Errorf("provider %s does not have an OAuth token", providerID)
+ }
+
+ var newToken *oauth.Token
+ var refreshErr error
+ switch providerID {
+ case string(catwalk.InferenceProviderCopilot):
+ newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ case hyperp.Name:
+ newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ default:
+ return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
+ }
+ if refreshErr != nil {
+ return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
+ }
+
+ slog.Info("Successfully refreshed OAuth token", "provider", providerID)
+ providerConfig.OAuthToken = newToken
+ providerConfig.APIKey = newToken.AccessToken
+
+ switch providerID {
+ case string(catwalk.InferenceProviderCopilot):
+ providerConfig.SetupGitHubCopilot()
+ }
+
+ cfg.Providers.Set(providerID, providerConfig)
+
+ if err := cmp.Or(
+ s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
+ s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken),
+ ); err != nil {
+ return fmt.Errorf("failed to persist refreshed token: %w", err)
+ }
+
+ return nil
+}
+
+// SetProviderAPIKey sets the API key (string or *oauth.Token) for a
+// provider and persists the change.
+func (s *Service) SetProviderAPIKey(providerID string, apiKey any) error {
+ cfg := s.cfg
+ var providerConfig ProviderConfig
+ var exists bool
+ var setKeyOrToken func()
+
+ switch v := apiKey.(type) {
+ case string:
+ if err := s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
+ return fmt.Errorf("failed to save api key to config file: %w", err)
+ }
+ setKeyOrToken = func() { providerConfig.APIKey = v }
+ case *oauth.Token:
+ if err := cmp.Or(
+ s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
+ s.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v),
+ ); err != nil {
+ return err
+ }
+ setKeyOrToken = func() {
+ providerConfig.APIKey = v.AccessToken
+ providerConfig.OAuthToken = v
+ switch providerID {
+ case string(catwalk.InferenceProviderCopilot):
+ providerConfig.SetupGitHubCopilot()
+ }
+ }
+ }
+
+ providerConfig, exists = cfg.Providers.Get(providerID)
+ if exists {
+ setKeyOrToken()
+ cfg.Providers.Set(providerID, providerConfig)
+ return nil
+ }
+
+ var foundProvider *catwalk.Provider
+ for _, p := range s.knownProviders {
+ if string(p.ID) == providerID {
+ foundProvider = &p
+ break
+ }
+ }
+
+ if foundProvider != nil {
+ providerConfig = ProviderConfig{
+ ID: providerID,
+ Name: foundProvider.Name,
+ BaseURL: foundProvider.APIEndpoint,
+ Type: foundProvider.Type,
+ Disable: false,
+ ExtraHeaders: make(map[string]string),
+ ExtraParams: make(map[string]string),
+ Models: foundProvider.Models,
+ }
+ setKeyOrToken()
+ } else {
+ return fmt.Errorf("provider with ID %s not found in known providers", providerID)
+ }
+ cfg.Providers.Set(providerID, providerConfig)
+ return nil
+}
+
+// ImportCopilot imports an existing GitHub Copilot token from disk if
+// available and not already configured.
+func (s *Service) ImportCopilot() (*oauth.Token, bool) {
+ if testing.Testing() {
+ return nil, false
+ }
+
+ if s.HasConfigField("providers.copilot.api_key") || s.HasConfigField("providers.copilot.oauth") {
+ return nil, false
+ }
+
+ diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
+ if !hasDiskToken {
+ return nil, false
+ }
+
+ slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
+ token, err := copilot.RefreshToken(context.TODO(), diskToken)
+ if err != nil {
+ slog.Error("Unable to import GitHub Copilot token", "error", err)
+ return nil, false
+ }
+
+ if err := s.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil {
+ return token, false
+ }
+
+ if err := cmp.Or(
+ s.SetConfigField("providers.copilot.api_key", token.AccessToken),
+ s.SetConfigField("providers.copilot.oauth", token),
+ ); err != nil {
+ slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
+ }
+
+ slog.Info("GitHub Copilot successfully imported")
+ return token, true
+}
@@ -31,6 +31,11 @@ func (c *Common) Config() *config.Config {
return c.App.Config()
}
+// ConfigService returns the config service associated with this [Common] instance.
+func (c *Common) ConfigService() *config.Service {
+ return c.App.ConfigService()
+}
+
// DefaultCommon returns the default common UI configurations.
func DefaultCommon(app *app.App) *Common {
s := styles.DefaultStyles()
@@ -312,9 +312,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg {
}
func (m *APIKeyInput) saveKeyAndContinue() Action {
- cfg := m.com.Config()
-
- err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value())
+ err := m.com.ConfigService().SetProviderAPIKey(string(m.provider.ID), m.input.Value())
if err != nil {
return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
}
@@ -482,7 +482,7 @@ func (m *Models) setProviderItems() error {
if len(validRecentItems) != len(recentItems) {
// FIXME: Does this need to be here? Is it mutating the config during a read?
- if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
+ if err := m.com.ConfigService().SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
return fmt.Errorf("failed to update recent models: %w", err)
}
}
@@ -373,9 +373,7 @@ func (d *OAuth) copyCodeAndOpenURL() tea.Cmd {
}
func (m *OAuth) saveKeyAndContinue() Action {
- cfg := m.com.Config()
-
- err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token)
+ err := m.com.ConfigService().SetProviderAPIKey(string(m.provider.ID), m.token)
if err != nil {
return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
}
@@ -1208,7 +1208,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
currentModel := cfg.Models[agentCfg.Model]
currentModel.Think = !currentModel.Think
- if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil {
+ if err := m.com.ConfigService().UpdatePreferredModel(agentCfg.Model, currentModel); err != nil {
return util.ReportError(err)()
}
m.com.App.UpdateAgentModel(context.TODO())
@@ -1249,7 +1249,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
// Attempt to import GitHub Copilot tokens from VSCode if available.
if isCopilot && !isConfigured() {
- m.com.Config().ImportCopilot()
+ m.com.ConfigService().ImportCopilot()
}
if !isConfigured() {
@@ -1260,12 +1260,12 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
break
}
- if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil {
+ if err := m.com.ConfigService().UpdatePreferredModel(msg.ModelType, msg.Model); err != nil {
cmds = append(cmds, util.ReportError(err))
} else if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok {
// Ensure small model is set is unset.
smallModel := m.com.App.GetDefaultSmallModel(providerID)
- if err := cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil {
+ if err := m.com.ConfigService().UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil {
cmds = append(cmds, util.ReportError(err))
}
}
@@ -1311,7 +1311,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
currentModel := cfg.Models[agentCfg.Model]
currentModel.ReasoningEffort = msg.Effort
- if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil {
+ if err := m.com.ConfigService().UpdatePreferredModel(agentCfg.Model, currentModel); err != nil {
cmds = append(cmds, util.ReportError(err))
break
}
@@ -2157,7 +2157,7 @@ func (m *UI) FullHelp() [][]key.Binding {
func (m *UI) toggleCompactMode() tea.Cmd {
m.forceCompactMode = !m.forceCompactMode
- err := m.com.Config().SetCompactMode(m.forceCompactMode)
+ err := m.com.ConfigService().SetCompactMode(m.forceCompactMode)
if err != nil {
return util.ReportError(err)
}