diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 4b67bc16504b454348b9675d3d7cd1b3eeb18bf9..28bc7977c101e1ba51933d3408ef5e98010fb746 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -104,15 +104,38 @@ pub fn init(http: Arc, node_runtime: Arc, cx: &mut enum CopilotServer { Disabled, - Starting { - task: Shared>, - }, + Starting { task: Shared> }, Error(Arc), - Started { - server: Arc, - status: SignInStatus, - registered_buffers: HashMap, - }, + Running(RunningCopilotServer), +} + +impl CopilotServer { + fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> { + let server = self.as_running()?; + if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) { + Ok(server) + } else { + Err(anyhow!("must sign in before using copilot")) + } + } + + fn as_running(&mut self) -> Result<&mut RunningCopilotServer> { + match self { + CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")), + CopilotServer::Disabled => Err(anyhow!("copilot is disabled")), + CopilotServer::Error(error) => Err(anyhow!( + "copilot was not started because of an error: {}", + error + )), + CopilotServer::Running(server) => Ok(server), + } + } +} + +struct RunningCopilotServer { + lsp: Arc, + sign_in_status: SignInStatus, + registered_buffers: HashMap, } #[derive(Clone, Debug)] @@ -293,11 +316,11 @@ impl Copilot { let this = cx.add_model(|cx| Self { http: http.clone(), node_runtime: NodeRuntime::new(http, cx.background().clone()), - server: CopilotServer::Started { - server: Arc::new(server), - status: SignInStatus::Authorized, + server: CopilotServer::Running(RunningCopilotServer { + lsp: Arc::new(server), + sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), - }, + }), buffers: Default::default(), }); (this, fake_server) @@ -371,11 +394,11 @@ impl Copilot { cx.notify(); match server { Ok((server, status)) => { - this.server = CopilotServer::Started { - server, - status: SignInStatus::SignedOut, + this.server = CopilotServer::Running(RunningCopilotServer { + lsp: server, + sign_in_status: SignInStatus::SignedOut, registered_buffers: Default::default(), - }; + }); this.update_sign_in_status(status, cx); } Err(error) => { @@ -388,8 +411,8 @@ impl Copilot { } fn sign_in(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status, .. } = &mut self.server { - let task = match status { + if let CopilotServer::Running(server) = &mut self.server { + let task = match &server.sign_in_status { SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => { Task::ready(Ok(())).shared() } @@ -398,11 +421,11 @@ impl Copilot { task.clone() } SignInStatus::SignedOut => { - let server = server.clone(); + let lsp = server.lsp.clone(); let task = cx .spawn(|this, mut cx| async move { let sign_in = async { - let sign_in = server + let sign_in = lsp .request::( request::SignInInitiateParams {}, ) @@ -413,8 +436,10 @@ impl Copilot { } request::SignInInitiateResult::PromptUserDeviceFlow(flow) => { this.update(&mut cx, |this, cx| { - if let CopilotServer::Started { status, .. } = - &mut this.server + if let CopilotServer::Running(RunningCopilotServer { + sign_in_status: status, + .. + }) = &mut this.server { if let SignInStatus::SigningIn { prompt: prompt_flow, @@ -426,7 +451,7 @@ impl Copilot { } } }); - let response = server + let response = lsp .request::( request::SignInConfirmParams { user_code: flow.user_code, @@ -454,7 +479,7 @@ impl Copilot { }) }) .shared(); - *status = SignInStatus::SigningIn { + server.sign_in_status = SignInStatus::SigningIn { prompt: None, task: task.clone(), }; @@ -474,7 +499,7 @@ impl Copilot { fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); - if let CopilotServer::Started { server, .. } = &self.server { + if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server { let server = server.clone(); cx.background().spawn(async move { server @@ -512,12 +537,12 @@ impl Copilot { let buffer_id = buffer.id(); self.buffers.insert(buffer_id, buffer.downgrade()); - if let CopilotServer::Started { - server, - status, + if let CopilotServer::Running(RunningCopilotServer { + lsp: server, + sign_in_status: status, registered_buffers, .. - } = &mut self.server + }) = &mut self.server { if !matches!(status, SignInStatus::Authorized { .. }) { return; @@ -565,26 +590,23 @@ impl Copilot { event: &language::Event, cx: &mut ModelContext, ) -> Result<()> { - if let CopilotServer::Started { - server, - registered_buffers, - .. - } = &mut self.server - { - if let Some(registered_buffer) = registered_buffers.get_mut(&buffer.id()) { + if let Ok(server) = self.server.as_running() { + if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) { match event { language::Event::Edited => { - registered_buffer.report_changes(&buffer, server, cx)?; + registered_buffer.report_changes(&buffer, &server.lsp, cx)?; } language::Event::Saved => { - server.notify::( - lsp::DidSaveTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new( - registered_buffer.uri.clone(), - ), - text: None, - }, - )?; + server + .lsp + .notify::( + lsp::DidSaveTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new( + registered_buffer.uri.clone(), + ), + text: None, + }, + )?; } language::Event::FileHandleChanged | language::Event::LanguageChanged => { let new_language_id = id_for_language(buffer.read(cx).language()); @@ -594,21 +616,25 @@ impl Copilot { { let old_uri = mem::replace(&mut registered_buffer.uri, new_uri); registered_buffer.language_id = new_language_id; - server.notify::( - lsp::DidCloseTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new(old_uri), - }, - )?; - server.notify::( - lsp::DidOpenTextDocumentParams { - text_document: lsp::TextDocumentItem::new( - registered_buffer.uri.clone(), - registered_buffer.language_id.clone(), - registered_buffer.snapshot_version, - registered_buffer.snapshot.text(), - ), - }, - )?; + server + .lsp + .notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(old_uri), + }, + )?; + server + .lsp + .notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( + registered_buffer.uri.clone(), + registered_buffer.language_id.clone(), + registered_buffer.snapshot_version, + registered_buffer.snapshot.text(), + ), + }, + )?; } } _ => {} @@ -620,14 +646,10 @@ impl Copilot { } fn unregister_buffer(&mut self, buffer_id: usize) { - if let CopilotServer::Started { - server, - registered_buffers, - .. - } = &mut self.server - { - if let Some(buffer) = registered_buffers.remove(&buffer_id) { + if let Ok(server) = self.server.as_running() { + if let Some(buffer) = server.registered_buffers.remove(&buffer_id) { server + .lsp .notify::( lsp::DidCloseTextDocumentParams { text_document: lsp::TextDocumentIdentifier::new(buffer.uri), @@ -677,34 +699,15 @@ impl Copilot { T: ToPointUtf16, { self.register_buffer(buffer, cx); - let (server, registered_buffer) = match &mut self.server { - CopilotServer::Starting { .. } => { - return Task::ready(Err(anyhow!("copilot is still starting"))) - } - CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))), - CopilotServer::Error(error) => { - return Task::ready(Err(anyhow!( - "copilot was not started because of an error: {}", - error - ))) - } - CopilotServer::Started { - server, - status, - registered_buffers, - .. - } => { - if matches!(status, SignInStatus::Authorized { .. }) { - let registered_buffer = registered_buffers.get_mut(&buffer.id()).unwrap(); - if let Err(error) = registered_buffer.report_changes(buffer, &server, cx) { - return Task::ready(Err(error)); - } - (server.clone(), registered_buffer) - } else { - return Task::ready(Err(anyhow!("must sign in before using copilot"))); - } - } + + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), }; + let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap(); + if let Err(error) = registered_buffer.report_changes(buffer, &server.lsp, cx) { + return Task::ready(Err(error)); + } let uri = registered_buffer.uri.clone(); let snapshot = registered_buffer.snapshot.clone(); @@ -720,7 +723,7 @@ impl Copilot { .file() .map(|file| file.path().to_path_buf()) .unwrap_or_default(); - let request = server.request::(request::GetCompletionsParams { + let request = server.lsp.request::(request::GetCompletionsParams { doc: request::GetCompletionsDocument { uri, tab_size: tab_size.into(), @@ -742,6 +745,7 @@ impl Copilot { let end = snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); Completion { + uuid: completion.uuid, range: snapshot.anchor_before(start)..snapshot.anchor_after(end), text: completion.text, } @@ -756,14 +760,16 @@ impl Copilot { CopilotServer::Starting { task } => Status::Starting { task: task.clone() }, CopilotServer::Disabled => Status::Disabled, CopilotServer::Error(error) => Status::Error(error.clone()), - CopilotServer::Started { status, .. } => match status { - SignInStatus::Authorized { .. } => Status::Authorized, - SignInStatus::Unauthorized { .. } => Status::Unauthorized, - SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { - prompt: prompt.clone(), - }, - SignInStatus::SignedOut => Status::SignedOut, - }, + CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => { + match sign_in_status { + SignInStatus::Authorized { .. } => Status::Authorized, + SignInStatus::Unauthorized { .. } => Status::Unauthorized, + SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { + prompt: prompt.clone(), + }, + SignInStatus::SignedOut => Status::SignedOut, + } + } } } @@ -774,13 +780,12 @@ impl Copilot { ) { self.buffers.retain(|_, buffer| buffer.is_upgradable(cx)); - if let CopilotServer::Started { status, .. } = &mut self.server { + if let Ok(server) = self.server.as_running() { match lsp_status { request::SignInStatus::Ok { .. } | request::SignInStatus::MaybeOk { .. } | request::SignInStatus::AlreadySignedIn { .. } => { - *status = SignInStatus::Authorized; - + server.sign_in_status = SignInStatus::Authorized; for buffer in self.buffers.values().cloned().collect::>() { if let Some(buffer) = buffer.upgrade(cx) { self.register_buffer(&buffer, cx); @@ -788,15 +793,13 @@ impl Copilot { } } request::SignInStatus::NotAuthorized { .. } => { - *status = SignInStatus::Unauthorized; - + server.sign_in_status = SignInStatus::Unauthorized; for buffer_id in self.buffers.keys().copied().collect::>() { self.unregister_buffer(buffer_id); } } request::SignInStatus::NotSignedIn => { - *status = SignInStatus::SignedOut; - + server.sign_in_status = SignInStatus::SignedOut; for buffer_id in self.buffers.keys().copied().collect::>() { self.unregister_buffer(buffer_id); }