1mod copilot_completion_provider;
2pub mod request;
3mod sign_in;
4
5use anyhow::{anyhow, Context as _, Result};
6use async_compression::futures::bufread::GzipDecoder;
7use async_tar::Archive;
8use collections::{HashMap, HashSet};
9use command_palette_hooks::CommandPaletteFilter;
10use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
11use gpui::{
12 actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Global, Model,
13 ModelContext, Task, WeakModel,
14};
15use language::{
16 language_settings::{all_language_settings, language_settings, InlineCompletionProvider},
17 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
18 ToPointUtf16,
19};
20use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId};
21use node_runtime::NodeRuntime;
22use parking_lot::Mutex;
23use request::StatusNotification;
24use settings::SettingsStore;
25use smol::{fs, io::BufReader, stream::StreamExt};
26use std::{
27 any::TypeId,
28 ffi::OsString,
29 mem,
30 ops::Range,
31 path::{Path, PathBuf},
32 sync::Arc,
33};
34use util::{
35 fs::remove_matching, github::latest_github_release, http::HttpClient, maybe, paths, ResultExt,
36};
37
38pub use copilot_completion_provider::CopilotCompletionProvider;
39pub use sign_in::CopilotCodeVerification;
40
41actions!(
42 copilot,
43 [
44 Suggest,
45 NextSuggestion,
46 PreviousSuggestion,
47 Reinstall,
48 SignIn,
49 SignOut
50 ]
51);
52
53pub fn init(
54 new_server_id: LanguageServerId,
55 http: Arc<dyn HttpClient>,
56 node_runtime: Arc<dyn NodeRuntime>,
57 cx: &mut AppContext,
58) {
59 let copilot = cx.new_model({
60 let node_runtime = node_runtime.clone();
61 move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
62 });
63 Copilot::set_global(copilot.clone(), cx);
64 cx.observe(&copilot, |handle, cx| {
65 let copilot_action_types = [
66 TypeId::of::<Suggest>(),
67 TypeId::of::<NextSuggestion>(),
68 TypeId::of::<PreviousSuggestion>(),
69 TypeId::of::<Reinstall>(),
70 ];
71 let copilot_auth_action_types = [TypeId::of::<SignOut>()];
72 let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
73 let status = handle.read(cx).status();
74 let filter = CommandPaletteFilter::global_mut(cx);
75
76 match status {
77 Status::Disabled => {
78 filter.hide_action_types(&copilot_action_types);
79 filter.hide_action_types(&copilot_auth_action_types);
80 filter.hide_action_types(&copilot_no_auth_action_types);
81 }
82 Status::Authorized => {
83 filter.hide_action_types(&copilot_no_auth_action_types);
84 filter.show_action_types(
85 copilot_action_types
86 .iter()
87 .chain(&copilot_auth_action_types),
88 );
89 }
90 _ => {
91 filter.hide_action_types(&copilot_action_types);
92 filter.hide_action_types(&copilot_auth_action_types);
93 filter.show_action_types(copilot_no_auth_action_types.iter());
94 }
95 }
96 })
97 .detach();
98
99 cx.on_action(|_: &SignIn, cx| {
100 if let Some(copilot) = Copilot::global(cx) {
101 copilot
102 .update(cx, |copilot, cx| copilot.sign_in(cx))
103 .detach_and_log_err(cx);
104 }
105 });
106 cx.on_action(|_: &SignOut, cx| {
107 if let Some(copilot) = Copilot::global(cx) {
108 copilot
109 .update(cx, |copilot, cx| copilot.sign_out(cx))
110 .detach_and_log_err(cx);
111 }
112 });
113 cx.on_action(|_: &Reinstall, cx| {
114 if let Some(copilot) = Copilot::global(cx) {
115 copilot
116 .update(cx, |copilot, cx| copilot.reinstall(cx))
117 .detach();
118 }
119 });
120}
121
122enum CopilotServer {
123 Disabled,
124 Starting { task: Shared<Task<()>> },
125 Error(Arc<str>),
126 Running(RunningCopilotServer),
127}
128
129impl CopilotServer {
130 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
131 let server = self.as_running()?;
132 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
133 Ok(server)
134 } else {
135 Err(anyhow!("must sign in before using copilot"))
136 }
137 }
138
139 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
140 match self {
141 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
142 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
143 CopilotServer::Error(error) => Err(anyhow!(
144 "copilot was not started because of an error: {}",
145 error
146 )),
147 CopilotServer::Running(server) => Ok(server),
148 }
149 }
150}
151
152struct RunningCopilotServer {
153 lsp: Arc<LanguageServer>,
154 sign_in_status: SignInStatus,
155 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
156}
157
158#[derive(Clone, Debug)]
159enum SignInStatus {
160 Authorized,
161 Unauthorized,
162 SigningIn {
163 prompt: Option<request::PromptUserDeviceFlow>,
164 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
165 },
166 SignedOut,
167}
168
169#[derive(Debug, Clone)]
170pub enum Status {
171 Starting {
172 task: Shared<Task<()>>,
173 },
174 Error(Arc<str>),
175 Disabled,
176 SignedOut,
177 SigningIn {
178 prompt: Option<request::PromptUserDeviceFlow>,
179 },
180 Unauthorized,
181 Authorized,
182}
183
184impl Status {
185 pub fn is_authorized(&self) -> bool {
186 matches!(self, Status::Authorized)
187 }
188}
189
190struct RegisteredBuffer {
191 uri: lsp::Url,
192 language_id: String,
193 snapshot: BufferSnapshot,
194 snapshot_version: i32,
195 _subscriptions: [gpui::Subscription; 2],
196 pending_buffer_change: Task<Option<()>>,
197}
198
199impl RegisteredBuffer {
200 fn report_changes(
201 &mut self,
202 buffer: &Model<Buffer>,
203 cx: &mut ModelContext<Copilot>,
204 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
205 let (done_tx, done_rx) = oneshot::channel();
206
207 if buffer.read(cx).version() == self.snapshot.version {
208 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
209 } else {
210 let buffer = buffer.downgrade();
211 let id = buffer.entity_id();
212 let prev_pending_change =
213 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
214 self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
215 prev_pending_change.await;
216
217 let old_version = copilot
218 .update(&mut cx, |copilot, _| {
219 let server = copilot.server.as_authenticated().log_err()?;
220 let buffer = server.registered_buffers.get_mut(&id)?;
221 Some(buffer.snapshot.version.clone())
222 })
223 .ok()??;
224 let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
225
226 let content_changes = cx
227 .background_executor()
228 .spawn({
229 let new_snapshot = new_snapshot.clone();
230 async move {
231 new_snapshot
232 .edits_since::<(PointUtf16, usize)>(&old_version)
233 .map(|edit| {
234 let edit_start = edit.new.start.0;
235 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
236 let new_text = new_snapshot
237 .text_for_range(edit.new.start.1..edit.new.end.1)
238 .collect();
239 lsp::TextDocumentContentChangeEvent {
240 range: Some(lsp::Range::new(
241 point_to_lsp(edit_start),
242 point_to_lsp(edit_end),
243 )),
244 range_length: None,
245 text: new_text,
246 }
247 })
248 .collect::<Vec<_>>()
249 }
250 })
251 .await;
252
253 copilot
254 .update(&mut cx, |copilot, _| {
255 let server = copilot.server.as_authenticated().log_err()?;
256 let buffer = server.registered_buffers.get_mut(&id)?;
257 if !content_changes.is_empty() {
258 buffer.snapshot_version += 1;
259 buffer.snapshot = new_snapshot;
260 server
261 .lsp
262 .notify::<lsp::notification::DidChangeTextDocument>(
263 lsp::DidChangeTextDocumentParams {
264 text_document: lsp::VersionedTextDocumentIdentifier::new(
265 buffer.uri.clone(),
266 buffer.snapshot_version,
267 ),
268 content_changes,
269 },
270 )
271 .log_err();
272 }
273 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
274 Some(())
275 })
276 .ok()?;
277
278 Some(())
279 });
280 }
281
282 done_rx
283 }
284}
285
286#[derive(Debug)]
287pub struct Completion {
288 pub uuid: String,
289 pub range: Range<Anchor>,
290 pub text: String,
291}
292
293pub struct Copilot {
294 http: Arc<dyn HttpClient>,
295 node_runtime: Arc<dyn NodeRuntime>,
296 server: CopilotServer,
297 buffers: HashSet<WeakModel<Buffer>>,
298 server_id: LanguageServerId,
299 _subscription: gpui::Subscription,
300}
301
302pub enum Event {
303 CopilotLanguageServerStarted,
304}
305
306impl EventEmitter<Event> for Copilot {}
307
308struct GlobalCopilot(Model<Copilot>);
309
310impl Global for GlobalCopilot {}
311
312impl Copilot {
313 pub fn global(cx: &AppContext) -> Option<Model<Self>> {
314 cx.try_global::<GlobalCopilot>()
315 .map(|model| model.0.clone())
316 }
317
318 pub fn set_global(copilot: Model<Self>, cx: &mut AppContext) {
319 cx.set_global(GlobalCopilot(copilot));
320 }
321
322 fn start(
323 new_server_id: LanguageServerId,
324 http: Arc<dyn HttpClient>,
325 node_runtime: Arc<dyn NodeRuntime>,
326 cx: &mut ModelContext<Self>,
327 ) -> Self {
328 let mut this = Self {
329 server_id: new_server_id,
330 http,
331 node_runtime,
332 server: CopilotServer::Disabled,
333 buffers: Default::default(),
334 _subscription: cx.on_app_quit(Self::shutdown_language_server),
335 };
336 this.enable_or_disable_copilot(cx);
337 cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
338 .detach();
339 this
340 }
341
342 fn shutdown_language_server(
343 &mut self,
344 _cx: &mut ModelContext<Self>,
345 ) -> impl Future<Output = ()> {
346 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
347 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
348 _ => None,
349 };
350
351 async move {
352 if let Some(shutdown) = shutdown {
353 shutdown.await;
354 }
355 }
356 }
357
358 fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
359 let server_id = self.server_id;
360 let http = self.http.clone();
361 let node_runtime = self.node_runtime.clone();
362 if all_language_settings(None, cx).inline_completions.provider
363 == InlineCompletionProvider::Copilot
364 {
365 if matches!(self.server, CopilotServer::Disabled) {
366 let start_task = cx
367 .spawn(move |this, cx| {
368 Self::start_language_server(server_id, http, node_runtime, this, cx)
369 })
370 .shared();
371 self.server = CopilotServer::Starting { task: start_task };
372 cx.notify();
373 }
374 } else {
375 self.server = CopilotServer::Disabled;
376 cx.notify();
377 }
378 }
379
380 #[cfg(any(test, feature = "test-support"))]
381 pub fn fake(cx: &mut gpui::TestAppContext) -> (Model<Self>, lsp::FakeLanguageServer) {
382 use lsp::FakeLanguageServer;
383 use node_runtime::FakeNodeRuntime;
384
385 let (server, fake_server) = FakeLanguageServer::new(
386 LanguageServerId(0),
387 LanguageServerBinary {
388 path: "path/to/copilot".into(),
389 arguments: vec![],
390 env: None,
391 },
392 "copilot".into(),
393 Default::default(),
394 cx.to_async(),
395 );
396 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
397 let node_runtime = FakeNodeRuntime::new();
398 let this = cx.new_model(|cx| Self {
399 server_id: LanguageServerId(0),
400 http: http.clone(),
401 node_runtime,
402 server: CopilotServer::Running(RunningCopilotServer {
403 lsp: Arc::new(server),
404 sign_in_status: SignInStatus::Authorized,
405 registered_buffers: Default::default(),
406 }),
407 _subscription: cx.on_app_quit(Self::shutdown_language_server),
408 buffers: Default::default(),
409 });
410 (this, fake_server)
411 }
412
413 fn start_language_server(
414 new_server_id: LanguageServerId,
415 http: Arc<dyn HttpClient>,
416 node_runtime: Arc<dyn NodeRuntime>,
417 this: WeakModel<Self>,
418 mut cx: AsyncAppContext,
419 ) -> impl Future<Output = ()> {
420 async move {
421 let start_language_server = async {
422 let server_path = get_copilot_lsp(http).await?;
423 let node_path = node_runtime.binary_path().await?;
424 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
425 let binary = LanguageServerBinary {
426 path: node_path,
427 arguments,
428 // TODO: We could set HTTP_PROXY etc here and fix the copilot issue.
429 env: None,
430 };
431
432 let server = LanguageServer::new(
433 Arc::new(Mutex::new(None)),
434 new_server_id,
435 binary,
436 Path::new("/"),
437 None,
438 cx.clone(),
439 )?;
440
441 server
442 .on_notification::<StatusNotification, _>(
443 |_, _| { /* Silence the notification */ },
444 )
445 .detach();
446 let server = cx.update(|cx| server.initialize(None, cx))?.await?;
447
448 let status = server
449 .request::<request::CheckStatus>(request::CheckStatusParams {
450 local_checks_only: false,
451 })
452 .await?;
453
454 server
455 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
456 editor_info: request::EditorInfo {
457 name: "zed".into(),
458 version: env!("CARGO_PKG_VERSION").into(),
459 },
460 editor_plugin_info: request::EditorPluginInfo {
461 name: "zed-copilot".into(),
462 version: "0.0.1".into(),
463 },
464 })
465 .await?;
466
467 anyhow::Ok((server, status))
468 };
469
470 let server = start_language_server.await;
471 this.update(&mut cx, |this, cx| {
472 cx.notify();
473 match server {
474 Ok((server, status)) => {
475 this.server = CopilotServer::Running(RunningCopilotServer {
476 lsp: server,
477 sign_in_status: SignInStatus::SignedOut,
478 registered_buffers: Default::default(),
479 });
480 cx.emit(Event::CopilotLanguageServerStarted);
481 this.update_sign_in_status(status, cx);
482 }
483 Err(error) => {
484 this.server = CopilotServer::Error(error.to_string().into());
485 cx.notify()
486 }
487 }
488 })
489 .ok();
490 }
491 }
492
493 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
494 if let CopilotServer::Running(server) = &mut self.server {
495 let task = match &server.sign_in_status {
496 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
497 SignInStatus::SigningIn { task, .. } => {
498 cx.notify();
499 task.clone()
500 }
501 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
502 let lsp = server.lsp.clone();
503 let task = cx
504 .spawn(|this, mut cx| async move {
505 let sign_in = async {
506 let sign_in = lsp
507 .request::<request::SignInInitiate>(
508 request::SignInInitiateParams {},
509 )
510 .await?;
511 match sign_in {
512 request::SignInInitiateResult::AlreadySignedIn { user } => {
513 Ok(request::SignInStatus::Ok { user: Some(user) })
514 }
515 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
516 this.update(&mut cx, |this, cx| {
517 if let CopilotServer::Running(RunningCopilotServer {
518 sign_in_status: status,
519 ..
520 }) = &mut this.server
521 {
522 if let SignInStatus::SigningIn {
523 prompt: prompt_flow,
524 ..
525 } = status
526 {
527 *prompt_flow = Some(flow.clone());
528 cx.notify();
529 }
530 }
531 })?;
532 let response = lsp
533 .request::<request::SignInConfirm>(
534 request::SignInConfirmParams {
535 user_code: flow.user_code,
536 },
537 )
538 .await?;
539 Ok(response)
540 }
541 }
542 };
543
544 let sign_in = sign_in.await;
545 this.update(&mut cx, |this, cx| match sign_in {
546 Ok(status) => {
547 this.update_sign_in_status(status, cx);
548 Ok(())
549 }
550 Err(error) => {
551 this.update_sign_in_status(
552 request::SignInStatus::NotSignedIn,
553 cx,
554 );
555 Err(Arc::new(error))
556 }
557 })?
558 })
559 .shared();
560 server.sign_in_status = SignInStatus::SigningIn {
561 prompt: None,
562 task: task.clone(),
563 };
564 cx.notify();
565 task
566 }
567 };
568
569 cx.background_executor()
570 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
571 } else {
572 // If we're downloading, wait until download is finished
573 // If we're in a stuck state, display to the user
574 Task::ready(Err(anyhow!("copilot hasn't started yet")))
575 }
576 }
577
578 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
579 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
580 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
581 let server = server.clone();
582 cx.background_executor().spawn(async move {
583 server
584 .request::<request::SignOut>(request::SignOutParams {})
585 .await?;
586 anyhow::Ok(())
587 })
588 } else {
589 Task::ready(Err(anyhow!("copilot hasn't started yet")))
590 }
591 }
592
593 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
594 let start_task = cx
595 .spawn({
596 let http = self.http.clone();
597 let node_runtime = self.node_runtime.clone();
598 let server_id = self.server_id;
599 move |this, cx| async move {
600 clear_copilot_dir().await;
601 Self::start_language_server(server_id, http, node_runtime, this, cx).await
602 }
603 })
604 .shared();
605
606 self.server = CopilotServer::Starting {
607 task: start_task.clone(),
608 };
609
610 cx.notify();
611
612 cx.background_executor().spawn(start_task)
613 }
614
615 pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
616 if let CopilotServer::Running(server) = &self.server {
617 Some(&server.lsp)
618 } else {
619 None
620 }
621 }
622
623 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
624 let weak_buffer = buffer.downgrade();
625 self.buffers.insert(weak_buffer.clone());
626
627 if let CopilotServer::Running(RunningCopilotServer {
628 lsp: server,
629 sign_in_status: status,
630 registered_buffers,
631 ..
632 }) = &mut self.server
633 {
634 if !matches!(status, SignInStatus::Authorized { .. }) {
635 return;
636 }
637
638 registered_buffers
639 .entry(buffer.entity_id())
640 .or_insert_with(|| {
641 let uri: lsp::Url = uri_for_buffer(buffer, cx);
642 let language_id = id_for_language(buffer.read(cx).language());
643 let snapshot = buffer.read(cx).snapshot();
644 server
645 .notify::<lsp::notification::DidOpenTextDocument>(
646 lsp::DidOpenTextDocumentParams {
647 text_document: lsp::TextDocumentItem {
648 uri: uri.clone(),
649 language_id: language_id.clone(),
650 version: 0,
651 text: snapshot.text(),
652 },
653 },
654 )
655 .log_err();
656
657 RegisteredBuffer {
658 uri,
659 language_id,
660 snapshot,
661 snapshot_version: 0,
662 pending_buffer_change: Task::ready(Some(())),
663 _subscriptions: [
664 cx.subscribe(buffer, |this, buffer, event, cx| {
665 this.handle_buffer_event(buffer, event, cx).log_err();
666 }),
667 cx.observe_release(buffer, move |this, _buffer, _cx| {
668 this.buffers.remove(&weak_buffer);
669 this.unregister_buffer(&weak_buffer);
670 }),
671 ],
672 }
673 });
674 }
675 }
676
677 fn handle_buffer_event(
678 &mut self,
679 buffer: Model<Buffer>,
680 event: &language::Event,
681 cx: &mut ModelContext<Self>,
682 ) -> Result<()> {
683 if let Ok(server) = self.server.as_running() {
684 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
685 {
686 match event {
687 language::Event::Edited => {
688 let _ = registered_buffer.report_changes(&buffer, cx);
689 }
690 language::Event::Saved => {
691 server
692 .lsp
693 .notify::<lsp::notification::DidSaveTextDocument>(
694 lsp::DidSaveTextDocumentParams {
695 text_document: lsp::TextDocumentIdentifier::new(
696 registered_buffer.uri.clone(),
697 ),
698 text: None,
699 },
700 )?;
701 }
702 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
703 let new_language_id = id_for_language(buffer.read(cx).language());
704 let new_uri = uri_for_buffer(&buffer, cx);
705 if new_uri != registered_buffer.uri
706 || new_language_id != registered_buffer.language_id
707 {
708 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
709 registered_buffer.language_id = new_language_id;
710 server
711 .lsp
712 .notify::<lsp::notification::DidCloseTextDocument>(
713 lsp::DidCloseTextDocumentParams {
714 text_document: lsp::TextDocumentIdentifier::new(old_uri),
715 },
716 )?;
717 server
718 .lsp
719 .notify::<lsp::notification::DidOpenTextDocument>(
720 lsp::DidOpenTextDocumentParams {
721 text_document: lsp::TextDocumentItem::new(
722 registered_buffer.uri.clone(),
723 registered_buffer.language_id.clone(),
724 registered_buffer.snapshot_version,
725 registered_buffer.snapshot.text(),
726 ),
727 },
728 )?;
729 }
730 }
731 _ => {}
732 }
733 }
734 }
735
736 Ok(())
737 }
738
739 fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
740 if let Ok(server) = self.server.as_running() {
741 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
742 server
743 .lsp
744 .notify::<lsp::notification::DidCloseTextDocument>(
745 lsp::DidCloseTextDocumentParams {
746 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
747 },
748 )
749 .log_err();
750 }
751 }
752 }
753
754 pub fn completions<T>(
755 &mut self,
756 buffer: &Model<Buffer>,
757 position: T,
758 cx: &mut ModelContext<Self>,
759 ) -> Task<Result<Vec<Completion>>>
760 where
761 T: ToPointUtf16,
762 {
763 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
764 }
765
766 pub fn completions_cycling<T>(
767 &mut self,
768 buffer: &Model<Buffer>,
769 position: T,
770 cx: &mut ModelContext<Self>,
771 ) -> Task<Result<Vec<Completion>>>
772 where
773 T: ToPointUtf16,
774 {
775 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
776 }
777
778 pub fn accept_completion(
779 &mut self,
780 completion: &Completion,
781 cx: &mut ModelContext<Self>,
782 ) -> Task<Result<()>> {
783 let server = match self.server.as_authenticated() {
784 Ok(server) => server,
785 Err(error) => return Task::ready(Err(error)),
786 };
787 let request =
788 server
789 .lsp
790 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
791 uuid: completion.uuid.clone(),
792 });
793 cx.background_executor().spawn(async move {
794 request.await?;
795 Ok(())
796 })
797 }
798
799 pub fn discard_completions(
800 &mut self,
801 completions: &[Completion],
802 cx: &mut ModelContext<Self>,
803 ) -> Task<Result<()>> {
804 let server = match self.server.as_authenticated() {
805 Ok(server) => server,
806 Err(_) => return Task::ready(Ok(())),
807 };
808 let request =
809 server
810 .lsp
811 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
812 uuids: completions
813 .iter()
814 .map(|completion| completion.uuid.clone())
815 .collect(),
816 });
817 cx.background_executor().spawn(async move {
818 request.await?;
819 Ok(())
820 })
821 }
822
823 fn request_completions<R, T>(
824 &mut self,
825 buffer: &Model<Buffer>,
826 position: T,
827 cx: &mut ModelContext<Self>,
828 ) -> Task<Result<Vec<Completion>>>
829 where
830 R: 'static
831 + lsp::request::Request<
832 Params = request::GetCompletionsParams,
833 Result = request::GetCompletionsResult,
834 >,
835 T: ToPointUtf16,
836 {
837 self.register_buffer(buffer, cx);
838
839 let server = match self.server.as_authenticated() {
840 Ok(server) => server,
841 Err(error) => return Task::ready(Err(error)),
842 };
843 let lsp = server.lsp.clone();
844 let registered_buffer = server
845 .registered_buffers
846 .get_mut(&buffer.entity_id())
847 .unwrap();
848 let snapshot = registered_buffer.report_changes(buffer, cx);
849 let buffer = buffer.read(cx);
850 let uri = registered_buffer.uri.clone();
851 let position = position.to_point_utf16(buffer);
852 let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
853 let tab_size = settings.tab_size;
854 let hard_tabs = settings.hard_tabs;
855 let relative_path = buffer
856 .file()
857 .map(|file| file.path().to_path_buf())
858 .unwrap_or_default();
859
860 cx.background_executor().spawn(async move {
861 let (version, snapshot) = snapshot.await?;
862 let result = lsp
863 .request::<R>(request::GetCompletionsParams {
864 doc: request::GetCompletionsDocument {
865 uri,
866 tab_size: tab_size.into(),
867 indent_size: 1,
868 insert_spaces: !hard_tabs,
869 relative_path: relative_path.to_string_lossy().into(),
870 position: point_to_lsp(position),
871 version: version.try_into().unwrap(),
872 },
873 })
874 .await?;
875 let completions = result
876 .completions
877 .into_iter()
878 .map(|completion| {
879 let start = snapshot
880 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
881 let end =
882 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
883 Completion {
884 uuid: completion.uuid,
885 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
886 text: completion.text,
887 }
888 })
889 .collect();
890 anyhow::Ok(completions)
891 })
892 }
893
894 pub fn status(&self) -> Status {
895 match &self.server {
896 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
897 CopilotServer::Disabled => Status::Disabled,
898 CopilotServer::Error(error) => Status::Error(error.clone()),
899 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
900 match sign_in_status {
901 SignInStatus::Authorized { .. } => Status::Authorized,
902 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
903 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
904 prompt: prompt.clone(),
905 },
906 SignInStatus::SignedOut => Status::SignedOut,
907 }
908 }
909 }
910 }
911
912 fn update_sign_in_status(
913 &mut self,
914 lsp_status: request::SignInStatus,
915 cx: &mut ModelContext<Self>,
916 ) {
917 self.buffers.retain(|buffer| buffer.is_upgradable());
918
919 if let Ok(server) = self.server.as_running() {
920 match lsp_status {
921 request::SignInStatus::Ok { user: Some(_) }
922 | request::SignInStatus::MaybeOk { .. }
923 | request::SignInStatus::AlreadySignedIn { .. } => {
924 server.sign_in_status = SignInStatus::Authorized;
925 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
926 if let Some(buffer) = buffer.upgrade() {
927 self.register_buffer(&buffer, cx);
928 }
929 }
930 }
931 request::SignInStatus::NotAuthorized { .. } => {
932 server.sign_in_status = SignInStatus::Unauthorized;
933 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
934 self.unregister_buffer(&buffer);
935 }
936 }
937 request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
938 server.sign_in_status = SignInStatus::SignedOut;
939 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
940 self.unregister_buffer(&buffer);
941 }
942 }
943 }
944
945 cx.notify();
946 }
947 }
948}
949
950fn id_for_language(language: Option<&Arc<Language>>) -> String {
951 language
952 .map(|language| language.lsp_id())
953 .unwrap_or_else(|| "plaintext".to_string())
954}
955
956fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp::Url {
957 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
958 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
959 } else {
960 format!("buffer://{}", buffer.entity_id()).parse().unwrap()
961 }
962}
963
964async fn clear_copilot_dir() {
965 remove_matching(&paths::COPILOT_DIR, |_| true).await
966}
967
968async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
969 const SERVER_PATH: &str = "dist/agent.js";
970
971 ///Check for the latest copilot language server and download it if we haven't already
972 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
973 let release =
974 latest_github_release("zed-industries/copilot", true, false, http.clone()).await?;
975
976 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.tag_name));
977
978 fs::create_dir_all(version_dir).await?;
979 let server_path = version_dir.join(SERVER_PATH);
980
981 if fs::metadata(&server_path).await.is_err() {
982 // Copilot LSP looks for this dist dir specifically, so lets add it in.
983 let dist_dir = version_dir.join("dist");
984 fs::create_dir_all(dist_dir.as_path()).await?;
985
986 let url = &release
987 .assets
988 .get(0)
989 .context("Github release for copilot contained no assets")?
990 .browser_download_url;
991
992 let mut response = http
993 .get(url, Default::default(), true)
994 .await
995 .context("error downloading copilot release")?;
996 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
997 let archive = Archive::new(decompressed_bytes);
998 archive.unpack(dist_dir).await?;
999
1000 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
1001 }
1002
1003 Ok(server_path)
1004 }
1005
1006 match fetch_latest(http).await {
1007 ok @ Result::Ok(..) => ok,
1008 e @ Err(..) => {
1009 e.log_err();
1010 // Fetch a cached binary, if it exists
1011 maybe!(async {
1012 let mut last_version_dir = None;
1013 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1014 while let Some(entry) = entries.next().await {
1015 let entry = entry?;
1016 if entry.file_type().await?.is_dir() {
1017 last_version_dir = Some(entry.path());
1018 }
1019 }
1020 let last_version_dir =
1021 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1022 let server_path = last_version_dir.join(SERVER_PATH);
1023 if server_path.exists() {
1024 Ok(server_path)
1025 } else {
1026 Err(anyhow!(
1027 "missing executable in directory {:?}",
1028 last_version_dir
1029 ))
1030 }
1031 })
1032 .await
1033 }
1034 }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039 use super::*;
1040 use gpui::TestAppContext;
1041 use language::BufferId;
1042
1043 #[gpui::test(iterations = 10)]
1044 async fn test_buffer_management(cx: &mut TestAppContext) {
1045 let (copilot, mut lsp) = Copilot::fake(cx);
1046
1047 let buffer_1 = cx.new_model(|cx| Buffer::local("Hello", cx));
1048 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1049 .parse()
1050 .unwrap();
1051 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1052 assert_eq!(
1053 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1054 .await,
1055 lsp::DidOpenTextDocumentParams {
1056 text_document: lsp::TextDocumentItem::new(
1057 buffer_1_uri.clone(),
1058 "plaintext".into(),
1059 0,
1060 "Hello".into()
1061 ),
1062 }
1063 );
1064
1065 let buffer_2 = cx.new_model(|cx| Buffer::local("Goodbye", cx));
1066 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1067 .parse()
1068 .unwrap();
1069 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1070 assert_eq!(
1071 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1072 .await,
1073 lsp::DidOpenTextDocumentParams {
1074 text_document: lsp::TextDocumentItem::new(
1075 buffer_2_uri.clone(),
1076 "plaintext".into(),
1077 0,
1078 "Goodbye".into()
1079 ),
1080 }
1081 );
1082
1083 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1084 assert_eq!(
1085 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1086 .await,
1087 lsp::DidChangeTextDocumentParams {
1088 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1089 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1090 range: Some(lsp::Range::new(
1091 lsp::Position::new(0, 5),
1092 lsp::Position::new(0, 5)
1093 )),
1094 range_length: None,
1095 text: " world".into(),
1096 }],
1097 }
1098 );
1099
1100 // Ensure updates to the file are reflected in the LSP.
1101 buffer_1.update(cx, |buffer, cx| {
1102 buffer.file_updated(
1103 Arc::new(File {
1104 abs_path: "/root/child/buffer-1".into(),
1105 path: Path::new("child/buffer-1").into(),
1106 }),
1107 cx,
1108 )
1109 });
1110 assert_eq!(
1111 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1112 .await,
1113 lsp::DidCloseTextDocumentParams {
1114 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1115 }
1116 );
1117 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1118 assert_eq!(
1119 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1120 .await,
1121 lsp::DidOpenTextDocumentParams {
1122 text_document: lsp::TextDocumentItem::new(
1123 buffer_1_uri.clone(),
1124 "plaintext".into(),
1125 1,
1126 "Hello world".into()
1127 ),
1128 }
1129 );
1130
1131 // Ensure all previously-registered buffers are closed when signing out.
1132 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1133 Ok(request::SignOutResult {})
1134 });
1135 copilot
1136 .update(cx, |copilot, cx| copilot.sign_out(cx))
1137 .await
1138 .unwrap();
1139 assert_eq!(
1140 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1141 .await,
1142 lsp::DidCloseTextDocumentParams {
1143 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1144 }
1145 );
1146 assert_eq!(
1147 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1148 .await,
1149 lsp::DidCloseTextDocumentParams {
1150 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1151 }
1152 );
1153
1154 // Ensure all previously-registered buffers are re-opened when signing in.
1155 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1156 Ok(request::SignInInitiateResult::AlreadySignedIn {
1157 user: "user-1".into(),
1158 })
1159 });
1160 copilot
1161 .update(cx, |copilot, cx| copilot.sign_in(cx))
1162 .await
1163 .unwrap();
1164
1165 assert_eq!(
1166 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1167 .await,
1168 lsp::DidOpenTextDocumentParams {
1169 text_document: lsp::TextDocumentItem::new(
1170 buffer_1_uri.clone(),
1171 "plaintext".into(),
1172 0,
1173 "Hello world".into()
1174 ),
1175 }
1176 );
1177 assert_eq!(
1178 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1179 .await,
1180 lsp::DidOpenTextDocumentParams {
1181 text_document: lsp::TextDocumentItem::new(
1182 buffer_2_uri.clone(),
1183 "plaintext".into(),
1184 0,
1185 "Goodbye".into()
1186 ),
1187 }
1188 );
1189 // Dropping a buffer causes it to be closed on the LSP side as well.
1190 cx.update(|_| drop(buffer_2));
1191 assert_eq!(
1192 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1193 .await,
1194 lsp::DidCloseTextDocumentParams {
1195 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1196 }
1197 );
1198 }
1199
1200 struct File {
1201 abs_path: PathBuf,
1202 path: Arc<Path>,
1203 }
1204
1205 impl language::File for File {
1206 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1207 Some(self)
1208 }
1209
1210 fn mtime(&self) -> Option<std::time::SystemTime> {
1211 unimplemented!()
1212 }
1213
1214 fn path(&self) -> &Arc<Path> {
1215 &self.path
1216 }
1217
1218 fn full_path(&self, _: &AppContext) -> PathBuf {
1219 unimplemented!()
1220 }
1221
1222 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1223 unimplemented!()
1224 }
1225
1226 fn is_deleted(&self) -> bool {
1227 unimplemented!()
1228 }
1229
1230 fn as_any(&self) -> &dyn std::any::Any {
1231 unimplemented!()
1232 }
1233
1234 fn to_proto(&self) -> rpc::proto::File {
1235 unimplemented!()
1236 }
1237
1238 fn worktree_id(&self) -> usize {
1239 0
1240 }
1241
1242 fn is_private(&self) -> bool {
1243 false
1244 }
1245 }
1246
1247 impl language::LocalFile for File {
1248 fn abs_path(&self, _: &AppContext) -> PathBuf {
1249 self.abs_path.clone()
1250 }
1251
1252 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1253 unimplemented!()
1254 }
1255
1256 fn buffer_reloaded(
1257 &self,
1258 _: BufferId,
1259 _: &clock::Global,
1260 _: language::LineEnding,
1261 _: Option<std::time::SystemTime>,
1262 _: &mut AppContext,
1263 ) {
1264 unimplemented!()
1265 }
1266 }
1267}