1pub mod copilot_chat;
2mod copilot_completion_provider;
3pub mod request;
4mod sign_in;
5
6use crate::sign_in::initiate_sign_in_within_workspace;
7use ::fs::Fs;
8use anyhow::{Context as _, Result, anyhow};
9use client::DisableAiSettings;
10use collections::{HashMap, HashSet};
11use command_palette_hooks::CommandPaletteFilter;
12use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
13use gpui::{
14 App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
15 WeakEntity, actions,
16};
17use http_client::HttpClient;
18use language::language_settings::CopilotSettings;
19use language::{
20 Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
21 language_settings::{EditPredictionProvider, all_language_settings, language_settings},
22 point_from_lsp, point_to_lsp,
23};
24use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
25use node_runtime::NodeRuntime;
26use parking_lot::Mutex;
27use request::StatusNotification;
28use serde_json::json;
29use settings::Settings;
30use settings::SettingsStore;
31use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
32use std::collections::hash_map::Entry;
33use std::{
34 any::TypeId,
35 env,
36 ffi::OsString,
37 mem,
38 ops::Range,
39 path::{Path, PathBuf},
40 sync::Arc,
41};
42use util::{ResultExt, fs::remove_matching};
43use workspace::Workspace;
44
45pub use crate::copilot_completion_provider::CopilotCompletionProvider;
46pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
47
48actions!(
49 copilot,
50 [
51 /// Requests a code completion suggestion from Copilot.
52 Suggest,
53 /// Cycles to the next Copilot suggestion.
54 NextSuggestion,
55 /// Cycles to the previous Copilot suggestion.
56 PreviousSuggestion,
57 /// Reinstalls the Copilot language server.
58 Reinstall,
59 /// Signs in to GitHub Copilot.
60 SignIn,
61 /// Signs out of GitHub Copilot.
62 SignOut
63 ]
64);
65
66pub fn init(
67 new_server_id: LanguageServerId,
68 fs: Arc<dyn Fs>,
69 http: Arc<dyn HttpClient>,
70 node_runtime: NodeRuntime,
71 cx: &mut App,
72) {
73 let language_settings = all_language_settings(None, cx);
74 let configuration = copilot_chat::CopilotChatConfiguration {
75 enterprise_uri: language_settings
76 .edit_predictions
77 .copilot
78 .enterprise_uri
79 .clone(),
80 };
81 copilot_chat::init(fs.clone(), http.clone(), configuration, cx);
82
83 let copilot = cx.new({
84 let node_runtime = node_runtime.clone();
85 move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
86 });
87 Copilot::set_global(copilot.clone(), cx);
88 cx.observe(&copilot, |handle, cx| {
89 let copilot_action_types = [
90 TypeId::of::<Suggest>(),
91 TypeId::of::<NextSuggestion>(),
92 TypeId::of::<PreviousSuggestion>(),
93 TypeId::of::<Reinstall>(),
94 ];
95 let copilot_auth_action_types = [TypeId::of::<SignOut>()];
96 let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
97 let status = handle.read(cx).status();
98
99 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
100 let filter = CommandPaletteFilter::global_mut(cx);
101
102 if is_ai_disabled {
103 filter.hide_action_types(&copilot_action_types);
104 filter.hide_action_types(&copilot_auth_action_types);
105 filter.hide_action_types(&copilot_no_auth_action_types);
106 } else {
107 match status {
108 Status::Disabled => {
109 filter.hide_action_types(&copilot_action_types);
110 filter.hide_action_types(&copilot_auth_action_types);
111 filter.hide_action_types(&copilot_no_auth_action_types);
112 }
113 Status::Authorized => {
114 filter.hide_action_types(&copilot_no_auth_action_types);
115 filter.show_action_types(
116 copilot_action_types
117 .iter()
118 .chain(&copilot_auth_action_types),
119 );
120 }
121 _ => {
122 filter.hide_action_types(&copilot_action_types);
123 filter.hide_action_types(&copilot_auth_action_types);
124 filter.show_action_types(copilot_no_auth_action_types.iter());
125 }
126 }
127 }
128 })
129 .detach();
130
131 cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
132 workspace.register_action(|workspace, _: &SignIn, window, cx| {
133 if let Some(copilot) = Copilot::global(cx) {
134 let is_reinstall = false;
135 initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
136 }
137 });
138 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
139 if let Some(copilot) = Copilot::global(cx) {
140 reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
141 }
142 });
143 workspace.register_action(|workspace, _: &SignOut, _window, cx| {
144 if let Some(copilot) = Copilot::global(cx) {
145 sign_out_within_workspace(workspace, copilot, cx);
146 }
147 });
148 })
149 .detach();
150}
151
152enum CopilotServer {
153 Disabled,
154 Starting { task: Shared<Task<()>> },
155 Error(Arc<str>),
156 Running(RunningCopilotServer),
157}
158
159impl CopilotServer {
160 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
161 let server = self.as_running()?;
162 anyhow::ensure!(
163 matches!(server.sign_in_status, SignInStatus::Authorized { .. }),
164 "must sign in before using copilot"
165 );
166 Ok(server)
167 }
168
169 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
170 match self {
171 CopilotServer::Starting { .. } => anyhow::bail!("copilot is still starting"),
172 CopilotServer::Disabled => anyhow::bail!("copilot is disabled"),
173 CopilotServer::Error(error) => {
174 anyhow::bail!("copilot was not started because of an error: {error}")
175 }
176 CopilotServer::Running(server) => Ok(server),
177 }
178 }
179}
180
181struct RunningCopilotServer {
182 lsp: Arc<LanguageServer>,
183 sign_in_status: SignInStatus,
184 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
185}
186
187#[derive(Clone, Debug)]
188enum SignInStatus {
189 Authorized,
190 Unauthorized,
191 SigningIn {
192 prompt: Option<request::PromptUserDeviceFlow>,
193 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
194 },
195 SignedOut {
196 awaiting_signing_in: bool,
197 },
198}
199
200#[derive(Debug, Clone)]
201pub enum Status {
202 Starting {
203 task: Shared<Task<()>>,
204 },
205 Error(Arc<str>),
206 Disabled,
207 SignedOut {
208 awaiting_signing_in: bool,
209 },
210 SigningIn {
211 prompt: Option<request::PromptUserDeviceFlow>,
212 },
213 Unauthorized,
214 Authorized,
215}
216
217impl Status {
218 pub fn is_authorized(&self) -> bool {
219 matches!(self, Status::Authorized)
220 }
221
222 pub fn is_configured(&self) -> bool {
223 matches!(
224 self,
225 Status::Starting { .. }
226 | Status::Error(_)
227 | Status::SigningIn { .. }
228 | Status::Authorized
229 )
230 }
231}
232
233struct RegisteredBuffer {
234 uri: lsp::Url,
235 language_id: String,
236 snapshot: BufferSnapshot,
237 snapshot_version: i32,
238 _subscriptions: [gpui::Subscription; 2],
239 pending_buffer_change: Task<Option<()>>,
240}
241
242impl RegisteredBuffer {
243 fn report_changes(
244 &mut self,
245 buffer: &Entity<Buffer>,
246 cx: &mut Context<Copilot>,
247 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
248 let (done_tx, done_rx) = oneshot::channel();
249
250 if buffer.read(cx).version() == self.snapshot.version {
251 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
252 } else {
253 let buffer = buffer.downgrade();
254 let id = buffer.entity_id();
255 let prev_pending_change =
256 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
257 self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
258 prev_pending_change.await;
259
260 let old_version = copilot
261 .update(cx, |copilot, _| {
262 let server = copilot.server.as_authenticated().log_err()?;
263 let buffer = server.registered_buffers.get_mut(&id)?;
264 Some(buffer.snapshot.version.clone())
265 })
266 .ok()??;
267 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()).ok()?;
268
269 let content_changes = cx
270 .background_spawn({
271 let new_snapshot = new_snapshot.clone();
272 async move {
273 new_snapshot
274 .edits_since::<(PointUtf16, usize)>(&old_version)
275 .map(|edit| {
276 let edit_start = edit.new.start.0;
277 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
278 let new_text = new_snapshot
279 .text_for_range(edit.new.start.1..edit.new.end.1)
280 .collect();
281 lsp::TextDocumentContentChangeEvent {
282 range: Some(lsp::Range::new(
283 point_to_lsp(edit_start),
284 point_to_lsp(edit_end),
285 )),
286 range_length: None,
287 text: new_text,
288 }
289 })
290 .collect::<Vec<_>>()
291 }
292 })
293 .await;
294
295 copilot
296 .update(cx, |copilot, _| {
297 let server = copilot.server.as_authenticated().log_err()?;
298 let buffer = server.registered_buffers.get_mut(&id)?;
299 if !content_changes.is_empty() {
300 buffer.snapshot_version += 1;
301 buffer.snapshot = new_snapshot;
302 server
303 .lsp
304 .notify::<lsp::notification::DidChangeTextDocument>(
305 &lsp::DidChangeTextDocumentParams {
306 text_document: lsp::VersionedTextDocumentIdentifier::new(
307 buffer.uri.clone(),
308 buffer.snapshot_version,
309 ),
310 content_changes,
311 },
312 )
313 .ok();
314 }
315 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
316 Some(())
317 })
318 .ok()?;
319
320 Some(())
321 });
322 }
323
324 done_rx
325 }
326}
327
328#[derive(Debug)]
329pub struct Completion {
330 pub uuid: String,
331 pub range: Range<Anchor>,
332 pub text: String,
333}
334
335pub struct Copilot {
336 fs: Arc<dyn Fs>,
337 node_runtime: NodeRuntime,
338 server: CopilotServer,
339 buffers: HashSet<WeakEntity<Buffer>>,
340 server_id: LanguageServerId,
341 _subscription: gpui::Subscription,
342}
343
344pub enum Event {
345 CopilotLanguageServerStarted,
346 CopilotAuthSignedIn,
347 CopilotAuthSignedOut,
348}
349
350impl EventEmitter<Event> for Copilot {}
351
352struct GlobalCopilot(Entity<Copilot>);
353
354impl Global for GlobalCopilot {}
355
356impl Copilot {
357 pub fn global(cx: &App) -> Option<Entity<Self>> {
358 cx.try_global::<GlobalCopilot>()
359 .map(|model| model.0.clone())
360 }
361
362 pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
363 cx.set_global(GlobalCopilot(copilot));
364 }
365
366 fn start(
367 new_server_id: LanguageServerId,
368 fs: Arc<dyn Fs>,
369 node_runtime: NodeRuntime,
370 cx: &mut Context<Self>,
371 ) -> Self {
372 let mut this = Self {
373 server_id: new_server_id,
374 fs,
375 node_runtime,
376 server: CopilotServer::Disabled,
377 buffers: Default::default(),
378 _subscription: cx.on_app_quit(Self::shutdown_language_server),
379 };
380 this.start_copilot(true, false, cx);
381 cx.observe_global::<SettingsStore>(move |this, cx| {
382 this.start_copilot(true, false, cx);
383 this.send_configuration_update(cx);
384 })
385 .detach();
386 this
387 }
388
389 fn shutdown_language_server(
390 &mut self,
391 _cx: &mut Context<Self>,
392 ) -> impl Future<Output = ()> + use<> {
393 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
394 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
395 _ => None,
396 };
397
398 async move {
399 if let Some(shutdown) = shutdown {
400 shutdown.await;
401 }
402 }
403 }
404
405 fn start_copilot(
406 &mut self,
407 check_edit_prediction_provider: bool,
408 awaiting_sign_in_after_start: bool,
409 cx: &mut Context<Self>,
410 ) {
411 if !matches!(self.server, CopilotServer::Disabled) {
412 return;
413 }
414 let language_settings = all_language_settings(None, cx);
415 if check_edit_prediction_provider
416 && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
417 {
418 return;
419 }
420 let server_id = self.server_id;
421 let fs = self.fs.clone();
422 let node_runtime = self.node_runtime.clone();
423 let env = self.build_env(&language_settings.edit_predictions.copilot);
424 let start_task = cx
425 .spawn(async move |this, cx| {
426 Self::start_language_server(
427 server_id,
428 fs,
429 node_runtime,
430 env,
431 this,
432 awaiting_sign_in_after_start,
433 cx,
434 )
435 .await
436 })
437 .shared();
438 self.server = CopilotServer::Starting { task: start_task };
439 cx.notify();
440 }
441
442 fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
443 let proxy_url = copilot_settings.proxy.clone()?;
444 let no_verify = copilot_settings.proxy_no_verify;
445 let http_or_https_proxy = if proxy_url.starts_with("http:") {
446 Some("HTTP_PROXY")
447 } else if proxy_url.starts_with("https:") {
448 Some("HTTPS_PROXY")
449 } else {
450 log::error!(
451 "Unsupported protocol scheme for language server proxy (must be http or https)"
452 );
453 None
454 };
455
456 let mut env = HashMap::default();
457
458 if let Some(proxy_type) = http_or_https_proxy {
459 env.insert(proxy_type.to_string(), proxy_url);
460 if let Some(true) = no_verify {
461 env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
462 };
463 }
464
465 if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) {
466 env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token);
467 }
468
469 if env.is_empty() { None } else { Some(env) }
470 }
471
472 fn send_configuration_update(&mut self, cx: &mut Context<Self>) {
473 let copilot_settings = all_language_settings(None, cx)
474 .edit_predictions
475 .copilot
476 .clone();
477
478 let settings = json!({
479 "http": {
480 "proxy": copilot_settings.proxy,
481 "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false)
482 },
483 "github-enterprise": {
484 "uri": copilot_settings.enterprise_uri
485 }
486 });
487
488 if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
489 copilot_chat.update(cx, |chat, cx| {
490 chat.set_configuration(
491 copilot_chat::CopilotChatConfiguration {
492 enterprise_uri: copilot_settings.enterprise_uri.clone(),
493 },
494 cx,
495 );
496 });
497 }
498
499 if let Ok(server) = self.server.as_running() {
500 server
501 .lsp
502 .notify::<lsp::notification::DidChangeConfiguration>(
503 &lsp::DidChangeConfigurationParams { settings },
504 )
505 .log_err();
506 }
507 }
508
509 #[cfg(any(test, feature = "test-support"))]
510 pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
511 use fs::FakeFs;
512 use lsp::FakeLanguageServer;
513 use node_runtime::NodeRuntime;
514
515 let (server, fake_server) = FakeLanguageServer::new(
516 LanguageServerId(0),
517 LanguageServerBinary {
518 path: "path/to/copilot".into(),
519 arguments: vec![],
520 env: None,
521 },
522 "copilot".into(),
523 Default::default(),
524 &mut cx.to_async(),
525 );
526 let node_runtime = NodeRuntime::unavailable();
527 let this = cx.new(|cx| Self {
528 server_id: LanguageServerId(0),
529 fs: FakeFs::new(cx.background_executor().clone()),
530 node_runtime,
531 server: CopilotServer::Running(RunningCopilotServer {
532 lsp: Arc::new(server),
533 sign_in_status: SignInStatus::Authorized,
534 registered_buffers: Default::default(),
535 }),
536 _subscription: cx.on_app_quit(Self::shutdown_language_server),
537 buffers: Default::default(),
538 });
539 (this, fake_server)
540 }
541
542 async fn start_language_server(
543 new_server_id: LanguageServerId,
544 fs: Arc<dyn Fs>,
545 node_runtime: NodeRuntime,
546 env: Option<HashMap<String, String>>,
547 this: WeakEntity<Self>,
548 awaiting_sign_in_after_start: bool,
549 cx: &mut AsyncApp,
550 ) {
551 let start_language_server = async {
552 let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
553 let node_path = node_runtime.binary_path().await?;
554 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
555 let binary = LanguageServerBinary {
556 path: node_path,
557 arguments,
558 env,
559 };
560
561 let root_path = if cfg!(target_os = "windows") {
562 Path::new("C:/")
563 } else {
564 Path::new("/")
565 };
566
567 let server_name = LanguageServerName("copilot".into());
568 let server = LanguageServer::new(
569 Arc::new(Mutex::new(None)),
570 new_server_id,
571 server_name,
572 binary,
573 root_path,
574 None,
575 Default::default(),
576 cx,
577 )?;
578
579 server
580 .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
581 .detach();
582
583 let configuration = lsp::DidChangeConfigurationParams {
584 settings: Default::default(),
585 };
586
587 let editor_info = request::SetEditorInfoParams {
588 editor_info: request::EditorInfo {
589 name: "zed".into(),
590 version: env!("CARGO_PKG_VERSION").into(),
591 },
592 editor_plugin_info: request::EditorPluginInfo {
593 name: "zed-copilot".into(),
594 version: "0.0.1".into(),
595 },
596 };
597 let editor_info_json = serde_json::to_value(&editor_info)?;
598
599 let server = cx
600 .update(|cx| {
601 let mut params = server.default_initialize_params(false, cx);
602 params.initialization_options = Some(editor_info_json);
603 server.initialize(params, configuration.into(), cx)
604 })?
605 .await?;
606
607 let status = server
608 .request::<request::CheckStatus>(request::CheckStatusParams {
609 local_checks_only: false,
610 })
611 .await
612 .into_response()
613 .context("copilot: check status")?;
614
615 anyhow::Ok((server, status))
616 };
617
618 let server = start_language_server.await;
619 this.update(cx, |this, cx| {
620 cx.notify();
621 match server {
622 Ok((server, status)) => {
623 this.server = CopilotServer::Running(RunningCopilotServer {
624 lsp: server,
625 sign_in_status: SignInStatus::SignedOut {
626 awaiting_signing_in: awaiting_sign_in_after_start,
627 },
628 registered_buffers: Default::default(),
629 });
630 cx.emit(Event::CopilotLanguageServerStarted);
631 this.update_sign_in_status(status, cx);
632 // Send configuration now that the LSP is fully started
633 this.send_configuration_update(cx);
634 }
635 Err(error) => {
636 this.server = CopilotServer::Error(error.to_string().into());
637 cx.notify()
638 }
639 }
640 })
641 .ok();
642 }
643
644 pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
645 if let CopilotServer::Running(server) = &mut self.server {
646 let task = match &server.sign_in_status {
647 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
648 SignInStatus::SigningIn { task, .. } => {
649 cx.notify();
650 task.clone()
651 }
652 SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
653 let lsp = server.lsp.clone();
654 let task = cx
655 .spawn(async move |this, cx| {
656 let sign_in = async {
657 let sign_in = lsp
658 .request::<request::SignInInitiate>(
659 request::SignInInitiateParams {},
660 )
661 .await
662 .into_response()
663 .context("copilot sign-in")?;
664 match sign_in {
665 request::SignInInitiateResult::AlreadySignedIn { user } => {
666 Ok(request::SignInStatus::Ok { user: Some(user) })
667 }
668 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
669 this.update(cx, |this, cx| {
670 if let CopilotServer::Running(RunningCopilotServer {
671 sign_in_status: status,
672 ..
673 }) = &mut this.server
674 {
675 if let SignInStatus::SigningIn {
676 prompt: prompt_flow,
677 ..
678 } = status
679 {
680 *prompt_flow = Some(flow.clone());
681 cx.notify();
682 }
683 }
684 })?;
685 let response = lsp
686 .request::<request::SignInConfirm>(
687 request::SignInConfirmParams {
688 user_code: flow.user_code,
689 },
690 )
691 .await
692 .into_response()
693 .context("copilot: sign in confirm")?;
694 Ok(response)
695 }
696 }
697 };
698
699 let sign_in = sign_in.await;
700 this.update(cx, |this, cx| match sign_in {
701 Ok(status) => {
702 this.update_sign_in_status(status, cx);
703 Ok(())
704 }
705 Err(error) => {
706 this.update_sign_in_status(
707 request::SignInStatus::NotSignedIn,
708 cx,
709 );
710 Err(Arc::new(error))
711 }
712 })?
713 })
714 .shared();
715 server.sign_in_status = SignInStatus::SigningIn {
716 prompt: None,
717 task: task.clone(),
718 };
719 cx.notify();
720 task
721 }
722 };
723
724 cx.background_spawn(task.map_err(|err| anyhow!("{err:?}")))
725 } else {
726 // If we're downloading, wait until download is finished
727 // If we're in a stuck state, display to the user
728 Task::ready(Err(anyhow!("copilot hasn't started yet")))
729 }
730 }
731
732 pub(crate) fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
733 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
734 match &self.server {
735 CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
736 let server = server.clone();
737 cx.background_spawn(async move {
738 server
739 .request::<request::SignOut>(request::SignOutParams {})
740 .await
741 .into_response()
742 .context("copilot: sign in confirm")?;
743 anyhow::Ok(())
744 })
745 }
746 CopilotServer::Disabled => cx.background_spawn(async {
747 clear_copilot_config_dir().await;
748 anyhow::Ok(())
749 }),
750 _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
751 }
752 }
753
754 pub(crate) fn reinstall(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
755 let language_settings = all_language_settings(None, cx);
756 let env = self.build_env(&language_settings.edit_predictions.copilot);
757 let start_task = cx
758 .spawn({
759 let fs = self.fs.clone();
760 let node_runtime = self.node_runtime.clone();
761 let server_id = self.server_id;
762 async move |this, cx| {
763 clear_copilot_dir().await;
764 Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
765 .await
766 }
767 })
768 .shared();
769
770 self.server = CopilotServer::Starting {
771 task: start_task.clone(),
772 };
773
774 cx.notify();
775
776 start_task
777 }
778
779 pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
780 if let CopilotServer::Running(server) = &self.server {
781 Some(&server.lsp)
782 } else {
783 None
784 }
785 }
786
787 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
788 let weak_buffer = buffer.downgrade();
789 self.buffers.insert(weak_buffer.clone());
790
791 if let CopilotServer::Running(RunningCopilotServer {
792 lsp: server,
793 sign_in_status: status,
794 registered_buffers,
795 ..
796 }) = &mut self.server
797 {
798 if !matches!(status, SignInStatus::Authorized { .. }) {
799 return;
800 }
801
802 let entry = registered_buffers.entry(buffer.entity_id());
803 if let Entry::Vacant(e) = entry {
804 let Ok(uri) = uri_for_buffer(buffer, cx) else {
805 return;
806 };
807 let language_id = id_for_language(buffer.read(cx).language());
808 let snapshot = buffer.read(cx).snapshot();
809 server
810 .notify::<lsp::notification::DidOpenTextDocument>(
811 &lsp::DidOpenTextDocumentParams {
812 text_document: lsp::TextDocumentItem {
813 uri: uri.clone(),
814 language_id: language_id.clone(),
815 version: 0,
816 text: snapshot.text(),
817 },
818 },
819 )
820 .ok();
821
822 e.insert(RegisteredBuffer {
823 uri,
824 language_id,
825 snapshot,
826 snapshot_version: 0,
827 pending_buffer_change: Task::ready(Some(())),
828 _subscriptions: [
829 cx.subscribe(buffer, |this, buffer, event, cx| {
830 this.handle_buffer_event(buffer, event, cx).log_err();
831 }),
832 cx.observe_release(buffer, move |this, _buffer, _cx| {
833 this.buffers.remove(&weak_buffer);
834 this.unregister_buffer(&weak_buffer);
835 }),
836 ],
837 });
838 }
839 }
840 }
841
842 fn handle_buffer_event(
843 &mut self,
844 buffer: Entity<Buffer>,
845 event: &language::BufferEvent,
846 cx: &mut Context<Self>,
847 ) -> Result<()> {
848 if let Ok(server) = self.server.as_running() {
849 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
850 {
851 match event {
852 language::BufferEvent::Edited => {
853 drop(registered_buffer.report_changes(&buffer, cx));
854 }
855 language::BufferEvent::Saved => {
856 server
857 .lsp
858 .notify::<lsp::notification::DidSaveTextDocument>(
859 &lsp::DidSaveTextDocumentParams {
860 text_document: lsp::TextDocumentIdentifier::new(
861 registered_buffer.uri.clone(),
862 ),
863 text: None,
864 },
865 )?;
866 }
867 language::BufferEvent::FileHandleChanged
868 | language::BufferEvent::LanguageChanged => {
869 let new_language_id = id_for_language(buffer.read(cx).language());
870 let Ok(new_uri) = uri_for_buffer(&buffer, cx) else {
871 return Ok(());
872 };
873 if new_uri != registered_buffer.uri
874 || new_language_id != registered_buffer.language_id
875 {
876 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
877 registered_buffer.language_id = new_language_id;
878 server
879 .lsp
880 .notify::<lsp::notification::DidCloseTextDocument>(
881 &lsp::DidCloseTextDocumentParams {
882 text_document: lsp::TextDocumentIdentifier::new(old_uri),
883 },
884 )?;
885 server
886 .lsp
887 .notify::<lsp::notification::DidOpenTextDocument>(
888 &lsp::DidOpenTextDocumentParams {
889 text_document: lsp::TextDocumentItem::new(
890 registered_buffer.uri.clone(),
891 registered_buffer.language_id.clone(),
892 registered_buffer.snapshot_version,
893 registered_buffer.snapshot.text(),
894 ),
895 },
896 )?;
897 }
898 }
899 _ => {}
900 }
901 }
902 }
903
904 Ok(())
905 }
906
907 fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
908 if let Ok(server) = self.server.as_running() {
909 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
910 server
911 .lsp
912 .notify::<lsp::notification::DidCloseTextDocument>(
913 &lsp::DidCloseTextDocumentParams {
914 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
915 },
916 )
917 .ok();
918 }
919 }
920 }
921
922 pub fn completions<T>(
923 &mut self,
924 buffer: &Entity<Buffer>,
925 position: T,
926 cx: &mut Context<Self>,
927 ) -> Task<Result<Vec<Completion>>>
928 where
929 T: ToPointUtf16,
930 {
931 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
932 }
933
934 pub fn completions_cycling<T>(
935 &mut self,
936 buffer: &Entity<Buffer>,
937 position: T,
938 cx: &mut Context<Self>,
939 ) -> Task<Result<Vec<Completion>>>
940 where
941 T: ToPointUtf16,
942 {
943 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
944 }
945
946 pub fn accept_completion(
947 &mut self,
948 completion: &Completion,
949 cx: &mut Context<Self>,
950 ) -> Task<Result<()>> {
951 let server = match self.server.as_authenticated() {
952 Ok(server) => server,
953 Err(error) => return Task::ready(Err(error)),
954 };
955 let request =
956 server
957 .lsp
958 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
959 uuid: completion.uuid.clone(),
960 });
961 cx.background_spawn(async move {
962 request
963 .await
964 .into_response()
965 .context("copilot: notify accepted")?;
966 Ok(())
967 })
968 }
969
970 pub fn discard_completions(
971 &mut self,
972 completions: &[Completion],
973 cx: &mut Context<Self>,
974 ) -> Task<Result<()>> {
975 let server = match self.server.as_authenticated() {
976 Ok(server) => server,
977 Err(_) => return Task::ready(Ok(())),
978 };
979 let request =
980 server
981 .lsp
982 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
983 uuids: completions
984 .iter()
985 .map(|completion| completion.uuid.clone())
986 .collect(),
987 });
988 cx.background_spawn(async move {
989 request
990 .await
991 .into_response()
992 .context("copilot: notify rejected")?;
993 Ok(())
994 })
995 }
996
997 fn request_completions<R, T>(
998 &mut self,
999 buffer: &Entity<Buffer>,
1000 position: T,
1001 cx: &mut Context<Self>,
1002 ) -> Task<Result<Vec<Completion>>>
1003 where
1004 R: 'static
1005 + lsp::request::Request<
1006 Params = request::GetCompletionsParams,
1007 Result = request::GetCompletionsResult,
1008 >,
1009 T: ToPointUtf16,
1010 {
1011 self.register_buffer(buffer, cx);
1012
1013 let server = match self.server.as_authenticated() {
1014 Ok(server) => server,
1015 Err(error) => return Task::ready(Err(error)),
1016 };
1017 let lsp = server.lsp.clone();
1018 let registered_buffer = server
1019 .registered_buffers
1020 .get_mut(&buffer.entity_id())
1021 .unwrap();
1022 let snapshot = registered_buffer.report_changes(buffer, cx);
1023 let buffer = buffer.read(cx);
1024 let uri = registered_buffer.uri.clone();
1025 let position = position.to_point_utf16(buffer);
1026 let settings = language_settings(
1027 buffer.language_at(position).map(|l| l.name()),
1028 buffer.file(),
1029 cx,
1030 );
1031 let tab_size = settings.tab_size;
1032 let hard_tabs = settings.hard_tabs;
1033 let relative_path = buffer
1034 .file()
1035 .map(|file| file.path().to_path_buf())
1036 .unwrap_or_default();
1037
1038 cx.background_spawn(async move {
1039 let (version, snapshot) = snapshot.await?;
1040 let result = lsp
1041 .request::<R>(request::GetCompletionsParams {
1042 doc: request::GetCompletionsDocument {
1043 uri,
1044 tab_size: tab_size.into(),
1045 indent_size: 1,
1046 insert_spaces: !hard_tabs,
1047 relative_path: relative_path.to_string_lossy().into(),
1048 position: point_to_lsp(position),
1049 version: version.try_into().unwrap(),
1050 },
1051 })
1052 .await
1053 .into_response()
1054 .context("copilot: get completions")?;
1055 let completions = result
1056 .completions
1057 .into_iter()
1058 .map(|completion| {
1059 let start = snapshot
1060 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
1061 let end =
1062 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
1063 Completion {
1064 uuid: completion.uuid,
1065 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
1066 text: completion.text,
1067 }
1068 })
1069 .collect();
1070 anyhow::Ok(completions)
1071 })
1072 }
1073
1074 pub fn status(&self) -> Status {
1075 match &self.server {
1076 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
1077 CopilotServer::Disabled => Status::Disabled,
1078 CopilotServer::Error(error) => Status::Error(error.clone()),
1079 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
1080 match sign_in_status {
1081 SignInStatus::Authorized { .. } => Status::Authorized,
1082 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
1083 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
1084 prompt: prompt.clone(),
1085 },
1086 SignInStatus::SignedOut {
1087 awaiting_signing_in,
1088 } => Status::SignedOut {
1089 awaiting_signing_in: *awaiting_signing_in,
1090 },
1091 }
1092 }
1093 }
1094 }
1095
1096 fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1097 self.buffers.retain(|buffer| buffer.is_upgradable());
1098
1099 if let Ok(server) = self.server.as_running() {
1100 match lsp_status {
1101 request::SignInStatus::Ok { user: Some(_) }
1102 | request::SignInStatus::MaybeOk { .. }
1103 | request::SignInStatus::AlreadySignedIn { .. } => {
1104 server.sign_in_status = SignInStatus::Authorized;
1105 cx.emit(Event::CopilotAuthSignedIn);
1106 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1107 if let Some(buffer) = buffer.upgrade() {
1108 self.register_buffer(&buffer, cx);
1109 }
1110 }
1111 }
1112 request::SignInStatus::NotAuthorized { .. } => {
1113 server.sign_in_status = SignInStatus::Unauthorized;
1114 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1115 self.unregister_buffer(&buffer);
1116 }
1117 }
1118 request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1119 if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1120 server.sign_in_status = SignInStatus::SignedOut {
1121 awaiting_signing_in: false,
1122 };
1123 }
1124 cx.emit(Event::CopilotAuthSignedOut);
1125 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1126 self.unregister_buffer(&buffer);
1127 }
1128 }
1129 }
1130
1131 cx.notify();
1132 }
1133 }
1134}
1135
1136fn id_for_language(language: Option<&Arc<Language>>) -> String {
1137 language
1138 .map(|language| language.lsp_id())
1139 .unwrap_or_else(|| "plaintext".to_string())
1140}
1141
1142fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> Result<lsp::Url, ()> {
1143 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1144 lsp::Url::from_file_path(file.abs_path(cx))
1145 } else {
1146 format!("buffer://{}", buffer.entity_id())
1147 .parse()
1148 .map_err(|_| ())
1149 }
1150}
1151
1152async fn clear_copilot_dir() {
1153 remove_matching(paths::copilot_dir(), |_| true).await
1154}
1155
1156async fn clear_copilot_config_dir() {
1157 remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1158}
1159
1160async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1161 const PACKAGE_NAME: &str = "@github/copilot-language-server";
1162 const SERVER_PATH: &str =
1163 "node_modules/@github/copilot-language-server/dist/language-server.js";
1164
1165 let latest_version = node_runtime
1166 .npm_package_latest_version(PACKAGE_NAME)
1167 .await?;
1168 let server_path = paths::copilot_dir().join(SERVER_PATH);
1169
1170 fs.create_dir(paths::copilot_dir()).await?;
1171
1172 let should_install = node_runtime
1173 .should_install_npm_package(
1174 PACKAGE_NAME,
1175 &server_path,
1176 paths::copilot_dir(),
1177 &latest_version,
1178 )
1179 .await;
1180 if should_install {
1181 node_runtime
1182 .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1183 .await?;
1184 }
1185
1186 Ok(server_path)
1187}
1188
1189#[cfg(test)]
1190mod tests {
1191 use super::*;
1192 use gpui::TestAppContext;
1193 use util::path;
1194
1195 #[gpui::test(iterations = 10)]
1196 async fn test_buffer_management(cx: &mut TestAppContext) {
1197 let (copilot, mut lsp) = Copilot::fake(cx);
1198
1199 let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1200 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1201 .parse()
1202 .unwrap();
1203 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1204 assert_eq!(
1205 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1206 .await,
1207 lsp::DidOpenTextDocumentParams {
1208 text_document: lsp::TextDocumentItem::new(
1209 buffer_1_uri.clone(),
1210 "plaintext".into(),
1211 0,
1212 "Hello".into()
1213 ),
1214 }
1215 );
1216
1217 let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1218 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1219 .parse()
1220 .unwrap();
1221 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1222 assert_eq!(
1223 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1224 .await,
1225 lsp::DidOpenTextDocumentParams {
1226 text_document: lsp::TextDocumentItem::new(
1227 buffer_2_uri.clone(),
1228 "plaintext".into(),
1229 0,
1230 "Goodbye".into()
1231 ),
1232 }
1233 );
1234
1235 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1236 assert_eq!(
1237 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1238 .await,
1239 lsp::DidChangeTextDocumentParams {
1240 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1241 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1242 range: Some(lsp::Range::new(
1243 lsp::Position::new(0, 5),
1244 lsp::Position::new(0, 5)
1245 )),
1246 range_length: None,
1247 text: " world".into(),
1248 }],
1249 }
1250 );
1251
1252 // Ensure updates to the file are reflected in the LSP.
1253 buffer_1.update(cx, |buffer, cx| {
1254 buffer.file_updated(
1255 Arc::new(File {
1256 abs_path: path!("/root/child/buffer-1").into(),
1257 path: Path::new("child/buffer-1").into(),
1258 }),
1259 cx,
1260 )
1261 });
1262 assert_eq!(
1263 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1264 .await,
1265 lsp::DidCloseTextDocumentParams {
1266 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1267 }
1268 );
1269 let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1270 assert_eq!(
1271 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1272 .await,
1273 lsp::DidOpenTextDocumentParams {
1274 text_document: lsp::TextDocumentItem::new(
1275 buffer_1_uri.clone(),
1276 "plaintext".into(),
1277 1,
1278 "Hello world".into()
1279 ),
1280 }
1281 );
1282
1283 // Ensure all previously-registered buffers are closed when signing out.
1284 lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1285 Ok(request::SignOutResult {})
1286 });
1287 copilot
1288 .update(cx, |copilot, cx| copilot.sign_out(cx))
1289 .await
1290 .unwrap();
1291 assert_eq!(
1292 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1293 .await,
1294 lsp::DidCloseTextDocumentParams {
1295 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1296 }
1297 );
1298 assert_eq!(
1299 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1300 .await,
1301 lsp::DidCloseTextDocumentParams {
1302 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1303 }
1304 );
1305
1306 // Ensure all previously-registered buffers are re-opened when signing in.
1307 lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1308 Ok(request::SignInInitiateResult::AlreadySignedIn {
1309 user: "user-1".into(),
1310 })
1311 });
1312 copilot
1313 .update(cx, |copilot, cx| copilot.sign_in(cx))
1314 .await
1315 .unwrap();
1316
1317 assert_eq!(
1318 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1319 .await,
1320 lsp::DidOpenTextDocumentParams {
1321 text_document: lsp::TextDocumentItem::new(
1322 buffer_1_uri.clone(),
1323 "plaintext".into(),
1324 0,
1325 "Hello world".into()
1326 ),
1327 }
1328 );
1329 assert_eq!(
1330 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1331 .await,
1332 lsp::DidOpenTextDocumentParams {
1333 text_document: lsp::TextDocumentItem::new(
1334 buffer_2_uri.clone(),
1335 "plaintext".into(),
1336 0,
1337 "Goodbye".into()
1338 ),
1339 }
1340 );
1341 // Dropping a buffer causes it to be closed on the LSP side as well.
1342 cx.update(|_| drop(buffer_2));
1343 assert_eq!(
1344 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1345 .await,
1346 lsp::DidCloseTextDocumentParams {
1347 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1348 }
1349 );
1350 }
1351
1352 struct File {
1353 abs_path: PathBuf,
1354 path: Arc<Path>,
1355 }
1356
1357 impl language::File for File {
1358 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1359 Some(self)
1360 }
1361
1362 fn disk_state(&self) -> language::DiskState {
1363 language::DiskState::Present {
1364 mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1365 }
1366 }
1367
1368 fn path(&self) -> &Arc<Path> {
1369 &self.path
1370 }
1371
1372 fn full_path(&self, _: &App) -> PathBuf {
1373 unimplemented!()
1374 }
1375
1376 fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1377 unimplemented!()
1378 }
1379
1380 fn to_proto(&self, _: &App) -> rpc::proto::File {
1381 unimplemented!()
1382 }
1383
1384 fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1385 settings::WorktreeId::from_usize(0)
1386 }
1387
1388 fn is_private(&self) -> bool {
1389 false
1390 }
1391 }
1392
1393 impl language::LocalFile for File {
1394 fn abs_path(&self, _: &App) -> PathBuf {
1395 self.abs_path.clone()
1396 }
1397
1398 fn load(&self, _: &App) -> Task<Result<String>> {
1399 unimplemented!()
1400 }
1401
1402 fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1403 unimplemented!()
1404 }
1405 }
1406}
1407
1408#[cfg(test)]
1409#[ctor::ctor]
1410fn init_logger() {
1411 zlog::init_test();
1412}