fix(config): atomically update multiple fields during oauth

Kieran Klukas created

Change summary

internal/config/store.go      | 45 +++++++++++++++++++++++-------------
internal/config/store_test.go | 32 ++++++++++++++++++++++++++
2 files changed, 60 insertions(+), 17 deletions(-)

Detailed changes

internal/config/store.go 🔗

@@ -1,7 +1,6 @@
 package config
 
 import (
-	"cmp"
 	"context"
 	"encoding/json"
 	"fmt"
@@ -127,9 +126,18 @@ func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
 // After a successful write, it automatically reloads config to keep in-memory
 // state fresh.
 func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
+	return s.SetConfigFields(scope, map[string]any{key: value})
+}
+
+// SetConfigFields sets multiple key/value pairs in the config file for the given
+// scope in a single write. After a successful write, it automatically reloads
+// config to keep in-memory state fresh. This is preferred over multiple
+// SetConfigField calls when writing several fields atomically to avoid
+// intermediate reloads with partial state.
+func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
 	path, err := s.configPath(scope)
 	if err != nil {
-		return fmt.Errorf("%s: %w", key, err)
+		return fmt.Errorf("%v: %w", kv, err)
 	}
 	data, err := os.ReadFile(path)
 	if err != nil {
@@ -140,9 +148,12 @@ func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
 		}
 	}
 
-	newValue, err := sjson.Set(string(data), key, value)
-	if err != nil {
-		return fmt.Errorf("failed to set config field %s: %w", key, err)
+	newValue := string(data)
+	for key, value := range kv {
+		newValue, err = sjson.Set(newValue, key, value)
+		if err != nil {
+			return fmt.Errorf("failed to set config field %s: %w", key, err)
+		}
 	}
 	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 		return fmt.Errorf("failed to create config directory %q: %w", path, err)
@@ -238,10 +249,10 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
 		}
 		setKeyOrToken = func() { providerConfig.APIKey = v }
 	case *oauth.Token:
-		if err := cmp.Or(
-			s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
-			s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
-		); err != nil {
+		if err := s.SetConfigFields(scope, map[string]any{
+			fmt.Sprintf("providers.%s.api_key", providerID): v.AccessToken,
+			fmt.Sprintf("providers.%s.oauth", providerID):   v,
+		}); err != nil {
 			return err
 		}
 		setKeyOrToken = func() {
@@ -343,10 +354,10 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 
 	s.config.Providers.Set(providerID, providerConfig)
 
-	if err := cmp.Or(
-		s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), refreshedToken.AccessToken),
-		s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), refreshedToken),
-	); err != nil {
+	if err := s.SetConfigFields(scope, map[string]any{
+		fmt.Sprintf("providers.%s.api_key", providerID): refreshedToken.AccessToken,
+		fmt.Sprintf("providers.%s.oauth", providerID):   refreshedToken,
+	}); err != nil {
 		return fmt.Errorf("failed to persist refreshed token: %w", err)
 	}
 
@@ -460,10 +471,10 @@ func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
 		return token, false
 	}
 
-	if err := cmp.Or(
-		s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
-		s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
-	); err != nil {
+	if err := s.SetConfigFields(ScopeGlobal, map[string]any{
+		"providers.copilot.api_key": token.AccessToken,
+		"providers.copilot.oauth":   token,
+	}); err != nil {
 		slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
 	}
 

internal/config/store_test.go 🔗

@@ -513,6 +513,38 @@ func TestAutoReloadDisabledDuringReload(t *testing.T) {
 	require.False(t, store.autoReloadDisabled, "autoReloadDisabled should be false after ReloadFromDisk")
 }
 
+// TestSetConfigFields_AutoReloadsAtomically verifies that SetConfigFields writes
+// multiple fields in a single disk write and triggers only one auto-reload,
+// avoiding intermediate states where only some fields are persisted.
+func TestSetConfigFields_AutoReloadsAtomically(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "crush.json")
+
+	// Create initial config file.
+	initialConfig := `{"options": {"debug": false}}`
+	require.NoError(t, os.WriteFile(configPath, []byte(initialConfig), 0o600))
+
+	// Load initial config.
+	store, err := Load(dir, dir, false)
+	require.NoError(t, err)
+
+	// Set globalDataPath and capture snapshot.
+	store.globalDataPath = configPath
+	store.CaptureStalenessSnapshot([]string{configPath})
+
+	// Write multiple fields atomically.
+	err = store.SetConfigFields(ScopeGlobal, map[string]any{
+		"options.debug":  true,
+		"options.custom": "hello",
+	})
+	require.NoError(t, err)
+
+	// Verify both fields are reflected in memory.
+	require.True(t, store.config.Options.Debug)
+}
+
 func TestLoadTokenFromDisk_ReturnsNewerToken(t *testing.T) {
 	t.Parallel()