Move buffer change reporting to a background task

Antonio Scandurra created

Change summary

crates/copilot/src/copilot.rs | 165 +++++++++++++++++++++++-------------
1 file changed, 104 insertions(+), 61 deletions(-)

Detailed changes

crates/copilot/src/copilot.rs 🔗

@@ -5,7 +5,7 @@ use anyhow::{anyhow, Context, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use async_tar::Archive;
 use collections::HashMap;
-use futures::{future::Shared, Future, FutureExt, TryFutureExt};
+use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
 use gpui::{
     actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
 };
@@ -171,56 +171,97 @@ impl Status {
 }
 
 struct RegisteredBuffer {
+    id: usize,
     uri: lsp::Url,
     language_id: String,
     snapshot: BufferSnapshot,
     snapshot_version: i32,
     _subscriptions: [gpui::Subscription; 2],
+    pending_buffer_change: Task<Option<()>>,
 }
 
 impl RegisteredBuffer {
     fn report_changes(
         &mut self,
         buffer: &ModelHandle<Buffer>,
-        server: &LanguageServer,
-        cx: &AppContext,
-    ) -> Result<()> {
-        let buffer = buffer.read(cx);
-        let new_snapshot = buffer.snapshot();
-        let content_changes = buffer
-            .edits_since::<(PointUtf16, usize)>(self.snapshot.version())
-            .map(|edit| {
-                let edit_start = edit.new.start.0;
-                let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
-                let new_text = new_snapshot
-                    .text_for_range(edit.new.start.1..edit.new.end.1)
-                    .collect();
-                lsp::TextDocumentContentChangeEvent {
-                    range: Some(lsp::Range::new(
-                        point_to_lsp(edit_start),
-                        point_to_lsp(edit_end),
-                    )),
-                    range_length: None,
-                    text: new_text,
-                }
-            })
-            .collect::<Vec<_>>();
-
-        if !content_changes.is_empty() {
-            self.snapshot_version += 1;
-            self.snapshot = new_snapshot;
-            server.notify::<lsp::notification::DidChangeTextDocument>(
-                lsp::DidChangeTextDocumentParams {
-                    text_document: lsp::VersionedTextDocumentIdentifier::new(
-                        self.uri.clone(),
-                        self.snapshot_version,
-                    ),
-                    content_changes,
-                },
-            )?;
+        cx: &mut ModelContext<Copilot>,
+    ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
+        let id = self.id;
+        let (done_tx, done_rx) = oneshot::channel();
+
+        if buffer.read(cx).version() == self.snapshot.version {
+            let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
+        } else {
+            let buffer = buffer.downgrade();
+            let prev_pending_change =
+                mem::replace(&mut self.pending_buffer_change, Task::ready(None));
+            self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
+                prev_pending_change.await;
+
+                let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
+                    let server = copilot.server.as_authenticated().log_err()?;
+                    let buffer = server.registered_buffers.get_mut(&id)?;
+                    Some(buffer.snapshot.version.clone())
+                })?;
+                let new_snapshot = buffer
+                    .upgrade(&cx)?
+                    .read_with(&cx, |buffer, _| buffer.snapshot());
+
+                let content_changes = cx
+                    .background()
+                    .spawn({
+                        let new_snapshot = new_snapshot.clone();
+                        async move {
+                            new_snapshot
+                                .edits_since::<(PointUtf16, usize)>(&old_version)
+                                .map(|edit| {
+                                    let edit_start = edit.new.start.0;
+                                    let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
+                                    let new_text = new_snapshot
+                                        .text_for_range(edit.new.start.1..edit.new.end.1)
+                                        .collect();
+                                    lsp::TextDocumentContentChangeEvent {
+                                        range: Some(lsp::Range::new(
+                                            point_to_lsp(edit_start),
+                                            point_to_lsp(edit_end),
+                                        )),
+                                        range_length: None,
+                                        text: new_text,
+                                    }
+                                })
+                                .collect::<Vec<_>>()
+                        }
+                    })
+                    .await;
+
+                copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
+                    let server = copilot.server.as_authenticated().log_err()?;
+                    let buffer = server.registered_buffers.get_mut(&id)?;
+                    if !content_changes.is_empty() {
+                        buffer.snapshot_version += 1;
+                        buffer.snapshot = new_snapshot;
+                        server
+                            .lsp
+                            .notify::<lsp::notification::DidChangeTextDocument>(
+                                lsp::DidChangeTextDocumentParams {
+                                    text_document: lsp::VersionedTextDocumentIdentifier::new(
+                                        buffer.uri.clone(),
+                                        buffer.snapshot_version,
+                                    ),
+                                    content_changes,
+                                },
+                            )
+                            .log_err();
+                    }
+                    let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
+                    Some(())
+                })?;
+
+                Some(())
+            });
         }
 
-        Ok(())
+        done_rx
     }
 }
 
@@ -567,10 +608,12 @@ impl Copilot {
                     .log_err();
 
                 RegisteredBuffer {
+                    id: buffer_id,
                     uri,
                     language_id,
                     snapshot,
                     snapshot_version: 0,
+                    pending_buffer_change: Task::ready(Some(())),
                     _subscriptions: [
                         cx.subscribe(buffer, |this, buffer, event, cx| {
                             this.handle_buffer_event(buffer, event, cx).log_err();
@@ -595,7 +638,7 @@ impl Copilot {
             if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
                 match event {
                     language::Event::Edited => {
-                        registered_buffer.report_changes(&buffer, &server.lsp, cx)?;
+                        let _ = registered_buffer.report_changes(&buffer, cx);
                     }
                     language::Event::Saved => {
                         server
@@ -750,38 +793,38 @@ impl Copilot {
             Ok(server) => server,
             Err(error) => return Task::ready(Err(error)),
         };
+        let lsp = server.lsp.clone();
         let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
-        if let Err(error) = registered_buffer.report_changes(buffer, &server.lsp, cx) {
-            return Task::ready(Err(error));
-        }
-
+        let snapshot = registered_buffer.report_changes(buffer, cx);
+        let buffer = buffer.read(cx);
         let uri = registered_buffer.uri.clone();
-        let snapshot = registered_buffer.snapshot.clone();
-        let version = registered_buffer.snapshot_version;
         let settings = cx.global::<Settings>();
-        let position = position.to_point_utf16(&snapshot);
-        let language = snapshot.language_at(position);
+        let position = position.to_point_utf16(buffer);
+        let language = buffer.language_at(position);
         let language_name = language.map(|language| language.name());
         let language_name = language_name.as_deref();
         let tab_size = settings.tab_size(language_name);
         let hard_tabs = settings.hard_tabs(language_name);
-        let relative_path = snapshot
+        let relative_path = buffer
             .file()
             .map(|file| file.path().to_path_buf())
             .unwrap_or_default();
-        let request = server.lsp.request::<R>(request::GetCompletionsParams {
-            doc: request::GetCompletionsDocument {
-                uri,
-                tab_size: tab_size.into(),
-                indent_size: 1,
-                insert_spaces: !hard_tabs,
-                relative_path: relative_path.to_string_lossy().into(),
-                position: point_to_lsp(position),
-                version: version.try_into().unwrap(),
-            },
-        });
-        cx.background().spawn(async move {
-            let result = request.await?;
+
+        cx.foreground().spawn(async move {
+            let (version, snapshot) = snapshot.await?;
+            let result = lsp
+                .request::<R>(request::GetCompletionsParams {
+                    doc: request::GetCompletionsDocument {
+                        uri,
+                        tab_size: tab_size.into(),
+                        indent_size: 1,
+                        insert_spaces: !hard_tabs,
+                        relative_path: relative_path.to_string_lossy().into(),
+                        position: point_to_lsp(position),
+                        version: version.try_into().unwrap(),
+                    },
+                })
+                .await?;
             let completions = result
                 .completions
                 .into_iter()