Reparse unknown injection ranges in buffer when adding a new language

Antonio Scandurra created

Change summary

crates/language/src/buffer.rs     |  33 ++++---
crates/language/src/language.rs   |   7 +
crates/language/src/syntax_map.rs | 151 ++++++++++++++++++++------------
crates/project/src/project.rs     |  14 ++
4 files changed, 129 insertions(+), 76 deletions(-)

Detailed changes

crates/language/src/buffer.rs 🔗

@@ -797,12 +797,16 @@ impl Buffer {
         self.parsing_in_background
     }
 
+    pub fn contains_unknown_injections(&self) -> bool {
+        self.syntax_map.lock().contains_unknown_injections()
+    }
+
     #[cfg(test)]
     pub fn set_sync_parse_timeout(&mut self, timeout: Duration) {
         self.sync_parse_timeout = timeout;
     }
 
-    fn reparse(&mut self, cx: &mut ModelContext<Self>) {
+    pub fn reparse(&mut self, cx: &mut ModelContext<Self>) {
         if self.parsing_in_background {
             return;
         }
@@ -819,13 +823,13 @@ impl Buffer {
         syntax_map.interpolate(&text);
         let language_registry = syntax_map.language_registry();
         let mut syntax_snapshot = syntax_map.snapshot();
-        let syntax_map_version = syntax_map.parsed_version();
         drop(syntax_map);
 
         let parse_task = cx.background().spawn({
             let language = language.clone();
+            let language_registry = language_registry.clone();
             async move {
-                syntax_snapshot.reparse(&syntax_map_version, &text, language_registry, language);
+                syntax_snapshot.reparse(&text, language_registry, language);
                 syntax_snapshot
             }
         });
@@ -835,7 +839,7 @@ impl Buffer {
             .block_with_timeout(self.sync_parse_timeout, parse_task)
         {
             Ok(new_syntax_snapshot) => {
-                self.did_finish_parsing(new_syntax_snapshot, parsed_version, cx);
+                self.did_finish_parsing(new_syntax_snapshot, cx);
                 return;
             }
             Err(parse_task) => {
@@ -847,9 +851,15 @@ impl Buffer {
                             this.language.as_ref().map_or(true, |current_language| {
                                 !Arc::ptr_eq(&language, current_language)
                             });
-                        let parse_again =
-                            this.version.changed_since(&parsed_version) || grammar_changed;
-                        this.did_finish_parsing(new_syntax_map, parsed_version, cx);
+                        let language_registry_changed = new_syntax_map
+                            .contains_unknown_injections()
+                            && language_registry.map_or(false, |registry| {
+                                registry.version() != new_syntax_map.language_registry_version()
+                            });
+                        let parse_again = language_registry_changed
+                            || grammar_changed
+                            || this.version.changed_since(&parsed_version);
+                        this.did_finish_parsing(new_syntax_map, cx);
                         this.parsing_in_background = false;
                         if parse_again {
                             this.reparse(cx);
@@ -861,14 +871,9 @@ impl Buffer {
         }
     }
 
-    fn did_finish_parsing(
-        &mut self,
-        syntax_snapshot: SyntaxSnapshot,
-        version: clock::Global,
-        cx: &mut ModelContext<Self>,
-    ) {
+    fn did_finish_parsing(&mut self, syntax_snapshot: SyntaxSnapshot, cx: &mut ModelContext<Self>) {
         self.parse_count += 1;
-        self.syntax_map.lock().did_parse(syntax_snapshot, version);
+        self.syntax_map.lock().did_parse(syntax_snapshot);
         self.request_autoindent(cx);
         cx.emit(Event::Reparsed);
         cx.notify();

crates/language/src/language.rs 🔗

@@ -422,6 +422,7 @@ pub struct LanguageRegistry {
     >,
     subscription: RwLock<(watch::Sender<()>, watch::Receiver<()>)>,
     theme: RwLock<Option<Arc<Theme>>>,
+    version: AtomicUsize,
 }
 
 impl LanguageRegistry {
@@ -436,6 +437,7 @@ impl LanguageRegistry {
             lsp_binary_paths: Default::default(),
             subscription: RwLock::new(watch::channel()),
             theme: Default::default(),
+            version: Default::default(),
         }
     }
 
@@ -449,6 +451,7 @@ impl LanguageRegistry {
             language.set_theme(&theme.editor.syntax);
         }
         self.languages.write().push(language);
+        self.version.fetch_add(1, SeqCst);
         *self.subscription.write().0.borrow_mut() = ();
     }
 
@@ -456,6 +459,10 @@ impl LanguageRegistry {
         self.subscription.read().1.clone()
     }
 
+    pub fn version(&self) -> usize {
+        self.version.load(SeqCst)
+    }
+
     pub fn set_theme(&self, theme: Arc<Theme>) {
         *self.theme.write() = Some(theme.clone());
         for language in self.languages.read().iter() {

crates/language/src/syntax_map.rs 🔗

@@ -27,8 +27,6 @@ lazy_static! {
 
 #[derive(Default)]
 pub struct SyntaxMap {
-    parsed_version: clock::Global,
-    interpolated_version: clock::Global,
     snapshot: SyntaxSnapshot,
     language_registry: Option<Arc<LanguageRegistry>>,
 }
@@ -36,6 +34,9 @@ pub struct SyntaxMap {
 #[derive(Clone, Default)]
 pub struct SyntaxSnapshot {
     layers: SumTree<SyntaxLayer>,
+    parsed_version: clock::Global,
+    interpolated_version: clock::Global,
+    language_registry_version: usize,
 }
 
 #[derive(Default)]
@@ -134,7 +135,7 @@ struct SyntaxLayerSummary {
     range: Range<Anchor>,
     last_layer_range: Range<Anchor>,
     last_layer_language: Option<usize>,
-    contains_pending_layer: bool,
+    contains_unknown_injections: bool,
 }
 
 #[derive(Clone, Debug)]
@@ -218,30 +219,17 @@ impl SyntaxMap {
         self.language_registry.clone()
     }
 
-    pub fn parsed_version(&self) -> clock::Global {
-        self.parsed_version.clone()
-    }
-
     pub fn interpolate(&mut self, text: &BufferSnapshot) {
-        self.snapshot.interpolate(&self.interpolated_version, text);
-        self.interpolated_version = text.version.clone();
+        self.snapshot.interpolate(text);
     }
 
     #[cfg(test)]
     pub fn reparse(&mut self, language: Arc<Language>, text: &BufferSnapshot) {
-        self.snapshot.reparse(
-            &self.parsed_version,
-            text,
-            self.language_registry.clone(),
-            language,
-        );
-        self.parsed_version = text.version.clone();
-        self.interpolated_version = text.version.clone();
+        self.snapshot
+            .reparse(text, self.language_registry.clone(), language);
     }
 
-    pub fn did_parse(&mut self, snapshot: SyntaxSnapshot, version: clock::Global) {
-        self.interpolated_version = version.clone();
-        self.parsed_version = version;
+    pub fn did_parse(&mut self, snapshot: SyntaxSnapshot) {
         self.snapshot = snapshot;
     }
 
@@ -255,10 +243,12 @@ impl SyntaxSnapshot {
         self.layers.is_empty()
     }
 
-    pub fn interpolate(&mut self, from_version: &clock::Global, text: &BufferSnapshot) {
+    fn interpolate(&mut self, text: &BufferSnapshot) {
         let edits = text
-            .anchored_edits_since::<(usize, Point)>(&from_version)
+            .anchored_edits_since::<(usize, Point)>(&self.interpolated_version)
             .collect::<Vec<_>>();
+        self.interpolated_version = text.version().clone();
+
         if edits.is_empty() {
             return;
         }
@@ -372,12 +362,53 @@ impl SyntaxSnapshot {
 
     pub fn reparse(
         &mut self,
-        from_version: &clock::Global,
         text: &BufferSnapshot,
         registry: Option<Arc<LanguageRegistry>>,
         root_language: Arc<Language>,
     ) {
-        let edits = text.edits_since::<usize>(from_version).collect::<Vec<_>>();
+        let edit_ranges = text
+            .edits_since::<usize>(&self.parsed_version)
+            .map(|edit| edit.new)
+            .collect::<Vec<_>>();
+        self.reparse_with_ranges(text, root_language.clone(), edit_ranges, registry.as_ref());
+
+        if let Some(registry) = registry {
+            if registry.version() != self.language_registry_version {
+                let mut resolved_injection_ranges = Vec::new();
+                let mut cursor = self
+                    .layers
+                    .filter::<_, ()>(|summary| summary.contains_unknown_injections);
+                cursor.next(text);
+                while let Some(layer) = cursor.item() {
+                    let SyntaxLayerContent::Pending { language_name } = &layer.content else { unreachable!() };
+                    if language_for_injection(language_name, &registry).is_some() {
+                        resolved_injection_ranges.push(layer.range.to_offset(text));
+                    }
+
+                    cursor.next(text);
+                }
+                drop(cursor);
+
+                if !resolved_injection_ranges.is_empty() {
+                    self.reparse_with_ranges(
+                        text,
+                        root_language,
+                        resolved_injection_ranges,
+                        Some(&registry),
+                    );
+                }
+                self.language_registry_version = registry.version();
+            }
+        }
+    }
+
+    fn reparse_with_ranges(
+        &mut self,
+        text: &BufferSnapshot,
+        root_language: Arc<Language>,
+        invalidated_ranges: Vec<Range<usize>>,
+        registry: Option<&Arc<LanguageRegistry>>,
+    ) {
         let max_depth = self.layers.summary().max_depth;
         let mut cursor = self.layers.cursor::<SyntaxLayerSummary>();
         cursor.next(&text);
@@ -503,7 +534,7 @@ impl SyntaxSnapshot {
                             Some(old_tree.clone()),
                         );
                         changed_ranges = join_ranges(
-                            edits.iter().map(|e| e.new.clone()).filter(|range| {
+                            invalidated_ranges.iter().cloned().filter(|range| {
                                 range.start <= step_end_byte && range.end >= step_start_byte
                             }),
                             old_tree.changed_ranges(&tree).map(|r| {
@@ -570,6 +601,8 @@ impl SyntaxSnapshot {
 
         drop(cursor);
         self.layers = layers;
+        self.interpolated_version = text.version.clone();
+        self.parsed_version = text.version.clone();
     }
 
     pub fn single_tree_captures<'a>(
@@ -665,25 +698,12 @@ impl SyntaxSnapshot {
         })
     }
 
-    pub fn unknown_injection_languages<'a>(
-        &'a self,
-        buffer: &'a BufferSnapshot,
-    ) -> impl 'a + Iterator<Item = &Arc<str>> {
-        let mut cursor = self
-            .layers
-            .filter::<_, ()>(|summary| summary.contains_pending_layer);
-        cursor.next(buffer);
-        iter::from_fn(move || {
-            while let Some(layer) = cursor.item() {
-                if let SyntaxLayerContent::Pending { language_name } = &layer.content {
-                    cursor.next(buffer);
-                    return Some(language_name);
-                } else {
-                    cursor.next(buffer);
-                }
-            }
-            None
-        })
+    pub fn contains_unknown_injections(&self) -> bool {
+        self.layers.summary().contains_unknown_injections
+    }
+
+    pub fn language_registry_version(&self) -> usize {
+        self.language_registry_version
     }
 }
 
@@ -1058,7 +1078,7 @@ fn get_injections(
     combined_injection_ranges.clear();
     for pattern in &config.patterns {
         if let (Some(language_name), true) = (pattern.language.as_ref(), pattern.combined) {
-            if let Some(language) = language_registry.language_for_name(language_name) {
+            if let Some(language) = language_for_injection(language_name, language_registry) {
                 combined_injection_ranges.insert(language, Vec::new());
             }
         }
@@ -1103,9 +1123,7 @@ fn get_injections(
             };
 
             if let Some(language_name) = language_name {
-                let language = language_registry
-                    .language_for_name(&language_name)
-                    .or_else(|| language_registry.language_for_extension(&language_name));
+                let language = language_for_injection(&language_name, language_registry);
                 let range = text.anchor_before(step_range.start)..text.anchor_after(step_range.end);
                 if let Some(language) = language {
                     if combined {
@@ -1153,6 +1171,15 @@ fn get_injections(
     }
 }
 
+fn language_for_injection(
+    language_name: &str,
+    language_registry: &LanguageRegistry,
+) -> Option<Arc<Language>> {
+    language_registry
+        .language_for_name(language_name)
+        .or_else(|| language_registry.language_for_extension(language_name))
+}
+
 fn splice_included_ranges(
     mut ranges: Vec<tree_sitter::Range>,
     changed_ranges: &[Range<usize>],
@@ -1379,7 +1406,7 @@ impl Default for SyntaxLayerSummary {
             range: Anchor::MAX..Anchor::MIN,
             last_layer_range: Anchor::MIN..Anchor::MAX,
             last_layer_language: None,
-            contains_pending_layer: false,
+            contains_unknown_injections: false,
         }
     }
 }
@@ -1401,7 +1428,7 @@ impl sum_tree::Summary for SyntaxLayerSummary {
         }
         self.last_layer_range = other.last_layer_range.clone();
         self.last_layer_language = other.last_layer_language;
-        self.contains_pending_layer |= other.contains_pending_layer;
+        self.contains_unknown_injections |= other.contains_unknown_injections;
     }
 }
 
@@ -1452,7 +1479,7 @@ impl sum_tree::Item for SyntaxLayer {
             range: self.range.clone(),
             last_layer_range: self.range.clone(),
             last_layer_language: self.content.language_id(),
-            contains_pending_layer: matches!(self.content, SyntaxLayerContent::Pending { .. }),
+            contains_unknown_injections: matches!(self.content, SyntaxLayerContent::Pending { .. }),
         }
     }
 }
@@ -1744,7 +1771,7 @@ mod tests {
 
         // Replace Ruby with a language that hasn't been loaded yet.
         let macro_name_range = range_for_text(&buffer, "ruby");
-        buffer.edit([(macro_name_range, "erb")]);
+        buffer.edit([(macro_name_range, "html")]);
         syntax_map.interpolate(&buffer);
         syntax_map.reparse(markdown.clone(), &buffer);
         assert_layers_for_range(
@@ -1755,14 +1782,20 @@ mod tests {
                 "...(fenced_code_block (fenced_code_block_delimiter) (info_string (language)) (code_fence_content) (fenced_code_block_delimiter..."
             ],
         );
-        assert_eq!(
-            syntax_map
-                .unknown_injection_languages(&buffer)
-                .collect::<Vec<_>>(),
-            vec![&Arc::from("erb")]
-        );
+        assert!(syntax_map.contains_unknown_injections());
 
-        registry.add(Arc::new(erb_lang()));
+        registry.add(Arc::new(html_lang()));
+        syntax_map.reparse(markdown.clone(), &buffer);
+        assert_layers_for_range(
+            &syntax_map,
+            &buffer,
+            Point::new(3, 0)..Point::new(3, 0),
+            &[
+                "...(fenced_code_block (fenced_code_block_delimiter) (info_string (language)) (code_fence_content) (fenced_code_block_delimiter...",
+                "(fragment (text))",
+            ],
+        );
+        assert!(!syntax_map.contains_unknown_injections());
     }
 
     #[gpui::test]

crates/project/src/project.rs 🔗

@@ -1765,10 +1765,14 @@ impl Project {
                 if let Some(project) = project.upgrade(&cx) {
                     project.update(&mut cx, |project, cx| {
                         let mut buffers_without_language = Vec::new();
+                        let mut buffers_with_unknown_injections = Vec::new();
                         for buffer in project.opened_buffers.values() {
-                            if let Some(buffer) = buffer.upgrade(cx) {
-                                if buffer.read(cx).language().is_none() {
-                                    buffers_without_language.push(buffer);
+                            if let Some(handle) = buffer.upgrade(cx) {
+                                let buffer = &handle.read(cx);
+                                if buffer.language().is_none() {
+                                    buffers_without_language.push(handle);
+                                } else if buffer.contains_unknown_injections() {
+                                    buffers_with_unknown_injections.push(handle);
                                 }
                             }
                         }
@@ -1777,6 +1781,10 @@ impl Project {
                             project.assign_language_to_buffer(&buffer, cx);
                             project.register_buffer_with_language_server(&buffer, cx);
                         }
+
+                        for buffer in buffers_with_unknown_injections {
+                            buffer.update(cx, |buffer, cx| buffer.reparse(cx));
+                        }
                     });
                 }
             }