Fix leak when project is unshared while LSP handler waits for edits

Max Brunsfeld created

Change summary

crates/language/src/buffer.rs     | 12 ++++--
crates/language/src/proto.rs      |  8 ++--
crates/project/src/lsp_command.rs | 40 ++++++++++++------------
crates/project/src/project.rs     | 42 ++++++++++++++-----------
crates/project/src/worktree.rs    |  2 
crates/text/src/text.rs           | 54 +++++++++++++++++++++++---------
6 files changed, 95 insertions(+), 63 deletions(-)

Detailed changes

crates/language/src/buffer.rs 🔗

@@ -377,7 +377,7 @@ impl Buffer {
             rpc::proto::LineEnding::from_i32(message.line_ending)
                 .ok_or_else(|| anyhow!("missing line_ending"))?,
         ));
-        this.saved_version = proto::deserialize_version(message.saved_version);
+        this.saved_version = proto::deserialize_version(&message.saved_version);
         this.saved_version_fingerprint =
             proto::deserialize_fingerprint(&message.saved_version_fingerprint)?;
         this.saved_mtime = message
@@ -1309,21 +1309,25 @@ impl Buffer {
     pub fn wait_for_edits(
         &mut self,
         edit_ids: impl IntoIterator<Item = clock::Local>,
-    ) -> impl Future<Output = ()> {
+    ) -> impl Future<Output = Result<()>> {
         self.text.wait_for_edits(edit_ids)
     }
 
     pub fn wait_for_anchors<'a>(
         &mut self,
         anchors: impl IntoIterator<Item = &'a Anchor>,
-    ) -> impl Future<Output = ()> {
+    ) -> impl Future<Output = Result<()>> {
         self.text.wait_for_anchors(anchors)
     }
 
-    pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future<Output = ()> {
+    pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future<Output = Result<()>> {
         self.text.wait_for_version(version)
     }
 
+    pub fn give_up_waiting(&mut self) {
+        self.text.give_up_waiting();
+    }
+
     pub fn set_active_selections(
         &mut self,
         selections: Arc<[Selection<Anchor>]>,

crates/language/src/proto.rs 🔗

@@ -220,7 +220,7 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operati
                             replica_id: undo.replica_id as ReplicaId,
                             value: undo.local_timestamp,
                         },
-                        version: deserialize_version(undo.version),
+                        version: deserialize_version(&undo.version),
                         counts: undo
                             .counts
                             .into_iter()
@@ -294,7 +294,7 @@ pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation
             local: edit.local_timestamp,
             lamport: edit.lamport_timestamp,
         },
-        version: deserialize_version(edit.version),
+        version: deserialize_version(&edit.version),
         ranges: edit.ranges.into_iter().map(deserialize_range).collect(),
         new_text: edit.new_text.into_iter().map(Arc::from).collect(),
     }
@@ -509,7 +509,7 @@ pub fn deserialize_transaction(transaction: proto::Transaction) -> Result<Transa
             .into_iter()
             .map(deserialize_local_timestamp)
             .collect(),
-        start: deserialize_version(transaction.start),
+        start: deserialize_version(&transaction.start),
     })
 }
 
@@ -538,7 +538,7 @@ pub fn deserialize_range(range: proto::Range) -> Range<FullOffset> {
     FullOffset(range.start as usize)..FullOffset(range.end as usize)
 }
 
-pub fn deserialize_version(message: Vec<proto::VectorClockEntry>) -> clock::Global {
+pub fn deserialize_version(message: &[proto::VectorClockEntry]) -> clock::Global {
     let mut version = clock::Global::new();
     for entry in message {
         version.observe(clock::Local {

crates/project/src/lsp_command.rs 🔗

@@ -161,9 +161,9 @@ impl LspCommand for PrepareRename {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(message.version))
+                buffer.wait_for_version(deserialize_version(&message.version))
             })
-            .await;
+            .await?;
 
         Ok(Self {
             position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer)),
@@ -199,9 +199,9 @@ impl LspCommand for PrepareRename {
         if message.can_rename {
             buffer
                 .update(&mut cx, |buffer, _| {
-                    buffer.wait_for_version(deserialize_version(message.version))
+                    buffer.wait_for_version(deserialize_version(&message.version))
                 })
-                .await;
+                .await?;
             let start = message.start.and_then(deserialize_anchor);
             let end = message.end.and_then(deserialize_anchor);
             Ok(start.zip(end).map(|(start, end)| start..end))
@@ -281,9 +281,9 @@ impl LspCommand for PerformRename {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(message.version))
+                buffer.wait_for_version(deserialize_version(&message.version))
             })
-            .await;
+            .await?;
         Ok(Self {
             position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer)),
             new_name: message.new_name,
@@ -378,9 +378,9 @@ impl LspCommand for GetDefinition {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(message.version))
+                buffer.wait_for_version(deserialize_version(&message.version))
             })
-            .await;
+            .await?;
         Ok(Self {
             position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer)),
         })
@@ -464,9 +464,9 @@ impl LspCommand for GetTypeDefinition {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(message.version))
+                buffer.wait_for_version(deserialize_version(&message.version))
             })
-            .await;
+            .await?;
         Ok(Self {
             position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer)),
         })
@@ -537,7 +537,7 @@ async fn location_links_from_proto(
                     .ok_or_else(|| anyhow!("missing origin end"))?;
                 buffer
                     .update(&mut cx, |buffer, _| buffer.wait_for_anchors([&start, &end]))
-                    .await;
+                    .await?;
                 Some(Location {
                     buffer,
                     range: start..end,
@@ -562,7 +562,7 @@ async fn location_links_from_proto(
             .ok_or_else(|| anyhow!("missing target end"))?;
         buffer
             .update(&mut cx, |buffer, _| buffer.wait_for_anchors([&start, &end]))
-            .await;
+            .await?;
         let target = Location {
             buffer,
             range: start..end,
@@ -774,9 +774,9 @@ impl LspCommand for GetReferences {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(message.version))
+                buffer.wait_for_version(deserialize_version(&message.version))
             })
-            .await;
+            .await?;
         Ok(Self {
             position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer)),
         })
@@ -827,7 +827,7 @@ impl LspCommand for GetReferences {
                 .ok_or_else(|| anyhow!("missing target end"))?;
             target_buffer
                 .update(&mut cx, |buffer, _| buffer.wait_for_anchors([&start, &end]))
-                .await;
+                .await?;
             locations.push(Location {
                 buffer: target_buffer,
                 range: start..end,
@@ -915,9 +915,9 @@ impl LspCommand for GetDocumentHighlights {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(message.version))
+                buffer.wait_for_version(deserialize_version(&message.version))
             })
-            .await;
+            .await?;
         Ok(Self {
             position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer)),
         })
@@ -965,7 +965,7 @@ impl LspCommand for GetDocumentHighlights {
                 .ok_or_else(|| anyhow!("missing target end"))?;
             buffer
                 .update(&mut cx, |buffer, _| buffer.wait_for_anchors([&start, &end]))
-                .await;
+                .await?;
             let kind = match proto::document_highlight::Kind::from_i32(highlight.kind) {
                 Some(proto::document_highlight::Kind::Text) => DocumentHighlightKind::TEXT,
                 Some(proto::document_highlight::Kind::Read) => DocumentHighlightKind::READ,
@@ -1117,9 +1117,9 @@ impl LspCommand for GetHover {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(message.version))
+                buffer.wait_for_version(deserialize_version(&message.version))
             })
-            .await;
+            .await?;
         Ok(Self {
             position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer)),
         })

crates/project/src/project.rs 🔗

@@ -1182,6 +1182,11 @@ impl Project {
             }
 
             for open_buffer in self.opened_buffers.values_mut() {
+                // Wake up any tasks waiting for peers' edits to this buffer.
+                if let Some(buffer) = open_buffer.upgrade(cx) {
+                    buffer.update(cx, |buffer, _| buffer.give_up_waiting());
+                }
+
                 if let OpenBuffer::Strong(buffer) = open_buffer {
                     *open_buffer = OpenBuffer::Weak(buffer.downgrade());
                 }
@@ -3738,9 +3743,9 @@ impl Project {
                 } else {
                     source_buffer_handle
                         .update(&mut cx, |buffer, _| {
-                            buffer.wait_for_version(deserialize_version(response.version))
+                            buffer.wait_for_version(deserialize_version(&response.version))
                         })
-                        .await;
+                        .await?;
 
                     let completions = response.completions.into_iter().map(|completion| {
                         language::proto::deserialize_completion(completion, language.clone())
@@ -3831,7 +3836,7 @@ impl Project {
                         .update(&mut cx, |buffer, _| {
                             buffer.wait_for_edits(transaction.edit_ids.iter().copied())
                         })
-                        .await;
+                        .await?;
                     if push_to_history {
                         buffer_handle.update(&mut cx, |buffer, _| {
                             buffer.push_transaction(transaction.clone(), Instant::now());
@@ -3939,9 +3944,9 @@ impl Project {
                 } else {
                     buffer_handle
                         .update(&mut cx, |buffer, _| {
-                            buffer.wait_for_version(deserialize_version(response.version))
+                            buffer.wait_for_version(deserialize_version(&response.version))
                         })
-                        .await;
+                        .await?;
 
                     response
                         .actions
@@ -5425,8 +5430,6 @@ impl Project {
         mut cx: AsyncAppContext,
     ) -> Result<proto::BufferSaved> {
         let buffer_id = envelope.payload.buffer_id;
-        let requested_version = deserialize_version(envelope.payload.version);
-
         let (project_id, buffer) = this.update(&mut cx, |this, cx| {
             let project_id = this.remote_id().ok_or_else(|| anyhow!("not connected"))?;
             let buffer = this
@@ -5434,13 +5437,14 @@ impl Project {
                 .get(&buffer_id)
                 .and_then(|buffer| buffer.upgrade(cx))
                 .ok_or_else(|| anyhow!("unknown buffer id {}", buffer_id))?;
-            Ok::<_, anyhow::Error>((project_id, buffer))
+            anyhow::Ok((project_id, buffer))
         })?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(requested_version)
+                buffer.wait_for_version(deserialize_version(&envelope.payload.version))
             })
-            .await;
+            .await?;
+        let buffer_id = buffer.read_with(&cx, |buffer, _| buffer.remote_id());
 
         let (saved_version, fingerprint, mtime) = this
             .update(&mut cx, |this, cx| this.save_buffer(buffer, cx))
@@ -5503,7 +5507,7 @@ impl Project {
             this.shared_buffers.entry(guest_id).or_default().clear();
             for buffer in envelope.payload.buffers {
                 let buffer_id = buffer.id;
-                let remote_version = language::proto::deserialize_version(buffer.version);
+                let remote_version = language::proto::deserialize_version(&buffer.version);
                 if let Some(buffer) = this.buffer_for_id(buffer_id, cx) {
                     this.shared_buffers
                         .entry(guest_id)
@@ -5619,10 +5623,10 @@ impl Project {
                 .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))
         })?;
 
-        let version = deserialize_version(envelope.payload.version);
+        let version = deserialize_version(&envelope.payload.version);
         buffer
             .update(&mut cx, |buffer, _| buffer.wait_for_version(version))
-            .await;
+            .await?;
         let version = buffer.read_with(&cx, |buffer, _| buffer.version());
 
         let position = envelope
@@ -5710,9 +5714,9 @@ impl Project {
         })?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(envelope.payload.version))
+                buffer.wait_for_version(deserialize_version(&envelope.payload.version))
             })
-            .await;
+            .await?;
 
         let version = buffer.read_with(&cx, |buffer, _| buffer.version());
         let code_actions = this.update(&mut cx, |this, cx| {
@@ -5979,7 +5983,7 @@ impl Project {
                     .update(&mut cx, |buffer, _| {
                         buffer.wait_for_edits(transaction.edit_ids.iter().copied())
                     })
-                    .await;
+                    .await?;
 
                 if push_to_history {
                     buffer.update(&mut cx, |buffer, _| {
@@ -6098,7 +6102,7 @@ impl Project {
             let send_updates_for_buffers = response.buffers.into_iter().map(|buffer| {
                 let client = client.clone();
                 let buffer_id = buffer.id;
-                let remote_version = language::proto::deserialize_version(buffer.version);
+                let remote_version = language::proto::deserialize_version(&buffer.version);
                 this.read_with(&cx, |this, cx| {
                     if let Some(buffer) = this.buffer_for_id(buffer_id, cx) {
                         let operations = buffer.read(cx).serialize_ops(Some(remote_version), cx);
@@ -6263,7 +6267,7 @@ impl Project {
         mut cx: AsyncAppContext,
     ) -> Result<()> {
         let fingerprint = deserialize_fingerprint(&envelope.payload.fingerprint)?;
-        let version = deserialize_version(envelope.payload.version);
+        let version = deserialize_version(&envelope.payload.version);
         let mtime = envelope
             .payload
             .mtime
@@ -6296,7 +6300,7 @@ impl Project {
         mut cx: AsyncAppContext,
     ) -> Result<()> {
         let payload = envelope.payload;
-        let version = deserialize_version(payload.version);
+        let version = deserialize_version(&payload.version);
         let fingerprint = deserialize_fingerprint(&payload.fingerprint)?;
         let line_ending = deserialize_line_ending(
             proto::LineEnding::from_i32(payload.line_ending)

crates/project/src/worktree.rs 🔗

@@ -1064,7 +1064,7 @@ impl RemoteWorktree {
                     version: serialize_version(&version),
                 })
                 .await?;
-            let version = deserialize_version(response.version);
+            let version = deserialize_version(&response.version);
             let fingerprint = deserialize_fingerprint(&response.fingerprint)?;
             let mtime = response
                 .mtime

crates/text/src/text.rs 🔗

@@ -11,14 +11,14 @@ mod tests;
 mod undo_map;
 
 pub use anchor::*;
-use anyhow::Result;
+use anyhow::{anyhow, Result};
 use clock::ReplicaId;
 use collections::{HashMap, HashSet};
 use fs::LineEnding;
 use locator::Locator;
 use operation_queue::OperationQueue;
 pub use patch::Patch;
-use postage::{barrier, oneshot, prelude::*};
+use postage::{oneshot, prelude::*};
 
 pub use rope::*;
 pub use selection::*;
@@ -52,7 +52,7 @@ pub struct Buffer {
     pub lamport_clock: clock::Lamport,
     subscriptions: Topic,
     edit_id_resolvers: HashMap<clock::Local, Vec<oneshot::Sender<()>>>,
-    version_barriers: Vec<(clock::Global, barrier::Sender)>,
+    wait_for_version_txs: Vec<(clock::Global, oneshot::Sender<()>)>,
 }
 
 #[derive(Clone)]
@@ -522,7 +522,7 @@ impl Buffer {
             lamport_clock,
             subscriptions: Default::default(),
             edit_id_resolvers: Default::default(),
-            version_barriers: Default::default(),
+            wait_for_version_txs: Default::default(),
         }
     }
 
@@ -793,8 +793,14 @@ impl Buffer {
                 }
             }
         }
-        self.version_barriers
-            .retain(|(version, _)| !self.snapshot.version().observed_all(version));
+        self.wait_for_version_txs.retain_mut(|(version, tx)| {
+            if self.snapshot.version().observed_all(version) {
+                tx.try_send(()).ok();
+                false
+            } else {
+                true
+            }
+        });
         Ok(())
     }
 
@@ -1305,7 +1311,7 @@ impl Buffer {
     pub fn wait_for_edits(
         &mut self,
         edit_ids: impl IntoIterator<Item = clock::Local>,
-    ) -> impl 'static + Future<Output = ()> {
+    ) -> impl 'static + Future<Output = Result<()>> {
         let mut futures = Vec::new();
         for edit_id in edit_ids {
             if !self.version.observed(edit_id) {
@@ -1317,15 +1323,18 @@ impl Buffer {
 
         async move {
             for mut future in futures {
-                future.recv().await;
+                if future.recv().await.is_none() {
+                    Err(anyhow!("gave up waiting for edits"))?;
+                }
             }
+            Ok(())
         }
     }
 
     pub fn wait_for_anchors<'a>(
         &mut self,
         anchors: impl IntoIterator<Item = &'a Anchor>,
-    ) -> impl 'static + Future<Output = ()> {
+    ) -> impl 'static + Future<Output = Result<()>> {
         let mut futures = Vec::new();
         for anchor in anchors {
             if !self.version.observed(anchor.timestamp)
@@ -1343,21 +1352,36 @@ impl Buffer {
 
         async move {
             for mut future in futures {
-                future.recv().await;
+                if future.recv().await.is_none() {
+                    Err(anyhow!("gave up waiting for anchors"))?;
+                }
             }
+            Ok(())
         }
     }
 
-    pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future<Output = ()> {
-        let (tx, mut rx) = barrier::channel();
+    pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future<Output = Result<()>> {
+        let mut rx = None;
         if !self.snapshot.version.observed_all(&version) {
-            self.version_barriers.push((version, tx));
+            let channel = oneshot::channel();
+            self.wait_for_version_txs.push((version, channel.0));
+            rx = Some(channel.1);
         }
         async move {
-            rx.recv().await;
+            if let Some(mut rx) = rx {
+                if rx.recv().await.is_none() {
+                    Err(anyhow!("gave up waiting for version"))?;
+                }
+            }
+            Ok(())
         }
     }
 
+    pub fn give_up_waiting(&mut self) {
+        self.edit_id_resolvers.clear();
+        self.wait_for_version_txs.clear();
+    }
+
     fn resolve_edit(&mut self, edit_id: clock::Local) {
         for mut tx in self
             .edit_id_resolvers
@@ -1365,7 +1389,7 @@ impl Buffer {
             .into_iter()
             .flatten()
         {
-            let _ = tx.try_send(());
+            tx.try_send(()).ok();
         }
     }
 }