@@ -0,0 +1,38 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+)
+
+// atomicWriteFile writes data to a file atomically by writing to a unique
+// temporary file in the same directory and renaming it into place. This
+// prevents concurrent readers from observing a partially-written file.
+func atomicWriteFile(path string, data []byte, perm os.FileMode) error {
+ path = filepath.Clean(path)
+ dir := filepath.Dir(path)
+ f, err := os.CreateTemp(dir, filepath.Base(path)+".*.tmp")
+ if err != nil {
+ return err
+ }
+ tmp := f.Name()
+ if _, err := f.Write(data); err != nil {
+ f.Close()
+ os.Remove(tmp)
+ return err
+ }
+ if err := f.Chmod(perm); err != nil {
+ f.Close()
+ os.Remove(tmp)
+ return err
+ }
+ if err := f.Close(); err != nil {
+ os.Remove(tmp)
+ return err
+ }
+ if err := os.Rename(tmp, path); err != nil {
+ os.Remove(tmp)
+ return err
+ }
+ return nil
+}
@@ -0,0 +1,43 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+ "runtime"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAtomicWriteFile(t *testing.T) {
+ t.Parallel()
+ dir := t.TempDir()
+ path := filepath.Join(dir, "test.json")
+
+ require.NoError(t, atomicWriteFile(path, []byte(`{"key":"value"}`), 0o600))
+
+ data, err := os.ReadFile(path)
+ require.NoError(t, err)
+ require.Equal(t, `{"key":"value"}`, string(data))
+
+ // No temp files should linger.
+ entries, err := os.ReadDir(dir)
+ require.NoError(t, err)
+ require.Len(t, entries, 1)
+ require.Equal(t, "test.json", entries[0].Name())
+}
+
+func TestAtomicWriteFile_PermissionsApplied(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("Windows does not support Unix file permissions")
+ }
+ t.Parallel()
+ dir := t.TempDir()
+ path := filepath.Join(dir, "test.json")
+
+ require.NoError(t, atomicWriteFile(path, []byte(`{}`), 0o600))
+
+ info, err := os.Stat(path)
+ require.NoError(t, err)
+ require.Equal(t, os.FileMode(0o600), info.Mode().Perm())
+}
@@ -158,7 +158,7 @@ func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("failed to create config directory %q: %w", path, err)
}
- if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
+ if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
@@ -193,7 +193,7 @@ func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("failed to create config directory %q: %w", path, err)
}
- if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
+ if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
@@ -302,7 +302,10 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
// RefreshOAuthToken refreshes the OAuth token for the given provider.
// Before making an external refresh request, it checks the config file on
// disk to see if another Crush session has already refreshed the token. If
-// a newer token is found, it is used instead of refreshing.
+// a newer, unexpired token is found, it is used instead of refreshing. If
+// the exchange fails (e.g. because another session already rotated the
+// refresh token), the disk is re-checked to recover the other session's
+// token.
func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
providerConfig, exists := s.config.Providers.Get(providerID)
if !exists {
@@ -318,15 +321,9 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
newToken, err := s.loadTokenFromDisk(scope, providerID)
if err != nil {
slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err)
- } else if newToken != nil && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+ } else if newToken != nil && !newToken.IsExpired() && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
slog.Info("Using token refreshed by another session", "provider", providerID)
- providerConfig.OAuthToken = newToken
- providerConfig.APIKey = newToken.AccessToken
- if providerID == string(catwalk.InferenceProviderCopilot) {
- providerConfig.SetupGitHubCopilot()
- }
- s.config.Providers.Set(providerID, providerConfig)
- return nil
+ return s.applyToken(providerConfig, newToken, providerID)
}
var refreshedToken *oauth.Token
@@ -340,6 +337,16 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
}
if refreshErr != nil {
+ // The exchange may have failed because another session already
+ // rotated the refresh token. Re-read the config file and use the
+ // other session's token if available.
+ if diskToken, diskErr := s.loadTokenFromDisk(scope, providerID); diskErr == nil &&
+ diskToken != nil &&
+ !diskToken.IsExpired() &&
+ diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+ slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID)
+ return s.applyToken(providerConfig, diskToken, providerID)
+ }
return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
}
@@ -364,6 +371,17 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
return nil
}
+// applyToken updates the in-memory provider config with the given token.
+func (s *ConfigStore) applyToken(providerConfig ProviderConfig, token *oauth.Token, providerID string) error {
+ providerConfig.OAuthToken = token
+ providerConfig.APIKey = token.AccessToken
+ if providerID == string(catwalk.InferenceProviderCopilot) {
+ providerConfig.SetupGitHubCopilot()
+ }
+ s.config.Providers.Set(providerID, providerConfig)
+ return nil
+}
+
// loadTokenFromDisk reads the OAuth token for the given provider from the
// config file on disk. Returns nil if the token is not found or matches the
// current in-memory token.