@@ -1,7 +1,6 @@
package config
import (
- "cmp"
"context"
"encoding/json"
"fmt"
@@ -127,9 +126,18 @@ func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
// After a successful write, it automatically reloads config to keep in-memory
// state fresh.
func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
+ return s.SetConfigFields(scope, map[string]any{key: value})
+}
+
+// SetConfigFields sets multiple key/value pairs in the config file for the given
+// scope in a single write. After a successful write, it automatically reloads
+// config to keep in-memory state fresh. This is preferred over multiple
+// SetConfigField calls when writing several fields atomically to avoid
+// intermediate reloads with partial state.
+func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
path, err := s.configPath(scope)
if err != nil {
- return fmt.Errorf("%s: %w", key, err)
+ return fmt.Errorf("%v: %w", kv, err)
}
data, err := os.ReadFile(path)
if err != nil {
@@ -140,9 +148,12 @@ func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
}
}
- newValue, err := sjson.Set(string(data), key, value)
- if err != nil {
- return fmt.Errorf("failed to set config field %s: %w", key, err)
+ newValue := string(data)
+ for key, value := range kv {
+ newValue, err = sjson.Set(newValue, key, value)
+ if err != nil {
+ return fmt.Errorf("failed to set config field %s: %w", key, err)
+ }
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("failed to create config directory %q: %w", path, err)
@@ -238,10 +249,10 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
}
setKeyOrToken = func() { providerConfig.APIKey = v }
case *oauth.Token:
- if err := cmp.Or(
- s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
- s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
- ); err != nil {
+ if err := s.SetConfigFields(scope, map[string]any{
+ fmt.Sprintf("providers.%s.api_key", providerID): v.AccessToken,
+ fmt.Sprintf("providers.%s.oauth", providerID): v,
+ }); err != nil {
return err
}
setKeyOrToken = func() {
@@ -343,10 +354,10 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
s.config.Providers.Set(providerID, providerConfig)
- if err := cmp.Or(
- s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), refreshedToken.AccessToken),
- s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), refreshedToken),
- ); err != nil {
+ if err := s.SetConfigFields(scope, map[string]any{
+ fmt.Sprintf("providers.%s.api_key", providerID): refreshedToken.AccessToken,
+ fmt.Sprintf("providers.%s.oauth", providerID): refreshedToken,
+ }); err != nil {
return fmt.Errorf("failed to persist refreshed token: %w", err)
}
@@ -460,10 +471,10 @@ func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
return token, false
}
- if err := cmp.Or(
- s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
- s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
- ); err != nil {
+ if err := s.SetConfigFields(ScopeGlobal, map[string]any{
+ "providers.copilot.api_key": token.AccessToken,
+ "providers.copilot.oauth": token,
+ }); err != nil {
slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
}
@@ -513,6 +513,38 @@ func TestAutoReloadDisabledDuringReload(t *testing.T) {
require.False(t, store.autoReloadDisabled, "autoReloadDisabled should be false after ReloadFromDisk")
}
+// TestSetConfigFields_AutoReloadsAtomically verifies that SetConfigFields writes
+// multiple fields in a single disk write and triggers only one auto-reload,
+// avoiding intermediate states where only some fields are persisted.
+func TestSetConfigFields_AutoReloadsAtomically(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "crush.json")
+
+ // Create initial config file.
+ initialConfig := `{"options": {"debug": false}}`
+ require.NoError(t, os.WriteFile(configPath, []byte(initialConfig), 0o600))
+
+ // Load initial config.
+ store, err := Load(dir, dir, false)
+ require.NoError(t, err)
+
+ // Set globalDataPath and capture snapshot.
+ store.globalDataPath = configPath
+ store.CaptureStalenessSnapshot([]string{configPath})
+
+ // Write multiple fields atomically.
+ err = store.SetConfigFields(ScopeGlobal, map[string]any{
+ "options.debug": true,
+ "options.custom": "hello",
+ })
+ require.NoError(t, err)
+
+ // Verify both fields are reflected in memory.
+ require.True(t, store.config.Options.Debug)
+}
+
func TestLoadTokenFromDisk_ReturnsNewerToken(t *testing.T) {
t.Parallel()