From 672cf6b8c789a9882effab185ad120e6adff42ca Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 19 Apr 2023 12:19:24 +0200 Subject: [PATCH 1/8] Relay buffer change events to Copilot --- Cargo.lock | 1 + crates/copilot/src/copilot.rs | 338 ++++++++++++++++++++++++---------- crates/copilot/src/request.rs | 3 - crates/project/Cargo.toml | 1 + crates/project/src/project.rs | 24 +++ 5 files changed, 268 insertions(+), 99 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f05512b76d4a574ce15e391426e027bce3e7088..bb931853fcac984bd2133bd4aaa6dbe038217933 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4687,6 +4687,7 @@ dependencies = [ "client", "clock", "collections", + "copilot", "ctor", "db", "env_logger", diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 1967c3cd14d0045110811680d5cce39d191a8d60..57abd0893917c6f6ecba313f5a18e51f5125b4d6 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -6,8 +6,13 @@ use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use collections::HashMap; use futures::{future::Shared, Future, FutureExt, TryFutureExt}; -use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; -use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16}; +use gpui::{ + actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle, +}; +use language::{ + point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, + ToPointUtf16, +}; use log::{debug, error}; use lsp::LanguageServer; use node_runtime::NodeRuntime; @@ -105,7 +110,7 @@ enum CopilotServer { Started { server: Arc, status: SignInStatus, - subscriptions_by_buffer_id: HashMap, + registered_buffers: HashMap, }, } @@ -141,6 +146,66 @@ impl Status { } } +struct RegisteredBuffer { + uri: lsp::Url, + snapshot: Option<(i32, BufferSnapshot)>, + _subscriptions: [gpui::Subscription; 2], +} + +impl RegisteredBuffer { + fn report_changes( + &mut self, + buffer: &ModelHandle, + server: &LanguageServer, + cx: &AppContext, + ) -> Result<(i32, BufferSnapshot)> { + let buffer = buffer.read(cx); + let (version, prev_snapshot) = self + .snapshot + .as_ref() + .ok_or_else(|| anyhow!("expected at least one snapshot"))?; + let next_snapshot = buffer.snapshot(); + + let content_changes = buffer + .edits_since::<(PointUtf16, usize)>(prev_snapshot.version()) + .map(|edit| { + let edit_start = edit.new.start.0; + let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); + let new_text = next_snapshot + .text_for_range(edit.new.start.1..edit.new.end.1) + .collect(); + lsp::TextDocumentContentChangeEvent { + range: Some(lsp::Range::new( + point_to_lsp(edit_start), + point_to_lsp(edit_end), + )), + range_length: None, + text: new_text, + } + }) + .collect::>(); + + if content_changes.is_empty() { + Ok((*version, prev_snapshot.clone())) + } else { + let next_version = version + 1; + self.snapshot = Some((next_version, next_snapshot.clone())); + + server.notify::( + lsp::DidChangeTextDocumentParams { + text_document: lsp::VersionedTextDocumentIdentifier::new( + self.uri.clone(), + next_version, + ), + content_changes, + }, + )?; + + Ok((next_version, next_snapshot)) + } + } +} + #[derive(Debug, PartialEq, Eq)] pub struct Completion { pub range: Range, @@ -151,6 +216,7 @@ pub struct Copilot { http: Arc, node_runtime: Arc, server: CopilotServer, + buffers: HashMap>, } impl Entity for Copilot { @@ -212,12 +278,14 @@ impl Copilot { http, node_runtime, server: CopilotServer::Starting { task: start_task }, + buffers: Default::default(), } } else { Self { http, node_runtime, server: CopilotServer::Disabled, + buffers: Default::default(), } } } @@ -233,8 +301,9 @@ impl Copilot { server: CopilotServer::Started { server: Arc::new(server), status: SignInStatus::Authorized, - subscriptions_by_buffer_id: Default::default(), + registered_buffers: Default::default(), }, + buffers: Default::default(), }); (this, fake_server) } @@ -297,7 +366,7 @@ impl Copilot { this.server = CopilotServer::Started { server, status: SignInStatus::SignedOut, - subscriptions_by_buffer_id: Default::default(), + registered_buffers: Default::default(), }; this.update_sign_in_status(status, cx); } @@ -396,10 +465,8 @@ impl Copilot { } fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status, .. } = &mut self.server { - *status = SignInStatus::SignedOut; - cx.notify(); - + self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); + if let CopilotServer::Started { server, .. } = &self.server { let server = server.clone(); cx.background().spawn(async move { server @@ -433,6 +500,108 @@ impl Copilot { cx.foreground().spawn(start_task) } + pub fn register_buffer(&mut self, buffer: &ModelHandle, cx: &mut ModelContext) { + let buffer_id = buffer.id(); + self.buffers.insert(buffer_id, buffer.downgrade()); + + if let CopilotServer::Started { + server, + status, + registered_buffers, + .. + } = &mut self.server + { + if !matches!(status, SignInStatus::Authorized { .. }) { + return; + } + + let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap(); + registered_buffers.entry(buffer.id()).or_insert_with(|| { + let snapshot = buffer.read(cx).snapshot(); + server + .notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem { + uri: uri.clone(), + language_id: id_for_language(buffer.read(cx).language()), + version: 0, + text: snapshot.text(), + }, + }, + ) + .log_err(); + + RegisteredBuffer { + uri, + snapshot: Some((0, snapshot)), + _subscriptions: [ + cx.subscribe(buffer, |this, buffer, event, cx| { + this.handle_buffer_event(buffer, event, cx).log_err(); + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + this.buffers.remove(&buffer_id); + this.unregister_buffer(buffer_id); + }), + ], + } + }); + } + } + + fn handle_buffer_event( + &mut self, + buffer: ModelHandle, + event: &language::Event, + cx: &mut ModelContext, + ) -> Result<()> { + if let CopilotServer::Started { + server, + registered_buffers, + .. + } = &mut self.server + { + if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { + match event { + language::Event::Edited => { + registered_buffer.report_changes(&buffer, server, cx)?; + } + language::Event::Saved => { + server.notify::( + lsp::DidSaveTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new( + registered_buffer.uri.clone(), + ), + text: None, + }, + )?; + } + _ => {} + } + } + } + + Ok(()) + } + + fn unregister_buffer(&mut self, buffer_id: usize) { + if let CopilotServer::Started { + server, + registered_buffers, + .. + } = &mut self.server + { + if let Some(buffer) = registered_buffers.remove(&buffer_id) { + server + .notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer.uri), + }, + ) + .log_err(); + } + } + } + pub fn completions( &mut self, buffer: &ModelHandle, @@ -464,16 +633,14 @@ impl Copilot { cx: &mut ModelContext, ) -> Task>> where - R: lsp::request::Request< - Params = request::GetCompletionsParams, - Result = request::GetCompletionsResult, - >, + R: 'static + + lsp::request::Request< + Params = request::GetCompletionsParams, + Result = request::GetCompletionsResult, + >, T: ToPointUtf16, { - let buffer_id = buffer.id(); - let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap(); - let snapshot = buffer.read(cx).snapshot(); - let server = match &mut self.server { + let (server, registered_buffer) = match &mut self.server { CopilotServer::Starting { .. } => { return Task::ready(Err(anyhow!("copilot is still starting"))) } @@ -487,56 +654,28 @@ impl Copilot { CopilotServer::Started { server, status, - subscriptions_by_buffer_id, + registered_buffers, + .. } => { if matches!(status, SignInStatus::Authorized { .. }) { - subscriptions_by_buffer_id - .entry(buffer_id) - .or_insert_with(|| { - server - .notify::( - lsp::DidOpenTextDocumentParams { - text_document: lsp::TextDocumentItem { - uri: uri.clone(), - language_id: id_for_language( - buffer.read(cx).language(), - ), - version: 0, - text: snapshot.text(), - }, - }, - ) - .log_err(); - - let uri = uri.clone(); - cx.observe_release(buffer, move |this, _, _| { - if let CopilotServer::Started { - server, - subscriptions_by_buffer_id, - .. - } = &mut this.server - { - server - .notify::( - lsp::DidCloseTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new( - uri.clone(), - ), - }, - ) - .log_err(); - subscriptions_by_buffer_id.remove(&buffer_id); - } - }) - }); - - server.clone() + if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { + (server.clone(), registered_buffer) + } else { + return Task::ready(Err(anyhow!( + "requested completions for an unregistered buffer" + ))); + } } else { return Task::ready(Err(anyhow!("must sign in before using copilot"))); } } }; + let (version, snapshot) = match registered_buffer.report_changes(buffer, &server, cx) { + Ok((version, snapshot)) => (version, snapshot), + Err(error) => return Task::ready(Err(error)), + }; + let uri = registered_buffer.uri.clone(); let settings = cx.global::(); let position = position.to_point_utf16(&snapshot); let language = snapshot.language_at(position); @@ -544,39 +683,23 @@ impl Copilot { let language_name = language_name.as_deref(); let tab_size = settings.tab_size(language_name); let hard_tabs = settings.hard_tabs(language_name); - let language_id = id_for_language(language); - - let path; - let relative_path; - if let Some(file) = snapshot.file() { - if let Some(file) = file.as_local() { - path = file.abs_path(cx); - } else { - path = file.full_path(cx); - } - relative_path = file.path().to_path_buf(); - } else { - path = PathBuf::new(); - relative_path = PathBuf::new(); - } - + let relative_path = snapshot + .file() + .map(|file| file.path().to_path_buf()) + .unwrap_or_default(); + let request = server.request::(request::GetCompletionsParams { + doc: request::GetCompletionsDocument { + uri, + tab_size: tab_size.into(), + indent_size: 1, + insert_spaces: !hard_tabs, + relative_path: relative_path.to_string_lossy().into(), + position: point_to_lsp(position), + version: version.try_into().unwrap(), + }, + }); cx.background().spawn(async move { - let result = server - .request::(request::GetCompletionsParams { - doc: request::GetCompletionsDocument { - source: snapshot.text(), - tab_size: tab_size.into(), - indent_size: 1, - insert_spaces: !hard_tabs, - uri, - path: path.to_string_lossy().into(), - relative_path: relative_path.to_string_lossy().into(), - language_id, - position: point_to_lsp(position), - version: 0, - }, - }) - .await?; + let result = request.await?; let completions = result .completions .into_iter() @@ -616,14 +739,37 @@ impl Copilot { lsp_status: request::SignInStatus, cx: &mut ModelContext, ) { + self.buffers.retain(|_, buffer| buffer.is_upgradable(cx)); + if let CopilotServer::Started { status, .. } = &mut self.server { - *status = match lsp_status { + match lsp_status { request::SignInStatus::Ok { .. } | request::SignInStatus::MaybeOk { .. } - | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized, - request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized, - request::SignInStatus::NotSignedIn => SignInStatus::SignedOut, - }; + | request::SignInStatus::AlreadySignedIn { .. } => { + *status = SignInStatus::Authorized; + + for buffer in self.buffers.values().cloned().collect::>() { + if let Some(buffer) = buffer.upgrade(cx) { + self.register_buffer(&buffer, cx); + } + } + } + request::SignInStatus::NotAuthorized { .. } => { + *status = SignInStatus::Unauthorized; + + for buffer_id in self.buffers.keys().copied().collect::>() { + self.unregister_buffer(buffer_id); + } + } + request::SignInStatus::NotSignedIn => { + *status = SignInStatus::SignedOut; + + for buffer_id in self.buffers.keys().copied().collect::>() { + self.unregister_buffer(buffer_id); + } + } + } + cx.notify(); } } diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 415f160ea3a9335d8272c2122e2db285f3f16290..08173c413aae5b0e8ddf011941cff850202b7465 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -99,14 +99,11 @@ pub struct GetCompletionsParams { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetCompletionsDocument { - pub source: String, pub tab_size: u32, pub indent_size: u32, pub insert_spaces: bool, pub uri: lsp::Url, - pub path: String, pub relative_path: String, - pub language_id: String, pub position: lsp::Position, pub version: usize, } diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index f5c144a3adb915c61ae52af38fa78cc961b00639..e30ab56e45d1b088aa5ce811dfe246094c00c413 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -19,6 +19,7 @@ test-support = [ [dependencies] text = { path = "../text" } +copilot = { path = "../copilot" } client = { path = "../client" } clock = { path = "../clock" } collections = { path = "../collections" } diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index d126cb4994b86f592004290de29923900cca1a14..d5b7ac3f3f80afb4c85b1776e974587c35e9913d 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -12,6 +12,7 @@ use anyhow::{anyhow, Context, Result}; use client::{proto, Client, TypedEnvelope, UserStore}; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet}; +use copilot::Copilot; use futures::{ channel::mpsc::{self, UnboundedReceiver}, future::{try_join_all, Shared}, @@ -129,6 +130,7 @@ pub struct Project { _maintain_buffer_languages: Task<()>, _maintain_workspace_config: Task<()>, terminals: Terminals, + copilot_enabled: bool, } enum BufferMessage { @@ -472,6 +474,7 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, + copilot_enabled: Copilot::global(cx).is_some(), } }) } @@ -559,6 +562,7 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, + copilot_enabled: Copilot::global(cx).is_some(), }; for worktree in worktrees { let _ = this.add_worktree(&worktree, cx); @@ -664,6 +668,15 @@ impl Project { self.start_language_server(worktree_id, worktree_path, language, cx); } + if !self.copilot_enabled && Copilot::global(cx).is_some() { + self.copilot_enabled = true; + for buffer in self.opened_buffers.values() { + if let Some(buffer) = buffer.upgrade(cx) { + self.register_buffer_with_copilot(&buffer, cx); + } + } + } + cx.notify(); } @@ -1616,6 +1629,7 @@ impl Project { self.detect_language_for_buffer(buffer, cx); self.register_buffer_with_language_server(buffer, cx); + self.register_buffer_with_copilot(buffer, cx); cx.observe_release(buffer, |this, buffer, cx| { if let Some(file) = File::from_dyn(buffer.file()) { if file.is_local() { @@ -1731,6 +1745,16 @@ impl Project { }); } + fn register_buffer_with_copilot( + &self, + buffer_handle: &ModelHandle, + cx: &mut ModelContext, + ) { + if let Some(copilot) = Copilot::global(cx) { + copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); + } + } + async fn send_buffer_messages( this: WeakModelHandle, rx: UnboundedReceiver, From 34bcf6f07263825da81ed762611856a002d2617d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 19 Apr 2023 14:27:26 +0200 Subject: [PATCH 2/8] Reopen file in Copilot language server when language or URI changes --- crates/copilot/src/copilot.rs | 85 +++++++++++++++++++++---------- crates/editor/src/editor.rs | 1 + crates/editor/src/multi_buffer.rs | 2 + crates/language/src/buffer.rs | 2 + 4 files changed, 64 insertions(+), 26 deletions(-) diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 57abd0893917c6f6ecba313f5a18e51f5125b4d6..ebe139c1cf5652de5010c4270b0314869cddde58 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -21,6 +21,7 @@ use settings::Settings; use smol::{fs, io::BufReader, stream::StreamExt}; use std::{ ffi::OsString, + mem, ops::Range, path::{Path, PathBuf}, sync::Arc, @@ -148,7 +149,9 @@ impl Status { struct RegisteredBuffer { uri: lsp::Url, - snapshot: Option<(i32, BufferSnapshot)>, + language_id: String, + snapshot: BufferSnapshot, + snapshot_version: i32, _subscriptions: [gpui::Subscription; 2], } @@ -158,20 +161,15 @@ impl RegisteredBuffer { buffer: &ModelHandle, server: &LanguageServer, cx: &AppContext, - ) -> Result<(i32, BufferSnapshot)> { + ) -> Result<()> { let buffer = buffer.read(cx); - let (version, prev_snapshot) = self - .snapshot - .as_ref() - .ok_or_else(|| anyhow!("expected at least one snapshot"))?; - let next_snapshot = buffer.snapshot(); - + let new_snapshot = buffer.snapshot(); let content_changes = buffer - .edits_since::<(PointUtf16, usize)>(prev_snapshot.version()) + .edits_since::<(PointUtf16, usize)>(self.snapshot.version()) .map(|edit| { let edit_start = edit.new.start.0; let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); - let new_text = next_snapshot + let new_text = new_snapshot .text_for_range(edit.new.start.1..edit.new.end.1) .collect(); lsp::TextDocumentContentChangeEvent { @@ -185,24 +183,21 @@ impl RegisteredBuffer { }) .collect::>(); - if content_changes.is_empty() { - Ok((*version, prev_snapshot.clone())) - } else { - let next_version = version + 1; - self.snapshot = Some((next_version, next_snapshot.clone())); - + if !content_changes.is_empty() { + self.snapshot_version += 1; + self.snapshot = new_snapshot; server.notify::( lsp::DidChangeTextDocumentParams { text_document: lsp::VersionedTextDocumentIdentifier::new( self.uri.clone(), - next_version, + self.snapshot_version, ), content_changes, }, )?; - - Ok((next_version, next_snapshot)) } + + Ok(()) } } @@ -515,15 +510,16 @@ impl Copilot { return; } - let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap(); registered_buffers.entry(buffer.id()).or_insert_with(|| { + let uri: lsp::Url = uri_for_buffer(buffer, cx); + let language_id = id_for_language(buffer.read(cx).language()); let snapshot = buffer.read(cx).snapshot(); server .notify::( lsp::DidOpenTextDocumentParams { text_document: lsp::TextDocumentItem { uri: uri.clone(), - language_id: id_for_language(buffer.read(cx).language()), + language_id: language_id.clone(), version: 0, text: snapshot.text(), }, @@ -533,7 +529,9 @@ impl Copilot { RegisteredBuffer { uri, - snapshot: Some((0, snapshot)), + language_id, + snapshot, + snapshot_version: 0, _subscriptions: [ cx.subscribe(buffer, |this, buffer, event, cx| { this.handle_buffer_event(buffer, event, cx).log_err(); @@ -575,6 +573,31 @@ impl Copilot { }, )?; } + language::Event::FileHandleChanged | language::Event::LanguageChanged => { + let new_language_id = id_for_language(buffer.read(cx).language()); + let new_uri = uri_for_buffer(&buffer, cx); + if new_uri != registered_buffer.uri + || new_language_id != registered_buffer.language_id + { + let old_uri = mem::replace(&mut registered_buffer.uri, new_uri); + registered_buffer.language_id = new_language_id; + server.notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(old_uri), + }, + )?; + server.notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + registered_buffer.uri.clone(), + registered_buffer.language_id.clone(), + registered_buffer.snapshot_version, + registered_buffer.snapshot.text(), + ), + }, + )?; + } + } _ => {} } } @@ -659,6 +682,10 @@ impl Copilot { } => { if matches!(status, SignInStatus::Authorized { .. }) { if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { + if let Err(error) = registered_buffer.report_changes(buffer, &server, cx) { + return Task::ready(Err(error)); + } + (server.clone(), registered_buffer) } else { return Task::ready(Err(anyhow!( @@ -671,11 +698,9 @@ impl Copilot { } }; - let (version, snapshot) = match registered_buffer.report_changes(buffer, &server, cx) { - Ok((version, snapshot)) => (version, snapshot), - Err(error) => return Task::ready(Err(error)), - }; let uri = registered_buffer.uri.clone(); + let snapshot = registered_buffer.snapshot.clone(); + let version = registered_buffer.snapshot_version; let settings = cx.global::(); let position = position.to_point_utf16(&snapshot); let language = snapshot.language_at(position); @@ -784,6 +809,14 @@ fn id_for_language(language: Option<&Arc>) -> String { } } +fn uri_for_buffer(buffer: &ModelHandle, cx: &AppContext) -> lsp::Url { + if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) { + lsp::Url::from_file_path(file.abs_path(cx)).unwrap() + } else { + format!("buffer://{}", buffer.id()).parse().unwrap() + } +} + async fn clear_copilot_dir() { remove_matching(&paths::COPILOT_DIR, |_| true).await } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 742541293e96346919f995bec8641ec347de2025..83b92cc06c4a358ae2f7773a3bfdc20ca7df1f83 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -6643,6 +6643,7 @@ impl Editor { multi_buffer::Event::DiagnosticsUpdated => { self.refresh_active_diagnostics(cx); } + multi_buffer::Event::LanguageChanged => {} } } diff --git a/crates/editor/src/multi_buffer.rs b/crates/editor/src/multi_buffer.rs index f8a56557abb633f832ce2d23dcba264fc0535fa7..824c108e46806f77d03002311b0cfc66fa67a20b 100644 --- a/crates/editor/src/multi_buffer.rs +++ b/crates/editor/src/multi_buffer.rs @@ -64,6 +64,7 @@ pub enum Event { }, Edited, Reloaded, + LanguageChanged, Reparsed, Saved, FileHandleChanged, @@ -1302,6 +1303,7 @@ impl MultiBuffer { language::Event::Saved => Event::Saved, language::Event::FileHandleChanged => Event::FileHandleChanged, language::Event::Reloaded => Event::Reloaded, + language::Event::LanguageChanged => Event::LanguageChanged, language::Event::Reparsed => Event::Reparsed, language::Event::DiagnosticsUpdated => Event::DiagnosticsUpdated, language::Event::Closed => Event::Closed, diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index fb16d416409523261aeb010674a6e94473e6bbc9..7325ca9af53dcee7d163533305f18bf81e48d61f 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -187,6 +187,7 @@ pub enum Event { Saved, FileHandleChanged, Reloaded, + LanguageChanged, Reparsed, DiagnosticsUpdated, Closed, @@ -536,6 +537,7 @@ impl Buffer { self.syntax_map.lock().clear(); self.language = language; self.reparse(cx); + cx.emit(Event::LanguageChanged); } pub fn set_language_registry(&mut self, language_registry: Arc) { From b9a7b70e52870aa121bd3551abe5ef38d06bf4c5 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 19 Apr 2023 14:30:23 +0200 Subject: [PATCH 3/8] Register unknown buffer on the fly if completions are requested for it --- crates/copilot/src/copilot.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index ebe139c1cf5652de5010c4270b0314869cddde58..71843402a63bd25a0d792be46baa6dfdbc47e5f2 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -663,6 +663,7 @@ impl Copilot { >, T: ToPointUtf16, { + self.register_buffer(buffer, cx); let (server, registered_buffer) = match &mut self.server { CopilotServer::Starting { .. } => { return Task::ready(Err(anyhow!("copilot is still starting"))) @@ -681,17 +682,11 @@ impl Copilot { .. } => { if matches!(status, SignInStatus::Authorized { .. }) { - if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { - if let Err(error) = registered_buffer.report_changes(buffer, &server, cx) { - return Task::ready(Err(error)); - } - - (server.clone(), registered_buffer) - } else { - return Task::ready(Err(anyhow!( - "requested completions for an unregistered buffer" - ))); + let registered_buffer = registered_buffers.get_mut(&buffer.id()).unwrap(); + if let Err(error) = registered_buffer.report_changes(buffer, &server, cx) { + return Task::ready(Err(error)); } + (server.clone(), registered_buffer) } else { return Task::ready(Err(anyhow!("must sign in before using copilot"))); } From 4c3d6c854afd284630105f75d7e5d7081923e7d1 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 19 Apr 2023 15:45:55 +0200 Subject: [PATCH 4/8] Send editor information to copilot --- crates/copilot/src/copilot.rs | 13 +++++++++++++ crates/copilot/src/request.rs | 29 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 71843402a63bd25a0d792be46baa6dfdbc47e5f2..4b67bc16504b454348b9675d3d7cd1b3eeb18bf9 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -350,6 +350,19 @@ impl Copilot { ) .detach(); + server + .request::(request::SetEditorInfoParams { + editor_info: request::EditorInfo { + name: "zed".into(), + version: env!("CARGO_PKG_VERSION").into(), + }, + editor_plugin_info: request::EditorPluginInfo { + name: "zed-copilot".into(), + version: "0.0.1".into(), + }, + }) + .await?; + anyhow::Ok((server, status)) }; diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 08173c413aae5b0e8ddf011941cff850202b7465..0d43bb7debc5401094b1b96bdd93aad994ab056a 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -166,3 +166,32 @@ impl lsp::notification::Notification for StatusNotification { type Params = StatusNotificationParams; const METHOD: &'static str = "statusNotification"; } + +pub enum SetEditorInfo {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SetEditorInfoParams { + pub editor_info: EditorInfo, + pub editor_plugin_info: EditorPluginInfo, +} + +impl lsp::request::Request for SetEditorInfo { + type Params = SetEditorInfoParams; + type Result = String; + const METHOD: &'static str = "setEditorInfo"; +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EditorInfo { + pub name: String, + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EditorPluginInfo { + pub name: String, + pub version: String, +} From 5d571673025491a800336c0d398c1140f50886e7 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 20 Apr 2023 09:59:19 +0200 Subject: [PATCH 5/8] Make it easier to access a running/authenticated copilot server --- crates/copilot/src/copilot.rs | 223 +++++++++++++++++----------------- 1 file changed, 113 insertions(+), 110 deletions(-) diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 4b67bc16504b454348b9675d3d7cd1b3eeb18bf9..28bc7977c101e1ba51933d3408ef5e98010fb746 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -104,15 +104,38 @@ pub fn init(http: Arc, node_runtime: Arc, cx: &mut enum CopilotServer { Disabled, - Starting { - task: Shared>, - }, + Starting { task: Shared> }, Error(Arc), - Started { - server: Arc, - status: SignInStatus, - registered_buffers: HashMap, - }, + Running(RunningCopilotServer), +} + +impl CopilotServer { + fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> { + let server = self.as_running()?; + if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) { + Ok(server) + } else { + Err(anyhow!("must sign in before using copilot")) + } + } + + fn as_running(&mut self) -> Result<&mut RunningCopilotServer> { + match self { + CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")), + CopilotServer::Disabled => Err(anyhow!("copilot is disabled")), + CopilotServer::Error(error) => Err(anyhow!( + "copilot was not started because of an error: {}", + error + )), + CopilotServer::Running(server) => Ok(server), + } + } +} + +struct RunningCopilotServer { + lsp: Arc, + sign_in_status: SignInStatus, + registered_buffers: HashMap, } #[derive(Clone, Debug)] @@ -293,11 +316,11 @@ impl Copilot { let this = cx.add_model(|cx| Self { http: http.clone(), node_runtime: NodeRuntime::new(http, cx.background().clone()), - server: CopilotServer::Started { - server: Arc::new(server), - status: SignInStatus::Authorized, + server: CopilotServer::Running(RunningCopilotServer { + lsp: Arc::new(server), + sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), - }, + }), buffers: Default::default(), }); (this, fake_server) @@ -371,11 +394,11 @@ impl Copilot { cx.notify(); match server { Ok((server, status)) => { - this.server = CopilotServer::Started { - server, - status: SignInStatus::SignedOut, + this.server = CopilotServer::Running(RunningCopilotServer { + lsp: server, + sign_in_status: SignInStatus::SignedOut, registered_buffers: Default::default(), - }; + }); this.update_sign_in_status(status, cx); } Err(error) => { @@ -388,8 +411,8 @@ impl Copilot { } fn sign_in(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status, .. } = &mut self.server { - let task = match status { + if let CopilotServer::Running(server) = &mut self.server { + let task = match &server.sign_in_status { SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => { Task::ready(Ok(())).shared() } @@ -398,11 +421,11 @@ impl Copilot { task.clone() } SignInStatus::SignedOut => { - let server = server.clone(); + let lsp = server.lsp.clone(); let task = cx .spawn(|this, mut cx| async move { let sign_in = async { - let sign_in = server + let sign_in = lsp .request::( request::SignInInitiateParams {}, ) @@ -413,8 +436,10 @@ impl Copilot { } request::SignInInitiateResult::PromptUserDeviceFlow(flow) => { this.update(&mut cx, |this, cx| { - if let CopilotServer::Started { status, .. } = - &mut this.server + if let CopilotServer::Running(RunningCopilotServer { + sign_in_status: status, + .. + }) = &mut this.server { if let SignInStatus::SigningIn { prompt: prompt_flow, @@ -426,7 +451,7 @@ impl Copilot { } } }); - let response = server + let response = lsp .request::( request::SignInConfirmParams { user_code: flow.user_code, @@ -454,7 +479,7 @@ impl Copilot { }) }) .shared(); - *status = SignInStatus::SigningIn { + server.sign_in_status = SignInStatus::SigningIn { prompt: None, task: task.clone(), }; @@ -474,7 +499,7 @@ impl Copilot { fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); - if let CopilotServer::Started { server, .. } = &self.server { + if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server { let server = server.clone(); cx.background().spawn(async move { server @@ -512,12 +537,12 @@ impl Copilot { let buffer_id = buffer.id(); self.buffers.insert(buffer_id, buffer.downgrade()); - if let CopilotServer::Started { - server, - status, + if let CopilotServer::Running(RunningCopilotServer { + lsp: server, + sign_in_status: status, registered_buffers, .. - } = &mut self.server + }) = &mut self.server { if !matches!(status, SignInStatus::Authorized { .. }) { return; @@ -565,26 +590,23 @@ impl Copilot { event: &language::Event, cx: &mut ModelContext, ) -> Result<()> { - if let CopilotServer::Started { - server, - registered_buffers, - .. - } = &mut self.server - { - if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { + if let Ok(server) = self.server.as_running() { + if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) { match event { language::Event::Edited => { - registered_buffer.report_changes(&buffer, server, cx)?; + registered_buffer.report_changes(&buffer, &server.lsp, cx)?; } language::Event::Saved => { - server.notify::( - lsp::DidSaveTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new( - registered_buffer.uri.clone(), - ), - text: None, - }, - )?; + server + .lsp + .notify::( + lsp::DidSaveTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new( + registered_buffer.uri.clone(), + ), + text: None, + }, + )?; } language::Event::FileHandleChanged | language::Event::LanguageChanged => { let new_language_id = id_for_language(buffer.read(cx).language()); @@ -594,21 +616,25 @@ impl Copilot { { let old_uri = mem::replace(&mut registered_buffer.uri, new_uri); registered_buffer.language_id = new_language_id; - server.notify::( - lsp::DidCloseTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new(old_uri), - }, - )?; - server.notify::( - lsp::DidOpenTextDocumentParams { - text_document: lsp::TextDocumentItem::new( - registered_buffer.uri.clone(), - registered_buffer.language_id.clone(), - registered_buffer.snapshot_version, - registered_buffer.snapshot.text(), - ), - }, - )?; + server + .lsp + .notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(old_uri), + }, + )?; + server + .lsp + .notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + registered_buffer.uri.clone(), + registered_buffer.language_id.clone(), + registered_buffer.snapshot_version, + registered_buffer.snapshot.text(), + ), + }, + )?; } } _ => {} @@ -620,14 +646,10 @@ impl Copilot { } fn unregister_buffer(&mut self, buffer_id: usize) { - if let CopilotServer::Started { - server, - registered_buffers, - .. - } = &mut self.server - { - if let Some(buffer) = registered_buffers.remove(&buffer_id) { + if let Ok(server) = self.server.as_running() { + if let Some(buffer) = server.registered_buffers.remove(&buffer_id) { server + .lsp .notify::( lsp::DidCloseTextDocumentParams { text_document: lsp::TextDocumentIdentifier::new(buffer.uri), @@ -677,34 +699,15 @@ impl Copilot { T: ToPointUtf16, { self.register_buffer(buffer, cx); - let (server, registered_buffer) = match &mut self.server { - CopilotServer::Starting { .. } => { - return Task::ready(Err(anyhow!("copilot is still starting"))) - } - CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))), - CopilotServer::Error(error) => { - return Task::ready(Err(anyhow!( - "copilot was not started because of an error: {}", - error - ))) - } - CopilotServer::Started { - server, - status, - registered_buffers, - .. - } => { - if matches!(status, SignInStatus::Authorized { .. }) { - let registered_buffer = registered_buffers.get_mut(&buffer.id()).unwrap(); - if let Err(error) = registered_buffer.report_changes(buffer, &server, cx) { - return Task::ready(Err(error)); - } - (server.clone(), registered_buffer) - } else { - return Task::ready(Err(anyhow!("must sign in before using copilot"))); - } - } + + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), }; + let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap(); + if let Err(error) = registered_buffer.report_changes(buffer, &server.lsp, cx) { + return Task::ready(Err(error)); + } let uri = registered_buffer.uri.clone(); let snapshot = registered_buffer.snapshot.clone(); @@ -720,7 +723,7 @@ impl Copilot { .file() .map(|file| file.path().to_path_buf()) .unwrap_or_default(); - let request = server.request::(request::GetCompletionsParams { + let request = server.lsp.request::(request::GetCompletionsParams { doc: request::GetCompletionsDocument { uri, tab_size: tab_size.into(), @@ -742,6 +745,7 @@ impl Copilot { let end = snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); Completion { + uuid: completion.uuid, range: snapshot.anchor_before(start)..snapshot.anchor_after(end), text: completion.text, } @@ -756,14 +760,16 @@ impl Copilot { CopilotServer::Starting { task } => Status::Starting { task: task.clone() }, CopilotServer::Disabled => Status::Disabled, CopilotServer::Error(error) => Status::Error(error.clone()), - CopilotServer::Started { status, .. } => match status { - SignInStatus::Authorized { .. } => Status::Authorized, - SignInStatus::Unauthorized { .. } => Status::Unauthorized, - SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { - prompt: prompt.clone(), - }, - SignInStatus::SignedOut => Status::SignedOut, - }, + CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => { + match sign_in_status { + SignInStatus::Authorized { .. } => Status::Authorized, + SignInStatus::Unauthorized { .. } => Status::Unauthorized, + SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { + prompt: prompt.clone(), + }, + SignInStatus::SignedOut => Status::SignedOut, + } + } } } @@ -774,13 +780,12 @@ impl Copilot { ) { self.buffers.retain(|_, buffer| buffer.is_upgradable(cx)); - if let CopilotServer::Started { status, .. } = &mut self.server { + if let Ok(server) = self.server.as_running() { match lsp_status { request::SignInStatus::Ok { .. } | request::SignInStatus::MaybeOk { .. } | request::SignInStatus::AlreadySignedIn { .. } => { - *status = SignInStatus::Authorized; - + server.sign_in_status = SignInStatus::Authorized; for buffer in self.buffers.values().cloned().collect::>() { if let Some(buffer) = buffer.upgrade(cx) { self.register_buffer(&buffer, cx); @@ -788,15 +793,13 @@ impl Copilot { } } request::SignInStatus::NotAuthorized { .. } => { - *status = SignInStatus::Unauthorized; - + server.sign_in_status = SignInStatus::Unauthorized; for buffer_id in self.buffers.keys().copied().collect::>() { self.unregister_buffer(buffer_id); } } request::SignInStatus::NotSignedIn => { - *status = SignInStatus::SignedOut; - + server.sign_in_status = SignInStatus::SignedOut; for buffer_id in self.buffers.keys().copied().collect::>() { self.unregister_buffer(buffer_id); } From 4d207981ae0ca73800afc6c4634b6c996d421be9 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 20 Apr 2023 10:12:13 +0200 Subject: [PATCH 6/8] Notify LSP when Copilot suggestions are accepted/rejected --- crates/copilot/src/copilot.rs | 48 ++++++++++++++++++++++++- crates/copilot/src/request.rs | 28 +++++++++++++++ crates/editor/src/editor.rs | 68 +++++++++++++++++++++++------------ 3 files changed, 120 insertions(+), 24 deletions(-) diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 28bc7977c101e1ba51933d3408ef5e98010fb746..6b32cc260372fd6091e1ffd99d986d52f1333c5b 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -224,8 +224,9 @@ impl RegisteredBuffer { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug)] pub struct Completion { + uuid: String, pub range: Range, pub text: String, } @@ -684,6 +685,51 @@ impl Copilot { self.request_completions::(buffer, position, cx) } + pub fn accept_completion( + &mut self, + completion: &Completion, + cx: &mut ModelContext, + ) -> Task> { + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + let request = + server + .lsp + .request::(request::NotifyAcceptedParams { + uuid: completion.uuid.clone(), + }); + cx.background().spawn(async move { + request.await?; + Ok(()) + }) + } + + pub fn discard_completions( + &mut self, + completions: &[Completion], + cx: &mut ModelContext, + ) -> Task> { + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + let request = + server + .lsp + .request::(request::NotifyRejectedParams { + uuids: completions + .iter() + .map(|completion| completion.uuid.clone()) + .collect(), + }); + cx.background().spawn(async move { + request.await?; + Ok(()) + }) + } + fn request_completions( &mut self, buffer: &ModelHandle, diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 0d43bb7debc5401094b1b96bdd93aad994ab056a..43b5109d027dd2bcc62d1e9a5cb48a5996e84fd7 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -195,3 +195,31 @@ pub struct EditorPluginInfo { pub name: String, pub version: String, } + +pub enum NotifyAccepted {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NotifyAcceptedParams { + pub uuid: String, +} + +impl lsp::request::Request for NotifyAccepted { + type Params = NotifyAcceptedParams; + type Result = String; + const METHOD: &'static str = "notifyAccepted"; +} + +pub enum NotifyRejected {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NotifyRejectedParams { + pub uuids: Vec, +} + +impl lsp::request::Request for NotifyRejected { + type Params = NotifyRejectedParams; + type Result = String; + const METHOD: &'static str = "notifyRejected"; +} diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 83b92cc06c4a358ae2f7773a3bfdc20ca7df1f83..59f8ab6acb434b24d41c048903d37f1041491779 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -52,7 +52,7 @@ pub use language::{char_kind, CharKind}; use language::{ AutoindentMode, BracketPair, Buffer, CodeAction, CodeLabel, Completion, CursorShape, Diagnostic, DiagnosticSeverity, IndentKind, IndentSize, Language, OffsetRangeExt, OffsetUtf16, - Point, Rope, Selection, SelectionGoal, TransactionId, + Point, Selection, SelectionGoal, TransactionId, }; use link_go_to_definition::{ hide_link_definition, show_link_definition, LinkDefinitionKind, LinkGoToDefinitionState, @@ -1037,6 +1037,10 @@ impl Default for CopilotState { } impl CopilotState { + fn active_completion(&self) -> Option<&copilot::Completion> { + self.completions.get(self.active_completion_index) + } + fn text_for_active_completion( &self, cursor: Anchor, @@ -1044,7 +1048,7 @@ impl CopilotState { ) -> Option<&str> { use language::ToOffset as _; - let completion = self.completions.get(self.active_completion_index)?; + let completion = self.active_completion()?; let excerpt_id = self.excerpt_id?; let completion_buffer = buffer.buffer_for_excerpt(excerpt_id)?; if excerpt_id != cursor.excerpt_id @@ -1097,7 +1101,7 @@ impl CopilotState { fn push_completion(&mut self, new_completion: copilot::Completion) { for completion in &self.completions { - if *completion == new_completion { + if completion.text == new_completion.text && completion.range == new_completion.range { return; } } @@ -1496,7 +1500,7 @@ impl Editor { self.refresh_code_actions(cx); self.refresh_document_highlights(cx); refresh_matching_bracket_highlights(self, cx); - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } self.blink_manager.update(cx, BlinkManager::pause_blinking); @@ -1870,7 +1874,7 @@ impl Editor { return; } - if self.hide_copilot_suggestion(cx).is_some() { + if self.discard_copilot_suggestion(cx) { return; } @@ -2969,7 +2973,7 @@ impl Editor { Some(()) } - fn cycle_suggestions( + fn cycle_copilot_suggestions( &mut self, direction: Direction, cx: &mut ViewContext, @@ -3020,7 +3024,7 @@ impl Editor { fn next_copilot_suggestion(&mut self, _: &copilot::NextSuggestion, cx: &mut ViewContext) { if self.has_active_copilot_suggestion(cx) { - self.cycle_suggestions(Direction::Next, cx); + self.cycle_copilot_suggestions(Direction::Next, cx); } else { self.refresh_copilot_suggestions(false, cx); } @@ -3032,37 +3036,55 @@ impl Editor { cx: &mut ViewContext, ) { if self.has_active_copilot_suggestion(cx) { - self.cycle_suggestions(Direction::Prev, cx); + self.cycle_copilot_suggestions(Direction::Prev, cx); } else { self.refresh_copilot_suggestions(false, cx); } } fn accept_copilot_suggestion(&mut self, cx: &mut ViewContext) -> bool { - if let Some(text) = self.hide_copilot_suggestion(cx) { - self.insert_with_autoindent_mode(&text.to_string(), None, cx); + if let Some(suggestion) = self + .display_map + .update(cx, |map, cx| map.replace_suggestion::(None, cx)) + { + if let Some((copilot, completion)) = + Copilot::global(cx).zip(self.copilot_state.active_completion()) + { + copilot + .update(cx, |copilot, cx| copilot.accept_completion(completion, cx)) + .detach_and_log_err(cx); + } + self.insert_with_autoindent_mode(&suggestion.text.to_string(), None, cx); + cx.notify(); true } else { false } } - fn has_active_copilot_suggestion(&self, cx: &AppContext) -> bool { - self.display_map.read(cx).has_suggestion() - } - - fn hide_copilot_suggestion(&mut self, cx: &mut ViewContext) -> Option { + fn discard_copilot_suggestion(&mut self, cx: &mut ViewContext) -> bool { if self.has_active_copilot_suggestion(cx) { - let old_suggestion = self - .display_map + if let Some(copilot) = Copilot::global(cx) { + copilot + .update(cx, |copilot, cx| { + copilot.discard_completions(&self.copilot_state.completions, cx) + }) + .detach_and_log_err(cx); + } + + self.display_map .update(cx, |map, cx| map.replace_suggestion::(None, cx)); cx.notify(); - old_suggestion.map(|suggestion| suggestion.text) + true } else { - None + false } } + fn has_active_copilot_suggestion(&self, cx: &AppContext) -> bool { + self.display_map.read(cx).has_suggestion() + } + fn update_visible_copilot_suggestion(&mut self, cx: &mut ViewContext) { let snapshot = self.buffer.read(cx).snapshot(cx); let selection = self.selections.newest_anchor(); @@ -3072,7 +3094,7 @@ impl Editor { || !self.completion_tasks.is_empty() || selection.start != selection.end { - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } else if let Some(text) = self .copilot_state .text_for_active_completion(cursor, &snapshot) @@ -3088,13 +3110,13 @@ impl Editor { }); cx.notify(); } else { - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } } fn clear_copilot_suggestions(&mut self, cx: &mut ViewContext) { self.copilot_state = Default::default(); - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } pub fn render_code_actions_indicator( @@ -3212,7 +3234,7 @@ impl Editor { self.completion_tasks.clear(); } self.context_menu = Some(menu); - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); cx.notify(); } From 4151bd39da112ee95fa5e8744a0e034128744103 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 20 Apr 2023 10:51:50 +0200 Subject: [PATCH 7/8] Add buffer management test to Copilot --- Cargo.lock | 3 + crates/copilot/Cargo.toml | 5 +- crates/copilot/src/copilot.rs | 223 ++++++++++++++++++++++++++++++++++ 3 files changed, 230 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index bb931853fcac984bd2133bd4aaa6dbe038217933..d7c02497983110b042cda2807d1c179549dfd559 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1341,14 +1341,17 @@ dependencies = [ "anyhow", "async-compression", "async-tar", + "clock", "collections", "context_menu", + "fs", "futures 0.3.25", "gpui", "language", "log", "lsp", "node_runtime", + "rpc", "serde", "serde_derive", "settings", diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 9d68edd6b5b6a626d7a082d132c1d135a8fd2d84..bfafdbc0ca9522fe1553ee4ac38f4fc934007020 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -38,10 +38,13 @@ smol = "1.2.5" futures = "0.3" [dev-dependencies] +clock = { path = "../clock" } collections = { path = "../collections", features = ["test-support"] } +fs = { path = "../fs", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } -settings = { path = "../settings", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } +settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 6b32cc260372fd6091e1ffd99d986d52f1333c5b..a558307a70fc783c7312be170f5231b818e255a2 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -945,3 +945,226 @@ async fn get_copilot_lsp(http: Arc) -> anyhow::Result { } } } + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{executor::Deterministic, TestAppContext}; + + #[gpui::test(iterations = 10)] + async fn test_buffer_management(deterministic: Arc, cx: &mut TestAppContext) { + deterministic.forbid_parking(); + let (copilot, mut lsp) = Copilot::fake(cx); + + let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx)); + let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap(); + copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_1_uri.clone(), + "plaintext".into(), + 0, + "Hello".into() + ), + } + ); + + let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx)); + let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap(); + copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_2_uri.clone(), + "plaintext".into(), + 0, + "Goodbye".into() + ), + } + ); + + buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidChangeTextDocumentParams { + text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1), + content_changes: vec![lsp::TextDocumentContentChangeEvent { + range: Some(lsp::Range::new( + lsp::Position::new(0, 5), + lsp::Position::new(0, 5) + )), + range_length: None, + text: " world".into(), + }], + } + ); + + // Ensure updates to the file are reflected in the LSP. + buffer_1 + .update(cx, |buffer, cx| { + buffer.file_updated( + Arc::new(File { + abs_path: "/root/child/buffer-1".into(), + path: Path::new("child/buffer-1").into(), + }), + cx, + ) + }) + .await; + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri), + } + ); + let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap(); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_1_uri.clone(), + "plaintext".into(), + 1, + "Hello world".into() + ), + } + ); + + // Ensure all previously-registered buffers are closed when signing out. + lsp.handle_request::(|_, _| async { + Ok(request::SignOutResult {}) + }); + copilot + .update(cx, |copilot, cx| copilot.sign_out(cx)) + .await + .unwrap(); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()), + } + ); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()), + } + ); + + // Ensure all previously-registered buffers are re-opened when signing in. + lsp.handle_request::(|_, _| async { + Ok(request::SignInInitiateResult::AlreadySignedIn { + user: "user-1".into(), + }) + }); + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .await + .unwrap(); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_2_uri.clone(), + "plaintext".into(), + 0, + "Goodbye".into() + ), + } + ); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_1_uri.clone(), + "plaintext".into(), + 0, + "Hello world".into() + ), + } + ); + + // Dropping a buffer causes it to be closed on the LSP side as well. + cx.update(|_| drop(buffer_2)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri), + } + ); + } + + struct File { + abs_path: PathBuf, + path: Arc, + } + + impl language::File for File { + fn as_local(&self) -> Option<&dyn language::LocalFile> { + Some(self) + } + + fn mtime(&self) -> std::time::SystemTime { + todo!() + } + + fn path(&self) -> &Arc { + &self.path + } + + fn full_path(&self, _: &AppContext) -> PathBuf { + todo!() + } + + fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr { + todo!() + } + + fn is_deleted(&self) -> bool { + todo!() + } + + fn as_any(&self) -> &dyn std::any::Any { + todo!() + } + + fn to_proto(&self) -> rpc::proto::File { + todo!() + } + } + + impl language::LocalFile for File { + fn abs_path(&self, _: &AppContext) -> PathBuf { + self.abs_path.clone() + } + + fn load(&self, _: &AppContext) -> Task> { + todo!() + } + + fn buffer_reloaded( + &self, + _: u64, + _: &clock::Global, + _: language::RopeFingerprint, + _: ::fs::LineEnding, + _: std::time::SystemTime, + _: &mut AppContext, + ) { + todo!() + } + } +} From df71a9cfaeb6c68f90db3d8d3c28c391f859be6e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 20 Apr 2023 11:57:37 +0200 Subject: [PATCH 8/8] Move buffer change reporting to a background task --- crates/copilot/src/copilot.rs | 165 +++++++++++++++++++++------------- 1 file changed, 104 insertions(+), 61 deletions(-) diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index a558307a70fc783c7312be170f5231b818e255a2..c3ec63c43ce0c4da875c01cf90a08c13fb9837c8 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -5,7 +5,7 @@ use anyhow::{anyhow, Context, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use collections::HashMap; -use futures::{future::Shared, Future, FutureExt, TryFutureExt}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt}; use gpui::{ actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle, }; @@ -171,56 +171,97 @@ impl Status { } struct RegisteredBuffer { + id: usize, uri: lsp::Url, language_id: String, snapshot: BufferSnapshot, snapshot_version: i32, _subscriptions: [gpui::Subscription; 2], + pending_buffer_change: Task>, } impl RegisteredBuffer { fn report_changes( &mut self, buffer: &ModelHandle, - server: &LanguageServer, - cx: &AppContext, - ) -> Result<()> { - let buffer = buffer.read(cx); - let new_snapshot = buffer.snapshot(); - let content_changes = buffer - .edits_since::<(PointUtf16, usize)>(self.snapshot.version()) - .map(|edit| { - let edit_start = edit.new.start.0; - let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); - let new_text = new_snapshot - .text_for_range(edit.new.start.1..edit.new.end.1) - .collect(); - lsp::TextDocumentContentChangeEvent { - range: Some(lsp::Range::new( - point_to_lsp(edit_start), - point_to_lsp(edit_end), - )), - range_length: None, - text: new_text, - } - }) - .collect::>(); - - if !content_changes.is_empty() { - self.snapshot_version += 1; - self.snapshot = new_snapshot; - server.notify::( - lsp::DidChangeTextDocumentParams { - text_document: lsp::VersionedTextDocumentIdentifier::new( - self.uri.clone(), - self.snapshot_version, - ), - content_changes, - }, - )?; + cx: &mut ModelContext, + ) -> oneshot::Receiver<(i32, BufferSnapshot)> { + let id = self.id; + let (done_tx, done_rx) = oneshot::channel(); + + if buffer.read(cx).version() == self.snapshot.version { + let _ = done_tx.send((self.snapshot_version, self.snapshot.clone())); + } else { + let buffer = buffer.downgrade(); + let prev_pending_change = + mem::replace(&mut self.pending_buffer_change, Task::ready(None)); + self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move { + prev_pending_change.await; + + let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + Some(buffer.snapshot.version.clone()) + })?; + let new_snapshot = buffer + .upgrade(&cx)? + .read_with(&cx, |buffer, _| buffer.snapshot()); + + let content_changes = cx + .background() + .spawn({ + let new_snapshot = new_snapshot.clone(); + async move { + new_snapshot + .edits_since::<(PointUtf16, usize)>(&old_version) + .map(|edit| { + let edit_start = edit.new.start.0; + let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); + let new_text = new_snapshot + .text_for_range(edit.new.start.1..edit.new.end.1) + .collect(); + lsp::TextDocumentContentChangeEvent { + range: Some(lsp::Range::new( + point_to_lsp(edit_start), + point_to_lsp(edit_end), + )), + range_length: None, + text: new_text, + } + }) + .collect::>() + } + }) + .await; + + copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + if !content_changes.is_empty() { + buffer.snapshot_version += 1; + buffer.snapshot = new_snapshot; + server + .lsp + .notify::( + lsp::DidChangeTextDocumentParams { + text_document: lsp::VersionedTextDocumentIdentifier::new( + buffer.uri.clone(), + buffer.snapshot_version, + ), + content_changes, + }, + ) + .log_err(); + } + let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone())); + Some(()) + })?; + + Some(()) + }); } - Ok(()) + done_rx } } @@ -567,10 +608,12 @@ impl Copilot { .log_err(); RegisteredBuffer { + id: buffer_id, uri, language_id, snapshot, snapshot_version: 0, + pending_buffer_change: Task::ready(Some(())), _subscriptions: [ cx.subscribe(buffer, |this, buffer, event, cx| { this.handle_buffer_event(buffer, event, cx).log_err(); @@ -595,7 +638,7 @@ impl Copilot { if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) { match event { language::Event::Edited => { - registered_buffer.report_changes(&buffer, &server.lsp, cx)?; + let _ = registered_buffer.report_changes(&buffer, cx); } language::Event::Saved => { server @@ -750,38 +793,38 @@ impl Copilot { Ok(server) => server, Err(error) => return Task::ready(Err(error)), }; + let lsp = server.lsp.clone(); let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap(); - if let Err(error) = registered_buffer.report_changes(buffer, &server.lsp, cx) { - return Task::ready(Err(error)); - } - + let snapshot = registered_buffer.report_changes(buffer, cx); + let buffer = buffer.read(cx); let uri = registered_buffer.uri.clone(); - let snapshot = registered_buffer.snapshot.clone(); - let version = registered_buffer.snapshot_version; let settings = cx.global::(); - let position = position.to_point_utf16(&snapshot); - let language = snapshot.language_at(position); + let position = position.to_point_utf16(buffer); + let language = buffer.language_at(position); let language_name = language.map(|language| language.name()); let language_name = language_name.as_deref(); let tab_size = settings.tab_size(language_name); let hard_tabs = settings.hard_tabs(language_name); - let relative_path = snapshot + let relative_path = buffer .file() .map(|file| file.path().to_path_buf()) .unwrap_or_default(); - let request = server.lsp.request::(request::GetCompletionsParams { - doc: request::GetCompletionsDocument { - uri, - tab_size: tab_size.into(), - indent_size: 1, - insert_spaces: !hard_tabs, - relative_path: relative_path.to_string_lossy().into(), - position: point_to_lsp(position), - version: version.try_into().unwrap(), - }, - }); - cx.background().spawn(async move { - let result = request.await?; + + cx.foreground().spawn(async move { + let (version, snapshot) = snapshot.await?; + let result = lsp + .request::(request::GetCompletionsParams { + doc: request::GetCompletionsDocument { + uri, + tab_size: tab_size.into(), + indent_size: 1, + insert_spaces: !hard_tabs, + relative_path: relative_path.to_string_lossy().into(), + position: point_to_lsp(position), + version: version.try_into().unwrap(), + }, + }) + .await?; let completions = result .completions .into_iter()