1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
11};
12use language::{
13 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
14 ToPointUtf16,
15};
16use log::{debug, error};
17use lsp::{LanguageServer, LanguageServerId};
18use node_runtime::NodeRuntime;
19use request::{LogMessage, StatusNotification};
20use settings::Settings;
21use smol::{fs, io::BufReader, stream::StreamExt};
22use std::{
23 ffi::OsString,
24 mem,
25 ops::Range,
26 path::{Path, PathBuf},
27 pin::Pin,
28 sync::Arc,
29};
30use util::{
31 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
32};
33
34const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
35actions!(copilot_auth, [SignIn, SignOut]);
36
37const COPILOT_NAMESPACE: &'static str = "copilot";
38actions!(
39 copilot,
40 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
41);
42
43pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
44 let copilot = cx.add_model({
45 let node_runtime = node_runtime.clone();
46 move |cx| Copilot::start(http, node_runtime, cx)
47 });
48 cx.set_global(copilot.clone());
49
50 //////////////////////////////////////
51 // SUBSCRIBE TO COPILOT EVENTS HERE //
52 //////////////////////////////////////
53
54 cx.observe(&copilot, |handle, cx| {
55 let status = handle.read(cx).status();
56 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
57 match status {
58 Status::Disabled => {
59 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
60 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
61 }
62 Status::Authorized => {
63 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
64 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
65 }
66 _ => {
67 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
68 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
69 }
70 }
71 });
72 })
73 .detach();
74
75 sign_in::init(cx);
76 cx.add_global_action(|_: &SignIn, cx| {
77 if let Some(copilot) = Copilot::global(cx) {
78 copilot
79 .update(cx, |copilot, cx| copilot.sign_in(cx))
80 .detach_and_log_err(cx);
81 }
82 });
83 cx.add_global_action(|_: &SignOut, cx| {
84 if let Some(copilot) = Copilot::global(cx) {
85 copilot
86 .update(cx, |copilot, cx| copilot.sign_out(cx))
87 .detach_and_log_err(cx);
88 }
89 });
90
91 cx.add_global_action(|_: &Reinstall, cx| {
92 if let Some(copilot) = Copilot::global(cx) {
93 copilot
94 .update(cx, |copilot, cx| copilot.reinstall(cx))
95 .detach();
96 }
97 });
98}
99
100enum CopilotServer {
101 Disabled,
102 Starting { task: Shared<Task<()>> },
103 Error(Arc<str>),
104 Running(RunningCopilotServer),
105}
106
107impl CopilotServer {
108 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
109 let server = self.as_running()?;
110 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
111 Ok(server)
112 } else {
113 Err(anyhow!("must sign in before using copilot"))
114 }
115 }
116
117 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
118 match self {
119 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
120 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
121 CopilotServer::Error(error) => Err(anyhow!(
122 "copilot was not started because of an error: {}",
123 error
124 )),
125 CopilotServer::Running(server) => Ok(server),
126 }
127 }
128}
129
130struct RunningCopilotServer {
131 lsp: Arc<LanguageServer>,
132 sign_in_status: SignInStatus,
133 registered_buffers: HashMap<u64, RegisteredBuffer>,
134}
135
136#[derive(Clone, Debug)]
137enum SignInStatus {
138 Authorized,
139 Unauthorized,
140 SigningIn {
141 prompt: Option<request::PromptUserDeviceFlow>,
142 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
143 },
144 SignedOut,
145}
146
147#[derive(Debug, Clone)]
148pub enum Status {
149 Starting {
150 task: Shared<Task<()>>,
151 },
152 Error(Arc<str>),
153 Disabled,
154 SignedOut,
155 SigningIn {
156 prompt: Option<request::PromptUserDeviceFlow>,
157 },
158 Unauthorized,
159 Authorized,
160}
161
162impl Status {
163 pub fn is_authorized(&self) -> bool {
164 matches!(self, Status::Authorized)
165 }
166}
167
168struct RegisteredBuffer {
169 id: u64,
170 uri: lsp::Url,
171 language_id: String,
172 snapshot: BufferSnapshot,
173 snapshot_version: i32,
174 _subscriptions: [gpui::Subscription; 2],
175 pending_buffer_change: Task<Option<()>>,
176}
177
178impl RegisteredBuffer {
179 fn report_changes(
180 &mut self,
181 buffer: &ModelHandle<Buffer>,
182 cx: &mut ModelContext<Copilot>,
183 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
184 let id = self.id;
185 let (done_tx, done_rx) = oneshot::channel();
186
187 if buffer.read(cx).version() == self.snapshot.version {
188 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
189 } else {
190 let buffer = buffer.downgrade();
191 let prev_pending_change =
192 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
193 self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
194 prev_pending_change.await;
195
196 let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
197 let server = copilot.server.as_authenticated().log_err()?;
198 let buffer = server.registered_buffers.get_mut(&id)?;
199 Some(buffer.snapshot.version.clone())
200 })?;
201 let new_snapshot = buffer
202 .upgrade(&cx)?
203 .read_with(&cx, |buffer, _| buffer.snapshot());
204
205 let content_changes = cx
206 .background()
207 .spawn({
208 let new_snapshot = new_snapshot.clone();
209 async move {
210 new_snapshot
211 .edits_since::<(PointUtf16, usize)>(&old_version)
212 .map(|edit| {
213 let edit_start = edit.new.start.0;
214 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
215 let new_text = new_snapshot
216 .text_for_range(edit.new.start.1..edit.new.end.1)
217 .collect();
218 lsp::TextDocumentContentChangeEvent {
219 range: Some(lsp::Range::new(
220 point_to_lsp(edit_start),
221 point_to_lsp(edit_end),
222 )),
223 range_length: None,
224 text: new_text,
225 }
226 })
227 .collect::<Vec<_>>()
228 }
229 })
230 .await;
231
232 copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
233 let server = copilot.server.as_authenticated().log_err()?;
234 let buffer = server.registered_buffers.get_mut(&id)?;
235 if !content_changes.is_empty() {
236 buffer.snapshot_version += 1;
237 buffer.snapshot = new_snapshot;
238 server
239 .lsp
240 .notify::<lsp::notification::DidChangeTextDocument>(
241 lsp::DidChangeTextDocumentParams {
242 text_document: lsp::VersionedTextDocumentIdentifier::new(
243 buffer.uri.clone(),
244 buffer.snapshot_version,
245 ),
246 content_changes,
247 },
248 )
249 .log_err();
250 }
251 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
252 Some(())
253 })?;
254
255 Some(())
256 });
257 }
258
259 done_rx
260 }
261}
262
263#[derive(Debug)]
264pub struct Completion {
265 uuid: String,
266 pub range: Range<Anchor>,
267 pub text: String,
268}
269
270pub struct Copilot {
271 http: Arc<dyn HttpClient>,
272 node_runtime: Arc<NodeRuntime>,
273 server: CopilotServer,
274 buffers: HashMap<u64, WeakModelHandle<Buffer>>,
275}
276
277pub enum Event {
278 CompletionAccepted {
279 uuid: String,
280 file_type: Option<Arc<str>>,
281 },
282 CompletionsDiscarded {
283 uuids: Vec<String>,
284 file_type: Option<Arc<str>>,
285 },
286}
287
288impl Entity for Copilot {
289 type Event = Event;
290
291 fn app_will_quit(
292 &mut self,
293 _: &mut AppContext,
294 ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
295 match mem::replace(&mut self.server, CopilotServer::Disabled) {
296 CopilotServer::Running(server) => Some(Box::pin(async move {
297 if let Some(shutdown) = server.lsp.shutdown() {
298 shutdown.await;
299 }
300 })),
301 _ => None,
302 }
303 }
304}
305
306impl Copilot {
307 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
308 if cx.has_global::<ModelHandle<Self>>() {
309 Some(cx.global::<ModelHandle<Self>>().clone())
310 } else {
311 None
312 }
313 }
314
315 fn start(
316 http: Arc<dyn HttpClient>,
317 node_runtime: Arc<NodeRuntime>,
318 cx: &mut ModelContext<Self>,
319 ) -> Self {
320 cx.observe_global::<Settings, _>({
321 let http = http.clone();
322 let node_runtime = node_runtime.clone();
323 move |this, cx| {
324 if cx.global::<Settings>().features.copilot {
325 if matches!(this.server, CopilotServer::Disabled) {
326 let start_task = cx
327 .spawn({
328 let http = http.clone();
329 let node_runtime = node_runtime.clone();
330 move |this, cx| {
331 Self::start_language_server(http, node_runtime, this, cx)
332 }
333 })
334 .shared();
335 this.server = CopilotServer::Starting { task: start_task };
336 cx.notify();
337 }
338 } else {
339 this.server = CopilotServer::Disabled;
340 cx.notify();
341 }
342 }
343 })
344 .detach();
345
346 if cx.global::<Settings>().features.copilot {
347 let start_task = cx
348 .spawn({
349 let http = http.clone();
350 let node_runtime = node_runtime.clone();
351 move |this, cx| async {
352 Self::start_language_server(http, node_runtime, this, cx).await
353 }
354 })
355 .shared();
356
357 Self {
358 http,
359 node_runtime,
360 server: CopilotServer::Starting { task: start_task },
361 buffers: Default::default(),
362 }
363 } else {
364 Self {
365 http,
366 node_runtime,
367 server: CopilotServer::Disabled,
368 buffers: Default::default(),
369 }
370 }
371 }
372
373 #[cfg(any(test, feature = "test-support"))]
374 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
375 let (server, fake_server) =
376 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
377 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
378 let this = cx.add_model(|cx| Self {
379 http: http.clone(),
380 node_runtime: NodeRuntime::new(http, cx.background().clone()),
381 server: CopilotServer::Running(RunningCopilotServer {
382 lsp: Arc::new(server),
383 sign_in_status: SignInStatus::Authorized,
384 registered_buffers: Default::default(),
385 }),
386 buffers: Default::default(),
387 });
388 (this, fake_server)
389 }
390
391 fn start_language_server(
392 http: Arc<dyn HttpClient>,
393 node_runtime: Arc<NodeRuntime>,
394 this: ModelHandle<Self>,
395 mut cx: AsyncAppContext,
396 ) -> impl Future<Output = ()> {
397 async move {
398 let start_language_server = async {
399 let server_path = get_copilot_lsp(http).await?;
400 let node_path = node_runtime.binary_path().await?;
401 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
402 let server = LanguageServer::new(
403 LanguageServerId(0),
404 &node_path,
405 arguments,
406 Path::new("/"),
407 None,
408 cx.clone(),
409 )?;
410
411 server
412 .on_notification::<LogMessage, _>(|params, _cx| {
413 match params.level {
414 // Copilot is pretty agressive about logging
415 0 => debug!("copilot: {}", params.message),
416 1 => debug!("copilot: {}", params.message),
417 _ => error!("copilot: {}", params.message),
418 }
419
420 debug!("copilot metadata: {}", params.metadata_str);
421 debug!("copilot extra: {:?}", params.extra);
422 })
423 .detach();
424
425 server
426 .on_notification::<StatusNotification, _>(
427 |_, _| { /* Silence the notification */ },
428 )
429 .detach();
430
431 let server = server.initialize(Default::default()).await?;
432
433 let status = server
434 .request::<request::CheckStatus>(request::CheckStatusParams {
435 local_checks_only: false,
436 })
437 .await?;
438
439 server
440 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
441 editor_info: request::EditorInfo {
442 name: "zed".into(),
443 version: env!("CARGO_PKG_VERSION").into(),
444 },
445 editor_plugin_info: request::EditorPluginInfo {
446 name: "zed-copilot".into(),
447 version: "0.0.1".into(),
448 },
449 })
450 .await?;
451
452 anyhow::Ok((server, status))
453 };
454
455 let server = start_language_server.await;
456 this.update(&mut cx, |this, cx| {
457 cx.notify();
458 match server {
459 Ok((server, status)) => {
460 this.server = CopilotServer::Running(RunningCopilotServer {
461 lsp: server,
462 sign_in_status: SignInStatus::SignedOut,
463 registered_buffers: Default::default(),
464 });
465 this.update_sign_in_status(status, cx);
466 }
467 Err(error) => {
468 this.server = CopilotServer::Error(error.to_string().into());
469 cx.notify()
470 }
471 }
472 })
473 }
474 }
475
476 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
477 if let CopilotServer::Running(server) = &mut self.server {
478 let task = match &server.sign_in_status {
479 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
480 SignInStatus::SigningIn { task, .. } => {
481 cx.notify();
482 task.clone()
483 }
484 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
485 let lsp = server.lsp.clone();
486 let task = cx
487 .spawn(|this, mut cx| async move {
488 let sign_in = async {
489 let sign_in = lsp
490 .request::<request::SignInInitiate>(
491 request::SignInInitiateParams {},
492 )
493 .await?;
494 match sign_in {
495 request::SignInInitiateResult::AlreadySignedIn { user } => {
496 Ok(request::SignInStatus::Ok { user })
497 }
498 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
499 this.update(&mut cx, |this, cx| {
500 if let CopilotServer::Running(RunningCopilotServer {
501 sign_in_status: status,
502 ..
503 }) = &mut this.server
504 {
505 if let SignInStatus::SigningIn {
506 prompt: prompt_flow,
507 ..
508 } = status
509 {
510 *prompt_flow = Some(flow.clone());
511 cx.notify();
512 }
513 }
514 });
515 let response = lsp
516 .request::<request::SignInConfirm>(
517 request::SignInConfirmParams {
518 user_code: flow.user_code,
519 },
520 )
521 .await?;
522 Ok(response)
523 }
524 }
525 };
526
527 let sign_in = sign_in.await;
528 this.update(&mut cx, |this, cx| match sign_in {
529 Ok(status) => {
530 this.update_sign_in_status(status, cx);
531 Ok(())
532 }
533 Err(error) => {
534 this.update_sign_in_status(
535 request::SignInStatus::NotSignedIn,
536 cx,
537 );
538 Err(Arc::new(error))
539 }
540 })
541 })
542 .shared();
543 server.sign_in_status = SignInStatus::SigningIn {
544 prompt: None,
545 task: task.clone(),
546 };
547 cx.notify();
548 task
549 }
550 };
551
552 cx.foreground()
553 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
554 } else {
555 // If we're downloading, wait until download is finished
556 // If we're in a stuck state, display to the user
557 Task::ready(Err(anyhow!("copilot hasn't started yet")))
558 }
559 }
560
561 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
562 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
563 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
564 let server = server.clone();
565 cx.background().spawn(async move {
566 server
567 .request::<request::SignOut>(request::SignOutParams {})
568 .await?;
569 anyhow::Ok(())
570 })
571 } else {
572 Task::ready(Err(anyhow!("copilot hasn't started yet")))
573 }
574 }
575
576 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
577 let start_task = cx
578 .spawn({
579 let http = self.http.clone();
580 let node_runtime = self.node_runtime.clone();
581 move |this, cx| async move {
582 clear_copilot_dir().await;
583 Self::start_language_server(http, node_runtime, this, cx).await
584 }
585 })
586 .shared();
587
588 self.server = CopilotServer::Starting {
589 task: start_task.clone(),
590 };
591
592 cx.notify();
593
594 cx.foreground().spawn(start_task)
595 }
596
597 pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
598 let buffer_id = buffer.read(cx).remote_id();
599 self.buffers.insert(buffer_id, buffer.downgrade());
600
601 if let CopilotServer::Running(RunningCopilotServer {
602 lsp: server,
603 sign_in_status: status,
604 registered_buffers,
605 ..
606 }) = &mut self.server
607 {
608 if !matches!(status, SignInStatus::Authorized { .. }) {
609 return;
610 }
611
612 let buffer_id = buffer.read(cx).remote_id();
613 registered_buffers.entry(buffer_id).or_insert_with(|| {
614 let uri: lsp::Url = uri_for_buffer(buffer, cx);
615 let language_id = id_for_language(buffer.read(cx).language());
616 let snapshot = buffer.read(cx).snapshot();
617 server
618 .notify::<lsp::notification::DidOpenTextDocument>(
619 lsp::DidOpenTextDocumentParams {
620 text_document: lsp::TextDocumentItem {
621 uri: uri.clone(),
622 language_id: language_id.clone(),
623 version: 0,
624 text: snapshot.text(),
625 },
626 },
627 )
628 .log_err();
629
630 RegisteredBuffer {
631 id: buffer_id,
632 uri,
633 language_id,
634 snapshot,
635 snapshot_version: 0,
636 pending_buffer_change: Task::ready(Some(())),
637 _subscriptions: [
638 cx.subscribe(buffer, |this, buffer, event, cx| {
639 this.handle_buffer_event(buffer, event, cx).log_err();
640 }),
641 cx.observe_release(buffer, move |this, _buffer, _cx| {
642 this.buffers.remove(&buffer_id);
643 this.unregister_buffer(buffer_id);
644 }),
645 ],
646 }
647 });
648 }
649 }
650
651 fn handle_buffer_event(
652 &mut self,
653 buffer: ModelHandle<Buffer>,
654 event: &language::Event,
655 cx: &mut ModelContext<Self>,
656 ) -> Result<()> {
657 if let Ok(server) = self.server.as_running() {
658 let buffer_id = buffer.read(cx).remote_id();
659 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer_id) {
660 match event {
661 language::Event::Edited => {
662 let _ = registered_buffer.report_changes(&buffer, cx);
663 }
664 language::Event::Saved => {
665 server
666 .lsp
667 .notify::<lsp::notification::DidSaveTextDocument>(
668 lsp::DidSaveTextDocumentParams {
669 text_document: lsp::TextDocumentIdentifier::new(
670 registered_buffer.uri.clone(),
671 ),
672 text: None,
673 },
674 )?;
675 }
676 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
677 let new_language_id = id_for_language(buffer.read(cx).language());
678 let new_uri = uri_for_buffer(&buffer, cx);
679 if new_uri != registered_buffer.uri
680 || new_language_id != registered_buffer.language_id
681 {
682 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
683 registered_buffer.language_id = new_language_id;
684 server
685 .lsp
686 .notify::<lsp::notification::DidCloseTextDocument>(
687 lsp::DidCloseTextDocumentParams {
688 text_document: lsp::TextDocumentIdentifier::new(old_uri),
689 },
690 )?;
691 server
692 .lsp
693 .notify::<lsp::notification::DidOpenTextDocument>(
694 lsp::DidOpenTextDocumentParams {
695 text_document: lsp::TextDocumentItem::new(
696 registered_buffer.uri.clone(),
697 registered_buffer.language_id.clone(),
698 registered_buffer.snapshot_version,
699 registered_buffer.snapshot.text(),
700 ),
701 },
702 )?;
703 }
704 }
705 _ => {}
706 }
707 }
708 }
709
710 Ok(())
711 }
712
713 fn unregister_buffer(&mut self, buffer_id: u64) {
714 if let Ok(server) = self.server.as_running() {
715 if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
716 server
717 .lsp
718 .notify::<lsp::notification::DidCloseTextDocument>(
719 lsp::DidCloseTextDocumentParams {
720 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
721 },
722 )
723 .log_err();
724 }
725 }
726 }
727
728 pub fn completions<T>(
729 &mut self,
730 buffer: &ModelHandle<Buffer>,
731 position: T,
732 cx: &mut ModelContext<Self>,
733 ) -> Task<Result<Vec<Completion>>>
734 where
735 T: ToPointUtf16,
736 {
737 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
738 }
739
740 pub fn completions_cycling<T>(
741 &mut self,
742 buffer: &ModelHandle<Buffer>,
743 position: T,
744 cx: &mut ModelContext<Self>,
745 ) -> Task<Result<Vec<Completion>>>
746 where
747 T: ToPointUtf16,
748 {
749 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
750 }
751
752 pub fn accept_completion(
753 &mut self,
754 completion: &Completion,
755 file_type: Option<Arc<str>>,
756 cx: &mut ModelContext<Self>,
757 ) -> Task<Result<()>> {
758 let server = match self.server.as_authenticated() {
759 Ok(server) => server,
760 Err(error) => return Task::ready(Err(error)),
761 };
762
763 cx.emit(Event::CompletionAccepted {
764 uuid: completion.uuid.clone(),
765 file_type,
766 });
767
768 let request =
769 server
770 .lsp
771 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
772 uuid: completion.uuid.clone(),
773 });
774
775 cx.background().spawn(async move {
776 request.await?;
777 Ok(())
778 })
779 }
780
781 pub fn discard_completions(
782 &mut self,
783 completions: &[Completion],
784 file_type: Option<Arc<str>>,
785 cx: &mut ModelContext<Self>,
786 ) -> Task<Result<()>> {
787 let server = match self.server.as_authenticated() {
788 Ok(server) => server,
789 Err(error) => return Task::ready(Err(error)),
790 };
791
792 cx.emit(Event::CompletionsDiscarded {
793 uuids: completions
794 .iter()
795 .map(|completion| completion.uuid.clone())
796 .collect(),
797 file_type: file_type.clone(),
798 });
799
800 let request =
801 server
802 .lsp
803 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
804 uuids: completions
805 .iter()
806 .map(|completion| completion.uuid.clone())
807 .collect(),
808 });
809
810 cx.background().spawn(async move {
811 request.await?;
812 Ok(())
813 })
814 }
815
816 fn request_completions<R, T>(
817 &mut self,
818 buffer: &ModelHandle<Buffer>,
819 position: T,
820 cx: &mut ModelContext<Self>,
821 ) -> Task<Result<Vec<Completion>>>
822 where
823 R: 'static
824 + lsp::request::Request<
825 Params = request::GetCompletionsParams,
826 Result = request::GetCompletionsResult,
827 >,
828 T: ToPointUtf16,
829 {
830 self.register_buffer(buffer, cx);
831
832 let server = match self.server.as_authenticated() {
833 Ok(server) => server,
834 Err(error) => return Task::ready(Err(error)),
835 };
836 let lsp = server.lsp.clone();
837 let buffer_id = buffer.read(cx).remote_id();
838 let registered_buffer = server.registered_buffers.get_mut(&buffer_id).unwrap();
839 let snapshot = registered_buffer.report_changes(buffer, cx);
840 let buffer = buffer.read(cx);
841 let uri = registered_buffer.uri.clone();
842 let settings = cx.global::<Settings>();
843 let position = position.to_point_utf16(buffer);
844 let language = buffer.language_at(position);
845 let language_name = language.map(|language| language.name());
846 let language_name = language_name.as_deref();
847 let tab_size = settings.tab_size(language_name);
848 let hard_tabs = settings.hard_tabs(language_name);
849 let relative_path = buffer
850 .file()
851 .map(|file| file.path().to_path_buf())
852 .unwrap_or_default();
853
854 cx.foreground().spawn(async move {
855 let (version, snapshot) = snapshot.await?;
856 let result = lsp
857 .request::<R>(request::GetCompletionsParams {
858 doc: request::GetCompletionsDocument {
859 uri,
860 tab_size: tab_size.into(),
861 indent_size: 1,
862 insert_spaces: !hard_tabs,
863 relative_path: relative_path.to_string_lossy().into(),
864 position: point_to_lsp(position),
865 version: version.try_into().unwrap(),
866 },
867 })
868 .await?;
869 let completions = result
870 .completions
871 .into_iter()
872 .map(|completion| {
873 let start = snapshot
874 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
875 let end =
876 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
877 Completion {
878 uuid: completion.uuid,
879 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
880 text: completion.text,
881 }
882 })
883 .collect();
884 anyhow::Ok(completions)
885 })
886 }
887
888 pub fn status(&self) -> Status {
889 match &self.server {
890 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
891 CopilotServer::Disabled => Status::Disabled,
892 CopilotServer::Error(error) => Status::Error(error.clone()),
893 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
894 match sign_in_status {
895 SignInStatus::Authorized { .. } => Status::Authorized,
896 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
897 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
898 prompt: prompt.clone(),
899 },
900 SignInStatus::SignedOut => Status::SignedOut,
901 }
902 }
903 }
904 }
905
906 fn update_sign_in_status(
907 &mut self,
908 lsp_status: request::SignInStatus,
909 cx: &mut ModelContext<Self>,
910 ) {
911 self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
912
913 if let Ok(server) = self.server.as_running() {
914 match lsp_status {
915 request::SignInStatus::Ok { .. }
916 | request::SignInStatus::MaybeOk { .. }
917 | request::SignInStatus::AlreadySignedIn { .. } => {
918 server.sign_in_status = SignInStatus::Authorized;
919 for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
920 if let Some(buffer) = buffer.upgrade(cx) {
921 self.register_buffer(&buffer, cx);
922 }
923 }
924 }
925 request::SignInStatus::NotAuthorized { .. } => {
926 server.sign_in_status = SignInStatus::Unauthorized;
927 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
928 self.unregister_buffer(buffer_id);
929 }
930 }
931 request::SignInStatus::NotSignedIn => {
932 server.sign_in_status = SignInStatus::SignedOut;
933 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
934 self.unregister_buffer(buffer_id);
935 }
936 }
937 }
938
939 cx.notify();
940 }
941 }
942}
943
944fn id_for_language(language: Option<&Arc<Language>>) -> String {
945 let language_name = language.map(|language| language.name());
946 match language_name.as_deref() {
947 Some("Plain Text") => "plaintext".to_string(),
948 Some(language_name) => language_name.to_lowercase(),
949 None => "plaintext".to_string(),
950 }
951}
952
953fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
954 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
955 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
956 } else {
957 format!("buffer://{}", buffer.read(cx).remote_id())
958 .parse()
959 .unwrap()
960 }
961}
962
963async fn clear_copilot_dir() {
964 remove_matching(&paths::COPILOT_DIR, |_| true).await
965}
966
967async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
968 const SERVER_PATH: &'static str = "dist/agent.js";
969
970 ///Check for the latest copilot language server and download it if we haven't already
971 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
972 let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
973
974 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
975
976 fs::create_dir_all(version_dir).await?;
977 let server_path = version_dir.join(SERVER_PATH);
978
979 if fs::metadata(&server_path).await.is_err() {
980 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
981 let dist_dir = version_dir.join("dist");
982 fs::create_dir_all(dist_dir.as_path()).await?;
983
984 let url = &release
985 .assets
986 .get(0)
987 .context("Github release for copilot contained no assets")?
988 .browser_download_url;
989
990 let mut response = http
991 .get(&url, Default::default(), true)
992 .await
993 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
994 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
995 let archive = Archive::new(decompressed_bytes);
996 archive.unpack(dist_dir).await?;
997
998 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
999 }
1000
1001 Ok(server_path)
1002 }
1003
1004 match fetch_latest(http).await {
1005 ok @ Result::Ok(..) => ok,
1006 e @ Err(..) => {
1007 e.log_err();
1008 // Fetch a cached binary, if it exists
1009 (|| async move {
1010 let mut last_version_dir = None;
1011 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1012 while let Some(entry) = entries.next().await {
1013 let entry = entry?;
1014 if entry.file_type().await?.is_dir() {
1015 last_version_dir = Some(entry.path());
1016 }
1017 }
1018 let last_version_dir =
1019 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1020 let server_path = last_version_dir.join(SERVER_PATH);
1021 if server_path.exists() {
1022 Ok(server_path)
1023 } else {
1024 Err(anyhow!(
1025 "missing executable in directory {:?}",
1026 last_version_dir
1027 ))
1028 }
1029 })()
1030 .await
1031 }
1032 }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037 use super::*;
1038 use gpui::{executor::Deterministic, TestAppContext};
1039
1040 #[gpui::test(iterations = 10)]
1041 async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1042 deterministic.forbid_parking();
1043 let (copilot, mut lsp) = Copilot::fake(cx);
1044
1045 let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
1046 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1047 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1048 assert_eq!(
1049 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1050 .await,
1051 lsp::DidOpenTextDocumentParams {
1052 text_document: lsp::TextDocumentItem::new(
1053 buffer_1_uri.clone(),
1054 "plaintext".into(),
1055 0,
1056 "Hello".into()
1057 ),
1058 }
1059 );
1060
1061 let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
1062 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1063 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1064 assert_eq!(
1065 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1066 .await,
1067 lsp::DidOpenTextDocumentParams {
1068 text_document: lsp::TextDocumentItem::new(
1069 buffer_2_uri.clone(),
1070 "plaintext".into(),
1071 0,
1072 "Goodbye".into()
1073 ),
1074 }
1075 );
1076
1077 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1078 assert_eq!(
1079 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1080 .await,
1081 lsp::DidChangeTextDocumentParams {
1082 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1083 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1084 range: Some(lsp::Range::new(
1085 lsp::Position::new(0, 5),
1086 lsp::Position::new(0, 5)
1087 )),
1088 range_length: None,
1089 text: " world".into(),
1090 }],
1091 }
1092 );
1093
1094 // Ensure updates to the file are reflected in the LSP.
1095 buffer_1
1096 .update(cx, |buffer, cx| {
1097 buffer.file_updated(
1098 Arc::new(File {
1099 abs_path: "/root/child/buffer-1".into(),
1100 path: Path::new("child/buffer-1").into(),
1101 }),
1102 cx,
1103 )
1104 })
1105 .await;
1106 assert_eq!(
1107 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1108 .await,
1109 lsp::DidCloseTextDocumentParams {
1110 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1111 }
1112 );
1113 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1114 assert_eq!(
1115 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1116 .await,
1117 lsp::DidOpenTextDocumentParams {
1118 text_document: lsp::TextDocumentItem::new(
1119 buffer_1_uri.clone(),
1120 "plaintext".into(),
1121 1,
1122 "Hello world".into()
1123 ),
1124 }
1125 );
1126
1127 // Ensure all previously-registered buffers are closed when signing out.
1128 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1129 Ok(request::SignOutResult {})
1130 });
1131 copilot
1132 .update(cx, |copilot, cx| copilot.sign_out(cx))
1133 .await
1134 .unwrap();
1135 assert_eq!(
1136 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1137 .await,
1138 lsp::DidCloseTextDocumentParams {
1139 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1140 }
1141 );
1142 assert_eq!(
1143 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1144 .await,
1145 lsp::DidCloseTextDocumentParams {
1146 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1147 }
1148 );
1149
1150 // Ensure all previously-registered buffers are re-opened when signing in.
1151 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1152 Ok(request::SignInInitiateResult::AlreadySignedIn {
1153 user: "user-1".into(),
1154 })
1155 });
1156 copilot
1157 .update(cx, |copilot, cx| copilot.sign_in(cx))
1158 .await
1159 .unwrap();
1160 assert_eq!(
1161 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1162 .await,
1163 lsp::DidOpenTextDocumentParams {
1164 text_document: lsp::TextDocumentItem::new(
1165 buffer_2_uri.clone(),
1166 "plaintext".into(),
1167 0,
1168 "Goodbye".into()
1169 ),
1170 }
1171 );
1172 assert_eq!(
1173 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1174 .await,
1175 lsp::DidOpenTextDocumentParams {
1176 text_document: lsp::TextDocumentItem::new(
1177 buffer_1_uri.clone(),
1178 "plaintext".into(),
1179 0,
1180 "Hello world".into()
1181 ),
1182 }
1183 );
1184
1185 // Dropping a buffer causes it to be closed on the LSP side as well.
1186 cx.update(|_| drop(buffer_2));
1187 assert_eq!(
1188 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1189 .await,
1190 lsp::DidCloseTextDocumentParams {
1191 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1192 }
1193 );
1194 }
1195
1196 struct File {
1197 abs_path: PathBuf,
1198 path: Arc<Path>,
1199 }
1200
1201 impl language::File for File {
1202 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1203 Some(self)
1204 }
1205
1206 fn mtime(&self) -> std::time::SystemTime {
1207 unimplemented!()
1208 }
1209
1210 fn path(&self) -> &Arc<Path> {
1211 &self.path
1212 }
1213
1214 fn full_path(&self, _: &AppContext) -> PathBuf {
1215 unimplemented!()
1216 }
1217
1218 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1219 unimplemented!()
1220 }
1221
1222 fn is_deleted(&self) -> bool {
1223 unimplemented!()
1224 }
1225
1226 fn as_any(&self) -> &dyn std::any::Any {
1227 unimplemented!()
1228 }
1229
1230 fn to_proto(&self) -> rpc::proto::File {
1231 unimplemented!()
1232 }
1233 }
1234
1235 impl language::LocalFile for File {
1236 fn abs_path(&self, _: &AppContext) -> PathBuf {
1237 self.abs_path.clone()
1238 }
1239
1240 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1241 unimplemented!()
1242 }
1243
1244 fn buffer_reloaded(
1245 &self,
1246 _: u64,
1247 _: &clock::Global,
1248 _: language::RopeFingerprint,
1249 _: ::fs::LineEnding,
1250 _: std::time::SystemTime,
1251 _: &mut AppContext,
1252 ) {
1253 unimplemented!()
1254 }
1255 }
1256}