Add generic update method to SettingsStore

Max Brunsfeld created

Change summary

crates/settings/src/settings_store.rs | 416 ++++++++++++++++++++++++++++
1 file changed, 408 insertions(+), 8 deletions(-)

Detailed changes

crates/settings/src/settings_store.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{anyhow, Result};
 use collections::{hash_map, BTreeMap, HashMap, HashSet};
+use lazy_static::lazy_static;
 use schemars::JsonSchema;
 use serde::{de::DeserializeOwned, Deserialize as _, Serialize};
 use smallvec::SmallVec;
@@ -7,10 +8,12 @@ use std::{
     any::{type_name, Any, TypeId},
     fmt::Debug,
     mem,
+    ops::Range,
     path::Path,
+    str,
     sync::Arc,
 };
-use util::{merge_non_null_json_value_into, ResultExt as _};
+use util::{merge_non_null_json_value_into, RangeExt, ResultExt as _};
 
 /// A value that can be defined as a user setting.
 ///
@@ -22,7 +25,7 @@ pub trait Setting: 'static + Debug {
     const KEY: Option<&'static str> = None;
 
     /// The type that is stored in an individual JSON file.
-    type FileContent: Serialize + DeserializeOwned + JsonSchema;
+    type FileContent: Clone + Serialize + DeserializeOwned + JsonSchema;
 
     /// The logic for combining together values from one or more JSON files into the
     /// final value for this setting.
@@ -37,7 +40,6 @@ pub trait Setting: 'static + Debug {
     ) -> Self
     where
         Self: DeserializeOwned,
-        Self::FileContent: Serialize,
     {
         let mut merged = serde_json::Value::Null;
         for value in [default_value].iter().chain(user_values) {
@@ -55,6 +57,7 @@ pub struct SettingsStore {
     user_deserialized_settings: Option<DeserializedSettingMap>,
     local_deserialized_settings: BTreeMap<Arc<Path>, DeserializedSettingMap>,
     changed_setting_types: HashSet<TypeId>,
+    tab_size_callback: Option<(TypeId, Box<dyn Fn(&dyn Any) -> Option<usize>>)>,
 }
 
 #[derive(Debug)]
@@ -129,6 +132,81 @@ impl SettingsStore {
             .unwrap()
     }
 
+    /// Update the value of a setting.
+    ///
+    /// Returns a list of edits to apply to the JSON file.
+    pub fn update<T: Setting>(
+        &self,
+        text: &str,
+        update: impl Fn(&mut T::FileContent),
+    ) -> Vec<(Range<usize>, String)> {
+        let setting_type_id = TypeId::of::<T>();
+        let old_content = self
+            .user_deserialized_settings
+            .as_ref()
+            .unwrap()
+            .typed
+            .get(&setting_type_id)
+            .unwrap()
+            .0
+            .downcast_ref::<T::FileContent>()
+            .unwrap()
+            .clone();
+        let mut new_content = old_content.clone();
+        update(&mut new_content);
+
+        let mut parser = tree_sitter::Parser::new();
+        parser.set_language(tree_sitter_json::language()).unwrap();
+        let tree = parser.parse(text, None).unwrap();
+
+        let old_value = &serde_json::to_value(old_content).unwrap();
+        let new_value = &serde_json::to_value(new_content).unwrap();
+
+        let mut key_path = Vec::new();
+        if let Some(key) = T::KEY {
+            key_path.push(key);
+        }
+
+        let mut edits = Vec::new();
+        let tab_size = self.json_tab_size();
+        update_value_in_json_text(
+            &text,
+            &tree,
+            &mut key_path,
+            tab_size,
+            &old_value,
+            &new_value,
+            &mut edits,
+        );
+        edits.sort_unstable_by_key(|e| e.0.start);
+        return edits;
+    }
+
+    /// Configure the tab sized when updating JSON files.
+    pub fn set_json_tab_size_callback<T: Setting>(
+        &mut self,
+        get_tab_size: fn(&T) -> Option<usize>,
+    ) {
+        self.tab_size_callback = Some((
+            TypeId::of::<T>(),
+            Box::new(move |value| get_tab_size(value.downcast_ref::<T>().unwrap())),
+        ));
+    }
+
+    fn json_tab_size(&self) -> usize {
+        const DEFAULT_JSON_TAB_SIZE: usize = 2;
+
+        if let Some((setting_type_id, callback)) = &self.tab_size_callback {
+            let setting_value = self.setting_values.get(setting_type_id).unwrap();
+            let value = setting_value.value_for_path(None);
+            if let Some(value) = callback(value) {
+                return value;
+            }
+        }
+
+        DEFAULT_JSON_TAB_SIZE
+    }
+
     /// Set the default settings via a JSON string.
     ///
     /// The string should contain a JSON object with a default value for every setting.
@@ -277,8 +355,8 @@ impl SettingsStore {
     /// Returns an error if the string doesn't contain a valid JSON object.
     fn load_setting_map(&self, json: &str) -> Result<DeserializedSettingMap> {
         let mut map = DeserializedSettingMap {
+            untyped: parse_json_with_comments(json)?,
             typed: HashMap::default(),
-            untyped: serde_json::from_str(json)?,
         };
         for (setting_type_id, setting_value) in self.setting_values.iter() {
             Self::load_setting_in_map(*setting_type_id, setting_value, &mut map);
@@ -374,10 +452,231 @@ impl Debug for SettingsStore {
     }
 }
 
+fn update_value_in_json_text<'a>(
+    text: &str,
+    syntax_tree: &tree_sitter::Tree,
+    key_path: &mut Vec<&'a str>,
+    tab_size: usize,
+    old_value: &'a serde_json::Value,
+    new_value: &'a serde_json::Value,
+    edits: &mut Vec<(Range<usize>, String)>,
+) {
+    // If the old and new values are both objects, then compare them key by key,
+    // preserving the comments and formatting of the unchanged parts. Otherwise,
+    // replace the old value with the new value.
+    if let (serde_json::Value::Object(old_object), serde_json::Value::Object(new_object)) =
+        (old_value, new_value)
+    {
+        for (key, old_sub_value) in old_object.iter() {
+            key_path.push(key);
+            let new_sub_value = new_object.get(key).unwrap_or(&serde_json::Value::Null);
+            update_value_in_json_text(
+                text,
+                syntax_tree,
+                key_path,
+                tab_size,
+                old_sub_value,
+                new_sub_value,
+                edits,
+            );
+            key_path.pop();
+        }
+        for (key, new_sub_value) in new_object.iter() {
+            key_path.push(key);
+            if !old_object.contains_key(key) {
+                update_value_in_json_text(
+                    text,
+                    syntax_tree,
+                    key_path,
+                    tab_size,
+                    &serde_json::Value::Null,
+                    new_sub_value,
+                    edits,
+                );
+            }
+            key_path.pop();
+        }
+    } else if old_value != new_value {
+        let (range, replacement) =
+            replace_value_in_json_text(text, syntax_tree, &key_path, tab_size, &new_value);
+        edits.push((range, replacement));
+    }
+}
+
+lazy_static! {
+    static ref PAIR_QUERY: tree_sitter::Query = tree_sitter::Query::new(
+        tree_sitter_json::language(),
+        "(pair key: (string) @key value: (_) @value)",
+    )
+    .unwrap();
+}
+
+fn replace_value_in_json_text(
+    text: &str,
+    syntax_tree: &tree_sitter::Tree,
+    key_path: &[&str],
+    tab_size: usize,
+    new_value: impl Serialize,
+) -> (Range<usize>, String) {
+    const LANGUAGE_OVERRIDES: &'static str = "language_overrides";
+    const LANGUAGES: &'static str = "languages";
+
+    let mut cursor = tree_sitter::QueryCursor::new();
+
+    let has_language_overrides = text.contains(LANGUAGE_OVERRIDES);
+
+    let mut depth = 0;
+    let mut last_value_range = 0..0;
+    let mut first_key_start = None;
+    let mut existing_value_range = 0..text.len();
+    let matches = cursor.matches(&PAIR_QUERY, syntax_tree.root_node(), text.as_bytes());
+    for mat in matches {
+        if mat.captures.len() != 2 {
+            continue;
+        }
+
+        let key_range = mat.captures[0].node.byte_range();
+        let value_range = mat.captures[1].node.byte_range();
+
+        // Don't enter sub objects until we find an exact
+        // match for the current keypath
+        if last_value_range.contains_inclusive(&value_range) {
+            continue;
+        }
+
+        last_value_range = value_range.clone();
+
+        if key_range.start > existing_value_range.end {
+            break;
+        }
+
+        first_key_start.get_or_insert_with(|| key_range.start);
+
+        let found_key = text
+            .get(key_range.clone())
+            .map(|key_text| {
+                if key_path[depth] == LANGUAGES && has_language_overrides {
+                    return key_text == format!("\"{}\"", LANGUAGE_OVERRIDES);
+                } else {
+                    return key_text == format!("\"{}\"", key_path[depth]);
+                }
+            })
+            .unwrap_or(false);
+
+        if found_key {
+            existing_value_range = value_range;
+            // Reset last value range when increasing in depth
+            last_value_range = existing_value_range.start..existing_value_range.start;
+            depth += 1;
+
+            if depth == key_path.len() {
+                break;
+            } else {
+                first_key_start = None;
+            }
+        }
+    }
+
+    // We found the exact key we want, insert the new value
+    if depth == key_path.len() {
+        let new_val = to_pretty_json(&new_value, tab_size, tab_size * depth);
+        (existing_value_range, new_val)
+    } else {
+        // We have key paths, construct the sub objects
+        let new_key = if has_language_overrides && key_path[depth] == LANGUAGES {
+            LANGUAGE_OVERRIDES
+        } else {
+            key_path[depth]
+        };
+
+        // We don't have the key, construct the nested objects
+        let mut new_value = serde_json::to_value(new_value).unwrap();
+        for key in key_path[(depth + 1)..].iter().rev() {
+            if has_language_overrides && key == &LANGUAGES {
+                new_value = serde_json::json!({ LANGUAGE_OVERRIDES.to_string(): new_value });
+            } else {
+                new_value = serde_json::json!({ key.to_string(): new_value });
+            }
+        }
+
+        if let Some(first_key_start) = first_key_start {
+            let mut row = 0;
+            let mut column = 0;
+            for (ix, char) in text.char_indices() {
+                if ix == first_key_start {
+                    break;
+                }
+                if char == '\n' {
+                    row += 1;
+                    column = 0;
+                } else {
+                    column += char.len_utf8();
+                }
+            }
+
+            if row > 0 {
+                // depth is 0 based, but division needs to be 1 based.
+                let new_val = to_pretty_json(&new_value, column / (depth + 1), column);
+                let space = ' ';
+                let content = format!("\"{new_key}\": {new_val},\n{space:width$}", width = column);
+                (first_key_start..first_key_start, content)
+            } else {
+                let new_val = serde_json::to_string(&new_value).unwrap();
+                let mut content = format!(r#""{new_key}": {new_val},"#);
+                content.push(' ');
+                (first_key_start..first_key_start, content)
+            }
+        } else {
+            new_value = serde_json::json!({ new_key.to_string(): new_value });
+            let indent_prefix_len = 4 * depth;
+            let mut new_val = to_pretty_json(&new_value, 4, indent_prefix_len);
+            if depth == 0 {
+                new_val.push('\n');
+            }
+
+            (existing_value_range, new_val)
+        }
+    }
+}
+
+fn to_pretty_json(value: &impl Serialize, indent_size: usize, indent_prefix_len: usize) -> String {
+    const SPACES: [u8; 32] = [b' '; 32];
+
+    debug_assert!(indent_size <= SPACES.len());
+    debug_assert!(indent_prefix_len <= SPACES.len());
+
+    let mut output = Vec::new();
+    let mut ser = serde_json::Serializer::with_formatter(
+        &mut output,
+        serde_json::ser::PrettyFormatter::with_indent(&SPACES[0..indent_size.min(SPACES.len())]),
+    );
+
+    value.serialize(&mut ser).unwrap();
+    let text = String::from_utf8(output).unwrap();
+
+    let mut adjusted_text = String::new();
+    for (i, line) in text.split('\n').enumerate() {
+        if i > 0 {
+            adjusted_text.push_str(str::from_utf8(&SPACES[0..indent_prefix_len]).unwrap());
+        }
+        adjusted_text.push_str(line);
+        adjusted_text.push('\n');
+    }
+    adjusted_text.pop();
+    adjusted_text
+}
+
+fn parse_json_with_comments<T: DeserializeOwned>(content: &str) -> Result<T> {
+    Ok(serde_json::from_reader(
+        json_comments::CommentSettings::c_style().strip_comments(content.as_bytes()),
+    )?)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use serde_derive::Deserialize;
+    use unindent::Unindent;
 
     #[test]
     fn test_settings_store_basic() {
@@ -518,7 +817,7 @@ mod tests {
     }
 
     #[test]
-    fn test_setting_store_load_before_register() {
+    fn test_setting_store_assign_json_before_register() {
         let mut store = SettingsStore::default();
         store
             .set_default_settings(
@@ -529,7 +828,7 @@ mod tests {
                         "age": 30,
                         "staff": false
                     },
-                    "key1": "x
+                    "key1": "x"
                 }"#,
             )
             .unwrap();
@@ -557,6 +856,86 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_setting_store_update() {
+        let mut store = SettingsStore::default();
+        store.register_setting::<UserSettings>();
+        store.register_setting::<LanguageSettings>();
+
+        // entries added and updated
+        check_settings_update::<LanguageSettings>(
+            &mut store,
+            r#"{
+                "languages": {
+                    "JSON": {
+                        "is_enabled": true
+                    }
+                }
+            }"#
+            .unindent(),
+            |settings| {
+                settings.languages.get_mut("JSON").unwrap().is_enabled = false;
+                settings
+                    .languages
+                    .insert("Rust".into(), LanguageSettingEntry { is_enabled: true });
+            },
+            r#"{
+                "languages": {
+                    "Rust": {
+                        "is_enabled": true
+                    },
+                    "JSON": {
+                        "is_enabled": false
+                    }
+                }
+            }"#
+            .unindent(),
+        );
+
+        // weird formatting
+        check_settings_update::<UserSettings>(
+            &mut store,
+            r#"{
+                "user":   { "age": 36, "name": "Max", "staff": true }
+            }"#
+            .unindent(),
+            |settings| settings.age = Some(37),
+            r#"{
+                "user":   { "age": 37, "name": "Max", "staff": true }
+            }"#
+            .unindent(),
+        );
+
+        // no content
+        check_settings_update::<UserSettings>(
+            &mut store,
+            r#""#.unindent(),
+            |settings| settings.age = Some(37),
+            r#"{
+                "user": {
+                    "age": 37
+                }
+            }
+            "#
+            .unindent(),
+        );
+    }
+
+    fn check_settings_update<T: Setting>(
+        store: &mut SettingsStore,
+        old_json: String,
+        update: fn(&mut T::FileContent),
+        expected_new_json: String,
+    ) {
+        store.set_user_settings(&old_json).ok();
+        let edits = store.update::<T>(&old_json, update);
+        let mut new_json = old_json;
+        for (range, replacement) in edits.into_iter().rev() {
+            new_json.replace_range(range, &replacement);
+        }
+        pretty_assertions::assert_eq!(new_json, expected_new_json);
+    }
+
     #[derive(Debug, PartialEq, Deserialize)]
     struct UserSettings {
         name: String,
@@ -564,7 +943,7 @@ mod tests {
         staff: bool,
     }
 
-    #[derive(Serialize, Deserialize, JsonSchema)]
+    #[derive(Clone, Serialize, Deserialize, JsonSchema)]
     struct UserSettingsJson {
         name: Option<String>,
         age: Option<u32>,
@@ -600,7 +979,7 @@ mod tests {
         key2: String,
     }
 
-    #[derive(Serialize, Deserialize, JsonSchema)]
+    #[derive(Clone, Serialize, Deserialize, JsonSchema)]
     struct MultiKeySettingsJson {
         key1: Option<String>,
         key2: Option<String>,
@@ -645,4 +1024,25 @@ mod tests {
             Self::load_via_json_merge(default_value, user_values)
         }
     }
+
+    #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+    struct LanguageSettings {
+        #[serde(default)]
+        languages: HashMap<String, LanguageSettingEntry>,
+    }
+
+    #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+    struct LanguageSettingEntry {
+        is_enabled: bool,
+    }
+
+    impl Setting for LanguageSettings {
+        const KEY: Option<&'static str> = None;
+
+        type FileContent = Self;
+
+        fn load(default_value: &Self, user_values: &[&Self]) -> Self {
+            Self::load_via_json_merge(default_value, user_values)
+        }
+    }
 }