Implement `Copilot::completions_cycling`

Antonio Scandurra and Mikayla Maki created

Co-Authored-By: Mikayla Maki <mikayla@zed.dev>

Change summary

crates/copilot/src/copilot.rs | 155 ++++++++++++++++++++++++++----------
crates/copilot/src/request.rs |  20 +++-
2 files changed, 126 insertions(+), 49 deletions(-)

Detailed changes

crates/copilot/src/copilot.rs 🔗

@@ -4,7 +4,7 @@ use anyhow::{anyhow, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use client::Client;
 use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
-use language::{point_to_lsp, Buffer, ToPointUtf16};
+use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
 use lsp::LanguageServer;
 use settings::Settings;
 use smol::{fs, io::BufReader, stream::StreamExt};
@@ -77,6 +77,12 @@ impl Status {
     }
 }
 
+#[derive(Debug)]
+pub struct Completion {
+    pub position: Anchor,
+    pub text: String,
+}
+
 struct Copilot {
     server: CopilotServer,
 }
@@ -186,12 +192,12 @@ impl Copilot {
         }
     }
 
-    pub fn completions<T>(
+    pub fn completion<T>(
         &self,
         buffer: &ModelHandle<Buffer>,
         position: T,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>>
+    ) -> Task<Result<Option<Completion>>>
     where
         T: ToPointUtf16,
     {
@@ -201,43 +207,45 @@ impl Copilot {
         };
 
         let buffer = buffer.read(cx).snapshot();
-        let position = position.to_point_utf16(&buffer);
-        let language_name = buffer.language_at(position).map(|language| language.name());
-        let language_name = language_name.as_deref();
-
-        let path;
-        let relative_path;
-        if let Some(file) = buffer.file() {
-            if let Some(file) = file.as_local() {
-                path = file.abs_path(cx);
-            } else {
-                path = file.full_path(cx);
-            }
-            relative_path = file.path().to_path_buf();
-        } else {
-            path = PathBuf::from("/untitled");
-            relative_path = PathBuf::from("untitled");
-        }
+        let request = server
+            .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
+        cx.background().spawn(async move {
+            let result = request.await?;
+            let completion = result
+                .completions
+                .into_iter()
+                .next()
+                .map(|completion| completion_from_lsp(completion, &buffer));
+            anyhow::Ok(completion)
+        })
+    }
 
-        let settings = cx.global::<Settings>();
-        let request = server.request::<request::GetCompletions>(request::GetCompletionsParams {
-            doc: request::GetCompletionsDocument {
-                source: buffer.text(),
-                tab_size: settings.tab_size(language_name).into(),
-                indent_size: 1,
-                insert_spaces: !settings.hard_tabs(language_name),
-                uri: lsp::Url::from_file_path(&path).unwrap(),
-                path: path.to_string_lossy().into(),
-                relative_path: relative_path.to_string_lossy().into(),
-                language_id: "csharp".into(),
-                position: point_to_lsp(position),
-                version: 0,
-            },
-        });
-        cx.spawn(|this, cx| async move {
-            dbg!(request.await?);
+    pub fn completions_cycling<T>(
+        &self,
+        buffer: &ModelHandle<Buffer>,
+        position: T,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<Completion>>>
+    where
+        T: ToPointUtf16,
+    {
+        let server = match self.authenticated_server() {
+            Ok(server) => server,
+            Err(error) => return Task::ready(Err(error)),
+        };
 
-            anyhow::Ok(())
+        let buffer = buffer.read(cx).snapshot();
+        let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
+            &buffer, position, cx,
+        ));
+        cx.background().spawn(async move {
+            let result = request.await?;
+            let completions = result
+                .completions
+                .into_iter()
+                .map(|completion| completion_from_lsp(completion, &buffer))
+                .collect();
+            anyhow::Ok(completions)
         })
     }
 
@@ -290,6 +298,62 @@ impl Copilot {
     }
 }
 
+fn build_completion_params<T>(
+    buffer: &BufferSnapshot,
+    position: T,
+    cx: &AppContext,
+) -> request::GetCompletionsParams
+where
+    T: ToPointUtf16,
+{
+    let position = position.to_point_utf16(&buffer);
+    let language_name = buffer.language_at(position).map(|language| language.name());
+    let language_name = language_name.as_deref();
+
+    let path;
+    let relative_path;
+    if let Some(file) = buffer.file() {
+        if let Some(file) = file.as_local() {
+            path = file.abs_path(cx);
+        } else {
+            path = file.full_path(cx);
+        }
+        relative_path = file.path().to_path_buf();
+    } else {
+        path = PathBuf::from("/untitled");
+        relative_path = PathBuf::from("untitled");
+    }
+
+    let settings = cx.global::<Settings>();
+    let language_id = match language_name {
+        Some("Plain Text") => "plaintext".to_string(),
+        Some(language_name) => language_name.to_lowercase(),
+        None => "plaintext".to_string(),
+    };
+    request::GetCompletionsParams {
+        doc: request::GetCompletionsDocument {
+            source: buffer.text(),
+            tab_size: settings.tab_size(language_name).into(),
+            indent_size: 1,
+            insert_spaces: !settings.hard_tabs(language_name),
+            uri: lsp::Url::from_file_path(&path).unwrap(),
+            path: path.to_string_lossy().into(),
+            relative_path: relative_path.to_string_lossy().into(),
+            language_id,
+            position: point_to_lsp(position),
+            version: 0,
+        },
+    }
+}
+
+fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
+    let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
+    Completion {
+        position: buffer.anchor_before(position),
+        text: completion.display_text,
+    }
+}
+
 async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
     ///Check for the latest copilot language server and download it if we haven't already
     async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
@@ -354,17 +418,22 @@ mod tests {
         Settings::test_async(cx);
         let http = http::client();
         let copilot = cx.add_model(|cx| Copilot::start(http, cx));
-        smol::Timer::after(std::time::Duration::from_secs(5)).await;
+        smol::Timer::after(std::time::Duration::from_secs(2)).await;
         copilot
             .update(cx, |copilot, cx| copilot.sign_in(cx))
             .await
             .unwrap();
         dbg!(copilot.read_with(cx, |copilot, _| copilot.status()));
 
-        let buffer = cx.add_model(|cx| language::Buffer::new(0, "Lorem ipsum dol", cx));
-        copilot
-            .update(cx, |copilot, cx| copilot.completions(&buffer, 15, cx))
+        let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx));
+        dbg!(copilot
+            .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx))
             .await
-            .unwrap();
+            .unwrap());
+        dbg!(copilot
+            .update(cx, |copilot, cx| copilot
+                .completions_cycling(&buffer, 12, cx))
+            .await
+            .unwrap());
     }
 }

crates/copilot/src/request.rs 🔗

@@ -114,17 +114,17 @@ pub struct GetCompletionsDocument {
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct GetCompletionsResult {
-    completions: Vec<Completion>,
+    pub completions: Vec<Completion>,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct Completion {
-    text: String,
-    position: lsp::Position,
-    uuid: String,
-    range: lsp::Range,
-    display_text: String,
+    pub text: String,
+    pub position: lsp::Position,
+    pub uuid: String,
+    pub range: lsp::Range,
+    pub display_text: String,
 }
 
 impl lsp::request::Request for GetCompletions {
@@ -132,3 +132,11 @@ impl lsp::request::Request for GetCompletions {
     type Result = GetCompletionsResult;
     const METHOD: &'static str = "getCompletions";
 }
+
+pub enum GetCompletionsCycling {}
+
+impl lsp::request::Request for GetCompletionsCycling {
+    type Params = GetCompletionsParams;
+    type Result = GetCompletionsResult;
+    const METHOD: &'static str = "getCompletionsCycling";
+}