Merge pull request #2344 from zed-industries/copilot-collaboration

Antonio Scandurra created

Fix Copilot errors when opening buffers that don't exist locally

Change summary

crates/copilot/Cargo.toml     |   1 
crates/copilot/src/copilot.rs | 258 +++++++++++++++++++-----------------
crates/copilot/src/editor.rs  |   3 
crates/editor/src/editor.rs   |   4 
4 files changed, 141 insertions(+), 125 deletions(-)

Detailed changes

crates/copilot/Cargo.toml 🔗

@@ -19,7 +19,6 @@ lsp = { path = "../lsp" }
 node_runtime = { path = "../node_runtime"}
 util = { path = "../util" }
 client = { path = "../client" }
-workspace = { path = "../workspace" }
 async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] }
 async-tar = "0.4.2"
 anyhow = "1.0"

crates/copilot/src/copilot.rs 🔗

@@ -1,16 +1,17 @@
 mod request;
 mod sign_in;
 
-use anyhow::{anyhow, bail, Context, Result};
+use anyhow::{anyhow, Context, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use async_tar::Archive;
 use client::Client;
+use collections::HashMap;
 use futures::{future::Shared, Future, FutureExt, TryFutureExt};
 use gpui::{
     actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
     Task,
 };
-use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
+use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 use log::{debug, error};
 use lsp::LanguageServer;
 use node_runtime::NodeRuntime;
@@ -92,6 +93,7 @@ enum CopilotServer {
     Started {
         server: Arc<LanguageServer>,
         status: SignInStatus,
+        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
     },
 }
 
@@ -275,6 +277,7 @@ impl Copilot {
                         this.server = CopilotServer::Started {
                             server,
                             status: SignInStatus::SignedOut,
+                            subscriptions_by_buffer_id: Default::default(),
                         };
                         this.update_sign_in_status(status, cx);
                     }
@@ -288,7 +291,7 @@ impl Copilot {
     }
 
     fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
-        if let CopilotServer::Started { server, status } = &mut self.server {
+        if let CopilotServer::Started { server, status, .. } = &mut self.server {
             let task = match status {
                 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
                     Task::ready(Ok(())).shared()
@@ -373,7 +376,7 @@ impl Copilot {
     }
 
     fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
-        if let CopilotServer::Started { server, status } = &mut self.server {
+        if let CopilotServer::Started { server, status, .. } = &mut self.server {
             *status = SignInStatus::SignedOut;
             cx.notify();
 
@@ -410,43 +413,20 @@ impl Copilot {
         cx.foreground().spawn(start_task)
     }
 
-    pub fn completion<T>(
-        &self,
+    pub fn completions<T>(
+        &mut self,
         buffer: &ModelHandle<Buffer>,
         position: T,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<Option<Completion>>>
+    ) -> Task<Result<Vec<Completion>>>
     where
         T: ToPointUtf16,
     {
-        let server = match self.authorized_server() {
-            Ok(server) => server,
-            Err(error) => return Task::ready(Err(error)),
-        };
-
-        let buffer = buffer.read(cx);
-
-        if !buffer.file().map(|file| file.is_local()).unwrap_or(true) {
-            return Task::ready(Err(anyhow!("Copilot only works locally")));
-        }
-
-        let buffer = buffer.snapshot();
-        let request = server.request::<request::GetCompletions>(
-            build_completion_params(&buffer, position, cx).unwrap(),
-        );
-        cx.background().spawn(async move {
-            let result = request.await?;
-            let completion = result
-                .completions
-                .into_iter()
-                .next()
-                .map(|completion| completion_from_lsp(completion, &buffer));
-            anyhow::Ok(completion)
-        })
+        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
     }
 
     pub fn completions_cycling<T>(
-        &self,
+        &mut self,
         buffer: &ModelHandle<Buffer>,
         position: T,
         cx: &mut ModelContext<Self>,
@@ -454,27 +434,138 @@ impl Copilot {
     where
         T: ToPointUtf16,
     {
-        let server = match self.authorized_server() {
-            Ok(server) => server,
-            Err(error) => return Task::ready(Err(error)),
-        };
+        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
+    }
 
-        let buffer = buffer.read(cx);
+    fn request_completions<R, T>(
+        &mut self,
+        buffer: &ModelHandle<Buffer>,
+        position: T,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<Completion>>>
+    where
+        R: lsp::request::Request<
+            Params = request::GetCompletionsParams,
+            Result = request::GetCompletionsResult,
+        >,
+        T: ToPointUtf16,
+    {
+        let buffer_id = buffer.id();
+        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
+        let snapshot = buffer.read(cx).snapshot();
+        let server = match &mut self.server {
+            CopilotServer::Starting { .. } => {
+                return Task::ready(Err(anyhow!("copilot is still starting")))
+            }
+            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
+            CopilotServer::Error(error) => {
+                return Task::ready(Err(anyhow!(
+                    "copilot was not started because of an error: {}",
+                    error
+                )))
+            }
+            CopilotServer::Started {
+                server,
+                status,
+                subscriptions_by_buffer_id,
+            } => {
+                if matches!(status, SignInStatus::Authorized { .. }) {
+                    subscriptions_by_buffer_id
+                        .entry(buffer_id)
+                        .or_insert_with(|| {
+                            server
+                                .notify::<lsp::notification::DidOpenTextDocument>(
+                                    lsp::DidOpenTextDocumentParams {
+                                        text_document: lsp::TextDocumentItem {
+                                            uri: uri.clone(),
+                                            language_id: id_for_language(
+                                                buffer.read(cx).language(),
+                                            ),
+                                            version: 0,
+                                            text: snapshot.text(),
+                                        },
+                                    },
+                                )
+                                .log_err();
+
+                            let uri = uri.clone();
+                            cx.observe_release(buffer, move |this, _, _| {
+                                if let CopilotServer::Started {
+                                    server,
+                                    subscriptions_by_buffer_id,
+                                    ..
+                                } = &mut this.server
+                                {
+                                    server
+                                        .notify::<lsp::notification::DidCloseTextDocument>(
+                                            lsp::DidCloseTextDocumentParams {
+                                                text_document: lsp::TextDocumentIdentifier::new(
+                                                    uri.clone(),
+                                                ),
+                                            },
+                                        )
+                                        .log_err();
+                                    subscriptions_by_buffer_id.remove(&buffer_id);
+                                }
+                            })
+                        });
 
-        if !buffer.file().map(|file| file.is_local()).unwrap_or(true) {
-            return Task::ready(Err(anyhow!("Copilot only works locally")));
+                    server.clone()
+                } else {
+                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
+                }
+            }
+        };
+
+        let settings = cx.global::<Settings>();
+        let position = position.to_point_utf16(&snapshot);
+        let language = snapshot.language_at(position);
+        let language_name = language.map(|language| language.name());
+        let language_name = language_name.as_deref();
+
+        let path;
+        let relative_path;
+        if let Some(file) = snapshot.file() {
+            if let Some(file) = file.as_local() {
+                path = file.abs_path(cx);
+            } else {
+                path = file.full_path(cx);
+            }
+            relative_path = file.path().to_path_buf();
+        } else {
+            path = PathBuf::new();
+            relative_path = PathBuf::new();
         }
 
-        let buffer = buffer.snapshot();
-        let request = server.request::<request::GetCompletionsCycling>(
-            build_completion_params(&buffer, position, cx).unwrap(),
-        );
+        let params = request::GetCompletionsParams {
+            doc: request::GetCompletionsDocument {
+                source: snapshot.text(),
+                tab_size: settings.tab_size(language_name).into(),
+                indent_size: 1,
+                insert_spaces: !settings.hard_tabs(language_name),
+                uri,
+                path: path.to_string_lossy().into(),
+                relative_path: relative_path.to_string_lossy().into(),
+                language_id: id_for_language(language),
+                position: point_to_lsp(position),
+                version: 0,
+            },
+        };
         cx.background().spawn(async move {
-            let result = request.await?;
+            let result = server.request::<R>(params).await?;
             let completions = result
                 .completions
                 .into_iter()
-                .map(|completion| completion_from_lsp(completion, &buffer))
+                .map(|completion| {
+                    let start = snapshot
+                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
+                    let end =
+                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
+                    Completion {
+                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
+                        text: completion.text,
+                    }
+                })
                 .collect();
             anyhow::Ok(completions)
         })
@@ -516,85 +607,14 @@ impl Copilot {
             cx.notify();
         }
     }
-
-    fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
-        match &self.server {
-            CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
-            CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
-            CopilotServer::Error(error) => Err(anyhow!(
-                "copilot was not started because of an error: {}",
-                error
-            )),
-            CopilotServer::Started { server, status } => {
-                if matches!(status, SignInStatus::Authorized { .. }) {
-                    Ok(server.clone())
-                } else {
-                    Err(anyhow!("must sign in before using copilot"))
-                }
-            }
-        }
-    }
 }
 
-fn build_completion_params<T>(
-    buffer: &BufferSnapshot,
-    position: T,
-    cx: &AppContext,
-) -> anyhow::Result<request::GetCompletionsParams>
-where
-    T: ToPointUtf16,
-{
-    let position = position.to_point_utf16(&buffer);
-    let language_name = buffer.language_at(position).map(|language| language.name());
-    let language_name = language_name.as_deref();
-
-    let path;
-    let relative_path;
-    if let Some(file) = buffer.file() {
-        if let Some(file) = file.as_local() {
-            path = file.abs_path(cx);
-        } else {
-            path = file.full_path(cx);
-        }
-        relative_path = file.path().to_path_buf();
-    } else {
-        path = PathBuf::from("/untitled");
-        relative_path = PathBuf::from("untitled");
-    }
-
-    let settings = cx.global::<Settings>();
-    let language_id = match language_name {
+fn id_for_language(language: Option<&Arc<Language>>) -> String {
+    let language_name = language.map(|language| language.name());
+    match language_name.as_deref() {
         Some("Plain Text") => "plaintext".to_string(),
         Some(language_name) => language_name.to_lowercase(),
         None => "plaintext".to_string(),
-    };
-
-    let Ok(uri) = lsp::Url::from_file_path(&path) else {
-        bail!("Failed convert file path")
-    };
-
-    Ok(request::GetCompletionsParams {
-        doc: request::GetCompletionsDocument {
-            source: buffer.text(),
-            tab_size: settings.tab_size(language_name).into(),
-            indent_size: 1,
-            insert_spaces: !settings.hard_tabs(language_name),
-            uri,
-            path: path.to_string_lossy().into(),
-            relative_path: relative_path.to_string_lossy().into(),
-            language_id,
-            position: point_to_lsp(position),
-            version: 0,
-        },
-    })
-}
-
-fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
-    let start = buffer.clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
-    let end = buffer.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
-    Completion {
-        range: buffer.anchor_before(start)..buffer.anchor_after(end),
-        text: completion.text,
     }
 }
 

crates/editor/src/editor.rs 🔗

@@ -2843,14 +2843,14 @@ impl Editor {
         self.copilot_state.pending_refresh = cx.spawn_weak(|this, mut cx| async move {
             let (completion, completions_cycling) = copilot.update(&mut cx, |copilot, cx| {
                 (
-                    copilot.completion(&buffer, buffer_position, cx),
+                    copilot.completions(&buffer, buffer_position, cx),
                     copilot.completions_cycling(&buffer, buffer_position, cx),
                 )
             });
 
             let (completion, completions_cycling) = futures::join!(completion, completions_cycling);
             let mut completions = Vec::new();
-            completions.extend(completion.log_err().flatten());
+            completions.extend(completion.log_err().into_iter().flatten());
             completions.extend(completions_cycling.log_err().into_iter().flatten());
             this.upgrade(&cx)?.update(&mut cx, |this, cx| {
                 if !completions.is_empty() {