Handle GetDefinition via LspCommand trait

Max Brunsfeld created

Change summary

crates/project/src/lsp_command.rs | 249 +++++++++++++++++++++++++-------
crates/project/src/project.rs     | 192 +-----------------------
2 files changed, 204 insertions(+), 237 deletions(-)

Detailed changes

crates/project/src/lsp_command.rs 🔗

@@ -1,11 +1,12 @@
-use crate::{Project, ProjectTransaction};
+use crate::{Definition, Project, ProjectTransaction};
 use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use client::{proto, PeerId};
-use gpui::{AppContext, AsyncAppContext, ModelContext, ModelHandle};
+use gpui::{AppContext, AsyncAppContext, ModelHandle};
 use language::{
-    proto::deserialize_anchor, range_from_lsp, Anchor, Bias, Buffer, PointUtf16, ToLspPosition,
-    ToPointUtf16,
+    point_from_lsp,
+    proto::{deserialize_anchor, serialize_anchor},
+    range_from_lsp, Anchor, Bias, Buffer, PointUtf16, ToLspPosition, ToPointUtf16,
 };
 use std::{ops::Range, path::Path};
 
@@ -28,26 +29,18 @@ pub(crate) trait LspCommand: 'static + Sized {
         cx: AsyncAppContext,
     ) -> Result<Self::Response>;
 
-    fn to_proto(
-        &self,
-        project_id: u64,
-        buffer: &ModelHandle<Buffer>,
-        cx: &AppContext,
-    ) -> Self::ProtoRequest;
+    fn to_proto(&self, project_id: u64, buffer: &Buffer) -> Self::ProtoRequest;
     fn from_proto(
         message: Self::ProtoRequest,
         project: &mut Project,
-        buffer: &ModelHandle<Buffer>,
-        cx: &mut ModelContext<Project>,
+        buffer: &Buffer,
     ) -> Result<Self>;
-    fn buffer_id_from_proto(message: &Self::ProtoRequest) -> u64;
-
     fn response_to_proto(
         response: Self::Response,
         project: &mut Project,
         peer_id: PeerId,
         buffer_version: &clock::Global,
-        cx: &mut ModelContext<Project>,
+        cx: &AppContext,
     ) -> <Self::ProtoRequest as proto::RequestMessage>::Response;
     async fn response_from_proto(
         self,
@@ -56,6 +49,7 @@ pub(crate) trait LspCommand: 'static + Sized {
         buffer: ModelHandle<Buffer>,
         cx: AsyncAppContext,
     ) -> Result<Self::Response>;
+    fn buffer_id_from_proto(message: &Self::ProtoRequest) -> u64;
 }
 
 pub(crate) struct PrepareRename {
@@ -68,6 +62,10 @@ pub(crate) struct PerformRename {
     pub push_to_history: bool,
 }
 
+pub(crate) struct GetDefinition {
+    pub position: PointUtf16,
+}
+
 #[async_trait(?Send)]
 impl LspCommand for PrepareRename {
     type Response = Option<Range<Anchor>>;
@@ -107,32 +105,17 @@ impl LspCommand for PrepareRename {
         })
     }
 
-    fn to_proto(
-        &self,
-        project_id: u64,
-        buffer: &ModelHandle<Buffer>,
-        cx: &AppContext,
-    ) -> proto::PrepareRename {
+    fn to_proto(&self, project_id: u64, buffer: &Buffer) -> proto::PrepareRename {
         proto::PrepareRename {
             project_id,
-            buffer_id: buffer.read(cx).remote_id(),
+            buffer_id: buffer.remote_id(),
             position: Some(language::proto::serialize_anchor(
-                &buffer.read(cx).anchor_before(self.position),
+                &buffer.anchor_before(self.position),
             )),
         }
     }
 
-    fn buffer_id_from_proto(message: &proto::PrepareRename) -> u64 {
-        message.buffer_id
-    }
-
-    fn from_proto(
-        message: proto::PrepareRename,
-        _: &mut Project,
-        buffer: &ModelHandle<Buffer>,
-        cx: &mut ModelContext<Project>,
-    ) -> Result<Self> {
-        let buffer = buffer.read(cx);
+    fn from_proto(message: proto::PrepareRename, _: &mut Project, buffer: &Buffer) -> Result<Self> {
         let position = message
             .position
             .and_then(deserialize_anchor)
@@ -150,7 +133,7 @@ impl LspCommand for PrepareRename {
         _: &mut Project,
         _: PeerId,
         buffer_version: &clock::Global,
-        _: &mut ModelContext<Project>,
+        _: &AppContext,
     ) -> proto::PrepareRenameResponse {
         proto::PrepareRenameResponse {
             can_rename: range.is_some(),
@@ -184,6 +167,10 @@ impl LspCommand for PrepareRename {
             Ok(None)
         }
     }
+
+    fn buffer_id_from_proto(message: &proto::PrepareRename) -> u64 {
+        message.buffer_id
+    }
 }
 
 #[async_trait(?Send)]
@@ -237,17 +224,10 @@ impl LspCommand for PerformRename {
         }
     }
 
-    fn to_proto(
-        &self,
-        project_id: u64,
-        buffer: &ModelHandle<Buffer>,
-        cx: &AppContext,
-    ) -> proto::PerformRename {
-        let buffer = buffer.read(cx);
-        let buffer_id = buffer.remote_id();
+    fn to_proto(&self, project_id: u64, buffer: &Buffer) -> proto::PerformRename {
         proto::PerformRename {
             project_id,
-            buffer_id,
+            buffer_id: buffer.remote_id(),
             position: Some(language::proto::serialize_anchor(
                 &buffer.anchor_before(self.position),
             )),
@@ -255,21 +235,11 @@ impl LspCommand for PerformRename {
         }
     }
 
-    fn buffer_id_from_proto(message: &proto::PerformRename) -> u64 {
-        message.buffer_id
-    }
-
-    fn from_proto(
-        message: proto::PerformRename,
-        _: &mut Project,
-        buffer: &ModelHandle<Buffer>,
-        cx: &mut ModelContext<Project>,
-    ) -> Result<Self> {
+    fn from_proto(message: proto::PerformRename, _: &mut Project, buffer: &Buffer) -> Result<Self> {
         let position = message
             .position
             .and_then(deserialize_anchor)
             .ok_or_else(|| anyhow!("invalid position"))?;
-        let buffer = buffer.read(cx);
         if !buffer.can_resolve(&position) {
             Err(anyhow!("cannot resolve position"))?;
         }
@@ -285,7 +255,7 @@ impl LspCommand for PerformRename {
         project: &mut Project,
         peer_id: PeerId,
         _: &clock::Global,
-        cx: &mut ModelContext<Project>,
+        cx: &AppContext,
     ) -> proto::PerformRenameResponse {
         let transaction = project.serialize_project_transaction_for_peer(response, peer_id, cx);
         proto::PerformRenameResponse {
@@ -309,4 +279,171 @@ impl LspCommand for PerformRename {
             })
             .await
     }
+
+    fn buffer_id_from_proto(message: &proto::PerformRename) -> u64 {
+        message.buffer_id
+    }
+}
+
+#[async_trait(?Send)]
+impl LspCommand for GetDefinition {
+    type Response = Vec<Definition>;
+    type LspRequest = lsp::request::GotoDefinition;
+    type ProtoRequest = proto::GetDefinition;
+
+    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::GotoDefinitionParams {
+        lsp::GotoDefinitionParams {
+            text_document_position_params: lsp::TextDocumentPositionParams {
+                text_document: lsp::TextDocumentIdentifier {
+                    uri: lsp::Url::from_file_path(path).unwrap(),
+                },
+                position: self.position.to_lsp_position(),
+            },
+            work_done_progress_params: Default::default(),
+            partial_result_params: Default::default(),
+        }
+    }
+
+    async fn response_from_lsp(
+        self,
+        message: Option<lsp::GotoDefinitionResponse>,
+        project: ModelHandle<Project>,
+        buffer: ModelHandle<Buffer>,
+        mut cx: AsyncAppContext,
+    ) -> Result<Vec<Definition>> {
+        let mut definitions = Vec::new();
+        let (language, language_server) = buffer
+            .read_with(&cx, |buffer, _| {
+                buffer
+                    .language()
+                    .cloned()
+                    .zip(buffer.language_server().cloned())
+            })
+            .ok_or_else(|| anyhow!("buffer no longer has language server"))?;
+
+        if let Some(message) = message {
+            let mut unresolved_locations = Vec::new();
+            match message {
+                lsp::GotoDefinitionResponse::Scalar(loc) => {
+                    unresolved_locations.push((loc.uri, loc.range));
+                }
+                lsp::GotoDefinitionResponse::Array(locs) => {
+                    unresolved_locations.extend(locs.into_iter().map(|l| (l.uri, l.range)));
+                }
+                lsp::GotoDefinitionResponse::Link(links) => {
+                    unresolved_locations.extend(
+                        links
+                            .into_iter()
+                            .map(|l| (l.target_uri, l.target_selection_range)),
+                    );
+                }
+            }
+
+            for (target_uri, target_range) in unresolved_locations {
+                let target_buffer_handle = project
+                    .update(&mut cx, |this, cx| {
+                        this.open_local_buffer_from_lsp_path(
+                            target_uri,
+                            language.name().to_string(),
+                            language_server.clone(),
+                            cx,
+                        )
+                    })
+                    .await?;
+
+                cx.read(|cx| {
+                    let target_buffer = target_buffer_handle.read(cx);
+                    let target_start = target_buffer
+                        .clip_point_utf16(point_from_lsp(target_range.start), Bias::Left);
+                    let target_end = target_buffer
+                        .clip_point_utf16(point_from_lsp(target_range.end), Bias::Left);
+                    definitions.push(Definition {
+                        target_buffer: target_buffer_handle,
+                        target_range: target_buffer.anchor_after(target_start)
+                            ..target_buffer.anchor_before(target_end),
+                    });
+                });
+            }
+        }
+
+        Ok(definitions)
+    }
+
+    fn to_proto(&self, project_id: u64, buffer: &Buffer) -> proto::GetDefinition {
+        proto::GetDefinition {
+            project_id,
+            buffer_id: buffer.remote_id(),
+            position: Some(language::proto::serialize_anchor(
+                &buffer.anchor_before(self.position),
+            )),
+        }
+    }
+
+    fn from_proto(message: proto::GetDefinition, _: &mut Project, buffer: &Buffer) -> Result<Self> {
+        let position = message
+            .position
+            .and_then(deserialize_anchor)
+            .ok_or_else(|| anyhow!("invalid position"))?;
+        if !buffer.can_resolve(&position) {
+            Err(anyhow!("cannot resolve position"))?;
+        }
+        Ok(Self {
+            position: position.to_point_utf16(buffer),
+        })
+    }
+
+    fn response_to_proto(
+        response: Vec<Definition>,
+        project: &mut Project,
+        peer_id: PeerId,
+        _: &clock::Global,
+        cx: &AppContext,
+    ) -> proto::GetDefinitionResponse {
+        let definitions = response
+            .into_iter()
+            .map(|definition| {
+                let buffer =
+                    project.serialize_buffer_for_peer(&definition.target_buffer, peer_id, cx);
+                proto::Definition {
+                    target_start: Some(serialize_anchor(&definition.target_range.start)),
+                    target_end: Some(serialize_anchor(&definition.target_range.end)),
+                    buffer: Some(buffer),
+                }
+            })
+            .collect();
+        proto::GetDefinitionResponse { definitions }
+    }
+
+    async fn response_from_proto(
+        self,
+        message: proto::GetDefinitionResponse,
+        project: ModelHandle<Project>,
+        _: ModelHandle<Buffer>,
+        mut cx: AsyncAppContext,
+    ) -> Result<Vec<Definition>> {
+        let mut definitions = Vec::new();
+        for definition in message.definitions {
+            let buffer = definition.buffer.ok_or_else(|| anyhow!("missing buffer"))?;
+            let target_buffer = project
+                .update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx))
+                .await?;
+            let target_start = definition
+                .target_start
+                .and_then(deserialize_anchor)
+                .ok_or_else(|| anyhow!("missing target start"))?;
+            let target_end = definition
+                .target_end
+                .and_then(deserialize_anchor)
+                .ok_or_else(|| anyhow!("missing target end"))?;
+            definitions.push(Definition {
+                target_buffer,
+                target_range: target_start..target_end,
+            })
+        }
+        Ok(definitions)
+    }
+
+    fn buffer_id_from_proto(message: &proto::GetDefinition) -> u64 {
+        message.buffer_id
+    }
 }

crates/project/src/project.rs 🔗

@@ -14,8 +14,6 @@ use gpui::{
     UpgradeModelHandle, WeakModelHandle,
 };
 use language::{
-    point_from_lsp,
-    proto::{deserialize_anchor, serialize_anchor},
     range_from_lsp, Anchor, AnchorRangeExt, Bias, Buffer, CodeAction, Completion, CompletionLabel,
     Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, Operation, PointUtf16,
     ToLspPosition, ToOffset, ToPointUtf16, Transaction,
@@ -183,9 +181,9 @@ impl Project {
         client.add_entity_request_handler(Self::handle_format_buffers);
         client.add_entity_request_handler(Self::handle_get_code_actions);
         client.add_entity_request_handler(Self::handle_get_completions);
-        client.add_entity_request_handler(Self::handle_get_definition);
-        client.add_entity_request_handler(Self::handle_lsp_command::<lsp_command::PrepareRename>);
-        client.add_entity_request_handler(Self::handle_lsp_command::<lsp_command::PerformRename>);
+        client.add_entity_request_handler(Self::handle_lsp_command::<GetDefinition>);
+        client.add_entity_request_handler(Self::handle_lsp_command::<PrepareRename>);
+        client.add_entity_request_handler(Self::handle_lsp_command::<PerformRename>);
         client.add_entity_request_handler(Self::handle_open_buffer);
         client.add_entity_request_handler(Self::handle_save_buffer);
     }
@@ -1175,137 +1173,12 @@ impl Project {
 
     pub fn definition<T: ToPointUtf16>(
         &self,
-        source_buffer_handle: &ModelHandle<Buffer>,
+        buffer: &ModelHandle<Buffer>,
         position: T,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<Definition>>> {
-        let source_buffer_handle = source_buffer_handle.clone();
-        let source_buffer = source_buffer_handle.read(cx);
-        let worktree;
-        let buffer_abs_path;
-        if let Some(file) = File::from_dyn(source_buffer.file()) {
-            worktree = file.worktree.clone();
-            buffer_abs_path = file.as_local().map(|f| f.abs_path(cx));
-        } else {
-            return Task::ready(Ok(Default::default()));
-        };
-
-        let position = position.to_point_utf16(source_buffer);
-
-        if worktree.read(cx).as_local().is_some() {
-            let buffer_abs_path = buffer_abs_path.unwrap();
-            let lang_name;
-            let lang_server;
-            if let Some(lang) = source_buffer.language() {
-                lang_name = lang.name().to_string();
-                if let Some(server) = self
-                    .language_servers
-                    .get(&(worktree.read(cx).id(), lang_name.clone()))
-                {
-                    lang_server = server.clone();
-                } else {
-                    return Task::ready(Ok(Default::default()));
-                };
-            } else {
-                return Task::ready(Ok(Default::default()));
-            }
-
-            cx.spawn(|this, mut cx| async move {
-                let response = lang_server
-                    .request::<lsp::request::GotoDefinition>(lsp::GotoDefinitionParams {
-                        text_document_position_params: lsp::TextDocumentPositionParams {
-                            text_document: lsp::TextDocumentIdentifier::new(
-                                lsp::Url::from_file_path(&buffer_abs_path).unwrap(),
-                            ),
-                            position: lsp::Position::new(position.row, position.column),
-                        },
-                        work_done_progress_params: Default::default(),
-                        partial_result_params: Default::default(),
-                    })
-                    .await?;
-
-                let mut definitions = Vec::new();
-                if let Some(response) = response {
-                    let mut unresolved_locations = Vec::new();
-                    match response {
-                        lsp::GotoDefinitionResponse::Scalar(loc) => {
-                            unresolved_locations.push((loc.uri, loc.range));
-                        }
-                        lsp::GotoDefinitionResponse::Array(locs) => {
-                            unresolved_locations.extend(locs.into_iter().map(|l| (l.uri, l.range)));
-                        }
-                        lsp::GotoDefinitionResponse::Link(links) => {
-                            unresolved_locations.extend(
-                                links
-                                    .into_iter()
-                                    .map(|l| (l.target_uri, l.target_selection_range)),
-                            );
-                        }
-                    }
-
-                    for (target_uri, target_range) in unresolved_locations {
-                        let target_buffer_handle = this
-                            .update(&mut cx, |this, cx| {
-                                this.open_local_buffer_from_lsp_path(
-                                    target_uri,
-                                    lang_name.clone(),
-                                    lang_server.clone(),
-                                    cx,
-                                )
-                            })
-                            .await?;
-
-                        cx.read(|cx| {
-                            let target_buffer = target_buffer_handle.read(cx);
-                            let target_start = target_buffer
-                                .clip_point_utf16(point_from_lsp(target_range.start), Bias::Left);
-                            let target_end = target_buffer
-                                .clip_point_utf16(point_from_lsp(target_range.end), Bias::Left);
-                            definitions.push(Definition {
-                                target_buffer: target_buffer_handle,
-                                target_range: target_buffer.anchor_after(target_start)
-                                    ..target_buffer.anchor_before(target_end),
-                            });
-                        });
-                    }
-                }
-
-                Ok(definitions)
-            })
-        } else if let Some(project_id) = self.remote_id() {
-            let client = self.client.clone();
-            let request = proto::GetDefinition {
-                project_id,
-                buffer_id: source_buffer.remote_id(),
-                position: Some(serialize_anchor(&source_buffer.anchor_before(position))),
-            };
-            cx.spawn(|this, mut cx| async move {
-                let response = client.request(request).await?;
-                let mut definitions = Vec::new();
-                for definition in response.definitions {
-                    let buffer = definition.buffer.ok_or_else(|| anyhow!("missing buffer"))?;
-                    let target_buffer = this
-                        .update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx))
-                        .await?;
-                    let target_start = definition
-                        .target_start
-                        .and_then(deserialize_anchor)
-                        .ok_or_else(|| anyhow!("missing target start"))?;
-                    let target_end = definition
-                        .target_end
-                        .and_then(deserialize_anchor)
-                        .ok_or_else(|| anyhow!("missing target end"))?;
-                    definitions.push(Definition {
-                        target_buffer,
-                        target_range: target_start..target_end,
-                    })
-                }
-
-                Ok(definitions)
-            })
-        } else {
-            Task::ready(Ok(Default::default()))
-        }
+        let position = position.to_point_utf16(buffer.read(cx));
+        self.request_lsp(buffer.clone(), GetDefinition { position }, cx)
     }
 
     pub fn completions<T: ToPointUtf16>(
@@ -1861,8 +1734,8 @@ impl Project {
     where
         <R::LspRequest as lsp::request::Request>::Result: Send,
     {
+        let buffer = buffer_handle.read(cx);
         if self.is_local() {
-            let buffer = buffer_handle.read(cx);
             let file = File::from_dyn(buffer.file()).and_then(File::as_local);
             if let Some((file, language_server)) = file.zip(buffer.language_server().cloned()) {
                 let lsp_params = request.to_lsp(&file.abs_path(cx), cx);
@@ -1878,7 +1751,7 @@ impl Project {
             }
         } else if let Some(project_id) = self.remote_id() {
             let rpc = self.client.clone();
-            let message = request.to_proto(project_id, &buffer_handle, cx);
+            let message = request.to_proto(project_id, buffer);
             return cx.spawn(|this, cx| async move {
                 let response = rpc.request(message).await?;
                 request
@@ -2578,50 +2451,6 @@ impl Project {
         })
     }
 
-    async fn handle_get_definition(
-        this: ModelHandle<Self>,
-        envelope: TypedEnvelope<proto::GetDefinition>,
-        _: Arc<Client>,
-        mut cx: AsyncAppContext,
-    ) -> Result<proto::GetDefinitionResponse> {
-        let sender_id = envelope.original_sender_id()?;
-        let position = envelope
-            .payload
-            .position
-            .and_then(deserialize_anchor)
-            .ok_or_else(|| anyhow!("invalid position"))?;
-        let definitions = this.update(&mut cx, |this, cx| {
-            let source_buffer = this
-                .shared_buffers
-                .get(&sender_id)
-                .and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned())
-                .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?;
-            if source_buffer.read(cx).can_resolve(&position) {
-                Ok(this.definition(&source_buffer, position, cx))
-            } else {
-                Err(anyhow!("cannot resolve position"))
-            }
-        })?;
-
-        let definitions = definitions.await?;
-
-        this.update(&mut cx, |this, cx| {
-            let mut response = proto::GetDefinitionResponse {
-                definitions: Default::default(),
-            };
-            for definition in definitions {
-                let buffer =
-                    this.serialize_buffer_for_peer(&definition.target_buffer, sender_id, cx);
-                response.definitions.push(proto::Definition {
-                    target_start: Some(serialize_anchor(&definition.target_range.start)),
-                    target_end: Some(serialize_anchor(&definition.target_range.end)),
-                    buffer: Some(buffer),
-                });
-            }
-            Ok(response)
-        })
-    }
-
     async fn handle_lsp_command<T: LspCommand>(
         this: ModelHandle<Self>,
         envelope: TypedEnvelope<T::ProtoRequest>,
@@ -2639,8 +2468,9 @@ impl Project {
                 .get(&sender_id)
                 .and_then(|shared_buffers| shared_buffers.get(&buffer_id).cloned())
                 .ok_or_else(|| anyhow!("unknown buffer id {}", buffer_id))?;
-            let buffer_version = buffer_handle.read(cx).version();
-            let request = T::from_proto(envelope.payload, this, &buffer_handle, cx)?;
+            let buffer = buffer_handle.read(cx);
+            let buffer_version = buffer.version();
+            let request = T::from_proto(envelope.payload, this, buffer)?;
             Ok::<_, anyhow::Error>((this.request_lsp(buffer_handle, request, cx), buffer_version))
         })?;
         let response = request.await?;