Use `LspCommand` to handle completions

Antonio Scandurra created

Change summary

crates/project/src/lsp_command.rs | 209 ++++++++++++++++++++++++++++++
crates/project/src/project.rs     | 227 --------------------------------
2 files changed, 212 insertions(+), 224 deletions(-)

Detailed changes

crates/project/src/lsp_command.rs 🔗

@@ -4,11 +4,13 @@ use crate::{
 use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use client::proto::{self, PeerId};
+use fs::LineEnding;
 use gpui::{AppContext, AsyncAppContext, ModelHandle};
 use language::{
     point_from_lsp, point_to_lsp,
     proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version},
-    range_from_lsp, Anchor, Bias, Buffer, CachedLspAdapter, PointUtf16, ToPointUtf16,
+    range_from_lsp, Anchor, Bias, Buffer, CachedLspAdapter, CharKind, Completion, PointUtf16,
+    ToOffset, ToPointUtf16, Unclipped,
 };
 use lsp::{DocumentHighlightKind, LanguageServer, ServerCapabilities};
 use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag};
@@ -91,6 +93,10 @@ pub(crate) struct GetHover {
     pub position: PointUtf16,
 }
 
+pub(crate) struct GetCompletions {
+    pub position: PointUtf16,
+}
+
 #[async_trait(?Send)]
 impl LspCommand for PrepareRename {
     type Response = Option<Range<Anchor>>;
@@ -1199,3 +1205,204 @@ impl LspCommand for GetHover {
         message.buffer_id
     }
 }
+
+#[async_trait(?Send)]
+impl LspCommand for GetCompletions {
+    type Response = Vec<Completion>;
+    type LspRequest = lsp::request::Completion;
+    type ProtoRequest = proto::GetCompletions;
+
+    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::CompletionParams {
+        lsp::CompletionParams {
+            text_document_position: lsp::TextDocumentPositionParams::new(
+                lsp::TextDocumentIdentifier::new(lsp::Url::from_file_path(path).unwrap()),
+                point_to_lsp(self.position),
+            ),
+            context: Default::default(),
+            work_done_progress_params: Default::default(),
+            partial_result_params: Default::default(),
+        }
+    }
+
+    async fn response_from_lsp(
+        self,
+        completions: Option<lsp::CompletionResponse>,
+        _: ModelHandle<Project>,
+        buffer: ModelHandle<Buffer>,
+        cx: AsyncAppContext,
+    ) -> Result<Vec<Completion>> {
+        let completions = if let Some(completions) = completions {
+            match completions {
+                lsp::CompletionResponse::Array(completions) => completions,
+                lsp::CompletionResponse::List(list) => list.items,
+            }
+        } else {
+            Default::default()
+        };
+
+        let completions = buffer.read_with(&cx, |buffer, _| {
+            let language = buffer.language().cloned();
+            let snapshot = buffer.snapshot();
+            let clipped_position = buffer.clip_point_utf16(Unclipped(self.position), Bias::Left);
+            let mut range_for_token = None;
+            completions
+                .into_iter()
+                .filter_map(move |mut lsp_completion| {
+                    // For now, we can only handle additional edits if they are returned
+                    // when resolving the completion, not if they are present initially.
+                    if lsp_completion
+                        .additional_text_edits
+                        .as_ref()
+                        .map_or(false, |edits| !edits.is_empty())
+                    {
+                        return None;
+                    }
+
+                    let (old_range, mut new_text) = match lsp_completion.text_edit.as_ref() {
+                        // If the language server provides a range to overwrite, then
+                        // check that the range is valid.
+                        Some(lsp::CompletionTextEdit::Edit(edit)) => {
+                            let range = range_from_lsp(edit.range);
+                            let start = snapshot.clip_point_utf16(range.start, Bias::Left);
+                            let end = snapshot.clip_point_utf16(range.end, Bias::Left);
+                            if start != range.start.0 || end != range.end.0 {
+                                log::info!("completion out of expected range");
+                                return None;
+                            }
+                            (
+                                snapshot.anchor_before(start)..snapshot.anchor_after(end),
+                                edit.new_text.clone(),
+                            )
+                        }
+                        // If the language server does not provide a range, then infer
+                        // the range based on the syntax tree.
+                        None => {
+                            if self.position != clipped_position {
+                                log::info!("completion out of expected range");
+                                return None;
+                            }
+                            let Range { start, end } = range_for_token
+                                .get_or_insert_with(|| {
+                                    let offset = self.position.to_offset(&snapshot);
+                                    let (range, kind) = snapshot.surrounding_word(offset);
+                                    if kind == Some(CharKind::Word) {
+                                        range
+                                    } else {
+                                        offset..offset
+                                    }
+                                })
+                                .clone();
+                            let text = lsp_completion
+                                .insert_text
+                                .as_ref()
+                                .unwrap_or(&lsp_completion.label)
+                                .clone();
+                            (
+                                snapshot.anchor_before(start)..snapshot.anchor_after(end),
+                                text,
+                            )
+                        }
+                        Some(lsp::CompletionTextEdit::InsertAndReplace(_)) => {
+                            log::info!("unsupported insert/replace completion");
+                            return None;
+                        }
+                    };
+
+                    let language = language.clone();
+                    LineEnding::normalize(&mut new_text);
+                    Some(async move {
+                        let mut label = None;
+                        if let Some(language) = language {
+                            language.process_completion(&mut lsp_completion).await;
+                            label = language.label_for_completion(&lsp_completion).await;
+                        }
+                        Completion {
+                            old_range,
+                            new_text,
+                            label: label.unwrap_or_else(|| {
+                                language::CodeLabel::plain(
+                                    lsp_completion.label.clone(),
+                                    lsp_completion.filter_text.as_deref(),
+                                )
+                            }),
+                            lsp_completion,
+                        }
+                    })
+                })
+        });
+
+        Ok(futures::future::join_all(completions).await)
+    }
+
+    fn to_proto(&self, project_id: u64, buffer: &Buffer) -> proto::GetCompletions {
+        let anchor = buffer.anchor_after(self.position);
+        proto::GetCompletions {
+            project_id,
+            buffer_id: buffer.remote_id(),
+            position: Some(language::proto::serialize_anchor(&anchor)),
+            version: serialize_version(&buffer.version()),
+        }
+    }
+
+    async fn from_proto(
+        message: proto::GetCompletions,
+        project: ModelHandle<Project>,
+        buffer: ModelHandle<Buffer>,
+        mut cx: AsyncAppContext,
+    ) -> Result<Self> {
+        let version = deserialize_version(&message.version);
+        buffer
+            .update(&mut cx, |buffer, _| buffer.wait_for_version(version))
+            .await?;
+        let position = message
+            .position
+            .and_then(language::proto::deserialize_anchor)
+            .map(|p| {
+                buffer.read_with(&cx, |buffer, _| {
+                    buffer.clip_point_utf16(Unclipped(p.to_point_utf16(buffer)), Bias::Left)
+                })
+            })
+            .ok_or_else(|| anyhow!("invalid position"))?;
+        Ok(Self { position })
+    }
+
+    fn response_to_proto(
+        completions: Vec<Completion>,
+        _: &mut Project,
+        _: PeerId,
+        buffer_version: &clock::Global,
+        _: &mut AppContext,
+    ) -> proto::GetCompletionsResponse {
+        proto::GetCompletionsResponse {
+            completions: completions
+                .iter()
+                .map(language::proto::serialize_completion)
+                .collect(),
+            version: serialize_version(&buffer_version),
+        }
+    }
+
+    async fn response_from_proto(
+        self,
+        message: proto::GetCompletionsResponse,
+        _: ModelHandle<Project>,
+        buffer: ModelHandle<Buffer>,
+        mut cx: AsyncAppContext,
+    ) -> Result<Vec<Completion>> {
+        buffer
+            .update(&mut cx, |buffer, _| {
+                buffer.wait_for_version(deserialize_version(&message.version))
+            })
+            .await?;
+
+        let language = buffer.read_with(&cx, |buffer, _| buffer.language().cloned());
+        let completions = message.completions.into_iter().map(|completion| {
+            language::proto::deserialize_completion(completion, language.clone())
+        });
+        futures::future::try_join_all(completions).await
+    }
+
+    fn buffer_id_from_proto(message: &proto::GetCompletions) -> u64 {
+        message.buffer_id
+    }
+}

crates/project/src/project.rs 🔗

@@ -410,7 +410,7 @@ impl Project {
         client.add_model_request_handler(Self::handle_synchronize_buffers);
         client.add_model_request_handler(Self::handle_format_buffers);
         client.add_model_request_handler(Self::handle_get_code_actions);
-        client.add_model_request_handler(Self::handle_get_completions);
+        client.add_model_request_handler(Self::handle_lsp_command::<GetCompletions>);
         client.add_model_request_handler(Self::handle_lsp_command::<GetHover>);
         client.add_model_request_handler(Self::handle_lsp_command::<GetDefinition>);
         client.add_model_request_handler(Self::handle_lsp_command::<GetTypeDefinition>);
@@ -3596,188 +3596,12 @@ impl Project {
 
     pub fn completions<T: ToPointUtf16>(
         &self,
-        source_buffer_handle: &ModelHandle<Buffer>,
+        buffer: &ModelHandle<Buffer>,
         position: T,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<Completion>>> {
-        let source_buffer_handle = source_buffer_handle.clone();
-        let source_buffer = source_buffer_handle.read(cx);
-        let buffer_id = source_buffer.remote_id();
-        let language = source_buffer.language().cloned();
-        let worktree;
-        let buffer_abs_path;
-        if let Some(file) = File::from_dyn(source_buffer.file()) {
-            worktree = file.worktree.clone();
-            buffer_abs_path = file.as_local().map(|f| f.abs_path(cx));
-        } else {
-            return Task::ready(Ok(Default::default()));
-        };
-
-        let position = Unclipped(position.to_point_utf16(source_buffer));
-        let anchor = source_buffer.anchor_after(position);
-
-        if worktree.read(cx).as_local().is_some() {
-            let buffer_abs_path = buffer_abs_path.unwrap();
-            let lang_server =
-                if let Some((_, server)) = self.language_server_for_buffer(source_buffer, cx) {
-                    server.clone()
-                } else {
-                    return Task::ready(Ok(Default::default()));
-                };
-
-            cx.spawn(|_, cx| async move {
-                let completions = lang_server
-                    .request::<lsp::request::Completion>(lsp::CompletionParams {
-                        text_document_position: lsp::TextDocumentPositionParams::new(
-                            lsp::TextDocumentIdentifier::new(
-                                lsp::Url::from_file_path(buffer_abs_path).unwrap(),
-                            ),
-                            point_to_lsp(position.0),
-                        ),
-                        context: Default::default(),
-                        work_done_progress_params: Default::default(),
-                        partial_result_params: Default::default(),
-                    })
-                    .await
-                    .context("lsp completion request failed")?;
-
-                let completions = if let Some(completions) = completions {
-                    match completions {
-                        lsp::CompletionResponse::Array(completions) => completions,
-                        lsp::CompletionResponse::List(list) => list.items,
-                    }
-                } else {
-                    Default::default()
-                };
-
-                let completions = source_buffer_handle.read_with(&cx, |this, _| {
-                    let snapshot = this.snapshot();
-                    let clipped_position = this.clip_point_utf16(position, Bias::Left);
-                    let mut range_for_token = None;
-                    completions
-                        .into_iter()
-                        .filter_map(move |mut lsp_completion| {
-                            // For now, we can only handle additional edits if they are returned
-                            // when resolving the completion, not if they are present initially.
-                            if lsp_completion
-                                .additional_text_edits
-                                .as_ref()
-                                .map_or(false, |edits| !edits.is_empty())
-                            {
-                                return None;
-                            }
-
-                            let (old_range, mut new_text) = match lsp_completion.text_edit.as_ref()
-                            {
-                                // If the language server provides a range to overwrite, then
-                                // check that the range is valid.
-                                Some(lsp::CompletionTextEdit::Edit(edit)) => {
-                                    let range = range_from_lsp(edit.range);
-                                    let start = snapshot.clip_point_utf16(range.start, Bias::Left);
-                                    let end = snapshot.clip_point_utf16(range.end, Bias::Left);
-                                    if start != range.start.0 || end != range.end.0 {
-                                        log::info!("completion out of expected range");
-                                        return None;
-                                    }
-                                    (
-                                        snapshot.anchor_before(start)..snapshot.anchor_after(end),
-                                        edit.new_text.clone(),
-                                    )
-                                }
-                                // If the language server does not provide a range, then infer
-                                // the range based on the syntax tree.
-                                None => {
-                                    if position.0 != clipped_position {
-                                        log::info!("completion out of expected range");
-                                        return None;
-                                    }
-                                    let Range { start, end } = range_for_token
-                                        .get_or_insert_with(|| {
-                                            let offset = position.to_offset(&snapshot);
-                                            let (range, kind) = snapshot.surrounding_word(offset);
-                                            if kind == Some(CharKind::Word) {
-                                                range
-                                            } else {
-                                                offset..offset
-                                            }
-                                        })
-                                        .clone();
-                                    let text = lsp_completion
-                                        .insert_text
-                                        .as_ref()
-                                        .unwrap_or(&lsp_completion.label)
-                                        .clone();
-                                    (
-                                        snapshot.anchor_before(start)..snapshot.anchor_after(end),
-                                        text,
-                                    )
-                                }
-                                Some(lsp::CompletionTextEdit::InsertAndReplace(_)) => {
-                                    log::info!("unsupported insert/replace completion");
-                                    return None;
-                                }
-                            };
-
-                            LineEnding::normalize(&mut new_text);
-                            let language = language.clone();
-                            Some(async move {
-                                let mut label = None;
-                                if let Some(language) = language {
-                                    language.process_completion(&mut lsp_completion).await;
-                                    label = language.label_for_completion(&lsp_completion).await;
-                                }
-                                Completion {
-                                    old_range,
-                                    new_text,
-                                    label: label.unwrap_or_else(|| {
-                                        CodeLabel::plain(
-                                            lsp_completion.label.clone(),
-                                            lsp_completion.filter_text.as_deref(),
-                                        )
-                                    }),
-                                    lsp_completion,
-                                }
-                            })
-                        })
-                });
-
-                Ok(futures::future::join_all(completions).await)
-            })
-        } else if let Some(project_id) = self.remote_id() {
-            let rpc = self.client.clone();
-            let message = proto::GetCompletions {
-                project_id,
-                buffer_id,
-                position: Some(language::proto::serialize_anchor(&anchor)),
-                version: serialize_version(&source_buffer.version()),
-            };
-            cx.spawn_weak(|this, mut cx| async move {
-                let response = rpc.request(message).await?;
-
-                if this
-                    .upgrade(&cx)
-                    .ok_or_else(|| anyhow!("project was dropped"))?
-                    .read_with(&cx, |this, _| this.is_read_only())
-                {
-                    return Err(anyhow!(
-                        "failed to get completions: project was disconnected"
-                    ));
-                } else {
-                    source_buffer_handle
-                        .update(&mut cx, |buffer, _| {
-                            buffer.wait_for_version(deserialize_version(&response.version))
-                        })
-                        .await?;
-
-                    let completions = response.completions.into_iter().map(|completion| {
-                        language::proto::deserialize_completion(completion, language.clone())
-                    });
-                    futures::future::try_join_all(completions).await
-                }
-            })
-        } else {
-            Task::ready(Ok(Default::default()))
-        }
+        let position = position.to_point_utf16(buffer.read(cx));
+        self.request_lsp(buffer.clone(), GetCompletions { position }, cx)
     }
 
     pub fn apply_additional_edits_for_completion(
@@ -5632,49 +5456,6 @@ impl Project {
         })
     }
 
-    async fn handle_get_completions(
-        this: ModelHandle<Self>,
-        envelope: TypedEnvelope<proto::GetCompletions>,
-        _: Arc<Client>,
-        mut cx: AsyncAppContext,
-    ) -> Result<proto::GetCompletionsResponse> {
-        let buffer = this.read_with(&cx, |this, cx| {
-            this.opened_buffers
-                .get(&envelope.payload.buffer_id)
-                .and_then(|buffer| buffer.upgrade(cx))
-                .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))
-        })?;
-
-        let version = deserialize_version(&envelope.payload.version);
-        buffer
-            .update(&mut cx, |buffer, _| buffer.wait_for_version(version))
-            .await?;
-        let version = buffer.read_with(&cx, |buffer, _| buffer.version());
-
-        let position = envelope
-            .payload
-            .position
-            .and_then(language::proto::deserialize_anchor)
-            .map(|p| {
-                buffer.read_with(&cx, |buffer, _| {
-                    buffer.clip_point_utf16(Unclipped(p.to_point_utf16(buffer)), Bias::Left)
-                })
-            })
-            .ok_or_else(|| anyhow!("invalid position"))?;
-
-        let completions = this
-            .update(&mut cx, |this, cx| this.completions(&buffer, position, cx))
-            .await?;
-
-        Ok(proto::GetCompletionsResponse {
-            completions: completions
-                .iter()
-                .map(language::proto::serialize_completion)
-                .collect(),
-            version: serialize_version(&version),
-        })
-    }
-
     async fn handle_apply_additional_edits_for_completion(
         this: ModelHandle<Self>,
         envelope: TypedEnvelope<proto::ApplyCompletionAdditionalEdits>,