1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
11};
12use language::{
13 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
14 ToPointUtf16,
15};
16use log::{debug, error};
17use lsp::{LanguageServer, LanguageServerId};
18use node_runtime::NodeRuntime;
19use request::{LogMessage, StatusNotification};
20use settings::Settings;
21use smol::{fs, io::BufReader, stream::StreamExt};
22use std::{
23 ffi::OsString,
24 mem,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use util::{
30 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
31};
32
33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
34actions!(copilot_auth, [SignIn, SignOut]);
35
36const COPILOT_NAMESPACE: &'static str = "copilot";
37actions!(
38 copilot,
39 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
40);
41
42pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
43 let copilot = cx.add_model({
44 let node_runtime = node_runtime.clone();
45 move |cx| Copilot::start(http, node_runtime, cx)
46 });
47 cx.set_global(copilot.clone());
48
49 cx.observe(&copilot, |handle, cx| {
50 let status = handle.read(cx).status();
51 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
52 match status {
53 Status::Disabled => {
54 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
55 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
56 }
57 Status::Authorized => {
58 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
59 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
60 }
61 _ => {
62 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
63 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
64 }
65 }
66 });
67 })
68 .detach();
69
70 sign_in::init(cx);
71 cx.add_global_action(|_: &SignIn, cx| {
72 if let Some(copilot) = Copilot::global(cx) {
73 copilot
74 .update(cx, |copilot, cx| copilot.sign_in(cx))
75 .detach_and_log_err(cx);
76 }
77 });
78 cx.add_global_action(|_: &SignOut, cx| {
79 if let Some(copilot) = Copilot::global(cx) {
80 copilot
81 .update(cx, |copilot, cx| copilot.sign_out(cx))
82 .detach_and_log_err(cx);
83 }
84 });
85
86 cx.add_global_action(|_: &Reinstall, cx| {
87 if let Some(copilot) = Copilot::global(cx) {
88 copilot
89 .update(cx, |copilot, cx| copilot.reinstall(cx))
90 .detach();
91 }
92 });
93}
94
95enum CopilotServer {
96 Disabled,
97 Starting { task: Shared<Task<()>> },
98 Error(Arc<str>),
99 Running(RunningCopilotServer),
100}
101
102impl CopilotServer {
103 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
104 let server = self.as_running()?;
105 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
106 Ok(server)
107 } else {
108 Err(anyhow!("must sign in before using copilot"))
109 }
110 }
111
112 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
113 match self {
114 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
115 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
116 CopilotServer::Error(error) => Err(anyhow!(
117 "copilot was not started because of an error: {}",
118 error
119 )),
120 CopilotServer::Running(server) => Ok(server),
121 }
122 }
123}
124
125struct RunningCopilotServer {
126 lsp: Arc<LanguageServer>,
127 sign_in_status: SignInStatus,
128 registered_buffers: HashMap<usize, RegisteredBuffer>,
129}
130
131#[derive(Clone, Debug)]
132enum SignInStatus {
133 Authorized,
134 Unauthorized,
135 SigningIn {
136 prompt: Option<request::PromptUserDeviceFlow>,
137 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
138 },
139 SignedOut,
140}
141
142#[derive(Debug, Clone)]
143pub enum Status {
144 Starting {
145 task: Shared<Task<()>>,
146 },
147 Error(Arc<str>),
148 Disabled,
149 SignedOut,
150 SigningIn {
151 prompt: Option<request::PromptUserDeviceFlow>,
152 },
153 Unauthorized,
154 Authorized,
155}
156
157impl Status {
158 pub fn is_authorized(&self) -> bool {
159 matches!(self, Status::Authorized)
160 }
161}
162
163struct RegisteredBuffer {
164 id: usize,
165 uri: lsp::Url,
166 language_id: String,
167 snapshot: BufferSnapshot,
168 snapshot_version: i32,
169 _subscriptions: [gpui::Subscription; 2],
170 pending_buffer_change: Task<Option<()>>,
171}
172
173impl RegisteredBuffer {
174 fn report_changes(
175 &mut self,
176 buffer: &ModelHandle<Buffer>,
177 cx: &mut ModelContext<Copilot>,
178 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
179 let id = self.id;
180 let (done_tx, done_rx) = oneshot::channel();
181
182 if buffer.read(cx).version() == self.snapshot.version {
183 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
184 } else {
185 let buffer = buffer.downgrade();
186 let prev_pending_change =
187 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
188 self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
189 prev_pending_change.await;
190
191 let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
192 let server = copilot.server.as_authenticated().log_err()?;
193 let buffer = server.registered_buffers.get_mut(&id)?;
194 Some(buffer.snapshot.version.clone())
195 })?;
196 let new_snapshot = buffer
197 .upgrade(&cx)?
198 .read_with(&cx, |buffer, _| buffer.snapshot());
199
200 let content_changes = cx
201 .background()
202 .spawn({
203 let new_snapshot = new_snapshot.clone();
204 async move {
205 new_snapshot
206 .edits_since::<(PointUtf16, usize)>(&old_version)
207 .map(|edit| {
208 let edit_start = edit.new.start.0;
209 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
210 let new_text = new_snapshot
211 .text_for_range(edit.new.start.1..edit.new.end.1)
212 .collect();
213 lsp::TextDocumentContentChangeEvent {
214 range: Some(lsp::Range::new(
215 point_to_lsp(edit_start),
216 point_to_lsp(edit_end),
217 )),
218 range_length: None,
219 text: new_text,
220 }
221 })
222 .collect::<Vec<_>>()
223 }
224 })
225 .await;
226
227 copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
228 let server = copilot.server.as_authenticated().log_err()?;
229 let buffer = server.registered_buffers.get_mut(&id)?;
230 if !content_changes.is_empty() {
231 buffer.snapshot_version += 1;
232 buffer.snapshot = new_snapshot;
233 server
234 .lsp
235 .notify::<lsp::notification::DidChangeTextDocument>(
236 lsp::DidChangeTextDocumentParams {
237 text_document: lsp::VersionedTextDocumentIdentifier::new(
238 buffer.uri.clone(),
239 buffer.snapshot_version,
240 ),
241 content_changes,
242 },
243 )
244 .log_err();
245 }
246 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
247 Some(())
248 })?;
249
250 Some(())
251 });
252 }
253
254 done_rx
255 }
256}
257
258#[derive(Debug)]
259pub struct Completion {
260 uuid: String,
261 pub range: Range<Anchor>,
262 pub text: String,
263}
264
265pub struct Copilot {
266 http: Arc<dyn HttpClient>,
267 node_runtime: Arc<NodeRuntime>,
268 server: CopilotServer,
269 buffers: HashMap<usize, WeakModelHandle<Buffer>>,
270}
271
272impl Entity for Copilot {
273 type Event = ();
274}
275
276impl Copilot {
277 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
278 if cx.has_global::<ModelHandle<Self>>() {
279 Some(cx.global::<ModelHandle<Self>>().clone())
280 } else {
281 None
282 }
283 }
284
285 fn start(
286 http: Arc<dyn HttpClient>,
287 node_runtime: Arc<NodeRuntime>,
288 cx: &mut ModelContext<Self>,
289 ) -> Self {
290 cx.observe_global::<Settings, _>({
291 let http = http.clone();
292 let node_runtime = node_runtime.clone();
293 move |this, cx| {
294 if cx.global::<Settings>().features.copilot {
295 if matches!(this.server, CopilotServer::Disabled) {
296 let start_task = cx
297 .spawn({
298 let http = http.clone();
299 let node_runtime = node_runtime.clone();
300 move |this, cx| {
301 Self::start_language_server(http, node_runtime, this, cx)
302 }
303 })
304 .shared();
305 this.server = CopilotServer::Starting { task: start_task };
306 cx.notify();
307 }
308 } else {
309 this.server = CopilotServer::Disabled;
310 cx.notify();
311 }
312 }
313 })
314 .detach();
315
316 if cx.global::<Settings>().features.copilot {
317 let start_task = cx
318 .spawn({
319 let http = http.clone();
320 let node_runtime = node_runtime.clone();
321 move |this, cx| async {
322 Self::start_language_server(http, node_runtime, this, cx).await
323 }
324 })
325 .shared();
326
327 Self {
328 http,
329 node_runtime,
330 server: CopilotServer::Starting { task: start_task },
331 buffers: Default::default(),
332 }
333 } else {
334 Self {
335 http,
336 node_runtime,
337 server: CopilotServer::Disabled,
338 buffers: Default::default(),
339 }
340 }
341 }
342
343 #[cfg(any(test, feature = "test-support"))]
344 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
345 let (server, fake_server) =
346 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
347 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
348 let this = cx.add_model(|cx| Self {
349 http: http.clone(),
350 node_runtime: NodeRuntime::new(http, cx.background().clone()),
351 server: CopilotServer::Running(RunningCopilotServer {
352 lsp: Arc::new(server),
353 sign_in_status: SignInStatus::Authorized,
354 registered_buffers: Default::default(),
355 }),
356 buffers: Default::default(),
357 });
358 (this, fake_server)
359 }
360
361 fn start_language_server(
362 http: Arc<dyn HttpClient>,
363 node_runtime: Arc<NodeRuntime>,
364 this: ModelHandle<Self>,
365 mut cx: AsyncAppContext,
366 ) -> impl Future<Output = ()> {
367 async move {
368 let start_language_server = async {
369 let server_path = get_copilot_lsp(http).await?;
370 let node_path = node_runtime.binary_path().await?;
371 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
372 let server = LanguageServer::new(
373 LanguageServerId(0),
374 &node_path,
375 arguments,
376 Path::new("/"),
377 None,
378 cx.clone(),
379 )?;
380
381 server
382 .on_notification::<LogMessage, _>(|params, _cx| {
383 match params.level {
384 // Copilot is pretty agressive about logging
385 0 => debug!("copilot: {}", params.message),
386 1 => debug!("copilot: {}", params.message),
387 _ => error!("copilot: {}", params.message),
388 }
389
390 debug!("copilot metadata: {}", params.metadata_str);
391 debug!("copilot extra: {:?}", params.extra);
392 })
393 .detach();
394
395 server
396 .on_notification::<StatusNotification, _>(
397 |_, _| { /* Silence the notification */ },
398 )
399 .detach();
400
401 let server = server.initialize(Default::default()).await?;
402
403 let status = server
404 .request::<request::CheckStatus>(request::CheckStatusParams {
405 local_checks_only: false,
406 })
407 .await?;
408
409 server
410 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
411 editor_info: request::EditorInfo {
412 name: "zed".into(),
413 version: env!("CARGO_PKG_VERSION").into(),
414 },
415 editor_plugin_info: request::EditorPluginInfo {
416 name: "zed-copilot".into(),
417 version: "0.0.1".into(),
418 },
419 })
420 .await?;
421
422 anyhow::Ok((server, status))
423 };
424
425 let server = start_language_server.await;
426 this.update(&mut cx, |this, cx| {
427 cx.notify();
428 match server {
429 Ok((server, status)) => {
430 this.server = CopilotServer::Running(RunningCopilotServer {
431 lsp: server,
432 sign_in_status: SignInStatus::SignedOut,
433 registered_buffers: Default::default(),
434 });
435 this.update_sign_in_status(status, cx);
436 }
437 Err(error) => {
438 this.server = CopilotServer::Error(error.to_string().into());
439 cx.notify()
440 }
441 }
442 })
443 }
444 }
445
446 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
447 if let CopilotServer::Running(server) = &mut self.server {
448 let task = match &server.sign_in_status {
449 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
450 Task::ready(Ok(())).shared()
451 }
452 SignInStatus::SigningIn { task, .. } => {
453 cx.notify();
454 task.clone()
455 }
456 SignInStatus::SignedOut => {
457 let lsp = server.lsp.clone();
458 let task = cx
459 .spawn(|this, mut cx| async move {
460 let sign_in = async {
461 let sign_in = lsp
462 .request::<request::SignInInitiate>(
463 request::SignInInitiateParams {},
464 )
465 .await?;
466 match sign_in {
467 request::SignInInitiateResult::AlreadySignedIn { user } => {
468 Ok(request::SignInStatus::Ok { user })
469 }
470 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
471 this.update(&mut cx, |this, cx| {
472 if let CopilotServer::Running(RunningCopilotServer {
473 sign_in_status: status,
474 ..
475 }) = &mut this.server
476 {
477 if let SignInStatus::SigningIn {
478 prompt: prompt_flow,
479 ..
480 } = status
481 {
482 *prompt_flow = Some(flow.clone());
483 cx.notify();
484 }
485 }
486 });
487 let response = lsp
488 .request::<request::SignInConfirm>(
489 request::SignInConfirmParams {
490 user_code: flow.user_code,
491 },
492 )
493 .await?;
494 Ok(response)
495 }
496 }
497 };
498
499 let sign_in = sign_in.await;
500 this.update(&mut cx, |this, cx| match sign_in {
501 Ok(status) => {
502 this.update_sign_in_status(status, cx);
503 Ok(())
504 }
505 Err(error) => {
506 this.update_sign_in_status(
507 request::SignInStatus::NotSignedIn,
508 cx,
509 );
510 Err(Arc::new(error))
511 }
512 })
513 })
514 .shared();
515 server.sign_in_status = SignInStatus::SigningIn {
516 prompt: None,
517 task: task.clone(),
518 };
519 cx.notify();
520 task
521 }
522 };
523
524 cx.foreground()
525 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
526 } else {
527 // If we're downloading, wait until download is finished
528 // If we're in a stuck state, display to the user
529 Task::ready(Err(anyhow!("copilot hasn't started yet")))
530 }
531 }
532
533 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
534 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
535 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
536 let server = server.clone();
537 cx.background().spawn(async move {
538 server
539 .request::<request::SignOut>(request::SignOutParams {})
540 .await?;
541 anyhow::Ok(())
542 })
543 } else {
544 Task::ready(Err(anyhow!("copilot hasn't started yet")))
545 }
546 }
547
548 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
549 let start_task = cx
550 .spawn({
551 let http = self.http.clone();
552 let node_runtime = self.node_runtime.clone();
553 move |this, cx| async move {
554 clear_copilot_dir().await;
555 Self::start_language_server(http, node_runtime, this, cx).await
556 }
557 })
558 .shared();
559
560 self.server = CopilotServer::Starting {
561 task: start_task.clone(),
562 };
563
564 cx.notify();
565
566 cx.foreground().spawn(start_task)
567 }
568
569 pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
570 let buffer_id = buffer.id();
571 self.buffers.insert(buffer_id, buffer.downgrade());
572
573 if let CopilotServer::Running(RunningCopilotServer {
574 lsp: server,
575 sign_in_status: status,
576 registered_buffers,
577 ..
578 }) = &mut self.server
579 {
580 if !matches!(status, SignInStatus::Authorized { .. }) {
581 return;
582 }
583
584 registered_buffers.entry(buffer.id()).or_insert_with(|| {
585 let uri: lsp::Url = uri_for_buffer(buffer, cx);
586 let language_id = id_for_language(buffer.read(cx).language());
587 let snapshot = buffer.read(cx).snapshot();
588 server
589 .notify::<lsp::notification::DidOpenTextDocument>(
590 lsp::DidOpenTextDocumentParams {
591 text_document: lsp::TextDocumentItem {
592 uri: uri.clone(),
593 language_id: language_id.clone(),
594 version: 0,
595 text: snapshot.text(),
596 },
597 },
598 )
599 .log_err();
600
601 RegisteredBuffer {
602 id: buffer_id,
603 uri,
604 language_id,
605 snapshot,
606 snapshot_version: 0,
607 pending_buffer_change: Task::ready(Some(())),
608 _subscriptions: [
609 cx.subscribe(buffer, |this, buffer, event, cx| {
610 this.handle_buffer_event(buffer, event, cx).log_err();
611 }),
612 cx.observe_release(buffer, move |this, _buffer, _cx| {
613 this.buffers.remove(&buffer_id);
614 this.unregister_buffer(buffer_id);
615 }),
616 ],
617 }
618 });
619 }
620 }
621
622 fn handle_buffer_event(
623 &mut self,
624 buffer: ModelHandle<Buffer>,
625 event: &language::Event,
626 cx: &mut ModelContext<Self>,
627 ) -> Result<()> {
628 if let Ok(server) = self.server.as_running() {
629 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
630 match event {
631 language::Event::Edited => {
632 let _ = registered_buffer.report_changes(&buffer, cx);
633 }
634 language::Event::Saved => {
635 server
636 .lsp
637 .notify::<lsp::notification::DidSaveTextDocument>(
638 lsp::DidSaveTextDocumentParams {
639 text_document: lsp::TextDocumentIdentifier::new(
640 registered_buffer.uri.clone(),
641 ),
642 text: None,
643 },
644 )?;
645 }
646 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
647 let new_language_id = id_for_language(buffer.read(cx).language());
648 let new_uri = uri_for_buffer(&buffer, cx);
649 if new_uri != registered_buffer.uri
650 || new_language_id != registered_buffer.language_id
651 {
652 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
653 registered_buffer.language_id = new_language_id;
654 server
655 .lsp
656 .notify::<lsp::notification::DidCloseTextDocument>(
657 lsp::DidCloseTextDocumentParams {
658 text_document: lsp::TextDocumentIdentifier::new(old_uri),
659 },
660 )?;
661 server
662 .lsp
663 .notify::<lsp::notification::DidOpenTextDocument>(
664 lsp::DidOpenTextDocumentParams {
665 text_document: lsp::TextDocumentItem::new(
666 registered_buffer.uri.clone(),
667 registered_buffer.language_id.clone(),
668 registered_buffer.snapshot_version,
669 registered_buffer.snapshot.text(),
670 ),
671 },
672 )?;
673 }
674 }
675 _ => {}
676 }
677 }
678 }
679
680 Ok(())
681 }
682
683 fn unregister_buffer(&mut self, buffer_id: usize) {
684 if let Ok(server) = self.server.as_running() {
685 if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
686 server
687 .lsp
688 .notify::<lsp::notification::DidCloseTextDocument>(
689 lsp::DidCloseTextDocumentParams {
690 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
691 },
692 )
693 .log_err();
694 }
695 }
696 }
697
698 pub fn completions<T>(
699 &mut self,
700 buffer: &ModelHandle<Buffer>,
701 position: T,
702 cx: &mut ModelContext<Self>,
703 ) -> Task<Result<Vec<Completion>>>
704 where
705 T: ToPointUtf16,
706 {
707 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
708 }
709
710 pub fn completions_cycling<T>(
711 &mut self,
712 buffer: &ModelHandle<Buffer>,
713 position: T,
714 cx: &mut ModelContext<Self>,
715 ) -> Task<Result<Vec<Completion>>>
716 where
717 T: ToPointUtf16,
718 {
719 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
720 }
721
722 pub fn accept_completion(
723 &mut self,
724 completion: &Completion,
725 cx: &mut ModelContext<Self>,
726 ) -> Task<Result<()>> {
727 let server = match self.server.as_authenticated() {
728 Ok(server) => server,
729 Err(error) => return Task::ready(Err(error)),
730 };
731 let request =
732 server
733 .lsp
734 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
735 uuid: completion.uuid.clone(),
736 });
737 cx.background().spawn(async move {
738 request.await?;
739 Ok(())
740 })
741 }
742
743 pub fn discard_completions(
744 &mut self,
745 completions: &[Completion],
746 cx: &mut ModelContext<Self>,
747 ) -> Task<Result<()>> {
748 let server = match self.server.as_authenticated() {
749 Ok(server) => server,
750 Err(error) => return Task::ready(Err(error)),
751 };
752 let request =
753 server
754 .lsp
755 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
756 uuids: completions
757 .iter()
758 .map(|completion| completion.uuid.clone())
759 .collect(),
760 });
761 cx.background().spawn(async move {
762 request.await?;
763 Ok(())
764 })
765 }
766
767 fn request_completions<R, T>(
768 &mut self,
769 buffer: &ModelHandle<Buffer>,
770 position: T,
771 cx: &mut ModelContext<Self>,
772 ) -> Task<Result<Vec<Completion>>>
773 where
774 R: 'static
775 + lsp::request::Request<
776 Params = request::GetCompletionsParams,
777 Result = request::GetCompletionsResult,
778 >,
779 T: ToPointUtf16,
780 {
781 self.register_buffer(buffer, cx);
782
783 let server = match self.server.as_authenticated() {
784 Ok(server) => server,
785 Err(error) => return Task::ready(Err(error)),
786 };
787 let lsp = server.lsp.clone();
788 let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
789 let snapshot = registered_buffer.report_changes(buffer, cx);
790 let buffer = buffer.read(cx);
791 let uri = registered_buffer.uri.clone();
792 let settings = cx.global::<Settings>();
793 let position = position.to_point_utf16(buffer);
794 let language = buffer.language_at(position);
795 let language_name = language.map(|language| language.name());
796 let language_name = language_name.as_deref();
797 let tab_size = settings.tab_size(language_name);
798 let hard_tabs = settings.hard_tabs(language_name);
799 let relative_path = buffer
800 .file()
801 .map(|file| file.path().to_path_buf())
802 .unwrap_or_default();
803
804 cx.foreground().spawn(async move {
805 let (version, snapshot) = snapshot.await?;
806 let result = lsp
807 .request::<R>(request::GetCompletionsParams {
808 doc: request::GetCompletionsDocument {
809 uri,
810 tab_size: tab_size.into(),
811 indent_size: 1,
812 insert_spaces: !hard_tabs,
813 relative_path: relative_path.to_string_lossy().into(),
814 position: point_to_lsp(position),
815 version: version.try_into().unwrap(),
816 },
817 })
818 .await?;
819 let completions = result
820 .completions
821 .into_iter()
822 .map(|completion| {
823 let start = snapshot
824 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
825 let end =
826 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
827 Completion {
828 uuid: completion.uuid,
829 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
830 text: completion.text,
831 }
832 })
833 .collect();
834 anyhow::Ok(completions)
835 })
836 }
837
838 pub fn status(&self) -> Status {
839 match &self.server {
840 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
841 CopilotServer::Disabled => Status::Disabled,
842 CopilotServer::Error(error) => Status::Error(error.clone()),
843 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
844 match sign_in_status {
845 SignInStatus::Authorized { .. } => Status::Authorized,
846 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
847 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
848 prompt: prompt.clone(),
849 },
850 SignInStatus::SignedOut => Status::SignedOut,
851 }
852 }
853 }
854 }
855
856 fn update_sign_in_status(
857 &mut self,
858 lsp_status: request::SignInStatus,
859 cx: &mut ModelContext<Self>,
860 ) {
861 self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
862
863 if let Ok(server) = self.server.as_running() {
864 match lsp_status {
865 request::SignInStatus::Ok { .. }
866 | request::SignInStatus::MaybeOk { .. }
867 | request::SignInStatus::AlreadySignedIn { .. } => {
868 server.sign_in_status = SignInStatus::Authorized;
869 for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
870 if let Some(buffer) = buffer.upgrade(cx) {
871 self.register_buffer(&buffer, cx);
872 }
873 }
874 }
875 request::SignInStatus::NotAuthorized { .. } => {
876 server.sign_in_status = SignInStatus::Unauthorized;
877 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
878 self.unregister_buffer(buffer_id);
879 }
880 }
881 request::SignInStatus::NotSignedIn => {
882 server.sign_in_status = SignInStatus::SignedOut;
883 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
884 self.unregister_buffer(buffer_id);
885 }
886 }
887 }
888
889 cx.notify();
890 }
891 }
892}
893
894fn id_for_language(language: Option<&Arc<Language>>) -> String {
895 let language_name = language.map(|language| language.name());
896 match language_name.as_deref() {
897 Some("Plain Text") => "plaintext".to_string(),
898 Some(language_name) => language_name.to_lowercase(),
899 None => "plaintext".to_string(),
900 }
901}
902
903fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
904 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
905 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
906 } else {
907 format!("buffer://{}", buffer.id()).parse().unwrap()
908 }
909}
910
911async fn clear_copilot_dir() {
912 remove_matching(&paths::COPILOT_DIR, |_| true).await
913}
914
915async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
916 const SERVER_PATH: &'static str = "dist/agent.js";
917
918 ///Check for the latest copilot language server and download it if we haven't already
919 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
920 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
921
922 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
923
924 fs::create_dir_all(version_dir).await?;
925 let server_path = version_dir.join(SERVER_PATH);
926
927 if fs::metadata(&server_path).await.is_err() {
928 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
929 let dist_dir = version_dir.join("dist");
930 fs::create_dir_all(dist_dir.as_path()).await?;
931
932 let url = &release
933 .assets
934 .get(0)
935 .context("Github release for copilot contained no assets")?
936 .browser_download_url;
937
938 let mut response = http
939 .get(&url, Default::default(), true)
940 .await
941 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
942 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
943 let archive = Archive::new(decompressed_bytes);
944 archive.unpack(dist_dir).await?;
945
946 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
947 }
948
949 Ok(server_path)
950 }
951
952 match fetch_latest(http).await {
953 ok @ Result::Ok(..) => ok,
954 e @ Err(..) => {
955 e.log_err();
956 // Fetch a cached binary, if it exists
957 (|| async move {
958 let mut last_version_dir = None;
959 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
960 while let Some(entry) = entries.next().await {
961 let entry = entry?;
962 if entry.file_type().await?.is_dir() {
963 last_version_dir = Some(entry.path());
964 }
965 }
966 let last_version_dir =
967 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
968 let server_path = last_version_dir.join(SERVER_PATH);
969 if server_path.exists() {
970 Ok(server_path)
971 } else {
972 Err(anyhow!(
973 "missing executable in directory {:?}",
974 last_version_dir
975 ))
976 }
977 })()
978 .await
979 }
980 }
981}
982
983#[cfg(test)]
984mod tests {
985 use super::*;
986 use gpui::{executor::Deterministic, TestAppContext};
987
988 #[gpui::test(iterations = 10)]
989 async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
990 deterministic.forbid_parking();
991 let (copilot, mut lsp) = Copilot::fake(cx);
992
993 let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
994 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
995 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
996 assert_eq!(
997 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
998 .await,
999 lsp::DidOpenTextDocumentParams {
1000 text_document: lsp::TextDocumentItem::new(
1001 buffer_1_uri.clone(),
1002 "plaintext".into(),
1003 0,
1004 "Hello".into()
1005 ),
1006 }
1007 );
1008
1009 let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
1010 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1011 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1012 assert_eq!(
1013 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1014 .await,
1015 lsp::DidOpenTextDocumentParams {
1016 text_document: lsp::TextDocumentItem::new(
1017 buffer_2_uri.clone(),
1018 "plaintext".into(),
1019 0,
1020 "Goodbye".into()
1021 ),
1022 }
1023 );
1024
1025 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1026 assert_eq!(
1027 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1028 .await,
1029 lsp::DidChangeTextDocumentParams {
1030 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1031 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1032 range: Some(lsp::Range::new(
1033 lsp::Position::new(0, 5),
1034 lsp::Position::new(0, 5)
1035 )),
1036 range_length: None,
1037 text: " world".into(),
1038 }],
1039 }
1040 );
1041
1042 // Ensure updates to the file are reflected in the LSP.
1043 buffer_1
1044 .update(cx, |buffer, cx| {
1045 buffer.file_updated(
1046 Arc::new(File {
1047 abs_path: "/root/child/buffer-1".into(),
1048 path: Path::new("child/buffer-1").into(),
1049 }),
1050 cx,
1051 )
1052 })
1053 .await;
1054 assert_eq!(
1055 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1056 .await,
1057 lsp::DidCloseTextDocumentParams {
1058 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1059 }
1060 );
1061 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1062 assert_eq!(
1063 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1064 .await,
1065 lsp::DidOpenTextDocumentParams {
1066 text_document: lsp::TextDocumentItem::new(
1067 buffer_1_uri.clone(),
1068 "plaintext".into(),
1069 1,
1070 "Hello world".into()
1071 ),
1072 }
1073 );
1074
1075 // Ensure all previously-registered buffers are closed when signing out.
1076 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1077 Ok(request::SignOutResult {})
1078 });
1079 copilot
1080 .update(cx, |copilot, cx| copilot.sign_out(cx))
1081 .await
1082 .unwrap();
1083 assert_eq!(
1084 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1085 .await,
1086 lsp::DidCloseTextDocumentParams {
1087 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1088 }
1089 );
1090 assert_eq!(
1091 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1092 .await,
1093 lsp::DidCloseTextDocumentParams {
1094 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1095 }
1096 );
1097
1098 // Ensure all previously-registered buffers are re-opened when signing in.
1099 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1100 Ok(request::SignInInitiateResult::AlreadySignedIn {
1101 user: "user-1".into(),
1102 })
1103 });
1104 copilot
1105 .update(cx, |copilot, cx| copilot.sign_in(cx))
1106 .await
1107 .unwrap();
1108 assert_eq!(
1109 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1110 .await,
1111 lsp::DidOpenTextDocumentParams {
1112 text_document: lsp::TextDocumentItem::new(
1113 buffer_2_uri.clone(),
1114 "plaintext".into(),
1115 0,
1116 "Goodbye".into()
1117 ),
1118 }
1119 );
1120 assert_eq!(
1121 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1122 .await,
1123 lsp::DidOpenTextDocumentParams {
1124 text_document: lsp::TextDocumentItem::new(
1125 buffer_1_uri.clone(),
1126 "plaintext".into(),
1127 0,
1128 "Hello world".into()
1129 ),
1130 }
1131 );
1132
1133 // Dropping a buffer causes it to be closed on the LSP side as well.
1134 cx.update(|_| drop(buffer_2));
1135 assert_eq!(
1136 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1137 .await,
1138 lsp::DidCloseTextDocumentParams {
1139 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1140 }
1141 );
1142 }
1143
1144 struct File {
1145 abs_path: PathBuf,
1146 path: Arc<Path>,
1147 }
1148
1149 impl language::File for File {
1150 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1151 Some(self)
1152 }
1153
1154 fn mtime(&self) -> std::time::SystemTime {
1155 todo!()
1156 }
1157
1158 fn path(&self) -> &Arc<Path> {
1159 &self.path
1160 }
1161
1162 fn full_path(&self, _: &AppContext) -> PathBuf {
1163 todo!()
1164 }
1165
1166 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1167 todo!()
1168 }
1169
1170 fn is_deleted(&self) -> bool {
1171 todo!()
1172 }
1173
1174 fn as_any(&self) -> &dyn std::any::Any {
1175 todo!()
1176 }
1177
1178 fn to_proto(&self) -> rpc::proto::File {
1179 todo!()
1180 }
1181 }
1182
1183 impl language::LocalFile for File {
1184 fn abs_path(&self, _: &AppContext) -> PathBuf {
1185 self.abs_path.clone()
1186 }
1187
1188 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1189 todo!()
1190 }
1191
1192 fn buffer_reloaded(
1193 &self,
1194 _: u64,
1195 _: &clock::Global,
1196 _: language::RopeFingerprint,
1197 _: ::fs::LineEnding,
1198 _: std::time::SystemTime,
1199 _: &mut AppContext,
1200 ) {
1201 todo!()
1202 }
1203 }
1204}