From c8379c306f3c13b14225106d70c87a8b89ce16ec Mon Sep 17 00:00:00 2001 From: Kieran Klukas Date: Wed, 13 May 2026 17:56:05 -0400 Subject: [PATCH] fix(auth): add better atomic refresh for hyper --- internal/config/atomicwrite.go | 38 +++++++++++++++++++++++++ internal/config/atomicwrite_test.go | 43 +++++++++++++++++++++++++++++ internal/config/store.go | 40 +++++++++++++++++++-------- 3 files changed, 110 insertions(+), 11 deletions(-) create mode 100644 internal/config/atomicwrite.go create mode 100644 internal/config/atomicwrite_test.go diff --git a/internal/config/atomicwrite.go b/internal/config/atomicwrite.go new file mode 100644 index 0000000000000000000000000000000000000000..7e981fa11e89550184002fa1232104b846bf65c8 --- /dev/null +++ b/internal/config/atomicwrite.go @@ -0,0 +1,38 @@ +package config + +import ( + "os" + "path/filepath" +) + +// atomicWriteFile writes data to a file atomically by writing to a unique +// temporary file in the same directory and renaming it into place. This +// prevents concurrent readers from observing a partially-written file. +func atomicWriteFile(path string, data []byte, perm os.FileMode) error { + path = filepath.Clean(path) + dir := filepath.Dir(path) + f, err := os.CreateTemp(dir, filepath.Base(path)+".*.tmp") + if err != nil { + return err + } + tmp := f.Name() + if _, err := f.Write(data); err != nil { + f.Close() + os.Remove(tmp) + return err + } + if err := f.Chmod(perm); err != nil { + f.Close() + os.Remove(tmp) + return err + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return err + } + if err := os.Rename(tmp, path); err != nil { + os.Remove(tmp) + return err + } + return nil +} diff --git a/internal/config/atomicwrite_test.go b/internal/config/atomicwrite_test.go new file mode 100644 index 0000000000000000000000000000000000000000..089ee5c78eae7fecdb5bc3cd3d8f74713098fc38 --- /dev/null +++ b/internal/config/atomicwrite_test.go @@ -0,0 +1,43 @@ +package config + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAtomicWriteFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + + require.NoError(t, atomicWriteFile(path, []byte(`{"key":"value"}`), 0o600)) + + data, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, `{"key":"value"}`, string(data)) + + // No temp files should linger. + entries, err := os.ReadDir(dir) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "test.json", entries[0].Name()) +} + +func TestAtomicWriteFile_PermissionsApplied(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Windows does not support Unix file permissions") + } + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + + require.NoError(t, atomicWriteFile(path, []byte(`{}`), 0o600)) + + info, err := os.Stat(path) + require.NoError(t, err) + require.Equal(t, os.FileMode(0o600), info.Mode().Perm()) +} diff --git a/internal/config/store.go b/internal/config/store.go index 5501bdddafd206c4294a22666b59399338db7503..3e55509b7132e38805830818fca1e5265b7e03f9 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -158,7 +158,7 @@ func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error { if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return fmt.Errorf("failed to create config directory %q: %w", path, err) } - if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil { + if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil { return fmt.Errorf("failed to write config file: %w", err) } @@ -193,7 +193,7 @@ func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error { if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return fmt.Errorf("failed to create config directory %q: %w", path, err) } - if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil { + if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil { return fmt.Errorf("failed to write config file: %w", err) } @@ -302,7 +302,10 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a // RefreshOAuthToken refreshes the OAuth token for the given provider. // Before making an external refresh request, it checks the config file on // disk to see if another Crush session has already refreshed the token. If -// a newer token is found, it is used instead of refreshing. +// a newer, unexpired token is found, it is used instead of refreshing. If +// the exchange fails (e.g. because another session already rotated the +// refresh token), the disk is re-checked to recover the other session's +// token. func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error { providerConfig, exists := s.config.Providers.Get(providerID) if !exists { @@ -318,15 +321,9 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid newToken, err := s.loadTokenFromDisk(scope, providerID) if err != nil { slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err) - } else if newToken != nil && newToken.AccessToken != providerConfig.OAuthToken.AccessToken { + } else if newToken != nil && !newToken.IsExpired() && newToken.AccessToken != providerConfig.OAuthToken.AccessToken { slog.Info("Using token refreshed by another session", "provider", providerID) - providerConfig.OAuthToken = newToken - providerConfig.APIKey = newToken.AccessToken - if providerID == string(catwalk.InferenceProviderCopilot) { - providerConfig.SetupGitHubCopilot() - } - s.config.Providers.Set(providerID, providerConfig) - return nil + return s.applyToken(providerConfig, newToken, providerID) } var refreshedToken *oauth.Token @@ -340,6 +337,16 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } if refreshErr != nil { + // The exchange may have failed because another session already + // rotated the refresh token. Re-read the config file and use the + // other session's token if available. + if diskToken, diskErr := s.loadTokenFromDisk(scope, providerID); diskErr == nil && + diskToken != nil && + !diskToken.IsExpired() && + diskToken.AccessToken != providerConfig.OAuthToken.AccessToken { + slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID) + return s.applyToken(providerConfig, diskToken, providerID) + } return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) } @@ -364,6 +371,17 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return nil } +// applyToken updates the in-memory provider config with the given token. +func (s *ConfigStore) applyToken(providerConfig ProviderConfig, token *oauth.Token, providerID string) error { + providerConfig.OAuthToken = token + providerConfig.APIKey = token.AccessToken + if providerID == string(catwalk.InferenceProviderCopilot) { + providerConfig.SetupGitHubCopilot() + } + s.config.Providers.Set(providerID, providerConfig) + return nil +} + // loadTokenFromDisk reads the OAuth token for the given provider from the // config file on disk. Returns nil if the token is not found or matches the // current in-memory token.