Relay buffer change events to Copilot

Antonio Scandurra created

Change summary

Cargo.lock                    |   1 
crates/copilot/src/copilot.rs | 338 ++++++++++++++++++++++++++----------
crates/copilot/src/request.rs |   3 
crates/project/Cargo.toml     |   1 
crates/project/src/project.rs |  24 ++
5 files changed, 268 insertions(+), 99 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4687,6 +4687,7 @@ dependencies = [
  "client",
  "clock",
  "collections",
+ "copilot",
  "ctor",
  "db",
  "env_logger",

crates/copilot/src/copilot.rs 🔗

@@ -6,8 +6,13 @@ use async_compression::futures::bufread::GzipDecoder;
 use async_tar::Archive;
 use collections::HashMap;
 use futures::{future::Shared, Future, FutureExt, TryFutureExt};
-use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
-use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
+use gpui::{
+    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
+};
+use language::{
+    point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
+    ToPointUtf16,
+};
 use log::{debug, error};
 use lsp::LanguageServer;
 use node_runtime::NodeRuntime;
@@ -105,7 +110,7 @@ enum CopilotServer {
     Started {
         server: Arc<LanguageServer>,
         status: SignInStatus,
-        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
+        registered_buffers: HashMap<usize, RegisteredBuffer>,
     },
 }
 
@@ -141,6 +146,66 @@ impl Status {
     }
 }
 
+struct RegisteredBuffer {
+    uri: lsp::Url,
+    snapshot: Option<(i32, BufferSnapshot)>,
+    _subscriptions: [gpui::Subscription; 2],
+}
+
+impl RegisteredBuffer {
+    fn report_changes(
+        &mut self,
+        buffer: &ModelHandle<Buffer>,
+        server: &LanguageServer,
+        cx: &AppContext,
+    ) -> Result<(i32, BufferSnapshot)> {
+        let buffer = buffer.read(cx);
+        let (version, prev_snapshot) = self
+            .snapshot
+            .as_ref()
+            .ok_or_else(|| anyhow!("expected at least one snapshot"))?;
+        let next_snapshot = buffer.snapshot();
+
+        let content_changes = buffer
+            .edits_since::<(PointUtf16, usize)>(prev_snapshot.version())
+            .map(|edit| {
+                let edit_start = edit.new.start.0;
+                let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
+                let new_text = next_snapshot
+                    .text_for_range(edit.new.start.1..edit.new.end.1)
+                    .collect();
+                lsp::TextDocumentContentChangeEvent {
+                    range: Some(lsp::Range::new(
+                        point_to_lsp(edit_start),
+                        point_to_lsp(edit_end),
+                    )),
+                    range_length: None,
+                    text: new_text,
+                }
+            })
+            .collect::<Vec<_>>();
+
+        if content_changes.is_empty() {
+            Ok((*version, prev_snapshot.clone()))
+        } else {
+            let next_version = version + 1;
+            self.snapshot = Some((next_version, next_snapshot.clone()));
+
+            server.notify::<lsp::notification::DidChangeTextDocument>(
+                lsp::DidChangeTextDocumentParams {
+                    text_document: lsp::VersionedTextDocumentIdentifier::new(
+                        self.uri.clone(),
+                        next_version,
+                    ),
+                    content_changes,
+                },
+            )?;
+
+            Ok((next_version, next_snapshot))
+        }
+    }
+}
+
 #[derive(Debug, PartialEq, Eq)]
 pub struct Completion {
     pub range: Range<Anchor>,
@@ -151,6 +216,7 @@ pub struct Copilot {
     http: Arc<dyn HttpClient>,
     node_runtime: Arc<NodeRuntime>,
     server: CopilotServer,
+    buffers: HashMap<usize, WeakModelHandle<Buffer>>,
 }
 
 impl Entity for Copilot {
@@ -212,12 +278,14 @@ impl Copilot {
                 http,
                 node_runtime,
                 server: CopilotServer::Starting { task: start_task },
+                buffers: Default::default(),
             }
         } else {
             Self {
                 http,
                 node_runtime,
                 server: CopilotServer::Disabled,
+                buffers: Default::default(),
             }
         }
     }
@@ -233,8 +301,9 @@ impl Copilot {
             server: CopilotServer::Started {
                 server: Arc::new(server),
                 status: SignInStatus::Authorized,
-                subscriptions_by_buffer_id: Default::default(),
+                registered_buffers: Default::default(),
             },
+            buffers: Default::default(),
         });
         (this, fake_server)
     }
@@ -297,7 +366,7 @@ impl Copilot {
                         this.server = CopilotServer::Started {
                             server,
                             status: SignInStatus::SignedOut,
-                            subscriptions_by_buffer_id: Default::default(),
+                            registered_buffers: Default::default(),
                         };
                         this.update_sign_in_status(status, cx);
                     }
@@ -396,10 +465,8 @@ impl Copilot {
     }
 
     fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
-        if let CopilotServer::Started { server, status, .. } = &mut self.server {
-            *status = SignInStatus::SignedOut;
-            cx.notify();
-
+        self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
+        if let CopilotServer::Started { server, .. } = &self.server {
             let server = server.clone();
             cx.background().spawn(async move {
                 server
@@ -433,6 +500,108 @@ impl Copilot {
         cx.foreground().spawn(start_task)
     }
 
+    pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
+        let buffer_id = buffer.id();
+        self.buffers.insert(buffer_id, buffer.downgrade());
+
+        if let CopilotServer::Started {
+            server,
+            status,
+            registered_buffers,
+            ..
+        } = &mut self.server
+        {
+            if !matches!(status, SignInStatus::Authorized { .. }) {
+                return;
+            }
+
+            let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
+            registered_buffers.entry(buffer.id()).or_insert_with(|| {
+                let snapshot = buffer.read(cx).snapshot();
+                server
+                    .notify::<lsp::notification::DidOpenTextDocument>(
+                        lsp::DidOpenTextDocumentParams {
+                            text_document: lsp::TextDocumentItem {
+                                uri: uri.clone(),
+                                language_id: id_for_language(buffer.read(cx).language()),
+                                version: 0,
+                                text: snapshot.text(),
+                            },
+                        },
+                    )
+                    .log_err();
+
+                RegisteredBuffer {
+                    uri,
+                    snapshot: Some((0, snapshot)),
+                    _subscriptions: [
+                        cx.subscribe(buffer, |this, buffer, event, cx| {
+                            this.handle_buffer_event(buffer, event, cx).log_err();
+                        }),
+                        cx.observe_release(buffer, move |this, _buffer, _cx| {
+                            this.buffers.remove(&buffer_id);
+                            this.unregister_buffer(buffer_id);
+                        }),
+                    ],
+                }
+            });
+        }
+    }
+
+    fn handle_buffer_event(
+        &mut self,
+        buffer: ModelHandle<Buffer>,
+        event: &language::Event,
+        cx: &mut ModelContext<Self>,
+    ) -> Result<()> {
+        if let CopilotServer::Started {
+            server,
+            registered_buffers,
+            ..
+        } = &mut self.server
+        {
+            if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) {
+                match event {
+                    language::Event::Edited => {
+                        registered_buffer.report_changes(&buffer, server, cx)?;
+                    }
+                    language::Event::Saved => {
+                        server.notify::<lsp::notification::DidSaveTextDocument>(
+                            lsp::DidSaveTextDocumentParams {
+                                text_document: lsp::TextDocumentIdentifier::new(
+                                    registered_buffer.uri.clone(),
+                                ),
+                                text: None,
+                            },
+                        )?;
+                    }
+                    _ => {}
+                }
+            }
+        }
+
+        Ok(())
+    }
+
+    fn unregister_buffer(&mut self, buffer_id: usize) {
+        if let CopilotServer::Started {
+            server,
+            registered_buffers,
+            ..
+        } = &mut self.server
+        {
+            if let Some(buffer) = registered_buffers.remove(&buffer_id) {
+                server
+                    .notify::<lsp::notification::DidCloseTextDocument>(
+                        lsp::DidCloseTextDocumentParams {
+                            text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
+                        },
+                    )
+                    .log_err();
+            }
+        }
+    }
+
     pub fn completions<T>(
         &mut self,
         buffer: &ModelHandle<Buffer>,
@@ -464,16 +633,14 @@ impl Copilot {
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<Completion>>>
     where
-        R: lsp::request::Request<
-            Params = request::GetCompletionsParams,
-            Result = request::GetCompletionsResult,
-        >,
+        R: 'static
+            + lsp::request::Request<
+                Params = request::GetCompletionsParams,
+                Result = request::GetCompletionsResult,
+            >,
         T: ToPointUtf16,
     {
-        let buffer_id = buffer.id();
-        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
-        let snapshot = buffer.read(cx).snapshot();
-        let server = match &mut self.server {
+        let (server, registered_buffer) = match &mut self.server {
             CopilotServer::Starting { .. } => {
                 return Task::ready(Err(anyhow!("copilot is still starting")))
             }
@@ -487,56 +654,28 @@ impl Copilot {
             CopilotServer::Started {
                 server,
                 status,
-                subscriptions_by_buffer_id,
+                registered_buffers,
+                ..
             } => {
                 if matches!(status, SignInStatus::Authorized { .. }) {
-                    subscriptions_by_buffer_id
-                        .entry(buffer_id)
-                        .or_insert_with(|| {
-                            server
-                                .notify::<lsp::notification::DidOpenTextDocument>(
-                                    lsp::DidOpenTextDocumentParams {
-                                        text_document: lsp::TextDocumentItem {
-                                            uri: uri.clone(),
-                                            language_id: id_for_language(
-                                                buffer.read(cx).language(),
-                                            ),
-                                            version: 0,
-                                            text: snapshot.text(),
-                                        },
-                                    },
-                                )
-                                .log_err();
-
-                            let uri = uri.clone();
-                            cx.observe_release(buffer, move |this, _, _| {
-                                if let CopilotServer::Started {
-                                    server,
-                                    subscriptions_by_buffer_id,
-                                    ..
-                                } = &mut this.server
-                                {
-                                    server
-                                        .notify::<lsp::notification::DidCloseTextDocument>(
-                                            lsp::DidCloseTextDocumentParams {
-                                                text_document: lsp::TextDocumentIdentifier::new(
-                                                    uri.clone(),
-                                                ),
-                                            },
-                                        )
-                                        .log_err();
-                                    subscriptions_by_buffer_id.remove(&buffer_id);
-                                }
-                            })
-                        });
-
-                    server.clone()
+                    if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) {
+                        (server.clone(), registered_buffer)
+                    } else {
+                        return Task::ready(Err(anyhow!(
+                            "requested completions for an unregistered buffer"
+                        )));
+                    }
                 } else {
                     return Task::ready(Err(anyhow!("must sign in before using copilot")));
                 }
             }
         };
 
+        let (version, snapshot) = match registered_buffer.report_changes(buffer, &server, cx) {
+            Ok((version, snapshot)) => (version, snapshot),
+            Err(error) => return Task::ready(Err(error)),
+        };
+        let uri = registered_buffer.uri.clone();
         let settings = cx.global::<Settings>();
         let position = position.to_point_utf16(&snapshot);
         let language = snapshot.language_at(position);
@@ -544,39 +683,23 @@ impl Copilot {
         let language_name = language_name.as_deref();
         let tab_size = settings.tab_size(language_name);
         let hard_tabs = settings.hard_tabs(language_name);
-        let language_id = id_for_language(language);
-
-        let path;
-        let relative_path;
-        if let Some(file) = snapshot.file() {
-            if let Some(file) = file.as_local() {
-                path = file.abs_path(cx);
-            } else {
-                path = file.full_path(cx);
-            }
-            relative_path = file.path().to_path_buf();
-        } else {
-            path = PathBuf::new();
-            relative_path = PathBuf::new();
-        }
-
+        let relative_path = snapshot
+            .file()
+            .map(|file| file.path().to_path_buf())
+            .unwrap_or_default();
+        let request = server.request::<R>(request::GetCompletionsParams {
+            doc: request::GetCompletionsDocument {
+                uri,
+                tab_size: tab_size.into(),
+                indent_size: 1,
+                insert_spaces: !hard_tabs,
+                relative_path: relative_path.to_string_lossy().into(),
+                position: point_to_lsp(position),
+                version: version.try_into().unwrap(),
+            },
+        });
         cx.background().spawn(async move {
-            let result = server
-                .request::<R>(request::GetCompletionsParams {
-                    doc: request::GetCompletionsDocument {
-                        source: snapshot.text(),
-                        tab_size: tab_size.into(),
-                        indent_size: 1,
-                        insert_spaces: !hard_tabs,
-                        uri,
-                        path: path.to_string_lossy().into(),
-                        relative_path: relative_path.to_string_lossy().into(),
-                        language_id,
-                        position: point_to_lsp(position),
-                        version: 0,
-                    },
-                })
-                .await?;
+            let result = request.await?;
             let completions = result
                 .completions
                 .into_iter()
@@ -616,14 +739,37 @@ impl Copilot {
         lsp_status: request::SignInStatus,
         cx: &mut ModelContext<Self>,
     ) {
+        self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
+
         if let CopilotServer::Started { status, .. } = &mut self.server {
-            *status = match lsp_status {
+            match lsp_status {
                 request::SignInStatus::Ok { .. }
                 | request::SignInStatus::MaybeOk { .. }
-                | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
-                request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
-                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
-            };
+                | request::SignInStatus::AlreadySignedIn { .. } => {
+                    *status = SignInStatus::Authorized;
+
+                    for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
+                        if let Some(buffer) = buffer.upgrade(cx) {
+                            self.register_buffer(&buffer, cx);
+                        }
+                    }
+                }
+                request::SignInStatus::NotAuthorized { .. } => {
+                    *status = SignInStatus::Unauthorized;
+
+                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
+                        self.unregister_buffer(buffer_id);
+                    }
+                }
+                request::SignInStatus::NotSignedIn => {
+                    *status = SignInStatus::SignedOut;
+
+                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
+                        self.unregister_buffer(buffer_id);
+                    }
+                }
+            }
+
             cx.notify();
         }
     }

crates/copilot/src/request.rs 🔗

@@ -99,14 +99,11 @@ pub struct GetCompletionsParams {
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct GetCompletionsDocument {
-    pub source: String,
     pub tab_size: u32,
     pub indent_size: u32,
     pub insert_spaces: bool,
     pub uri: lsp::Url,
-    pub path: String,
     pub relative_path: String,
-    pub language_id: String,
     pub position: lsp::Position,
     pub version: usize,
 }

crates/project/Cargo.toml 🔗

@@ -19,6 +19,7 @@ test-support = [
 
 [dependencies]
 text = { path = "../text" }
+copilot = { path = "../copilot" }
 client = { path = "../client" }
 clock = { path = "../clock" }
 collections = { path = "../collections" }

crates/project/src/project.rs 🔗

@@ -12,6 +12,7 @@ use anyhow::{anyhow, Context, Result};
 use client::{proto, Client, TypedEnvelope, UserStore};
 use clock::ReplicaId;
 use collections::{hash_map, BTreeMap, HashMap, HashSet};
+use copilot::Copilot;
 use futures::{
     channel::mpsc::{self, UnboundedReceiver},
     future::{try_join_all, Shared},
@@ -129,6 +130,7 @@ pub struct Project {
     _maintain_buffer_languages: Task<()>,
     _maintain_workspace_config: Task<()>,
     terminals: Terminals,
+    copilot_enabled: bool,
 }
 
 enum BufferMessage {
@@ -472,6 +474,7 @@ impl Project {
                 terminals: Terminals {
                     local_handles: Vec::new(),
                 },
+                copilot_enabled: Copilot::global(cx).is_some(),
             }
         })
     }
@@ -559,6 +562,7 @@ impl Project {
                 terminals: Terminals {
                     local_handles: Vec::new(),
                 },
+                copilot_enabled: Copilot::global(cx).is_some(),
             };
             for worktree in worktrees {
                 let _ = this.add_worktree(&worktree, cx);
@@ -664,6 +668,15 @@ impl Project {
             self.start_language_server(worktree_id, worktree_path, language, cx);
         }
 
+        if !self.copilot_enabled && Copilot::global(cx).is_some() {
+            self.copilot_enabled = true;
+            for buffer in self.opened_buffers.values() {
+                if let Some(buffer) = buffer.upgrade(cx) {
+                    self.register_buffer_with_copilot(&buffer, cx);
+                }
+            }
+        }
+
         cx.notify();
     }
 
@@ -1616,6 +1629,7 @@ impl Project {
 
         self.detect_language_for_buffer(buffer, cx);
         self.register_buffer_with_language_server(buffer, cx);
+        self.register_buffer_with_copilot(buffer, cx);
         cx.observe_release(buffer, |this, buffer, cx| {
             if let Some(file) = File::from_dyn(buffer.file()) {
                 if file.is_local() {
@@ -1731,6 +1745,16 @@ impl Project {
         });
     }
 
+    fn register_buffer_with_copilot(
+        &self,
+        buffer_handle: &ModelHandle<Buffer>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        if let Some(copilot) = Copilot::global(cx) {
+            copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx));
+        }
+    }
+
     async fn send_buffer_messages(
         this: WeakModelHandle<Self>,
         rx: UnboundedReceiver<BufferMessage>,