Start on requesting completions for remote buffers

Antonio Scandurra created

Change summary

crates/editor/src/editor.rs    |  2 
crates/language/src/buffer.rs  | 22 ++++++++++++
crates/project/src/project.rs  | 61 ++++++++++++++++++++++++++++++++++++
crates/project/src/worktree.rs | 48 +++++++++++++++++++++++++++
crates/rpc/proto/zed.proto     | 51 ++++++++++++++++++++---------
crates/rpc/src/proto.rs        |  4 ++
crates/server/src/rpc.rs       | 25 ++++++++++++++
7 files changed, 194 insertions(+), 19 deletions(-)

Detailed changes

crates/editor/src/editor.rs 🔗

@@ -1670,7 +1670,7 @@ impl Editor {
             .get(completion_state.selected_item)?;
         let completion = completion_state.completions.get(mat.candidate_id)?;
 
-        if completion.lsp_completion.insert_text_format == Some(lsp::InsertTextFormat::SNIPPET) {
+        if completion.is_snippet() {
             self.insert_snippet(completion.old_range.clone(), &completion.new_text, cx)
                 .log_err();
         } else {

crates/language/src/buffer.rs 🔗

@@ -195,6 +195,13 @@ pub trait File {
     fn format_remote(&self, buffer_id: u64, cx: &mut MutableAppContext)
         -> Option<Task<Result<()>>>;
 
+    fn completions(
+        &self,
+        buffer_id: u64,
+        position: Anchor,
+        cx: &mut MutableAppContext,
+    ) -> Task<Result<Vec<Completion<Anchor>>>>;
+
     fn buffer_updated(&self, buffer_id: u64, operation: Operation, cx: &mut MutableAppContext);
 
     fn buffer_removed(&self, buffer_id: u64, cx: &mut MutableAppContext);
@@ -264,6 +271,15 @@ impl File for FakeFile {
         None
     }
 
+    fn completions(
+        &self,
+        _: u64,
+        _: Anchor,
+        _: &mut MutableAppContext,
+    ) -> Task<Result<Vec<Completion<Anchor>>>> {
+        Task::ready(Ok(Default::default()))
+    }
+
     fn buffer_updated(&self, _: u64, _: Operation, _: &mut MutableAppContext) {}
 
     fn buffer_removed(&self, _: u64, _: &mut MutableAppContext) {}
@@ -1773,7 +1789,7 @@ impl Buffer {
                 })
             })
         } else {
-            Task::ready(Ok(Default::default()))
+            file.completions(self.remote_id(), self.anchor_before(position), cx.as_mut())
         }
     }
 
@@ -2555,6 +2571,10 @@ impl<T> Completion<T> {
         };
         (kind_key, &self.label()[self.filter_range()])
     }
+
+    pub fn is_snippet(&self) -> bool {
+        self.lsp_completion.insert_text_format == Some(lsp::InsertTextFormat::SNIPPET)
+    }
 }
 
 pub fn contiguous_ranges(

crates/project/src/project.rs 🔗

@@ -334,6 +334,7 @@ impl Project {
                 client.subscribe_to_entity(remote_id, cx, Self::handle_save_buffer),
                 client.subscribe_to_entity(remote_id, cx, Self::handle_buffer_saved),
                 client.subscribe_to_entity(remote_id, cx, Self::handle_format_buffer),
+                client.subscribe_to_entity(remote_id, cx, Self::handle_get_completions),
                 client.subscribe_to_entity(remote_id, cx, Self::handle_get_definition),
             ]);
         }
@@ -1683,6 +1684,66 @@ impl Project {
         Ok(())
     }
 
+    fn handle_get_completions(
+        &mut self,
+        envelope: TypedEnvelope<proto::GetCompletions>,
+        rpc: Arc<Client>,
+        cx: &mut ModelContext<Self>,
+    ) -> Result<()> {
+        let receipt = envelope.receipt();
+        let sender_id = envelope.original_sender_id()?;
+        let buffer = self
+            .shared_buffers
+            .get(&sender_id)
+            .and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned())
+            .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?;
+        let position = envelope
+            .payload
+            .position
+            .and_then(language::proto::deserialize_anchor)
+            .ok_or_else(|| anyhow!("invalid position"))?;
+        cx.spawn(|_, mut cx| async move {
+            match buffer
+                .update(&mut cx, |buffer, cx| buffer.completions(position, cx))
+                .await
+            {
+                Ok(completions) => {
+                    rpc.respond(
+                        receipt,
+                        proto::GetCompletionsResponse {
+                            completions: completions
+                                .into_iter()
+                                .map(|completion| proto::Completion {
+                                    old_start: Some(language::proto::serialize_anchor(
+                                        &completion.old_range.start,
+                                    )),
+                                    old_end: Some(language::proto::serialize_anchor(
+                                        &completion.old_range.end,
+                                    )),
+                                    new_text: completion.new_text,
+                                    lsp_completion: serde_json::to_vec(&completion.lsp_completion)
+                                        .unwrap(),
+                                })
+                                .collect(),
+                        },
+                    )
+                    .await
+                }
+                Err(error) => {
+                    rpc.respond_with_error(
+                        receipt,
+                        proto::Error {
+                            message: error.to_string(),
+                        },
+                    )
+                    .await
+                }
+            }
+        })
+        .detach_and_log_err(cx);
+        Ok(())
+    }
+
     pub fn handle_get_definition(
         &mut self,
         envelope: TypedEnvelope<proto::GetDefinition>,

crates/project/src/worktree.rs 🔗

@@ -14,7 +14,7 @@ use gpui::{
     executor, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
     Task,
 };
-use language::{Buffer, DiagnosticEntry, Operation, PointUtf16, Rope};
+use language::{Anchor, Buffer, Completion, DiagnosticEntry, Operation, PointUtf16, Rope};
 use lazy_static::lazy_static;
 use parking_lot::Mutex;
 use postage::{
@@ -1421,6 +1421,52 @@ impl language::File for File {
         }))
     }
 
+    fn completions(
+        &self,
+        buffer_id: u64,
+        position: Anchor,
+        cx: &mut MutableAppContext,
+    ) -> Task<Result<Vec<Completion<Anchor>>>> {
+        let worktree = self.worktree.read(cx);
+        let worktree = if let Some(worktree) = worktree.as_remote() {
+            worktree
+        } else {
+            return Task::ready(Err(anyhow!(
+                "remote completions requested on a local worktree"
+            )));
+        };
+        let rpc = worktree.client.clone();
+        let project_id = worktree.project_id;
+        cx.foreground().spawn(async move {
+            let response = rpc
+                .request(proto::GetCompletions {
+                    project_id,
+                    buffer_id,
+                    position: Some(language::proto::serialize_anchor(&position)),
+                })
+                .await?;
+            response
+                .completions
+                .into_iter()
+                .map(|completion| {
+                    let old_start = completion
+                        .old_start
+                        .and_then(language::proto::deserialize_anchor)
+                        .ok_or_else(|| anyhow!("invalid old start"))?;
+                    let old_end = completion
+                        .old_end
+                        .and_then(language::proto::deserialize_anchor)
+                        .ok_or_else(|| anyhow!("invalid old end"))?;
+                    Ok(Completion {
+                        old_range: old_start..old_end,
+                        new_text: completion.new_text,
+                        lsp_completion: serde_json::from_slice(&completion.lsp_completion)?,
+                    })
+                })
+                .collect()
+        })
+    }
+
     fn buffer_updated(&self, buffer_id: u64, operation: Operation, cx: &mut MutableAppContext) {
         self.worktree.update(cx, |worktree, cx| {
             worktree.send_buffer_update(buffer_id, operation, cx);

crates/rpc/proto/zed.proto 🔗

@@ -40,22 +40,24 @@ message Envelope {
         BufferSaved buffer_saved = 32;
         BufferReloaded buffer_reloaded = 33;
         FormatBuffer format_buffer = 34;
-
-        GetChannels get_channels = 35;
-        GetChannelsResponse get_channels_response = 36;
-        JoinChannel join_channel = 37;
-        JoinChannelResponse join_channel_response = 38;
-        LeaveChannel leave_channel = 39;
-        SendChannelMessage send_channel_message = 40;
-        SendChannelMessageResponse send_channel_message_response = 41;
-        ChannelMessageSent channel_message_sent = 42;
-        GetChannelMessages get_channel_messages = 43;
-        GetChannelMessagesResponse get_channel_messages_response = 44;
-
-        UpdateContacts update_contacts = 45;
-
-        GetUsers get_users = 46;
-        GetUsersResponse get_users_response = 47;
+        GetCompletions get_completions = 35;
+        GetCompletionsResponse get_completions_response = 36;
+
+        GetChannels get_channels = 37;
+        GetChannelsResponse get_channels_response = 38;
+        JoinChannel join_channel = 39;
+        JoinChannelResponse join_channel_response = 40;
+        LeaveChannel leave_channel = 41;
+        SendChannelMessage send_channel_message = 42;
+        SendChannelMessageResponse send_channel_message_response = 43;
+        ChannelMessageSent channel_message_sent = 44;
+        GetChannelMessages get_channel_messages = 45;
+        GetChannelMessagesResponse get_channel_messages_response = 46;
+
+        UpdateContacts update_contacts = 47;
+
+        GetUsers get_users = 48;
+        GetUsersResponse get_users_response = 49;
     }
 }
 
@@ -203,6 +205,23 @@ message FormatBuffer {
     uint64 buffer_id = 2;
 }
 
+message GetCompletions {
+    uint64 project_id = 1;
+    uint64 buffer_id = 2;
+    Anchor position = 3;
+}
+
+message GetCompletionsResponse {
+    repeated Completion completions = 1;
+}
+
+message Completion {
+    Anchor old_start = 1;
+    Anchor old_end = 2;
+    string new_text = 3;
+    bytes lsp_completion = 4;
+}
+
 message UpdateDiagnosticSummary {
     uint64 project_id = 1;
     uint64 worktree_id = 2;

crates/rpc/src/proto.rs 🔗

@@ -134,6 +134,8 @@ messages!(
     GetChannelMessagesResponse,
     GetChannels,
     GetChannelsResponse,
+    GetCompletions,
+    GetCompletionsResponse,
     GetDefinition,
     GetDefinitionResponse,
     GetUsers,
@@ -170,6 +172,7 @@ request_messages!(
     (FormatBuffer, Ack),
     (GetChannelMessages, GetChannelMessagesResponse),
     (GetChannels, GetChannelsResponse),
+    (GetCompletions, GetCompletionsResponse),
     (GetDefinition, GetDefinitionResponse),
     (GetUsers, GetUsersResponse),
     (JoinChannel, JoinChannelResponse),
@@ -194,6 +197,7 @@ entity_messages!(
     DiskBasedDiagnosticsUpdated,
     DiskBasedDiagnosticsUpdating,
     FormatBuffer,
+    GetCompletions,
     GetDefinition,
     JoinProject,
     LeaveProject,

crates/server/src/rpc.rs 🔗

@@ -83,6 +83,7 @@ impl Server {
             .add_handler(Server::buffer_saved)
             .add_handler(Server::save_buffer)
             .add_handler(Server::format_buffer)
+            .add_handler(Server::get_completions)
             .add_handler(Server::get_channels)
             .add_handler(Server::get_users)
             .add_handler(Server::join_channel)
@@ -722,6 +723,30 @@ impl Server {
         Ok(())
     }
 
+    async fn get_completions(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::GetCompletions>,
+    ) -> tide::Result<()> {
+        let host;
+        {
+            let state = self.state();
+            let project = state
+                .read_project(request.payload.project_id, request.sender_id)
+                .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?;
+            host = project.host_connection_id;
+        }
+
+        let sender = request.sender_id;
+        let receipt = request.receipt();
+        let response = self
+            .peer
+            .forward_request(sender, host, request.payload.clone())
+            .await?;
+        self.peer.respond(receipt, response).await?;
+
+        Ok(())
+    }
+
     async fn update_buffer(
         self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateBuffer>,