1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
11};
12use language::{
13 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
14 ToPointUtf16,
15};
16use log::{debug, error};
17use lsp::LanguageServer;
18use node_runtime::NodeRuntime;
19use request::{LogMessage, StatusNotification};
20use settings::Settings;
21use smol::{fs, io::BufReader, stream::StreamExt};
22use std::{
23 ffi::OsString,
24 mem,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use util::{
30 channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
31 paths, ResultExt,
32};
33
34const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
35actions!(copilot_auth, [SignIn, SignOut]);
36
37const COPILOT_NAMESPACE: &'static str = "copilot";
38actions!(
39 copilot,
40 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
41);
42
43pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
44 // Disable Copilot for stable releases.
45 if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
46 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
47 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
48 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
49 });
50 return;
51 }
52
53 let copilot = cx.add_model({
54 let node_runtime = node_runtime.clone();
55 move |cx| Copilot::start(http, node_runtime, cx)
56 });
57 cx.set_global(copilot.clone());
58
59 cx.observe(&copilot, |handle, cx| {
60 let status = handle.read(cx).status();
61 cx.update_global::<collections::CommandPaletteFilter, _, _>(
62 move |filter, _cx| match status {
63 Status::Disabled => {
64 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
65 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
66 }
67 Status::Authorized => {
68 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
69 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
70 }
71 _ => {
72 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
73 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
74 }
75 },
76 );
77 })
78 .detach();
79
80 sign_in::init(cx);
81 cx.add_global_action(|_: &SignIn, cx| {
82 if let Some(copilot) = Copilot::global(cx) {
83 copilot
84 .update(cx, |copilot, cx| copilot.sign_in(cx))
85 .detach_and_log_err(cx);
86 }
87 });
88 cx.add_global_action(|_: &SignOut, cx| {
89 if let Some(copilot) = Copilot::global(cx) {
90 copilot
91 .update(cx, |copilot, cx| copilot.sign_out(cx))
92 .detach_and_log_err(cx);
93 }
94 });
95
96 cx.add_global_action(|_: &Reinstall, cx| {
97 if let Some(copilot) = Copilot::global(cx) {
98 copilot
99 .update(cx, |copilot, cx| copilot.reinstall(cx))
100 .detach();
101 }
102 });
103}
104
105enum CopilotServer {
106 Disabled,
107 Starting { task: Shared<Task<()>> },
108 Error(Arc<str>),
109 Running(RunningCopilotServer),
110}
111
112impl CopilotServer {
113 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
114 let server = self.as_running()?;
115 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
116 Ok(server)
117 } else {
118 Err(anyhow!("must sign in before using copilot"))
119 }
120 }
121
122 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
123 match self {
124 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
125 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
126 CopilotServer::Error(error) => Err(anyhow!(
127 "copilot was not started because of an error: {}",
128 error
129 )),
130 CopilotServer::Running(server) => Ok(server),
131 }
132 }
133}
134
135struct RunningCopilotServer {
136 lsp: Arc<LanguageServer>,
137 sign_in_status: SignInStatus,
138 registered_buffers: HashMap<usize, RegisteredBuffer>,
139}
140
141#[derive(Clone, Debug)]
142enum SignInStatus {
143 Authorized,
144 Unauthorized,
145 SigningIn {
146 prompt: Option<request::PromptUserDeviceFlow>,
147 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
148 },
149 SignedOut,
150}
151
152#[derive(Debug, Clone)]
153pub enum Status {
154 Starting {
155 task: Shared<Task<()>>,
156 },
157 Error(Arc<str>),
158 Disabled,
159 SignedOut,
160 SigningIn {
161 prompt: Option<request::PromptUserDeviceFlow>,
162 },
163 Unauthorized,
164 Authorized,
165}
166
167impl Status {
168 pub fn is_authorized(&self) -> bool {
169 matches!(self, Status::Authorized)
170 }
171}
172
173struct RegisteredBuffer {
174 uri: lsp::Url,
175 language_id: String,
176 snapshot: BufferSnapshot,
177 snapshot_version: i32,
178 _subscriptions: [gpui::Subscription; 2],
179}
180
181impl RegisteredBuffer {
182 fn report_changes(
183 &mut self,
184 buffer: &ModelHandle<Buffer>,
185 server: &LanguageServer,
186 cx: &AppContext,
187 ) -> Result<()> {
188 let buffer = buffer.read(cx);
189 let new_snapshot = buffer.snapshot();
190 let content_changes = buffer
191 .edits_since::<(PointUtf16, usize)>(self.snapshot.version())
192 .map(|edit| {
193 let edit_start = edit.new.start.0;
194 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
195 let new_text = new_snapshot
196 .text_for_range(edit.new.start.1..edit.new.end.1)
197 .collect();
198 lsp::TextDocumentContentChangeEvent {
199 range: Some(lsp::Range::new(
200 point_to_lsp(edit_start),
201 point_to_lsp(edit_end),
202 )),
203 range_length: None,
204 text: new_text,
205 }
206 })
207 .collect::<Vec<_>>();
208
209 if !content_changes.is_empty() {
210 self.snapshot_version += 1;
211 self.snapshot = new_snapshot;
212 server.notify::<lsp::notification::DidChangeTextDocument>(
213 lsp::DidChangeTextDocumentParams {
214 text_document: lsp::VersionedTextDocumentIdentifier::new(
215 self.uri.clone(),
216 self.snapshot_version,
217 ),
218 content_changes,
219 },
220 )?;
221 }
222
223 Ok(())
224 }
225}
226
227#[derive(Debug)]
228pub struct Completion {
229 uuid: String,
230 pub range: Range<Anchor>,
231 pub text: String,
232}
233
234pub struct Copilot {
235 http: Arc<dyn HttpClient>,
236 node_runtime: Arc<NodeRuntime>,
237 server: CopilotServer,
238 buffers: HashMap<usize, WeakModelHandle<Buffer>>,
239}
240
241impl Entity for Copilot {
242 type Event = ();
243}
244
245impl Copilot {
246 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
247 if cx.has_global::<ModelHandle<Self>>() {
248 Some(cx.global::<ModelHandle<Self>>().clone())
249 } else {
250 None
251 }
252 }
253
254 fn start(
255 http: Arc<dyn HttpClient>,
256 node_runtime: Arc<NodeRuntime>,
257 cx: &mut ModelContext<Self>,
258 ) -> Self {
259 cx.observe_global::<Settings, _>({
260 let http = http.clone();
261 let node_runtime = node_runtime.clone();
262 move |this, cx| {
263 if cx.global::<Settings>().features.copilot {
264 if matches!(this.server, CopilotServer::Disabled) {
265 let start_task = cx
266 .spawn({
267 let http = http.clone();
268 let node_runtime = node_runtime.clone();
269 move |this, cx| {
270 Self::start_language_server(http, node_runtime, this, cx)
271 }
272 })
273 .shared();
274 this.server = CopilotServer::Starting { task: start_task };
275 cx.notify();
276 }
277 } else {
278 this.server = CopilotServer::Disabled;
279 cx.notify();
280 }
281 }
282 })
283 .detach();
284
285 if cx.global::<Settings>().features.copilot {
286 let start_task = cx
287 .spawn({
288 let http = http.clone();
289 let node_runtime = node_runtime.clone();
290 move |this, cx| async {
291 Self::start_language_server(http, node_runtime, this, cx).await
292 }
293 })
294 .shared();
295
296 Self {
297 http,
298 node_runtime,
299 server: CopilotServer::Starting { task: start_task },
300 buffers: Default::default(),
301 }
302 } else {
303 Self {
304 http,
305 node_runtime,
306 server: CopilotServer::Disabled,
307 buffers: Default::default(),
308 }
309 }
310 }
311
312 #[cfg(any(test, feature = "test-support"))]
313 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
314 let (server, fake_server) =
315 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
316 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
317 let this = cx.add_model(|cx| Self {
318 http: http.clone(),
319 node_runtime: NodeRuntime::new(http, cx.background().clone()),
320 server: CopilotServer::Running(RunningCopilotServer {
321 lsp: Arc::new(server),
322 sign_in_status: SignInStatus::Authorized,
323 registered_buffers: Default::default(),
324 }),
325 buffers: Default::default(),
326 });
327 (this, fake_server)
328 }
329
330 fn start_language_server(
331 http: Arc<dyn HttpClient>,
332 node_runtime: Arc<NodeRuntime>,
333 this: ModelHandle<Self>,
334 mut cx: AsyncAppContext,
335 ) -> impl Future<Output = ()> {
336 async move {
337 let start_language_server = async {
338 let server_path = get_copilot_lsp(http).await?;
339 let node_path = node_runtime.binary_path().await?;
340 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
341 let server = LanguageServer::new(
342 0,
343 &node_path,
344 arguments,
345 Path::new("/"),
346 None,
347 cx.clone(),
348 )?;
349
350 let server = server.initialize(Default::default()).await?;
351 let status = server
352 .request::<request::CheckStatus>(request::CheckStatusParams {
353 local_checks_only: false,
354 })
355 .await?;
356
357 server
358 .on_notification::<LogMessage, _>(|params, _cx| {
359 match params.level {
360 // Copilot is pretty agressive about logging
361 0 => debug!("copilot: {}", params.message),
362 1 => debug!("copilot: {}", params.message),
363 _ => error!("copilot: {}", params.message),
364 }
365
366 debug!("copilot metadata: {}", params.metadata_str);
367 debug!("copilot extra: {:?}", params.extra);
368 })
369 .detach();
370
371 server
372 .on_notification::<StatusNotification, _>(
373 |_, _| { /* Silence the notification */ },
374 )
375 .detach();
376
377 server
378 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
379 editor_info: request::EditorInfo {
380 name: "zed".into(),
381 version: env!("CARGO_PKG_VERSION").into(),
382 },
383 editor_plugin_info: request::EditorPluginInfo {
384 name: "zed-copilot".into(),
385 version: "0.0.1".into(),
386 },
387 })
388 .await?;
389
390 anyhow::Ok((server, status))
391 };
392
393 let server = start_language_server.await;
394 this.update(&mut cx, |this, cx| {
395 cx.notify();
396 match server {
397 Ok((server, status)) => {
398 this.server = CopilotServer::Running(RunningCopilotServer {
399 lsp: server,
400 sign_in_status: SignInStatus::SignedOut,
401 registered_buffers: Default::default(),
402 });
403 this.update_sign_in_status(status, cx);
404 }
405 Err(error) => {
406 this.server = CopilotServer::Error(error.to_string().into());
407 cx.notify()
408 }
409 }
410 })
411 }
412 }
413
414 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
415 if let CopilotServer::Running(server) = &mut self.server {
416 let task = match &server.sign_in_status {
417 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
418 Task::ready(Ok(())).shared()
419 }
420 SignInStatus::SigningIn { task, .. } => {
421 cx.notify();
422 task.clone()
423 }
424 SignInStatus::SignedOut => {
425 let lsp = server.lsp.clone();
426 let task = cx
427 .spawn(|this, mut cx| async move {
428 let sign_in = async {
429 let sign_in = lsp
430 .request::<request::SignInInitiate>(
431 request::SignInInitiateParams {},
432 )
433 .await?;
434 match sign_in {
435 request::SignInInitiateResult::AlreadySignedIn { user } => {
436 Ok(request::SignInStatus::Ok { user })
437 }
438 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
439 this.update(&mut cx, |this, cx| {
440 if let CopilotServer::Running(RunningCopilotServer {
441 sign_in_status: status,
442 ..
443 }) = &mut this.server
444 {
445 if let SignInStatus::SigningIn {
446 prompt: prompt_flow,
447 ..
448 } = status
449 {
450 *prompt_flow = Some(flow.clone());
451 cx.notify();
452 }
453 }
454 });
455 let response = lsp
456 .request::<request::SignInConfirm>(
457 request::SignInConfirmParams {
458 user_code: flow.user_code,
459 },
460 )
461 .await?;
462 Ok(response)
463 }
464 }
465 };
466
467 let sign_in = sign_in.await;
468 this.update(&mut cx, |this, cx| match sign_in {
469 Ok(status) => {
470 this.update_sign_in_status(status, cx);
471 Ok(())
472 }
473 Err(error) => {
474 this.update_sign_in_status(
475 request::SignInStatus::NotSignedIn,
476 cx,
477 );
478 Err(Arc::new(error))
479 }
480 })
481 })
482 .shared();
483 server.sign_in_status = SignInStatus::SigningIn {
484 prompt: None,
485 task: task.clone(),
486 };
487 cx.notify();
488 task
489 }
490 };
491
492 cx.foreground()
493 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
494 } else {
495 // If we're downloading, wait until download is finished
496 // If we're in a stuck state, display to the user
497 Task::ready(Err(anyhow!("copilot hasn't started yet")))
498 }
499 }
500
501 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
502 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
503 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
504 let server = server.clone();
505 cx.background().spawn(async move {
506 server
507 .request::<request::SignOut>(request::SignOutParams {})
508 .await?;
509 anyhow::Ok(())
510 })
511 } else {
512 Task::ready(Err(anyhow!("copilot hasn't started yet")))
513 }
514 }
515
516 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
517 let start_task = cx
518 .spawn({
519 let http = self.http.clone();
520 let node_runtime = self.node_runtime.clone();
521 move |this, cx| async move {
522 clear_copilot_dir().await;
523 Self::start_language_server(http, node_runtime, this, cx).await
524 }
525 })
526 .shared();
527
528 self.server = CopilotServer::Starting {
529 task: start_task.clone(),
530 };
531
532 cx.notify();
533
534 cx.foreground().spawn(start_task)
535 }
536
537 pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
538 let buffer_id = buffer.id();
539 self.buffers.insert(buffer_id, buffer.downgrade());
540
541 if let CopilotServer::Running(RunningCopilotServer {
542 lsp: server,
543 sign_in_status: status,
544 registered_buffers,
545 ..
546 }) = &mut self.server
547 {
548 if !matches!(status, SignInStatus::Authorized { .. }) {
549 return;
550 }
551
552 registered_buffers.entry(buffer.id()).or_insert_with(|| {
553 let uri: lsp::Url = uri_for_buffer(buffer, cx);
554 let language_id = id_for_language(buffer.read(cx).language());
555 let snapshot = buffer.read(cx).snapshot();
556 server
557 .notify::<lsp::notification::DidOpenTextDocument>(
558 lsp::DidOpenTextDocumentParams {
559 text_document: lsp::TextDocumentItem {
560 uri: uri.clone(),
561 language_id: language_id.clone(),
562 version: 0,
563 text: snapshot.text(),
564 },
565 },
566 )
567 .log_err();
568
569 RegisteredBuffer {
570 uri,
571 language_id,
572 snapshot,
573 snapshot_version: 0,
574 _subscriptions: [
575 cx.subscribe(buffer, |this, buffer, event, cx| {
576 this.handle_buffer_event(buffer, event, cx).log_err();
577 }),
578 cx.observe_release(buffer, move |this, _buffer, _cx| {
579 this.buffers.remove(&buffer_id);
580 this.unregister_buffer(buffer_id);
581 }),
582 ],
583 }
584 });
585 }
586 }
587
588 fn handle_buffer_event(
589 &mut self,
590 buffer: ModelHandle<Buffer>,
591 event: &language::Event,
592 cx: &mut ModelContext<Self>,
593 ) -> Result<()> {
594 if let Ok(server) = self.server.as_running() {
595 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
596 match event {
597 language::Event::Edited => {
598 registered_buffer.report_changes(&buffer, &server.lsp, cx)?;
599 }
600 language::Event::Saved => {
601 server
602 .lsp
603 .notify::<lsp::notification::DidSaveTextDocument>(
604 lsp::DidSaveTextDocumentParams {
605 text_document: lsp::TextDocumentIdentifier::new(
606 registered_buffer.uri.clone(),
607 ),
608 text: None,
609 },
610 )?;
611 }
612 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
613 let new_language_id = id_for_language(buffer.read(cx).language());
614 let new_uri = uri_for_buffer(&buffer, cx);
615 if new_uri != registered_buffer.uri
616 || new_language_id != registered_buffer.language_id
617 {
618 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
619 registered_buffer.language_id = new_language_id;
620 server
621 .lsp
622 .notify::<lsp::notification::DidCloseTextDocument>(
623 lsp::DidCloseTextDocumentParams {
624 text_document: lsp::TextDocumentIdentifier::new(old_uri),
625 },
626 )?;
627 server
628 .lsp
629 .notify::<lsp::notification::DidOpenTextDocument>(
630 lsp::DidOpenTextDocumentParams {
631 text_document: lsp::TextDocumentItem::new(
632 registered_buffer.uri.clone(),
633 registered_buffer.language_id.clone(),
634 registered_buffer.snapshot_version,
635 registered_buffer.snapshot.text(),
636 ),
637 },
638 )?;
639 }
640 }
641 _ => {}
642 }
643 }
644 }
645
646 Ok(())
647 }
648
649 fn unregister_buffer(&mut self, buffer_id: usize) {
650 if let Ok(server) = self.server.as_running() {
651 if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
652 server
653 .lsp
654 .notify::<lsp::notification::DidCloseTextDocument>(
655 lsp::DidCloseTextDocumentParams {
656 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
657 },
658 )
659 .log_err();
660 }
661 }
662 }
663
664 pub fn completions<T>(
665 &mut self,
666 buffer: &ModelHandle<Buffer>,
667 position: T,
668 cx: &mut ModelContext<Self>,
669 ) -> Task<Result<Vec<Completion>>>
670 where
671 T: ToPointUtf16,
672 {
673 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
674 }
675
676 pub fn completions_cycling<T>(
677 &mut self,
678 buffer: &ModelHandle<Buffer>,
679 position: T,
680 cx: &mut ModelContext<Self>,
681 ) -> Task<Result<Vec<Completion>>>
682 where
683 T: ToPointUtf16,
684 {
685 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
686 }
687
688 pub fn accept_completion(
689 &mut self,
690 completion: &Completion,
691 cx: &mut ModelContext<Self>,
692 ) -> Task<Result<()>> {
693 let server = match self.server.as_authenticated() {
694 Ok(server) => server,
695 Err(error) => return Task::ready(Err(error)),
696 };
697 let request =
698 server
699 .lsp
700 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
701 uuid: completion.uuid.clone(),
702 });
703 cx.background().spawn(async move {
704 request.await?;
705 Ok(())
706 })
707 }
708
709 pub fn discard_completions(
710 &mut self,
711 completions: &[Completion],
712 cx: &mut ModelContext<Self>,
713 ) -> Task<Result<()>> {
714 let server = match self.server.as_authenticated() {
715 Ok(server) => server,
716 Err(error) => return Task::ready(Err(error)),
717 };
718 let request =
719 server
720 .lsp
721 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
722 uuids: completions
723 .iter()
724 .map(|completion| completion.uuid.clone())
725 .collect(),
726 });
727 cx.background().spawn(async move {
728 request.await?;
729 Ok(())
730 })
731 }
732
733 fn request_completions<R, T>(
734 &mut self,
735 buffer: &ModelHandle<Buffer>,
736 position: T,
737 cx: &mut ModelContext<Self>,
738 ) -> Task<Result<Vec<Completion>>>
739 where
740 R: 'static
741 + lsp::request::Request<
742 Params = request::GetCompletionsParams,
743 Result = request::GetCompletionsResult,
744 >,
745 T: ToPointUtf16,
746 {
747 self.register_buffer(buffer, cx);
748
749 let server = match self.server.as_authenticated() {
750 Ok(server) => server,
751 Err(error) => return Task::ready(Err(error)),
752 };
753 let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
754 if let Err(error) = registered_buffer.report_changes(buffer, &server.lsp, cx) {
755 return Task::ready(Err(error));
756 }
757
758 let uri = registered_buffer.uri.clone();
759 let snapshot = registered_buffer.snapshot.clone();
760 let version = registered_buffer.snapshot_version;
761 let settings = cx.global::<Settings>();
762 let position = position.to_point_utf16(&snapshot);
763 let language = snapshot.language_at(position);
764 let language_name = language.map(|language| language.name());
765 let language_name = language_name.as_deref();
766 let tab_size = settings.tab_size(language_name);
767 let hard_tabs = settings.hard_tabs(language_name);
768 let relative_path = snapshot
769 .file()
770 .map(|file| file.path().to_path_buf())
771 .unwrap_or_default();
772 let request = server.lsp.request::<R>(request::GetCompletionsParams {
773 doc: request::GetCompletionsDocument {
774 uri,
775 tab_size: tab_size.into(),
776 indent_size: 1,
777 insert_spaces: !hard_tabs,
778 relative_path: relative_path.to_string_lossy().into(),
779 position: point_to_lsp(position),
780 version: version.try_into().unwrap(),
781 },
782 });
783 cx.background().spawn(async move {
784 let result = request.await?;
785 let completions = result
786 .completions
787 .into_iter()
788 .map(|completion| {
789 let start = snapshot
790 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
791 let end =
792 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
793 Completion {
794 uuid: completion.uuid,
795 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
796 text: completion.text,
797 }
798 })
799 .collect();
800 anyhow::Ok(completions)
801 })
802 }
803
804 pub fn status(&self) -> Status {
805 match &self.server {
806 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
807 CopilotServer::Disabled => Status::Disabled,
808 CopilotServer::Error(error) => Status::Error(error.clone()),
809 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
810 match sign_in_status {
811 SignInStatus::Authorized { .. } => Status::Authorized,
812 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
813 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
814 prompt: prompt.clone(),
815 },
816 SignInStatus::SignedOut => Status::SignedOut,
817 }
818 }
819 }
820 }
821
822 fn update_sign_in_status(
823 &mut self,
824 lsp_status: request::SignInStatus,
825 cx: &mut ModelContext<Self>,
826 ) {
827 self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
828
829 if let Ok(server) = self.server.as_running() {
830 match lsp_status {
831 request::SignInStatus::Ok { .. }
832 | request::SignInStatus::MaybeOk { .. }
833 | request::SignInStatus::AlreadySignedIn { .. } => {
834 server.sign_in_status = SignInStatus::Authorized;
835 for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
836 if let Some(buffer) = buffer.upgrade(cx) {
837 self.register_buffer(&buffer, cx);
838 }
839 }
840 }
841 request::SignInStatus::NotAuthorized { .. } => {
842 server.sign_in_status = SignInStatus::Unauthorized;
843 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
844 self.unregister_buffer(buffer_id);
845 }
846 }
847 request::SignInStatus::NotSignedIn => {
848 server.sign_in_status = SignInStatus::SignedOut;
849 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
850 self.unregister_buffer(buffer_id);
851 }
852 }
853 }
854
855 cx.notify();
856 }
857 }
858}
859
860fn id_for_language(language: Option<&Arc<Language>>) -> String {
861 let language_name = language.map(|language| language.name());
862 match language_name.as_deref() {
863 Some("Plain Text") => "plaintext".to_string(),
864 Some(language_name) => language_name.to_lowercase(),
865 None => "plaintext".to_string(),
866 }
867}
868
869fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
870 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
871 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
872 } else {
873 format!("buffer://{}", buffer.id()).parse().unwrap()
874 }
875}
876
877async fn clear_copilot_dir() {
878 remove_matching(&paths::COPILOT_DIR, |_| true).await
879}
880
881async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
882 const SERVER_PATH: &'static str = "dist/agent.js";
883
884 ///Check for the latest copilot language server and download it if we haven't already
885 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
886 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
887
888 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
889
890 fs::create_dir_all(version_dir).await?;
891 let server_path = version_dir.join(SERVER_PATH);
892
893 if fs::metadata(&server_path).await.is_err() {
894 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
895 let dist_dir = version_dir.join("dist");
896 fs::create_dir_all(dist_dir.as_path()).await?;
897
898 let url = &release
899 .assets
900 .get(0)
901 .context("Github release for copilot contained no assets")?
902 .browser_download_url;
903
904 let mut response = http
905 .get(&url, Default::default(), true)
906 .await
907 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
908 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
909 let archive = Archive::new(decompressed_bytes);
910 archive.unpack(dist_dir).await?;
911
912 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
913 }
914
915 Ok(server_path)
916 }
917
918 match fetch_latest(http).await {
919 ok @ Result::Ok(..) => ok,
920 e @ Err(..) => {
921 e.log_err();
922 // Fetch a cached binary, if it exists
923 (|| async move {
924 let mut last_version_dir = None;
925 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
926 while let Some(entry) = entries.next().await {
927 let entry = entry?;
928 if entry.file_type().await?.is_dir() {
929 last_version_dir = Some(entry.path());
930 }
931 }
932 let last_version_dir =
933 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
934 let server_path = last_version_dir.join(SERVER_PATH);
935 if server_path.exists() {
936 Ok(server_path)
937 } else {
938 Err(anyhow!(
939 "missing executable in directory {:?}",
940 last_version_dir
941 ))
942 }
943 })()
944 .await
945 }
946 }
947}
948
949#[cfg(test)]
950mod tests {
951 use super::*;
952 use gpui::{executor::Deterministic, TestAppContext};
953
954 #[gpui::test(iterations = 10)]
955 async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
956 deterministic.forbid_parking();
957 let (copilot, mut lsp) = Copilot::fake(cx);
958
959 let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
960 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
961 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
962 assert_eq!(
963 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
964 .await,
965 lsp::DidOpenTextDocumentParams {
966 text_document: lsp::TextDocumentItem::new(
967 buffer_1_uri.clone(),
968 "plaintext".into(),
969 0,
970 "Hello".into()
971 ),
972 }
973 );
974
975 let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
976 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
977 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
978 assert_eq!(
979 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
980 .await,
981 lsp::DidOpenTextDocumentParams {
982 text_document: lsp::TextDocumentItem::new(
983 buffer_2_uri.clone(),
984 "plaintext".into(),
985 0,
986 "Goodbye".into()
987 ),
988 }
989 );
990
991 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
992 assert_eq!(
993 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
994 .await,
995 lsp::DidChangeTextDocumentParams {
996 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
997 content_changes: vec![lsp::TextDocumentContentChangeEvent {
998 range: Some(lsp::Range::new(
999 lsp::Position::new(0, 5),
1000 lsp::Position::new(0, 5)
1001 )),
1002 range_length: None,
1003 text: " world".into(),
1004 }],
1005 }
1006 );
1007
1008 // Ensure updates to the file are reflected in the LSP.
1009 buffer_1
1010 .update(cx, |buffer, cx| {
1011 buffer.file_updated(
1012 Arc::new(File {
1013 abs_path: "/root/child/buffer-1".into(),
1014 path: Path::new("child/buffer-1").into(),
1015 }),
1016 cx,
1017 )
1018 })
1019 .await;
1020 assert_eq!(
1021 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1022 .await,
1023 lsp::DidCloseTextDocumentParams {
1024 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1025 }
1026 );
1027 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1028 assert_eq!(
1029 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1030 .await,
1031 lsp::DidOpenTextDocumentParams {
1032 text_document: lsp::TextDocumentItem::new(
1033 buffer_1_uri.clone(),
1034 "plaintext".into(),
1035 1,
1036 "Hello world".into()
1037 ),
1038 }
1039 );
1040
1041 // Ensure all previously-registered buffers are closed when signing out.
1042 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1043 Ok(request::SignOutResult {})
1044 });
1045 copilot
1046 .update(cx, |copilot, cx| copilot.sign_out(cx))
1047 .await
1048 .unwrap();
1049 assert_eq!(
1050 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1051 .await,
1052 lsp::DidCloseTextDocumentParams {
1053 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1054 }
1055 );
1056 assert_eq!(
1057 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1058 .await,
1059 lsp::DidCloseTextDocumentParams {
1060 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1061 }
1062 );
1063
1064 // Ensure all previously-registered buffers are re-opened when signing in.
1065 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1066 Ok(request::SignInInitiateResult::AlreadySignedIn {
1067 user: "user-1".into(),
1068 })
1069 });
1070 copilot
1071 .update(cx, |copilot, cx| copilot.sign_in(cx))
1072 .await
1073 .unwrap();
1074 assert_eq!(
1075 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1076 .await,
1077 lsp::DidOpenTextDocumentParams {
1078 text_document: lsp::TextDocumentItem::new(
1079 buffer_2_uri.clone(),
1080 "plaintext".into(),
1081 0,
1082 "Goodbye".into()
1083 ),
1084 }
1085 );
1086 assert_eq!(
1087 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1088 .await,
1089 lsp::DidOpenTextDocumentParams {
1090 text_document: lsp::TextDocumentItem::new(
1091 buffer_1_uri.clone(),
1092 "plaintext".into(),
1093 0,
1094 "Hello world".into()
1095 ),
1096 }
1097 );
1098
1099 // Dropping a buffer causes it to be closed on the LSP side as well.
1100 cx.update(|_| drop(buffer_2));
1101 assert_eq!(
1102 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1103 .await,
1104 lsp::DidCloseTextDocumentParams {
1105 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1106 }
1107 );
1108 }
1109
1110 struct File {
1111 abs_path: PathBuf,
1112 path: Arc<Path>,
1113 }
1114
1115 impl language::File for File {
1116 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1117 Some(self)
1118 }
1119
1120 fn mtime(&self) -> std::time::SystemTime {
1121 todo!()
1122 }
1123
1124 fn path(&self) -> &Arc<Path> {
1125 &self.path
1126 }
1127
1128 fn full_path(&self, _: &AppContext) -> PathBuf {
1129 todo!()
1130 }
1131
1132 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1133 todo!()
1134 }
1135
1136 fn is_deleted(&self) -> bool {
1137 todo!()
1138 }
1139
1140 fn as_any(&self) -> &dyn std::any::Any {
1141 todo!()
1142 }
1143
1144 fn to_proto(&self) -> rpc::proto::File {
1145 todo!()
1146 }
1147 }
1148
1149 impl language::LocalFile for File {
1150 fn abs_path(&self, _: &AppContext) -> PathBuf {
1151 self.abs_path.clone()
1152 }
1153
1154 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1155 todo!()
1156 }
1157
1158 fn buffer_reloaded(
1159 &self,
1160 _: u64,
1161 _: &clock::Global,
1162 _: language::RopeFingerprint,
1163 _: ::fs::LineEnding,
1164 _: std::time::SystemTime,
1165 _: &mut AppContext,
1166 ) {
1167 todo!()
1168 }
1169 }
1170}