fix(auth): add better atomic refresh for hyper

Kieran Klukas created

Change summary

internal/config/atomicwrite.go      | 38 +++++++++++++++++++++++++++
internal/config/atomicwrite_test.go | 43 ++++++++++++++++++++++++++++++
internal/config/store.go            | 40 ++++++++++++++++++++-------
3 files changed, 110 insertions(+), 11 deletions(-)

Detailed changes

internal/config/atomicwrite.go 🔗

@@ -0,0 +1,38 @@
+package config
+
+import (
+	"os"
+	"path/filepath"
+)
+
+// atomicWriteFile writes data to a file atomically by writing to a unique
+// temporary file in the same directory and renaming it into place. This
+// prevents concurrent readers from observing a partially-written file.
+func atomicWriteFile(path string, data []byte, perm os.FileMode) error {
+	path = filepath.Clean(path)
+	dir := filepath.Dir(path)
+	f, err := os.CreateTemp(dir, filepath.Base(path)+".*.tmp")
+	if err != nil {
+		return err
+	}
+	tmp := f.Name()
+	if _, err := f.Write(data); err != nil {
+		f.Close()
+		os.Remove(tmp)
+		return err
+	}
+	if err := f.Chmod(perm); err != nil {
+		f.Close()
+		os.Remove(tmp)
+		return err
+	}
+	if err := f.Close(); err != nil {
+		os.Remove(tmp)
+		return err
+	}
+	if err := os.Rename(tmp, path); err != nil {
+		os.Remove(tmp)
+		return err
+	}
+	return nil
+}

internal/config/atomicwrite_test.go 🔗

@@ -0,0 +1,43 @@
+package config
+
+import (
+	"os"
+	"path/filepath"
+	"runtime"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestAtomicWriteFile(t *testing.T) {
+	t.Parallel()
+	dir := t.TempDir()
+	path := filepath.Join(dir, "test.json")
+
+	require.NoError(t, atomicWriteFile(path, []byte(`{"key":"value"}`), 0o600))
+
+	data, err := os.ReadFile(path)
+	require.NoError(t, err)
+	require.Equal(t, `{"key":"value"}`, string(data))
+
+	// No temp files should linger.
+	entries, err := os.ReadDir(dir)
+	require.NoError(t, err)
+	require.Len(t, entries, 1)
+	require.Equal(t, "test.json", entries[0].Name())
+}
+
+func TestAtomicWriteFile_PermissionsApplied(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skip("Windows does not support Unix file permissions")
+	}
+	t.Parallel()
+	dir := t.TempDir()
+	path := filepath.Join(dir, "test.json")
+
+	require.NoError(t, atomicWriteFile(path, []byte(`{}`), 0o600))
+
+	info, err := os.Stat(path)
+	require.NoError(t, err)
+	require.Equal(t, os.FileMode(0o600), info.Mode().Perm())
+}

internal/config/store.go 🔗

@@ -158,7 +158,7 @@ func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
 	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 		return fmt.Errorf("failed to create config directory %q: %w", path, err)
 	}
-	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
+	if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
 		return fmt.Errorf("failed to write config file: %w", err)
 	}
 
@@ -193,7 +193,7 @@ func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
 	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 		return fmt.Errorf("failed to create config directory %q: %w", path, err)
 	}
-	if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
+	if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
 		return fmt.Errorf("failed to write config file: %w", err)
 	}
 
@@ -302,7 +302,10 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
 // RefreshOAuthToken refreshes the OAuth token for the given provider.
 // Before making an external refresh request, it checks the config file on
 // disk to see if another Crush session has already refreshed the token. If
-// a newer token is found, it is used instead of refreshing.
+// a newer, unexpired token is found, it is used instead of refreshing. If
+// the exchange fails (e.g. because another session already rotated the
+// refresh token), the disk is re-checked to recover the other session's
+// token.
 func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
 	providerConfig, exists := s.config.Providers.Get(providerID)
 	if !exists {
@@ -318,15 +321,9 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 	newToken, err := s.loadTokenFromDisk(scope, providerID)
 	if err != nil {
 		slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err)
-	} else if newToken != nil && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+	} else if newToken != nil && !newToken.IsExpired() && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
 		slog.Info("Using token refreshed by another session", "provider", providerID)
-		providerConfig.OAuthToken = newToken
-		providerConfig.APIKey = newToken.AccessToken
-		if providerID == string(catwalk.InferenceProviderCopilot) {
-			providerConfig.SetupGitHubCopilot()
-		}
-		s.config.Providers.Set(providerID, providerConfig)
-		return nil
+		return s.applyToken(providerConfig, newToken, providerID)
 	}
 
 	var refreshedToken *oauth.Token
@@ -340,6 +337,16 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
 	}
 	if refreshErr != nil {
+		// The exchange may have failed because another session already
+		// rotated the refresh token. Re-read the config file and use the
+		// other session's token if available.
+		if diskToken, diskErr := s.loadTokenFromDisk(scope, providerID); diskErr == nil &&
+			diskToken != nil &&
+			!diskToken.IsExpired() &&
+			diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+			slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID)
+			return s.applyToken(providerConfig, diskToken, providerID)
+		}
 		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
 	}
 
@@ -364,6 +371,17 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 	return nil
 }
 
+// applyToken updates the in-memory provider config with the given token.
+func (s *ConfigStore) applyToken(providerConfig ProviderConfig, token *oauth.Token, providerID string) error {
+	providerConfig.OAuthToken = token
+	providerConfig.APIKey = token.AccessToken
+	if providerID == string(catwalk.InferenceProviderCopilot) {
+		providerConfig.SetupGitHubCopilot()
+	}
+	s.config.Providers.Set(providerID, providerConfig)
+	return nil
+}
+
 // loadTokenFromDisk reads the OAuth token for the given provider from the
 // config file on disk. Returns nil if the token is not found or matches the
 // current in-memory token.