Fix race condition when opening a buffer and getting a definition to it

Antonio Scandurra created

Change summary

crates/language/src/buffer.rs  |   6 
crates/language/src/proto.rs   |   2 
crates/project/src/project.rs  | 185 ++++++++++++++++++++++++-----------
crates/project/src/worktree.rs |  46 --------
crates/rpc/proto/zed.proto     |  16 +-
crates/rpc/src/peer.rs         |  18 --
crates/server/src/rpc.rs       |  22 +++-
7 files changed, 160 insertions(+), 135 deletions(-)

Detailed changes

crates/language/src/buffer.rs 🔗

@@ -305,7 +305,7 @@ impl Buffer {
 
     pub fn from_proto(
         replica_id: ReplicaId,
-        message: proto::Buffer,
+        message: proto::BufferState,
         file: Option<Box<dyn File>>,
         cx: &mut ModelContext<Self>,
     ) -> Result<Self> {
@@ -359,8 +359,8 @@ impl Buffer {
         Ok(this)
     }
 
-    pub fn to_proto(&self) -> proto::Buffer {
-        proto::Buffer {
+    pub fn to_proto(&self) -> proto::BufferState {
+        proto::BufferState {
             id: self.remote_id(),
             file: self.file.as_ref().map(|f| f.to_proto()),
             visible_text: self.text.text(),

crates/language/src/proto.rs 🔗

@@ -7,7 +7,7 @@ use rpc::proto;
 use std::sync::Arc;
 use text::*;
 
-pub use proto::{Buffer, SelectionSet};
+pub use proto::{Buffer, BufferState, SelectionSet};
 
 pub fn serialize_operation(operation: &Operation) -> proto::Operation {
     proto::Operation {

crates/project/src/project.rs 🔗

@@ -518,17 +518,18 @@ impl Project {
                 let (mut tx, rx) = postage::watch::channel();
                 entry.insert(rx.clone());
 
-                let load_buffer = worktree.update(cx, |worktree, cx| {
-                    worktree.load_buffer(&project_path.path, cx)
-                });
+                let load_buffer = if worktree.read(cx).is_local() {
+                    self.open_local_buffer(&project_path.path, &worktree, cx)
+                } else {
+                    self.open_remote_buffer(&project_path.path, &worktree, cx)
+                };
 
                 cx.spawn(move |this, mut cx| async move {
                     let load_result = load_buffer.await;
-                    *tx.borrow_mut() = Some(this.update(&mut cx, |this, cx| {
+                    *tx.borrow_mut() = Some(this.update(&mut cx, |this, _| {
                         // Record the fact that the buffer is no longer loading.
                         this.loading_buffers.remove(&project_path);
                         let buffer = load_result.map_err(Arc::new)?;
-                        this.register_buffer(&buffer, &worktree, cx)?;
                         Ok(buffer)
                     }));
                 })
@@ -550,6 +551,55 @@ impl Project {
         })
     }
 
+    fn open_local_buffer(
+        &mut self,
+        path: &Arc<Path>,
+        worktree: &ModelHandle<Worktree>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<ModelHandle<Buffer>>> {
+        let load_buffer = worktree.update(cx, |worktree, cx| {
+            let worktree = worktree.as_local_mut().unwrap();
+            worktree.load_buffer(path, cx)
+        });
+        let worktree = worktree.downgrade();
+        cx.spawn(|this, mut cx| async move {
+            let buffer = load_buffer.await?;
+            let worktree = worktree
+                .upgrade(&cx)
+                .ok_or_else(|| anyhow!("worktree was removed"))?;
+            this.update(&mut cx, |this, cx| {
+                this.register_buffer(&buffer, Some(&worktree), cx)
+            })?;
+            Ok(buffer)
+        })
+    }
+
+    fn open_remote_buffer(
+        &mut self,
+        path: &Arc<Path>,
+        worktree: &ModelHandle<Worktree>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<ModelHandle<Buffer>>> {
+        let rpc = self.client.clone();
+        let project_id = self.remote_id().unwrap();
+        let remote_worktree_id = worktree.read(cx).id();
+        let path = path.clone();
+        let path_string = path.to_string_lossy().to_string();
+        cx.spawn(|this, mut cx| async move {
+            let response = rpc
+                .request(proto::OpenBuffer {
+                    project_id,
+                    worktree_id: remote_worktree_id.to_proto(),
+                    path: path_string,
+                })
+                .await?;
+            let buffer = response.buffer.ok_or_else(|| anyhow!("missing buffer"))?;
+            this.update(&mut cx, |this, cx| {
+                this.deserialize_remote_buffer(buffer, cx)
+            })
+        })
+    }
+
     pub fn save_buffer_as(
         &self,
         buffer: ModelHandle<Buffer>,
@@ -568,7 +618,7 @@ impl Project {
                 })
                 .await?;
             this.update(&mut cx, |this, cx| {
-                this.assign_language_to_buffer(&buffer, &worktree, cx);
+                this.assign_language_to_buffer(&buffer, Some(&worktree), cx);
             });
             Ok(())
         })
@@ -619,7 +669,7 @@ impl Project {
     fn register_buffer(
         &mut self,
         buffer: &ModelHandle<Buffer>,
-        worktree: &ModelHandle<Worktree>,
+        worktree: Option<&ModelHandle<Worktree>>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         match self.open_buffers.insert(
@@ -627,21 +677,21 @@ impl Project {
             OpenBuffer::Loaded(buffer.downgrade()),
         ) {
             Some(OpenBuffer::Operations(pending_ops)) => {
-                buffer.update(cx, |buf, cx| buf.apply_ops(pending_ops, cx))?;
+                // buffer.update(cx, |buf, cx| buf.apply_ops(pending_ops, cx))?;
             }
             Some(OpenBuffer::Loaded(_)) => {
                 return Err(anyhow!("registered the same buffer twice"));
             }
             None => {}
         }
-        self.assign_language_to_buffer(&buffer, &worktree, cx);
+        self.assign_language_to_buffer(&buffer, worktree, cx);
         Ok(())
     }
 
     fn assign_language_to_buffer(
         &mut self,
         buffer: &ModelHandle<Buffer>,
-        worktree: &ModelHandle<Worktree>,
+        worktree: Option<&ModelHandle<Worktree>>,
         cx: &mut ModelContext<Self>,
     ) -> Option<()> {
         let (path, full_path) = {
@@ -657,7 +707,7 @@ impl Project {
 
             // For local worktrees, start a language server if needed.
             // Also assign the language server and any previously stored diagnostics to the buffer.
-            if let Some(local_worktree) = worktree.read(cx).as_local() {
+            if let Some(local_worktree) = worktree.and_then(|w| w.read(cx).as_local()) {
                 let worktree_id = local_worktree.id();
                 let worktree_abs_path = local_worktree.abs_path().clone();
 
@@ -681,7 +731,7 @@ impl Project {
             }
         }
 
-        if let Some(local_worktree) = worktree.read(cx).as_local() {
+        if let Some(local_worktree) = worktree.and_then(|w| w.read(cx).as_local()) {
             if let Some(diagnostics) = local_worktree.diagnostics_for_path(&path) {
                 buffer.update(cx, |buffer, cx| {
                     buffer.update_diagnostics(None, diagnostics, cx).log_err();
@@ -1067,7 +1117,6 @@ impl Project {
             })
         } else if let Some(project_id) = self.remote_id() {
             let client = self.client.clone();
-            let replica_id = self.replica_id();
             let request = proto::GetDefinition {
                 project_id,
                 buffer_id: source_buffer.remote_id(),
@@ -1078,35 +1127,10 @@ impl Project {
                 this.update(&mut cx, |this, cx| {
                     let mut definitions = Vec::new();
                     for definition in response.definitions {
-                        let target_buffer = match definition
-                            .buffer
-                            .ok_or_else(|| anyhow!("missing buffer"))?
-                        {
-                            proto::definition::Buffer::Id(id) => this
-                                .open_buffers
-                                .get(&(id as usize))
-                                .and_then(|buffer| buffer.upgrade(cx))
-                                .ok_or_else(|| anyhow!("no buffer exists for id {}", id))?,
-                            proto::definition::Buffer::State(mut buffer) => {
-                                let file = if let Some(file) = buffer.file.take() {
-                                    let worktree_id = WorktreeId::from_proto(file.worktree_id);
-                                    let worktree =
-                                        this.worktree_for_id(worktree_id, cx).ok_or_else(|| {
-                                            anyhow!("no worktree found for id {}", file.worktree_id)
-                                        })?;
-                                    let file = File::from_proto(file, worktree.clone(), cx)?;
-                                    Some(Box::new(file) as Box<dyn language::File>)
-                                } else {
-                                    None
-                                };
-
-                                let buffer = cx.add_model(|cx| {
-                                    Buffer::from_proto(replica_id, buffer, file, cx).unwrap()
-                                });
-                                this.register_buffer(&buffer, &worktree, cx)?;
-                                buffer
-                            }
-                        };
+                        let target_buffer = this.deserialize_remote_buffer(
+                            definition.buffer.ok_or_else(|| anyhow!("missing buffer"))?,
+                            cx,
+                        )?;
                         let target_start = definition
                             .target_start
                             .and_then(deserialize_anchor)
@@ -1712,17 +1736,8 @@ impl Project {
             };
             this.update(&mut cx, |this, cx| {
                 for definition in definitions {
-                    let buffer_id = definition.target_buffer.read(cx).remote_id();
-                    let shared_buffers = this.shared_buffers.entry(sender_id).or_default();
-                    let buffer = match shared_buffers.entry(buffer_id) {
-                        hash_map::Entry::Occupied(_) => proto::definition::Buffer::Id(buffer_id),
-                        hash_map::Entry::Vacant(entry) => {
-                            entry.insert(definition.target_buffer.clone());
-                            proto::definition::Buffer::State(
-                                definition.target_buffer.read(cx).to_proto(),
-                            )
-                        }
-                    };
+                    let buffer =
+                        this.serialize_buffer_for_peer(&definition.target_buffer, sender_id, cx);
                     response.definitions.push(proto::Definition {
                         target_start: Some(serialize_anchor(&definition.target_range.start)),
                         target_end: Some(serialize_anchor(&definition.target_range.end)),
@@ -1757,17 +1772,13 @@ impl Project {
         cx.spawn(|this, mut cx| {
             async move {
                 let buffer = open_buffer.await?;
-                this.update(&mut cx, |this, _| {
-                    this.shared_buffers
-                        .entry(peer_id)
-                        .or_default()
-                        .insert(buffer.id() as u64, buffer.clone());
+                let buffer = this.update(&mut cx, |this, cx| {
+                    this.serialize_buffer_for_peer(&buffer, peer_id, cx)
                 });
-                let message = buffer.read_with(&cx, |buffer, _| buffer.to_proto());
                 rpc.respond(
                     receipt,
                     proto::OpenBufferResponse {
-                        buffer: Some(message),
+                        buffer: Some(buffer),
                     },
                 )
                 .await
@@ -1778,6 +1789,60 @@ impl Project {
         Ok(())
     }
 
+    fn serialize_buffer_for_peer(
+        &mut self,
+        buffer: &ModelHandle<Buffer>,
+        peer_id: PeerId,
+        cx: &AppContext,
+    ) -> proto::Buffer {
+        let buffer_id = buffer.read(cx).remote_id();
+        let shared_buffers = self.shared_buffers.entry(peer_id).or_default();
+        match shared_buffers.entry(buffer_id) {
+            hash_map::Entry::Occupied(_) => proto::Buffer {
+                variant: Some(proto::buffer::Variant::Id(buffer_id)),
+            },
+            hash_map::Entry::Vacant(entry) => {
+                entry.insert(buffer.clone());
+                proto::Buffer {
+                    variant: Some(proto::buffer::Variant::State(buffer.read(cx).to_proto())),
+                }
+            }
+        }
+    }
+
+    fn deserialize_remote_buffer(
+        &mut self,
+        buffer: proto::Buffer,
+        cx: &mut ModelContext<Self>,
+    ) -> Result<ModelHandle<Buffer>> {
+        match buffer.variant.ok_or_else(|| anyhow!("missing buffer"))? {
+            proto::buffer::Variant::Id(id) => self
+                .open_buffers
+                .get(&(id as usize))
+                .and_then(|buffer| buffer.upgrade(cx))
+                .ok_or_else(|| anyhow!("no buffer exists for id {}", id)),
+            proto::buffer::Variant::State(mut buffer) => {
+                let mut buffer_worktree = None;
+                let mut buffer_file = None;
+                if let Some(file) = buffer.file.take() {
+                    let worktree_id = WorktreeId::from_proto(file.worktree_id);
+                    let worktree = self
+                        .worktree_for_id(worktree_id, cx)
+                        .ok_or_else(|| anyhow!("no worktree found for id {}", file.worktree_id))?;
+                    buffer_file = Some(Box::new(File::from_proto(file, worktree.clone(), cx)?)
+                        as Box<dyn language::File>);
+                    buffer_worktree = Some(worktree);
+                }
+
+                let buffer = cx.add_model(|cx| {
+                    Buffer::from_proto(self.replica_id(), buffer, buffer_file, cx).unwrap()
+                });
+                self.register_buffer(&buffer, buffer_worktree.as_ref(), cx)?;
+                Ok(buffer)
+            }
+        }
+    }
+
     pub fn handle_close_buffer(
         &mut self,
         envelope: TypedEnvelope<proto::CloseBuffer>,

crates/project/src/worktree.rs 🔗

@@ -367,17 +367,6 @@ impl Worktree {
         }
     }
 
-    pub fn load_buffer(
-        &mut self,
-        path: &Path,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<ModelHandle<Buffer>>> {
-        match self {
-            Worktree::Local(worktree) => worktree.load_buffer(path, cx),
-            Worktree::Remote(worktree) => worktree.load_buffer(path, cx),
-        }
-    }
-
     pub fn diagnostic_summaries<'a>(
         &'a self,
     ) -> impl Iterator<Item = (Arc<Path>, DiagnosticSummary)> + 'a {
@@ -834,41 +823,6 @@ impl LocalWorktree {
 }
 
 impl RemoteWorktree {
-    pub(crate) fn load_buffer(
-        &mut self,
-        path: &Path,
-        cx: &mut ModelContext<Worktree>,
-    ) -> Task<Result<ModelHandle<Buffer>>> {
-        let rpc = self.client.clone();
-        let replica_id = self.replica_id;
-        let project_id = self.project_id;
-        let remote_worktree_id = self.id();
-        let path: Arc<Path> = Arc::from(path);
-        let path_string = path.to_string_lossy().to_string();
-        cx.spawn_weak(move |this, mut cx| async move {
-            let response = rpc
-                .request(proto::OpenBuffer {
-                    project_id,
-                    worktree_id: remote_worktree_id.to_proto(),
-                    path: path_string,
-                })
-                .await?;
-
-            let this = this
-                .upgrade(&cx)
-                .ok_or_else(|| anyhow!("worktree was closed"))?;
-            let mut remote_buffer = response.buffer.ok_or_else(|| anyhow!("empty buffer"))?;
-            let file = remote_buffer
-                .file
-                .take()
-                .map(|proto| cx.read(|cx| File::from_proto(proto, this.clone(), cx)))
-                .transpose()?
-                .map(|file| Box::new(file) as Box<dyn language::File>);
-
-            Ok(cx.add_model(|cx| Buffer::from_proto(replica_id, remote_buffer, file, cx).unwrap()))
-        })
-    }
-
     fn snapshot(&self) -> Snapshot {
         self.snapshot.clone()
     }

crates/rpc/proto/zed.proto 🔗

@@ -147,12 +147,9 @@ message GetDefinitionResponse {
 }
 
 message Definition {
-    oneof buffer {
-        uint64 id = 1;
-        Buffer state = 2;
-    }
-    Anchor target_start = 3;
-    Anchor target_end = 4;
+    Buffer buffer = 1;
+    Anchor target_start = 2;
+    Anchor target_end = 3;
 }
 
 message OpenBuffer {
@@ -324,6 +321,13 @@ message Entry {
 }
 
 message Buffer {
+    oneof variant {
+        uint64 id = 1;
+        BufferState state = 2;
+    }
+}
+
+message BufferState {
     uint64 id = 1;
     optional File file = 2;
     string visible_text = 3;

crates/rpc/src/peer.rs 🔗

@@ -410,9 +410,7 @@ mod tests {
                 .unwrap(),
             proto::OpenBufferResponse {
                 buffer: Some(proto::Buffer {
-                    id: 101,
-                    visible_text: "path/one content".to_string(),
-                    ..Default::default()
+                    variant: Some(proto::buffer::Variant::Id(0))
                 }),
             }
         );
@@ -431,10 +429,8 @@ mod tests {
                 .unwrap(),
             proto::OpenBufferResponse {
                 buffer: Some(proto::Buffer {
-                    id: 102,
-                    visible_text: "path/two content".to_string(),
-                    ..Default::default()
-                }),
+                    variant: Some(proto::buffer::Variant::Id(1))
+                })
             }
         );
 
@@ -460,9 +456,7 @@ mod tests {
                             assert_eq!(message.worktree_id, 1);
                             proto::OpenBufferResponse {
                                 buffer: Some(proto::Buffer {
-                                    id: 101,
-                                    visible_text: "path/one content".to_string(),
-                                    ..Default::default()
+                                    variant: Some(proto::buffer::Variant::Id(0)),
                                 }),
                             }
                         }
@@ -470,9 +464,7 @@ mod tests {
                             assert_eq!(message.worktree_id, 2);
                             proto::OpenBufferResponse {
                                 buffer: Some(proto::Buffer {
-                                    id: 102,
-                                    visible_text: "path/two content".to_string(),
-                                    ..Default::default()
+                                    variant: Some(proto::buffer::Variant::Id(1)),
                                 }),
                             }
                         }

crates/server/src/rpc.rs 🔗

@@ -1168,6 +1168,7 @@ mod tests {
     use gpui::{executor, ModelHandle, TestAppContext};
     use parking_lot::Mutex;
     use postage::{mpsc, watch};
+    use rand::prelude::*;
     use rpc::PeerId;
     use serde_json::json;
     use sqlx::types::time::OffsetDateTime;
@@ -2507,10 +2508,11 @@ mod tests {
             .await;
     }
 
-    #[gpui::test(iterations = 100, seed = 1)]
+    #[gpui::test]
     async fn test_open_buffer_while_getting_definition_pointing_to_it(
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
+        mut rng: StdRng,
     ) {
         cx_a.foreground().forbid_parking();
         let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2589,7 +2591,18 @@ mod tests {
             .await
             .unwrap();
 
-        let definitions = project_b.update(&mut cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
+        let definitions;
+        let buffer_b2;
+        if rng.gen() {
+            definitions = project_b.update(&mut cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
+            buffer_b2 =
+                project_b.update(&mut cx_b, |p, cx| p.open_buffer((worktree_id, "b.rs"), cx));
+        } else {
+            buffer_b2 =
+                project_b.update(&mut cx_b, |p, cx| p.open_buffer((worktree_id, "b.rs"), cx));
+            definitions = project_b.update(&mut cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
+        }
+
         let (request_id, _) = fake_language_server
             .receive_request::<lsp::request::GotoDefinition>()
             .await;
@@ -2603,10 +2616,7 @@ mod tests {
             )
             .await;
 
-        let buffer_b2 = project_b
-            .update(&mut cx_b, |p, cx| p.open_buffer((worktree_id, "b.rs"), cx))
-            .await
-            .unwrap();
+        let buffer_b2 = buffer_b2.await.unwrap();
         let definitions = definitions.await.unwrap();
         assert_eq!(definitions.len(), 1);
         assert_eq!(definitions[0].target_buffer, buffer_b2);