From 99bc5ce850c1da02207ceba39f92afee6cba73fa Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Mon, 4 May 2026 17:10:10 -0300 Subject: [PATCH] fix(config): check config file for newer token before OAuth refresh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When refreshing an OAuth token (e.g. for Hyper), check the config file on disk first to see if another Crush session already refreshed it. If the disk token differs from the in-memory one, use the disk token and skip the external refresh request. Prevents unnecessary token churn and 401s when multiple Crush sessions are running. 💘 Generated with Crush Assisted-by: Kimi K2.6 via Crush --- internal/config/store.go | 69 +++++++++++-- internal/config/store_test.go | 182 ++++++++++++++++++++++++++++++++++ 2 files changed, 244 insertions(+), 7 deletions(-) diff --git a/internal/config/store.go b/internal/config/store.go index 7c3d0f80cdebf545d2601ea78f490eddef3067e7..c44fceba0ac87133a787b1351de213802a4e17cc 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -3,6 +3,7 @@ package config import ( "cmp" "context" + "encoding/json" "fmt" "log/slog" "os" @@ -288,6 +289,9 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a } // RefreshOAuthToken refreshes the OAuth token for the given provider. +// Before making an external refresh request, it checks the config file on +// disk to see if another Crush session has already refreshed the token. If +// a newer token is found, it is used instead of refreshing. func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error { providerConfig, exists := s.config.Providers.Get(providerID) if !exists { @@ -298,13 +302,29 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return fmt.Errorf("provider %s does not have an OAuth token", providerID) } - var newToken *oauth.Token + // Check if another session refreshed the token recently by reading + // the current token from the config file on disk. + newToken, err := s.loadTokenFromDisk(scope, providerID) + if err != nil { + slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err) + } else if newToken != nil && newToken.AccessToken != providerConfig.OAuthToken.AccessToken { + slog.Info("Using token refreshed by another session", "provider", providerID) + providerConfig.OAuthToken = newToken + providerConfig.APIKey = newToken.AccessToken + if providerID == string(catwalk.InferenceProviderCopilot) { + providerConfig.SetupGitHubCopilot() + } + s.config.Providers.Set(providerID, providerConfig) + return nil + } + + var refreshedToken *oauth.Token var refreshErr error switch providerID { case string(catwalk.InferenceProviderCopilot): - newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + refreshedToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) case hyperp.Name: - newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + refreshedToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) default: return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } @@ -313,8 +333,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid } slog.Info("Successfully refreshed OAuth token", "provider", providerID) - providerConfig.OAuthToken = newToken - providerConfig.APIKey = newToken.AccessToken + providerConfig.OAuthToken = refreshedToken + providerConfig.APIKey = refreshedToken.AccessToken switch providerID { case string(catwalk.InferenceProviderCopilot): @@ -324,8 +344,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid s.config.Providers.Set(providerID, providerConfig) if err := cmp.Or( - s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), - s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), refreshedToken.AccessToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), refreshedToken), ); err != nil { return fmt.Errorf("failed to persist refreshed token: %w", err) } @@ -333,6 +353,41 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return nil } +// loadTokenFromDisk reads the OAuth token for the given provider from the +// config file on disk. Returns nil if the token is not found or matches the +// current in-memory token. +func (s *ConfigStore) loadTokenFromDisk(scope Scope, providerID string) (*oauth.Token, error) { + path, err := s.configPath(scope) + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + oauthKey := fmt.Sprintf("providers.%s.oauth", providerID) + oauthResult := gjson.Get(string(data), oauthKey) + if !oauthResult.Exists() { + return nil, nil + } + + var token oauth.Token + if err := json.Unmarshal([]byte(oauthResult.Raw), &token); err != nil { + return nil, err + } + + if token.AccessToken == "" { + return nil, nil + } + + return &token, nil +} + // recordRecentModel records a model in the recent models list. func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error { if model.Provider == "" || model.Model == "" { diff --git a/internal/config/store_test.go b/internal/config/store_test.go index 46d51440870c0361b6e0246881d4522fa044363c..3f9441ea02315a6ac75012d956acbcab217d1c91 100644 --- a/internal/config/store_test.go +++ b/internal/config/store_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/oauth" "github.com/stretchr/testify/require" ) @@ -510,3 +512,183 @@ func TestAutoReloadDisabledDuringReload(t *testing.T) { // Verify reload completed successfully require.False(t, store.autoReloadDisabled, "autoReloadDisabled should be false after ReloadFromDisk") } + +func TestLoadTokenFromDisk_ReturnsNewerToken(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with a newer token on disk + configContent := `{ + "providers": { + "hyper": { + "oauth": { + "access_token": "newer-token-from-disk", + "refresh_token": "refresh-abc", + "expires_in": 3600, + "expires_at": 9999999999 + } + } + } + }` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.NotNil(t, token) + require.Equal(t, "newer-token-from-disk", token.AccessToken) + require.Equal(t, "refresh-abc", token.RefreshToken) + require.Equal(t, 3600, token.ExpiresIn) + require.Equal(t, int64(9999999999), token.ExpiresAt) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenSameToken(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with the same token + configContent := `{ + "providers": { + "hyper": { + "oauth": { + "access_token": "same-token", + "refresh_token": "refresh-abc", + "expires_in": 3600, + "expires_at": 9999999999 + } + } + } + }` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.NotNil(t, token) + require.Equal(t, "same-token", token.AccessToken) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenFileMissing(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "nonexistent.json") + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.Nil(t, token) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenProviderMissing(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file without the hyper provider + configContent := `{"providers": {"openai": {"api_key": "test-key"}}}` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.Nil(t, token) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenOAuthMissing(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with provider but no OAuth token + configContent := `{"providers": {"hyper": {"api_key": "test-key"}}}` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.Nil(t, token) +} + +func TestRefreshOAuthToken_UsesDiskTokenWhenDifferent(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with a newer token on disk + configContent := `{ + "providers": { + "hyper": { + "api_key": "newer-access-token", + "oauth": { + "access_token": "newer-access-token", + "refresh_token": "refresh-abc", + "expires_in": 3600, + "expires_at": 9999999999 + } + } + } + }` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + // Set up store with an older in-memory token + oldToken := &oauth.Token{ + AccessToken: "older-access-token", + RefreshToken: "refresh-abc", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Expired + } + + providers := csync.NewMap[string, ProviderConfig]() + providers.Set("hyper", ProviderConfig{ + ID: "hyper", + Name: "Hyper", + APIKey: oldToken.AccessToken, + OAuthToken: oldToken, + }) + + store := &ConfigStore{ + config: &Config{ + Providers: providers, + }, + globalDataPath: configPath, + } + + // Refresh should use the disk token without making an external call + err := store.RefreshOAuthToken(context.Background(), ScopeGlobal, "hyper") + require.NoError(t, err) + + // Verify the in-memory token was updated to the disk token + updatedConfig, ok := store.config.Providers.Get("hyper") + require.True(t, ok) + require.Equal(t, "newer-access-token", updatedConfig.APIKey) + require.Equal(t, "newer-access-token", updatedConfig.OAuthToken.AccessToken) + require.Equal(t, "refresh-abc", updatedConfig.OAuthToken.RefreshToken) +}