1pub mod copilot_chat;
2mod copilot_completion_provider;
3pub mod request;
4mod sign_in;
5
6use crate::sign_in::initiate_sign_in_within_workspace;
7use ::fs::Fs;
8use anyhow::{Context as _, Result, anyhow};
9use collections::{HashMap, HashSet};
10use command_palette_hooks::CommandPaletteFilter;
11use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
12use gpui::{
13 App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
14 WeakEntity, actions,
15};
16use http_client::HttpClient;
17use language::language_settings::CopilotSettings;
18use language::{
19 Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
20 language_settings::{EditPredictionProvider, all_language_settings, language_settings},
21 point_from_lsp, point_to_lsp,
22};
23use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
24use node_runtime::NodeRuntime;
25use parking_lot::Mutex;
26use request::StatusNotification;
27use serde_json::json;
28use settings::SettingsStore;
29use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
30use std::collections::hash_map::Entry;
31use std::{
32 any::TypeId,
33 env,
34 ffi::OsString,
35 mem,
36 ops::Range,
37 path::{Path, PathBuf},
38 sync::Arc,
39};
40use util::{ResultExt, fs::remove_matching};
41use workspace::Workspace;
42
43pub use crate::copilot_completion_provider::CopilotCompletionProvider;
44pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
45
46actions!(
47 copilot,
48 [
49 /// Requests a code completion suggestion from Copilot.
50 Suggest,
51 /// Cycles to the next Copilot suggestion.
52 NextSuggestion,
53 /// Cycles to the previous Copilot suggestion.
54 PreviousSuggestion,
55 /// Reinstalls the Copilot language server.
56 Reinstall,
57 /// Signs in to GitHub Copilot.
58 SignIn,
59 /// Signs out of GitHub Copilot.
60 SignOut
61 ]
62);
63
64pub fn init(
65 new_server_id: LanguageServerId,
66 fs: Arc<dyn Fs>,
67 http: Arc<dyn HttpClient>,
68 node_runtime: NodeRuntime,
69 cx: &mut App,
70) {
71 let language_settings = all_language_settings(None, cx);
72 let configuration = copilot_chat::CopilotChatConfiguration {
73 enterprise_uri: language_settings
74 .edit_predictions
75 .copilot
76 .enterprise_uri
77 .clone(),
78 };
79 copilot_chat::init(fs.clone(), http.clone(), configuration, cx);
80
81 let copilot = cx.new({
82 let node_runtime = node_runtime.clone();
83 move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
84 });
85 Copilot::set_global(copilot.clone(), cx);
86 cx.observe(&copilot, |handle, cx| {
87 let copilot_action_types = [
88 TypeId::of::<Suggest>(),
89 TypeId::of::<NextSuggestion>(),
90 TypeId::of::<PreviousSuggestion>(),
91 TypeId::of::<Reinstall>(),
92 ];
93 let copilot_auth_action_types = [TypeId::of::<SignOut>()];
94 let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
95 let status = handle.read(cx).status();
96 let filter = CommandPaletteFilter::global_mut(cx);
97
98 match status {
99 Status::Disabled => {
100 filter.hide_action_types(&copilot_action_types);
101 filter.hide_action_types(&copilot_auth_action_types);
102 filter.hide_action_types(&copilot_no_auth_action_types);
103 }
104 Status::Authorized => {
105 filter.hide_action_types(&copilot_no_auth_action_types);
106 filter.show_action_types(
107 copilot_action_types
108 .iter()
109 .chain(&copilot_auth_action_types),
110 );
111 }
112 _ => {
113 filter.hide_action_types(&copilot_action_types);
114 filter.hide_action_types(&copilot_auth_action_types);
115 filter.show_action_types(copilot_no_auth_action_types.iter());
116 }
117 }
118 })
119 .detach();
120
121 cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
122 workspace.register_action(|workspace, _: &SignIn, window, cx| {
123 if let Some(copilot) = Copilot::global(cx) {
124 let is_reinstall = false;
125 initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
126 }
127 });
128 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
129 if let Some(copilot) = Copilot::global(cx) {
130 reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
131 }
132 });
133 workspace.register_action(|workspace, _: &SignOut, _window, cx| {
134 if let Some(copilot) = Copilot::global(cx) {
135 sign_out_within_workspace(workspace, copilot, cx);
136 }
137 });
138 })
139 .detach();
140}
141
142enum CopilotServer {
143 Disabled,
144 Starting { task: Shared<Task<()>> },
145 Error(Arc<str>),
146 Running(RunningCopilotServer),
147}
148
149impl CopilotServer {
150 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
151 let server = self.as_running()?;
152 anyhow::ensure!(
153 matches!(server.sign_in_status, SignInStatus::Authorized { .. }),
154 "must sign in before using copilot"
155 );
156 Ok(server)
157 }
158
159 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
160 match self {
161 CopilotServer::Starting { .. } => anyhow::bail!("copilot is still starting"),
162 CopilotServer::Disabled => anyhow::bail!("copilot is disabled"),
163 CopilotServer::Error(error) => {
164 anyhow::bail!("copilot was not started because of an error: {error}")
165 }
166 CopilotServer::Running(server) => Ok(server),
167 }
168 }
169}
170
171struct RunningCopilotServer {
172 lsp: Arc<LanguageServer>,
173 sign_in_status: SignInStatus,
174 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
175}
176
177#[derive(Clone, Debug)]
178enum SignInStatus {
179 Authorized,
180 Unauthorized,
181 SigningIn {
182 prompt: Option<request::PromptUserDeviceFlow>,
183 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
184 },
185 SignedOut {
186 awaiting_signing_in: bool,
187 },
188}
189
190#[derive(Debug, Clone)]
191pub enum Status {
192 Starting {
193 task: Shared<Task<()>>,
194 },
195 Error(Arc<str>),
196 Disabled,
197 SignedOut {
198 awaiting_signing_in: bool,
199 },
200 SigningIn {
201 prompt: Option<request::PromptUserDeviceFlow>,
202 },
203 Unauthorized,
204 Authorized,
205}
206
207impl Status {
208 pub fn is_authorized(&self) -> bool {
209 matches!(self, Status::Authorized)
210 }
211
212 pub fn is_configured(&self) -> bool {
213 matches!(
214 self,
215 Status::Starting { .. }
216 | Status::Error(_)
217 | Status::SigningIn { .. }
218 | Status::Authorized
219 )
220 }
221}
222
223struct RegisteredBuffer {
224 uri: lsp::Url,
225 language_id: String,
226 snapshot: BufferSnapshot,
227 snapshot_version: i32,
228 _subscriptions: [gpui::Subscription; 2],
229 pending_buffer_change: Task<Option<()>>,
230}
231
232impl RegisteredBuffer {
233 fn report_changes(
234 &mut self,
235 buffer: &Entity<Buffer>,
236 cx: &mut Context<Copilot>,
237 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
238 let (done_tx, done_rx) = oneshot::channel();
239
240 if buffer.read(cx).version() == self.snapshot.version {
241 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
242 } else {
243 let buffer = buffer.downgrade();
244 let id = buffer.entity_id();
245 let prev_pending_change =
246 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
247 self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
248 prev_pending_change.await;
249
250 let old_version = copilot
251 .update(cx, |copilot, _| {
252 let server = copilot.server.as_authenticated().log_err()?;
253 let buffer = server.registered_buffers.get_mut(&id)?;
254 Some(buffer.snapshot.version.clone())
255 })
256 .ok()??;
257 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()).ok()?;
258
259 let content_changes = cx
260 .background_spawn({
261 let new_snapshot = new_snapshot.clone();
262 async move {
263 new_snapshot
264 .edits_since::<(PointUtf16, usize)>(&old_version)
265 .map(|edit| {
266 let edit_start = edit.new.start.0;
267 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
268 let new_text = new_snapshot
269 .text_for_range(edit.new.start.1..edit.new.end.1)
270 .collect();
271 lsp::TextDocumentContentChangeEvent {
272 range: Some(lsp::Range::new(
273 point_to_lsp(edit_start),
274 point_to_lsp(edit_end),
275 )),
276 range_length: None,
277 text: new_text,
278 }
279 })
280 .collect::<Vec<_>>()
281 }
282 })
283 .await;
284
285 copilot
286 .update(cx, |copilot, _| {
287 let server = copilot.server.as_authenticated().log_err()?;
288 let buffer = server.registered_buffers.get_mut(&id)?;
289 if !content_changes.is_empty() {
290 buffer.snapshot_version += 1;
291 buffer.snapshot = new_snapshot;
292 server
293 .lsp
294 .notify::<lsp::notification::DidChangeTextDocument>(
295 &lsp::DidChangeTextDocumentParams {
296 text_document: lsp::VersionedTextDocumentIdentifier::new(
297 buffer.uri.clone(),
298 buffer.snapshot_version,
299 ),
300 content_changes,
301 },
302 )
303 .ok();
304 }
305 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
306 Some(())
307 })
308 .ok()?;
309
310 Some(())
311 });
312 }
313
314 done_rx
315 }
316}
317
318#[derive(Debug)]
319pub struct Completion {
320 pub uuid: String,
321 pub range: Range<Anchor>,
322 pub text: String,
323}
324
325pub struct Copilot {
326 fs: Arc<dyn Fs>,
327 node_runtime: NodeRuntime,
328 server: CopilotServer,
329 buffers: HashSet<WeakEntity<Buffer>>,
330 server_id: LanguageServerId,
331 _subscription: gpui::Subscription,
332}
333
334pub enum Event {
335 CopilotLanguageServerStarted,
336 CopilotAuthSignedIn,
337 CopilotAuthSignedOut,
338}
339
340impl EventEmitter<Event> for Copilot {}
341
342struct GlobalCopilot(Entity<Copilot>);
343
344impl Global for GlobalCopilot {}
345
346impl Copilot {
347 pub fn global(cx: &App) -> Option<Entity<Self>> {
348 cx.try_global::<GlobalCopilot>()
349 .map(|model| model.0.clone())
350 }
351
352 pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
353 cx.set_global(GlobalCopilot(copilot));
354 }
355
356 fn start(
357 new_server_id: LanguageServerId,
358 fs: Arc<dyn Fs>,
359 node_runtime: NodeRuntime,
360 cx: &mut Context<Self>,
361 ) -> Self {
362 let mut this = Self {
363 server_id: new_server_id,
364 fs,
365 node_runtime,
366 server: CopilotServer::Disabled,
367 buffers: Default::default(),
368 _subscription: cx.on_app_quit(Self::shutdown_language_server),
369 };
370 this.start_copilot(true, false, cx);
371 cx.observe_global::<SettingsStore>(move |this, cx| {
372 this.start_copilot(true, false, cx);
373 this.send_configuration_update(cx);
374 })
375 .detach();
376 this
377 }
378
379 fn shutdown_language_server(
380 &mut self,
381 _cx: &mut Context<Self>,
382 ) -> impl Future<Output = ()> + use<> {
383 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
384 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
385 _ => None,
386 };
387
388 async move {
389 if let Some(shutdown) = shutdown {
390 shutdown.await;
391 }
392 }
393 }
394
395 fn start_copilot(
396 &mut self,
397 check_edit_prediction_provider: bool,
398 awaiting_sign_in_after_start: bool,
399 cx: &mut Context<Self>,
400 ) {
401 if !matches!(self.server, CopilotServer::Disabled) {
402 return;
403 }
404 let language_settings = all_language_settings(None, cx);
405 if check_edit_prediction_provider
406 && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
407 {
408 return;
409 }
410 let server_id = self.server_id;
411 let fs = self.fs.clone();
412 let node_runtime = self.node_runtime.clone();
413 let env = self.build_env(&language_settings.edit_predictions.copilot);
414 let start_task = cx
415 .spawn(async move |this, cx| {
416 Self::start_language_server(
417 server_id,
418 fs,
419 node_runtime,
420 env,
421 this,
422 awaiting_sign_in_after_start,
423 cx,
424 )
425 .await
426 })
427 .shared();
428 self.server = CopilotServer::Starting { task: start_task };
429 cx.notify();
430 }
431
432 fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
433 let proxy_url = copilot_settings.proxy.clone()?;
434 let no_verify = copilot_settings.proxy_no_verify;
435 let http_or_https_proxy = if proxy_url.starts_with("http:") {
436 Some("HTTP_PROXY")
437 } else if proxy_url.starts_with("https:") {
438 Some("HTTPS_PROXY")
439 } else {
440 log::error!(
441 "Unsupported protocol scheme for language server proxy (must be http or https)"
442 );
443 None
444 };
445
446 let mut env = HashMap::default();
447
448 if let Some(proxy_type) = http_or_https_proxy {
449 env.insert(proxy_type.to_string(), proxy_url);
450 if let Some(true) = no_verify {
451 env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
452 };
453 }
454
455 if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) {
456 env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token);
457 }
458
459 if env.is_empty() { None } else { Some(env) }
460 }
461
462 fn send_configuration_update(&mut self, cx: &mut Context<Self>) {
463 let copilot_settings = all_language_settings(None, cx)
464 .edit_predictions
465 .copilot
466 .clone();
467
468 let settings = json!({
469 "http": {
470 "proxy": copilot_settings.proxy,
471 "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false)
472 },
473 "github-enterprise": {
474 "uri": copilot_settings.enterprise_uri
475 }
476 });
477
478 if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
479 copilot_chat.update(cx, |chat, cx| {
480 chat.set_configuration(
481 copilot_chat::CopilotChatConfiguration {
482 enterprise_uri: copilot_settings.enterprise_uri.clone(),
483 },
484 cx,
485 );
486 });
487 }
488
489 if let Ok(server) = self.server.as_running() {
490 server
491 .lsp
492 .notify::<lsp::notification::DidChangeConfiguration>(
493 &lsp::DidChangeConfigurationParams { settings },
494 )
495 .log_err();
496 }
497 }
498
499 #[cfg(any(test, feature = "test-support"))]
500 pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
501 use fs::FakeFs;
502 use lsp::FakeLanguageServer;
503 use node_runtime::NodeRuntime;
504
505 let (server, fake_server) = FakeLanguageServer::new(
506 LanguageServerId(0),
507 LanguageServerBinary {
508 path: "path/to/copilot".into(),
509 arguments: vec![],
510 env: None,
511 },
512 "copilot".into(),
513 Default::default(),
514 &mut cx.to_async(),
515 );
516 let node_runtime = NodeRuntime::unavailable();
517 let this = cx.new(|cx| Self {
518 server_id: LanguageServerId(0),
519 fs: FakeFs::new(cx.background_executor().clone()),
520 node_runtime,
521 server: CopilotServer::Running(RunningCopilotServer {
522 lsp: Arc::new(server),
523 sign_in_status: SignInStatus::Authorized,
524 registered_buffers: Default::default(),
525 }),
526 _subscription: cx.on_app_quit(Self::shutdown_language_server),
527 buffers: Default::default(),
528 });
529 (this, fake_server)
530 }
531
532 async fn start_language_server(
533 new_server_id: LanguageServerId,
534 fs: Arc<dyn Fs>,
535 node_runtime: NodeRuntime,
536 env: Option<HashMap<String, String>>,
537 this: WeakEntity<Self>,
538 awaiting_sign_in_after_start: bool,
539 cx: &mut AsyncApp,
540 ) {
541 let start_language_server = async {
542 let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
543 let node_path = node_runtime.binary_path().await?;
544 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
545 let binary = LanguageServerBinary {
546 path: node_path,
547 arguments,
548 env,
549 };
550
551 let root_path = if cfg!(target_os = "windows") {
552 Path::new("C:/")
553 } else {
554 Path::new("/")
555 };
556
557 let server_name = LanguageServerName("copilot".into());
558 let server = LanguageServer::new(
559 Arc::new(Mutex::new(None)),
560 new_server_id,
561 server_name,
562 binary,
563 root_path,
564 None,
565 Default::default(),
566 cx,
567 )?;
568
569 server
570 .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
571 .detach();
572
573 let configuration = lsp::DidChangeConfigurationParams {
574 settings: Default::default(),
575 };
576
577 let editor_info = request::SetEditorInfoParams {
578 editor_info: request::EditorInfo {
579 name: "zed".into(),
580 version: env!("CARGO_PKG_VERSION").into(),
581 },
582 editor_plugin_info: request::EditorPluginInfo {
583 name: "zed-copilot".into(),
584 version: "0.0.1".into(),
585 },
586 };
587 let editor_info_json = serde_json::to_value(&editor_info)?;
588
589 let server = cx
590 .update(|cx| {
591 let mut params = server.default_initialize_params(false, cx);
592 params.initialization_options = Some(editor_info_json);
593 server.initialize(params, configuration.into(), cx)
594 })?
595 .await?;
596
597 let status = server
598 .request::<request::CheckStatus>(request::CheckStatusParams {
599 local_checks_only: false,
600 })
601 .await
602 .into_response()
603 .context("copilot: check status")?;
604
605 anyhow::Ok((server, status))
606 };
607
608 let server = start_language_server.await;
609 this.update(cx, |this, cx| {
610 cx.notify();
611 match server {
612 Ok((server, status)) => {
613 this.server = CopilotServer::Running(RunningCopilotServer {
614 lsp: server,
615 sign_in_status: SignInStatus::SignedOut {
616 awaiting_signing_in: awaiting_sign_in_after_start,
617 },
618 registered_buffers: Default::default(),
619 });
620 cx.emit(Event::CopilotLanguageServerStarted);
621 this.update_sign_in_status(status, cx);
622 // Send configuration now that the LSP is fully started
623 this.send_configuration_update(cx);
624 }
625 Err(error) => {
626 this.server = CopilotServer::Error(error.to_string().into());
627 cx.notify()
628 }
629 }
630 })
631 .ok();
632 }
633
634 pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
635 if let CopilotServer::Running(server) = &mut self.server {
636 let task = match &server.sign_in_status {
637 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
638 SignInStatus::SigningIn { task, .. } => {
639 cx.notify();
640 task.clone()
641 }
642 SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
643 let lsp = server.lsp.clone();
644 let task = cx
645 .spawn(async move |this, cx| {
646 let sign_in = async {
647 let sign_in = lsp
648 .request::<request::SignInInitiate>(
649 request::SignInInitiateParams {},
650 )
651 .await
652 .into_response()
653 .context("copilot sign-in")?;
654 match sign_in {
655 request::SignInInitiateResult::AlreadySignedIn { user } => {
656 Ok(request::SignInStatus::Ok { user: Some(user) })
657 }
658 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
659 this.update(cx, |this, cx| {
660 if let CopilotServer::Running(RunningCopilotServer {
661 sign_in_status: status,
662 ..
663 }) = &mut this.server
664 {
665 if let SignInStatus::SigningIn {
666 prompt: prompt_flow,
667 ..
668 } = status
669 {
670 *prompt_flow = Some(flow.clone());
671 cx.notify();
672 }
673 }
674 })?;
675 let response = lsp
676 .request::<request::SignInConfirm>(
677 request::SignInConfirmParams {
678 user_code: flow.user_code,
679 },
680 )
681 .await
682 .into_response()
683 .context("copilot: sign in confirm")?;
684 Ok(response)
685 }
686 }
687 };
688
689 let sign_in = sign_in.await;
690 this.update(cx, |this, cx| match sign_in {
691 Ok(status) => {
692 this.update_sign_in_status(status, cx);
693 Ok(())
694 }
695 Err(error) => {
696 this.update_sign_in_status(
697 request::SignInStatus::NotSignedIn,
698 cx,
699 );
700 Err(Arc::new(error))
701 }
702 })?
703 })
704 .shared();
705 server.sign_in_status = SignInStatus::SigningIn {
706 prompt: None,
707 task: task.clone(),
708 };
709 cx.notify();
710 task
711 }
712 };
713
714 cx.background_spawn(task.map_err(|err| anyhow!("{err:?}")))
715 } else {
716 // If we're downloading, wait until download is finished
717 // If we're in a stuck state, display to the user
718 Task::ready(Err(anyhow!("copilot hasn't started yet")))
719 }
720 }
721
722 pub(crate) fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
723 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
724 match &self.server {
725 CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
726 let server = server.clone();
727 cx.background_spawn(async move {
728 server
729 .request::<request::SignOut>(request::SignOutParams {})
730 .await
731 .into_response()
732 .context("copilot: sign in confirm")?;
733 anyhow::Ok(())
734 })
735 }
736 CopilotServer::Disabled => cx.background_spawn(async {
737 clear_copilot_config_dir().await;
738 anyhow::Ok(())
739 }),
740 _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
741 }
742 }
743
744 pub(crate) fn reinstall(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
745 let language_settings = all_language_settings(None, cx);
746 let env = self.build_env(&language_settings.edit_predictions.copilot);
747 let start_task = cx
748 .spawn({
749 let fs = self.fs.clone();
750 let node_runtime = self.node_runtime.clone();
751 let server_id = self.server_id;
752 async move |this, cx| {
753 clear_copilot_dir().await;
754 Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
755 .await
756 }
757 })
758 .shared();
759
760 self.server = CopilotServer::Starting {
761 task: start_task.clone(),
762 };
763
764 cx.notify();
765
766 start_task
767 }
768
769 pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
770 if let CopilotServer::Running(server) = &self.server {
771 Some(&server.lsp)
772 } else {
773 None
774 }
775 }
776
777 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
778 let weak_buffer = buffer.downgrade();
779 self.buffers.insert(weak_buffer.clone());
780
781 if let CopilotServer::Running(RunningCopilotServer {
782 lsp: server,
783 sign_in_status: status,
784 registered_buffers,
785 ..
786 }) = &mut self.server
787 {
788 if !matches!(status, SignInStatus::Authorized { .. }) {
789 return;
790 }
791
792 let entry = registered_buffers.entry(buffer.entity_id());
793 if let Entry::Vacant(e) = entry {
794 let Ok(uri) = uri_for_buffer(buffer, cx) else {
795 return;
796 };
797 let language_id = id_for_language(buffer.read(cx).language());
798 let snapshot = buffer.read(cx).snapshot();
799 server
800 .notify::<lsp::notification::DidOpenTextDocument>(
801 &lsp::DidOpenTextDocumentParams {
802 text_document: lsp::TextDocumentItem {
803 uri: uri.clone(),
804 language_id: language_id.clone(),
805 version: 0,
806 text: snapshot.text(),
807 },
808 },
809 )
810 .ok();
811
812 e.insert(RegisteredBuffer {
813 uri,
814 language_id,
815 snapshot,
816 snapshot_version: 0,
817 pending_buffer_change: Task::ready(Some(())),
818 _subscriptions: [
819 cx.subscribe(buffer, |this, buffer, event, cx| {
820 this.handle_buffer_event(buffer, event, cx).log_err();
821 }),
822 cx.observe_release(buffer, move |this, _buffer, _cx| {
823 this.buffers.remove(&weak_buffer);
824 this.unregister_buffer(&weak_buffer);
825 }),
826 ],
827 });
828 }
829 }
830 }
831
832 fn handle_buffer_event(
833 &mut self,
834 buffer: Entity<Buffer>,
835 event: &language::BufferEvent,
836 cx: &mut Context<Self>,
837 ) -> Result<()> {
838 if let Ok(server) = self.server.as_running() {
839 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
840 {
841 match event {
842 language::BufferEvent::Edited => {
843 drop(registered_buffer.report_changes(&buffer, cx));
844 }
845 language::BufferEvent::Saved => {
846 server
847 .lsp
848 .notify::<lsp::notification::DidSaveTextDocument>(
849 &lsp::DidSaveTextDocumentParams {
850 text_document: lsp::TextDocumentIdentifier::new(
851 registered_buffer.uri.clone(),
852 ),
853 text: None,
854 },
855 )?;
856 }
857 language::BufferEvent::FileHandleChanged
858 | language::BufferEvent::LanguageChanged => {
859 let new_language_id = id_for_language(buffer.read(cx).language());
860 let Ok(new_uri) = uri_for_buffer(&buffer, cx) else {
861 return Ok(());
862 };
863 if new_uri != registered_buffer.uri
864 || new_language_id != registered_buffer.language_id
865 {
866 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
867 registered_buffer.language_id = new_language_id;
868 server
869 .lsp
870 .notify::<lsp::notification::DidCloseTextDocument>(
871 &lsp::DidCloseTextDocumentParams {
872 text_document: lsp::TextDocumentIdentifier::new(old_uri),
873 },
874 )?;
875 server
876 .lsp
877 .notify::<lsp::notification::DidOpenTextDocument>(
878 &lsp::DidOpenTextDocumentParams {
879 text_document: lsp::TextDocumentItem::new(
880 registered_buffer.uri.clone(),
881 registered_buffer.language_id.clone(),
882 registered_buffer.snapshot_version,
883 registered_buffer.snapshot.text(),
884 ),
885 },
886 )?;
887 }
888 }
889 _ => {}
890 }
891 }
892 }
893
894 Ok(())
895 }
896
897 fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
898 if let Ok(server) = self.server.as_running() {
899 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
900 server
901 .lsp
902 .notify::<lsp::notification::DidCloseTextDocument>(
903 &lsp::DidCloseTextDocumentParams {
904 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
905 },
906 )
907 .ok();
908 }
909 }
910 }
911
912 pub fn completions<T>(
913 &mut self,
914 buffer: &Entity<Buffer>,
915 position: T,
916 cx: &mut Context<Self>,
917 ) -> Task<Result<Vec<Completion>>>
918 where
919 T: ToPointUtf16,
920 {
921 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
922 }
923
924 pub fn completions_cycling<T>(
925 &mut self,
926 buffer: &Entity<Buffer>,
927 position: T,
928 cx: &mut Context<Self>,
929 ) -> Task<Result<Vec<Completion>>>
930 where
931 T: ToPointUtf16,
932 {
933 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
934 }
935
936 pub fn accept_completion(
937 &mut self,
938 completion: &Completion,
939 cx: &mut Context<Self>,
940 ) -> Task<Result<()>> {
941 let server = match self.server.as_authenticated() {
942 Ok(server) => server,
943 Err(error) => return Task::ready(Err(error)),
944 };
945 let request =
946 server
947 .lsp
948 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
949 uuid: completion.uuid.clone(),
950 });
951 cx.background_spawn(async move {
952 request
953 .await
954 .into_response()
955 .context("copilot: notify accepted")?;
956 Ok(())
957 })
958 }
959
960 pub fn discard_completions(
961 &mut self,
962 completions: &[Completion],
963 cx: &mut Context<Self>,
964 ) -> Task<Result<()>> {
965 let server = match self.server.as_authenticated() {
966 Ok(server) => server,
967 Err(_) => return Task::ready(Ok(())),
968 };
969 let request =
970 server
971 .lsp
972 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
973 uuids: completions
974 .iter()
975 .map(|completion| completion.uuid.clone())
976 .collect(),
977 });
978 cx.background_spawn(async move {
979 request
980 .await
981 .into_response()
982 .context("copilot: notify rejected")?;
983 Ok(())
984 })
985 }
986
987 fn request_completions<R, T>(
988 &mut self,
989 buffer: &Entity<Buffer>,
990 position: T,
991 cx: &mut Context<Self>,
992 ) -> Task<Result<Vec<Completion>>>
993 where
994 R: 'static
995 + lsp::request::Request<
996 Params = request::GetCompletionsParams,
997 Result = request::GetCompletionsResult,
998 >,
999 T: ToPointUtf16,
1000 {
1001 self.register_buffer(buffer, cx);
1002
1003 let server = match self.server.as_authenticated() {
1004 Ok(server) => server,
1005 Err(error) => return Task::ready(Err(error)),
1006 };
1007 let lsp = server.lsp.clone();
1008 let registered_buffer = server
1009 .registered_buffers
1010 .get_mut(&buffer.entity_id())
1011 .unwrap();
1012 let snapshot = registered_buffer.report_changes(buffer, cx);
1013 let buffer = buffer.read(cx);
1014 let uri = registered_buffer.uri.clone();
1015 let position = position.to_point_utf16(buffer);
1016 let settings = language_settings(
1017 buffer.language_at(position).map(|l| l.name()),
1018 buffer.file(),
1019 cx,
1020 );
1021 let tab_size = settings.tab_size;
1022 let hard_tabs = settings.hard_tabs;
1023 let relative_path = buffer
1024 .file()
1025 .map(|file| file.path().to_path_buf())
1026 .unwrap_or_default();
1027
1028 cx.background_spawn(async move {
1029 let (version, snapshot) = snapshot.await?;
1030 let result = lsp
1031 .request::<R>(request::GetCompletionsParams {
1032 doc: request::GetCompletionsDocument {
1033 uri,
1034 tab_size: tab_size.into(),
1035 indent_size: 1,
1036 insert_spaces: !hard_tabs,
1037 relative_path: relative_path.to_string_lossy().into(),
1038 position: point_to_lsp(position),
1039 version: version.try_into().unwrap(),
1040 },
1041 })
1042 .await
1043 .into_response()
1044 .context("copilot: get completions")?;
1045 let completions = result
1046 .completions
1047 .into_iter()
1048 .map(|completion| {
1049 let start = snapshot
1050 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
1051 let end =
1052 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
1053 Completion {
1054 uuid: completion.uuid,
1055 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
1056 text: completion.text,
1057 }
1058 })
1059 .collect();
1060 anyhow::Ok(completions)
1061 })
1062 }
1063
1064 pub fn status(&self) -> Status {
1065 match &self.server {
1066 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
1067 CopilotServer::Disabled => Status::Disabled,
1068 CopilotServer::Error(error) => Status::Error(error.clone()),
1069 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
1070 match sign_in_status {
1071 SignInStatus::Authorized { .. } => Status::Authorized,
1072 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
1073 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
1074 prompt: prompt.clone(),
1075 },
1076 SignInStatus::SignedOut {
1077 awaiting_signing_in,
1078 } => Status::SignedOut {
1079 awaiting_signing_in: *awaiting_signing_in,
1080 },
1081 }
1082 }
1083 }
1084 }
1085
1086 fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1087 self.buffers.retain(|buffer| buffer.is_upgradable());
1088
1089 if let Ok(server) = self.server.as_running() {
1090 match lsp_status {
1091 request::SignInStatus::Ok { user: Some(_) }
1092 | request::SignInStatus::MaybeOk { .. }
1093 | request::SignInStatus::AlreadySignedIn { .. } => {
1094 server.sign_in_status = SignInStatus::Authorized;
1095 cx.emit(Event::CopilotAuthSignedIn);
1096 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1097 if let Some(buffer) = buffer.upgrade() {
1098 self.register_buffer(&buffer, cx);
1099 }
1100 }
1101 }
1102 request::SignInStatus::NotAuthorized { .. } => {
1103 server.sign_in_status = SignInStatus::Unauthorized;
1104 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1105 self.unregister_buffer(&buffer);
1106 }
1107 }
1108 request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1109 if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1110 server.sign_in_status = SignInStatus::SignedOut {
1111 awaiting_signing_in: false,
1112 };
1113 }
1114 cx.emit(Event::CopilotAuthSignedOut);
1115 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1116 self.unregister_buffer(&buffer);
1117 }
1118 }
1119 }
1120
1121 cx.notify();
1122 }
1123 }
1124}
1125
1126fn id_for_language(language: Option<&Arc<Language>>) -> String {
1127 language
1128 .map(|language| language.lsp_id())
1129 .unwrap_or_else(|| "plaintext".to_string())
1130}
1131
1132fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> Result<lsp::Url, ()> {
1133 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1134 lsp::Url::from_file_path(file.abs_path(cx))
1135 } else {
1136 format!("buffer://{}", buffer.entity_id())
1137 .parse()
1138 .map_err(|_| ())
1139 }
1140}
1141
1142async fn clear_copilot_dir() {
1143 remove_matching(paths::copilot_dir(), |_| true).await
1144}
1145
1146async fn clear_copilot_config_dir() {
1147 remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1148}
1149
1150async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1151 const PACKAGE_NAME: &str = "@github/copilot-language-server";
1152 const SERVER_PATH: &str =
1153 "node_modules/@github/copilot-language-server/dist/language-server.js";
1154
1155 let latest_version = node_runtime
1156 .npm_package_latest_version(PACKAGE_NAME)
1157 .await?;
1158 let server_path = paths::copilot_dir().join(SERVER_PATH);
1159
1160 fs.create_dir(paths::copilot_dir()).await?;
1161
1162 let should_install = node_runtime
1163 .should_install_npm_package(
1164 PACKAGE_NAME,
1165 &server_path,
1166 paths::copilot_dir(),
1167 &latest_version,
1168 )
1169 .await;
1170 if should_install {
1171 node_runtime
1172 .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1173 .await?;
1174 }
1175
1176 Ok(server_path)
1177}
1178
1179#[cfg(test)]
1180mod tests {
1181 use super::*;
1182 use gpui::TestAppContext;
1183 use util::path;
1184
1185 #[gpui::test(iterations = 10)]
1186 async fn test_buffer_management(cx: &mut TestAppContext) {
1187 let (copilot, mut lsp) = Copilot::fake(cx);
1188
1189 let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1190 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1191 .parse()
1192 .unwrap();
1193 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1194 assert_eq!(
1195 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1196 .await,
1197 lsp::DidOpenTextDocumentParams {
1198 text_document: lsp::TextDocumentItem::new(
1199 buffer_1_uri.clone(),
1200 "plaintext".into(),
1201 0,
1202 "Hello".into()
1203 ),
1204 }
1205 );
1206
1207 let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1208 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1209 .parse()
1210 .unwrap();
1211 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1212 assert_eq!(
1213 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1214 .await,
1215 lsp::DidOpenTextDocumentParams {
1216 text_document: lsp::TextDocumentItem::new(
1217 buffer_2_uri.clone(),
1218 "plaintext".into(),
1219 0,
1220 "Goodbye".into()
1221 ),
1222 }
1223 );
1224
1225 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1226 assert_eq!(
1227 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1228 .await,
1229 lsp::DidChangeTextDocumentParams {
1230 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1231 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1232 range: Some(lsp::Range::new(
1233 lsp::Position::new(0, 5),
1234 lsp::Position::new(0, 5)
1235 )),
1236 range_length: None,
1237 text: " world".into(),
1238 }],
1239 }
1240 );
1241
1242 // Ensure updates to the file are reflected in the LSP.
1243 buffer_1.update(cx, |buffer, cx| {
1244 buffer.file_updated(
1245 Arc::new(File {
1246 abs_path: path!("/root/child/buffer-1").into(),
1247 path: Path::new("child/buffer-1").into(),
1248 }),
1249 cx,
1250 )
1251 });
1252 assert_eq!(
1253 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1254 .await,
1255 lsp::DidCloseTextDocumentParams {
1256 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1257 }
1258 );
1259 let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1260 assert_eq!(
1261 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1262 .await,
1263 lsp::DidOpenTextDocumentParams {
1264 text_document: lsp::TextDocumentItem::new(
1265 buffer_1_uri.clone(),
1266 "plaintext".into(),
1267 1,
1268 "Hello world".into()
1269 ),
1270 }
1271 );
1272
1273 // Ensure all previously-registered buffers are closed when signing out.
1274 lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1275 Ok(request::SignOutResult {})
1276 });
1277 copilot
1278 .update(cx, |copilot, cx| copilot.sign_out(cx))
1279 .await
1280 .unwrap();
1281 assert_eq!(
1282 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1283 .await,
1284 lsp::DidCloseTextDocumentParams {
1285 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1286 }
1287 );
1288 assert_eq!(
1289 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1290 .await,
1291 lsp::DidCloseTextDocumentParams {
1292 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1293 }
1294 );
1295
1296 // Ensure all previously-registered buffers are re-opened when signing in.
1297 lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1298 Ok(request::SignInInitiateResult::AlreadySignedIn {
1299 user: "user-1".into(),
1300 })
1301 });
1302 copilot
1303 .update(cx, |copilot, cx| copilot.sign_in(cx))
1304 .await
1305 .unwrap();
1306
1307 assert_eq!(
1308 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1309 .await,
1310 lsp::DidOpenTextDocumentParams {
1311 text_document: lsp::TextDocumentItem::new(
1312 buffer_1_uri.clone(),
1313 "plaintext".into(),
1314 0,
1315 "Hello world".into()
1316 ),
1317 }
1318 );
1319 assert_eq!(
1320 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1321 .await,
1322 lsp::DidOpenTextDocumentParams {
1323 text_document: lsp::TextDocumentItem::new(
1324 buffer_2_uri.clone(),
1325 "plaintext".into(),
1326 0,
1327 "Goodbye".into()
1328 ),
1329 }
1330 );
1331 // Dropping a buffer causes it to be closed on the LSP side as well.
1332 cx.update(|_| drop(buffer_2));
1333 assert_eq!(
1334 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1335 .await,
1336 lsp::DidCloseTextDocumentParams {
1337 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1338 }
1339 );
1340 }
1341
1342 struct File {
1343 abs_path: PathBuf,
1344 path: Arc<Path>,
1345 }
1346
1347 impl language::File for File {
1348 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1349 Some(self)
1350 }
1351
1352 fn disk_state(&self) -> language::DiskState {
1353 language::DiskState::Present {
1354 mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1355 }
1356 }
1357
1358 fn path(&self) -> &Arc<Path> {
1359 &self.path
1360 }
1361
1362 fn full_path(&self, _: &App) -> PathBuf {
1363 unimplemented!()
1364 }
1365
1366 fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1367 unimplemented!()
1368 }
1369
1370 fn to_proto(&self, _: &App) -> rpc::proto::File {
1371 unimplemented!()
1372 }
1373
1374 fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1375 settings::WorktreeId::from_usize(0)
1376 }
1377
1378 fn is_private(&self) -> bool {
1379 false
1380 }
1381 }
1382
1383 impl language::LocalFile for File {
1384 fn abs_path(&self, _: &App) -> PathBuf {
1385 self.abs_path.clone()
1386 }
1387
1388 fn load(&self, _: &App) -> Task<Result<String>> {
1389 unimplemented!()
1390 }
1391
1392 fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1393 unimplemented!()
1394 }
1395 }
1396}
1397
1398#[cfg(test)]
1399#[ctor::ctor]
1400fn init_logger() {
1401 zlog::init_test();
1402}