Use local ids, not remote ids, to identify buffers to copilot

Max Brunsfeld created

Change summary

crates/copilot/src/copilot.rs | 47 +++++++++++++++---------------------
1 file changed, 20 insertions(+), 27 deletions(-)

Detailed changes

crates/copilot/src/copilot.rs 🔗

@@ -4,7 +4,7 @@ mod sign_in;
 use anyhow::{anyhow, Context, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use async_tar::Archive;
-use collections::HashMap;
+use collections::{HashMap, HashSet};
 use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
 use gpui::{
     actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
@@ -127,7 +127,7 @@ impl CopilotServer {
 struct RunningCopilotServer {
     lsp: Arc<LanguageServer>,
     sign_in_status: SignInStatus,
-    registered_buffers: HashMap<u64, RegisteredBuffer>,
+    registered_buffers: HashMap<usize, RegisteredBuffer>,
 }
 
 #[derive(Clone, Debug)]
@@ -163,7 +163,6 @@ impl Status {
 }
 
 struct RegisteredBuffer {
-    id: u64,
     uri: lsp::Url,
     language_id: String,
     snapshot: BufferSnapshot,
@@ -178,13 +177,13 @@ impl RegisteredBuffer {
         buffer: &ModelHandle<Buffer>,
         cx: &mut ModelContext<Copilot>,
     ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
-        let id = self.id;
         let (done_tx, done_rx) = oneshot::channel();
 
         if buffer.read(cx).version() == self.snapshot.version {
             let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
         } else {
             let buffer = buffer.downgrade();
+            let id = buffer.id();
             let prev_pending_change =
                 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
             self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
@@ -268,7 +267,7 @@ pub struct Copilot {
     http: Arc<dyn HttpClient>,
     node_runtime: Arc<NodeRuntime>,
     server: CopilotServer,
-    buffers: HashMap<u64, WeakModelHandle<Buffer>>,
+    buffers: HashSet<WeakModelHandle<Buffer>>,
 }
 
 impl Entity for Copilot {
@@ -559,8 +558,8 @@ impl Copilot {
     }
 
     pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
-        let buffer_id = buffer.read(cx).remote_id();
-        self.buffers.insert(buffer_id, buffer.downgrade());
+        let weak_buffer = buffer.downgrade();
+        self.buffers.insert(weak_buffer.clone());
 
         if let CopilotServer::Running(RunningCopilotServer {
             lsp: server,
@@ -573,8 +572,7 @@ impl Copilot {
                 return;
             }
 
-            let buffer_id = buffer.read(cx).remote_id();
-            registered_buffers.entry(buffer_id).or_insert_with(|| {
+            registered_buffers.entry(buffer.id()).or_insert_with(|| {
                 let uri: lsp::Url = uri_for_buffer(buffer, cx);
                 let language_id = id_for_language(buffer.read(cx).language());
                 let snapshot = buffer.read(cx).snapshot();
@@ -592,7 +590,6 @@ impl Copilot {
                     .log_err();
 
                 RegisteredBuffer {
-                    id: buffer_id,
                     uri,
                     language_id,
                     snapshot,
@@ -603,8 +600,8 @@ impl Copilot {
                             this.handle_buffer_event(buffer, event, cx).log_err();
                         }),
                         cx.observe_release(buffer, move |this, _buffer, _cx| {
-                            this.buffers.remove(&buffer_id);
-                            this.unregister_buffer(buffer_id);
+                            this.buffers.remove(&weak_buffer);
+                            this.unregister_buffer(&weak_buffer);
                         }),
                     ],
                 }
@@ -619,8 +616,7 @@ impl Copilot {
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         if let Ok(server) = self.server.as_running() {
-            let buffer_id = buffer.read(cx).remote_id();
-            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer_id) {
+            if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
                 match event {
                     language::Event::Edited => {
                         let _ = registered_buffer.report_changes(&buffer, cx);
@@ -674,9 +670,9 @@ impl Copilot {
         Ok(())
     }
 
-    fn unregister_buffer(&mut self, buffer_id: u64) {
+    fn unregister_buffer(&mut self, buffer: &WeakModelHandle<Buffer>) {
         if let Ok(server) = self.server.as_running() {
-            if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
+            if let Some(buffer) = server.registered_buffers.remove(&buffer.id()) {
                 server
                     .lsp
                     .notify::<lsp::notification::DidCloseTextDocument>(
@@ -779,8 +775,7 @@ impl Copilot {
             Err(error) => return Task::ready(Err(error)),
         };
         let lsp = server.lsp.clone();
-        let buffer_id = buffer.read(cx).remote_id();
-        let registered_buffer = server.registered_buffers.get_mut(&buffer_id).unwrap();
+        let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
         let snapshot = registered_buffer.report_changes(buffer, cx);
         let buffer = buffer.read(cx);
         let uri = registered_buffer.uri.clone();
@@ -850,7 +845,7 @@ impl Copilot {
         lsp_status: request::SignInStatus,
         cx: &mut ModelContext<Self>,
     ) {
-        self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
+        self.buffers.retain(|buffer| buffer.is_upgradable(cx));
 
         if let Ok(server) = self.server.as_running() {
             match lsp_status {
@@ -858,7 +853,7 @@ impl Copilot {
                 | request::SignInStatus::MaybeOk { .. }
                 | request::SignInStatus::AlreadySignedIn { .. } => {
                     server.sign_in_status = SignInStatus::Authorized;
-                    for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
+                    for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
                         if let Some(buffer) = buffer.upgrade(cx) {
                             self.register_buffer(&buffer, cx);
                         }
@@ -866,14 +861,14 @@ impl Copilot {
                 }
                 request::SignInStatus::NotAuthorized { .. } => {
                     server.sign_in_status = SignInStatus::Unauthorized;
-                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
-                        self.unregister_buffer(buffer_id);
+                    for buffer in self.buffers.iter().copied().collect::<Vec<_>>() {
+                        self.unregister_buffer(&buffer);
                     }
                 }
                 request::SignInStatus::NotSignedIn => {
                     server.sign_in_status = SignInStatus::SignedOut;
-                    for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
-                        self.unregister_buffer(buffer_id);
+                    for buffer in self.buffers.iter().copied().collect::<Vec<_>>() {
+                        self.unregister_buffer(&buffer);
                     }
                 }
             }
@@ -896,9 +891,7 @@ fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
     if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
         lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
     } else {
-        format!("buffer://{}", buffer.read(cx).remote_id())
-            .parse()
-            .unwrap()
+        format!("buffer://{}", buffer.id()).parse().unwrap()
     }
 }