1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context as _, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::{HashMap, HashSet};
8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Model,
11 ModelContext, Task, WeakModel,
12};
13use language::{
14 language_settings::{all_language_settings, language_settings},
15 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language,
16 LanguageServerName, PointUtf16, ToPointUtf16,
17};
18use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId};
19use node_runtime::NodeRuntime;
20use parking_lot::Mutex;
21use request::StatusNotification;
22use settings::SettingsStore;
23use smol::{fs, io::BufReader, stream::StreamExt};
24use std::{
25 any::TypeId,
26 ffi::OsString,
27 mem,
28 ops::Range,
29 path::{Path, PathBuf},
30 sync::Arc,
31};
32use util::{
33 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
34};
35
36actions!(
37 Suggest,
38 NextSuggestion,
39 PreviousSuggestion,
40 Reinstall,
41 SignIn,
42 SignOut
43);
44
45pub fn init(
46 new_server_id: LanguageServerId,
47 http: Arc<dyn HttpClient>,
48 node_runtime: Arc<dyn NodeRuntime>,
49 cx: &mut AppContext,
50) {
51 let copilot = cx.build_model({
52 let node_runtime = node_runtime.clone();
53 move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
54 });
55 cx.set_global(copilot.clone());
56 cx.observe(&copilot, |handle, cx| {
57 let copilot_action_types = [
58 TypeId::of::<Suggest>(),
59 TypeId::of::<NextSuggestion>(),
60 TypeId::of::<PreviousSuggestion>(),
61 TypeId::of::<Reinstall>(),
62 ];
63 let copilot_auth_action_types = [TypeId::of::<SignIn>(), TypeId::of::<SignOut>()];
64
65 let status = handle.read(cx).status();
66 let filter = cx.default_global::<collections::CommandPaletteFilter>();
67
68 match status {
69 Status::Disabled => {
70 filter.hidden_action_types.extend(copilot_action_types);
71 filter.hidden_action_types.extend(copilot_auth_action_types);
72 }
73 Status::Authorized => {
74 for type_id in copilot_action_types
75 .iter()
76 .chain(&copilot_auth_action_types)
77 {
78 filter.hidden_action_types.remove(type_id);
79 }
80 }
81 _ => {
82 filter.hidden_action_types.extend(copilot_action_types);
83 for type_id in &copilot_auth_action_types {
84 filter.hidden_action_types.remove(type_id);
85 }
86 }
87 }
88 })
89 .detach();
90
91 sign_in::init(cx);
92 cx.on_action(|_: &SignIn, cx| {
93 if let Some(copilot) = Copilot::global(cx) {
94 copilot
95 .update(cx, |copilot, cx| copilot.sign_in(cx))
96 .detach_and_log_err(cx);
97 }
98 });
99 cx.on_action(|_: &SignOut, cx| {
100 if let Some(copilot) = Copilot::global(cx) {
101 copilot
102 .update(cx, |copilot, cx| copilot.sign_out(cx))
103 .detach_and_log_err(cx);
104 }
105 });
106 cx.on_action(|_: &Reinstall, cx| {
107 if let Some(copilot) = Copilot::global(cx) {
108 copilot
109 .update(cx, |copilot, cx| copilot.reinstall(cx))
110 .detach();
111 }
112 });
113}
114
115enum CopilotServer {
116 Disabled,
117 Starting { task: Shared<Task<()>> },
118 Error(Arc<str>),
119 Running(RunningCopilotServer),
120}
121
122impl CopilotServer {
123 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
124 let server = self.as_running()?;
125 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
126 Ok(server)
127 } else {
128 Err(anyhow!("must sign in before using copilot"))
129 }
130 }
131
132 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
133 match self {
134 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
135 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
136 CopilotServer::Error(error) => Err(anyhow!(
137 "copilot was not started because of an error: {}",
138 error
139 )),
140 CopilotServer::Running(server) => Ok(server),
141 }
142 }
143}
144
145struct RunningCopilotServer {
146 name: LanguageServerName,
147 lsp: Arc<LanguageServer>,
148 sign_in_status: SignInStatus,
149 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
150}
151
152#[derive(Clone, Debug)]
153enum SignInStatus {
154 Authorized,
155 Unauthorized,
156 SigningIn {
157 prompt: Option<request::PromptUserDeviceFlow>,
158 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
159 },
160 SignedOut,
161}
162
163#[derive(Debug, Clone)]
164pub enum Status {
165 Starting {
166 task: Shared<Task<()>>,
167 },
168 Error(Arc<str>),
169 Disabled,
170 SignedOut,
171 SigningIn {
172 prompt: Option<request::PromptUserDeviceFlow>,
173 },
174 Unauthorized,
175 Authorized,
176}
177
178impl Status {
179 pub fn is_authorized(&self) -> bool {
180 matches!(self, Status::Authorized)
181 }
182}
183
184struct RegisteredBuffer {
185 uri: lsp::Url,
186 language_id: String,
187 snapshot: BufferSnapshot,
188 snapshot_version: i32,
189 _subscriptions: [gpui::Subscription; 2],
190 pending_buffer_change: Task<Option<()>>,
191}
192
193impl RegisteredBuffer {
194 fn report_changes(
195 &mut self,
196 buffer: &Model<Buffer>,
197 cx: &mut ModelContext<Copilot>,
198 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
199 let (done_tx, done_rx) = oneshot::channel();
200
201 if buffer.read(cx).version() == self.snapshot.version {
202 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
203 } else {
204 let buffer = buffer.downgrade();
205 let id = buffer.entity_id();
206 let prev_pending_change =
207 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
208 self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
209 prev_pending_change.await;
210
211 let old_version = copilot
212 .update(&mut cx, |copilot, _| {
213 let server = copilot.server.as_authenticated().log_err()?;
214 let buffer = server.registered_buffers.get_mut(&id)?;
215 Some(buffer.snapshot.version.clone())
216 })
217 .ok()??;
218 let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
219
220 let content_changes = cx
221 .background_executor()
222 .spawn({
223 let new_snapshot = new_snapshot.clone();
224 async move {
225 new_snapshot
226 .edits_since::<(PointUtf16, usize)>(&old_version)
227 .map(|edit| {
228 let edit_start = edit.new.start.0;
229 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
230 let new_text = new_snapshot
231 .text_for_range(edit.new.start.1..edit.new.end.1)
232 .collect();
233 lsp::TextDocumentContentChangeEvent {
234 range: Some(lsp::Range::new(
235 point_to_lsp(edit_start),
236 point_to_lsp(edit_end),
237 )),
238 range_length: None,
239 text: new_text,
240 }
241 })
242 .collect::<Vec<_>>()
243 }
244 })
245 .await;
246
247 copilot
248 .update(&mut cx, |copilot, _| {
249 let server = copilot.server.as_authenticated().log_err()?;
250 let buffer = server.registered_buffers.get_mut(&id)?;
251 if !content_changes.is_empty() {
252 buffer.snapshot_version += 1;
253 buffer.snapshot = new_snapshot;
254 server
255 .lsp
256 .notify::<lsp::notification::DidChangeTextDocument>(
257 lsp::DidChangeTextDocumentParams {
258 text_document: lsp::VersionedTextDocumentIdentifier::new(
259 buffer.uri.clone(),
260 buffer.snapshot_version,
261 ),
262 content_changes,
263 },
264 )
265 .log_err();
266 }
267 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
268 Some(())
269 })
270 .ok()?;
271
272 Some(())
273 });
274 }
275
276 done_rx
277 }
278}
279
280#[derive(Debug)]
281pub struct Completion {
282 pub uuid: String,
283 pub range: Range<Anchor>,
284 pub text: String,
285}
286
287pub struct Copilot {
288 http: Arc<dyn HttpClient>,
289 node_runtime: Arc<dyn NodeRuntime>,
290 server: CopilotServer,
291 buffers: HashSet<WeakModel<Buffer>>,
292 server_id: LanguageServerId,
293 _subscription: gpui::Subscription,
294}
295
296pub enum Event {
297 CopilotLanguageServerStarted,
298}
299
300impl EventEmitter<Event> for Copilot {}
301
302impl Copilot {
303 pub fn global(cx: &AppContext) -> Option<Model<Self>> {
304 if cx.has_global::<Model<Self>>() {
305 Some(cx.global::<Model<Self>>().clone())
306 } else {
307 None
308 }
309 }
310
311 fn start(
312 new_server_id: LanguageServerId,
313 http: Arc<dyn HttpClient>,
314 node_runtime: Arc<dyn NodeRuntime>,
315 cx: &mut ModelContext<Self>,
316 ) -> Self {
317 let mut this = Self {
318 server_id: new_server_id,
319 http,
320 node_runtime,
321 server: CopilotServer::Disabled,
322 buffers: Default::default(),
323 _subscription: cx.on_app_quit(Self::shutdown_language_server),
324 };
325 this.enable_or_disable_copilot(cx);
326 cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
327 .detach();
328 this
329 }
330
331 fn shutdown_language_server(
332 &mut self,
333 _cx: &mut ModelContext<Self>,
334 ) -> impl Future<Output = ()> {
335 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
336 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
337 _ => None,
338 };
339
340 async move {
341 if let Some(shutdown) = shutdown {
342 shutdown.await;
343 }
344 }
345 }
346
347 fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
348 let server_id = self.server_id;
349 let http = self.http.clone();
350 let node_runtime = self.node_runtime.clone();
351 if all_language_settings(None, cx).copilot_enabled(None, None) {
352 if matches!(self.server, CopilotServer::Disabled) {
353 let start_task = cx
354 .spawn(move |this, cx| {
355 Self::start_language_server(server_id, http, node_runtime, this, cx)
356 })
357 .shared();
358 self.server = CopilotServer::Starting { task: start_task };
359 cx.notify();
360 }
361 } else {
362 self.server = CopilotServer::Disabled;
363 cx.notify();
364 }
365 }
366
367 #[cfg(any(test, feature = "test-support"))]
368 pub fn fake(cx: &mut gpui::TestAppContext) -> (Model<Self>, lsp::FakeLanguageServer) {
369 use node_runtime::FakeNodeRuntime;
370
371 let (server, fake_server) =
372 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
373 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
374 let node_runtime = FakeNodeRuntime::new();
375 let this = cx.build_model(|cx| Self {
376 server_id: LanguageServerId(0),
377 http: http.clone(),
378 node_runtime,
379 server: CopilotServer::Running(RunningCopilotServer {
380 name: LanguageServerName(Arc::from("copilot")),
381 lsp: Arc::new(server),
382 sign_in_status: SignInStatus::Authorized,
383 registered_buffers: Default::default(),
384 }),
385 _subscription: cx.on_app_quit(Self::shutdown_language_server),
386 buffers: Default::default(),
387 });
388 (this, fake_server)
389 }
390
391 fn start_language_server(
392 new_server_id: LanguageServerId,
393 http: Arc<dyn HttpClient>,
394 node_runtime: Arc<dyn NodeRuntime>,
395 this: WeakModel<Self>,
396 mut cx: AsyncAppContext,
397 ) -> impl Future<Output = ()> {
398 async move {
399 let start_language_server = async {
400 let server_path = get_copilot_lsp(http).await?;
401 let node_path = node_runtime.binary_path().await?;
402 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
403 let binary = LanguageServerBinary {
404 path: node_path,
405 arguments,
406 };
407
408 let server = LanguageServer::new(
409 Arc::new(Mutex::new(None)),
410 new_server_id,
411 binary,
412 Path::new("/"),
413 None,
414 cx.clone(),
415 )?;
416
417 server
418 .on_notification::<StatusNotification, _>(
419 |_, _| { /* Silence the notification */ },
420 )
421 .detach();
422
423 let server = server.initialize(Default::default()).await?;
424
425 let status = server
426 .request::<request::CheckStatus>(request::CheckStatusParams {
427 local_checks_only: false,
428 })
429 .await?;
430
431 server
432 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
433 editor_info: request::EditorInfo {
434 name: "zed".into(),
435 version: env!("CARGO_PKG_VERSION").into(),
436 },
437 editor_plugin_info: request::EditorPluginInfo {
438 name: "zed-copilot".into(),
439 version: "0.0.1".into(),
440 },
441 })
442 .await?;
443
444 anyhow::Ok((server, status))
445 };
446
447 let server = start_language_server.await;
448 this.update(&mut cx, |this, cx| {
449 cx.notify();
450 match server {
451 Ok((server, status)) => {
452 this.server = CopilotServer::Running(RunningCopilotServer {
453 name: LanguageServerName(Arc::from("copilot")),
454 lsp: server,
455 sign_in_status: SignInStatus::SignedOut,
456 registered_buffers: Default::default(),
457 });
458 cx.emit(Event::CopilotLanguageServerStarted);
459 this.update_sign_in_status(status, cx);
460 }
461 Err(error) => {
462 this.server = CopilotServer::Error(error.to_string().into());
463 cx.notify()
464 }
465 }
466 })
467 .ok();
468 }
469 }
470
471 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
472 if let CopilotServer::Running(server) = &mut self.server {
473 let task = match &server.sign_in_status {
474 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
475 SignInStatus::SigningIn { task, .. } => {
476 cx.notify();
477 task.clone()
478 }
479 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
480 let lsp = server.lsp.clone();
481 let task = cx
482 .spawn(|this, mut cx| async move {
483 let sign_in = async {
484 let sign_in = lsp
485 .request::<request::SignInInitiate>(
486 request::SignInInitiateParams {},
487 )
488 .await?;
489 match sign_in {
490 request::SignInInitiateResult::AlreadySignedIn { user } => {
491 Ok(request::SignInStatus::Ok { user })
492 }
493 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
494 this.update(&mut cx, |this, cx| {
495 if let CopilotServer::Running(RunningCopilotServer {
496 sign_in_status: status,
497 ..
498 }) = &mut this.server
499 {
500 if let SignInStatus::SigningIn {
501 prompt: prompt_flow,
502 ..
503 } = status
504 {
505 *prompt_flow = Some(flow.clone());
506 cx.notify();
507 }
508 }
509 })?;
510 let response = lsp
511 .request::<request::SignInConfirm>(
512 request::SignInConfirmParams {
513 user_code: flow.user_code,
514 },
515 )
516 .await?;
517 Ok(response)
518 }
519 }
520 };
521
522 let sign_in = sign_in.await;
523 this.update(&mut cx, |this, cx| match sign_in {
524 Ok(status) => {
525 this.update_sign_in_status(status, cx);
526 Ok(())
527 }
528 Err(error) => {
529 this.update_sign_in_status(
530 request::SignInStatus::NotSignedIn,
531 cx,
532 );
533 Err(Arc::new(error))
534 }
535 })?
536 })
537 .shared();
538 server.sign_in_status = SignInStatus::SigningIn {
539 prompt: None,
540 task: task.clone(),
541 };
542 cx.notify();
543 task
544 }
545 };
546
547 cx.background_executor()
548 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
549 } else {
550 // If we're downloading, wait until download is finished
551 // If we're in a stuck state, display to the user
552 Task::ready(Err(anyhow!("copilot hasn't started yet")))
553 }
554 }
555
556 #[allow(dead_code)] // todo!()
557 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
558 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
559 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
560 let server = server.clone();
561 cx.background_executor().spawn(async move {
562 server
563 .request::<request::SignOut>(request::SignOutParams {})
564 .await?;
565 anyhow::Ok(())
566 })
567 } else {
568 Task::ready(Err(anyhow!("copilot hasn't started yet")))
569 }
570 }
571
572 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
573 let start_task = cx
574 .spawn({
575 let http = self.http.clone();
576 let node_runtime = self.node_runtime.clone();
577 let server_id = self.server_id;
578 move |this, cx| async move {
579 clear_copilot_dir().await;
580 Self::start_language_server(server_id, http, node_runtime, this, cx).await
581 }
582 })
583 .shared();
584
585 self.server = CopilotServer::Starting {
586 task: start_task.clone(),
587 };
588
589 cx.notify();
590
591 cx.background_executor().spawn(start_task)
592 }
593
594 pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc<LanguageServer>)> {
595 if let CopilotServer::Running(server) = &self.server {
596 Some((&server.name, &server.lsp))
597 } else {
598 None
599 }
600 }
601
602 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
603 let weak_buffer = buffer.downgrade();
604 self.buffers.insert(weak_buffer.clone());
605
606 if let CopilotServer::Running(RunningCopilotServer {
607 lsp: server,
608 sign_in_status: status,
609 registered_buffers,
610 ..
611 }) = &mut self.server
612 {
613 if !matches!(status, SignInStatus::Authorized { .. }) {
614 return;
615 }
616
617 registered_buffers
618 .entry(buffer.entity_id())
619 .or_insert_with(|| {
620 let uri: lsp::Url = uri_for_buffer(buffer, cx);
621 let language_id = id_for_language(buffer.read(cx).language());
622 let snapshot = buffer.read(cx).snapshot();
623 server
624 .notify::<lsp::notification::DidOpenTextDocument>(
625 lsp::DidOpenTextDocumentParams {
626 text_document: lsp::TextDocumentItem {
627 uri: uri.clone(),
628 language_id: language_id.clone(),
629 version: 0,
630 text: snapshot.text(),
631 },
632 },
633 )
634 .log_err();
635
636 RegisteredBuffer {
637 uri,
638 language_id,
639 snapshot,
640 snapshot_version: 0,
641 pending_buffer_change: Task::ready(Some(())),
642 _subscriptions: [
643 cx.subscribe(buffer, |this, buffer, event, cx| {
644 this.handle_buffer_event(buffer, event, cx).log_err();
645 }),
646 cx.observe_release(buffer, move |this, _buffer, _cx| {
647 this.buffers.remove(&weak_buffer);
648 this.unregister_buffer(&weak_buffer);
649 }),
650 ],
651 }
652 });
653 }
654 }
655
656 fn handle_buffer_event(
657 &mut self,
658 buffer: Model<Buffer>,
659 event: &language::Event,
660 cx: &mut ModelContext<Self>,
661 ) -> Result<()> {
662 if let Ok(server) = self.server.as_running() {
663 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
664 {
665 match event {
666 language::Event::Edited => {
667 let _ = registered_buffer.report_changes(&buffer, cx);
668 }
669 language::Event::Saved => {
670 server
671 .lsp
672 .notify::<lsp::notification::DidSaveTextDocument>(
673 lsp::DidSaveTextDocumentParams {
674 text_document: lsp::TextDocumentIdentifier::new(
675 registered_buffer.uri.clone(),
676 ),
677 text: None,
678 },
679 )?;
680 }
681 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
682 let new_language_id = id_for_language(buffer.read(cx).language());
683 let new_uri = uri_for_buffer(&buffer, cx);
684 if new_uri != registered_buffer.uri
685 || new_language_id != registered_buffer.language_id
686 {
687 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
688 registered_buffer.language_id = new_language_id;
689 server
690 .lsp
691 .notify::<lsp::notification::DidCloseTextDocument>(
692 lsp::DidCloseTextDocumentParams {
693 text_document: lsp::TextDocumentIdentifier::new(old_uri),
694 },
695 )?;
696 server
697 .lsp
698 .notify::<lsp::notification::DidOpenTextDocument>(
699 lsp::DidOpenTextDocumentParams {
700 text_document: lsp::TextDocumentItem::new(
701 registered_buffer.uri.clone(),
702 registered_buffer.language_id.clone(),
703 registered_buffer.snapshot_version,
704 registered_buffer.snapshot.text(),
705 ),
706 },
707 )?;
708 }
709 }
710 _ => {}
711 }
712 }
713 }
714
715 Ok(())
716 }
717
718 fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
719 if let Ok(server) = self.server.as_running() {
720 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
721 server
722 .lsp
723 .notify::<lsp::notification::DidCloseTextDocument>(
724 lsp::DidCloseTextDocumentParams {
725 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
726 },
727 )
728 .log_err();
729 }
730 }
731 }
732
733 pub fn completions<T>(
734 &mut self,
735 buffer: &Model<Buffer>,
736 position: T,
737 cx: &mut ModelContext<Self>,
738 ) -> Task<Result<Vec<Completion>>>
739 where
740 T: ToPointUtf16,
741 {
742 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
743 }
744
745 pub fn completions_cycling<T>(
746 &mut self,
747 buffer: &Model<Buffer>,
748 position: T,
749 cx: &mut ModelContext<Self>,
750 ) -> Task<Result<Vec<Completion>>>
751 where
752 T: ToPointUtf16,
753 {
754 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
755 }
756
757 pub fn accept_completion(
758 &mut self,
759 completion: &Completion,
760 cx: &mut ModelContext<Self>,
761 ) -> Task<Result<()>> {
762 let server = match self.server.as_authenticated() {
763 Ok(server) => server,
764 Err(error) => return Task::ready(Err(error)),
765 };
766 let request =
767 server
768 .lsp
769 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
770 uuid: completion.uuid.clone(),
771 });
772 cx.background_executor().spawn(async move {
773 request.await?;
774 Ok(())
775 })
776 }
777
778 pub fn discard_completions(
779 &mut self,
780 completions: &[Completion],
781 cx: &mut ModelContext<Self>,
782 ) -> Task<Result<()>> {
783 let server = match self.server.as_authenticated() {
784 Ok(server) => server,
785 Err(error) => return Task::ready(Err(error)),
786 };
787 let request =
788 server
789 .lsp
790 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
791 uuids: completions
792 .iter()
793 .map(|completion| completion.uuid.clone())
794 .collect(),
795 });
796 cx.background_executor().spawn(async move {
797 request.await?;
798 Ok(())
799 })
800 }
801
802 fn request_completions<R, T>(
803 &mut self,
804 buffer: &Model<Buffer>,
805 position: T,
806 cx: &mut ModelContext<Self>,
807 ) -> Task<Result<Vec<Completion>>>
808 where
809 R: 'static
810 + lsp::request::Request<
811 Params = request::GetCompletionsParams,
812 Result = request::GetCompletionsResult,
813 >,
814 T: ToPointUtf16,
815 {
816 self.register_buffer(buffer, cx);
817
818 let server = match self.server.as_authenticated() {
819 Ok(server) => server,
820 Err(error) => return Task::ready(Err(error)),
821 };
822 let lsp = server.lsp.clone();
823 let registered_buffer = server
824 .registered_buffers
825 .get_mut(&buffer.entity_id())
826 .unwrap();
827 let snapshot = registered_buffer.report_changes(buffer, cx);
828 let buffer = buffer.read(cx);
829 let uri = registered_buffer.uri.clone();
830 let position = position.to_point_utf16(buffer);
831 let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
832 let tab_size = settings.tab_size;
833 let hard_tabs = settings.hard_tabs;
834 let relative_path = buffer
835 .file()
836 .map(|file| file.path().to_path_buf())
837 .unwrap_or_default();
838
839 cx.background_executor().spawn(async move {
840 let (version, snapshot) = snapshot.await?;
841 let result = lsp
842 .request::<R>(request::GetCompletionsParams {
843 doc: request::GetCompletionsDocument {
844 uri,
845 tab_size: tab_size.into(),
846 indent_size: 1,
847 insert_spaces: !hard_tabs,
848 relative_path: relative_path.to_string_lossy().into(),
849 position: point_to_lsp(position),
850 version: version.try_into().unwrap(),
851 },
852 })
853 .await?;
854 let completions = result
855 .completions
856 .into_iter()
857 .map(|completion| {
858 let start = snapshot
859 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
860 let end =
861 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
862 Completion {
863 uuid: completion.uuid,
864 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
865 text: completion.text,
866 }
867 })
868 .collect();
869 anyhow::Ok(completions)
870 })
871 }
872
873 pub fn status(&self) -> Status {
874 match &self.server {
875 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
876 CopilotServer::Disabled => Status::Disabled,
877 CopilotServer::Error(error) => Status::Error(error.clone()),
878 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
879 match sign_in_status {
880 SignInStatus::Authorized { .. } => Status::Authorized,
881 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
882 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
883 prompt: prompt.clone(),
884 },
885 SignInStatus::SignedOut => Status::SignedOut,
886 }
887 }
888 }
889 }
890
891 fn update_sign_in_status(
892 &mut self,
893 lsp_status: request::SignInStatus,
894 cx: &mut ModelContext<Self>,
895 ) {
896 self.buffers.retain(|buffer| buffer.is_upgradable());
897
898 if let Ok(server) = self.server.as_running() {
899 match lsp_status {
900 request::SignInStatus::Ok { .. }
901 | request::SignInStatus::MaybeOk { .. }
902 | request::SignInStatus::AlreadySignedIn { .. } => {
903 server.sign_in_status = SignInStatus::Authorized;
904 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
905 if let Some(buffer) = buffer.upgrade() {
906 self.register_buffer(&buffer, cx);
907 }
908 }
909 }
910 request::SignInStatus::NotAuthorized { .. } => {
911 server.sign_in_status = SignInStatus::Unauthorized;
912 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
913 self.unregister_buffer(&buffer);
914 }
915 }
916 request::SignInStatus::NotSignedIn => {
917 server.sign_in_status = SignInStatus::SignedOut;
918 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
919 self.unregister_buffer(&buffer);
920 }
921 }
922 }
923
924 cx.notify();
925 }
926 }
927}
928
929fn id_for_language(language: Option<&Arc<Language>>) -> String {
930 let language_name = language.map(|language| language.name());
931 match language_name.as_deref() {
932 Some("Plain Text") => "plaintext".to_string(),
933 Some(language_name) => language_name.to_lowercase(),
934 None => "plaintext".to_string(),
935 }
936}
937
938fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp::Url {
939 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
940 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
941 } else {
942 format!("buffer://{}", buffer.entity_id()).parse().unwrap()
943 }
944}
945
946async fn clear_copilot_dir() {
947 remove_matching(&paths::COPILOT_DIR, |_| true).await
948}
949
950async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
951 const SERVER_PATH: &'static str = "dist/agent.js";
952
953 ///Check for the latest copilot language server and download it if we haven't already
954 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
955 let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
956
957 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
958
959 fs::create_dir_all(version_dir).await?;
960 let server_path = version_dir.join(SERVER_PATH);
961
962 if fs::metadata(&server_path).await.is_err() {
963 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
964 let dist_dir = version_dir.join("dist");
965 fs::create_dir_all(dist_dir.as_path()).await?;
966
967 let url = &release
968 .assets
969 .get(0)
970 .context("Github release for copilot contained no assets")?
971 .browser_download_url;
972
973 let mut response = http
974 .get(&url, Default::default(), true)
975 .await
976 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
977 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
978 let archive = Archive::new(decompressed_bytes);
979 archive.unpack(dist_dir).await?;
980
981 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
982 }
983
984 Ok(server_path)
985 }
986
987 match fetch_latest(http).await {
988 ok @ Result::Ok(..) => ok,
989 e @ Err(..) => {
990 e.log_err();
991 // Fetch a cached binary, if it exists
992 (|| async move {
993 let mut last_version_dir = None;
994 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
995 while let Some(entry) = entries.next().await {
996 let entry = entry?;
997 if entry.file_type().await?.is_dir() {
998 last_version_dir = Some(entry.path());
999 }
1000 }
1001 let last_version_dir =
1002 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1003 let server_path = last_version_dir.join(SERVER_PATH);
1004 if server_path.exists() {
1005 Ok(server_path)
1006 } else {
1007 Err(anyhow!(
1008 "missing executable in directory {:?}",
1009 last_version_dir
1010 ))
1011 }
1012 })()
1013 .await
1014 }
1015 }
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020 use super::*;
1021 use gpui::TestAppContext;
1022
1023 #[gpui::test(iterations = 10)]
1024 async fn test_buffer_management(cx: &mut TestAppContext) {
1025 let (copilot, mut lsp) = Copilot::fake(cx);
1026
1027 let buffer_1 = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), "Hello"));
1028 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1029 .parse()
1030 .unwrap();
1031 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1032 assert_eq!(
1033 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1034 .await,
1035 lsp::DidOpenTextDocumentParams {
1036 text_document: lsp::TextDocumentItem::new(
1037 buffer_1_uri.clone(),
1038 "plaintext".into(),
1039 0,
1040 "Hello".into()
1041 ),
1042 }
1043 );
1044
1045 let buffer_2 = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), "Goodbye"));
1046 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1047 .parse()
1048 .unwrap();
1049 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1050 assert_eq!(
1051 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1052 .await,
1053 lsp::DidOpenTextDocumentParams {
1054 text_document: lsp::TextDocumentItem::new(
1055 buffer_2_uri.clone(),
1056 "plaintext".into(),
1057 0,
1058 "Goodbye".into()
1059 ),
1060 }
1061 );
1062
1063 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1064 assert_eq!(
1065 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1066 .await,
1067 lsp::DidChangeTextDocumentParams {
1068 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1069 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1070 range: Some(lsp::Range::new(
1071 lsp::Position::new(0, 5),
1072 lsp::Position::new(0, 5)
1073 )),
1074 range_length: None,
1075 text: " world".into(),
1076 }],
1077 }
1078 );
1079
1080 // Ensure updates to the file are reflected in the LSP.
1081 buffer_1.update(cx, |buffer, cx| {
1082 buffer.file_updated(
1083 Arc::new(File {
1084 abs_path: "/root/child/buffer-1".into(),
1085 path: Path::new("child/buffer-1").into(),
1086 }),
1087 cx,
1088 )
1089 });
1090 assert_eq!(
1091 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1092 .await,
1093 lsp::DidCloseTextDocumentParams {
1094 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1095 }
1096 );
1097 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1098 assert_eq!(
1099 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1100 .await,
1101 lsp::DidOpenTextDocumentParams {
1102 text_document: lsp::TextDocumentItem::new(
1103 buffer_1_uri.clone(),
1104 "plaintext".into(),
1105 1,
1106 "Hello world".into()
1107 ),
1108 }
1109 );
1110
1111 // Ensure all previously-registered buffers are closed when signing out.
1112 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1113 Ok(request::SignOutResult {})
1114 });
1115 copilot
1116 .update(cx, |copilot, cx| copilot.sign_out(cx))
1117 .await
1118 .unwrap();
1119 // todo!() po: these notifications now happen in reverse order?
1120 assert_eq!(
1121 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1122 .await,
1123 lsp::DidCloseTextDocumentParams {
1124 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1125 }
1126 );
1127 assert_eq!(
1128 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1129 .await,
1130 lsp::DidCloseTextDocumentParams {
1131 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1132 }
1133 );
1134
1135 // Ensure all previously-registered buffers are re-opened when signing in.
1136 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1137 Ok(request::SignInInitiateResult::AlreadySignedIn {
1138 user: "user-1".into(),
1139 })
1140 });
1141 copilot
1142 .update(cx, |copilot, cx| copilot.sign_in(cx))
1143 .await
1144 .unwrap();
1145
1146 assert_eq!(
1147 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1148 .await,
1149 lsp::DidOpenTextDocumentParams {
1150 text_document: lsp::TextDocumentItem::new(
1151 buffer_1_uri.clone(),
1152 "plaintext".into(),
1153 0,
1154 "Hello world".into()
1155 ),
1156 }
1157 );
1158 assert_eq!(
1159 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1160 .await,
1161 lsp::DidOpenTextDocumentParams {
1162 text_document: lsp::TextDocumentItem::new(
1163 buffer_2_uri.clone(),
1164 "plaintext".into(),
1165 0,
1166 "Goodbye".into()
1167 ),
1168 }
1169 );
1170 // Dropping a buffer causes it to be closed on the LSP side as well.
1171 cx.update(|_| drop(buffer_2));
1172 assert_eq!(
1173 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1174 .await,
1175 lsp::DidCloseTextDocumentParams {
1176 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1177 }
1178 );
1179 }
1180
1181 struct File {
1182 abs_path: PathBuf,
1183 path: Arc<Path>,
1184 }
1185
1186 impl language::File for File {
1187 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1188 Some(self)
1189 }
1190
1191 fn mtime(&self) -> std::time::SystemTime {
1192 unimplemented!()
1193 }
1194
1195 fn path(&self) -> &Arc<Path> {
1196 &self.path
1197 }
1198
1199 fn full_path(&self, _: &AppContext) -> PathBuf {
1200 unimplemented!()
1201 }
1202
1203 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1204 unimplemented!()
1205 }
1206
1207 fn is_deleted(&self) -> bool {
1208 unimplemented!()
1209 }
1210
1211 fn as_any(&self) -> &dyn std::any::Any {
1212 unimplemented!()
1213 }
1214
1215 fn to_proto(&self) -> rpc::proto::File {
1216 unimplemented!()
1217 }
1218
1219 fn worktree_id(&self) -> usize {
1220 0
1221 }
1222 }
1223
1224 impl language::LocalFile for File {
1225 fn abs_path(&self, _: &AppContext) -> PathBuf {
1226 self.abs_path.clone()
1227 }
1228
1229 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1230 unimplemented!()
1231 }
1232
1233 fn buffer_reloaded(
1234 &self,
1235 _: u64,
1236 _: &clock::Global,
1237 _: language::RopeFingerprint,
1238 _: language::LineEnding,
1239 _: std::time::SystemTime,
1240 _: &mut AppContext,
1241 ) {
1242 unimplemented!()
1243 }
1244 }
1245}