1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::{HashMap, HashSet};
8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
11};
12use language::{
13 language_settings::{all_language_settings, language_settings},
14 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
15 ToPointUtf16,
16};
17use log::{debug, error};
18use lsp::{LanguageServer, LanguageServerId};
19use node_runtime::NodeRuntime;
20use request::{LogMessage, StatusNotification};
21use settings::SettingsStore;
22use smol::{fs, io::BufReader, stream::StreamExt};
23use std::{
24 ffi::OsString,
25 mem,
26 ops::Range,
27 path::{Path, PathBuf},
28 pin::Pin,
29 sync::Arc,
30};
31use util::{
32 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
33};
34
35const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
36actions!(copilot_auth, [SignIn, SignOut]);
37
38const COPILOT_NAMESPACE: &'static str = "copilot";
39actions!(
40 copilot,
41 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
42);
43
44pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
45 let copilot = cx.add_model({
46 let node_runtime = node_runtime.clone();
47 move |cx| Copilot::start(http, node_runtime, cx)
48 });
49 cx.set_global(copilot.clone());
50
51 cx.observe(&copilot, |handle, cx| {
52 let status = handle.read(cx).status();
53 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
54 match status {
55 Status::Disabled => {
56 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
57 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
58 }
59 Status::Authorized => {
60 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
61 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
62 }
63 _ => {
64 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
65 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
66 }
67 }
68 });
69 })
70 .detach();
71
72 sign_in::init(cx);
73 cx.add_global_action(|_: &SignIn, cx| {
74 if let Some(copilot) = Copilot::global(cx) {
75 copilot
76 .update(cx, |copilot, cx| copilot.sign_in(cx))
77 .detach_and_log_err(cx);
78 }
79 });
80 cx.add_global_action(|_: &SignOut, cx| {
81 if let Some(copilot) = Copilot::global(cx) {
82 copilot
83 .update(cx, |copilot, cx| copilot.sign_out(cx))
84 .detach_and_log_err(cx);
85 }
86 });
87
88 cx.add_global_action(|_: &Reinstall, cx| {
89 if let Some(copilot) = Copilot::global(cx) {
90 copilot
91 .update(cx, |copilot, cx| copilot.reinstall(cx))
92 .detach();
93 }
94 });
95}
96
97enum CopilotServer {
98 Disabled,
99 Starting { task: Shared<Task<()>> },
100 Error(Arc<str>),
101 Running(RunningCopilotServer),
102}
103
104impl CopilotServer {
105 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
106 let server = self.as_running()?;
107 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
108 Ok(server)
109 } else {
110 Err(anyhow!("must sign in before using copilot"))
111 }
112 }
113
114 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
115 match self {
116 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
117 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
118 CopilotServer::Error(error) => Err(anyhow!(
119 "copilot was not started because of an error: {}",
120 error
121 )),
122 CopilotServer::Running(server) => Ok(server),
123 }
124 }
125}
126
127struct RunningCopilotServer {
128 lsp: Arc<LanguageServer>,
129 sign_in_status: SignInStatus,
130 registered_buffers: HashMap<usize, RegisteredBuffer>,
131}
132
133#[derive(Clone, Debug)]
134enum SignInStatus {
135 Authorized,
136 Unauthorized,
137 SigningIn {
138 prompt: Option<request::PromptUserDeviceFlow>,
139 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
140 },
141 SignedOut,
142}
143
144#[derive(Debug, Clone)]
145pub enum Status {
146 Starting {
147 task: Shared<Task<()>>,
148 },
149 Error(Arc<str>),
150 Disabled,
151 SignedOut,
152 SigningIn {
153 prompt: Option<request::PromptUserDeviceFlow>,
154 },
155 Unauthorized,
156 Authorized,
157}
158
159impl Status {
160 pub fn is_authorized(&self) -> bool {
161 matches!(self, Status::Authorized)
162 }
163}
164
165struct RegisteredBuffer {
166 uri: lsp::Url,
167 language_id: String,
168 snapshot: BufferSnapshot,
169 snapshot_version: i32,
170 _subscriptions: [gpui::Subscription; 2],
171 pending_buffer_change: Task<Option<()>>,
172}
173
174impl RegisteredBuffer {
175 fn report_changes(
176 &mut self,
177 buffer: &ModelHandle<Buffer>,
178 cx: &mut ModelContext<Copilot>,
179 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
180 let (done_tx, done_rx) = oneshot::channel();
181
182 if buffer.read(cx).version() == self.snapshot.version {
183 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
184 } else {
185 let buffer = buffer.downgrade();
186 let id = buffer.id();
187 let prev_pending_change =
188 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
189 self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
190 prev_pending_change.await;
191
192 let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
193 let server = copilot.server.as_authenticated().log_err()?;
194 let buffer = server.registered_buffers.get_mut(&id)?;
195 Some(buffer.snapshot.version.clone())
196 })?;
197 let new_snapshot = buffer
198 .upgrade(&cx)?
199 .read_with(&cx, |buffer, _| buffer.snapshot());
200
201 let content_changes = cx
202 .background()
203 .spawn({
204 let new_snapshot = new_snapshot.clone();
205 async move {
206 new_snapshot
207 .edits_since::<(PointUtf16, usize)>(&old_version)
208 .map(|edit| {
209 let edit_start = edit.new.start.0;
210 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
211 let new_text = new_snapshot
212 .text_for_range(edit.new.start.1..edit.new.end.1)
213 .collect();
214 lsp::TextDocumentContentChangeEvent {
215 range: Some(lsp::Range::new(
216 point_to_lsp(edit_start),
217 point_to_lsp(edit_end),
218 )),
219 range_length: None,
220 text: new_text,
221 }
222 })
223 .collect::<Vec<_>>()
224 }
225 })
226 .await;
227
228 copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
229 let server = copilot.server.as_authenticated().log_err()?;
230 let buffer = server.registered_buffers.get_mut(&id)?;
231 if !content_changes.is_empty() {
232 buffer.snapshot_version += 1;
233 buffer.snapshot = new_snapshot;
234 server
235 .lsp
236 .notify::<lsp::notification::DidChangeTextDocument>(
237 lsp::DidChangeTextDocumentParams {
238 text_document: lsp::VersionedTextDocumentIdentifier::new(
239 buffer.uri.clone(),
240 buffer.snapshot_version,
241 ),
242 content_changes,
243 },
244 )
245 .log_err();
246 }
247 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
248 Some(())
249 })?;
250
251 Some(())
252 });
253 }
254
255 done_rx
256 }
257}
258
259#[derive(Debug)]
260pub struct Completion {
261 pub uuid: String,
262 pub range: Range<Anchor>,
263 pub text: String,
264}
265
266pub struct Copilot {
267 http: Arc<dyn HttpClient>,
268 node_runtime: Arc<NodeRuntime>,
269 server: CopilotServer,
270 buffers: HashSet<WeakModelHandle<Buffer>>,
271}
272
273impl Entity for Copilot {
274 type Event = ();
275
276 fn app_will_quit(
277 &mut self,
278 _: &mut AppContext,
279 ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
280 match mem::replace(&mut self.server, CopilotServer::Disabled) {
281 CopilotServer::Running(server) => Some(Box::pin(async move {
282 if let Some(shutdown) = server.lsp.shutdown() {
283 shutdown.await;
284 }
285 })),
286 _ => None,
287 }
288 }
289}
290
291impl Copilot {
292 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
293 if cx.has_global::<ModelHandle<Self>>() {
294 Some(cx.global::<ModelHandle<Self>>().clone())
295 } else {
296 None
297 }
298 }
299
300 fn start(
301 http: Arc<dyn HttpClient>,
302 node_runtime: Arc<NodeRuntime>,
303 cx: &mut ModelContext<Self>,
304 ) -> Self {
305 let mut this = Self {
306 http,
307 node_runtime,
308 server: CopilotServer::Disabled,
309 buffers: Default::default(),
310 };
311 this.enable_or_disable_copilot(cx);
312 cx.observe_global::<SettingsStore, _>(move |this, cx| this.enable_or_disable_copilot(cx))
313 .detach();
314 this
315 }
316
317 fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Copilot>) {
318 let http = self.http.clone();
319 let node_runtime = self.node_runtime.clone();
320 if all_language_settings(None, cx).copilot_enabled(None, None) {
321 if matches!(self.server, CopilotServer::Disabled) {
322 let start_task = cx
323 .spawn({
324 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
325 })
326 .shared();
327 self.server = CopilotServer::Starting { task: start_task };
328 cx.notify();
329 }
330 } else {
331 self.server = CopilotServer::Disabled;
332 cx.notify();
333 }
334 }
335
336 #[cfg(any(test, feature = "test-support"))]
337 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
338 let (server, fake_server) =
339 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
340 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
341 let this = cx.add_model(|cx| Self {
342 http: http.clone(),
343 node_runtime: NodeRuntime::new(http, cx.background().clone()),
344 server: CopilotServer::Running(RunningCopilotServer {
345 lsp: Arc::new(server),
346 sign_in_status: SignInStatus::Authorized,
347 registered_buffers: Default::default(),
348 }),
349 buffers: Default::default(),
350 });
351 (this, fake_server)
352 }
353
354 fn start_language_server(
355 http: Arc<dyn HttpClient>,
356 node_runtime: Arc<NodeRuntime>,
357 this: ModelHandle<Self>,
358 mut cx: AsyncAppContext,
359 ) -> impl Future<Output = ()> {
360 async move {
361 let start_language_server = async {
362 let server_path = get_copilot_lsp(http).await?;
363 let node_path = node_runtime.binary_path().await?;
364 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
365 let server = LanguageServer::new(
366 LanguageServerId(0),
367 &node_path,
368 arguments,
369 Path::new("/"),
370 None,
371 cx.clone(),
372 )?;
373
374 server
375 .on_notification::<LogMessage, _>(|params, _cx| {
376 match params.level {
377 // Copilot is pretty aggressive about logging
378 0 => debug!("copilot: {}", params.message),
379 1 => debug!("copilot: {}", params.message),
380 _ => error!("copilot: {}", params.message),
381 }
382
383 debug!("copilot metadata: {}", params.metadata_str);
384 debug!("copilot extra: {:?}", params.extra);
385 })
386 .detach();
387
388 server
389 .on_notification::<StatusNotification, _>(
390 |_, _| { /* Silence the notification */ },
391 )
392 .detach();
393
394 let server = server.initialize(Default::default()).await?;
395
396 let status = server
397 .request::<request::CheckStatus>(request::CheckStatusParams {
398 local_checks_only: false,
399 })
400 .await?;
401
402 server
403 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
404 editor_info: request::EditorInfo {
405 name: "zed".into(),
406 version: env!("CARGO_PKG_VERSION").into(),
407 },
408 editor_plugin_info: request::EditorPluginInfo {
409 name: "zed-copilot".into(),
410 version: "0.0.1".into(),
411 },
412 })
413 .await?;
414
415 anyhow::Ok((server, status))
416 };
417
418 let server = start_language_server.await;
419 this.update(&mut cx, |this, cx| {
420 cx.notify();
421 match server {
422 Ok((server, status)) => {
423 this.server = CopilotServer::Running(RunningCopilotServer {
424 lsp: server,
425 sign_in_status: SignInStatus::SignedOut,
426 registered_buffers: Default::default(),
427 });
428 this.update_sign_in_status(status, cx);
429 }
430 Err(error) => {
431 this.server = CopilotServer::Error(error.to_string().into());
432 cx.notify()
433 }
434 }
435 })
436 }
437 }
438
439 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
440 if let CopilotServer::Running(server) = &mut self.server {
441 let task = match &server.sign_in_status {
442 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
443 SignInStatus::SigningIn { task, .. } => {
444 cx.notify();
445 task.clone()
446 }
447 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
448 let lsp = server.lsp.clone();
449 let task = cx
450 .spawn(|this, mut cx| async move {
451 let sign_in = async {
452 let sign_in = lsp
453 .request::<request::SignInInitiate>(
454 request::SignInInitiateParams {},
455 )
456 .await?;
457 match sign_in {
458 request::SignInInitiateResult::AlreadySignedIn { user } => {
459 Ok(request::SignInStatus::Ok { user })
460 }
461 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
462 this.update(&mut cx, |this, cx| {
463 if let CopilotServer::Running(RunningCopilotServer {
464 sign_in_status: status,
465 ..
466 }) = &mut this.server
467 {
468 if let SignInStatus::SigningIn {
469 prompt: prompt_flow,
470 ..
471 } = status
472 {
473 *prompt_flow = Some(flow.clone());
474 cx.notify();
475 }
476 }
477 });
478 let response = lsp
479 .request::<request::SignInConfirm>(
480 request::SignInConfirmParams {
481 user_code: flow.user_code,
482 },
483 )
484 .await?;
485 Ok(response)
486 }
487 }
488 };
489
490 let sign_in = sign_in.await;
491 this.update(&mut cx, |this, cx| match sign_in {
492 Ok(status) => {
493 this.update_sign_in_status(status, cx);
494 Ok(())
495 }
496 Err(error) => {
497 this.update_sign_in_status(
498 request::SignInStatus::NotSignedIn,
499 cx,
500 );
501 Err(Arc::new(error))
502 }
503 })
504 })
505 .shared();
506 server.sign_in_status = SignInStatus::SigningIn {
507 prompt: None,
508 task: task.clone(),
509 };
510 cx.notify();
511 task
512 }
513 };
514
515 cx.foreground()
516 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
517 } else {
518 // If we're downloading, wait until download is finished
519 // If we're in a stuck state, display to the user
520 Task::ready(Err(anyhow!("copilot hasn't started yet")))
521 }
522 }
523
524 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
525 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
526 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
527 let server = server.clone();
528 cx.background().spawn(async move {
529 server
530 .request::<request::SignOut>(request::SignOutParams {})
531 .await?;
532 anyhow::Ok(())
533 })
534 } else {
535 Task::ready(Err(anyhow!("copilot hasn't started yet")))
536 }
537 }
538
539 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
540 let start_task = cx
541 .spawn({
542 let http = self.http.clone();
543 let node_runtime = self.node_runtime.clone();
544 move |this, cx| async move {
545 clear_copilot_dir().await;
546 Self::start_language_server(http, node_runtime, this, cx).await
547 }
548 })
549 .shared();
550
551 self.server = CopilotServer::Starting {
552 task: start_task.clone(),
553 };
554
555 cx.notify();
556
557 cx.foreground().spawn(start_task)
558 }
559
560 pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
561 let weak_buffer = buffer.downgrade();
562 self.buffers.insert(weak_buffer.clone());
563
564 if let CopilotServer::Running(RunningCopilotServer {
565 lsp: server,
566 sign_in_status: status,
567 registered_buffers,
568 ..
569 }) = &mut self.server
570 {
571 if !matches!(status, SignInStatus::Authorized { .. }) {
572 return;
573 }
574
575 registered_buffers.entry(buffer.id()).or_insert_with(|| {
576 let uri: lsp::Url = uri_for_buffer(buffer, cx);
577 let language_id = id_for_language(buffer.read(cx).language());
578 let snapshot = buffer.read(cx).snapshot();
579 server
580 .notify::<lsp::notification::DidOpenTextDocument>(
581 lsp::DidOpenTextDocumentParams {
582 text_document: lsp::TextDocumentItem {
583 uri: uri.clone(),
584 language_id: language_id.clone(),
585 version: 0,
586 text: snapshot.text(),
587 },
588 },
589 )
590 .log_err();
591
592 RegisteredBuffer {
593 uri,
594 language_id,
595 snapshot,
596 snapshot_version: 0,
597 pending_buffer_change: Task::ready(Some(())),
598 _subscriptions: [
599 cx.subscribe(buffer, |this, buffer, event, cx| {
600 this.handle_buffer_event(buffer, event, cx).log_err();
601 }),
602 cx.observe_release(buffer, move |this, _buffer, _cx| {
603 this.buffers.remove(&weak_buffer);
604 this.unregister_buffer(&weak_buffer);
605 }),
606 ],
607 }
608 });
609 }
610 }
611
612 fn handle_buffer_event(
613 &mut self,
614 buffer: ModelHandle<Buffer>,
615 event: &language::Event,
616 cx: &mut ModelContext<Self>,
617 ) -> Result<()> {
618 if let Ok(server) = self.server.as_running() {
619 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
620 match event {
621 language::Event::Edited => {
622 let _ = registered_buffer.report_changes(&buffer, cx);
623 }
624 language::Event::Saved => {
625 server
626 .lsp
627 .notify::<lsp::notification::DidSaveTextDocument>(
628 lsp::DidSaveTextDocumentParams {
629 text_document: lsp::TextDocumentIdentifier::new(
630 registered_buffer.uri.clone(),
631 ),
632 text: None,
633 },
634 )?;
635 }
636 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
637 let new_language_id = id_for_language(buffer.read(cx).language());
638 let new_uri = uri_for_buffer(&buffer, cx);
639 if new_uri != registered_buffer.uri
640 || new_language_id != registered_buffer.language_id
641 {
642 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
643 registered_buffer.language_id = new_language_id;
644 server
645 .lsp
646 .notify::<lsp::notification::DidCloseTextDocument>(
647 lsp::DidCloseTextDocumentParams {
648 text_document: lsp::TextDocumentIdentifier::new(old_uri),
649 },
650 )?;
651 server
652 .lsp
653 .notify::<lsp::notification::DidOpenTextDocument>(
654 lsp::DidOpenTextDocumentParams {
655 text_document: lsp::TextDocumentItem::new(
656 registered_buffer.uri.clone(),
657 registered_buffer.language_id.clone(),
658 registered_buffer.snapshot_version,
659 registered_buffer.snapshot.text(),
660 ),
661 },
662 )?;
663 }
664 }
665 _ => {}
666 }
667 }
668 }
669
670 Ok(())
671 }
672
673 fn unregister_buffer(&mut self, buffer: &WeakModelHandle<Buffer>) {
674 if let Ok(server) = self.server.as_running() {
675 if let Some(buffer) = server.registered_buffers.remove(&buffer.id()) {
676 server
677 .lsp
678 .notify::<lsp::notification::DidCloseTextDocument>(
679 lsp::DidCloseTextDocumentParams {
680 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
681 },
682 )
683 .log_err();
684 }
685 }
686 }
687
688 pub fn completions<T>(
689 &mut self,
690 buffer: &ModelHandle<Buffer>,
691 position: T,
692 cx: &mut ModelContext<Self>,
693 ) -> Task<Result<Vec<Completion>>>
694 where
695 T: ToPointUtf16,
696 {
697 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
698 }
699
700 pub fn completions_cycling<T>(
701 &mut self,
702 buffer: &ModelHandle<Buffer>,
703 position: T,
704 cx: &mut ModelContext<Self>,
705 ) -> Task<Result<Vec<Completion>>>
706 where
707 T: ToPointUtf16,
708 {
709 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
710 }
711
712 pub fn accept_completion(
713 &mut self,
714 completion: &Completion,
715 cx: &mut ModelContext<Self>,
716 ) -> Task<Result<()>> {
717 let server = match self.server.as_authenticated() {
718 Ok(server) => server,
719 Err(error) => return Task::ready(Err(error)),
720 };
721 let request =
722 server
723 .lsp
724 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
725 uuid: completion.uuid.clone(),
726 });
727 cx.background().spawn(async move {
728 request.await?;
729 Ok(())
730 })
731 }
732
733 pub fn discard_completions(
734 &mut self,
735 completions: &[Completion],
736 cx: &mut ModelContext<Self>,
737 ) -> Task<Result<()>> {
738 let server = match self.server.as_authenticated() {
739 Ok(server) => server,
740 Err(error) => return Task::ready(Err(error)),
741 };
742 let request =
743 server
744 .lsp
745 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
746 uuids: completions
747 .iter()
748 .map(|completion| completion.uuid.clone())
749 .collect(),
750 });
751 cx.background().spawn(async move {
752 request.await?;
753 Ok(())
754 })
755 }
756
757 fn request_completions<R, T>(
758 &mut self,
759 buffer: &ModelHandle<Buffer>,
760 position: T,
761 cx: &mut ModelContext<Self>,
762 ) -> Task<Result<Vec<Completion>>>
763 where
764 R: 'static
765 + lsp::request::Request<
766 Params = request::GetCompletionsParams,
767 Result = request::GetCompletionsResult,
768 >,
769 T: ToPointUtf16,
770 {
771 self.register_buffer(buffer, cx);
772
773 let server = match self.server.as_authenticated() {
774 Ok(server) => server,
775 Err(error) => return Task::ready(Err(error)),
776 };
777 let lsp = server.lsp.clone();
778 let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
779 let snapshot = registered_buffer.report_changes(buffer, cx);
780 let buffer = buffer.read(cx);
781 let uri = registered_buffer.uri.clone();
782 let position = position.to_point_utf16(buffer);
783 let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
784 let tab_size = settings.tab_size;
785 let hard_tabs = settings.hard_tabs;
786 let relative_path = buffer
787 .file()
788 .map(|file| file.path().to_path_buf())
789 .unwrap_or_default();
790
791 cx.foreground().spawn(async move {
792 let (version, snapshot) = snapshot.await?;
793 let result = lsp
794 .request::<R>(request::GetCompletionsParams {
795 doc: request::GetCompletionsDocument {
796 uri,
797 tab_size: tab_size.into(),
798 indent_size: 1,
799 insert_spaces: !hard_tabs,
800 relative_path: relative_path.to_string_lossy().into(),
801 position: point_to_lsp(position),
802 version: version.try_into().unwrap(),
803 },
804 })
805 .await?;
806 let completions = result
807 .completions
808 .into_iter()
809 .map(|completion| {
810 let start = snapshot
811 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
812 let end =
813 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
814 Completion {
815 uuid: completion.uuid,
816 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
817 text: completion.text,
818 }
819 })
820 .collect();
821 anyhow::Ok(completions)
822 })
823 }
824
825 pub fn status(&self) -> Status {
826 match &self.server {
827 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
828 CopilotServer::Disabled => Status::Disabled,
829 CopilotServer::Error(error) => Status::Error(error.clone()),
830 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
831 match sign_in_status {
832 SignInStatus::Authorized { .. } => Status::Authorized,
833 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
834 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
835 prompt: prompt.clone(),
836 },
837 SignInStatus::SignedOut => Status::SignedOut,
838 }
839 }
840 }
841 }
842
843 fn update_sign_in_status(
844 &mut self,
845 lsp_status: request::SignInStatus,
846 cx: &mut ModelContext<Self>,
847 ) {
848 self.buffers.retain(|buffer| buffer.is_upgradable(cx));
849
850 if let Ok(server) = self.server.as_running() {
851 match lsp_status {
852 request::SignInStatus::Ok { .. }
853 | request::SignInStatus::MaybeOk { .. }
854 | request::SignInStatus::AlreadySignedIn { .. } => {
855 server.sign_in_status = SignInStatus::Authorized;
856 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
857 if let Some(buffer) = buffer.upgrade(cx) {
858 self.register_buffer(&buffer, cx);
859 }
860 }
861 }
862 request::SignInStatus::NotAuthorized { .. } => {
863 server.sign_in_status = SignInStatus::Unauthorized;
864 for buffer in self.buffers.iter().copied().collect::<Vec<_>>() {
865 self.unregister_buffer(&buffer);
866 }
867 }
868 request::SignInStatus::NotSignedIn => {
869 server.sign_in_status = SignInStatus::SignedOut;
870 for buffer in self.buffers.iter().copied().collect::<Vec<_>>() {
871 self.unregister_buffer(&buffer);
872 }
873 }
874 }
875
876 cx.notify();
877 }
878 }
879}
880
881fn id_for_language(language: Option<&Arc<Language>>) -> String {
882 let language_name = language.map(|language| language.name());
883 match language_name.as_deref() {
884 Some("Plain Text") => "plaintext".to_string(),
885 Some(language_name) => language_name.to_lowercase(),
886 None => "plaintext".to_string(),
887 }
888}
889
890fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
891 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
892 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
893 } else {
894 format!("buffer://{}", buffer.id()).parse().unwrap()
895 }
896}
897
898async fn clear_copilot_dir() {
899 remove_matching(&paths::COPILOT_DIR, |_| true).await
900}
901
902async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
903 const SERVER_PATH: &'static str = "dist/agent.js";
904
905 ///Check for the latest copilot language server and download it if we haven't already
906 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
907 let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
908
909 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
910
911 fs::create_dir_all(version_dir).await?;
912 let server_path = version_dir.join(SERVER_PATH);
913
914 if fs::metadata(&server_path).await.is_err() {
915 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
916 let dist_dir = version_dir.join("dist");
917 fs::create_dir_all(dist_dir.as_path()).await?;
918
919 let url = &release
920 .assets
921 .get(0)
922 .context("Github release for copilot contained no assets")?
923 .browser_download_url;
924
925 let mut response = http
926 .get(&url, Default::default(), true)
927 .await
928 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
929 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
930 let archive = Archive::new(decompressed_bytes);
931 archive.unpack(dist_dir).await?;
932
933 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
934 }
935
936 Ok(server_path)
937 }
938
939 match fetch_latest(http).await {
940 ok @ Result::Ok(..) => ok,
941 e @ Err(..) => {
942 e.log_err();
943 // Fetch a cached binary, if it exists
944 (|| async move {
945 let mut last_version_dir = None;
946 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
947 while let Some(entry) = entries.next().await {
948 let entry = entry?;
949 if entry.file_type().await?.is_dir() {
950 last_version_dir = Some(entry.path());
951 }
952 }
953 let last_version_dir =
954 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
955 let server_path = last_version_dir.join(SERVER_PATH);
956 if server_path.exists() {
957 Ok(server_path)
958 } else {
959 Err(anyhow!(
960 "missing executable in directory {:?}",
961 last_version_dir
962 ))
963 }
964 })()
965 .await
966 }
967 }
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973 use gpui::{executor::Deterministic, TestAppContext};
974
975 #[gpui::test(iterations = 10)]
976 async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
977 deterministic.forbid_parking();
978 let (copilot, mut lsp) = Copilot::fake(cx);
979
980 let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
981 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
982 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
983 assert_eq!(
984 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
985 .await,
986 lsp::DidOpenTextDocumentParams {
987 text_document: lsp::TextDocumentItem::new(
988 buffer_1_uri.clone(),
989 "plaintext".into(),
990 0,
991 "Hello".into()
992 ),
993 }
994 );
995
996 let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
997 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
998 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
999 assert_eq!(
1000 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1001 .await,
1002 lsp::DidOpenTextDocumentParams {
1003 text_document: lsp::TextDocumentItem::new(
1004 buffer_2_uri.clone(),
1005 "plaintext".into(),
1006 0,
1007 "Goodbye".into()
1008 ),
1009 }
1010 );
1011
1012 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1013 assert_eq!(
1014 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1015 .await,
1016 lsp::DidChangeTextDocumentParams {
1017 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1018 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1019 range: Some(lsp::Range::new(
1020 lsp::Position::new(0, 5),
1021 lsp::Position::new(0, 5)
1022 )),
1023 range_length: None,
1024 text: " world".into(),
1025 }],
1026 }
1027 );
1028
1029 // Ensure updates to the file are reflected in the LSP.
1030 buffer_1
1031 .update(cx, |buffer, cx| {
1032 buffer.file_updated(
1033 Arc::new(File {
1034 abs_path: "/root/child/buffer-1".into(),
1035 path: Path::new("child/buffer-1").into(),
1036 }),
1037 cx,
1038 )
1039 })
1040 .await;
1041 assert_eq!(
1042 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1043 .await,
1044 lsp::DidCloseTextDocumentParams {
1045 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1046 }
1047 );
1048 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1049 assert_eq!(
1050 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1051 .await,
1052 lsp::DidOpenTextDocumentParams {
1053 text_document: lsp::TextDocumentItem::new(
1054 buffer_1_uri.clone(),
1055 "plaintext".into(),
1056 1,
1057 "Hello world".into()
1058 ),
1059 }
1060 );
1061
1062 // Ensure all previously-registered buffers are closed when signing out.
1063 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1064 Ok(request::SignOutResult {})
1065 });
1066 copilot
1067 .update(cx, |copilot, cx| copilot.sign_out(cx))
1068 .await
1069 .unwrap();
1070 assert_eq!(
1071 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1072 .await,
1073 lsp::DidCloseTextDocumentParams {
1074 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1075 }
1076 );
1077 assert_eq!(
1078 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1079 .await,
1080 lsp::DidCloseTextDocumentParams {
1081 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1082 }
1083 );
1084
1085 // Ensure all previously-registered buffers are re-opened when signing in.
1086 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1087 Ok(request::SignInInitiateResult::AlreadySignedIn {
1088 user: "user-1".into(),
1089 })
1090 });
1091 copilot
1092 .update(cx, |copilot, cx| copilot.sign_in(cx))
1093 .await
1094 .unwrap();
1095 assert_eq!(
1096 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1097 .await,
1098 lsp::DidOpenTextDocumentParams {
1099 text_document: lsp::TextDocumentItem::new(
1100 buffer_2_uri.clone(),
1101 "plaintext".into(),
1102 0,
1103 "Goodbye".into()
1104 ),
1105 }
1106 );
1107 assert_eq!(
1108 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1109 .await,
1110 lsp::DidOpenTextDocumentParams {
1111 text_document: lsp::TextDocumentItem::new(
1112 buffer_1_uri.clone(),
1113 "plaintext".into(),
1114 0,
1115 "Hello world".into()
1116 ),
1117 }
1118 );
1119
1120 // Dropping a buffer causes it to be closed on the LSP side as well.
1121 cx.update(|_| drop(buffer_2));
1122 assert_eq!(
1123 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1124 .await,
1125 lsp::DidCloseTextDocumentParams {
1126 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1127 }
1128 );
1129 }
1130
1131 struct File {
1132 abs_path: PathBuf,
1133 path: Arc<Path>,
1134 }
1135
1136 impl language::File for File {
1137 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1138 Some(self)
1139 }
1140
1141 fn mtime(&self) -> std::time::SystemTime {
1142 unimplemented!()
1143 }
1144
1145 fn path(&self) -> &Arc<Path> {
1146 &self.path
1147 }
1148
1149 fn full_path(&self, _: &AppContext) -> PathBuf {
1150 unimplemented!()
1151 }
1152
1153 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1154 unimplemented!()
1155 }
1156
1157 fn is_deleted(&self) -> bool {
1158 unimplemented!()
1159 }
1160
1161 fn as_any(&self) -> &dyn std::any::Any {
1162 unimplemented!()
1163 }
1164
1165 fn to_proto(&self) -> rpc::proto::File {
1166 unimplemented!()
1167 }
1168
1169 fn worktree_id(&self) -> usize {
1170 0
1171 }
1172 }
1173
1174 impl language::LocalFile for File {
1175 fn abs_path(&self, _: &AppContext) -> PathBuf {
1176 self.abs_path.clone()
1177 }
1178
1179 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1180 unimplemented!()
1181 }
1182
1183 fn buffer_reloaded(
1184 &self,
1185 _: u64,
1186 _: &clock::Global,
1187 _: language::RopeFingerprint,
1188 _: ::fs::LineEnding,
1189 _: std::time::SystemTime,
1190 _: &mut AppContext,
1191 ) {
1192 unimplemented!()
1193 }
1194 }
1195}