1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context as _, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::{HashMap, HashSet};
8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Model,
11 ModelContext, Task, WeakModel,
12};
13use language::{
14 language_settings::{all_language_settings, language_settings},
15 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language,
16 LanguageServerName, PointUtf16, ToPointUtf16,
17};
18use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId};
19use node_runtime::NodeRuntime;
20use parking_lot::Mutex;
21use request::StatusNotification;
22use settings::SettingsStore;
23use smol::{fs, io::BufReader, stream::StreamExt};
24use std::{
25 ffi::OsString,
26 mem,
27 ops::Range,
28 path::{Path, PathBuf},
29 sync::Arc,
30};
31use util::{
32 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
33};
34
35// todo!()
36// const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
37actions!(SignIn, SignOut);
38
39// todo!()
40// const COPILOT_NAMESPACE: &'static str = "copilot";
41actions!(Suggest, NextSuggestion, PreviousSuggestion, Reinstall);
42
43pub fn init(
44 new_server_id: LanguageServerId,
45 http: Arc<dyn HttpClient>,
46 node_runtime: Arc<dyn NodeRuntime>,
47 cx: &mut AppContext,
48) {
49 let copilot = cx.build_model({
50 let node_runtime = node_runtime.clone();
51 move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
52 });
53 cx.set_global(copilot.clone());
54
55 // TODO
56 // cx.observe(&copilot, |handle, cx| {
57 // let status = handle.read(cx).status();
58 // cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
59 // match status {
60 // Status::Disabled => {
61 // filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
62 // filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
63 // }
64 // Status::Authorized => {
65 // filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
66 // filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
67 // }
68 // _ => {
69 // filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
70 // filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
71 // }
72 // }
73 // });
74 // })
75 // .detach();
76
77 // sign_in::init(cx);
78 // cx.add_global_action(|_: &SignIn, cx| {
79 // if let Some(copilot) = Copilot::global(cx) {
80 // copilot
81 // .update(cx, |copilot, cx| copilot.sign_in(cx))
82 // .detach_and_log_err(cx);
83 // }
84 // });
85 // cx.add_global_action(|_: &SignOut, cx| {
86 // if let Some(copilot) = Copilot::global(cx) {
87 // copilot
88 // .update(cx, |copilot, cx| copilot.sign_out(cx))
89 // .detach_and_log_err(cx);
90 // }
91 // });
92
93 // cx.add_global_action(|_: &Reinstall, cx| {
94 // if let Some(copilot) = Copilot::global(cx) {
95 // copilot
96 // .update(cx, |copilot, cx| copilot.reinstall(cx))
97 // .detach();
98 // }
99 // });
100}
101
102enum CopilotServer {
103 Disabled,
104 Starting { task: Shared<Task<()>> },
105 Error(Arc<str>),
106 Running(RunningCopilotServer),
107}
108
109impl CopilotServer {
110 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
111 let server = self.as_running()?;
112 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
113 Ok(server)
114 } else {
115 Err(anyhow!("must sign in before using copilot"))
116 }
117 }
118
119 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
120 match self {
121 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
122 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
123 CopilotServer::Error(error) => Err(anyhow!(
124 "copilot was not started because of an error: {}",
125 error
126 )),
127 CopilotServer::Running(server) => Ok(server),
128 }
129 }
130}
131
132struct RunningCopilotServer {
133 name: LanguageServerName,
134 lsp: Arc<LanguageServer>,
135 sign_in_status: SignInStatus,
136 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
137}
138
139#[derive(Clone, Debug)]
140enum SignInStatus {
141 Authorized,
142 Unauthorized,
143 SigningIn {
144 prompt: Option<request::PromptUserDeviceFlow>,
145 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
146 },
147 SignedOut,
148}
149
150#[derive(Debug, Clone)]
151pub enum Status {
152 Starting {
153 task: Shared<Task<()>>,
154 },
155 Error(Arc<str>),
156 Disabled,
157 SignedOut,
158 SigningIn {
159 prompt: Option<request::PromptUserDeviceFlow>,
160 },
161 Unauthorized,
162 Authorized,
163}
164
165impl Status {
166 pub fn is_authorized(&self) -> bool {
167 matches!(self, Status::Authorized)
168 }
169}
170
171struct RegisteredBuffer {
172 uri: lsp::Url,
173 language_id: String,
174 snapshot: BufferSnapshot,
175 snapshot_version: i32,
176 _subscriptions: [gpui::Subscription; 2],
177 pending_buffer_change: Task<Option<()>>,
178}
179
180impl RegisteredBuffer {
181 fn report_changes(
182 &mut self,
183 buffer: &Model<Buffer>,
184 cx: &mut ModelContext<Copilot>,
185 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
186 let (done_tx, done_rx) = oneshot::channel();
187
188 if buffer.read(cx).version() == self.snapshot.version {
189 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
190 } else {
191 let buffer = buffer.downgrade();
192 let id = buffer.entity_id();
193 let prev_pending_change =
194 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
195 self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
196 prev_pending_change.await;
197
198 let old_version = copilot
199 .update(&mut cx, |copilot, _| {
200 let server = copilot.server.as_authenticated().log_err()?;
201 let buffer = server.registered_buffers.get_mut(&id)?;
202 Some(buffer.snapshot.version.clone())
203 })
204 .ok()??;
205 let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
206
207 let content_changes = cx
208 .background_executor()
209 .spawn({
210 let new_snapshot = new_snapshot.clone();
211 async move {
212 new_snapshot
213 .edits_since::<(PointUtf16, usize)>(&old_version)
214 .map(|edit| {
215 let edit_start = edit.new.start.0;
216 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
217 let new_text = new_snapshot
218 .text_for_range(edit.new.start.1..edit.new.end.1)
219 .collect();
220 lsp::TextDocumentContentChangeEvent {
221 range: Some(lsp::Range::new(
222 point_to_lsp(edit_start),
223 point_to_lsp(edit_end),
224 )),
225 range_length: None,
226 text: new_text,
227 }
228 })
229 .collect::<Vec<_>>()
230 }
231 })
232 .await;
233
234 copilot
235 .update(&mut cx, |copilot, _| {
236 let server = copilot.server.as_authenticated().log_err()?;
237 let buffer = server.registered_buffers.get_mut(&id)?;
238 if !content_changes.is_empty() {
239 buffer.snapshot_version += 1;
240 buffer.snapshot = new_snapshot;
241 server
242 .lsp
243 .notify::<lsp::notification::DidChangeTextDocument>(
244 lsp::DidChangeTextDocumentParams {
245 text_document: lsp::VersionedTextDocumentIdentifier::new(
246 buffer.uri.clone(),
247 buffer.snapshot_version,
248 ),
249 content_changes,
250 },
251 )
252 .log_err();
253 }
254 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
255 Some(())
256 })
257 .ok()?;
258
259 Some(())
260 });
261 }
262
263 done_rx
264 }
265}
266
267#[derive(Debug)]
268pub struct Completion {
269 pub uuid: String,
270 pub range: Range<Anchor>,
271 pub text: String,
272}
273
274pub struct Copilot {
275 http: Arc<dyn HttpClient>,
276 node_runtime: Arc<dyn NodeRuntime>,
277 server: CopilotServer,
278 buffers: HashSet<WeakModel<Buffer>>,
279 server_id: LanguageServerId,
280 _subscription: gpui::Subscription,
281}
282
283pub enum Event {
284 CopilotLanguageServerStarted,
285}
286
287impl EventEmitter<Event> for Copilot {}
288
289impl Copilot {
290 pub fn global(cx: &AppContext) -> Option<Model<Self>> {
291 if cx.has_global::<Model<Self>>() {
292 Some(cx.global::<Model<Self>>().clone())
293 } else {
294 None
295 }
296 }
297
298 fn start(
299 new_server_id: LanguageServerId,
300 http: Arc<dyn HttpClient>,
301 node_runtime: Arc<dyn NodeRuntime>,
302 cx: &mut ModelContext<Self>,
303 ) -> Self {
304 let mut this = Self {
305 server_id: new_server_id,
306 http,
307 node_runtime,
308 server: CopilotServer::Disabled,
309 buffers: Default::default(),
310 _subscription: cx.on_app_quit(Self::shutdown_language_server),
311 };
312 this.enable_or_disable_copilot(cx);
313 cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
314 .detach();
315 this
316 }
317
318 fn shutdown_language_server(
319 &mut self,
320 _cx: &mut ModelContext<Self>,
321 ) -> impl Future<Output = ()> {
322 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
323 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
324 _ => None,
325 };
326
327 async move {
328 if let Some(shutdown) = shutdown {
329 shutdown.await;
330 }
331 }
332 }
333
334 fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
335 let server_id = self.server_id;
336 let http = self.http.clone();
337 let node_runtime = self.node_runtime.clone();
338 if all_language_settings(None, cx).copilot_enabled(None, None) {
339 if matches!(self.server, CopilotServer::Disabled) {
340 let start_task = cx
341 .spawn(move |this, cx| {
342 Self::start_language_server(server_id, http, node_runtime, this, cx)
343 })
344 .shared();
345 self.server = CopilotServer::Starting { task: start_task };
346 cx.notify();
347 }
348 } else {
349 self.server = CopilotServer::Disabled;
350 cx.notify();
351 }
352 }
353
354 // #[cfg(any(test, feature = "test-support"))]
355 // pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
356 // use node_runtime::FakeNodeRuntime;
357
358 // let (server, fake_server) =
359 // LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
360 // let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
361 // let node_runtime = FakeNodeRuntime::new();
362 // let this = cx.add_model(|_| Self {
363 // server_id: LanguageServerId(0),
364 // http: http.clone(),
365 // node_runtime,
366 // server: CopilotServer::Running(RunningCopilotServer {
367 // name: LanguageServerName(Arc::from("copilot")),
368 // lsp: Arc::new(server),
369 // sign_in_status: SignInStatus::Authorized,
370 // registered_buffers: Default::default(),
371 // }),
372 // buffers: Default::default(),
373 // });
374 // (this, fake_server)
375 // }
376
377 fn start_language_server(
378 new_server_id: LanguageServerId,
379 http: Arc<dyn HttpClient>,
380 node_runtime: Arc<dyn NodeRuntime>,
381 this: WeakModel<Self>,
382 mut cx: AsyncAppContext,
383 ) -> impl Future<Output = ()> {
384 async move {
385 let start_language_server = async {
386 let server_path = get_copilot_lsp(http).await?;
387 let node_path = node_runtime.binary_path().await?;
388 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
389 let binary = LanguageServerBinary {
390 path: node_path,
391 arguments,
392 };
393
394 let server = LanguageServer::new(
395 Arc::new(Mutex::new(None)),
396 new_server_id,
397 binary,
398 Path::new("/"),
399 None,
400 cx.clone(),
401 )?;
402
403 server
404 .on_notification::<StatusNotification, _>(
405 |_, _| { /* Silence the notification */ },
406 )
407 .detach();
408
409 let server = server.initialize(Default::default()).await?;
410
411 let status = server
412 .request::<request::CheckStatus>(request::CheckStatusParams {
413 local_checks_only: false,
414 })
415 .await?;
416
417 server
418 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
419 editor_info: request::EditorInfo {
420 name: "zed".into(),
421 version: env!("CARGO_PKG_VERSION").into(),
422 },
423 editor_plugin_info: request::EditorPluginInfo {
424 name: "zed-copilot".into(),
425 version: "0.0.1".into(),
426 },
427 })
428 .await?;
429
430 anyhow::Ok((server, status))
431 };
432
433 let server = start_language_server.await;
434 this.update(&mut cx, |this, cx| {
435 cx.notify();
436 match server {
437 Ok((server, status)) => {
438 this.server = CopilotServer::Running(RunningCopilotServer {
439 name: LanguageServerName(Arc::from("copilot")),
440 lsp: server,
441 sign_in_status: SignInStatus::SignedOut,
442 registered_buffers: Default::default(),
443 });
444 cx.emit(Event::CopilotLanguageServerStarted);
445 this.update_sign_in_status(status, cx);
446 }
447 Err(error) => {
448 this.server = CopilotServer::Error(error.to_string().into());
449 cx.notify()
450 }
451 }
452 })
453 .ok();
454 }
455 }
456
457 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
458 if let CopilotServer::Running(server) = &mut self.server {
459 let task = match &server.sign_in_status {
460 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
461 SignInStatus::SigningIn { task, .. } => {
462 cx.notify();
463 task.clone()
464 }
465 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
466 let lsp = server.lsp.clone();
467 let task = cx
468 .spawn(|this, mut cx| async move {
469 let sign_in = async {
470 let sign_in = lsp
471 .request::<request::SignInInitiate>(
472 request::SignInInitiateParams {},
473 )
474 .await?;
475 match sign_in {
476 request::SignInInitiateResult::AlreadySignedIn { user } => {
477 Ok(request::SignInStatus::Ok { user })
478 }
479 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
480 this.update(&mut cx, |this, cx| {
481 if let CopilotServer::Running(RunningCopilotServer {
482 sign_in_status: status,
483 ..
484 }) = &mut this.server
485 {
486 if let SignInStatus::SigningIn {
487 prompt: prompt_flow,
488 ..
489 } = status
490 {
491 *prompt_flow = Some(flow.clone());
492 cx.notify();
493 }
494 }
495 })?;
496 let response = lsp
497 .request::<request::SignInConfirm>(
498 request::SignInConfirmParams {
499 user_code: flow.user_code,
500 },
501 )
502 .await?;
503 Ok(response)
504 }
505 }
506 };
507
508 let sign_in = sign_in.await;
509 this.update(&mut cx, |this, cx| match sign_in {
510 Ok(status) => {
511 this.update_sign_in_status(status, cx);
512 Ok(())
513 }
514 Err(error) => {
515 this.update_sign_in_status(
516 request::SignInStatus::NotSignedIn,
517 cx,
518 );
519 Err(Arc::new(error))
520 }
521 })?
522 })
523 .shared();
524 server.sign_in_status = SignInStatus::SigningIn {
525 prompt: None,
526 task: task.clone(),
527 };
528 cx.notify();
529 task
530 }
531 };
532
533 cx.background_executor()
534 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
535 } else {
536 // If we're downloading, wait until download is finished
537 // If we're in a stuck state, display to the user
538 Task::ready(Err(anyhow!("copilot hasn't started yet")))
539 }
540 }
541
542 #[allow(dead_code)] // todo!()
543 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
544 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
545 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
546 let server = server.clone();
547 cx.background_executor().spawn(async move {
548 server
549 .request::<request::SignOut>(request::SignOutParams {})
550 .await?;
551 anyhow::Ok(())
552 })
553 } else {
554 Task::ready(Err(anyhow!("copilot hasn't started yet")))
555 }
556 }
557
558 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
559 let start_task = cx
560 .spawn({
561 let http = self.http.clone();
562 let node_runtime = self.node_runtime.clone();
563 let server_id = self.server_id;
564 move |this, cx| async move {
565 clear_copilot_dir().await;
566 Self::start_language_server(server_id, http, node_runtime, this, cx).await
567 }
568 })
569 .shared();
570
571 self.server = CopilotServer::Starting {
572 task: start_task.clone(),
573 };
574
575 cx.notify();
576
577 cx.background_executor().spawn(start_task)
578 }
579
580 pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc<LanguageServer>)> {
581 if let CopilotServer::Running(server) = &self.server {
582 Some((&server.name, &server.lsp))
583 } else {
584 None
585 }
586 }
587
588 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
589 let weak_buffer = buffer.downgrade();
590 self.buffers.insert(weak_buffer.clone());
591
592 if let CopilotServer::Running(RunningCopilotServer {
593 lsp: server,
594 sign_in_status: status,
595 registered_buffers,
596 ..
597 }) = &mut self.server
598 {
599 if !matches!(status, SignInStatus::Authorized { .. }) {
600 return;
601 }
602
603 registered_buffers
604 .entry(buffer.entity_id())
605 .or_insert_with(|| {
606 let uri: lsp::Url = uri_for_buffer(buffer, cx);
607 let language_id = id_for_language(buffer.read(cx).language());
608 let snapshot = buffer.read(cx).snapshot();
609 server
610 .notify::<lsp::notification::DidOpenTextDocument>(
611 lsp::DidOpenTextDocumentParams {
612 text_document: lsp::TextDocumentItem {
613 uri: uri.clone(),
614 language_id: language_id.clone(),
615 version: 0,
616 text: snapshot.text(),
617 },
618 },
619 )
620 .log_err();
621
622 RegisteredBuffer {
623 uri,
624 language_id,
625 snapshot,
626 snapshot_version: 0,
627 pending_buffer_change: Task::ready(Some(())),
628 _subscriptions: [
629 cx.subscribe(buffer, |this, buffer, event, cx| {
630 this.handle_buffer_event(buffer, event, cx).log_err();
631 }),
632 cx.observe_release(buffer, move |this, _buffer, _cx| {
633 this.buffers.remove(&weak_buffer);
634 this.unregister_buffer(&weak_buffer);
635 }),
636 ],
637 }
638 });
639 }
640 }
641
642 fn handle_buffer_event(
643 &mut self,
644 buffer: Model<Buffer>,
645 event: &language::Event,
646 cx: &mut ModelContext<Self>,
647 ) -> Result<()> {
648 if let Ok(server) = self.server.as_running() {
649 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
650 {
651 match event {
652 language::Event::Edited => {
653 let _ = registered_buffer.report_changes(&buffer, cx);
654 }
655 language::Event::Saved => {
656 server
657 .lsp
658 .notify::<lsp::notification::DidSaveTextDocument>(
659 lsp::DidSaveTextDocumentParams {
660 text_document: lsp::TextDocumentIdentifier::new(
661 registered_buffer.uri.clone(),
662 ),
663 text: None,
664 },
665 )?;
666 }
667 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
668 let new_language_id = id_for_language(buffer.read(cx).language());
669 let new_uri = uri_for_buffer(&buffer, cx);
670 if new_uri != registered_buffer.uri
671 || new_language_id != registered_buffer.language_id
672 {
673 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
674 registered_buffer.language_id = new_language_id;
675 server
676 .lsp
677 .notify::<lsp::notification::DidCloseTextDocument>(
678 lsp::DidCloseTextDocumentParams {
679 text_document: lsp::TextDocumentIdentifier::new(old_uri),
680 },
681 )?;
682 server
683 .lsp
684 .notify::<lsp::notification::DidOpenTextDocument>(
685 lsp::DidOpenTextDocumentParams {
686 text_document: lsp::TextDocumentItem::new(
687 registered_buffer.uri.clone(),
688 registered_buffer.language_id.clone(),
689 registered_buffer.snapshot_version,
690 registered_buffer.snapshot.text(),
691 ),
692 },
693 )?;
694 }
695 }
696 _ => {}
697 }
698 }
699 }
700
701 Ok(())
702 }
703
704 fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
705 if let Ok(server) = self.server.as_running() {
706 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
707 server
708 .lsp
709 .notify::<lsp::notification::DidCloseTextDocument>(
710 lsp::DidCloseTextDocumentParams {
711 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
712 },
713 )
714 .log_err();
715 }
716 }
717 }
718
719 pub fn completions<T>(
720 &mut self,
721 buffer: &Model<Buffer>,
722 position: T,
723 cx: &mut ModelContext<Self>,
724 ) -> Task<Result<Vec<Completion>>>
725 where
726 T: ToPointUtf16,
727 {
728 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
729 }
730
731 pub fn completions_cycling<T>(
732 &mut self,
733 buffer: &Model<Buffer>,
734 position: T,
735 cx: &mut ModelContext<Self>,
736 ) -> Task<Result<Vec<Completion>>>
737 where
738 T: ToPointUtf16,
739 {
740 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
741 }
742
743 pub fn accept_completion(
744 &mut self,
745 completion: &Completion,
746 cx: &mut ModelContext<Self>,
747 ) -> Task<Result<()>> {
748 let server = match self.server.as_authenticated() {
749 Ok(server) => server,
750 Err(error) => return Task::ready(Err(error)),
751 };
752 let request =
753 server
754 .lsp
755 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
756 uuid: completion.uuid.clone(),
757 });
758 cx.background_executor().spawn(async move {
759 request.await?;
760 Ok(())
761 })
762 }
763
764 pub fn discard_completions(
765 &mut self,
766 completions: &[Completion],
767 cx: &mut ModelContext<Self>,
768 ) -> Task<Result<()>> {
769 let server = match self.server.as_authenticated() {
770 Ok(server) => server,
771 Err(error) => return Task::ready(Err(error)),
772 };
773 let request =
774 server
775 .lsp
776 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
777 uuids: completions
778 .iter()
779 .map(|completion| completion.uuid.clone())
780 .collect(),
781 });
782 cx.background_executor().spawn(async move {
783 request.await?;
784 Ok(())
785 })
786 }
787
788 fn request_completions<R, T>(
789 &mut self,
790 buffer: &Model<Buffer>,
791 position: T,
792 cx: &mut ModelContext<Self>,
793 ) -> Task<Result<Vec<Completion>>>
794 where
795 R: 'static
796 + lsp::request::Request<
797 Params = request::GetCompletionsParams,
798 Result = request::GetCompletionsResult,
799 >,
800 T: ToPointUtf16,
801 {
802 self.register_buffer(buffer, cx);
803
804 let server = match self.server.as_authenticated() {
805 Ok(server) => server,
806 Err(error) => return Task::ready(Err(error)),
807 };
808 let lsp = server.lsp.clone();
809 let registered_buffer = server
810 .registered_buffers
811 .get_mut(&buffer.entity_id())
812 .unwrap();
813 let snapshot = registered_buffer.report_changes(buffer, cx);
814 let buffer = buffer.read(cx);
815 let uri = registered_buffer.uri.clone();
816 let position = position.to_point_utf16(buffer);
817 let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
818 let tab_size = settings.tab_size;
819 let hard_tabs = settings.hard_tabs;
820 let relative_path = buffer
821 .file()
822 .map(|file| file.path().to_path_buf())
823 .unwrap_or_default();
824
825 cx.background_executor().spawn(async move {
826 let (version, snapshot) = snapshot.await?;
827 let result = lsp
828 .request::<R>(request::GetCompletionsParams {
829 doc: request::GetCompletionsDocument {
830 uri,
831 tab_size: tab_size.into(),
832 indent_size: 1,
833 insert_spaces: !hard_tabs,
834 relative_path: relative_path.to_string_lossy().into(),
835 position: point_to_lsp(position),
836 version: version.try_into().unwrap(),
837 },
838 })
839 .await?;
840 let completions = result
841 .completions
842 .into_iter()
843 .map(|completion| {
844 let start = snapshot
845 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
846 let end =
847 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
848 Completion {
849 uuid: completion.uuid,
850 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
851 text: completion.text,
852 }
853 })
854 .collect();
855 anyhow::Ok(completions)
856 })
857 }
858
859 pub fn status(&self) -> Status {
860 match &self.server {
861 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
862 CopilotServer::Disabled => Status::Disabled,
863 CopilotServer::Error(error) => Status::Error(error.clone()),
864 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
865 match sign_in_status {
866 SignInStatus::Authorized { .. } => Status::Authorized,
867 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
868 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
869 prompt: prompt.clone(),
870 },
871 SignInStatus::SignedOut => Status::SignedOut,
872 }
873 }
874 }
875 }
876
877 fn update_sign_in_status(
878 &mut self,
879 lsp_status: request::SignInStatus,
880 cx: &mut ModelContext<Self>,
881 ) {
882 self.buffers.retain(|buffer| buffer.is_upgradable());
883
884 if let Ok(server) = self.server.as_running() {
885 match lsp_status {
886 request::SignInStatus::Ok { .. }
887 | request::SignInStatus::MaybeOk { .. }
888 | request::SignInStatus::AlreadySignedIn { .. } => {
889 server.sign_in_status = SignInStatus::Authorized;
890 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
891 if let Some(buffer) = buffer.upgrade() {
892 self.register_buffer(&buffer, cx);
893 }
894 }
895 }
896 request::SignInStatus::NotAuthorized { .. } => {
897 server.sign_in_status = SignInStatus::Unauthorized;
898 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
899 self.unregister_buffer(&buffer);
900 }
901 }
902 request::SignInStatus::NotSignedIn => {
903 server.sign_in_status = SignInStatus::SignedOut;
904 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
905 self.unregister_buffer(&buffer);
906 }
907 }
908 }
909
910 cx.notify();
911 }
912 }
913}
914
915fn id_for_language(language: Option<&Arc<Language>>) -> String {
916 let language_name = language.map(|language| language.name());
917 match language_name.as_deref() {
918 Some("Plain Text") => "plaintext".to_string(),
919 Some(language_name) => language_name.to_lowercase(),
920 None => "plaintext".to_string(),
921 }
922}
923
924fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp::Url {
925 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
926 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
927 } else {
928 format!("buffer://{}", buffer.entity_id()).parse().unwrap()
929 }
930}
931
932async fn clear_copilot_dir() {
933 remove_matching(&paths::COPILOT_DIR, |_| true).await
934}
935
936async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
937 const SERVER_PATH: &'static str = "dist/agent.js";
938
939 ///Check for the latest copilot language server and download it if we haven't already
940 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
941 let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
942
943 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
944
945 fs::create_dir_all(version_dir).await?;
946 let server_path = version_dir.join(SERVER_PATH);
947
948 if fs::metadata(&server_path).await.is_err() {
949 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
950 let dist_dir = version_dir.join("dist");
951 fs::create_dir_all(dist_dir.as_path()).await?;
952
953 let url = &release
954 .assets
955 .get(0)
956 .context("Github release for copilot contained no assets")?
957 .browser_download_url;
958
959 let mut response = http
960 .get(&url, Default::default(), true)
961 .await
962 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
963 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
964 let archive = Archive::new(decompressed_bytes);
965 archive.unpack(dist_dir).await?;
966
967 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
968 }
969
970 Ok(server_path)
971 }
972
973 match fetch_latest(http).await {
974 ok @ Result::Ok(..) => ok,
975 e @ Err(..) => {
976 e.log_err();
977 // Fetch a cached binary, if it exists
978 (|| async move {
979 let mut last_version_dir = None;
980 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
981 while let Some(entry) = entries.next().await {
982 let entry = entry?;
983 if entry.file_type().await?.is_dir() {
984 last_version_dir = Some(entry.path());
985 }
986 }
987 let last_version_dir =
988 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
989 let server_path = last_version_dir.join(SERVER_PATH);
990 if server_path.exists() {
991 Ok(server_path)
992 } else {
993 Err(anyhow!(
994 "missing executable in directory {:?}",
995 last_version_dir
996 ))
997 }
998 })()
999 .await
1000 }
1001 }
1002}
1003
1004// #[cfg(test)]
1005// mod tests {
1006// use super::*;
1007// use gpui::{executor::Deterministic, TestAppContext};
1008
1009// #[gpui::test(iterations = 10)]
1010// async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1011// deterministic.forbid_parking();
1012// let (copilot, mut lsp) = Copilot::fake(cx);
1013
1014// let buffer_1 = cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "Hello"));
1015// let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1016// copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1017// assert_eq!(
1018// lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1019// .await,
1020// lsp::DidOpenTextDocumentParams {
1021// text_document: lsp::TextDocumentItem::new(
1022// buffer_1_uri.clone(),
1023// "plaintext".into(),
1024// 0,
1025// "Hello".into()
1026// ),
1027// }
1028// );
1029
1030// let buffer_2 = cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "Goodbye"));
1031// let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1032// copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1033// assert_eq!(
1034// lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1035// .await,
1036// lsp::DidOpenTextDocumentParams {
1037// text_document: lsp::TextDocumentItem::new(
1038// buffer_2_uri.clone(),
1039// "plaintext".into(),
1040// 0,
1041// "Goodbye".into()
1042// ),
1043// }
1044// );
1045
1046// buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1047// assert_eq!(
1048// lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1049// .await,
1050// lsp::DidChangeTextDocumentParams {
1051// text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1052// content_changes: vec![lsp::TextDocumentContentChangeEvent {
1053// range: Some(lsp::Range::new(
1054// lsp::Position::new(0, 5),
1055// lsp::Position::new(0, 5)
1056// )),
1057// range_length: None,
1058// text: " world".into(),
1059// }],
1060// }
1061// );
1062
1063// // Ensure updates to the file are reflected in the LSP.
1064// buffer_1
1065// .update(cx, |buffer, cx| {
1066// buffer.file_updated(
1067// Arc::new(File {
1068// abs_path: "/root/child/buffer-1".into(),
1069// path: Path::new("child/buffer-1").into(),
1070// }),
1071// cx,
1072// )
1073// })
1074// .await;
1075// assert_eq!(
1076// lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1077// .await,
1078// lsp::DidCloseTextDocumentParams {
1079// text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1080// }
1081// );
1082// let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1083// assert_eq!(
1084// lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1085// .await,
1086// lsp::DidOpenTextDocumentParams {
1087// text_document: lsp::TextDocumentItem::new(
1088// buffer_1_uri.clone(),
1089// "plaintext".into(),
1090// 1,
1091// "Hello world".into()
1092// ),
1093// }
1094// );
1095
1096// // Ensure all previously-registered buffers are closed when signing out.
1097// lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1098// Ok(request::SignOutResult {})
1099// });
1100// copilot
1101// .update(cx, |copilot, cx| copilot.sign_out(cx))
1102// .await
1103// .unwrap();
1104// assert_eq!(
1105// lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1106// .await,
1107// lsp::DidCloseTextDocumentParams {
1108// text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1109// }
1110// );
1111// assert_eq!(
1112// lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1113// .await,
1114// lsp::DidCloseTextDocumentParams {
1115// text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1116// }
1117// );
1118
1119// // Ensure all previously-registered buffers are re-opened when signing in.
1120// lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1121// Ok(request::SignInInitiateResult::AlreadySignedIn {
1122// user: "user-1".into(),
1123// })
1124// });
1125// copilot
1126// .update(cx, |copilot, cx| copilot.sign_in(cx))
1127// .await
1128// .unwrap();
1129// assert_eq!(
1130// lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1131// .await,
1132// lsp::DidOpenTextDocumentParams {
1133// text_document: lsp::TextDocumentItem::new(
1134// buffer_2_uri.clone(),
1135// "plaintext".into(),
1136// 0,
1137// "Goodbye".into()
1138// ),
1139// }
1140// );
1141// assert_eq!(
1142// lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1143// .await,
1144// lsp::DidOpenTextDocumentParams {
1145// text_document: lsp::TextDocumentItem::new(
1146// buffer_1_uri.clone(),
1147// "plaintext".into(),
1148// 0,
1149// "Hello world".into()
1150// ),
1151// }
1152// );
1153
1154// // Dropping a buffer causes it to be closed on the LSP side as well.
1155// cx.update(|_| drop(buffer_2));
1156// assert_eq!(
1157// lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1158// .await,
1159// lsp::DidCloseTextDocumentParams {
1160// text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1161// }
1162// );
1163// }
1164
1165// struct File {
1166// abs_path: PathBuf,
1167// path: Arc<Path>,
1168// }
1169
1170// impl language2::File for File {
1171// fn as_local(&self) -> Option<&dyn language2::LocalFile> {
1172// Some(self)
1173// }
1174
1175// fn mtime(&self) -> std::time::SystemTime {
1176// unimplemented!()
1177// }
1178
1179// fn path(&self) -> &Arc<Path> {
1180// &self.path
1181// }
1182
1183// fn full_path(&self, _: &AppContext) -> PathBuf {
1184// unimplemented!()
1185// }
1186
1187// fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1188// unimplemented!()
1189// }
1190
1191// fn is_deleted(&self) -> bool {
1192// unimplemented!()
1193// }
1194
1195// fn as_any(&self) -> &dyn std::any::Any {
1196// unimplemented!()
1197// }
1198
1199// fn to_proto(&self) -> rpc::proto::File {
1200// unimplemented!()
1201// }
1202
1203// fn worktree_id(&self) -> usize {
1204// 0
1205// }
1206// }
1207
1208// impl language::LocalFile for File {
1209// fn abs_path(&self, _: &AppContext) -> PathBuf {
1210// self.abs_path.clone()
1211// }
1212
1213// fn load(&self, _: &AppContext) -> Task<Result<String>> {
1214// unimplemented!()
1215// }
1216
1217// fn buffer_reloaded(
1218// &self,
1219// _: u64,
1220// _: &clock::Global,
1221// _: language::RopeFingerprint,
1222// _: language::LineEnding,
1223// _: std::time::SystemTime,
1224// _: &mut AppContext,
1225// ) {
1226// unimplemented!()
1227// }
1228// }
1229// }