From 8bc4a75289efc3e819a95f9c16f351f9cd4a091f Mon Sep 17 00:00:00 2001 From: Kieran Klukas Date: Wed, 6 May 2026 12:10:08 -0400 Subject: [PATCH] fix(config): atomically update multiple fields during oauth --- internal/config/store.go | 45 ++++++++++++++++++++++------------- internal/config/store_test.go | 32 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/internal/config/store.go b/internal/config/store.go index c44fceba0ac87133a787b1351de213802a4e17cc..376efb93feada30d4710929d9799a672a436f6e4 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -1,7 +1,6 @@ package config import ( - "cmp" "context" "encoding/json" "fmt" @@ -127,9 +126,18 @@ func (s *ConfigStore) HasConfigField(scope Scope, key string) bool { // After a successful write, it automatically reloads config to keep in-memory // state fresh. func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error { + return s.SetConfigFields(scope, map[string]any{key: value}) +} + +// SetConfigFields sets multiple key/value pairs in the config file for the given +// scope in a single write. After a successful write, it automatically reloads +// config to keep in-memory state fresh. This is preferred over multiple +// SetConfigField calls when writing several fields atomically to avoid +// intermediate reloads with partial state. +func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error { path, err := s.configPath(scope) if err != nil { - return fmt.Errorf("%s: %w", key, err) + return fmt.Errorf("%v: %w", kv, err) } data, err := os.ReadFile(path) if err != nil { @@ -140,9 +148,12 @@ func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error { } } - newValue, err := sjson.Set(string(data), key, value) - if err != nil { - return fmt.Errorf("failed to set config field %s: %w", key, err) + newValue := string(data) + for key, value := range kv { + newValue, err = sjson.Set(newValue, key, value) + if err != nil { + return fmt.Errorf("failed to set config field %s: %w", key, err) + } } if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return fmt.Errorf("failed to create config directory %q: %w", path, err) @@ -238,10 +249,10 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a } setKeyOrToken = func() { providerConfig.APIKey = v } case *oauth.Token: - if err := cmp.Or( - s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), - s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v), - ); err != nil { + if err := s.SetConfigFields(scope, map[string]any{ + fmt.Sprintf("providers.%s.api_key", providerID): v.AccessToken, + fmt.Sprintf("providers.%s.oauth", providerID): v, + }); err != nil { return err } setKeyOrToken = func() { @@ -343,10 +354,10 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid s.config.Providers.Set(providerID, providerConfig) - if err := cmp.Or( - s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), refreshedToken.AccessToken), - s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), refreshedToken), - ); err != nil { + if err := s.SetConfigFields(scope, map[string]any{ + fmt.Sprintf("providers.%s.api_key", providerID): refreshedToken.AccessToken, + fmt.Sprintf("providers.%s.oauth", providerID): refreshedToken, + }); err != nil { return fmt.Errorf("failed to persist refreshed token: %w", err) } @@ -460,10 +471,10 @@ func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) { return token, false } - if err := cmp.Or( - s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken), - s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token), - ); err != nil { + if err := s.SetConfigFields(ScopeGlobal, map[string]any{ + "providers.copilot.api_key": token.AccessToken, + "providers.copilot.oauth": token, + }); err != nil { slog.Error("Unable to save GitHub Copilot token to disk", "error", err) } diff --git a/internal/config/store_test.go b/internal/config/store_test.go index 3f9441ea02315a6ac75012d956acbcab217d1c91..e977547c74388f632993b1101eb6aacdaf1e74d1 100644 --- a/internal/config/store_test.go +++ b/internal/config/store_test.go @@ -513,6 +513,38 @@ func TestAutoReloadDisabledDuringReload(t *testing.T) { require.False(t, store.autoReloadDisabled, "autoReloadDisabled should be false after ReloadFromDisk") } +// TestSetConfigFields_AutoReloadsAtomically verifies that SetConfigFields writes +// multiple fields in a single disk write and triggers only one auto-reload, +// avoiding intermediate states where only some fields are persisted. +func TestSetConfigFields_AutoReloadsAtomically(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + + // Create initial config file. + initialConfig := `{"options": {"debug": false}}` + require.NoError(t, os.WriteFile(configPath, []byte(initialConfig), 0o600)) + + // Load initial config. + store, err := Load(dir, dir, false) + require.NoError(t, err) + + // Set globalDataPath and capture snapshot. + store.globalDataPath = configPath + store.CaptureStalenessSnapshot([]string{configPath}) + + // Write multiple fields atomically. + err = store.SetConfigFields(ScopeGlobal, map[string]any{ + "options.debug": true, + "options.custom": "hello", + }) + require.NoError(t, err) + + // Verify both fields are reflected in memory. + require.True(t, store.config.Options.Debug) +} + func TestLoadTokenFromDisk_ReturnsNewerToken(t *testing.T) { t.Parallel()