@@ -3,6 +3,7 @@ package config
import (
"cmp"
"context"
+ "encoding/json"
"fmt"
"log/slog"
"os"
@@ -288,6 +289,9 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
}
// RefreshOAuthToken refreshes the OAuth token for the given provider.
+// Before making an external refresh request, it checks the config file on
+// disk to see if another Crush session has already refreshed the token. If
+// a newer token is found, it is used instead of refreshing.
func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
providerConfig, exists := s.config.Providers.Get(providerID)
if !exists {
@@ -298,13 +302,29 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
return fmt.Errorf("provider %s does not have an OAuth token", providerID)
}
- var newToken *oauth.Token
+ // Check if another session refreshed the token recently by reading
+ // the current token from the config file on disk.
+ newToken, err := s.loadTokenFromDisk(scope, providerID)
+ if err != nil {
+ slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err)
+ } else if newToken != nil && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+ slog.Info("Using token refreshed by another session", "provider", providerID)
+ providerConfig.OAuthToken = newToken
+ providerConfig.APIKey = newToken.AccessToken
+ if providerID == string(catwalk.InferenceProviderCopilot) {
+ providerConfig.SetupGitHubCopilot()
+ }
+ s.config.Providers.Set(providerID, providerConfig)
+ return nil
+ }
+
+ var refreshedToken *oauth.Token
var refreshErr error
switch providerID {
case string(catwalk.InferenceProviderCopilot):
- newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ refreshedToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
case hyperp.Name:
- newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ refreshedToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
default:
return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
}
@@ -313,8 +333,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
}
slog.Info("Successfully refreshed OAuth token", "provider", providerID)
- providerConfig.OAuthToken = newToken
- providerConfig.APIKey = newToken.AccessToken
+ providerConfig.OAuthToken = refreshedToken
+ providerConfig.APIKey = refreshedToken.AccessToken
switch providerID {
case string(catwalk.InferenceProviderCopilot):
@@ -324,8 +344,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
s.config.Providers.Set(providerID, providerConfig)
if err := cmp.Or(
- s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
- s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
+ s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), refreshedToken.AccessToken),
+ s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), refreshedToken),
); err != nil {
return fmt.Errorf("failed to persist refreshed token: %w", err)
}
@@ -333,6 +353,41 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
return nil
}
+// loadTokenFromDisk reads the OAuth token for the given provider from the
+// config file on disk. Returns nil if the token is not found or matches the
+// current in-memory token.
+func (s *ConfigStore) loadTokenFromDisk(scope Scope, providerID string) (*oauth.Token, error) {
+ path, err := s.configPath(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ oauthKey := fmt.Sprintf("providers.%s.oauth", providerID)
+ oauthResult := gjson.Get(string(data), oauthKey)
+ if !oauthResult.Exists() {
+ return nil, nil
+ }
+
+ var token oauth.Token
+ if err := json.Unmarshal([]byte(oauthResult.Raw), &token); err != nil {
+ return nil, err
+ }
+
+ if token.AccessToken == "" {
+ return nil, nil
+ }
+
+ return &token, nil
+}
+
// recordRecentModel records a model in the recent models list.
func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
if model.Provider == "" || model.Model == "" {
@@ -8,6 +8,8 @@ import (
"testing"
"time"
+ "github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/oauth"
"github.com/stretchr/testify/require"
)
@@ -510,3 +512,183 @@ func TestAutoReloadDisabledDuringReload(t *testing.T) {
// Verify reload completed successfully
require.False(t, store.autoReloadDisabled, "autoReloadDisabled should be false after ReloadFromDisk")
}
+
+func TestLoadTokenFromDisk_ReturnsNewerToken(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "crush.json")
+
+ // Create config file with a newer token on disk
+ configContent := `{
+ "providers": {
+ "hyper": {
+ "oauth": {
+ "access_token": "newer-token-from-disk",
+ "refresh_token": "refresh-abc",
+ "expires_in": 3600,
+ "expires_at": 9999999999
+ }
+ }
+ }
+ }`
+ require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+ store := &ConfigStore{
+ config: &Config{},
+ globalDataPath: configPath,
+ }
+
+ token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+ require.NoError(t, err)
+ require.NotNil(t, token)
+ require.Equal(t, "newer-token-from-disk", token.AccessToken)
+ require.Equal(t, "refresh-abc", token.RefreshToken)
+ require.Equal(t, 3600, token.ExpiresIn)
+ require.Equal(t, int64(9999999999), token.ExpiresAt)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenSameToken(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "crush.json")
+
+ // Create config file with the same token
+ configContent := `{
+ "providers": {
+ "hyper": {
+ "oauth": {
+ "access_token": "same-token",
+ "refresh_token": "refresh-abc",
+ "expires_in": 3600,
+ "expires_at": 9999999999
+ }
+ }
+ }
+ }`
+ require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+ store := &ConfigStore{
+ config: &Config{},
+ globalDataPath: configPath,
+ }
+
+ token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+ require.NoError(t, err)
+ require.NotNil(t, token)
+ require.Equal(t, "same-token", token.AccessToken)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenFileMissing(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "nonexistent.json")
+
+ store := &ConfigStore{
+ config: &Config{},
+ globalDataPath: configPath,
+ }
+
+ token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+ require.NoError(t, err)
+ require.Nil(t, token)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenProviderMissing(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "crush.json")
+
+ // Create config file without the hyper provider
+ configContent := `{"providers": {"openai": {"api_key": "test-key"}}}`
+ require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+ store := &ConfigStore{
+ config: &Config{},
+ globalDataPath: configPath,
+ }
+
+ token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+ require.NoError(t, err)
+ require.Nil(t, token)
+}
+
+func TestLoadTokenFromDisk_ReturnsNilWhenOAuthMissing(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "crush.json")
+
+ // Create config file with provider but no OAuth token
+ configContent := `{"providers": {"hyper": {"api_key": "test-key"}}}`
+ require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+ store := &ConfigStore{
+ config: &Config{},
+ globalDataPath: configPath,
+ }
+
+ token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper")
+ require.NoError(t, err)
+ require.Nil(t, token)
+}
+
+func TestRefreshOAuthToken_UsesDiskTokenWhenDifferent(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "crush.json")
+
+ // Create config file with a newer token on disk
+ configContent := `{
+ "providers": {
+ "hyper": {
+ "api_key": "newer-access-token",
+ "oauth": {
+ "access_token": "newer-access-token",
+ "refresh_token": "refresh-abc",
+ "expires_in": 3600,
+ "expires_at": 9999999999
+ }
+ }
+ }
+ }`
+ require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600))
+
+ // Set up store with an older in-memory token
+ oldToken := &oauth.Token{
+ AccessToken: "older-access-token",
+ RefreshToken: "refresh-abc",
+ ExpiresIn: 3600,
+ ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Expired
+ }
+
+ providers := csync.NewMap[string, ProviderConfig]()
+ providers.Set("hyper", ProviderConfig{
+ ID: "hyper",
+ Name: "Hyper",
+ APIKey: oldToken.AccessToken,
+ OAuthToken: oldToken,
+ })
+
+ store := &ConfigStore{
+ config: &Config{
+ Providers: providers,
+ },
+ globalDataPath: configPath,
+ }
+
+ // Refresh should use the disk token without making an external call
+ err := store.RefreshOAuthToken(context.Background(), ScopeGlobal, "hyper")
+ require.NoError(t, err)
+
+ // Verify the in-memory token was updated to the disk token
+ updatedConfig, ok := store.config.Providers.Get("hyper")
+ require.True(t, ok)
+ require.Equal(t, "newer-access-token", updatedConfig.APIKey)
+ require.Equal(t, "newer-access-token", updatedConfig.OAuthToken.AccessToken)
+ require.Equal(t, "refresh-abc", updatedConfig.OAuthToken.RefreshToken)
+}