Implement a topological sort for references in themes

Max Brunsfeld created

Change summary

zed/src/theme.rs | 529 ++++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 452 insertions(+), 77 deletions(-)

Detailed changes

zed/src/theme.rs 🔗

@@ -9,7 +9,7 @@ use json::{Map, Value};
 use parking_lot::Mutex;
 use serde::{Deserialize, Deserializer};
 use serde_json as json;
-use std::{cmp::Ordering, collections::HashMap, sync::Arc};
+use std::{collections::HashMap, fmt, mem, sync::Arc};
 
 const DEFAULT_HIGHLIGHT_ID: HighlightId = HighlightId(u32::MAX);
 pub const DEFAULT_THEME_NAME: &'static str = "dark";
@@ -91,6 +91,30 @@ pub struct SelectorItem {
     pub label: LabelStyle,
 }
 
+#[derive(Default)]
+struct KeyPathReferenceSet {
+    references: Vec<KeyPathReference>,
+    reference_ids_by_source: Vec<usize>,
+    reference_ids_by_target: Vec<usize>,
+    dependencies: Vec<(usize, usize)>,
+    dependency_counts: Vec<usize>,
+}
+
+#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
+struct KeyPathReference {
+    target: KeyPath,
+    source: KeyPath,
+}
+
+#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
+struct KeyPath(Vec<Key>);
+
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+enum Key {
+    Array(usize),
+    Object(String),
+}
+
 impl Default for Editor {
     fn default() -> Self {
         Self {
@@ -178,39 +202,33 @@ impl ThemeRegistry {
         // Find all of the key path references in the object, and then sort them according
         // to their dependencies.
         if evaluate_references {
-            let mut references = Vec::new();
-            let mut key_path = Vec::new();
+            let mut key_path = KeyPath::default();
+            let mut references = KeyPathReferenceSet::default();
             for (key, value) in theme_data.iter() {
-                key_path.push(Key::Object(key.clone()));
+                key_path.0.push(Key::Object(key.clone()));
                 find_references(value, &mut key_path, &mut references);
-                key_path.pop();
+                key_path.0.pop();
             }
-            sort_references(&mut references);
+            let sorted_references = references
+                .top_sort()
+                .map_err(|key_paths| anyhow!("cycle for key paths: {:?}", key_paths))?;
 
             // Now update objects to include the fields of objects they extend
-            for KeyPathReference {
-                source_path,
-                target_path,
-            } in references
-            {
-                let source = value_at(&mut theme_data, &source_path).cloned();
-                let target = value_at(&mut theme_data, &target_path).unwrap();
-                if let Some(source) = source {
+            for KeyPathReference { source, target } in sorted_references {
+                if let Some(source) = value_at(&mut theme_data, &source).cloned() {
+                    let target = value_at(&mut theme_data, &target).unwrap();
                     if let Value::Object(target_object) = target.take() {
                         if let Value::Object(mut source_object) = source {
                             deep_merge_json(&mut source_object, target_object);
                             *target = Value::Object(source_object);
                         } else {
-                            Err(anyhow!(
-                                "extended key path {:?} is not an object",
-                                source_path
-                            ))?;
+                            Err(anyhow!("extended key path {} is not an object", source))?;
                         }
                     } else {
                         *target = source;
                     }
                 } else {
-                    Err(anyhow!("invalid key path {:?}", source_path))?;
+                    Err(anyhow!("invalid key path '{}'", source))?;
                 }
             }
         }
@@ -281,6 +299,295 @@ impl HighlightMap {
     }
 }
 
+impl KeyPathReferenceSet {
+    fn insert(&mut self, reference: KeyPathReference) {
+        let id = self.references.len();
+        let source_ix = self
+            .reference_ids_by_source
+            .binary_search_by_key(&&reference.source, |id| &self.references[*id].source)
+            .unwrap_or_else(|i| i);
+        let target_ix = self
+            .reference_ids_by_target
+            .binary_search_by_key(&&reference.target, |id| &self.references[*id].target)
+            .unwrap_or_else(|i| i);
+
+        self.populate_dependencies(id, &reference);
+        self.reference_ids_by_source.insert(source_ix, id);
+        self.reference_ids_by_target.insert(target_ix, id);
+        self.references.push(reference);
+    }
+
+    fn top_sort(mut self) -> Result<Vec<KeyPathReference>, Vec<KeyPath>> {
+        let mut results = Vec::with_capacity(self.references.len());
+        let mut root_ids = Vec::with_capacity(self.references.len());
+
+        // Find the initial set of references that have no dependencies.
+        for (id, dep_count) in self.dependency_counts.iter().enumerate() {
+            if *dep_count == 0 {
+                root_ids.push(id);
+            }
+        }
+
+        root_ids.sort_by_key(|id| &self.references[*id]);
+
+        while results.len() < root_ids.len() {
+            let root_id = root_ids[results.len()];
+            let root = mem::take(&mut self.references[root_id]);
+            results.push(root);
+
+            // Remove this reference as a dependency from any of its dependent references.
+            if let Ok(dep_ix) = self
+                .dependencies
+                .binary_search_by_key(&root_id, |edge| edge.0)
+            {
+                let mut first_dep_ix = dep_ix;
+                let mut last_dep_ix = dep_ix + 1;
+                while first_dep_ix > 0 && self.dependencies[first_dep_ix - 1].0 == root_id {
+                    first_dep_ix -= 1;
+                }
+                while last_dep_ix < self.dependencies.len()
+                    && self.dependencies[last_dep_ix].0 == root_id
+                {
+                    last_dep_ix += 1;
+                }
+
+                // If any reference no longer has any dependencies, then then mark it as a root.
+                // Preserve the references' original order where possible.
+                for (_, successor_id) in self.dependencies.drain(first_dep_ix..last_dep_ix) {
+                    self.dependency_counts[successor_id] -= 1;
+                    if self.dependency_counts[successor_id] == 0 {
+                        if let Err(ix) = root_ids[results.len()..].binary_search(&successor_id) {
+                            root_ids.insert(results.len() + ix, successor_id);
+                        }
+                    }
+                }
+
+                root_ids[results.len()..].sort_by_key(|id| &self.references[*id]);
+            }
+        }
+
+        // If any references never became roots, then there are reference cycles
+        // in the set. Return an error containing all of the key paths that are
+        // directly involved in cycles.
+        if results.len() < self.references.len() {
+            let mut cycle_ref_ids = (0..self.references.len())
+                .filter(|id| !root_ids.contains(id))
+                .collect::<Vec<_>>();
+
+            // Iteratively remove any references that have no dependencies,
+            // so that the error will only indicate which key paths are directly
+            // involved in the cycles.
+            let mut done = false;
+            while !done {
+                done = true;
+                cycle_ref_ids.retain(|id| {
+                    if self.dependencies.iter().any(|dep| dep.0 == *id) {
+                        true
+                    } else {
+                        done = false;
+                        self.dependencies.retain(|dep| dep.1 != *id);
+                        false
+                    }
+                });
+            }
+
+            let mut cycle_key_paths = Vec::new();
+            for id in cycle_ref_ids {
+                let reference = &self.references[id];
+                cycle_key_paths.push(reference.target.clone());
+                cycle_key_paths.push(reference.source.clone());
+            }
+            cycle_key_paths.sort_unstable();
+            return Err(cycle_key_paths);
+        }
+
+        Ok(results)
+    }
+
+    fn populate_dependencies(&mut self, new_id: usize, new_reference: &KeyPathReference) {
+        self.dependency_counts.push(0);
+
+        // If an existing reference's source path starts with the new reference's
+        // target path, then insert this new reference before that existing reference.
+        for id in Self::reference_ids_for_key_path(
+            &new_reference.target.0,
+            &self.references,
+            &self.reference_ids_by_source,
+            KeyPathReference::source,
+            KeyPath::starts_with,
+        ) {
+            Self::add_dependency(
+                (new_id, id),
+                &mut self.dependencies,
+                &mut self.dependency_counts,
+            );
+        }
+
+        // If an existing reference's target path starts with the new reference's
+        // source path, then insert this new reference after that existing reference.
+        for id in Self::reference_ids_for_key_path(
+            &new_reference.source.0,
+            &self.references,
+            &self.reference_ids_by_target,
+            KeyPathReference::target,
+            KeyPath::starts_with,
+        ) {
+            Self::add_dependency(
+                (id, new_id),
+                &mut self.dependencies,
+                &mut self.dependency_counts,
+            );
+        }
+
+        // If an existing reference's source path is a prefix of the new reference's
+        // target path, then insert this new reference before that existing reference.
+        for prefix in new_reference.target.prefixes() {
+            for id in Self::reference_ids_for_key_path(
+                prefix,
+                &self.references,
+                &self.reference_ids_by_source,
+                KeyPathReference::source,
+                PartialEq::eq,
+            ) {
+                Self::add_dependency(
+                    (new_id, id),
+                    &mut self.dependencies,
+                    &mut self.dependency_counts,
+                );
+            }
+        }
+
+        // If an existing reference's target path is a prefix of the new reference's
+        // source path, then insert this new reference after that existing reference.
+        for prefix in new_reference.source.prefixes() {
+            for id in Self::reference_ids_for_key_path(
+                prefix,
+                &self.references,
+                &self.reference_ids_by_target,
+                KeyPathReference::target,
+                PartialEq::eq,
+            ) {
+                Self::add_dependency(
+                    (id, new_id),
+                    &mut self.dependencies,
+                    &mut self.dependency_counts,
+                );
+            }
+        }
+    }
+
+    // Find all existing references that satisfy a given predicate with respect
+    // to a given key path. Use a sorted array of reference ids in order to avoid
+    // performing unnecessary comparisons.
+    fn reference_ids_for_key_path<'a>(
+        key_path: &[Key],
+        references: &[KeyPathReference],
+        sorted_reference_ids: &'a [usize],
+        reference_attribute: impl Fn(&KeyPathReference) -> &KeyPath,
+        predicate: impl Fn(&KeyPath, &[Key]) -> bool,
+    ) -> impl Iterator<Item = usize> + 'a {
+        let ix = sorted_reference_ids
+            .binary_search_by_key(&key_path, |id| &reference_attribute(&references[*id]).0)
+            .unwrap_or_else(|i| i);
+
+        let mut start_ix = ix;
+        while start_ix > 0 {
+            let reference_id = sorted_reference_ids[start_ix - 1];
+            let reference = &references[reference_id];
+            if !predicate(&reference_attribute(reference), key_path) {
+                break;
+            }
+            start_ix -= 1;
+        }
+
+        let mut end_ix = ix;
+        while end_ix < sorted_reference_ids.len() {
+            let reference_id = sorted_reference_ids[end_ix];
+            let reference = &references[reference_id];
+            if !predicate(&reference_attribute(reference), key_path) {
+                break;
+            }
+            end_ix += 1;
+        }
+
+        sorted_reference_ids[start_ix..end_ix].iter().copied()
+    }
+
+    fn add_dependency(
+        (predecessor, successor): (usize, usize),
+        dependencies: &mut Vec<(usize, usize)>,
+        dependency_counts: &mut Vec<usize>,
+    ) {
+        let dependency = (predecessor, successor);
+        if let Err(i) = dependencies.binary_search(&dependency) {
+            dependencies.insert(i, dependency);
+        }
+        dependency_counts[successor] += 1;
+    }
+}
+
+impl KeyPathReference {
+    fn source(&self) -> &KeyPath {
+        &self.source
+    }
+
+    fn target(&self) -> &KeyPath {
+        &self.target
+    }
+}
+
+impl KeyPath {
+    fn new(string: &str) -> Self {
+        Self(
+            string
+                .split(".")
+                .map(|key| Key::Object(key.to_string()))
+                .collect(),
+        )
+    }
+
+    fn starts_with(&self, other: &[Key]) -> bool {
+        self.0.starts_with(&other)
+    }
+
+    fn prefixes(&self) -> impl Iterator<Item = &[Key]> {
+        (1..self.0.len()).map(move |end_ix| &self.0[0..end_ix])
+    }
+}
+
+impl PartialEq<[Key]> for KeyPath {
+    fn eq(&self, other: &[Key]) -> bool {
+        self.0.eq(other)
+    }
+}
+
+impl fmt::Debug for KeyPathReference {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(
+            f,
+            "KeyPathReference {{ {} <- {} }}",
+            self.target, self.source
+        )
+    }
+}
+
+impl fmt::Display for KeyPath {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        for (i, key) in self.0.iter().enumerate() {
+            match key {
+                Key::Array(index) => write!(f, "[{}]", index)?,
+                Key::Object(key) => {
+                    if i > 0 {
+                        ".".fmt(f)?;
+                    }
+                    key.fmt(f)?;
+                }
+            }
+        }
+        Ok(())
+    }
+}
+
 impl Default for HighlightMap {
     fn default() -> Self {
         Self(Arc::new([]))
@@ -307,70 +614,36 @@ fn deep_merge_json(base: &mut Map<String, Value>, extension: Map<String, Value>)
     }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq)]
-enum Key {
-    Array(usize),
-    Object(String),
-}
-
-#[derive(Debug, PartialEq, Eq)]
-struct KeyPathReference {
-    source_path: Vec<Key>,
-    target_path: Vec<Key>,
-}
-
-impl PartialOrd for KeyPathReference {
-    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
-        if self.target_path.starts_with(&other.source_path)
-            || other.source_path.starts_with(&self.target_path)
-        {
-            Some(Ordering::Less)
-        } else if other.target_path.starts_with(&self.source_path)
-            || self.source_path.starts_with(&other.target_path)
-        {
-            Some(Ordering::Greater)
-        } else {
-            None
-        }
-    }
-}
-
-fn find_references(value: &Value, key_path: &mut Vec<Key>, references: &mut Vec<KeyPathReference>) {
+fn find_references(value: &Value, key_path: &mut KeyPath, references: &mut KeyPathReferenceSet) {
     match value {
         Value::Array(vec) => {
             for (ix, value) in vec.iter().enumerate() {
-                key_path.push(Key::Array(ix));
+                key_path.0.push(Key::Array(ix));
                 find_references(value, key_path, references);
-                key_path.pop();
+                key_path.0.pop();
             }
         }
         Value::Object(map) => {
             for (key, value) in map.iter() {
                 if key == "extends" {
                     if let Some(source_path) = value.as_str().and_then(|s| s.strip_prefix("$")) {
-                        references.push(KeyPathReference {
-                            source_path: source_path
-                                .split(".")
-                                .map(|key| Key::Object(key.to_string()))
-                                .collect(),
-                            target_path: key_path.clone(),
+                        references.insert(KeyPathReference {
+                            source: KeyPath::new(source_path),
+                            target: key_path.clone(),
                         });
                     }
                 } else {
-                    key_path.push(Key::Object(key.to_string()));
+                    key_path.0.push(Key::Object(key.to_string()));
                     find_references(value, key_path, references);
-                    key_path.pop();
+                    key_path.0.pop();
                 }
             }
         }
         Value::String(string) => {
             if let Some(source_path) = string.strip_prefix("$") {
-                references.push(KeyPathReference {
-                    source_path: source_path
-                        .split(".")
-                        .map(|key| Key::Object(key.to_string()))
-                        .collect(),
-                    target_path: key_path.clone(),
+                references.insert(KeyPathReference {
+                    source: KeyPath::new(source_path),
+                    target: key_path.clone(),
                 });
             }
         }
@@ -378,18 +651,8 @@ fn find_references(value: &Value, key_path: &mut Vec<Key>, references: &mut Vec<
     }
 }
 
-fn sort_references(references: &mut Vec<KeyPathReference>) {
-    for i in 0..references.len() {
-        for j in (i + 1)..references.len() {
-            if let Some(Ordering::Greater) = &references[i].partial_cmp(&references[j]) {
-                references.swap(i, j)
-            }
-        }
-    }
-}
-
-fn value_at<'a>(object: &'a mut Map<String, Value>, key_path: &Vec<Key>) -> Option<&'a mut Value> {
-    let mut key_path = key_path.iter();
+fn value_at<'a>(object: &'a mut Map<String, Value>, key_path: &KeyPath) -> Option<&'a mut Value> {
+    let mut key_path = key_path.0.iter();
     if let Some(Key::Object(first_key)) = key_path.next() {
         let mut cur_value = object.get_mut(first_key);
         for key in key_path {
@@ -430,9 +693,10 @@ where
 
 #[cfg(test)]
 mod tests {
-    use crate::assets::Assets;
+    use rand::{prelude::StdRng, Rng};
 
     use super::*;
+    use crate::assets::Assets;
 
     #[test]
     fn test_bundled_themes() {
@@ -565,6 +829,117 @@ mod tests {
         assert_eq!(theme.highlight_name(map.get(2)), Some("variable.builtin"));
     }
 
+    #[test]
+    fn test_key_path_reference_set_simple() {
+        let input_references = build_refs(&[
+            ("r", "a"),
+            ("a.b.c", "d"),
+            ("d.e", "f"),
+            ("t.u", "v"),
+            ("v.w", "x"),
+            ("v.y", "x"),
+            ("d.h", "i"),
+            ("v.z", "x"),
+            ("f.g", "d.h"),
+        ]);
+        let expected_references = build_refs(&[
+            ("d.h", "i"),
+            ("f.g", "d.h"),
+            ("d.e", "f"),
+            ("a.b.c", "d"),
+            ("r", "a"),
+            ("v.w", "x"),
+            ("v.y", "x"),
+            ("v.z", "x"),
+            ("t.u", "v"),
+        ])
+        .collect::<Vec<_>>();
+
+        let mut reference_set = KeyPathReferenceSet::default();
+        for reference in input_references {
+            reference_set.insert(reference);
+        }
+        assert_eq!(reference_set.top_sort().unwrap(), expected_references);
+    }
+
+    #[test]
+    fn test_key_path_reference_set_with_cycles() {
+        let input_references = build_refs(&[
+            ("x", "a.b"),
+            ("y", "x.c"),
+            ("a.b.c", "d.e"),
+            ("d.e.f", "g.h"),
+            ("g.h.i", "a"),
+        ]);
+
+        let mut reference_set = KeyPathReferenceSet::default();
+        for reference in input_references {
+            reference_set.insert(reference);
+        }
+
+        assert_eq!(
+            reference_set.top_sort().unwrap_err(),
+            &[
+                KeyPath::new("a"),
+                KeyPath::new("a.b.c"),
+                KeyPath::new("d.e"),
+                KeyPath::new("d.e.f"),
+                KeyPath::new("g.h"),
+                KeyPath::new("g.h.i"),
+            ]
+        );
+    }
+
+    #[gpui::test(iterations = 20)]
+    async fn test_key_path_reference_set_random(mut rng: StdRng) {
+        let examples: &[&[_]] = &[
+            &[
+                ("n.d.h", "i"),
+                ("f.g", "n.d.h"),
+                ("n.d.e", "f"),
+                ("a.b.c", "n.d"),
+                ("r", "a"),
+                ("v.w", "x"),
+                ("v.y", "x"),
+                ("v.z", "x"),
+                ("t.u", "v"),
+            ],
+            &[
+                ("w.x.y.z", "t.u.z"),
+                ("x", "w.x"),
+                ("a.b.c1", "x.b1.c"),
+                ("a.b.c2", "x.b2.c"),
+            ],
+            &[
+                ("x.y", "m.n.n.o.q"),
+                ("x.y.z", "m.n.n.o.p"),
+                ("u.v.w", "x.y.z"),
+                ("a.b.c.d", "u.v"),
+                ("a.b.c.d.e", "u.v"),
+                ("a.b.c.d.f", "u.v"),
+                ("a.b.c.d.g", "u.v"),
+            ],
+        ];
+
+        for example in examples {
+            let expected_references = build_refs(example).collect::<Vec<_>>();
+            let mut input_references = expected_references.clone();
+            input_references.sort_by_key(|_| rng.gen_range(0..1000));
+            let mut reference_set = KeyPathReferenceSet::default();
+            for reference in input_references {
+                reference_set.insert(reference);
+            }
+            assert_eq!(reference_set.top_sort().unwrap(), expected_references);
+        }
+    }
+
+    fn build_refs<'a>(rows: &'a [(&str, &str)]) -> impl Iterator<Item = KeyPathReference> + 'a {
+        rows.iter().map(|(target, source)| KeyPathReference {
+            target: KeyPath::new(target),
+            source: KeyPath::new(source),
+        })
+    }
+
     struct TestAssets(&'static [(&'static str, &'static str)]);
 
     impl AssetSource for TestAssets {