diff --git a/Cargo.lock b/Cargo.lock index 6f05512b76d4a574ce15e391426e027bce3e7088..d7c02497983110b042cda2807d1c179549dfd559 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1341,14 +1341,17 @@ dependencies = [ "anyhow", "async-compression", "async-tar", + "clock", "collections", "context_menu", + "fs", "futures 0.3.25", "gpui", "language", "log", "lsp", "node_runtime", + "rpc", "serde", "serde_derive", "settings", @@ -4687,6 +4690,7 @@ dependencies = [ "client", "clock", "collections", + "copilot", "ctor", "db", "env_logger", diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 9d68edd6b5b6a626d7a082d132c1d135a8fd2d84..bfafdbc0ca9522fe1553ee4ac38f4fc934007020 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -38,10 +38,13 @@ smol = "1.2.5" futures = "0.3" [dev-dependencies] +clock = { path = "../clock" } collections = { path = "../collections", features = ["test-support"] } +fs = { path = "../fs", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } -settings = { path = "../settings", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } +settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 1967c3cd14d0045110811680d5cce39d191a8d60..c3ec63c43ce0c4da875c01cf90a08c13fb9837c8 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -5,9 +5,14 @@ use anyhow::{anyhow, Context, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use collections::HashMap; -use futures::{future::Shared, Future, FutureExt, TryFutureExt}; -use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; -use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt}; +use gpui::{ + actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle, +}; +use language::{ + point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, + ToPointUtf16, +}; use log::{debug, error}; use lsp::LanguageServer; use node_runtime::NodeRuntime; @@ -16,6 +21,7 @@ use settings::Settings; use smol::{fs, io::BufReader, stream::StreamExt}; use std::{ ffi::OsString, + mem, ops::Range, path::{Path, PathBuf}, sync::Arc, @@ -98,15 +104,38 @@ pub fn init(http: Arc, node_runtime: Arc, cx: &mut enum CopilotServer { Disabled, - Starting { - task: Shared>, - }, + Starting { task: Shared> }, Error(Arc), - Started { - server: Arc, - status: SignInStatus, - subscriptions_by_buffer_id: HashMap, - }, + Running(RunningCopilotServer), +} + +impl CopilotServer { + fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> { + let server = self.as_running()?; + if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) { + Ok(server) + } else { + Err(anyhow!("must sign in before using copilot")) + } + } + + fn as_running(&mut self) -> Result<&mut RunningCopilotServer> { + match self { + CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")), + CopilotServer::Disabled => Err(anyhow!("copilot is disabled")), + CopilotServer::Error(error) => Err(anyhow!( + "copilot was not started because of an error: {}", + error + )), + CopilotServer::Running(server) => Ok(server), + } + } +} + +struct RunningCopilotServer { + lsp: Arc, + sign_in_status: SignInStatus, + registered_buffers: HashMap, } #[derive(Clone, Debug)] @@ -141,8 +170,104 @@ impl Status { } } -#[derive(Debug, PartialEq, Eq)] +struct RegisteredBuffer { + id: usize, + uri: lsp::Url, + language_id: String, + snapshot: BufferSnapshot, + snapshot_version: i32, + _subscriptions: [gpui::Subscription; 2], + pending_buffer_change: Task>, +} + +impl RegisteredBuffer { + fn report_changes( + &mut self, + buffer: &ModelHandle, + cx: &mut ModelContext, + ) -> oneshot::Receiver<(i32, BufferSnapshot)> { + let id = self.id; + let (done_tx, done_rx) = oneshot::channel(); + + if buffer.read(cx).version() == self.snapshot.version { + let _ = done_tx.send((self.snapshot_version, self.snapshot.clone())); + } else { + let buffer = buffer.downgrade(); + let prev_pending_change = + mem::replace(&mut self.pending_buffer_change, Task::ready(None)); + self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move { + prev_pending_change.await; + + let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + Some(buffer.snapshot.version.clone()) + })?; + let new_snapshot = buffer + .upgrade(&cx)? + .read_with(&cx, |buffer, _| buffer.snapshot()); + + let content_changes = cx + .background() + .spawn({ + let new_snapshot = new_snapshot.clone(); + async move { + new_snapshot + .edits_since::<(PointUtf16, usize)>(&old_version) + .map(|edit| { + let edit_start = edit.new.start.0; + let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); + let new_text = new_snapshot + .text_for_range(edit.new.start.1..edit.new.end.1) + .collect(); + lsp::TextDocumentContentChangeEvent { + range: Some(lsp::Range::new( + point_to_lsp(edit_start), + point_to_lsp(edit_end), + )), + range_length: None, + text: new_text, + } + }) + .collect::>() + } + }) + .await; + + copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + if !content_changes.is_empty() { + buffer.snapshot_version += 1; + buffer.snapshot = new_snapshot; + server + .lsp + .notify::( + lsp::DidChangeTextDocumentParams { + text_document: lsp::VersionedTextDocumentIdentifier::new( + buffer.uri.clone(), + buffer.snapshot_version, + ), + content_changes, + }, + ) + .log_err(); + } + let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone())); + Some(()) + })?; + + Some(()) + }); + } + + done_rx + } +} + +#[derive(Debug)] pub struct Completion { + uuid: String, pub range: Range, pub text: String, } @@ -151,6 +276,7 @@ pub struct Copilot { http: Arc, node_runtime: Arc, server: CopilotServer, + buffers: HashMap>, } impl Entity for Copilot { @@ -212,12 +338,14 @@ impl Copilot { http, node_runtime, server: CopilotServer::Starting { task: start_task }, + buffers: Default::default(), } } else { Self { http, node_runtime, server: CopilotServer::Disabled, + buffers: Default::default(), } } } @@ -230,11 +358,12 @@ impl Copilot { let this = cx.add_model(|cx| Self { http: http.clone(), node_runtime: NodeRuntime::new(http, cx.background().clone()), - server: CopilotServer::Started { - server: Arc::new(server), - status: SignInStatus::Authorized, - subscriptions_by_buffer_id: Default::default(), - }, + server: CopilotServer::Running(RunningCopilotServer { + lsp: Arc::new(server), + sign_in_status: SignInStatus::Authorized, + registered_buffers: Default::default(), + }), + buffers: Default::default(), }); (this, fake_server) } @@ -286,6 +415,19 @@ impl Copilot { ) .detach(); + server + .request::(request::SetEditorInfoParams { + editor_info: request::EditorInfo { + name: "zed".into(), + version: env!("CARGO_PKG_VERSION").into(), + }, + editor_plugin_info: request::EditorPluginInfo { + name: "zed-copilot".into(), + version: "0.0.1".into(), + }, + }) + .await?; + anyhow::Ok((server, status)) }; @@ -294,11 +436,11 @@ impl Copilot { cx.notify(); match server { Ok((server, status)) => { - this.server = CopilotServer::Started { - server, - status: SignInStatus::SignedOut, - subscriptions_by_buffer_id: Default::default(), - }; + this.server = CopilotServer::Running(RunningCopilotServer { + lsp: server, + sign_in_status: SignInStatus::SignedOut, + registered_buffers: Default::default(), + }); this.update_sign_in_status(status, cx); } Err(error) => { @@ -311,8 +453,8 @@ impl Copilot { } fn sign_in(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status, .. } = &mut self.server { - let task = match status { + if let CopilotServer::Running(server) = &mut self.server { + let task = match &server.sign_in_status { SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => { Task::ready(Ok(())).shared() } @@ -321,11 +463,11 @@ impl Copilot { task.clone() } SignInStatus::SignedOut => { - let server = server.clone(); + let lsp = server.lsp.clone(); let task = cx .spawn(|this, mut cx| async move { let sign_in = async { - let sign_in = server + let sign_in = lsp .request::( request::SignInInitiateParams {}, ) @@ -336,8 +478,10 @@ impl Copilot { } request::SignInInitiateResult::PromptUserDeviceFlow(flow) => { this.update(&mut cx, |this, cx| { - if let CopilotServer::Started { status, .. } = - &mut this.server + if let CopilotServer::Running(RunningCopilotServer { + sign_in_status: status, + .. + }) = &mut this.server { if let SignInStatus::SigningIn { prompt: prompt_flow, @@ -349,7 +493,7 @@ impl Copilot { } } }); - let response = server + let response = lsp .request::( request::SignInConfirmParams { user_code: flow.user_code, @@ -377,7 +521,7 @@ impl Copilot { }) }) .shared(); - *status = SignInStatus::SigningIn { + server.sign_in_status = SignInStatus::SigningIn { prompt: None, task: task.clone(), }; @@ -396,10 +540,8 @@ impl Copilot { } fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status, .. } = &mut self.server { - *status = SignInStatus::SignedOut; - cx.notify(); - + self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); + if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server { let server = server.clone(); cx.background().spawn(async move { server @@ -433,6 +575,135 @@ impl Copilot { cx.foreground().spawn(start_task) } + pub fn register_buffer(&mut self, buffer: &ModelHandle, cx: &mut ModelContext) { + let buffer_id = buffer.id(); + self.buffers.insert(buffer_id, buffer.downgrade()); + + if let CopilotServer::Running(RunningCopilotServer { + lsp: server, + sign_in_status: status, + registered_buffers, + .. + }) = &mut self.server + { + if !matches!(status, SignInStatus::Authorized { .. }) { + return; + } + + registered_buffers.entry(buffer.id()).or_insert_with(|| { + let uri: lsp::Url = uri_for_buffer(buffer, cx); + let language_id = id_for_language(buffer.read(cx).language()); + let snapshot = buffer.read(cx).snapshot(); + server + .notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem { + uri: uri.clone(), + language_id: language_id.clone(), + version: 0, + text: snapshot.text(), + }, + }, + ) + .log_err(); + + RegisteredBuffer { + id: buffer_id, + uri, + language_id, + snapshot, + snapshot_version: 0, + pending_buffer_change: Task::ready(Some(())), + _subscriptions: [ + cx.subscribe(buffer, |this, buffer, event, cx| { + this.handle_buffer_event(buffer, event, cx).log_err(); + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + this.buffers.remove(&buffer_id); + this.unregister_buffer(buffer_id); + }), + ], + } + }); + } + } + + fn handle_buffer_event( + &mut self, + buffer: ModelHandle, + event: &language::Event, + cx: &mut ModelContext, + ) -> Result<()> { + if let Ok(server) = self.server.as_running() { + if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) { + match event { + language::Event::Edited => { + let _ = registered_buffer.report_changes(&buffer, cx); + } + language::Event::Saved => { + server + .lsp + .notify::( + lsp::DidSaveTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new( + registered_buffer.uri.clone(), + ), + text: None, + }, + )?; + } + language::Event::FileHandleChanged | language::Event::LanguageChanged => { + let new_language_id = id_for_language(buffer.read(cx).language()); + let new_uri = uri_for_buffer(&buffer, cx); + if new_uri != registered_buffer.uri + || new_language_id != registered_buffer.language_id + { + let old_uri = mem::replace(&mut registered_buffer.uri, new_uri); + registered_buffer.language_id = new_language_id; + server + .lsp + .notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(old_uri), + }, + )?; + server + .lsp + .notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + registered_buffer.uri.clone(), + registered_buffer.language_id.clone(), + registered_buffer.snapshot_version, + registered_buffer.snapshot.text(), + ), + }, + )?; + } + } + _ => {} + } + } + } + + Ok(()) + } + + fn unregister_buffer(&mut self, buffer_id: usize) { + if let Ok(server) = self.server.as_running() { + if let Some(buffer) = server.registered_buffers.remove(&buffer_id) { + server + .lsp + .notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer.uri), + }, + ) + .log_err(); + } + } + } + pub fn completions( &mut self, buffer: &ModelHandle, @@ -457,6 +728,51 @@ impl Copilot { self.request_completions::(buffer, position, cx) } + pub fn accept_completion( + &mut self, + completion: &Completion, + cx: &mut ModelContext, + ) -> Task> { + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + let request = + server + .lsp + .request::(request::NotifyAcceptedParams { + uuid: completion.uuid.clone(), + }); + cx.background().spawn(async move { + request.await?; + Ok(()) + }) + } + + pub fn discard_completions( + &mut self, + completions: &[Completion], + cx: &mut ModelContext, + ) -> Task> { + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + let request = + server + .lsp + .request::(request::NotifyRejectedParams { + uuids: completions + .iter() + .map(|completion| completion.uuid.clone()) + .collect(), + }); + cx.background().spawn(async move { + request.await?; + Ok(()) + }) + } + fn request_completions( &mut self, buffer: &ModelHandle, @@ -464,116 +780,48 @@ impl Copilot { cx: &mut ModelContext, ) -> Task>> where - R: lsp::request::Request< - Params = request::GetCompletionsParams, - Result = request::GetCompletionsResult, - >, + R: 'static + + lsp::request::Request< + Params = request::GetCompletionsParams, + Result = request::GetCompletionsResult, + >, T: ToPointUtf16, { - let buffer_id = buffer.id(); - let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap(); - let snapshot = buffer.read(cx).snapshot(); - let server = match &mut self.server { - CopilotServer::Starting { .. } => { - return Task::ready(Err(anyhow!("copilot is still starting"))) - } - CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))), - CopilotServer::Error(error) => { - return Task::ready(Err(anyhow!( - "copilot was not started because of an error: {}", - error - ))) - } - CopilotServer::Started { - server, - status, - subscriptions_by_buffer_id, - } => { - if matches!(status, SignInStatus::Authorized { .. }) { - subscriptions_by_buffer_id - .entry(buffer_id) - .or_insert_with(|| { - server - .notify::( - lsp::DidOpenTextDocumentParams { - text_document: lsp::TextDocumentItem { - uri: uri.clone(), - language_id: id_for_language( - buffer.read(cx).language(), - ), - version: 0, - text: snapshot.text(), - }, - }, - ) - .log_err(); - - let uri = uri.clone(); - cx.observe_release(buffer, move |this, _, _| { - if let CopilotServer::Started { - server, - subscriptions_by_buffer_id, - .. - } = &mut this.server - { - server - .notify::( - lsp::DidCloseTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new( - uri.clone(), - ), - }, - ) - .log_err(); - subscriptions_by_buffer_id.remove(&buffer_id); - } - }) - }); + self.register_buffer(buffer, cx); - server.clone() - } else { - return Task::ready(Err(anyhow!("must sign in before using copilot"))); - } - } + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), }; - + let lsp = server.lsp.clone(); + let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap(); + let snapshot = registered_buffer.report_changes(buffer, cx); + let buffer = buffer.read(cx); + let uri = registered_buffer.uri.clone(); let settings = cx.global::(); - let position = position.to_point_utf16(&snapshot); - let language = snapshot.language_at(position); + let position = position.to_point_utf16(buffer); + let language = buffer.language_at(position); let language_name = language.map(|language| language.name()); let language_name = language_name.as_deref(); let tab_size = settings.tab_size(language_name); let hard_tabs = settings.hard_tabs(language_name); - let language_id = id_for_language(language); - - let path; - let relative_path; - if let Some(file) = snapshot.file() { - if let Some(file) = file.as_local() { - path = file.abs_path(cx); - } else { - path = file.full_path(cx); - } - relative_path = file.path().to_path_buf(); - } else { - path = PathBuf::new(); - relative_path = PathBuf::new(); - } + let relative_path = buffer + .file() + .map(|file| file.path().to_path_buf()) + .unwrap_or_default(); - cx.background().spawn(async move { - let result = server + cx.foreground().spawn(async move { + let (version, snapshot) = snapshot.await?; + let result = lsp .request::(request::GetCompletionsParams { doc: request::GetCompletionsDocument { - source: snapshot.text(), + uri, tab_size: tab_size.into(), indent_size: 1, insert_spaces: !hard_tabs, - uri, - path: path.to_string_lossy().into(), relative_path: relative_path.to_string_lossy().into(), - language_id, position: point_to_lsp(position), - version: 0, + version: version.try_into().unwrap(), }, }) .await?; @@ -586,6 +834,7 @@ impl Copilot { let end = snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); Completion { + uuid: completion.uuid, range: snapshot.anchor_before(start)..snapshot.anchor_after(end), text: completion.text, } @@ -600,14 +849,16 @@ impl Copilot { CopilotServer::Starting { task } => Status::Starting { task: task.clone() }, CopilotServer::Disabled => Status::Disabled, CopilotServer::Error(error) => Status::Error(error.clone()), - CopilotServer::Started { status, .. } => match status { - SignInStatus::Authorized { .. } => Status::Authorized, - SignInStatus::Unauthorized { .. } => Status::Unauthorized, - SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { - prompt: prompt.clone(), - }, - SignInStatus::SignedOut => Status::SignedOut, - }, + CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => { + match sign_in_status { + SignInStatus::Authorized { .. } => Status::Authorized, + SignInStatus::Unauthorized { .. } => Status::Unauthorized, + SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { + prompt: prompt.clone(), + }, + SignInStatus::SignedOut => Status::SignedOut, + } + } } } @@ -616,14 +867,34 @@ impl Copilot { lsp_status: request::SignInStatus, cx: &mut ModelContext, ) { - if let CopilotServer::Started { status, .. } = &mut self.server { - *status = match lsp_status { + self.buffers.retain(|_, buffer| buffer.is_upgradable(cx)); + + if let Ok(server) = self.server.as_running() { + match lsp_status { request::SignInStatus::Ok { .. } | request::SignInStatus::MaybeOk { .. } - | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized, - request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized, - request::SignInStatus::NotSignedIn => SignInStatus::SignedOut, - }; + | request::SignInStatus::AlreadySignedIn { .. } => { + server.sign_in_status = SignInStatus::Authorized; + for buffer in self.buffers.values().cloned().collect::>() { + if let Some(buffer) = buffer.upgrade(cx) { + self.register_buffer(&buffer, cx); + } + } + } + request::SignInStatus::NotAuthorized { .. } => { + server.sign_in_status = SignInStatus::Unauthorized; + for buffer_id in self.buffers.keys().copied().collect::>() { + self.unregister_buffer(buffer_id); + } + } + request::SignInStatus::NotSignedIn => { + server.sign_in_status = SignInStatus::SignedOut; + for buffer_id in self.buffers.keys().copied().collect::>() { + self.unregister_buffer(buffer_id); + } + } + } + cx.notify(); } } @@ -638,6 +909,14 @@ fn id_for_language(language: Option<&Arc>) -> String { } } +fn uri_for_buffer(buffer: &ModelHandle, cx: &AppContext) -> lsp::Url { + if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) { + lsp::Url::from_file_path(file.abs_path(cx)).unwrap() + } else { + format!("buffer://{}", buffer.id()).parse().unwrap() + } +} + async fn clear_copilot_dir() { remove_matching(&paths::COPILOT_DIR, |_| true).await } @@ -709,3 +988,226 @@ async fn get_copilot_lsp(http: Arc) -> anyhow::Result { } } } + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{executor::Deterministic, TestAppContext}; + + #[gpui::test(iterations = 10)] + async fn test_buffer_management(deterministic: Arc, cx: &mut TestAppContext) { + deterministic.forbid_parking(); + let (copilot, mut lsp) = Copilot::fake(cx); + + let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx)); + let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap(); + copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_1_uri.clone(), + "plaintext".into(), + 0, + "Hello".into() + ), + } + ); + + let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx)); + let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap(); + copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_2_uri.clone(), + "plaintext".into(), + 0, + "Goodbye".into() + ), + } + ); + + buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidChangeTextDocumentParams { + text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1), + content_changes: vec![lsp::TextDocumentContentChangeEvent { + range: Some(lsp::Range::new( + lsp::Position::new(0, 5), + lsp::Position::new(0, 5) + )), + range_length: None, + text: " world".into(), + }], + } + ); + + // Ensure updates to the file are reflected in the LSP. + buffer_1 + .update(cx, |buffer, cx| { + buffer.file_updated( + Arc::new(File { + abs_path: "/root/child/buffer-1".into(), + path: Path::new("child/buffer-1").into(), + }), + cx, + ) + }) + .await; + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri), + } + ); + let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap(); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_1_uri.clone(), + "plaintext".into(), + 1, + "Hello world".into() + ), + } + ); + + // Ensure all previously-registered buffers are closed when signing out. + lsp.handle_request::(|_, _| async { + Ok(request::SignOutResult {}) + }); + copilot + .update(cx, |copilot, cx| copilot.sign_out(cx)) + .await + .unwrap(); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()), + } + ); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()), + } + ); + + // Ensure all previously-registered buffers are re-opened when signing in. + lsp.handle_request::(|_, _| async { + Ok(request::SignInInitiateResult::AlreadySignedIn { + user: "user-1".into(), + }) + }); + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .await + .unwrap(); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_2_uri.clone(), + "plaintext".into(), + 0, + "Goodbye".into() + ), + } + ); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + buffer_1_uri.clone(), + "plaintext".into(), + 0, + "Hello world".into() + ), + } + ); + + // Dropping a buffer causes it to be closed on the LSP side as well. + cx.update(|_| drop(buffer_2)); + assert_eq!( + lsp.receive_notification::() + .await, + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri), + } + ); + } + + struct File { + abs_path: PathBuf, + path: Arc, + } + + impl language::File for File { + fn as_local(&self) -> Option<&dyn language::LocalFile> { + Some(self) + } + + fn mtime(&self) -> std::time::SystemTime { + todo!() + } + + fn path(&self) -> &Arc { + &self.path + } + + fn full_path(&self, _: &AppContext) -> PathBuf { + todo!() + } + + fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr { + todo!() + } + + fn is_deleted(&self) -> bool { + todo!() + } + + fn as_any(&self) -> &dyn std::any::Any { + todo!() + } + + fn to_proto(&self) -> rpc::proto::File { + todo!() + } + } + + impl language::LocalFile for File { + fn abs_path(&self, _: &AppContext) -> PathBuf { + self.abs_path.clone() + } + + fn load(&self, _: &AppContext) -> Task> { + todo!() + } + + fn buffer_reloaded( + &self, + _: u64, + _: &clock::Global, + _: language::RopeFingerprint, + _: ::fs::LineEnding, + _: std::time::SystemTime, + _: &mut AppContext, + ) { + todo!() + } + } +} diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 415f160ea3a9335d8272c2122e2db285f3f16290..43b5109d027dd2bcc62d1e9a5cb48a5996e84fd7 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -99,14 +99,11 @@ pub struct GetCompletionsParams { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetCompletionsDocument { - pub source: String, pub tab_size: u32, pub indent_size: u32, pub insert_spaces: bool, pub uri: lsp::Url, - pub path: String, pub relative_path: String, - pub language_id: String, pub position: lsp::Position, pub version: usize, } @@ -169,3 +166,60 @@ impl lsp::notification::Notification for StatusNotification { type Params = StatusNotificationParams; const METHOD: &'static str = "statusNotification"; } + +pub enum SetEditorInfo {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SetEditorInfoParams { + pub editor_info: EditorInfo, + pub editor_plugin_info: EditorPluginInfo, +} + +impl lsp::request::Request for SetEditorInfo { + type Params = SetEditorInfoParams; + type Result = String; + const METHOD: &'static str = "setEditorInfo"; +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EditorInfo { + pub name: String, + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EditorPluginInfo { + pub name: String, + pub version: String, +} + +pub enum NotifyAccepted {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NotifyAcceptedParams { + pub uuid: String, +} + +impl lsp::request::Request for NotifyAccepted { + type Params = NotifyAcceptedParams; + type Result = String; + const METHOD: &'static str = "notifyAccepted"; +} + +pub enum NotifyRejected {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NotifyRejectedParams { + pub uuids: Vec, +} + +impl lsp::request::Request for NotifyRejected { + type Params = NotifyRejectedParams; + type Result = String; + const METHOD: &'static str = "notifyRejected"; +} diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 742541293e96346919f995bec8641ec347de2025..59f8ab6acb434b24d41c048903d37f1041491779 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -52,7 +52,7 @@ pub use language::{char_kind, CharKind}; use language::{ AutoindentMode, BracketPair, Buffer, CodeAction, CodeLabel, Completion, CursorShape, Diagnostic, DiagnosticSeverity, IndentKind, IndentSize, Language, OffsetRangeExt, OffsetUtf16, - Point, Rope, Selection, SelectionGoal, TransactionId, + Point, Selection, SelectionGoal, TransactionId, }; use link_go_to_definition::{ hide_link_definition, show_link_definition, LinkDefinitionKind, LinkGoToDefinitionState, @@ -1037,6 +1037,10 @@ impl Default for CopilotState { } impl CopilotState { + fn active_completion(&self) -> Option<&copilot::Completion> { + self.completions.get(self.active_completion_index) + } + fn text_for_active_completion( &self, cursor: Anchor, @@ -1044,7 +1048,7 @@ impl CopilotState { ) -> Option<&str> { use language::ToOffset as _; - let completion = self.completions.get(self.active_completion_index)?; + let completion = self.active_completion()?; let excerpt_id = self.excerpt_id?; let completion_buffer = buffer.buffer_for_excerpt(excerpt_id)?; if excerpt_id != cursor.excerpt_id @@ -1097,7 +1101,7 @@ impl CopilotState { fn push_completion(&mut self, new_completion: copilot::Completion) { for completion in &self.completions { - if *completion == new_completion { + if completion.text == new_completion.text && completion.range == new_completion.range { return; } } @@ -1496,7 +1500,7 @@ impl Editor { self.refresh_code_actions(cx); self.refresh_document_highlights(cx); refresh_matching_bracket_highlights(self, cx); - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } self.blink_manager.update(cx, BlinkManager::pause_blinking); @@ -1870,7 +1874,7 @@ impl Editor { return; } - if self.hide_copilot_suggestion(cx).is_some() { + if self.discard_copilot_suggestion(cx) { return; } @@ -2969,7 +2973,7 @@ impl Editor { Some(()) } - fn cycle_suggestions( + fn cycle_copilot_suggestions( &mut self, direction: Direction, cx: &mut ViewContext, @@ -3020,7 +3024,7 @@ impl Editor { fn next_copilot_suggestion(&mut self, _: &copilot::NextSuggestion, cx: &mut ViewContext) { if self.has_active_copilot_suggestion(cx) { - self.cycle_suggestions(Direction::Next, cx); + self.cycle_copilot_suggestions(Direction::Next, cx); } else { self.refresh_copilot_suggestions(false, cx); } @@ -3032,37 +3036,55 @@ impl Editor { cx: &mut ViewContext, ) { if self.has_active_copilot_suggestion(cx) { - self.cycle_suggestions(Direction::Prev, cx); + self.cycle_copilot_suggestions(Direction::Prev, cx); } else { self.refresh_copilot_suggestions(false, cx); } } fn accept_copilot_suggestion(&mut self, cx: &mut ViewContext) -> bool { - if let Some(text) = self.hide_copilot_suggestion(cx) { - self.insert_with_autoindent_mode(&text.to_string(), None, cx); + if let Some(suggestion) = self + .display_map + .update(cx, |map, cx| map.replace_suggestion::(None, cx)) + { + if let Some((copilot, completion)) = + Copilot::global(cx).zip(self.copilot_state.active_completion()) + { + copilot + .update(cx, |copilot, cx| copilot.accept_completion(completion, cx)) + .detach_and_log_err(cx); + } + self.insert_with_autoindent_mode(&suggestion.text.to_string(), None, cx); + cx.notify(); true } else { false } } - fn has_active_copilot_suggestion(&self, cx: &AppContext) -> bool { - self.display_map.read(cx).has_suggestion() - } - - fn hide_copilot_suggestion(&mut self, cx: &mut ViewContext) -> Option { + fn discard_copilot_suggestion(&mut self, cx: &mut ViewContext) -> bool { if self.has_active_copilot_suggestion(cx) { - let old_suggestion = self - .display_map + if let Some(copilot) = Copilot::global(cx) { + copilot + .update(cx, |copilot, cx| { + copilot.discard_completions(&self.copilot_state.completions, cx) + }) + .detach_and_log_err(cx); + } + + self.display_map .update(cx, |map, cx| map.replace_suggestion::(None, cx)); cx.notify(); - old_suggestion.map(|suggestion| suggestion.text) + true } else { - None + false } } + fn has_active_copilot_suggestion(&self, cx: &AppContext) -> bool { + self.display_map.read(cx).has_suggestion() + } + fn update_visible_copilot_suggestion(&mut self, cx: &mut ViewContext) { let snapshot = self.buffer.read(cx).snapshot(cx); let selection = self.selections.newest_anchor(); @@ -3072,7 +3094,7 @@ impl Editor { || !self.completion_tasks.is_empty() || selection.start != selection.end { - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } else if let Some(text) = self .copilot_state .text_for_active_completion(cursor, &snapshot) @@ -3088,13 +3110,13 @@ impl Editor { }); cx.notify(); } else { - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } } fn clear_copilot_suggestions(&mut self, cx: &mut ViewContext) { self.copilot_state = Default::default(); - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); } pub fn render_code_actions_indicator( @@ -3212,7 +3234,7 @@ impl Editor { self.completion_tasks.clear(); } self.context_menu = Some(menu); - self.hide_copilot_suggestion(cx); + self.discard_copilot_suggestion(cx); cx.notify(); } @@ -6643,6 +6665,7 @@ impl Editor { multi_buffer::Event::DiagnosticsUpdated => { self.refresh_active_diagnostics(cx); } + multi_buffer::Event::LanguageChanged => {} } } diff --git a/crates/editor/src/multi_buffer.rs b/crates/editor/src/multi_buffer.rs index f8a56557abb633f832ce2d23dcba264fc0535fa7..824c108e46806f77d03002311b0cfc66fa67a20b 100644 --- a/crates/editor/src/multi_buffer.rs +++ b/crates/editor/src/multi_buffer.rs @@ -64,6 +64,7 @@ pub enum Event { }, Edited, Reloaded, + LanguageChanged, Reparsed, Saved, FileHandleChanged, @@ -1302,6 +1303,7 @@ impl MultiBuffer { language::Event::Saved => Event::Saved, language::Event::FileHandleChanged => Event::FileHandleChanged, language::Event::Reloaded => Event::Reloaded, + language::Event::LanguageChanged => Event::LanguageChanged, language::Event::Reparsed => Event::Reparsed, language::Event::DiagnosticsUpdated => Event::DiagnosticsUpdated, language::Event::Closed => Event::Closed, diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index fb16d416409523261aeb010674a6e94473e6bbc9..7325ca9af53dcee7d163533305f18bf81e48d61f 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -187,6 +187,7 @@ pub enum Event { Saved, FileHandleChanged, Reloaded, + LanguageChanged, Reparsed, DiagnosticsUpdated, Closed, @@ -536,6 +537,7 @@ impl Buffer { self.syntax_map.lock().clear(); self.language = language; self.reparse(cx); + cx.emit(Event::LanguageChanged); } pub fn set_language_registry(&mut self, language_registry: Arc) { diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index f5c144a3adb915c61ae52af38fa78cc961b00639..e30ab56e45d1b088aa5ce811dfe246094c00c413 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -19,6 +19,7 @@ test-support = [ [dependencies] text = { path = "../text" } +copilot = { path = "../copilot" } client = { path = "../client" } clock = { path = "../clock" } collections = { path = "../collections" } diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index d126cb4994b86f592004290de29923900cca1a14..d5b7ac3f3f80afb4c85b1776e974587c35e9913d 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -12,6 +12,7 @@ use anyhow::{anyhow, Context, Result}; use client::{proto, Client, TypedEnvelope, UserStore}; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet}; +use copilot::Copilot; use futures::{ channel::mpsc::{self, UnboundedReceiver}, future::{try_join_all, Shared}, @@ -129,6 +130,7 @@ pub struct Project { _maintain_buffer_languages: Task<()>, _maintain_workspace_config: Task<()>, terminals: Terminals, + copilot_enabled: bool, } enum BufferMessage { @@ -472,6 +474,7 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, + copilot_enabled: Copilot::global(cx).is_some(), } }) } @@ -559,6 +562,7 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, + copilot_enabled: Copilot::global(cx).is_some(), }; for worktree in worktrees { let _ = this.add_worktree(&worktree, cx); @@ -664,6 +668,15 @@ impl Project { self.start_language_server(worktree_id, worktree_path, language, cx); } + if !self.copilot_enabled && Copilot::global(cx).is_some() { + self.copilot_enabled = true; + for buffer in self.opened_buffers.values() { + if let Some(buffer) = buffer.upgrade(cx) { + self.register_buffer_with_copilot(&buffer, cx); + } + } + } + cx.notify(); } @@ -1616,6 +1629,7 @@ impl Project { self.detect_language_for_buffer(buffer, cx); self.register_buffer_with_language_server(buffer, cx); + self.register_buffer_with_copilot(buffer, cx); cx.observe_release(buffer, |this, buffer, cx| { if let Some(file) = File::from_dyn(buffer.file()) { if file.is_local() { @@ -1731,6 +1745,16 @@ impl Project { }); } + fn register_buffer_with_copilot( + &self, + buffer_handle: &ModelHandle, + cx: &mut ModelContext, + ) { + if let Some(copilot) = Copilot::global(cx) { + copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); + } + } + async fn send_buffer_messages( this: WeakModelHandle, rx: UnboundedReceiver,