Assign base text language earlier to fix missing highlighting in deletion hunks (#24413)

Cole Miller and Max created

Release Notes:

- Fixed deletion diff hunks not being syntax highlighted in some cases

Co-authored-by: Max <max@zed.dev>

Change summary

crates/project/src/buffer_store.rs  | 73 +++++++++++++++++-------------
crates/project/src/project_tests.rs | 15 +++++
2 files changed, 56 insertions(+), 32 deletions(-)

Detailed changes

crates/project/src/buffer_store.rs 🔗

@@ -74,7 +74,6 @@ struct BufferDiffState {
     language: Option<Arc<Language>>,
     language_registry: Option<Arc<LanguageRegistry>>,
     diff_updated_futures: Vec<oneshot::Sender<()>>,
-    buffer_subscription: Option<Subscription>,
 
     head_text: Option<Arc<String>>,
     index_text: Option<Arc<String>>,
@@ -1447,13 +1446,14 @@ impl BufferStore {
         this: WeakEntity<Self>,
         kind: DiffKind,
         texts: Result<DiffBasesChange>,
-        buffer: Entity<Buffer>,
+        buffer_entity: Entity<Buffer>,
         mut cx: AsyncApp,
     ) -> Result<Entity<BufferDiff>> {
         let diff_bases_change = match texts {
             Err(e) => {
                 this.update(&mut cx, |this, cx| {
-                    let buffer_id = buffer.read(cx).remote_id();
+                    let buffer = buffer_entity.read(cx);
+                    let buffer_id = buffer.remote_id();
                     this.loading_diffs.remove(&(buffer_id, kind));
                 })?;
                 return Err(e);
@@ -1462,26 +1462,23 @@ impl BufferStore {
         };
 
         this.update(&mut cx, |this, cx| {
-            let buffer_id = buffer.read(cx).remote_id();
+            let buffer = buffer_entity.read(cx);
+            let buffer_id = buffer.remote_id();
+            let language = buffer.language().cloned();
+            let language_registry = buffer.language_registry();
+            let text_snapshot = buffer.text_snapshot();
             this.loading_diffs.remove(&(buffer_id, kind));
 
             if let Some(OpenBuffer::Complete { diff_state, .. }) =
-                this.opened_buffers.get_mut(&buffer.read(cx).remote_id())
+                this.opened_buffers.get_mut(&buffer_id)
             {
                 diff_state.update(cx, |diff_state, cx| {
-                    let buffer_id = buffer.read(cx).remote_id();
-                    diff_state.buffer_subscription.get_or_insert_with(|| {
-                        cx.subscribe(&buffer, |this, buffer, event, cx| match event {
-                            BufferEvent::LanguageChanged => {
-                                this.buffer_language_changed(buffer, cx)
-                            }
-                            _ => {}
-                        })
-                    });
+                    diff_state.language = language;
+                    diff_state.language_registry = language_registry;
 
-                    let diff = cx.new(|cx| BufferDiff {
+                    let diff = cx.new(|_| BufferDiff {
                         buffer_id,
-                        snapshot: BufferDiffSnapshot::new(&buffer.read(cx).text_snapshot()),
+                        snapshot: BufferDiffSnapshot::new(&text_snapshot),
                         unstaged_diff: None,
                     });
                     match kind {
@@ -1490,11 +1487,9 @@ impl BufferStore {
                             let unstaged_diff = if let Some(diff) = diff_state.unstaged_diff() {
                                 diff
                             } else {
-                                let unstaged_diff = cx.new(|cx| BufferDiff {
+                                let unstaged_diff = cx.new(|_| BufferDiff {
                                     buffer_id,
-                                    snapshot: BufferDiffSnapshot::new(
-                                        &buffer.read(cx).text_snapshot(),
-                                    ),
+                                    snapshot: BufferDiffSnapshot::new(&text_snapshot),
                                     unstaged_diff: None,
                                 });
                                 diff_state.unstaged_diff = Some(unstaged_diff.downgrade());
@@ -1508,8 +1503,7 @@ impl BufferStore {
                         }
                     };
 
-                    let buffer = buffer.read(cx).text_snapshot();
-                    let rx = diff_state.diff_bases_changed(buffer, diff_bases_change, cx);
+                    let rx = diff_state.diff_bases_changed(text_snapshot, diff_bases_change, cx);
 
                     Ok(async move {
                         rx.await.ok();
@@ -1721,16 +1715,23 @@ impl BufferStore {
         }
     }
 
-    fn add_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) -> Result<()> {
-        let remote_id = buffer.read(cx).remote_id();
-        let is_remote = buffer.read(cx).replica_id() != 0;
+    fn add_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) -> Result<()> {
+        let buffer = buffer_entity.read(cx);
+        let language = buffer.language().cloned();
+        let language_registry = buffer.language_registry();
+        let remote_id = buffer.remote_id();
+        let is_remote = buffer.replica_id() != 0;
         let open_buffer = OpenBuffer::Complete {
-            buffer: buffer.downgrade(),
-            diff_state: cx.new(|_| BufferDiffState::default()),
+            buffer: buffer_entity.downgrade(),
+            diff_state: cx.new(|_| BufferDiffState {
+                language,
+                language_registry,
+                ..Default::default()
+            }),
         };
 
         let handle = cx.entity().downgrade();
-        buffer.update(cx, move |_, cx| {
+        buffer_entity.update(cx, move |_, cx| {
             cx.on_release(move |buffer, cx| {
                 handle
                     .update(cx, |_, cx| {
@@ -1747,7 +1748,7 @@ impl BufferStore {
             }
             hash_map::Entry::Occupied(mut entry) => {
                 if let OpenBuffer::Operations(operations) = entry.get_mut() {
-                    buffer.update(cx, |b, cx| b.apply_ops(operations.drain(..), cx));
+                    buffer_entity.update(cx, |b, cx| b.apply_ops(operations.drain(..), cx));
                 } else if entry.get().upgrade().is_some() {
                     if is_remote {
                         return Ok(());
@@ -1760,8 +1761,8 @@ impl BufferStore {
             }
         }
 
-        cx.subscribe(&buffer, Self::on_buffer_event).detach();
-        cx.emit(BufferStoreEvent::BufferAdded(buffer));
+        cx.subscribe(&buffer_entity, Self::on_buffer_event).detach();
+        cx.emit(BufferStoreEvent::BufferAdded(buffer_entity));
         Ok(())
     }
 
@@ -1982,6 +1983,16 @@ impl BufferStore {
                     })
                     .log_err();
             }
+            BufferEvent::LanguageChanged => {
+                let buffer_id = buffer.read(cx).remote_id();
+                if let Some(OpenBuffer::Complete { diff_state, .. }) =
+                    self.opened_buffers.get(&buffer_id)
+                {
+                    diff_state.update(cx, |diff_state, cx| {
+                        diff_state.buffer_language_changed(buffer, cx);
+                    });
+                }
+            }
             _ => {}
         }
     }

crates/project/src/project_tests.rs 🔗

@@ -5776,6 +5776,9 @@ async fn test_uncommitted_diff_for_buffer(cx: &mut gpui::TestAppContext) {
     );
 
     let project = Project::test(fs.clone(), ["/dir".as_ref()], cx).await;
+    let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+    let language = rust_lang();
+    language_registry.add(language.clone());
 
     let buffer = project
         .update(cx, |project, cx| {
@@ -5790,13 +5793,23 @@ async fn test_uncommitted_diff_for_buffer(cx: &mut gpui::TestAppContext) {
         .await
         .unwrap();
 
+    uncommitted_diff.read_with(cx, |diff, _| {
+        assert_eq!(
+            diff.snapshot
+                .base_text
+                .as_ref()
+                .and_then(|base| base.language().cloned()),
+            Some(language)
+        )
+    });
+
     cx.run_until_parked();
     uncommitted_diff.update(cx, |uncommitted_diff, cx| {
         let snapshot = buffer.read(cx).snapshot();
         assert_hunks(
             uncommitted_diff.diff_hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot),
             &snapshot,
-            &uncommitted_diff.snapshot.base_text.as_ref().unwrap().text(),
+            &uncommitted_diff.base_text_string().unwrap(),
             &[
                 (0..1, "", "// print goodbye\n"),
                 (