fix(config): check config file for newer token before OAuth refresh

Andrey Nering created

When refreshing an OAuth token (e.g. for Hyper), check the config file
on disk first to see if another Crush session already refreshed it.
If the disk token differs from the in-memory one, use the disk token
and skip the external refresh request. Prevents unnecessary token
churn and 401s when multiple Crush sessions are running.

💘 Generated with Crush

Assisted-by: Kimi K2.6 via Crush <crush@charm.land>

Change summary

internal/config/store.go      |  69 ++++++++++++-
internal/config/store_test.go | 182 +++++++++++++++++++++++++++++++++++++
2 files changed, 244 insertions(+), 7 deletions(-)

Detailed changes

internal/config/store.go 🔗

@@ -3,6 +3,7 @@ package config
 import (
 	"cmp"
 	"context"
+	"encoding/json"
 	"fmt"
 	"log/slog"
 	"os"
@@ -288,6 +289,9 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
 }
 
 // RefreshOAuthToken refreshes the OAuth token for the given provider.
+// Before making an external refresh request, it checks the config file on
+// disk to see if another Crush session has already refreshed the token. If
+// a newer token is found, it is used instead of refreshing.
 func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
 	providerConfig, exists := s.config.Providers.Get(providerID)
 	if !exists {
@@ -298,13 +302,29 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
 	}
 
-	var newToken *oauth.Token
+	// Check if another session refreshed the token recently by reading
+	// the current token from the config file on disk.
+	newToken, err := s.loadTokenFromDisk(scope, providerID)
+	if err != nil {
+		slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err)
+	} else if newToken != nil && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+		slog.Info("Using token refreshed by another session", "provider", providerID)
+		providerConfig.OAuthToken = newToken
+		providerConfig.APIKey = newToken.AccessToken
+		if providerID == string(catwalk.InferenceProviderCopilot) {
+			providerConfig.SetupGitHubCopilot()
+		}
+		s.config.Providers.Set(providerID, providerConfig)
+		return nil
+	}
+
+	var refreshedToken *oauth.Token
 	var refreshErr error
 	switch providerID {
 	case string(catwalk.InferenceProviderCopilot):
-		newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+		refreshedToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
 	case hyperp.Name:
-		newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
+		refreshedToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
 	default:
 		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
 	}
@@ -313,8 +333,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 	}
 
 	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
-	providerConfig.OAuthToken = newToken
-	providerConfig.APIKey = newToken.AccessToken
+	providerConfig.OAuthToken = refreshedToken
+	providerConfig.APIKey = refreshedToken.AccessToken
 
 	switch providerID {
 	case string(catwalk.InferenceProviderCopilot):
@@ -324,8 +344,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 	s.config.Providers.Set(providerID, providerConfig)
 
 	if err := cmp.Or(
-		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
-		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
+		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), refreshedToken.AccessToken),
+		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), refreshedToken),
 	); err != nil {
 		return fmt.Errorf("failed to persist refreshed token: %w", err)
 	}
@@ -333,6 +353,41 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 	return nil
 }
 
+// loadTokenFromDisk reads the OAuth token for the given provider from the
+// config file on disk. Returns nil if the token is not found or matches the
+// current in-memory token.
+func (s *ConfigStore) loadTokenFromDisk(scope Scope, providerID string) (*oauth.Token, error) {
+	path, err := s.configPath(scope)
+	if err != nil {
+		return nil, err
+	}
+
+	data, err := os.ReadFile(path)
+	if err != nil {
+		if os.IsNotExist(err) {
+			return nil, nil
+		}
+		return nil, err
+	}
+
+	oauthKey := fmt.Sprintf("providers.%s.oauth", providerID)
+	oauthResult := gjson.Get(string(data), oauthKey)
+	if !oauthResult.Exists() {
+		return nil, nil
+	}
+
+	var token oauth.Token
+	if err := json.Unmarshal([]byte(oauthResult.Raw), &token); err != nil {
+		return nil, err
+	}
+
+	if token.AccessToken == "" {
+		return nil, nil
+	}
+
+	return &token, nil
+}
+
 // recordRecentModel records a model in the recent models list.
 func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
 	if model.Provider == "" || model.Model == "" {

internal/config/store_test.go 🔗

@@ -8,6 +8,8 @@ import (
 	"testing"
 	"time"
 
+	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/oauth"
 	"github.com/stretchr/testify/require"
 )
 
@@ -510,3 +512,183 @@ func TestAutoReloadDisabledDuringReload(t *testing.T) {
 	// Verify reload completed successfully
 	require.False(t, store.autoReloadDisabled, "autoReloadDisabled should be false after ReloadFromDisk")
 }
+
+func TestLoadTokenFromDisk_ReturnsNewerToken(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "crush.json")
+
+	// Create config file with a newer token on disk
+	configContent := `{
+		"providers": {
+			"hyper": {
+				"oauth": {
+					"access_token": "newer-token-from-disk",
+					"refresh_token": "refresh-abc",
+					"expires_in": 3600,
+					"expires_at": 9999999999
+				}
+			}
+		}
+	}`
+	require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+	store := &ConfigStore{
+		config:         &Config{},
+		globalDataPath: configPath,
+	}
+
+	token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+	require.NoError(t, err)
+	require.NotNil(t, token)
+	require.Equal(t, "newer-token-from-disk", token.AccessToken)
+	require.Equal(t, "refresh-abc", token.RefreshToken)
+	require.Equal(t, 3600, token.ExpiresIn)
+	require.Equal(t, int64(9999999999), token.ExpiresAt)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenSameToken(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "crush.json")
+
+	// Create config file with the same token
+	configContent := `{
+		"providers": {
+			"hyper": {
+				"oauth": {
+					"access_token": "same-token",
+					"refresh_token": "refresh-abc",
+					"expires_in": 3600,
+					"expires_at": 9999999999
+				}
+			}
+		}
+	}`
+	require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+	store := &ConfigStore{
+		config:         &Config{},
+		globalDataPath: configPath,
+	}
+
+	token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+	require.NoError(t, err)
+	require.NotNil(t, token)
+	require.Equal(t, "same-token", token.AccessToken)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenFileMissing(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "nonexistent.json")
+
+	store := &ConfigStore{
+		config:         &Config{},
+		globalDataPath: configPath,
+	}
+
+	token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+	require.NoError(t, err)
+	require.Nil(t, token)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenProviderMissing(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "crush.json")
+
+	// Create config file without the hyper provider
+	configContent := `{"providers": {"openai": {"api_key": "test-key"}}}`
+	require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+	store := &ConfigStore{
+		config:         &Config{},
+		globalDataPath: configPath,
+	}
+
+	token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+	require.NoError(t, err)
+	require.Nil(t, token)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenOAuthMissing(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "crush.json")
+
+	// Create config file with provider but no OAuth token
+	configContent := `{"providers": {"hyper": {"api_key": "test-key"}}}`
+	require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+	store := &ConfigStore{
+		config:         &Config{},
+		globalDataPath: configPath,
+	}
+
+	token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+	require.NoError(t, err)
+	require.Nil(t, token)
+}
+
+func TestRefreshOAuthToken_UsesDiskTokenWhenDifferent(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "crush.json")
+
+	// Create config file with a newer token on disk
+	configContent := `{
+		"providers": {
+			"hyper": {
+				"api_key": "newer-access-token",
+				"oauth": {
+					"access_token": "newer-access-token",
+					"refresh_token": "refresh-abc",
+					"expires_in": 3600,
+					"expires_at": 9999999999
+				}
+			}
+		}
+	}`
+	require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+	// Set up store with an older in-memory token
+	oldToken := &oauth.Token{
+		AccessToken:  "older-access-token",
+		RefreshToken: "refresh-abc",
+		ExpiresIn:    3600,
+		ExpiresAt:    time.Now().Add(-time.Hour).Unix(), // Expired
+	}
+
+	providers := csync.NewMap[string, ProviderConfig]()
+	providers.Set("hyper", ProviderConfig{
+		ID:         "hyper",
+		Name:       "Hyper",
+		APIKey:     oldToken.AccessToken,
+		OAuthToken: oldToken,
+	})
+
+	store := &ConfigStore{
+		config: &Config{
+			Providers: providers,
+		},
+		globalDataPath: configPath,
+	}
+
+	// Refresh should use the disk token without making an external call
+	err := store.RefreshOAuthToken(context.Background(), ScopeGlobal, "hyper")
+	require.NoError(t, err)
+
+	// Verify the in-memory token was updated to the disk token
+	updatedConfig, ok := store.config.Providers.Get("hyper")
+	require.True(t, ok)
+	require.Equal(t, "newer-access-token", updatedConfig.APIKey)
+	require.Equal(t, "newer-access-token", updatedConfig.OAuthToken.AccessToken)
+	require.Equal(t, "refresh-abc", updatedConfig.OAuthToken.RefreshToken)
+}