diff --git a/internal/config/store.go b/internal/config/store.go index 7c3d0f80cdebf545d2601ea78f490eddef3067e7..c44fceba0ac87133a787b1351de213802a4e17cc 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -3,6 +3,7 @@ package config import ( "cmp" "context" + "encoding/json" "fmt" "log/slog" "os" @@ -288,6 +289,9 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a } // RefreshOAuthToken refreshes the OAuth token for the given provider. +// Before making an external refresh request, it checks the config file on +// disk to see if another Crush session has already refreshed the token. If +// a newer token is found, it is used instead of refreshing. func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error { providerConfig, exists := s.config.Providers.Get(providerID) if !exists { @@ -298,13 +302,29 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return fmt.Errorf("provider %s does not have an OAuth token", providerID) } - var newToken *oauth.Token + // Check if another session refreshed the token recently by reading + // the current token from the config file on disk. + newToken, err := s.loadTokenFromDisk(scope, providerID) + if err != nil { + slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err) + } else if newToken != nil && newToken.AccessToken != providerConfig.OAuthToken.AccessToken { + slog.Info("Using token refreshed by another session", "provider", providerID) + providerConfig.OAuthToken = newToken + providerConfig.APIKey = newToken.AccessToken + if providerID == string(catwalk.InferenceProviderCopilot) { + providerConfig.SetupGitHubCopilot() + } + s.config.Providers.Set(providerID, providerConfig) + return nil + } + + var refreshedToken *oauth.Token var refreshErr error switch providerID { case string(catwalk.InferenceProviderCopilot): - newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + refreshedToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) case hyperp.Name: - newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + refreshedToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) default: return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } @@ -313,8 +333,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid } slog.Info("Successfully refreshed OAuth token", "provider", providerID) - providerConfig.OAuthToken = newToken - providerConfig.APIKey = newToken.AccessToken + providerConfig.OAuthToken = refreshedToken + providerConfig.APIKey = refreshedToken.AccessToken switch providerID { case string(catwalk.InferenceProviderCopilot): @@ -324,8 +344,8 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid s.config.Providers.Set(providerID, providerConfig) if err := cmp.Or( - s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), - s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), refreshedToken.AccessToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), refreshedToken), ); err != nil { return fmt.Errorf("failed to persist refreshed token: %w", err) } @@ -333,6 +353,41 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return nil } +// loadTokenFromDisk reads the OAuth token for the given provider from the +// config file on disk. Returns nil if the token is not found or matches the +// current in-memory token. +func (s *ConfigStore) loadTokenFromDisk(scope Scope, providerID string) (*oauth.Token, error) { + path, err := s.configPath(scope) + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + oauthKey := fmt.Sprintf("providers.%s.oauth", providerID) + oauthResult := gjson.Get(string(data), oauthKey) + if !oauthResult.Exists() { + return nil, nil + } + + var token oauth.Token + if err := json.Unmarshal([]byte(oauthResult.Raw), &token); err != nil { + return nil, err + } + + if token.AccessToken == "" { + return nil, nil + } + + return &token, nil +} + // recordRecentModel records a model in the recent models list. func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error { if model.Provider == "" || model.Model == "" { diff --git a/internal/config/store_test.go b/internal/config/store_test.go index 46d51440870c0361b6e0246881d4522fa044363c..3f9441ea02315a6ac75012d956acbcab217d1c91 100644 --- a/internal/config/store_test.go +++ b/internal/config/store_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/oauth" "github.com/stretchr/testify/require" ) @@ -510,3 +512,183 @@ func TestAutoReloadDisabledDuringReload(t *testing.T) { // Verify reload completed successfully require.False(t, store.autoReloadDisabled, "autoReloadDisabled should be false after ReloadFromDisk") } + +func TestLoadTokenFromDisk_ReturnsNewerToken(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with a newer token on disk + configContent := `{ + "providers": { + "hyper": { + "oauth": { + "access_token": "newer-token-from-disk", + "refresh_token": "refresh-abc", + "expires_in": 3600, + "expires_at": 9999999999 + } + } + } + }` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.NotNil(t, token) + require.Equal(t, "newer-token-from-disk", token.AccessToken) + require.Equal(t, "refresh-abc", token.RefreshToken) + require.Equal(t, 3600, token.ExpiresIn) + require.Equal(t, int64(9999999999), token.ExpiresAt) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenSameToken(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with the same token + configContent := `{ + "providers": { + "hyper": { + "oauth": { + "access_token": "same-token", + "refresh_token": "refresh-abc", + "expires_in": 3600, + "expires_at": 9999999999 + } + } + } + }` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.NotNil(t, token) + require.Equal(t, "same-token", token.AccessToken) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenFileMissing(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "nonexistent.json") + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.Nil(t, token) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenProviderMissing(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file without the hyper provider + configContent := `{"providers": {"openai": {"api_key": "test-key"}}}` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.Nil(t, token) +} + +func TestLoadTokenFromDisk_ReturnsNilWhenOAuthMissing(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with provider but no OAuth token + configContent := `{"providers": {"hyper": {"api_key": "test-key"}}}` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + store := &ConfigStore{ + config: &Config{}, + globalDataPath: configPath, + } + + token, err := store.loadTokenFromDisk(ScopeGlobal, "hyper") + require.NoError(t, err) + require.Nil(t, token) +} + +func TestRefreshOAuthToken_UsesDiskTokenWhenDifferent(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create config file with a newer token on disk + configContent := `{ + "providers": { + "hyper": { + "api_key": "newer-access-token", + "oauth": { + "access_token": "newer-access-token", + "refresh_token": "refresh-abc", + "expires_in": 3600, + "expires_at": 9999999999 + } + } + } + }` + require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) + + // Set up store with an older in-memory token + oldToken := &oauth.Token{ + AccessToken: "older-access-token", + RefreshToken: "refresh-abc", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Expired + } + + providers := csync.NewMap[string, ProviderConfig]() + providers.Set("hyper", ProviderConfig{ + ID: "hyper", + Name: "Hyper", + APIKey: oldToken.AccessToken, + OAuthToken: oldToken, + }) + + store := &ConfigStore{ + config: &Config{ + Providers: providers, + }, + globalDataPath: configPath, + } + + // Refresh should use the disk token without making an external call + err := store.RefreshOAuthToken(context.Background(), ScopeGlobal, "hyper") + require.NoError(t, err) + + // Verify the in-memory token was updated to the disk token + updatedConfig, ok := store.config.Providers.Get("hyper") + require.True(t, ok) + require.Equal(t, "newer-access-token", updatedConfig.APIKey) + require.Equal(t, "newer-access-token", updatedConfig.OAuthToken.AccessToken) + require.Equal(t, "refresh-abc", updatedConfig.OAuthToken.RefreshToken) +}