Reopen file in Copilot language server when language or URI changes

Antonio Scandurra created

Change summary

crates/copilot/src/copilot.rs     | 85 ++++++++++++++++++++++----------
crates/editor/src/editor.rs       |  1 
crates/editor/src/multi_buffer.rs |  2 
crates/language/src/buffer.rs     |  2 
4 files changed, 64 insertions(+), 26 deletions(-)

Detailed changes

crates/copilot/src/copilot.rs 🔗

@@ -21,6 +21,7 @@ use settings::Settings;
 use smol::{fs, io::BufReader, stream::StreamExt};
 use std::{
     ffi::OsString,
+    mem,
     ops::Range,
     path::{Path, PathBuf},
     sync::Arc,
@@ -148,7 +149,9 @@ impl Status {
 
 struct RegisteredBuffer {
     uri: lsp::Url,
-    snapshot: Option<(i32, BufferSnapshot)>,
+    language_id: String,
+    snapshot: BufferSnapshot,
+    snapshot_version: i32,
     _subscriptions: [gpui::Subscription; 2],
 }
 
@@ -158,20 +161,15 @@ impl RegisteredBuffer {
         buffer: &ModelHandle<Buffer>,
         server: &LanguageServer,
         cx: &AppContext,
-    ) -> Result<(i32, BufferSnapshot)> {
+    ) -> Result<()> {
         let buffer = buffer.read(cx);
-        let (version, prev_snapshot) = self
-            .snapshot
-            .as_ref()
-            .ok_or_else(|| anyhow!("expected at least one snapshot"))?;
-        let next_snapshot = buffer.snapshot();
-
+        let new_snapshot = buffer.snapshot();
         let content_changes = buffer
-            .edits_since::<(PointUtf16, usize)>(prev_snapshot.version())
+            .edits_since::<(PointUtf16, usize)>(self.snapshot.version())
             .map(|edit| {
                 let edit_start = edit.new.start.0;
                 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
-                let new_text = next_snapshot
+                let new_text = new_snapshot
                     .text_for_range(edit.new.start.1..edit.new.end.1)
                     .collect();
                 lsp::TextDocumentContentChangeEvent {
@@ -185,24 +183,21 @@ impl RegisteredBuffer {
             })
             .collect::<Vec<_>>();
 
-        if content_changes.is_empty() {
-            Ok((*version, prev_snapshot.clone()))
-        } else {
-            let next_version = version + 1;
-            self.snapshot = Some((next_version, next_snapshot.clone()));
-
+        if !content_changes.is_empty() {
+            self.snapshot_version += 1;
+            self.snapshot = new_snapshot;
             server.notify::<lsp::notification::DidChangeTextDocument>(
                 lsp::DidChangeTextDocumentParams {
                     text_document: lsp::VersionedTextDocumentIdentifier::new(
                         self.uri.clone(),
-                        next_version,
+                        self.snapshot_version,
                     ),
                     content_changes,
                 },
             )?;
-
-            Ok((next_version, next_snapshot))
         }
+
+        Ok(())
     }
 }
 
@@ -515,15 +510,16 @@ impl Copilot {
                 return;
             }
 
-            let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
             registered_buffers.entry(buffer.id()).or_insert_with(|| {
+                let uri: lsp::Url = uri_for_buffer(buffer, cx);
+                let language_id = id_for_language(buffer.read(cx).language());
                 let snapshot = buffer.read(cx).snapshot();
                 server
                     .notify::<lsp::notification::DidOpenTextDocument>(
                         lsp::DidOpenTextDocumentParams {
                             text_document: lsp::TextDocumentItem {
                                 uri: uri.clone(),
-                                language_id: id_for_language(buffer.read(cx).language()),
+                                language_id: language_id.clone(),
                                 version: 0,
                                 text: snapshot.text(),
                             },
@@ -533,7 +529,9 @@ impl Copilot {
 
                 RegisteredBuffer {
                     uri,
-                    snapshot: Some((0, snapshot)),
+                    language_id,
+                    snapshot,
+                    snapshot_version: 0,
                     _subscriptions: [
                         cx.subscribe(buffer, |this, buffer, event, cx| {
                             this.handle_buffer_event(buffer, event, cx).log_err();
@@ -575,6 +573,31 @@ impl Copilot {
                             },
                         )?;
                     }
+                    language::Event::FileHandleChanged | language::Event::LanguageChanged => {
+                        let new_language_id = id_for_language(buffer.read(cx).language());
+                        let new_uri = uri_for_buffer(&buffer, cx);
+                        if new_uri != registered_buffer.uri
+                            || new_language_id != registered_buffer.language_id
+                        {
+                            let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
+                            registered_buffer.language_id = new_language_id;
+                            server.notify::<lsp::notification::DidCloseTextDocument>(
+                                lsp::DidCloseTextDocumentParams {
+                                    text_document: lsp::TextDocumentIdentifier::new(old_uri),
+                                },
+                            )?;
+                            server.notify::<lsp::notification::DidOpenTextDocument>(
+                                lsp::DidOpenTextDocumentParams {
+                                    text_document: lsp::TextDocumentItem::new(
+                                        registered_buffer.uri.clone(),
+                                        registered_buffer.language_id.clone(),
+                                        registered_buffer.snapshot_version,
+                                        registered_buffer.snapshot.text(),
+                                    ),
+                                },
+                            )?;
+                        }
+                    }
                     _ => {}
                 }
             }
@@ -659,6 +682,10 @@ impl Copilot {
             } => {
                 if matches!(status, SignInStatus::Authorized { .. }) {
                     if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) {
+                        if let Err(error) = registered_buffer.report_changes(buffer, &server, cx) {
+                            return Task::ready(Err(error));
+                        }
+
                         (server.clone(), registered_buffer)
                     } else {
                         return Task::ready(Err(anyhow!(
@@ -671,11 +698,9 @@ impl Copilot {
             }
         };
 
-        let (version, snapshot) = match registered_buffer.report_changes(buffer, &server, cx) {
-            Ok((version, snapshot)) => (version, snapshot),
-            Err(error) => return Task::ready(Err(error)),
-        };
         let uri = registered_buffer.uri.clone();
+        let snapshot = registered_buffer.snapshot.clone();
+        let version = registered_buffer.snapshot_version;
         let settings = cx.global::<Settings>();
         let position = position.to_point_utf16(&snapshot);
         let language = snapshot.language_at(position);
@@ -784,6 +809,14 @@ fn id_for_language(language: Option<&Arc<Language>>) -> String {
     }
 }
 
+fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
+    if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
+        lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
+    } else {
+        format!("buffer://{}", buffer.id()).parse().unwrap()
+    }
+}
+
 async fn clear_copilot_dir() {
     remove_matching(&paths::COPILOT_DIR, |_| true).await
 }

crates/editor/src/editor.rs 🔗

@@ -6643,6 +6643,7 @@ impl Editor {
             multi_buffer::Event::DiagnosticsUpdated => {
                 self.refresh_active_diagnostics(cx);
             }
+            multi_buffer::Event::LanguageChanged => {}
         }
     }
 

crates/editor/src/multi_buffer.rs 🔗

@@ -64,6 +64,7 @@ pub enum Event {
     },
     Edited,
     Reloaded,
+    LanguageChanged,
     Reparsed,
     Saved,
     FileHandleChanged,
@@ -1302,6 +1303,7 @@ impl MultiBuffer {
             language::Event::Saved => Event::Saved,
             language::Event::FileHandleChanged => Event::FileHandleChanged,
             language::Event::Reloaded => Event::Reloaded,
+            language::Event::LanguageChanged => Event::LanguageChanged,
             language::Event::Reparsed => Event::Reparsed,
             language::Event::DiagnosticsUpdated => Event::DiagnosticsUpdated,
             language::Event::Closed => Event::Closed,

crates/language/src/buffer.rs 🔗

@@ -187,6 +187,7 @@ pub enum Event {
     Saved,
     FileHandleChanged,
     Reloaded,
+    LanguageChanged,
     Reparsed,
     DiagnosticsUpdated,
     Closed,
@@ -536,6 +537,7 @@ impl Buffer {
         self.syntax_map.lock().clear();
         self.language = language;
         self.reparse(cx);
+        cx.emit(Event::LanguageChanged);
     }
 
     pub fn set_language_registry(&mut self, language_registry: Arc<LanguageRegistry>) {