1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::{ClickhouseEvent, Client};
8use collections::HashMap;
9use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
10use gpui::{
11 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
12};
13use language::{
14 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
15 ToPointUtf16,
16};
17use log::{debug, error};
18use lsp::{LanguageServer, LanguageServerId};
19use node_runtime::NodeRuntime;
20use request::{LogMessage, StatusNotification};
21use settings::Settings;
22use smol::{fs, io::BufReader, stream::StreamExt};
23use std::{
24 ffi::OsString,
25 mem,
26 ops::Range,
27 path::{Path, PathBuf},
28 pin::Pin,
29 sync::Arc,
30};
31use util::{
32 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
33};
34
35const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
36actions!(copilot_auth, [SignIn, SignOut]);
37
38const COPILOT_NAMESPACE: &'static str = "copilot";
39actions!(
40 copilot,
41 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
42);
43
44pub fn init(client: &Client, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
45 let copilot = cx.add_model({
46 let node_runtime = node_runtime.clone();
47 move |cx| Copilot::start(client.http_client(), node_runtime, cx)
48 });
49 cx.set_global(copilot.clone());
50
51 let telemetry_settings = cx.global::<Settings>().telemetry();
52 let telemetry = client.telemetry();
53
54 cx.subscribe(&copilot, move |_, event, _| match event {
55 Event::CompletionAccepted { uuid, file_type } => {
56 let event = ClickhouseEvent::Copilot {
57 suggestion_id: uuid.clone(),
58 suggestion_accepted: true,
59 file_extension: file_type.clone().map(|a| a.to_string()),
60 };
61 telemetry.report_clickhouse_event(event, telemetry_settings);
62 }
63 Event::CompletionsDiscarded { uuids, file_type } => {
64 for uuid in uuids {
65 let event = ClickhouseEvent::Copilot {
66 suggestion_id: uuid.clone(),
67 suggestion_accepted: false,
68 file_extension: file_type.clone().map(|a| a.to_string()),
69 };
70 telemetry.report_clickhouse_event(event, telemetry_settings);
71 }
72 }
73 })
74 .detach();
75
76 cx.observe(&copilot, |handle, cx| {
77 let status = handle.read(cx).status();
78 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
79 match status {
80 Status::Disabled => {
81 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
82 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
83 }
84 Status::Authorized => {
85 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
86 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
87 }
88 _ => {
89 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
90 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
91 }
92 }
93 });
94 })
95 .detach();
96
97 sign_in::init(cx);
98 cx.add_global_action(|_: &SignIn, cx| {
99 if let Some(copilot) = Copilot::global(cx) {
100 copilot
101 .update(cx, |copilot, cx| copilot.sign_in(cx))
102 .detach_and_log_err(cx);
103 }
104 });
105 cx.add_global_action(|_: &SignOut, cx| {
106 if let Some(copilot) = Copilot::global(cx) {
107 copilot
108 .update(cx, |copilot, cx| copilot.sign_out(cx))
109 .detach_and_log_err(cx);
110 }
111 });
112
113 cx.add_global_action(|_: &Reinstall, cx| {
114 if let Some(copilot) = Copilot::global(cx) {
115 copilot
116 .update(cx, |copilot, cx| copilot.reinstall(cx))
117 .detach();
118 }
119 });
120}
121
122enum CopilotServer {
123 Disabled,
124 Starting { task: Shared<Task<()>> },
125 Error(Arc<str>),
126 Running(RunningCopilotServer),
127}
128
129impl CopilotServer {
130 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
131 let server = self.as_running()?;
132 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
133 Ok(server)
134 } else {
135 Err(anyhow!("must sign in before using copilot"))
136 }
137 }
138
139 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
140 match self {
141 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
142 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
143 CopilotServer::Error(error) => Err(anyhow!(
144 "copilot was not started because of an error: {}",
145 error
146 )),
147 CopilotServer::Running(server) => Ok(server),
148 }
149 }
150}
151
152struct RunningCopilotServer {
153 lsp: Arc<LanguageServer>,
154 sign_in_status: SignInStatus,
155 registered_buffers: HashMap<u64, RegisteredBuffer>,
156}
157
158#[derive(Clone, Debug)]
159enum SignInStatus {
160 Authorized,
161 Unauthorized,
162 SigningIn {
163 prompt: Option<request::PromptUserDeviceFlow>,
164 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
165 },
166 SignedOut,
167}
168
169#[derive(Debug, Clone)]
170pub enum Status {
171 Starting {
172 task: Shared<Task<()>>,
173 },
174 Error(Arc<str>),
175 Disabled,
176 SignedOut,
177 SigningIn {
178 prompt: Option<request::PromptUserDeviceFlow>,
179 },
180 Unauthorized,
181 Authorized,
182}
183
184impl Status {
185 pub fn is_authorized(&self) -> bool {
186 matches!(self, Status::Authorized)
187 }
188}
189
190struct RegisteredBuffer {
191 id: u64,
192 uri: lsp::Url,
193 language_id: String,
194 snapshot: BufferSnapshot,
195 snapshot_version: i32,
196 _subscriptions: [gpui::Subscription; 2],
197 pending_buffer_change: Task<Option<()>>,
198}
199
200impl RegisteredBuffer {
201 fn report_changes(
202 &mut self,
203 buffer: &ModelHandle<Buffer>,
204 cx: &mut ModelContext<Copilot>,
205 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
206 let id = self.id;
207 let (done_tx, done_rx) = oneshot::channel();
208
209 if buffer.read(cx).version() == self.snapshot.version {
210 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
211 } else {
212 let buffer = buffer.downgrade();
213 let prev_pending_change =
214 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
215 self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
216 prev_pending_change.await;
217
218 let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
219 let server = copilot.server.as_authenticated().log_err()?;
220 let buffer = server.registered_buffers.get_mut(&id)?;
221 Some(buffer.snapshot.version.clone())
222 })?;
223 let new_snapshot = buffer
224 .upgrade(&cx)?
225 .read_with(&cx, |buffer, _| buffer.snapshot());
226
227 let content_changes = cx
228 .background()
229 .spawn({
230 let new_snapshot = new_snapshot.clone();
231 async move {
232 new_snapshot
233 .edits_since::<(PointUtf16, usize)>(&old_version)
234 .map(|edit| {
235 let edit_start = edit.new.start.0;
236 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
237 let new_text = new_snapshot
238 .text_for_range(edit.new.start.1..edit.new.end.1)
239 .collect();
240 lsp::TextDocumentContentChangeEvent {
241 range: Some(lsp::Range::new(
242 point_to_lsp(edit_start),
243 point_to_lsp(edit_end),
244 )),
245 range_length: None,
246 text: new_text,
247 }
248 })
249 .collect::<Vec<_>>()
250 }
251 })
252 .await;
253
254 copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| {
255 let server = copilot.server.as_authenticated().log_err()?;
256 let buffer = server.registered_buffers.get_mut(&id)?;
257 if !content_changes.is_empty() {
258 buffer.snapshot_version += 1;
259 buffer.snapshot = new_snapshot;
260 server
261 .lsp
262 .notify::<lsp::notification::DidChangeTextDocument>(
263 lsp::DidChangeTextDocumentParams {
264 text_document: lsp::VersionedTextDocumentIdentifier::new(
265 buffer.uri.clone(),
266 buffer.snapshot_version,
267 ),
268 content_changes,
269 },
270 )
271 .log_err();
272 }
273 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
274 Some(())
275 })?;
276
277 Some(())
278 });
279 }
280
281 done_rx
282 }
283}
284
285#[derive(Debug)]
286pub struct Completion {
287 uuid: String,
288 pub range: Range<Anchor>,
289 pub text: String,
290}
291
292pub struct Copilot {
293 http: Arc<dyn HttpClient>,
294 node_runtime: Arc<NodeRuntime>,
295 server: CopilotServer,
296 buffers: HashMap<u64, WeakModelHandle<Buffer>>,
297}
298
299pub enum Event {
300 CompletionAccepted {
301 uuid: String,
302 file_type: Option<Arc<str>>,
303 },
304 CompletionsDiscarded {
305 uuids: Vec<String>,
306 file_type: Option<Arc<str>>,
307 },
308}
309
310impl Entity for Copilot {
311 type Event = Event;
312
313 fn app_will_quit(
314 &mut self,
315 _: &mut AppContext,
316 ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
317 match mem::replace(&mut self.server, CopilotServer::Disabled) {
318 CopilotServer::Running(server) => Some(Box::pin(async move {
319 if let Some(shutdown) = server.lsp.shutdown() {
320 shutdown.await;
321 }
322 })),
323 _ => None,
324 }
325 }
326}
327
328impl Copilot {
329 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
330 if cx.has_global::<ModelHandle<Self>>() {
331 Some(cx.global::<ModelHandle<Self>>().clone())
332 } else {
333 None
334 }
335 }
336
337 fn start(
338 http: Arc<dyn HttpClient>,
339 node_runtime: Arc<NodeRuntime>,
340 cx: &mut ModelContext<Self>,
341 ) -> Self {
342 cx.observe_global::<Settings, _>({
343 let http = http.clone();
344 let node_runtime = node_runtime.clone();
345 move |this, cx| {
346 if cx.global::<Settings>().features.copilot {
347 if matches!(this.server, CopilotServer::Disabled) {
348 let start_task = cx
349 .spawn({
350 let http = http.clone();
351 let node_runtime = node_runtime.clone();
352 move |this, cx| {
353 Self::start_language_server(http, node_runtime, this, cx)
354 }
355 })
356 .shared();
357 this.server = CopilotServer::Starting { task: start_task };
358 cx.notify();
359 }
360 } else {
361 this.server = CopilotServer::Disabled;
362 cx.notify();
363 }
364 }
365 })
366 .detach();
367
368 if cx.global::<Settings>().features.copilot {
369 let start_task = cx
370 .spawn({
371 let http = http.clone();
372 let node_runtime = node_runtime.clone();
373 move |this, cx| async {
374 Self::start_language_server(http, node_runtime, this, cx).await
375 }
376 })
377 .shared();
378
379 Self {
380 http,
381 node_runtime,
382 server: CopilotServer::Starting { task: start_task },
383 buffers: Default::default(),
384 }
385 } else {
386 Self {
387 http,
388 node_runtime,
389 server: CopilotServer::Disabled,
390 buffers: Default::default(),
391 }
392 }
393 }
394
395 #[cfg(any(test, feature = "test-support"))]
396 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
397 let (server, fake_server) =
398 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
399 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
400 let this = cx.add_model(|cx| Self {
401 http: http.clone(),
402 node_runtime: NodeRuntime::new(http, cx.background().clone()),
403 server: CopilotServer::Running(RunningCopilotServer {
404 lsp: Arc::new(server),
405 sign_in_status: SignInStatus::Authorized,
406 registered_buffers: Default::default(),
407 }),
408 buffers: Default::default(),
409 });
410 (this, fake_server)
411 }
412
413 fn start_language_server(
414 http: Arc<dyn HttpClient>,
415 node_runtime: Arc<NodeRuntime>,
416 this: ModelHandle<Self>,
417 mut cx: AsyncAppContext,
418 ) -> impl Future<Output = ()> {
419 async move {
420 let start_language_server = async {
421 let server_path = get_copilot_lsp(http).await?;
422 let node_path = node_runtime.binary_path().await?;
423 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
424 let server = LanguageServer::new(
425 LanguageServerId(0),
426 &node_path,
427 arguments,
428 Path::new("/"),
429 None,
430 cx.clone(),
431 )?;
432
433 server
434 .on_notification::<LogMessage, _>(|params, _cx| {
435 match params.level {
436 // Copilot is pretty agressive about logging
437 0 => debug!("copilot: {}", params.message),
438 1 => debug!("copilot: {}", params.message),
439 _ => error!("copilot: {}", params.message),
440 }
441
442 debug!("copilot metadata: {}", params.metadata_str);
443 debug!("copilot extra: {:?}", params.extra);
444 })
445 .detach();
446
447 server
448 .on_notification::<StatusNotification, _>(
449 |_, _| { /* Silence the notification */ },
450 )
451 .detach();
452
453 let server = server.initialize(Default::default()).await?;
454
455 let status = server
456 .request::<request::CheckStatus>(request::CheckStatusParams {
457 local_checks_only: false,
458 })
459 .await?;
460
461 server
462 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
463 editor_info: request::EditorInfo {
464 name: "zed".into(),
465 version: env!("CARGO_PKG_VERSION").into(),
466 },
467 editor_plugin_info: request::EditorPluginInfo {
468 name: "zed-copilot".into(),
469 version: "0.0.1".into(),
470 },
471 })
472 .await?;
473
474 anyhow::Ok((server, status))
475 };
476
477 let server = start_language_server.await;
478 this.update(&mut cx, |this, cx| {
479 cx.notify();
480 match server {
481 Ok((server, status)) => {
482 this.server = CopilotServer::Running(RunningCopilotServer {
483 lsp: server,
484 sign_in_status: SignInStatus::SignedOut,
485 registered_buffers: Default::default(),
486 });
487 this.update_sign_in_status(status, cx);
488 }
489 Err(error) => {
490 this.server = CopilotServer::Error(error.to_string().into());
491 cx.notify()
492 }
493 }
494 })
495 }
496 }
497
498 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
499 if let CopilotServer::Running(server) = &mut self.server {
500 let task = match &server.sign_in_status {
501 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
502 SignInStatus::SigningIn { task, .. } => {
503 cx.notify();
504 task.clone()
505 }
506 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
507 let lsp = server.lsp.clone();
508 let task = cx
509 .spawn(|this, mut cx| async move {
510 let sign_in = async {
511 let sign_in = lsp
512 .request::<request::SignInInitiate>(
513 request::SignInInitiateParams {},
514 )
515 .await?;
516 match sign_in {
517 request::SignInInitiateResult::AlreadySignedIn { user } => {
518 Ok(request::SignInStatus::Ok { user })
519 }
520 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
521 this.update(&mut cx, |this, cx| {
522 if let CopilotServer::Running(RunningCopilotServer {
523 sign_in_status: status,
524 ..
525 }) = &mut this.server
526 {
527 if let SignInStatus::SigningIn {
528 prompt: prompt_flow,
529 ..
530 } = status
531 {
532 *prompt_flow = Some(flow.clone());
533 cx.notify();
534 }
535 }
536 });
537 let response = lsp
538 .request::<request::SignInConfirm>(
539 request::SignInConfirmParams {
540 user_code: flow.user_code,
541 },
542 )
543 .await?;
544 Ok(response)
545 }
546 }
547 };
548
549 let sign_in = sign_in.await;
550 this.update(&mut cx, |this, cx| match sign_in {
551 Ok(status) => {
552 this.update_sign_in_status(status, cx);
553 Ok(())
554 }
555 Err(error) => {
556 this.update_sign_in_status(
557 request::SignInStatus::NotSignedIn,
558 cx,
559 );
560 Err(Arc::new(error))
561 }
562 })
563 })
564 .shared();
565 server.sign_in_status = SignInStatus::SigningIn {
566 prompt: None,
567 task: task.clone(),
568 };
569 cx.notify();
570 task
571 }
572 };
573
574 cx.foreground()
575 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
576 } else {
577 // If we're downloading, wait until download is finished
578 // If we're in a stuck state, display to the user
579 Task::ready(Err(anyhow!("copilot hasn't started yet")))
580 }
581 }
582
583 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
584 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
585 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
586 let server = server.clone();
587 cx.background().spawn(async move {
588 server
589 .request::<request::SignOut>(request::SignOutParams {})
590 .await?;
591 anyhow::Ok(())
592 })
593 } else {
594 Task::ready(Err(anyhow!("copilot hasn't started yet")))
595 }
596 }
597
598 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
599 let start_task = cx
600 .spawn({
601 let http = self.http.clone();
602 let node_runtime = self.node_runtime.clone();
603 move |this, cx| async move {
604 clear_copilot_dir().await;
605 Self::start_language_server(http, node_runtime, this, cx).await
606 }
607 })
608 .shared();
609
610 self.server = CopilotServer::Starting {
611 task: start_task.clone(),
612 };
613
614 cx.notify();
615
616 cx.foreground().spawn(start_task)
617 }
618
619 pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
620 let buffer_id = buffer.read(cx).remote_id();
621 self.buffers.insert(buffer_id, buffer.downgrade());
622
623 if let CopilotServer::Running(RunningCopilotServer {
624 lsp: server,
625 sign_in_status: status,
626 registered_buffers,
627 ..
628 }) = &mut self.server
629 {
630 if !matches!(status, SignInStatus::Authorized { .. }) {
631 return;
632 }
633
634 let buffer_id = buffer.read(cx).remote_id();
635 registered_buffers.entry(buffer_id).or_insert_with(|| {
636 let uri: lsp::Url = uri_for_buffer(buffer, cx);
637 let language_id = id_for_language(buffer.read(cx).language());
638 let snapshot = buffer.read(cx).snapshot();
639 server
640 .notify::<lsp::notification::DidOpenTextDocument>(
641 lsp::DidOpenTextDocumentParams {
642 text_document: lsp::TextDocumentItem {
643 uri: uri.clone(),
644 language_id: language_id.clone(),
645 version: 0,
646 text: snapshot.text(),
647 },
648 },
649 )
650 .log_err();
651
652 RegisteredBuffer {
653 id: buffer_id,
654 uri,
655 language_id,
656 snapshot,
657 snapshot_version: 0,
658 pending_buffer_change: Task::ready(Some(())),
659 _subscriptions: [
660 cx.subscribe(buffer, |this, buffer, event, cx| {
661 this.handle_buffer_event(buffer, event, cx).log_err();
662 }),
663 cx.observe_release(buffer, move |this, _buffer, _cx| {
664 this.buffers.remove(&buffer_id);
665 this.unregister_buffer(buffer_id);
666 }),
667 ],
668 }
669 });
670 }
671 }
672
673 fn handle_buffer_event(
674 &mut self,
675 buffer: ModelHandle<Buffer>,
676 event: &language::Event,
677 cx: &mut ModelContext<Self>,
678 ) -> Result<()> {
679 if let Ok(server) = self.server.as_running() {
680 let buffer_id = buffer.read(cx).remote_id();
681 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer_id) {
682 match event {
683 language::Event::Edited => {
684 let _ = registered_buffer.report_changes(&buffer, cx);
685 }
686 language::Event::Saved => {
687 server
688 .lsp
689 .notify::<lsp::notification::DidSaveTextDocument>(
690 lsp::DidSaveTextDocumentParams {
691 text_document: lsp::TextDocumentIdentifier::new(
692 registered_buffer.uri.clone(),
693 ),
694 text: None,
695 },
696 )?;
697 }
698 language::Event::FileHandleChanged | language::Event::LanguageChanged => {
699 let new_language_id = id_for_language(buffer.read(cx).language());
700 let new_uri = uri_for_buffer(&buffer, cx);
701 if new_uri != registered_buffer.uri
702 || new_language_id != registered_buffer.language_id
703 {
704 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
705 registered_buffer.language_id = new_language_id;
706 server
707 .lsp
708 .notify::<lsp::notification::DidCloseTextDocument>(
709 lsp::DidCloseTextDocumentParams {
710 text_document: lsp::TextDocumentIdentifier::new(old_uri),
711 },
712 )?;
713 server
714 .lsp
715 .notify::<lsp::notification::DidOpenTextDocument>(
716 lsp::DidOpenTextDocumentParams {
717 text_document: lsp::TextDocumentItem::new(
718 registered_buffer.uri.clone(),
719 registered_buffer.language_id.clone(),
720 registered_buffer.snapshot_version,
721 registered_buffer.snapshot.text(),
722 ),
723 },
724 )?;
725 }
726 }
727 _ => {}
728 }
729 }
730 }
731
732 Ok(())
733 }
734
735 fn unregister_buffer(&mut self, buffer_id: u64) {
736 if let Ok(server) = self.server.as_running() {
737 if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
738 server
739 .lsp
740 .notify::<lsp::notification::DidCloseTextDocument>(
741 lsp::DidCloseTextDocumentParams {
742 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
743 },
744 )
745 .log_err();
746 }
747 }
748 }
749
750 pub fn completions<T>(
751 &mut self,
752 buffer: &ModelHandle<Buffer>,
753 position: T,
754 cx: &mut ModelContext<Self>,
755 ) -> Task<Result<Vec<Completion>>>
756 where
757 T: ToPointUtf16,
758 {
759 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
760 }
761
762 pub fn completions_cycling<T>(
763 &mut self,
764 buffer: &ModelHandle<Buffer>,
765 position: T,
766 cx: &mut ModelContext<Self>,
767 ) -> Task<Result<Vec<Completion>>>
768 where
769 T: ToPointUtf16,
770 {
771 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
772 }
773
774 pub fn accept_completion(
775 &mut self,
776 completion: &Completion,
777 file_type: Option<Arc<str>>,
778 cx: &mut ModelContext<Self>,
779 ) -> Task<Result<()>> {
780 let server = match self.server.as_authenticated() {
781 Ok(server) => server,
782 Err(error) => return Task::ready(Err(error)),
783 };
784
785 cx.emit(Event::CompletionAccepted {
786 uuid: completion.uuid.clone(),
787 file_type,
788 });
789
790 let request =
791 server
792 .lsp
793 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
794 uuid: completion.uuid.clone(),
795 });
796
797 cx.background().spawn(async move {
798 request.await?;
799 Ok(())
800 })
801 }
802
803 pub fn discard_completions(
804 &mut self,
805 completions: &[Completion],
806 file_type: Option<Arc<str>>,
807 cx: &mut ModelContext<Self>,
808 ) -> Task<Result<()>> {
809 let server = match self.server.as_authenticated() {
810 Ok(server) => server,
811 Err(error) => return Task::ready(Err(error)),
812 };
813
814 cx.emit(Event::CompletionsDiscarded {
815 uuids: completions
816 .iter()
817 .map(|completion| completion.uuid.clone())
818 .collect(),
819 file_type: file_type.clone(),
820 });
821
822 let request =
823 server
824 .lsp
825 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
826 uuids: completions
827 .iter()
828 .map(|completion| completion.uuid.clone())
829 .collect(),
830 });
831
832 cx.background().spawn(async move {
833 request.await?;
834 Ok(())
835 })
836 }
837
838 fn request_completions<R, T>(
839 &mut self,
840 buffer: &ModelHandle<Buffer>,
841 position: T,
842 cx: &mut ModelContext<Self>,
843 ) -> Task<Result<Vec<Completion>>>
844 where
845 R: 'static
846 + lsp::request::Request<
847 Params = request::GetCompletionsParams,
848 Result = request::GetCompletionsResult,
849 >,
850 T: ToPointUtf16,
851 {
852 self.register_buffer(buffer, cx);
853
854 let server = match self.server.as_authenticated() {
855 Ok(server) => server,
856 Err(error) => return Task::ready(Err(error)),
857 };
858 let lsp = server.lsp.clone();
859 let buffer_id = buffer.read(cx).remote_id();
860 let registered_buffer = server.registered_buffers.get_mut(&buffer_id).unwrap();
861 let snapshot = registered_buffer.report_changes(buffer, cx);
862 let buffer = buffer.read(cx);
863 let uri = registered_buffer.uri.clone();
864 let settings = cx.global::<Settings>();
865 let position = position.to_point_utf16(buffer);
866 let language = buffer.language_at(position);
867 let language_name = language.map(|language| language.name());
868 let language_name = language_name.as_deref();
869 let tab_size = settings.tab_size(language_name);
870 let hard_tabs = settings.hard_tabs(language_name);
871 let relative_path = buffer
872 .file()
873 .map(|file| file.path().to_path_buf())
874 .unwrap_or_default();
875
876 cx.foreground().spawn(async move {
877 let (version, snapshot) = snapshot.await?;
878 let result = lsp
879 .request::<R>(request::GetCompletionsParams {
880 doc: request::GetCompletionsDocument {
881 uri,
882 tab_size: tab_size.into(),
883 indent_size: 1,
884 insert_spaces: !hard_tabs,
885 relative_path: relative_path.to_string_lossy().into(),
886 position: point_to_lsp(position),
887 version: version.try_into().unwrap(),
888 },
889 })
890 .await?;
891 let completions = result
892 .completions
893 .into_iter()
894 .map(|completion| {
895 let start = snapshot
896 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
897 let end =
898 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
899 Completion {
900 uuid: completion.uuid,
901 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
902 text: completion.text,
903 }
904 })
905 .collect();
906 anyhow::Ok(completions)
907 })
908 }
909
910 pub fn status(&self) -> Status {
911 match &self.server {
912 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
913 CopilotServer::Disabled => Status::Disabled,
914 CopilotServer::Error(error) => Status::Error(error.clone()),
915 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
916 match sign_in_status {
917 SignInStatus::Authorized { .. } => Status::Authorized,
918 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
919 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
920 prompt: prompt.clone(),
921 },
922 SignInStatus::SignedOut => Status::SignedOut,
923 }
924 }
925 }
926 }
927
928 fn update_sign_in_status(
929 &mut self,
930 lsp_status: request::SignInStatus,
931 cx: &mut ModelContext<Self>,
932 ) {
933 self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
934
935 if let Ok(server) = self.server.as_running() {
936 match lsp_status {
937 request::SignInStatus::Ok { .. }
938 | request::SignInStatus::MaybeOk { .. }
939 | request::SignInStatus::AlreadySignedIn { .. } => {
940 server.sign_in_status = SignInStatus::Authorized;
941 for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
942 if let Some(buffer) = buffer.upgrade(cx) {
943 self.register_buffer(&buffer, cx);
944 }
945 }
946 }
947 request::SignInStatus::NotAuthorized { .. } => {
948 server.sign_in_status = SignInStatus::Unauthorized;
949 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
950 self.unregister_buffer(buffer_id);
951 }
952 }
953 request::SignInStatus::NotSignedIn => {
954 server.sign_in_status = SignInStatus::SignedOut;
955 for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
956 self.unregister_buffer(buffer_id);
957 }
958 }
959 }
960
961 cx.notify();
962 }
963 }
964}
965
966fn id_for_language(language: Option<&Arc<Language>>) -> String {
967 let language_name = language.map(|language| language.name());
968 match language_name.as_deref() {
969 Some("Plain Text") => "plaintext".to_string(),
970 Some(language_name) => language_name.to_lowercase(),
971 None => "plaintext".to_string(),
972 }
973}
974
975fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
976 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
977 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
978 } else {
979 format!("buffer://{}", buffer.read(cx).remote_id())
980 .parse()
981 .unwrap()
982 }
983}
984
985async fn clear_copilot_dir() {
986 remove_matching(&paths::COPILOT_DIR, |_| true).await
987}
988
989async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
990 const SERVER_PATH: &'static str = "dist/agent.js";
991
992 ///Check for the latest copilot language server and download it if we haven't already
993 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
994 let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?;
995
996 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
997
998 fs::create_dir_all(version_dir).await?;
999 let server_path = version_dir.join(SERVER_PATH);
1000
1001 if fs::metadata(&server_path).await.is_err() {
1002 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
1003 let dist_dir = version_dir.join("dist");
1004 fs::create_dir_all(dist_dir.as_path()).await?;
1005
1006 let url = &release
1007 .assets
1008 .get(0)
1009 .context("Github release for copilot contained no assets")?
1010 .browser_download_url;
1011
1012 let mut response = http
1013 .get(&url, Default::default(), true)
1014 .await
1015 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
1016 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
1017 let archive = Archive::new(decompressed_bytes);
1018 archive.unpack(dist_dir).await?;
1019
1020 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
1021 }
1022
1023 Ok(server_path)
1024 }
1025
1026 match fetch_latest(http).await {
1027 ok @ Result::Ok(..) => ok,
1028 e @ Err(..) => {
1029 e.log_err();
1030 // Fetch a cached binary, if it exists
1031 (|| async move {
1032 let mut last_version_dir = None;
1033 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
1034 while let Some(entry) = entries.next().await {
1035 let entry = entry?;
1036 if entry.file_type().await?.is_dir() {
1037 last_version_dir = Some(entry.path());
1038 }
1039 }
1040 let last_version_dir =
1041 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1042 let server_path = last_version_dir.join(SERVER_PATH);
1043 if server_path.exists() {
1044 Ok(server_path)
1045 } else {
1046 Err(anyhow!(
1047 "missing executable in directory {:?}",
1048 last_version_dir
1049 ))
1050 }
1051 })()
1052 .await
1053 }
1054 }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060 use gpui::{executor::Deterministic, TestAppContext};
1061
1062 #[gpui::test(iterations = 10)]
1063 async fn test_buffer_management(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1064 deterministic.forbid_parking();
1065 let (copilot, mut lsp) = Copilot::fake(cx);
1066
1067 let buffer_1 = cx.add_model(|cx| Buffer::new(0, "Hello", cx));
1068 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap();
1069 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1070 assert_eq!(
1071 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1072 .await,
1073 lsp::DidOpenTextDocumentParams {
1074 text_document: lsp::TextDocumentItem::new(
1075 buffer_1_uri.clone(),
1076 "plaintext".into(),
1077 0,
1078 "Hello".into()
1079 ),
1080 }
1081 );
1082
1083 let buffer_2 = cx.add_model(|cx| Buffer::new(0, "Goodbye", cx));
1084 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap();
1085 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1086 assert_eq!(
1087 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1088 .await,
1089 lsp::DidOpenTextDocumentParams {
1090 text_document: lsp::TextDocumentItem::new(
1091 buffer_2_uri.clone(),
1092 "plaintext".into(),
1093 0,
1094 "Goodbye".into()
1095 ),
1096 }
1097 );
1098
1099 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1100 assert_eq!(
1101 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1102 .await,
1103 lsp::DidChangeTextDocumentParams {
1104 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1105 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1106 range: Some(lsp::Range::new(
1107 lsp::Position::new(0, 5),
1108 lsp::Position::new(0, 5)
1109 )),
1110 range_length: None,
1111 text: " world".into(),
1112 }],
1113 }
1114 );
1115
1116 // Ensure updates to the file are reflected in the LSP.
1117 buffer_1
1118 .update(cx, |buffer, cx| {
1119 buffer.file_updated(
1120 Arc::new(File {
1121 abs_path: "/root/child/buffer-1".into(),
1122 path: Path::new("child/buffer-1").into(),
1123 }),
1124 cx,
1125 )
1126 })
1127 .await;
1128 assert_eq!(
1129 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1130 .await,
1131 lsp::DidCloseTextDocumentParams {
1132 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1133 }
1134 );
1135 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1136 assert_eq!(
1137 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1138 .await,
1139 lsp::DidOpenTextDocumentParams {
1140 text_document: lsp::TextDocumentItem::new(
1141 buffer_1_uri.clone(),
1142 "plaintext".into(),
1143 1,
1144 "Hello world".into()
1145 ),
1146 }
1147 );
1148
1149 // Ensure all previously-registered buffers are closed when signing out.
1150 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1151 Ok(request::SignOutResult {})
1152 });
1153 copilot
1154 .update(cx, |copilot, cx| copilot.sign_out(cx))
1155 .await
1156 .unwrap();
1157 assert_eq!(
1158 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1159 .await,
1160 lsp::DidCloseTextDocumentParams {
1161 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1162 }
1163 );
1164 assert_eq!(
1165 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1166 .await,
1167 lsp::DidCloseTextDocumentParams {
1168 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1169 }
1170 );
1171
1172 // Ensure all previously-registered buffers are re-opened when signing in.
1173 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1174 Ok(request::SignInInitiateResult::AlreadySignedIn {
1175 user: "user-1".into(),
1176 })
1177 });
1178 copilot
1179 .update(cx, |copilot, cx| copilot.sign_in(cx))
1180 .await
1181 .unwrap();
1182 assert_eq!(
1183 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1184 .await,
1185 lsp::DidOpenTextDocumentParams {
1186 text_document: lsp::TextDocumentItem::new(
1187 buffer_2_uri.clone(),
1188 "plaintext".into(),
1189 0,
1190 "Goodbye".into()
1191 ),
1192 }
1193 );
1194 assert_eq!(
1195 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1196 .await,
1197 lsp::DidOpenTextDocumentParams {
1198 text_document: lsp::TextDocumentItem::new(
1199 buffer_1_uri.clone(),
1200 "plaintext".into(),
1201 0,
1202 "Hello world".into()
1203 ),
1204 }
1205 );
1206
1207 // Dropping a buffer causes it to be closed on the LSP side as well.
1208 cx.update(|_| drop(buffer_2));
1209 assert_eq!(
1210 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1211 .await,
1212 lsp::DidCloseTextDocumentParams {
1213 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1214 }
1215 );
1216 }
1217
1218 struct File {
1219 abs_path: PathBuf,
1220 path: Arc<Path>,
1221 }
1222
1223 impl language::File for File {
1224 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1225 Some(self)
1226 }
1227
1228 fn mtime(&self) -> std::time::SystemTime {
1229 unimplemented!()
1230 }
1231
1232 fn path(&self) -> &Arc<Path> {
1233 &self.path
1234 }
1235
1236 fn full_path(&self, _: &AppContext) -> PathBuf {
1237 unimplemented!()
1238 }
1239
1240 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1241 unimplemented!()
1242 }
1243
1244 fn is_deleted(&self) -> bool {
1245 unimplemented!()
1246 }
1247
1248 fn as_any(&self) -> &dyn std::any::Any {
1249 unimplemented!()
1250 }
1251
1252 fn to_proto(&self) -> rpc::proto::File {
1253 unimplemented!()
1254 }
1255 }
1256
1257 impl language::LocalFile for File {
1258 fn abs_path(&self, _: &AppContext) -> PathBuf {
1259 self.abs_path.clone()
1260 }
1261
1262 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1263 unimplemented!()
1264 }
1265
1266 fn buffer_reloaded(
1267 &self,
1268 _: u64,
1269 _: &clock::Global,
1270 _: language::RopeFingerprint,
1271 _: ::fs::LineEnding,
1272 _: std::time::SystemTime,
1273 _: &mut AppContext,
1274 ) {
1275 unimplemented!()
1276 }
1277 }
1278}