1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
11};
12use language::{
13 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
14 ToPointUtf16,
15};
16use log::{debug, error};
17use lsp::LanguageServer;
18use node_runtime::NodeRuntime;
19use request::{LogMessage, StatusNotification};
20use settings::Settings;
21use smol::{fs, io::BufReader, stream::StreamExt};
22use std::{
23 ffi::OsString,
24 mem,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use util::{
30 channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
31 paths, ResultExt,
32};
33
34const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
35actions!(copilot_auth, [SignIn, SignOut]);
36
37const COPILOT_NAMESPACE: &'static str = "copilot";
38actions!(
39 copilot,
40 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
41);
42
43pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
44 // Disable Copilot for stable releases.
45 if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
46 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
47 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
48 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
49 });
50 return;
51 }
52
53 let copilot = cx.add_model({
54 let node_runtime = node_runtime.clone();
55 move |cx| Copilot::start(http, node_runtime, cx)
56 });
57 cx.set_global(copilot.clone());
58
59 cx.observe(&copilot, |handle, cx| {
60 let status = handle.read(cx).status();
61 cx.update_global::<collections::CommandPaletteFilter, _, _>(
62 move |filter, _cx| match status {
63 Status::Disabled => {
64 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
65 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
66 }
67 Status::Authorized => {
68 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
69 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
70 }
71 _ => {
72 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
73 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
74 }
75 },
76 );
77 })
78 .detach();
79
80 sign_in::init(cx);
81 cx.add_global_action(|_: &SignIn, cx| {
82 if let Some(copilot) = Copilot::global(cx) {
83 copilot
84 .update(cx, |copilot, cx| copilot.sign_in(cx))
85 .detach_and_log_err(cx);
86 }
87 });
88 cx.add_global_action(|_: &SignOut, cx| {
89 if let Some(copilot) = Copilot::global(cx) {
90 copilot
91 .update(cx, |copilot, cx| copilot.sign_out(cx))
92 .detach_and_log_err(cx);
93 }
94 });
95
96 cx.add_global_action(|_: &Reinstall, cx| {
97 if let Some(copilot) = Copilot::global(cx) {
98 copilot
99 .update(cx, |copilot, cx| copilot.reinstall(cx))
100 .detach();
101 }
102 });
103}
104
105enum CopilotServer {
106 Disabled,
107 Starting { task: Shared<Task<()>> },
108 Error(Arc<str>),
109 Running(RunningCopilotServer),
110}
111
112impl CopilotServer {
113 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
114 let server = self.as_running()?;
115 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
116 Ok(server)
117 } else {
118 Err(anyhow!("must sign in before using copilot"))
119 }
120 }
121
122 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
123 match self {
124 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
125 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
126 CopilotServer::Error(error) => Err(anyhow!(
127 "copilot was not started because of an error: {}",
128 error
129 )),
130 CopilotServer::Running(server) => Ok(server),
131 }
132 }
133}
134
135struct RunningCopilotServer {
136 lsp: Arc<LanguageServer>,
137 sign_in_status: SignInStatus,
138 registered_buffers: HashMap<usize, RegisteredBuffer>,
139}
140
141#[derive(Clone, Debug)]
142enum SignInStatus {
143 Authorized,
144 Unauthorized,
145 SigningIn {
146 prompt: Option<request::PromptUserDeviceFlow>,
147 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
148 },
149 SignedOut,
150}
151
152#[derive(Debug, Clone)]
153pub enum Status {
154 Starting {
155 task: Shared<Task<()>>,
156 },
157 Error(Arc<str>),
158 Disabled,
159 SignedOut,
160 SigningIn {
161 prompt: Option<request::PromptUserDeviceFlow>,
162 },
163 Unauthorized,
164 Authorized,
165}
166
167impl Status {
168 pub fn is_authorized(&self) -> bool {
169 matches!(self, Status::Authorized)
170 }
171}
172
173struct RegisteredBuffer {
174 id: usize,
175 uri: lsp::Url,
176 language_id: String,
177 snapshot: BufferSnapshot,
178 snapshot_version: i32,
179 _subscriptions: [gpui::Subscription; 2],
180 pending_buffer_change: Task<Option<()>>,
181}
182
183impl RegisteredBuffer {
184 fn report_changes(
185 &mut self,
186 buffer: &ModelHandle<Buffer>,
187 cx: &mut ModelContext<Copilot>,
188 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
189 let id = self.id;
190 let (done_tx, done_rx) = oneshot::channel();
191
192 if buffer.read(cx).version() == self.snapshot.version {
193 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
194 } else {
195 let buffer = buffer.downgrade();
196 let prev_pending_change =
197 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
198 self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
199 prev_pending_change.await;
200
201 let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
202 let server = copilot.server.as_authenticated().log_err()?;
203 let buffer = server.registered_buffers.get_mut(&id)?;
204 Some(buffer.snapshot.version.clone())
205 })?;
206 let new_snapshot = buffer
207 .upgrade(&cx)?
208 .read_with(&cx, |buffer, _| buffer.snapshot());
209
210 let content_changes = cx
211 .background()
212 .spawn({
213 let new_snapshot = new_snapshot.clone();
214 async move {
215 new_snapshot
216 .edits_since::<(PointUtf16, usize)>(&old_version)
217 .map(|edit| {
218 let edit_start = edit.new.start.0;
219 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
220 let new_text = new_snapshot
221 .text_for_range(edit.new.start.1..edit.new.end.1)
222 .collect();
223 lsp::TextDocumentContentChangeEvent {
224 range: Some(lsp::Range::new(
225 point_to_lsp(edit_start),
226 point_to_lsp(edit_end),
227 )),
228 range_length: None,
229 text: new_text,
230 }
231 })
232 .collect::<Vec<_>>()
233 }
234 })
235 .await;
236
237 copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
238 let server = copilot.server.as_authenticated().log_err()?;
239 let buffer = server.registered_buffers.get_mut(&id)?;
240 if !content_changes.is_empty() {
241 buffer.snapshot_version += 1;
242 buffer.snapshot = new_snapshot;
243 server
244 .lsp
245 .notify::<lsp::notification::DidChangeTextDocument>(
246 lsp::DidChangeTextDocumentParams {
247 text_document: lsp::VersionedTextDocumentIdentifier::new(
248 buffer.uri.clone(),
249 buffer.snapshot_version,
250 ),
251 content_changes,
252 },
253 )
254 .log_err();
255 }
256 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
257 Some(())
258 })?;
259
260 Some(())
261 });
262 }
263
264 done_rx
265 }
266}
267
268#[derive(Debug)]
269pub struct Completion {
270 uuid: String,
271 pub range: Range<Anchor>,
272 pub text: String,
273}
274
275pub struct Copilot {
276 http: Arc<dyn HttpClient>,
277 node_runtime: Arc<NodeRuntime>,
278 server: CopilotServer,
279 buffers: HashMap<usize, WeakModelHandle<Buffer>>,
280}
281
282impl Entity for Copilot {
283 type Event = ();
284}
285
286impl Copilot {
287 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
288 if cx.has_global::<ModelHandle<Self>>() {
289 Some(cx.global::<ModelHandle<Self>>().clone())
290 } else {
291 None
292 }
293 }
294
295 fn start(
296 http: Arc<dyn HttpClient>,
297 node_runtime: Arc<NodeRuntime>,
298 cx: &mut ModelContext<Self>,
299 ) -> Self {
300 cx.observe_global::<Settings, _>({
301 let http = http.clone();
302 let node_runtime = node_runtime.clone();
303 move |this, cx| {
304 if cx.global::<Settings>().features.copilot {
305 if matches!(this.server, CopilotServer::Disabled) {
306 let start_task = cx
307 .spawn({
308 let http = http.clone();
309 let node_runtime = node_runtime.clone();
310 move |this, cx| {
311 Self::start_language_server(http, node_runtime, this, cx)
312 }
313 })
314 .shared();
315 this.server = CopilotServer::Starting { task: start_task };
316 cx.notify();
317 }
318 } else {
319 this.server = CopilotServer::Disabled;
320 cx.notify();
321 }
322 }
323 })
324 .detach();
325
326 if cx.global::<Settings>().features.copilot {
327 let start_task = cx
328 .spawn({
329 let http = http.clone();
330 let node_runtime = node_runtime.clone();
331 move |this, cx| async {
332 Self::start_language_server(http, node_runtime, this, cx).await
333 }
334 })
335 .shared();
336
337 Self {
338 http,
339 node_runtime,
340 server: CopilotServer::Starting { task: start_task },
341 buffers: Default::default(),
342 }
343 } else {
344 Self {
345 http,
346 node_runtime,
347 server: CopilotServer::Disabled,
348 buffers: Default::default(),
349 }
350 }
351 }
352
353 #[cfg(any(test, feature = "test-support"))]
354 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
355 let (server, fake_server) =
356 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
357 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
358 let this = cx.add_model(|cx| Self {
359 http: http.clone(),
360 node_runtime: NodeRuntime::new(http, cx.background().clone()),
361 server: CopilotServer::Running(RunningCopilotServer {
362 lsp: Arc::new(server),
363 sign_in_status: SignInStatus::Authorized,
364 registered_buffers: Default::default(),
365 }),
366 buffers: Default::default(),
367 });
368 (this, fake_server)
369 }
370
371 fn start_language_server(
372 http: Arc<dyn HttpClient>,
373 node_runtime: Arc<NodeRuntime>,
374 this: ModelHandle<Self>,
375 mut cx: AsyncAppContext,
376 ) -> impl Future<Output = ()> {
377 async move {
378 let start_language_server = async {
379 let server_path = get_copilot_lsp(http).await?;
380 let node_path = node_runtime.binary_path().await?;
381 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
382 let server = LanguageServer::new(
383 0,
384 &node_path,
385 arguments,
386 Path::new("/"),
387 None,
388 cx.clone(),
389 )?;
390
391 let server = server.initialize(Default::default()).await?;
392 let status = server
393 .request::<request::CheckStatus>(request::CheckStatusParams {
394 local_checks_only: false,
395 })
396 .await?;
397
398 server
399 .on_notification::<LogMessage, _>(|params, _cx| {
400 match params.level {
401 // Copilot is pretty agressive about logging
402 0 => debug!("copilot: {}", params.message),
403 1 => debug!("copilot: {}", params.message),
404 _ => error!("copilot: {}", params.message),
405 }
406
407 debug!("copilot metadata: {}", params.metadata_str);
408 debug!("copilot extra: {:?}", params.extra);
409 })
410 .detach();
411
412 server
413 .on_notification::<StatusNotification, _>(
414 |_, _| { /* Silence the notification */ },
415 )
416 .detach();
417
418 server
419 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
420 editor_info: request::EditorInfo {
421 name: "zed".into(),
422 version: env!("CARGO_PKG_VERSION").into(),
423 },
424 editor_plugin_info: request::EditorPluginInfo {
425 name: "zed-copilot".into(),
426 version: "0.0.1".into(),
427 },
428 })
429 .await?;
430
431 anyhow::Ok((server, status))
432 };
433
434 let server = start_language_server.await;
435 this.update(&mut cx, |this, cx| {
436 cx.notify();
437 match server {
438 Ok((server, status)) => {
439 this.server = CopilotServer::Running(RunningCopilotServer {
440 lsp: server,
441 sign_in_status: SignInStatus::SignedOut,
442 registered_buffers: Default::default(),
443 });
444 this.update_sign_in_status(status, cx);
445 }
446 Err(error) => {
447 this.server = CopilotServer::Error(error.to_string().into());
448 cx.notify()
449 }
450 }
451 })
452 }
453 }
454
455 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
456 if let CopilotServer::Running(server) = &mut self.server {
457 let task = match &server.sign_in_status {
458 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
459 Task::ready(Ok(())).shared()
460 }
461 SignInStatus::SigningIn { task, .. } => {
462 cx.notify();
463 task.clone()
464 }
465 SignInStatus::SignedOut => {
466 let lsp = server.lsp.clone();
467 let task = cx
468 .spawn(|this, mut cx| async move {
469 let sign_in = async {
470 let sign_in = lsp
471 .request::<request::SignInInitiate>(
472 request::SignInInitiateParams {},
473 )
474 .await?;
475 match sign_in {
476 request::SignInInitiateResult::AlreadySignedIn { user } => {
477 Ok(request::SignInStatus::Ok { user })
478 }
479 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
480 this.update(&mut cx, |this, cx| {
481 if let CopilotServer::Running(RunningCopilotServer {
482 sign_in_status: status,
483 ..
484 }) = &mut this.server
485 {
486 if let SignInStatus::SigningIn {
487 prompt: prompt_flow,
488 ..
489 } = status
490 {
491 *prompt_flow = Some(flow.clone());
492 cx.notify();
493 }
494 }
495 });
496 let response = lsp
497 .request::<request::SignInConfirm>(
498 request::SignInConfirmParams {
499 user_code: flow.user_code,
500 },
501 )
502 .await?;
503 Ok(response)
504 }
505 }
506 };
507
508 let sign_in = sign_in.await;
509 this.update(&mut cx, |this, cx| match sign_in {
510 Ok(status) => {
511 this.update_sign_in_status(status, cx);
512 Ok(())
513 }
514 Err(error) => {
515 this.update_sign_in_status(
516 request::SignInStatus::NotSignedIn,
517 cx,
518 );
519 Err(Arc::new(error))
520 }
521 })
522 })
523 .shared();
524 server.sign_in_status = SignInStatus::SigningIn {
525 prompt: None,
526 task: task.clone(),
527 };
528 cx.notify();
529 task
530 }
531 };
532
533 cx.foreground()
534 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
535 } else {
536 // If we're downloading, wait until download is finished
537 // If we're in a stuck state, display to the user
538 Task::ready(Err(anyhow!("copilot hasn't started yet")))
539 }
540 }
541
542 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
543 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
544 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
545 let server = server.clone();
546 cx.background().spawn(async move {
547 server
548 .request::<request::SignOut>(request::SignOutParams {})
549 .await?;
550 anyhow::Ok(())
551 })
552 } else {
553 Task::ready(Err(anyhow!("copilot hasn't started yet")))
554 }
555 }
556
557 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
558 let start_task = cx
559 .spawn({
560 let http = self.http.clone();
561 let node_runtime = self.node_runtime.clone();
562 move |this, cx| async move {
563 clear_copilot_dir().await;
564 Self::start_language_server(http, node_runtime, this, cx).await
565 }
566 })
567 .shared();
568
569 self.server = CopilotServer::Starting {
570 task: start_task.clone(),
571 };
572
573 cx.notify();
574
575 cx.foreground().spawn(start_task)
576 }
577
578 pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
579 let buffer_id = buffer.id();
580 self.buffers.insert(buffer_id, buffer.downgrade());
581
582 if let CopilotServer::Running(RunningCopilotServer {
583 lsp: server,
584 sign_in_status: status,
585 registered_buffers,
586 ..
587 }) = &mut self.server
588 {
589 if !matches!(status, SignInStatus::Authorized { .. }) {
590 return;
591 }
592
593 registered_buffers.entry(buffer.id()).or_insert_with(|| {
594 let uri: lsp::Url = uri_for_buffer(buffer, cx);
595 let language_id = id_for_language(buffer.read(cx).language());
596 let snapshot = buffer.read(cx).snapshot();
597 server
598 .notify::<lsp::notification::DidOpenTextDocument>(
599 lsp::DidOpenTextDocumentParams {
600 text_document: lsp::TextDocumentItem {
601 uri: uri.clone(),
602 language_id: language_id.clone(),
603 version: 0,
604 text: snapshot.text(),
605 },
606 },
607 )
608 .log_err();
609
610 RegisteredBuffer {
611 id: buffer_id,
612 uri,
613 language_id,
614 snapshot,
615 snapshot_version: 0,
616 pending_buffer_change: Task::ready(Some(())),
617 _subscriptions: [
618 cx.subscribe(buffer, |this, buffer, event, cx| {
619 this.handle_buffer_event(buffer, event, cx).log_err();
620 }),
621 cx.observe_release(buffer, move |this, _buffer, _cx| {
622 this.buffers.remove(&buffer_id);
623 this.unregister_buffer(buffer_id);
624 }),
625 ],
626 }
627 });
628 }
629 }
630
631 fn handle_buffer_event(
632 &mut self,
633 buffer: ModelHandle<Buffer>,
634 event: &language::Event,
635 cx: &mut ModelContext<Self>,
636 ) -> Result<()> {
637 if let Ok(server) = self.server.as_running() {
638 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
639 match event {
640 language::Event::Edited => {
641 let _ = registered_buffer.report_changes(&buffer, cx);
642 }
643 language::Event::Saved => {
644 server
645 .lsp
646 .notify::<lsp::notification::DidSaveTextDocument>(
647 lsp::DidSaveTextDocumentParams {
648 text_document: lsp::TextDocumentIdentifier::new(
649 registered_buffer.uri.clone(),
650 ),
651 text: None,
652 },
653 )?;
654 }
655 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
656 let new_language_id = id_for_language(buffer.read(cx).language());
657 let new_uri = uri_for_buffer(&buffer, cx);
658 if new_uri != registered_buffer.uri
659 || new_language_id != registered_buffer.language_id
660 {
661 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
662 registered_buffer.language_id = new_language_id;
663 server
664 .lsp
665 .notify::<lsp::notification::DidCloseTextDocument>(
666 lsp::DidCloseTextDocumentParams {
667 text_document: lsp::TextDocumentIdentifier::new(old_uri),
668 },
669 )?;
670 server
671 .lsp
672 .notify::<lsp::notification::DidOpenTextDocument>(
673 lsp::DidOpenTextDocumentParams {
674 text_document: lsp::TextDocumentItem::new(
675 registered_buffer.uri.clone(),
676 registered_buffer.language_id.clone(),
677 registered_buffer.snapshot_version,
678 registered_buffer.snapshot.text(),
679 ),
680 },
681 )?;
682 }
683 }
684 _ => {}
685 }
686 }
687 }
688
689 Ok(())
690 }
691
692 fn unregister_buffer(&mut self, buffer_id: usize) {
693 if let Ok(server) = self.server.as_running() {
694 if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
695 server
696 .lsp
697 .notify::<lsp::notification::DidCloseTextDocument>(
698 lsp::DidCloseTextDocumentParams {
699 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
700 },
701 )
702 .log_err();
703 }
704 }
705 }
706
707 pub fn completions<T>(
708 &mut self,
709 buffer: &ModelHandle<Buffer>,
710 position: T,
711 cx: &mut ModelContext<Self>,
712 ) -> Task<Result<Vec<Completion>>>
713 where
714 T: ToPointUtf16,
715 {
716 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
717 }
718
719 pub fn completions_cycling<T>(
720 &mut self,
721 buffer: &ModelHandle<Buffer>,
722 position: T,
723 cx: &mut ModelContext<Self>,
724 ) -> Task<Result<Vec<Completion>>>
725 where
726 T: ToPointUtf16,
727 {
728 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
729 }
730
731 pub fn accept_completion(
732 &mut self,
733 completion: &Completion,
734 cx: &mut ModelContext<Self>,
735 ) -> Task<Result<()>> {
736 let server = match self.server.as_authenticated() {
737 Ok(server) => server,
738 Err(error) => return Task::ready(Err(error)),
739 };
740 let request =
741 server
742 .lsp
743 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
744 uuid: completion.uuid.clone(),
745 });
746 cx.background().spawn(async move {
747 request.await?;
748 Ok(())
749 })
750 }
751
752 pub fn discard_completions(
753 &mut self,
754 completions: &[Completion],
755 cx: &mut ModelContext<Self>,
756 ) -> Task<Result<()>> {
757 let server = match self.server.as_authenticated() {
758 Ok(server) => server,
759 Err(error) => return Task::ready(Err(error)),
760 };
761 let request =
762 server
763 .lsp
764 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
765 uuids: completions
766 .iter()
767 .map(|completion| completion.uuid.clone())
768 .collect(),
769 });
770 cx.background().spawn(async move {
771 request.await?;
772 Ok(())
773 })
774 }
775
776 fn request_completions<R, T>(
777 &mut self,
778 buffer: &ModelHandle<Buffer>,
779 position: T,
780 cx: &mut ModelContext<Self>,
781 ) -> Task<Result<Vec<Completion>>>
782 where
783 R: 'static
784 + lsp::request::Request<
785 Params = request::GetCompletionsParams,
786 Result = request::GetCompletionsResult,
787 >,
788 T: ToPointUtf16,
789 {
790 self.register_buffer(buffer, cx);
791
792 let server = match self.server.as_authenticated() {
793 Ok(server) => server,
794 Err(error) => return Task::ready(Err(error)),
795 };
796 let lsp = server.lsp.clone();
797 let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
798 let snapshot = registered_buffer.report_changes(buffer, cx);
799 let buffer = buffer.read(cx);
800 let uri = registered_buffer.uri.clone();
801 let settings = cx.global::<Settings>();
802 let position = position.to_point_utf16(buffer);
803 let language = buffer.language_at(position);
804 let language_name = language.map(|language| language.name());
805 let language_name = language_name.as_deref();
806 let tab_size = settings.tab_size(language_name);
807 let hard_tabs = settings.hard_tabs(language_name);
808 let relative_path = buffer
809 .file()
810 .map(|file| file.path().to_path_buf())
811 .unwrap_or_default();
812
813 cx.foreground().spawn(async move {
814 let (version, snapshot) = snapshot.await?;
815 let result = lsp
816 .request::<R>(request::GetCompletionsParams {
817 doc: request::GetCompletionsDocument {
818 uri,
819 tab_size: tab_size.into(),
820 indent_size: 1,
821 insert_spaces: !hard_tabs,
822 relative_path: relative_path.to_string_lossy().into(),
823 position: point_to_lsp(position),
824 version: version.try_into().unwrap(),
825 },
826 })
827 .await?;
828 let completions = result
829 .completions
830 .into_iter()
831 .map(|completion| {
832 let start = snapshot
833 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
834 let end =
835 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
836 Completion {
837 uuid: completion.uuid,
838 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
839 text: completion.text,
840 }
841 })
842 .collect();
843 anyhow::Ok(completions)
844 })
845 }
846
847 pub fn status(&self) -> Status {
848 match &self.server {
849 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
850 CopilotServer::Disabled => Status::Disabled,
851 CopilotServer::Error(error) => Status::Error(error.clone()),
852 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
853 match sign_in_status {
854 SignInStatus::Authorized { .. } => Status::Authorized,
855 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
856 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
857 prompt: prompt.clone(),
858 },
859 SignInStatus::SignedOut => Status::SignedOut,
860 }
861 }
862 }
863 }
864
865 fn update_sign_in_status(
866 &mut self,
867 lsp_status: request::SignInStatus,
868 cx: &mut ModelContext<Self>,
869 ) {
870 self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
871
872 if let Ok(server) = self.server.as_running() {
873 match lsp_status {
874 request::SignInStatus::Ok { .. }
875 | request::SignInStatus::MaybeOk { .. }
876 | request::SignInStatus::AlreadySignedIn { .. } => {
877 server.sign_in_status = SignInStatus::Authorized;
878 for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
879 if let Some(buffer) = buffer.upgrade(cx) {
880 self.register_buffer(&buffer, cx);
881 }
882 }
883 }
884 request::SignInStatus::NotAuthorized { .. } => {
885 server.sign_in_status = SignInStatus::Unauthorized;
886 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
887 self.unregister_buffer(buffer_id);
888 }
889 }
890 request::SignInStatus::NotSignedIn => {
891 server.sign_in_status = SignInStatus::SignedOut;
892 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
893 self.unregister_buffer(buffer_id);
894 }
895 }
896 }
897
898 cx.notify();
899 }
900 }
901}
902
903fn id_for_language(language: Option<&Arc<Language>>) -> String {
904 let language_name = language.map(|language| language.name());
905 match language_name.as_deref() {
906 Some("Plain Text") => "plaintext".to_string(),
907 Some(language_name) => language_name.to_lowercase(),
908 None => "plaintext".to_string(),
909 }
910}
911
912fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
913 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
914 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
915 } else {
916 format!("buffer://{}", buffer.id()).parse().unwrap()
917 }
918}
919
920async fn clear_copilot_dir() {
921 remove_matching(&paths::COPILOT_DIR, |_| true).await
922}
923
924async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
925 const SERVER_PATH: &'static str = "dist/agent.js";
926
927 ///Check for the latest copilot language server and download it if we haven't already
928 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
929 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
930
931 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
932
933 fs::create_dir_all(version_dir).await?;
934 let server_path = version_dir.join(SERVER_PATH);
935
936 if fs::metadata(&server_path).await.is_err() {
937 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
938 let dist_dir = version_dir.join("dist");
939 fs::create_dir_all(dist_dir.as_path()).await?;
940
941 let url = &release
942 .assets
943 .get(0)
944 .context("Github release for copilot contained no assets")?
945 .browser_download_url;
946
947 let mut response = http
948 .get(&url, Default::default(), true)
949 .await
950 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
951 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
952 let archive = Archive::new(decompressed_bytes);
953 archive.unpack(dist_dir).await?;
954
955 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
956 }
957
958 Ok(server_path)
959 }
960
961 match fetch_latest(http).await {
962 ok @ Result::Ok(..) => ok,
963 e @ Err(..) => {
964 e.log_err();
965 // Fetch a cached binary, if it exists
966 (|| async move {
967 let mut last_version_dir = None;
968 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
969 while let Some(entry) = entries.next().await {
970 let entry = entry?;
971 if entry.file_type().await?.is_dir() {
972 last_version_dir = Some(entry.path());
973 }
974 }
975 let last_version_dir =
976 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
977 let server_path = last_version_dir.join(SERVER_PATH);
978 if server_path.exists() {
979 Ok(server_path)
980 } else {
981 Err(anyhow!(
982 "missing executable in directory {:?}",
983 last_version_dir
984 ))
985 }
986 })()
987 .await
988 }
989 }
990}
991
992#[cfg(test)]
993mod tests {
994 use super::*;
995 use gpui::{executor::Deterministic, TestAppContext};
996
997 #[gpui::test(iterations = 10)]
998 async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
999 deterministic.forbid_parking();
1000 let (copilot, mut lsp) = Copilot::fake(cx);
1001
1002 let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
1003 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1004 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1005 assert_eq!(
1006 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1007 .await,
1008 lsp::DidOpenTextDocumentParams {
1009 text_document: lsp::TextDocumentItem::new(
1010 buffer_1_uri.clone(),
1011 "plaintext".into(),
1012 0,
1013 "Hello".into()
1014 ),
1015 }
1016 );
1017
1018 let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
1019 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1020 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1021 assert_eq!(
1022 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1023 .await,
1024 lsp::DidOpenTextDocumentParams {
1025 text_document: lsp::TextDocumentItem::new(
1026 buffer_2_uri.clone(),
1027 "plaintext".into(),
1028 0,
1029 "Goodbye".into()
1030 ),
1031 }
1032 );
1033
1034 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1035 assert_eq!(
1036 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1037 .await,
1038 lsp::DidChangeTextDocumentParams {
1039 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1040 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1041 range: Some(lsp::Range::new(
1042 lsp::Position::new(0, 5),
1043 lsp::Position::new(0, 5)
1044 )),
1045 range_length: None,
1046 text: " world".into(),
1047 }],
1048 }
1049 );
1050
1051 // Ensure updates to the file are reflected in the LSP.
1052 buffer_1
1053 .update(cx, |buffer, cx| {
1054 buffer.file_updated(
1055 Arc::new(File {
1056 abs_path: "/root/child/buffer-1".into(),
1057 path: Path::new("child/buffer-1").into(),
1058 }),
1059 cx,
1060 )
1061 })
1062 .await;
1063 assert_eq!(
1064 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1065 .await,
1066 lsp::DidCloseTextDocumentParams {
1067 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1068 }
1069 );
1070 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1071 assert_eq!(
1072 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1073 .await,
1074 lsp::DidOpenTextDocumentParams {
1075 text_document: lsp::TextDocumentItem::new(
1076 buffer_1_uri.clone(),
1077 "plaintext".into(),
1078 1,
1079 "Hello world".into()
1080 ),
1081 }
1082 );
1083
1084 // Ensure all previously-registered buffers are closed when signing out.
1085 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1086 Ok(request::SignOutResult {})
1087 });
1088 copilot
1089 .update(cx, |copilot, cx| copilot.sign_out(cx))
1090 .await
1091 .unwrap();
1092 assert_eq!(
1093 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1094 .await,
1095 lsp::DidCloseTextDocumentParams {
1096 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1097 }
1098 );
1099 assert_eq!(
1100 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1101 .await,
1102 lsp::DidCloseTextDocumentParams {
1103 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1104 }
1105 );
1106
1107 // Ensure all previously-registered buffers are re-opened when signing in.
1108 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1109 Ok(request::SignInInitiateResult::AlreadySignedIn {
1110 user: "user-1".into(),
1111 })
1112 });
1113 copilot
1114 .update(cx, |copilot, cx| copilot.sign_in(cx))
1115 .await
1116 .unwrap();
1117 assert_eq!(
1118 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1119 .await,
1120 lsp::DidOpenTextDocumentParams {
1121 text_document: lsp::TextDocumentItem::new(
1122 buffer_2_uri.clone(),
1123 "plaintext".into(),
1124 0,
1125 "Goodbye".into()
1126 ),
1127 }
1128 );
1129 assert_eq!(
1130 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1131 .await,
1132 lsp::DidOpenTextDocumentParams {
1133 text_document: lsp::TextDocumentItem::new(
1134 buffer_1_uri.clone(),
1135 "plaintext".into(),
1136 0,
1137 "Hello world".into()
1138 ),
1139 }
1140 );
1141
1142 // Dropping a buffer causes it to be closed on the LSP side as well.
1143 cx.update(|_| drop(buffer_2));
1144 assert_eq!(
1145 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1146 .await,
1147 lsp::DidCloseTextDocumentParams {
1148 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1149 }
1150 );
1151 }
1152
1153 struct File {
1154 abs_path: PathBuf,
1155 path: Arc<Path>,
1156 }
1157
1158 impl language::File for File {
1159 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1160 Some(self)
1161 }
1162
1163 fn mtime(&self) -> std::time::SystemTime {
1164 todo!()
1165 }
1166
1167 fn path(&self) -> &Arc<Path> {
1168 &self.path
1169 }
1170
1171 fn full_path(&self, _: &AppContext) -> PathBuf {
1172 todo!()
1173 }
1174
1175 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1176 todo!()
1177 }
1178
1179 fn is_deleted(&self) -> bool {
1180 todo!()
1181 }
1182
1183 fn as_any(&self) -> &dyn std::any::Any {
1184 todo!()
1185 }
1186
1187 fn to_proto(&self) -> rpc::proto::File {
1188 todo!()
1189 }
1190 }
1191
1192 impl language::LocalFile for File {
1193 fn abs_path(&self, _: &AppContext) -> PathBuf {
1194 self.abs_path.clone()
1195 }
1196
1197 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1198 todo!()
1199 }
1200
1201 fn buffer_reloaded(
1202 &self,
1203 _: u64,
1204 _: &clock::Global,
1205 _: language::RopeFingerprint,
1206 _: ::fs::LineEnding,
1207 _: std::time::SystemTime,
1208 _: &mut AppContext,
1209 ) {
1210 todo!()
1211 }
1212 }
1213}