1mod copilot_completion_provider;
2pub mod request;
3mod sign_in;
4
5use anyhow::{anyhow, Context as _, Result};
6use async_compression::futures::bufread::GzipDecoder;
7use async_tar::Archive;
8use collections::{HashMap, HashSet};
9use command_palette_hooks::CommandPaletteFilter;
10use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
11use gpui::{
12 actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Global, Model,
13 ModelContext, Task, WeakModel,
14};
15use http::github::latest_github_release;
16use http::HttpClient;
17use language::{
18 language_settings::{all_language_settings, language_settings, InlineCompletionProvider},
19 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
20 ToPointUtf16,
21};
22use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId};
23use node_runtime::NodeRuntime;
24use parking_lot::Mutex;
25use request::StatusNotification;
26use settings::SettingsStore;
27use smol::{fs, io::BufReader, stream::StreamExt};
28use std::{
29 any::TypeId,
30 ffi::OsString,
31 mem,
32 ops::Range,
33 path::{Path, PathBuf},
34 sync::Arc,
35};
36use util::{fs::remove_matching, maybe, paths, ResultExt};
37
38pub use copilot_completion_provider::CopilotCompletionProvider;
39pub use sign_in::CopilotCodeVerification;
40
41actions!(
42 copilot,
43 [
44 Suggest,
45 NextSuggestion,
46 PreviousSuggestion,
47 Reinstall,
48 SignIn,
49 SignOut
50 ]
51);
52
53pub fn init(
54 new_server_id: LanguageServerId,
55 http: Arc<dyn HttpClient>,
56 node_runtime: Arc<dyn NodeRuntime>,
57 cx: &mut AppContext,
58) {
59 let copilot = cx.new_model({
60 let node_runtime = node_runtime.clone();
61 move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
62 });
63 Copilot::set_global(copilot.clone(), cx);
64 cx.observe(&copilot, |handle, cx| {
65 let copilot_action_types = [
66 TypeId::of::<Suggest>(),
67 TypeId::of::<NextSuggestion>(),
68 TypeId::of::<PreviousSuggestion>(),
69 TypeId::of::<Reinstall>(),
70 ];
71 let copilot_auth_action_types = [TypeId::of::<SignOut>()];
72 let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
73 let status = handle.read(cx).status();
74 let filter = CommandPaletteFilter::global_mut(cx);
75
76 match status {
77 Status::Disabled => {
78 filter.hide_action_types(&copilot_action_types);
79 filter.hide_action_types(&copilot_auth_action_types);
80 filter.hide_action_types(&copilot_no_auth_action_types);
81 }
82 Status::Authorized => {
83 filter.hide_action_types(&copilot_no_auth_action_types);
84 filter.show_action_types(
85 copilot_action_types
86 .iter()
87 .chain(&copilot_auth_action_types),
88 );
89 }
90 _ => {
91 filter.hide_action_types(&copilot_action_types);
92 filter.hide_action_types(&copilot_auth_action_types);
93 filter.show_action_types(copilot_no_auth_action_types.iter());
94 }
95 }
96 })
97 .detach();
98
99 cx.on_action(|_: &SignIn, cx| {
100 if let Some(copilot) = Copilot::global(cx) {
101 copilot
102 .update(cx, |copilot, cx| copilot.sign_in(cx))
103 .detach_and_log_err(cx);
104 }
105 });
106 cx.on_action(|_: &SignOut, cx| {
107 if let Some(copilot) = Copilot::global(cx) {
108 copilot
109 .update(cx, |copilot, cx| copilot.sign_out(cx))
110 .detach_and_log_err(cx);
111 }
112 });
113 cx.on_action(|_: &Reinstall, cx| {
114 if let Some(copilot) = Copilot::global(cx) {
115 copilot
116 .update(cx, |copilot, cx| copilot.reinstall(cx))
117 .detach();
118 }
119 });
120}
121
122enum CopilotServer {
123 Disabled,
124 Starting { task: Shared<Task<()>> },
125 Error(Arc<str>),
126 Running(RunningCopilotServer),
127}
128
129impl CopilotServer {
130 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
131 let server = self.as_running()?;
132 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
133 Ok(server)
134 } else {
135 Err(anyhow!("must sign in before using copilot"))
136 }
137 }
138
139 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
140 match self {
141 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
142 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
143 CopilotServer::Error(error) => Err(anyhow!(
144 "copilot was not started because of an error: {}",
145 error
146 )),
147 CopilotServer::Running(server) => Ok(server),
148 }
149 }
150}
151
152struct RunningCopilotServer {
153 lsp: Arc<LanguageServer>,
154 sign_in_status: SignInStatus,
155 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
156}
157
158#[derive(Clone, Debug)]
159enum SignInStatus {
160 Authorized,
161 Unauthorized,
162 SigningIn {
163 prompt: Option<request::PromptUserDeviceFlow>,
164 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
165 },
166 SignedOut,
167}
168
169#[derive(Debug, Clone)]
170pub enum Status {
171 Starting {
172 task: Shared<Task<()>>,
173 },
174 Error(Arc<str>),
175 Disabled,
176 SignedOut,
177 SigningIn {
178 prompt: Option<request::PromptUserDeviceFlow>,
179 },
180 Unauthorized,
181 Authorized,
182}
183
184impl Status {
185 pub fn is_authorized(&self) -> bool {
186 matches!(self, Status::Authorized)
187 }
188}
189
190struct RegisteredBuffer {
191 uri: lsp::Url,
192 language_id: String,
193 snapshot: BufferSnapshot,
194 snapshot_version: i32,
195 _subscriptions: [gpui::Subscription; 2],
196 pending_buffer_change: Task<Option<()>>,
197}
198
199impl RegisteredBuffer {
200 fn report_changes(
201 &mut self,
202 buffer: &Model<Buffer>,
203 cx: &mut ModelContext<Copilot>,
204 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
205 let (done_tx, done_rx) = oneshot::channel();
206
207 if buffer.read(cx).version() == self.snapshot.version {
208 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
209 } else {
210 let buffer = buffer.downgrade();
211 let id = buffer.entity_id();
212 let prev_pending_change =
213 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
214 self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
215 prev_pending_change.await;
216
217 let old_version = copilot
218 .update(&mut cx, |copilot, _| {
219 let server = copilot.server.as_authenticated().log_err()?;
220 let buffer = server.registered_buffers.get_mut(&id)?;
221 Some(buffer.snapshot.version.clone())
222 })
223 .ok()??;
224 let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
225
226 let content_changes = cx
227 .background_executor()
228 .spawn({
229 let new_snapshot = new_snapshot.clone();
230 async move {
231 new_snapshot
232 .edits_since::<(PointUtf16, usize)>(&old_version)
233 .map(|edit| {
234 let edit_start = edit.new.start.0;
235 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
236 let new_text = new_snapshot
237 .text_for_range(edit.new.start.1..edit.new.end.1)
238 .collect();
239 lsp::TextDocumentContentChangeEvent {
240 range: Some(lsp::Range::new(
241 point_to_lsp(edit_start),
242 point_to_lsp(edit_end),
243 )),
244 range_length: None,
245 text: new_text,
246 }
247 })
248 .collect::<Vec<_>>()
249 }
250 })
251 .await;
252
253 copilot
254 .update(&mut cx, |copilot, _| {
255 let server = copilot.server.as_authenticated().log_err()?;
256 let buffer = server.registered_buffers.get_mut(&id)?;
257 if !content_changes.is_empty() {
258 buffer.snapshot_version += 1;
259 buffer.snapshot = new_snapshot;
260 server
261 .lsp
262 .notify::<lsp::notification::DidChangeTextDocument>(
263 lsp::DidChangeTextDocumentParams {
264 text_document: lsp::VersionedTextDocumentIdentifier::new(
265 buffer.uri.clone(),
266 buffer.snapshot_version,
267 ),
268 content_changes,
269 },
270 )
271 .log_err();
272 }
273 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
274 Some(())
275 })
276 .ok()?;
277
278 Some(())
279 });
280 }
281
282 done_rx
283 }
284}
285
286#[derive(Debug)]
287pub struct Completion {
288 pub uuid: String,
289 pub range: Range<Anchor>,
290 pub text: String,
291}
292
293pub struct Copilot {
294 http: Arc<dyn HttpClient>,
295 node_runtime: Arc<dyn NodeRuntime>,
296 server: CopilotServer,
297 buffers: HashSet<WeakModel<Buffer>>,
298 server_id: LanguageServerId,
299 _subscription: gpui::Subscription,
300}
301
302pub enum Event {
303 CopilotLanguageServerStarted,
304}
305
306impl EventEmitter<Event> for Copilot {}
307
308struct GlobalCopilot(Model<Copilot>);
309
310impl Global for GlobalCopilot {}
311
312impl Copilot {
313 pub fn global(cx: &AppContext) -> Option<Model<Self>> {
314 cx.try_global::<GlobalCopilot>()
315 .map(|model| model.0.clone())
316 }
317
318 pub fn set_global(copilot: Model<Self>, cx: &mut AppContext) {
319 cx.set_global(GlobalCopilot(copilot));
320 }
321
322 fn start(
323 new_server_id: LanguageServerId,
324 http: Arc<dyn HttpClient>,
325 node_runtime: Arc<dyn NodeRuntime>,
326 cx: &mut ModelContext<Self>,
327 ) -> Self {
328 let mut this = Self {
329 server_id: new_server_id,
330 http,
331 node_runtime,
332 server: CopilotServer::Disabled,
333 buffers: Default::default(),
334 _subscription: cx.on_app_quit(Self::shutdown_language_server),
335 };
336 this.enable_or_disable_copilot(cx);
337 cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
338 .detach();
339 this
340 }
341
342 fn shutdown_language_server(
343 &mut self,
344 _cx: &mut ModelContext<Self>,
345 ) -> impl Future<Output = ()> {
346 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
347 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
348 _ => None,
349 };
350
351 async move {
352 if let Some(shutdown) = shutdown {
353 shutdown.await;
354 }
355 }
356 }
357
358 fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
359 let server_id = self.server_id;
360 let http = self.http.clone();
361 let node_runtime = self.node_runtime.clone();
362 if all_language_settings(None, cx).inline_completions.provider
363 == InlineCompletionProvider::Copilot
364 {
365 if matches!(self.server, CopilotServer::Disabled) {
366 let start_task = cx
367 .spawn(move |this, cx| {
368 Self::start_language_server(server_id, http, node_runtime, this, cx)
369 })
370 .shared();
371 self.server = CopilotServer::Starting { task: start_task };
372 cx.notify();
373 }
374 } else {
375 self.server = CopilotServer::Disabled;
376 cx.notify();
377 }
378 }
379
380 #[cfg(any(test, feature = "test-support"))]
381 pub fn fake(cx: &mut gpui::TestAppContext) -> (Model<Self>, lsp::FakeLanguageServer) {
382 use lsp::FakeLanguageServer;
383 use node_runtime::FakeNodeRuntime;
384
385 let (server, fake_server) = FakeLanguageServer::new(
386 LanguageServerId(0),
387 LanguageServerBinary {
388 path: "path/to/copilot".into(),
389 arguments: vec![],
390 env: None,
391 },
392 "copilot".into(),
393 Default::default(),
394 cx.to_async(),
395 );
396 let http = http::FakeHttpClient::create(|_| async { unreachable!() });
397 let node_runtime = FakeNodeRuntime::new();
398 let this = cx.new_model(|cx| Self {
399 server_id: LanguageServerId(0),
400 http: http.clone(),
401 node_runtime,
402 server: CopilotServer::Running(RunningCopilotServer {
403 lsp: Arc::new(server),
404 sign_in_status: SignInStatus::Authorized,
405 registered_buffers: Default::default(),
406 }),
407 _subscription: cx.on_app_quit(Self::shutdown_language_server),
408 buffers: Default::default(),
409 });
410 (this, fake_server)
411 }
412
413 fn start_language_server(
414 new_server_id: LanguageServerId,
415 http: Arc<dyn HttpClient>,
416 node_runtime: Arc<dyn NodeRuntime>,
417 this: WeakModel<Self>,
418 mut cx: AsyncAppContext,
419 ) -> impl Future<Output = ()> {
420 async move {
421 let start_language_server = async {
422 let server_path = get_copilot_lsp(http).await?;
423 let node_path = node_runtime.binary_path().await?;
424 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
425 let binary = LanguageServerBinary {
426 path: node_path,
427 arguments,
428 // TODO: We could set HTTP_PROXY etc here and fix the copilot issue.
429 env: None,
430 };
431
432 let root_path = if cfg!(target_os = "windows") {
433 Path::new("C:/")
434 } else {
435 Path::new("/")
436 };
437
438 let server = LanguageServer::new(
439 Arc::new(Mutex::new(None)),
440 new_server_id,
441 binary,
442 root_path,
443 None,
444 cx.clone(),
445 )?;
446
447 server
448 .on_notification::<StatusNotification, _>(
449 |_, _| { /* Silence the notification */ },
450 )
451 .detach();
452 let server = cx.update(|cx| server.initialize(None, cx))?.await?;
453
454 let status = server
455 .request::<request::CheckStatus>(request::CheckStatusParams {
456 local_checks_only: false,
457 })
458 .await?;
459
460 server
461 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
462 editor_info: request::EditorInfo {
463 name: "zed".into(),
464 version: env!("CARGO_PKG_VERSION").into(),
465 },
466 editor_plugin_info: request::EditorPluginInfo {
467 name: "zed-copilot".into(),
468 version: "0.0.1".into(),
469 },
470 })
471 .await?;
472
473 anyhow::Ok((server, status))
474 };
475
476 let server = start_language_server.await;
477 this.update(&mut cx, |this, cx| {
478 cx.notify();
479 match server {
480 Ok((server, status)) => {
481 this.server = CopilotServer::Running(RunningCopilotServer {
482 lsp: server,
483 sign_in_status: SignInStatus::SignedOut,
484 registered_buffers: Default::default(),
485 });
486 cx.emit(Event::CopilotLanguageServerStarted);
487 this.update_sign_in_status(status, cx);
488 }
489 Err(error) => {
490 this.server = CopilotServer::Error(error.to_string().into());
491 cx.notify()
492 }
493 }
494 })
495 .ok();
496 }
497 }
498
499 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
500 if let CopilotServer::Running(server) = &mut self.server {
501 let task = match &server.sign_in_status {
502 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
503 SignInStatus::SigningIn { task, .. } => {
504 cx.notify();
505 task.clone()
506 }
507 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
508 let lsp = server.lsp.clone();
509 let task = cx
510 .spawn(|this, mut cx| async move {
511 let sign_in = async {
512 let sign_in = lsp
513 .request::<request::SignInInitiate>(
514 request::SignInInitiateParams {},
515 )
516 .await?;
517 match sign_in {
518 request::SignInInitiateResult::AlreadySignedIn { user } => {
519 Ok(request::SignInStatus::Ok { user: Some(user) })
520 }
521 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
522 this.update(&mut cx, |this, cx| {
523 if let CopilotServer::Running(RunningCopilotServer {
524 sign_in_status: status,
525 ..
526 }) = &mut this.server
527 {
528 if let SignInStatus::SigningIn {
529 prompt: prompt_flow,
530 ..
531 } = status
532 {
533 *prompt_flow = Some(flow.clone());
534 cx.notify();
535 }
536 }
537 })?;
538 let response = lsp
539 .request::<request::SignInConfirm>(
540 request::SignInConfirmParams {
541 user_code: flow.user_code,
542 },
543 )
544 .await?;
545 Ok(response)
546 }
547 }
548 };
549
550 let sign_in = sign_in.await;
551 this.update(&mut cx, |this, cx| match sign_in {
552 Ok(status) => {
553 this.update_sign_in_status(status, cx);
554 Ok(())
555 }
556 Err(error) => {
557 this.update_sign_in_status(
558 request::SignInStatus::NotSignedIn,
559 cx,
560 );
561 Err(Arc::new(error))
562 }
563 })?
564 })
565 .shared();
566 server.sign_in_status = SignInStatus::SigningIn {
567 prompt: None,
568 task: task.clone(),
569 };
570 cx.notify();
571 task
572 }
573 };
574
575 cx.background_executor()
576 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
577 } else {
578 // If we're downloading, wait until download is finished
579 // If we're in a stuck state, display to the user
580 Task::ready(Err(anyhow!("copilot hasn't started yet")))
581 }
582 }
583
584 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
585 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
586 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
587 let server = server.clone();
588 cx.background_executor().spawn(async move {
589 server
590 .request::<request::SignOut>(request::SignOutParams {})
591 .await?;
592 anyhow::Ok(())
593 })
594 } else {
595 Task::ready(Err(anyhow!("copilot hasn't started yet")))
596 }
597 }
598
599 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
600 let start_task = cx
601 .spawn({
602 let http = self.http.clone();
603 let node_runtime = self.node_runtime.clone();
604 let server_id = self.server_id;
605 move |this, cx| async move {
606 clear_copilot_dir().await;
607 Self::start_language_server(server_id, http, node_runtime, this, cx).await
608 }
609 })
610 .shared();
611
612 self.server = CopilotServer::Starting {
613 task: start_task.clone(),
614 };
615
616 cx.notify();
617
618 cx.background_executor().spawn(start_task)
619 }
620
621 pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
622 if let CopilotServer::Running(server) = &self.server {
623 Some(&server.lsp)
624 } else {
625 None
626 }
627 }
628
629 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
630 let weak_buffer = buffer.downgrade();
631 self.buffers.insert(weak_buffer.clone());
632
633 if let CopilotServer::Running(RunningCopilotServer {
634 lsp: server,
635 sign_in_status: status,
636 registered_buffers,
637 ..
638 }) = &mut self.server
639 {
640 if !matches!(status, SignInStatus::Authorized { .. }) {
641 return;
642 }
643
644 registered_buffers
645 .entry(buffer.entity_id())
646 .or_insert_with(|| {
647 let uri: lsp::Url = uri_for_buffer(buffer, cx);
648 let language_id = id_for_language(buffer.read(cx).language());
649 let snapshot = buffer.read(cx).snapshot();
650 server
651 .notify::<lsp::notification::DidOpenTextDocument>(
652 lsp::DidOpenTextDocumentParams {
653 text_document: lsp::TextDocumentItem {
654 uri: uri.clone(),
655 language_id: language_id.clone(),
656 version: 0,
657 text: snapshot.text(),
658 },
659 },
660 )
661 .log_err();
662
663 RegisteredBuffer {
664 uri,
665 language_id,
666 snapshot,
667 snapshot_version: 0,
668 pending_buffer_change: Task::ready(Some(())),
669 _subscriptions: [
670 cx.subscribe(buffer, |this, buffer, event, cx| {
671 this.handle_buffer_event(buffer, event, cx).log_err();
672 }),
673 cx.observe_release(buffer, move |this, _buffer, _cx| {
674 this.buffers.remove(&weak_buffer);
675 this.unregister_buffer(&weak_buffer);
676 }),
677 ],
678 }
679 });
680 }
681 }
682
683 fn handle_buffer_event(
684 &mut self,
685 buffer: Model<Buffer>,
686 event: &language::Event,
687 cx: &mut ModelContext<Self>,
688 ) -> Result<()> {
689 if let Ok(server) = self.server.as_running() {
690 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
691 {
692 match event {
693 language::Event::Edited => {
694 let _ = registered_buffer.report_changes(&buffer, cx);
695 }
696 language::Event::Saved => {
697 server
698 .lsp
699 .notify::<lsp::notification::DidSaveTextDocument>(
700 lsp::DidSaveTextDocumentParams {
701 text_document: lsp::TextDocumentIdentifier::new(
702 registered_buffer.uri.clone(),
703 ),
704 text: None,
705 },
706 )?;
707 }
708 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
709 let new_language_id = id_for_language(buffer.read(cx).language());
710 let new_uri = uri_for_buffer(&buffer, cx);
711 if new_uri != registered_buffer.uri
712 || new_language_id != registered_buffer.language_id
713 {
714 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
715 registered_buffer.language_id = new_language_id;
716 server
717 .lsp
718 .notify::<lsp::notification::DidCloseTextDocument>(
719 lsp::DidCloseTextDocumentParams {
720 text_document: lsp::TextDocumentIdentifier::new(old_uri),
721 },
722 )?;
723 server
724 .lsp
725 .notify::<lsp::notification::DidOpenTextDocument>(
726 lsp::DidOpenTextDocumentParams {
727 text_document: lsp::TextDocumentItem::new(
728 registered_buffer.uri.clone(),
729 registered_buffer.language_id.clone(),
730 registered_buffer.snapshot_version,
731 registered_buffer.snapshot.text(),
732 ),
733 },
734 )?;
735 }
736 }
737 _ => {}
738 }
739 }
740 }
741
742 Ok(())
743 }
744
745 fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
746 if let Ok(server) = self.server.as_running() {
747 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
748 server
749 .lsp
750 .notify::<lsp::notification::DidCloseTextDocument>(
751 lsp::DidCloseTextDocumentParams {
752 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
753 },
754 )
755 .log_err();
756 }
757 }
758 }
759
760 pub fn completions<T>(
761 &mut self,
762 buffer: &Model<Buffer>,
763 position: T,
764 cx: &mut ModelContext<Self>,
765 ) -> Task<Result<Vec<Completion>>>
766 where
767 T: ToPointUtf16,
768 {
769 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
770 }
771
772 pub fn completions_cycling<T>(
773 &mut self,
774 buffer: &Model<Buffer>,
775 position: T,
776 cx: &mut ModelContext<Self>,
777 ) -> Task<Result<Vec<Completion>>>
778 where
779 T: ToPointUtf16,
780 {
781 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
782 }
783
784 pub fn accept_completion(
785 &mut self,
786 completion: &Completion,
787 cx: &mut ModelContext<Self>,
788 ) -> Task<Result<()>> {
789 let server = match self.server.as_authenticated() {
790 Ok(server) => server,
791 Err(error) => return Task::ready(Err(error)),
792 };
793 let request =
794 server
795 .lsp
796 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
797 uuid: completion.uuid.clone(),
798 });
799 cx.background_executor().spawn(async move {
800 request.await?;
801 Ok(())
802 })
803 }
804
805 pub fn discard_completions(
806 &mut self,
807 completions: &[Completion],
808 cx: &mut ModelContext<Self>,
809 ) -> Task<Result<()>> {
810 let server = match self.server.as_authenticated() {
811 Ok(server) => server,
812 Err(_) => return Task::ready(Ok(())),
813 };
814 let request =
815 server
816 .lsp
817 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
818 uuids: completions
819 .iter()
820 .map(|completion| completion.uuid.clone())
821 .collect(),
822 });
823 cx.background_executor().spawn(async move {
824 request.await?;
825 Ok(())
826 })
827 }
828
829 fn request_completions<R, T>(
830 &mut self,
831 buffer: &Model<Buffer>,
832 position: T,
833 cx: &mut ModelContext<Self>,
834 ) -> Task<Result<Vec<Completion>>>
835 where
836 R: 'static
837 + lsp::request::Request<
838 Params = request::GetCompletionsParams,
839 Result = request::GetCompletionsResult,
840 >,
841 T: ToPointUtf16,
842 {
843 self.register_buffer(buffer, cx);
844
845 let server = match self.server.as_authenticated() {
846 Ok(server) => server,
847 Err(error) => return Task::ready(Err(error)),
848 };
849 let lsp = server.lsp.clone();
850 let registered_buffer = server
851 .registered_buffers
852 .get_mut(&buffer.entity_id())
853 .unwrap();
854 let snapshot = registered_buffer.report_changes(buffer, cx);
855 let buffer = buffer.read(cx);
856 let uri = registered_buffer.uri.clone();
857 let position = position.to_point_utf16(buffer);
858 let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
859 let tab_size = settings.tab_size;
860 let hard_tabs = settings.hard_tabs;
861 let relative_path = buffer
862 .file()
863 .map(|file| file.path().to_path_buf())
864 .unwrap_or_default();
865
866 cx.background_executor().spawn(async move {
867 let (version, snapshot) = snapshot.await?;
868 let result = lsp
869 .request::<R>(request::GetCompletionsParams {
870 doc: request::GetCompletionsDocument {
871 uri,
872 tab_size: tab_size.into(),
873 indent_size: 1,
874 insert_spaces: !hard_tabs,
875 relative_path: relative_path.to_string_lossy().into(),
876 position: point_to_lsp(position),
877 version: version.try_into().unwrap(),
878 },
879 })
880 .await?;
881 let completions = result
882 .completions
883 .into_iter()
884 .map(|completion| {
885 let start = snapshot
886 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
887 let end =
888 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
889 Completion {
890 uuid: completion.uuid,
891 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
892 text: completion.text,
893 }
894 })
895 .collect();
896 anyhow::Ok(completions)
897 })
898 }
899
900 pub fn status(&self) -> Status {
901 match &self.server {
902 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
903 CopilotServer::Disabled => Status::Disabled,
904 CopilotServer::Error(error) => Status::Error(error.clone()),
905 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
906 match sign_in_status {
907 SignInStatus::Authorized { .. } => Status::Authorized,
908 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
909 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
910 prompt: prompt.clone(),
911 },
912 SignInStatus::SignedOut => Status::SignedOut,
913 }
914 }
915 }
916 }
917
918 fn update_sign_in_status(
919 &mut self,
920 lsp_status: request::SignInStatus,
921 cx: &mut ModelContext<Self>,
922 ) {
923 self.buffers.retain(|buffer| buffer.is_upgradable());
924
925 if let Ok(server) = self.server.as_running() {
926 match lsp_status {
927 request::SignInStatus::Ok { user: Some(_) }
928 | request::SignInStatus::MaybeOk { .. }
929 | request::SignInStatus::AlreadySignedIn { .. } => {
930 server.sign_in_status = SignInStatus::Authorized;
931 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
932 if let Some(buffer) = buffer.upgrade() {
933 self.register_buffer(&buffer, cx);
934 }
935 }
936 }
937 request::SignInStatus::NotAuthorized { .. } => {
938 server.sign_in_status = SignInStatus::Unauthorized;
939 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
940 self.unregister_buffer(&buffer);
941 }
942 }
943 request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
944 server.sign_in_status = SignInStatus::SignedOut;
945 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
946 self.unregister_buffer(&buffer);
947 }
948 }
949 }
950
951 cx.notify();
952 }
953 }
954}
955
956fn id_for_language(language: Option<&Arc<Language>>) -> String {
957 language
958 .map(|language| language.lsp_id())
959 .unwrap_or_else(|| "plaintext".to_string())
960}
961
962fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp::Url {
963 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
964 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
965 } else {
966 format!("buffer://{}", buffer.entity_id()).parse().unwrap()
967 }
968}
969
970async fn clear_copilot_dir() {
971 remove_matching(&paths::COPILOT_DIR, |_| true).await
972}
973
974async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
975 const SERVER_PATH: &str = "dist/agent.js";
976
977 ///Check for the latest copilot language server and download it if we haven't already
978 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
979 let release =
980 latest_github_release("zed-industries/copilot", true, false, http.clone()).await?;
981
982 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.tag_name));
983
984 fs::create_dir_all(version_dir).await?;
985 let server_path = version_dir.join(SERVER_PATH);
986
987 if fs::metadata(&server_path).await.is_err() {
988 // Copilot LSP looks for this dist dir specifically, so lets add it in.
989 let dist_dir = version_dir.join("dist");
990 fs::create_dir_all(dist_dir.as_path()).await?;
991
992 let url = &release
993 .assets
994 .get(0)
995 .context("Github release for copilot contained no assets")?
996 .browser_download_url;
997
998 let mut response = http
999 .get(url, Default::default(), true)
1000 .await
1001 .context("error downloading copilot release")?;
1002 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
1003 let archive = Archive::new(decompressed_bytes);
1004 archive.unpack(dist_dir).await?;
1005
1006 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
1007 }
1008
1009 Ok(server_path)
1010 }
1011
1012 match fetch_latest(http).await {
1013 ok @ Result::Ok(..) => ok,
1014 e @ Err(..) => {
1015 e.log_err();
1016 // Fetch a cached binary, if it exists
1017 maybe!(async {
1018 let mut last_version_dir = None;
1019 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1020 while let Some(entry) = entries.next().await {
1021 let entry = entry?;
1022 if entry.file_type().await?.is_dir() {
1023 last_version_dir = Some(entry.path());
1024 }
1025 }
1026 let last_version_dir =
1027 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1028 let server_path = last_version_dir.join(SERVER_PATH);
1029 if server_path.exists() {
1030 Ok(server_path)
1031 } else {
1032 Err(anyhow!(
1033 "missing executable in directory {:?}",
1034 last_version_dir
1035 ))
1036 }
1037 })
1038 .await
1039 }
1040 }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045 use super::*;
1046 use gpui::TestAppContext;
1047
1048 #[gpui::test(iterations = 10)]
1049 async fn test_buffer_management(cx: &mut TestAppContext) {
1050 let (copilot, mut lsp) = Copilot::fake(cx);
1051
1052 let buffer_1 = cx.new_model(|cx| Buffer::local("Hello", cx));
1053 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1054 .parse()
1055 .unwrap();
1056 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1057 assert_eq!(
1058 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1059 .await,
1060 lsp::DidOpenTextDocumentParams {
1061 text_document: lsp::TextDocumentItem::new(
1062 buffer_1_uri.clone(),
1063 "plaintext".into(),
1064 0,
1065 "Hello".into()
1066 ),
1067 }
1068 );
1069
1070 let buffer_2 = cx.new_model(|cx| Buffer::local("Goodbye", cx));
1071 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1072 .parse()
1073 .unwrap();
1074 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1075 assert_eq!(
1076 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1077 .await,
1078 lsp::DidOpenTextDocumentParams {
1079 text_document: lsp::TextDocumentItem::new(
1080 buffer_2_uri.clone(),
1081 "plaintext".into(),
1082 0,
1083 "Goodbye".into()
1084 ),
1085 }
1086 );
1087
1088 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1089 assert_eq!(
1090 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1091 .await,
1092 lsp::DidChangeTextDocumentParams {
1093 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1094 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1095 range: Some(lsp::Range::new(
1096 lsp::Position::new(0, 5),
1097 lsp::Position::new(0, 5)
1098 )),
1099 range_length: None,
1100 text: " world".into(),
1101 }],
1102 }
1103 );
1104
1105 // Ensure updates to the file are reflected in the LSP.
1106 buffer_1.update(cx, |buffer, cx| {
1107 buffer.file_updated(
1108 Arc::new(File {
1109 abs_path: "/root/child/buffer-1".into(),
1110 path: Path::new("child/buffer-1").into(),
1111 }),
1112 cx,
1113 )
1114 });
1115 assert_eq!(
1116 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1117 .await,
1118 lsp::DidCloseTextDocumentParams {
1119 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1120 }
1121 );
1122 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1123 assert_eq!(
1124 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1125 .await,
1126 lsp::DidOpenTextDocumentParams {
1127 text_document: lsp::TextDocumentItem::new(
1128 buffer_1_uri.clone(),
1129 "plaintext".into(),
1130 1,
1131 "Hello world".into()
1132 ),
1133 }
1134 );
1135
1136 // Ensure all previously-registered buffers are closed when signing out.
1137 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1138 Ok(request::SignOutResult {})
1139 });
1140 copilot
1141 .update(cx, |copilot, cx| copilot.sign_out(cx))
1142 .await
1143 .unwrap();
1144 assert_eq!(
1145 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1146 .await,
1147 lsp::DidCloseTextDocumentParams {
1148 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1149 }
1150 );
1151 assert_eq!(
1152 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1153 .await,
1154 lsp::DidCloseTextDocumentParams {
1155 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1156 }
1157 );
1158
1159 // Ensure all previously-registered buffers are re-opened when signing in.
1160 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1161 Ok(request::SignInInitiateResult::AlreadySignedIn {
1162 user: "user-1".into(),
1163 })
1164 });
1165 copilot
1166 .update(cx, |copilot, cx| copilot.sign_in(cx))
1167 .await
1168 .unwrap();
1169
1170 assert_eq!(
1171 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1172 .await,
1173 lsp::DidOpenTextDocumentParams {
1174 text_document: lsp::TextDocumentItem::new(
1175 buffer_1_uri.clone(),
1176 "plaintext".into(),
1177 0,
1178 "Hello world".into()
1179 ),
1180 }
1181 );
1182 assert_eq!(
1183 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1184 .await,
1185 lsp::DidOpenTextDocumentParams {
1186 text_document: lsp::TextDocumentItem::new(
1187 buffer_2_uri.clone(),
1188 "plaintext".into(),
1189 0,
1190 "Goodbye".into()
1191 ),
1192 }
1193 );
1194 // Dropping a buffer causes it to be closed on the LSP side as well.
1195 cx.update(|_| drop(buffer_2));
1196 assert_eq!(
1197 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1198 .await,
1199 lsp::DidCloseTextDocumentParams {
1200 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1201 }
1202 );
1203 }
1204
1205 struct File {
1206 abs_path: PathBuf,
1207 path: Arc<Path>,
1208 }
1209
1210 impl language::File for File {
1211 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1212 Some(self)
1213 }
1214
1215 fn mtime(&self) -> Option<std::time::SystemTime> {
1216 unimplemented!()
1217 }
1218
1219 fn path(&self) -> &Arc<Path> {
1220 &self.path
1221 }
1222
1223 fn full_path(&self, _: &AppContext) -> PathBuf {
1224 unimplemented!()
1225 }
1226
1227 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1228 unimplemented!()
1229 }
1230
1231 fn is_deleted(&self) -> bool {
1232 unimplemented!()
1233 }
1234
1235 fn as_any(&self) -> &dyn std::any::Any {
1236 unimplemented!()
1237 }
1238
1239 fn to_proto(&self) -> rpc::proto::File {
1240 unimplemented!()
1241 }
1242
1243 fn worktree_id(&self) -> usize {
1244 0
1245 }
1246
1247 fn is_private(&self) -> bool {
1248 false
1249 }
1250 }
1251
1252 impl language::LocalFile for File {
1253 fn abs_path(&self, _: &AppContext) -> PathBuf {
1254 self.abs_path.clone()
1255 }
1256
1257 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1258 unimplemented!()
1259 }
1260 }
1261}