1pub mod copilot_chat;
2mod copilot_completion_provider;
3pub mod request;
4mod sign_in;
5
6use crate::sign_in::initiate_sign_in_within_workspace;
7use ::fs::Fs;
8use anyhow::{Result, anyhow};
9use collections::{HashMap, HashSet};
10use command_palette_hooks::CommandPaletteFilter;
11use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
12use gpui::{
13 App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Global, Task,
14 WeakEntity, actions,
15};
16use http_client::HttpClient;
17use language::language_settings::CopilotSettings;
18use language::{
19 Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, ToPointUtf16,
20 language_settings::{EditPredictionProvider, all_language_settings, language_settings},
21 point_from_lsp, point_to_lsp,
22};
23use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
24use node_runtime::NodeRuntime;
25use parking_lot::Mutex;
26use request::StatusNotification;
27use settings::SettingsStore;
28use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
29use std::{
30 any::TypeId,
31 env,
32 ffi::OsString,
33 mem,
34 ops::Range,
35 path::{Path, PathBuf},
36 sync::Arc,
37};
38use util::{ResultExt, fs::remove_matching};
39use workspace::Workspace;
40
41pub use crate::copilot_completion_provider::CopilotCompletionProvider;
42pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
43
44actions!(
45 copilot,
46 [
47 Suggest,
48 NextSuggestion,
49 PreviousSuggestion,
50 Reinstall,
51 SignIn,
52 SignOut
53 ]
54);
55
56pub fn init(
57 new_server_id: LanguageServerId,
58 fs: Arc<dyn Fs>,
59 http: Arc<dyn HttpClient>,
60 node_runtime: NodeRuntime,
61 cx: &mut App,
62) {
63 copilot_chat::init(fs.clone(), http.clone(), cx);
64
65 let copilot = cx.new({
66 let node_runtime = node_runtime.clone();
67 move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)
68 });
69 Copilot::set_global(copilot.clone(), cx);
70 cx.observe(&copilot, |handle, cx| {
71 let copilot_action_types = [
72 TypeId::of::<Suggest>(),
73 TypeId::of::<NextSuggestion>(),
74 TypeId::of::<PreviousSuggestion>(),
75 TypeId::of::<Reinstall>(),
76 ];
77 let copilot_auth_action_types = [TypeId::of::<SignOut>()];
78 let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
79 let status = handle.read(cx).status();
80 let filter = CommandPaletteFilter::global_mut(cx);
81
82 match status {
83 Status::Disabled => {
84 filter.hide_action_types(&copilot_action_types);
85 filter.hide_action_types(&copilot_auth_action_types);
86 filter.hide_action_types(&copilot_no_auth_action_types);
87 }
88 Status::Authorized => {
89 filter.hide_action_types(&copilot_no_auth_action_types);
90 filter.show_action_types(
91 copilot_action_types
92 .iter()
93 .chain(&copilot_auth_action_types),
94 );
95 }
96 _ => {
97 filter.hide_action_types(&copilot_action_types);
98 filter.hide_action_types(&copilot_auth_action_types);
99 filter.show_action_types(copilot_no_auth_action_types.iter());
100 }
101 }
102 })
103 .detach();
104
105 cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
106 workspace.register_action(|workspace, _: &SignIn, window, cx| {
107 if let Some(copilot) = Copilot::global(cx) {
108 let is_reinstall = false;
109 initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
110 }
111 });
112 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
113 if let Some(copilot) = Copilot::global(cx) {
114 reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
115 }
116 });
117 workspace.register_action(|workspace, _: &SignOut, _window, cx| {
118 if let Some(copilot) = Copilot::global(cx) {
119 sign_out_within_workspace(workspace, copilot, cx);
120 }
121 });
122 })
123 .detach();
124}
125
126enum CopilotServer {
127 Disabled,
128 Starting { task: Shared<Task<()>> },
129 Error(Arc<str>),
130 Running(RunningCopilotServer),
131}
132
133impl CopilotServer {
134 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
135 let server = self.as_running()?;
136 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
137 Ok(server)
138 } else {
139 Err(anyhow!("must sign in before using copilot"))
140 }
141 }
142
143 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
144 match self {
145 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
146 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
147 CopilotServer::Error(error) => Err(anyhow!(
148 "copilot was not started because of an error: {}",
149 error
150 )),
151 CopilotServer::Running(server) => Ok(server),
152 }
153 }
154}
155
156struct RunningCopilotServer {
157 lsp: Arc<LanguageServer>,
158 sign_in_status: SignInStatus,
159 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
160}
161
162#[derive(Clone, Debug)]
163enum SignInStatus {
164 Authorized,
165 Unauthorized,
166 SigningIn {
167 prompt: Option<request::PromptUserDeviceFlow>,
168 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
169 },
170 SignedOut {
171 awaiting_signing_in: bool,
172 },
173}
174
175#[derive(Debug, Clone)]
176pub enum Status {
177 Starting {
178 task: Shared<Task<()>>,
179 },
180 Error(Arc<str>),
181 Disabled,
182 SignedOut {
183 awaiting_signing_in: bool,
184 },
185 SigningIn {
186 prompt: Option<request::PromptUserDeviceFlow>,
187 },
188 Unauthorized,
189 Authorized,
190}
191
192impl Status {
193 pub fn is_authorized(&self) -> bool {
194 matches!(self, Status::Authorized)
195 }
196
197 pub fn is_disabled(&self) -> bool {
198 matches!(self, Status::Disabled)
199 }
200}
201
202struct RegisteredBuffer {
203 uri: lsp::Url,
204 language_id: String,
205 snapshot: BufferSnapshot,
206 snapshot_version: i32,
207 _subscriptions: [gpui::Subscription; 2],
208 pending_buffer_change: Task<Option<()>>,
209}
210
211impl RegisteredBuffer {
212 fn report_changes(
213 &mut self,
214 buffer: &Entity<Buffer>,
215 cx: &mut Context<Copilot>,
216 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
217 let (done_tx, done_rx) = oneshot::channel();
218
219 if buffer.read(cx).version() == self.snapshot.version {
220 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
221 } else {
222 let buffer = buffer.downgrade();
223 let id = buffer.entity_id();
224 let prev_pending_change =
225 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
226 self.pending_buffer_change = cx.spawn(async move |copilot, cx| {
227 prev_pending_change.await;
228
229 let old_version = copilot
230 .update(cx, |copilot, _| {
231 let server = copilot.server.as_authenticated().log_err()?;
232 let buffer = server.registered_buffers.get_mut(&id)?;
233 Some(buffer.snapshot.version.clone())
234 })
235 .ok()??;
236 let new_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()).ok()?;
237
238 let content_changes = cx
239 .background_spawn({
240 let new_snapshot = new_snapshot.clone();
241 async move {
242 new_snapshot
243 .edits_since::<(PointUtf16, usize)>(&old_version)
244 .map(|edit| {
245 let edit_start = edit.new.start.0;
246 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
247 let new_text = new_snapshot
248 .text_for_range(edit.new.start.1..edit.new.end.1)
249 .collect();
250 lsp::TextDocumentContentChangeEvent {
251 range: Some(lsp::Range::new(
252 point_to_lsp(edit_start),
253 point_to_lsp(edit_end),
254 )),
255 range_length: None,
256 text: new_text,
257 }
258 })
259 .collect::<Vec<_>>()
260 }
261 })
262 .await;
263
264 copilot
265 .update(cx, |copilot, _| {
266 let server = copilot.server.as_authenticated().log_err()?;
267 let buffer = server.registered_buffers.get_mut(&id)?;
268 if !content_changes.is_empty() {
269 buffer.snapshot_version += 1;
270 buffer.snapshot = new_snapshot;
271 server
272 .lsp
273 .notify::<lsp::notification::DidChangeTextDocument>(
274 &lsp::DidChangeTextDocumentParams {
275 text_document: lsp::VersionedTextDocumentIdentifier::new(
276 buffer.uri.clone(),
277 buffer.snapshot_version,
278 ),
279 content_changes,
280 },
281 )
282 .ok();
283 }
284 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
285 Some(())
286 })
287 .ok()?;
288
289 Some(())
290 });
291 }
292
293 done_rx
294 }
295}
296
297#[derive(Debug)]
298pub struct Completion {
299 pub uuid: String,
300 pub range: Range<Anchor>,
301 pub text: String,
302}
303
304pub struct Copilot {
305 fs: Arc<dyn Fs>,
306 node_runtime: NodeRuntime,
307 server: CopilotServer,
308 buffers: HashSet<WeakEntity<Buffer>>,
309 server_id: LanguageServerId,
310 _subscription: gpui::Subscription,
311}
312
313pub enum Event {
314 CopilotLanguageServerStarted,
315 CopilotAuthSignedIn,
316 CopilotAuthSignedOut,
317}
318
319impl EventEmitter<Event> for Copilot {}
320
321struct GlobalCopilot(Entity<Copilot>);
322
323impl Global for GlobalCopilot {}
324
325impl Copilot {
326 pub fn global(cx: &App) -> Option<Entity<Self>> {
327 cx.try_global::<GlobalCopilot>()
328 .map(|model| model.0.clone())
329 }
330
331 pub fn set_global(copilot: Entity<Self>, cx: &mut App) {
332 cx.set_global(GlobalCopilot(copilot));
333 }
334
335 fn start(
336 new_server_id: LanguageServerId,
337 fs: Arc<dyn Fs>,
338 node_runtime: NodeRuntime,
339 cx: &mut Context<Self>,
340 ) -> Self {
341 let mut this = Self {
342 server_id: new_server_id,
343 fs,
344 node_runtime,
345 server: CopilotServer::Disabled,
346 buffers: Default::default(),
347 _subscription: cx.on_app_quit(Self::shutdown_language_server),
348 };
349 this.start_copilot(true, false, cx);
350 cx.observe_global::<SettingsStore>(move |this, cx| this.start_copilot(true, false, cx))
351 .detach();
352 this
353 }
354
355 fn shutdown_language_server(
356 &mut self,
357 _cx: &mut Context<Self>,
358 ) -> impl Future<Output = ()> + use<> {
359 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
360 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
361 _ => None,
362 };
363
364 async move {
365 if let Some(shutdown) = shutdown {
366 shutdown.await;
367 }
368 }
369 }
370
371 fn start_copilot(
372 &mut self,
373 check_edit_prediction_provider: bool,
374 awaiting_sign_in_after_start: bool,
375 cx: &mut Context<Self>,
376 ) {
377 if !matches!(self.server, CopilotServer::Disabled) {
378 return;
379 }
380 let language_settings = all_language_settings(None, cx);
381 if check_edit_prediction_provider
382 && language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
383 {
384 return;
385 }
386 let server_id = self.server_id;
387 let fs = self.fs.clone();
388 let node_runtime = self.node_runtime.clone();
389 let env = self.build_env(&language_settings.edit_predictions.copilot);
390 let start_task = cx
391 .spawn(async move |this, cx| {
392 Self::start_language_server(
393 server_id,
394 fs,
395 node_runtime,
396 env,
397 this,
398 awaiting_sign_in_after_start,
399 cx,
400 )
401 .await
402 })
403 .shared();
404 self.server = CopilotServer::Starting { task: start_task };
405 cx.notify();
406 }
407
408 fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
409 let proxy_url = copilot_settings.proxy.clone()?;
410 let no_verify = copilot_settings.proxy_no_verify;
411 let http_or_https_proxy = if proxy_url.starts_with("http:") {
412 "HTTP_PROXY"
413 } else if proxy_url.starts_with("https:") {
414 "HTTPS_PROXY"
415 } else {
416 log::error!(
417 "Unsupported protocol scheme for language server proxy (must be http or https)"
418 );
419 return None;
420 };
421
422 let mut env = HashMap::default();
423 env.insert(http_or_https_proxy.to_string(), proxy_url);
424
425 if let Some(true) = no_verify {
426 env.insert("NODE_TLS_REJECT_UNAUTHORIZED".to_string(), "0".to_string());
427 };
428
429 Some(env)
430 }
431
432 #[cfg(any(test, feature = "test-support"))]
433 pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
434 use fs::FakeFs;
435 use lsp::FakeLanguageServer;
436 use node_runtime::NodeRuntime;
437
438 let (server, fake_server) = FakeLanguageServer::new(
439 LanguageServerId(0),
440 LanguageServerBinary {
441 path: "path/to/copilot".into(),
442 arguments: vec![],
443 env: None,
444 },
445 "copilot".into(),
446 Default::default(),
447 &mut cx.to_async(),
448 );
449 let node_runtime = NodeRuntime::unavailable();
450 let this = cx.new(|cx| Self {
451 server_id: LanguageServerId(0),
452 fs: FakeFs::new(cx.background_executor().clone()),
453 node_runtime,
454 server: CopilotServer::Running(RunningCopilotServer {
455 lsp: Arc::new(server),
456 sign_in_status: SignInStatus::Authorized,
457 registered_buffers: Default::default(),
458 }),
459 _subscription: cx.on_app_quit(Self::shutdown_language_server),
460 buffers: Default::default(),
461 });
462 (this, fake_server)
463 }
464
465 async fn start_language_server(
466 new_server_id: LanguageServerId,
467 fs: Arc<dyn Fs>,
468 node_runtime: NodeRuntime,
469 env: Option<HashMap<String, String>>,
470 this: WeakEntity<Self>,
471 awaiting_sign_in_after_start: bool,
472 cx: &mut AsyncApp,
473 ) {
474 let start_language_server = async {
475 let server_path = get_copilot_lsp(fs, node_runtime.clone()).await?;
476 let node_path = node_runtime.binary_path().await?;
477 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
478 let binary = LanguageServerBinary {
479 path: node_path,
480 arguments,
481 env,
482 };
483
484 let root_path = if cfg!(target_os = "windows") {
485 Path::new("C:/")
486 } else {
487 Path::new("/")
488 };
489
490 let server_name = LanguageServerName("copilot".into());
491 let server = LanguageServer::new(
492 Arc::new(Mutex::new(None)),
493 new_server_id,
494 server_name,
495 binary,
496 root_path,
497 None,
498 Default::default(),
499 cx,
500 )?;
501
502 server
503 .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
504 .detach();
505
506 let configuration = lsp::DidChangeConfigurationParams {
507 settings: Default::default(),
508 };
509
510 let editor_info = request::SetEditorInfoParams {
511 editor_info: request::EditorInfo {
512 name: "zed".into(),
513 version: env!("CARGO_PKG_VERSION").into(),
514 },
515 editor_plugin_info: request::EditorPluginInfo {
516 name: "zed-copilot".into(),
517 version: "0.0.1".into(),
518 },
519 };
520 let editor_info_json = serde_json::to_value(&editor_info)?;
521
522 let server = cx
523 .update(|cx| {
524 let mut params = server.default_initialize_params(cx);
525 params.initialization_options = Some(editor_info_json);
526 server.initialize(params, configuration.into(), cx)
527 })?
528 .await?;
529
530 let status = server
531 .request::<request::CheckStatus>(request::CheckStatusParams {
532 local_checks_only: false,
533 })
534 .await?;
535
536 server
537 .request::<request::SetEditorInfo>(editor_info)
538 .await?;
539
540 anyhow::Ok((server, status))
541 };
542
543 let server = start_language_server.await;
544 this.update(cx, |this, cx| {
545 cx.notify();
546 match server {
547 Ok((server, status)) => {
548 this.server = CopilotServer::Running(RunningCopilotServer {
549 lsp: server,
550 sign_in_status: SignInStatus::SignedOut {
551 awaiting_signing_in: awaiting_sign_in_after_start,
552 },
553 registered_buffers: Default::default(),
554 });
555 cx.emit(Event::CopilotLanguageServerStarted);
556 this.update_sign_in_status(status, cx);
557 }
558 Err(error) => {
559 this.server = CopilotServer::Error(error.to_string().into());
560 cx.notify()
561 }
562 }
563 })
564 .ok();
565 }
566
567 pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
568 if let CopilotServer::Running(server) = &mut self.server {
569 let task = match &server.sign_in_status {
570 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
571 SignInStatus::SigningIn { task, .. } => {
572 cx.notify();
573 task.clone()
574 }
575 SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
576 let lsp = server.lsp.clone();
577 let task = cx
578 .spawn(async move |this, cx| {
579 let sign_in = async {
580 let sign_in = lsp
581 .request::<request::SignInInitiate>(
582 request::SignInInitiateParams {},
583 )
584 .await?;
585 match sign_in {
586 request::SignInInitiateResult::AlreadySignedIn { user } => {
587 Ok(request::SignInStatus::Ok { user: Some(user) })
588 }
589 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
590 this.update(cx, |this, cx| {
591 if let CopilotServer::Running(RunningCopilotServer {
592 sign_in_status: status,
593 ..
594 }) = &mut this.server
595 {
596 if let SignInStatus::SigningIn {
597 prompt: prompt_flow,
598 ..
599 } = status
600 {
601 *prompt_flow = Some(flow.clone());
602 cx.notify();
603 }
604 }
605 })?;
606 let response = lsp
607 .request::<request::SignInConfirm>(
608 request::SignInConfirmParams {
609 user_code: flow.user_code,
610 },
611 )
612 .await?;
613 Ok(response)
614 }
615 }
616 };
617
618 let sign_in = sign_in.await;
619 this.update(cx, |this, cx| match sign_in {
620 Ok(status) => {
621 this.update_sign_in_status(status, cx);
622 Ok(())
623 }
624 Err(error) => {
625 this.update_sign_in_status(
626 request::SignInStatus::NotSignedIn,
627 cx,
628 );
629 Err(Arc::new(error))
630 }
631 })?
632 })
633 .shared();
634 server.sign_in_status = SignInStatus::SigningIn {
635 prompt: None,
636 task: task.clone(),
637 };
638 cx.notify();
639 task
640 }
641 };
642
643 cx.background_spawn(task.map_err(|err| anyhow!("{:?}", err)))
644 } else {
645 // If we're downloading, wait until download is finished
646 // If we're in a stuck state, display to the user
647 Task::ready(Err(anyhow!("copilot hasn't started yet")))
648 }
649 }
650
651 pub(crate) fn sign_out(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
652 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
653 match &self.server {
654 CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) => {
655 let server = server.clone();
656 cx.background_spawn(async move {
657 server
658 .request::<request::SignOut>(request::SignOutParams {})
659 .await?;
660 anyhow::Ok(())
661 })
662 }
663 CopilotServer::Disabled => cx.background_spawn(async {
664 clear_copilot_config_dir().await;
665 anyhow::Ok(())
666 }),
667 _ => Task::ready(Err(anyhow!("copilot hasn't started yet"))),
668 }
669 }
670
671 pub(crate) fn reinstall(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
672 let language_settings = all_language_settings(None, cx);
673 let env = self.build_env(&language_settings.edit_predictions.copilot);
674 let start_task = cx
675 .spawn({
676 let fs = self.fs.clone();
677 let node_runtime = self.node_runtime.clone();
678 let server_id = self.server_id;
679 async move |this, cx| {
680 clear_copilot_dir().await;
681 Self::start_language_server(server_id, fs, node_runtime, env, this, false, cx)
682 .await
683 }
684 })
685 .shared();
686
687 self.server = CopilotServer::Starting {
688 task: start_task.clone(),
689 };
690
691 cx.notify();
692
693 start_task
694 }
695
696 pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
697 if let CopilotServer::Running(server) = &self.server {
698 Some(&server.lsp)
699 } else {
700 None
701 }
702 }
703
704 pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
705 let weak_buffer = buffer.downgrade();
706 self.buffers.insert(weak_buffer.clone());
707
708 if let CopilotServer::Running(RunningCopilotServer {
709 lsp: server,
710 sign_in_status: status,
711 registered_buffers,
712 ..
713 }) = &mut self.server
714 {
715 if !matches!(status, SignInStatus::Authorized { .. }) {
716 return;
717 }
718
719 registered_buffers
720 .entry(buffer.entity_id())
721 .or_insert_with(|| {
722 let uri: lsp::Url = uri_for_buffer(buffer, cx);
723 let language_id = id_for_language(buffer.read(cx).language());
724 let snapshot = buffer.read(cx).snapshot();
725 server
726 .notify::<lsp::notification::DidOpenTextDocument>(
727 &lsp::DidOpenTextDocumentParams {
728 text_document: lsp::TextDocumentItem {
729 uri: uri.clone(),
730 language_id: language_id.clone(),
731 version: 0,
732 text: snapshot.text(),
733 },
734 },
735 )
736 .ok();
737
738 RegisteredBuffer {
739 uri,
740 language_id,
741 snapshot,
742 snapshot_version: 0,
743 pending_buffer_change: Task::ready(Some(())),
744 _subscriptions: [
745 cx.subscribe(buffer, |this, buffer, event, cx| {
746 this.handle_buffer_event(buffer, event, cx).log_err();
747 }),
748 cx.observe_release(buffer, move |this, _buffer, _cx| {
749 this.buffers.remove(&weak_buffer);
750 this.unregister_buffer(&weak_buffer);
751 }),
752 ],
753 }
754 });
755 }
756 }
757
758 fn handle_buffer_event(
759 &mut self,
760 buffer: Entity<Buffer>,
761 event: &language::BufferEvent,
762 cx: &mut Context<Self>,
763 ) -> Result<()> {
764 if let Ok(server) = self.server.as_running() {
765 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
766 {
767 match event {
768 language::BufferEvent::Edited => {
769 drop(registered_buffer.report_changes(&buffer, cx));
770 }
771 language::BufferEvent::Saved => {
772 server
773 .lsp
774 .notify::<lsp::notification::DidSaveTextDocument>(
775 &lsp::DidSaveTextDocumentParams {
776 text_document: lsp::TextDocumentIdentifier::new(
777 registered_buffer.uri.clone(),
778 ),
779 text: None,
780 },
781 )?;
782 }
783 language::BufferEvent::FileHandleChanged
784 | language::BufferEvent::LanguageChanged => {
785 let new_language_id = id_for_language(buffer.read(cx).language());
786 let new_uri = uri_for_buffer(&buffer, cx);
787 if new_uri != registered_buffer.uri
788 || new_language_id != registered_buffer.language_id
789 {
790 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
791 registered_buffer.language_id = new_language_id;
792 server
793 .lsp
794 .notify::<lsp::notification::DidCloseTextDocument>(
795 &lsp::DidCloseTextDocumentParams {
796 text_document: lsp::TextDocumentIdentifier::new(old_uri),
797 },
798 )?;
799 server
800 .lsp
801 .notify::<lsp::notification::DidOpenTextDocument>(
802 &lsp::DidOpenTextDocumentParams {
803 text_document: lsp::TextDocumentItem::new(
804 registered_buffer.uri.clone(),
805 registered_buffer.language_id.clone(),
806 registered_buffer.snapshot_version,
807 registered_buffer.snapshot.text(),
808 ),
809 },
810 )?;
811 }
812 }
813 _ => {}
814 }
815 }
816 }
817
818 Ok(())
819 }
820
821 fn unregister_buffer(&mut self, buffer: &WeakEntity<Buffer>) {
822 if let Ok(server) = self.server.as_running() {
823 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
824 server
825 .lsp
826 .notify::<lsp::notification::DidCloseTextDocument>(
827 &lsp::DidCloseTextDocumentParams {
828 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
829 },
830 )
831 .ok();
832 }
833 }
834 }
835
836 pub fn completions<T>(
837 &mut self,
838 buffer: &Entity<Buffer>,
839 position: T,
840 cx: &mut Context<Self>,
841 ) -> Task<Result<Vec<Completion>>>
842 where
843 T: ToPointUtf16,
844 {
845 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
846 }
847
848 pub fn completions_cycling<T>(
849 &mut self,
850 buffer: &Entity<Buffer>,
851 position: T,
852 cx: &mut Context<Self>,
853 ) -> Task<Result<Vec<Completion>>>
854 where
855 T: ToPointUtf16,
856 {
857 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
858 }
859
860 pub fn accept_completion(
861 &mut self,
862 completion: &Completion,
863 cx: &mut Context<Self>,
864 ) -> Task<Result<()>> {
865 let server = match self.server.as_authenticated() {
866 Ok(server) => server,
867 Err(error) => return Task::ready(Err(error)),
868 };
869 let request =
870 server
871 .lsp
872 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
873 uuid: completion.uuid.clone(),
874 });
875 cx.background_spawn(async move {
876 request.await?;
877 Ok(())
878 })
879 }
880
881 pub fn discard_completions(
882 &mut self,
883 completions: &[Completion],
884 cx: &mut Context<Self>,
885 ) -> Task<Result<()>> {
886 let server = match self.server.as_authenticated() {
887 Ok(server) => server,
888 Err(_) => return Task::ready(Ok(())),
889 };
890 let request =
891 server
892 .lsp
893 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
894 uuids: completions
895 .iter()
896 .map(|completion| completion.uuid.clone())
897 .collect(),
898 });
899 cx.background_spawn(async move {
900 request.await?;
901 Ok(())
902 })
903 }
904
905 fn request_completions<R, T>(
906 &mut self,
907 buffer: &Entity<Buffer>,
908 position: T,
909 cx: &mut Context<Self>,
910 ) -> Task<Result<Vec<Completion>>>
911 where
912 R: 'static
913 + lsp::request::Request<
914 Params = request::GetCompletionsParams,
915 Result = request::GetCompletionsResult,
916 >,
917 T: ToPointUtf16,
918 {
919 self.register_buffer(buffer, cx);
920
921 let server = match self.server.as_authenticated() {
922 Ok(server) => server,
923 Err(error) => return Task::ready(Err(error)),
924 };
925 let lsp = server.lsp.clone();
926 let registered_buffer = server
927 .registered_buffers
928 .get_mut(&buffer.entity_id())
929 .unwrap();
930 let snapshot = registered_buffer.report_changes(buffer, cx);
931 let buffer = buffer.read(cx);
932 let uri = registered_buffer.uri.clone();
933 let position = position.to_point_utf16(buffer);
934 let settings = language_settings(
935 buffer.language_at(position).map(|l| l.name()),
936 buffer.file(),
937 cx,
938 );
939 let tab_size = settings.tab_size;
940 let hard_tabs = settings.hard_tabs;
941 let relative_path = buffer
942 .file()
943 .map(|file| file.path().to_path_buf())
944 .unwrap_or_default();
945
946 cx.background_spawn(async move {
947 let (version, snapshot) = snapshot.await?;
948 let result = lsp
949 .request::<R>(request::GetCompletionsParams {
950 doc: request::GetCompletionsDocument {
951 uri,
952 tab_size: tab_size.into(),
953 indent_size: 1,
954 insert_spaces: !hard_tabs,
955 relative_path: relative_path.to_string_lossy().into(),
956 position: point_to_lsp(position),
957 version: version.try_into().unwrap(),
958 },
959 })
960 .await?;
961 let completions = result
962 .completions
963 .into_iter()
964 .map(|completion| {
965 let start = snapshot
966 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
967 let end =
968 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
969 Completion {
970 uuid: completion.uuid,
971 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
972 text: completion.text,
973 }
974 })
975 .collect();
976 anyhow::Ok(completions)
977 })
978 }
979
980 pub fn status(&self) -> Status {
981 match &self.server {
982 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
983 CopilotServer::Disabled => Status::Disabled,
984 CopilotServer::Error(error) => Status::Error(error.clone()),
985 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
986 match sign_in_status {
987 SignInStatus::Authorized { .. } => Status::Authorized,
988 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
989 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
990 prompt: prompt.clone(),
991 },
992 SignInStatus::SignedOut {
993 awaiting_signing_in,
994 } => Status::SignedOut {
995 awaiting_signing_in: *awaiting_signing_in,
996 },
997 }
998 }
999 }
1000 }
1001
1002 fn update_sign_in_status(&mut self, lsp_status: request::SignInStatus, cx: &mut Context<Self>) {
1003 self.buffers.retain(|buffer| buffer.is_upgradable());
1004
1005 if let Ok(server) = self.server.as_running() {
1006 match lsp_status {
1007 request::SignInStatus::Ok { user: Some(_) }
1008 | request::SignInStatus::MaybeOk { .. }
1009 | request::SignInStatus::AlreadySignedIn { .. } => {
1010 server.sign_in_status = SignInStatus::Authorized;
1011 cx.emit(Event::CopilotAuthSignedIn);
1012 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1013 if let Some(buffer) = buffer.upgrade() {
1014 self.register_buffer(&buffer, cx);
1015 }
1016 }
1017 }
1018 request::SignInStatus::NotAuthorized { .. } => {
1019 server.sign_in_status = SignInStatus::Unauthorized;
1020 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1021 self.unregister_buffer(&buffer);
1022 }
1023 }
1024 request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
1025 if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
1026 server.sign_in_status = SignInStatus::SignedOut {
1027 awaiting_signing_in: false,
1028 };
1029 }
1030 cx.emit(Event::CopilotAuthSignedOut);
1031 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
1032 self.unregister_buffer(&buffer);
1033 }
1034 }
1035 }
1036
1037 cx.notify();
1038 }
1039 }
1040}
1041
1042fn id_for_language(language: Option<&Arc<Language>>) -> String {
1043 language
1044 .map(|language| language.lsp_id())
1045 .unwrap_or_else(|| "plaintext".to_string())
1046}
1047
1048fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> lsp::Url {
1049 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
1050 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
1051 } else {
1052 format!("buffer://{}", buffer.entity_id()).parse().unwrap()
1053 }
1054}
1055
1056async fn clear_copilot_dir() {
1057 remove_matching(paths::copilot_dir(), |_| true).await
1058}
1059
1060async fn clear_copilot_config_dir() {
1061 remove_matching(copilot_chat::copilot_chat_config_dir(), |_| true).await
1062}
1063
1064async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::Result<PathBuf> {
1065 const PACKAGE_NAME: &str = "@github/copilot-language-server";
1066 const SERVER_PATH: &str =
1067 "node_modules/@github/copilot-language-server/dist/language-server.js";
1068
1069 let latest_version = node_runtime
1070 .npm_package_latest_version(PACKAGE_NAME)
1071 .await?;
1072 let server_path = paths::copilot_dir().join(SERVER_PATH);
1073
1074 fs.create_dir(paths::copilot_dir()).await?;
1075
1076 let should_install = node_runtime
1077 .should_install_npm_package(
1078 PACKAGE_NAME,
1079 &server_path,
1080 paths::copilot_dir(),
1081 &latest_version,
1082 )
1083 .await;
1084 if should_install {
1085 node_runtime
1086 .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
1087 .await?;
1088 }
1089
1090 Ok(server_path)
1091}
1092
1093#[cfg(test)]
1094mod tests {
1095 use super::*;
1096 use gpui::TestAppContext;
1097 use util::path;
1098
1099 #[gpui::test(iterations = 10)]
1100 async fn test_buffer_management(cx: &mut TestAppContext) {
1101 let (copilot, mut lsp) = Copilot::fake(cx);
1102
1103 let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx));
1104 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1105 .parse()
1106 .unwrap();
1107 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1108 assert_eq!(
1109 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1110 .await,
1111 lsp::DidOpenTextDocumentParams {
1112 text_document: lsp::TextDocumentItem::new(
1113 buffer_1_uri.clone(),
1114 "plaintext".into(),
1115 0,
1116 "Hello".into()
1117 ),
1118 }
1119 );
1120
1121 let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx));
1122 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1123 .parse()
1124 .unwrap();
1125 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1126 assert_eq!(
1127 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1128 .await,
1129 lsp::DidOpenTextDocumentParams {
1130 text_document: lsp::TextDocumentItem::new(
1131 buffer_2_uri.clone(),
1132 "plaintext".into(),
1133 0,
1134 "Goodbye".into()
1135 ),
1136 }
1137 );
1138
1139 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1140 assert_eq!(
1141 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1142 .await,
1143 lsp::DidChangeTextDocumentParams {
1144 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1145 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1146 range: Some(lsp::Range::new(
1147 lsp::Position::new(0, 5),
1148 lsp::Position::new(0, 5)
1149 )),
1150 range_length: None,
1151 text: " world".into(),
1152 }],
1153 }
1154 );
1155
1156 // Ensure updates to the file are reflected in the LSP.
1157 buffer_1.update(cx, |buffer, cx| {
1158 buffer.file_updated(
1159 Arc::new(File {
1160 abs_path: path!("/root/child/buffer-1").into(),
1161 path: Path::new("child/buffer-1").into(),
1162 }),
1163 cx,
1164 )
1165 });
1166 assert_eq!(
1167 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1168 .await,
1169 lsp::DidCloseTextDocumentParams {
1170 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1171 }
1172 );
1173 let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap();
1174 assert_eq!(
1175 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1176 .await,
1177 lsp::DidOpenTextDocumentParams {
1178 text_document: lsp::TextDocumentItem::new(
1179 buffer_1_uri.clone(),
1180 "plaintext".into(),
1181 1,
1182 "Hello world".into()
1183 ),
1184 }
1185 );
1186
1187 // Ensure all previously-registered buffers are closed when signing out.
1188 lsp.set_request_handler::<request::SignOut, _, _>(|_, _| async {
1189 Ok(request::SignOutResult {})
1190 });
1191 copilot
1192 .update(cx, |copilot, cx| copilot.sign_out(cx))
1193 .await
1194 .unwrap();
1195 assert_eq!(
1196 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1197 .await,
1198 lsp::DidCloseTextDocumentParams {
1199 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1200 }
1201 );
1202 assert_eq!(
1203 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1204 .await,
1205 lsp::DidCloseTextDocumentParams {
1206 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1207 }
1208 );
1209
1210 // Ensure all previously-registered buffers are re-opened when signing in.
1211 lsp.set_request_handler::<request::SignInInitiate, _, _>(|_, _| async {
1212 Ok(request::SignInInitiateResult::AlreadySignedIn {
1213 user: "user-1".into(),
1214 })
1215 });
1216 copilot
1217 .update(cx, |copilot, cx| copilot.sign_in(cx))
1218 .await
1219 .unwrap();
1220
1221 assert_eq!(
1222 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1223 .await,
1224 lsp::DidOpenTextDocumentParams {
1225 text_document: lsp::TextDocumentItem::new(
1226 buffer_1_uri.clone(),
1227 "plaintext".into(),
1228 0,
1229 "Hello world".into()
1230 ),
1231 }
1232 );
1233 assert_eq!(
1234 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1235 .await,
1236 lsp::DidOpenTextDocumentParams {
1237 text_document: lsp::TextDocumentItem::new(
1238 buffer_2_uri.clone(),
1239 "plaintext".into(),
1240 0,
1241 "Goodbye".into()
1242 ),
1243 }
1244 );
1245 // Dropping a buffer causes it to be closed on the LSP side as well.
1246 cx.update(|_| drop(buffer_2));
1247 assert_eq!(
1248 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1249 .await,
1250 lsp::DidCloseTextDocumentParams {
1251 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1252 }
1253 );
1254 }
1255
1256 struct File {
1257 abs_path: PathBuf,
1258 path: Arc<Path>,
1259 }
1260
1261 impl language::File for File {
1262 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1263 Some(self)
1264 }
1265
1266 fn disk_state(&self) -> language::DiskState {
1267 language::DiskState::Present {
1268 mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1269 }
1270 }
1271
1272 fn path(&self) -> &Arc<Path> {
1273 &self.path
1274 }
1275
1276 fn full_path(&self, _: &App) -> PathBuf {
1277 unimplemented!()
1278 }
1279
1280 fn file_name<'a>(&'a self, _: &'a App) -> &'a std::ffi::OsStr {
1281 unimplemented!()
1282 }
1283
1284 fn to_proto(&self, _: &App) -> rpc::proto::File {
1285 unimplemented!()
1286 }
1287
1288 fn worktree_id(&self, _: &App) -> settings::WorktreeId {
1289 settings::WorktreeId::from_usize(0)
1290 }
1291
1292 fn is_private(&self) -> bool {
1293 false
1294 }
1295 }
1296
1297 impl language::LocalFile for File {
1298 fn abs_path(&self, _: &App) -> PathBuf {
1299 self.abs_path.clone()
1300 }
1301
1302 fn load(&self, _: &App) -> Task<Result<String>> {
1303 unimplemented!()
1304 }
1305
1306 fn load_bytes(&self, _cx: &App) -> Task<Result<Vec<u8>>> {
1307 unimplemented!()
1308 }
1309 }
1310}
1311
1312#[cfg(test)]
1313#[ctor::ctor]
1314fn init_logger() {
1315 if std::env::var("RUST_LOG").is_ok() {
1316 env_logger::init();
1317 }
1318}