1pub mod copilot_chat;
2mod copilot_completion_provider;
3pub mod request;
4mod sign_in;
5
6use ::fs::Fs;
7use anyhow::{anyhow, Context as _, Result};
8use async_compression::futures::bufread::GzipDecoder;
9use async_tar::Archive;
10use collections::{HashMap, HashSet};
11use command_palette_hooks::CommandPaletteFilter;
12use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
13use gpui::{
14 actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Global, Model,
15 ModelContext, Task, WeakModel,
16};
17use http_client::github::get_release_by_tag_name;
18use http_client::HttpClient;
19use language::{
20 language_settings::{all_language_settings, language_settings, InlineCompletionProvider},
21 point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
22 ToPointUtf16,
23};
24use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
25use node_runtime::NodeRuntime;
26use parking_lot::Mutex;
27use request::StatusNotification;
28use settings::SettingsStore;
29use smol::{fs, io::BufReader, stream::StreamExt};
30use std::{
31 any::TypeId,
32 env,
33 ffi::OsString,
34 mem,
35 ops::Range,
36 path::{Path, PathBuf},
37 sync::Arc,
38};
39use util::{fs::remove_matching, maybe, ResultExt};
40
41pub use crate::copilot_completion_provider::CopilotCompletionProvider;
42pub use crate::sign_in::{initiate_sign_in, CopilotCodeVerification};
43
44actions!(
45 copilot,
46 [
47 Suggest,
48 NextSuggestion,
49 PreviousSuggestion,
50 Reinstall,
51 SignIn,
52 SignOut
53 ]
54);
55
56pub fn init(
57 new_server_id: LanguageServerId,
58 fs: Arc<dyn Fs>,
59 http: Arc<dyn HttpClient>,
60 node_runtime: NodeRuntime,
61 cx: &mut AppContext,
62) {
63 copilot_chat::init(fs, http.clone(), cx);
64
65 let copilot = cx.new_model({
66 let node_runtime = node_runtime.clone();
67 move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
68 });
69 Copilot::set_global(copilot.clone(), cx);
70 cx.observe(&copilot, |handle, cx| {
71 let copilot_action_types = [
72 TypeId::of::<Suggest>(),
73 TypeId::of::<NextSuggestion>(),
74 TypeId::of::<PreviousSuggestion>(),
75 TypeId::of::<Reinstall>(),
76 ];
77 let copilot_auth_action_types = [TypeId::of::<SignOut>()];
78 let copilot_no_auth_action_types = [TypeId::of::<SignIn>()];
79 let status = handle.read(cx).status();
80 let filter = CommandPaletteFilter::global_mut(cx);
81
82 match status {
83 Status::Disabled => {
84 filter.hide_action_types(&copilot_action_types);
85 filter.hide_action_types(&copilot_auth_action_types);
86 filter.hide_action_types(&copilot_no_auth_action_types);
87 }
88 Status::Authorized => {
89 filter.hide_action_types(&copilot_no_auth_action_types);
90 filter.show_action_types(
91 copilot_action_types
92 .iter()
93 .chain(&copilot_auth_action_types),
94 );
95 }
96 _ => {
97 filter.hide_action_types(&copilot_action_types);
98 filter.hide_action_types(&copilot_auth_action_types);
99 filter.show_action_types(copilot_no_auth_action_types.iter());
100 }
101 }
102 })
103 .detach();
104
105 cx.on_action(|_: &SignIn, cx| {
106 if let Some(copilot) = Copilot::global(cx) {
107 copilot
108 .update(cx, |copilot, cx| copilot.sign_in(cx))
109 .detach_and_log_err(cx);
110 }
111 });
112 cx.on_action(|_: &SignOut, cx| {
113 if let Some(copilot) = Copilot::global(cx) {
114 copilot
115 .update(cx, |copilot, cx| copilot.sign_out(cx))
116 .detach_and_log_err(cx);
117 }
118 });
119 cx.on_action(|_: &Reinstall, cx| {
120 if let Some(copilot) = Copilot::global(cx) {
121 copilot
122 .update(cx, |copilot, cx| copilot.reinstall(cx))
123 .detach();
124 }
125 });
126}
127
128enum CopilotServer {
129 Disabled,
130 Starting { task: Shared<Task<()>> },
131 Error(Arc<str>),
132 Running(RunningCopilotServer),
133}
134
135impl CopilotServer {
136 fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> {
137 let server = self.as_running()?;
138 if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) {
139 Ok(server)
140 } else {
141 Err(anyhow!("must sign in before using copilot"))
142 }
143 }
144
145 fn as_running(&mut self) -> Result<&mut RunningCopilotServer> {
146 match self {
147 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
148 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
149 CopilotServer::Error(error) => Err(anyhow!(
150 "copilot was not started because of an error: {}",
151 error
152 )),
153 CopilotServer::Running(server) => Ok(server),
154 }
155 }
156}
157
158struct RunningCopilotServer {
159 lsp: Arc<LanguageServer>,
160 sign_in_status: SignInStatus,
161 registered_buffers: HashMap<EntityId, RegisteredBuffer>,
162}
163
164#[derive(Clone, Debug)]
165enum SignInStatus {
166 Authorized,
167 Unauthorized,
168 SigningIn {
169 prompt: Option<request::PromptUserDeviceFlow>,
170 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
171 },
172 SignedOut,
173}
174
175#[derive(Debug, Clone)]
176pub enum Status {
177 Starting {
178 task: Shared<Task<()>>,
179 },
180 Error(Arc<str>),
181 Disabled,
182 SignedOut,
183 SigningIn {
184 prompt: Option<request::PromptUserDeviceFlow>,
185 },
186 Unauthorized,
187 Authorized,
188}
189
190impl Status {
191 pub fn is_authorized(&self) -> bool {
192 matches!(self, Status::Authorized)
193 }
194
195 pub fn is_disabled(&self) -> bool {
196 matches!(self, Status::Disabled)
197 }
198}
199
200struct RegisteredBuffer {
201 uri: lsp::Url,
202 language_id: String,
203 snapshot: BufferSnapshot,
204 snapshot_version: i32,
205 _subscriptions: [gpui::Subscription; 2],
206 pending_buffer_change: Task<Option<()>>,
207}
208
209impl RegisteredBuffer {
210 fn report_changes(
211 &mut self,
212 buffer: &Model<Buffer>,
213 cx: &mut ModelContext<Copilot>,
214 ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
215 let (done_tx, done_rx) = oneshot::channel();
216
217 if buffer.read(cx).version() == self.snapshot.version {
218 let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
219 } else {
220 let buffer = buffer.downgrade();
221 let id = buffer.entity_id();
222 let prev_pending_change =
223 mem::replace(&mut self.pending_buffer_change, Task::ready(None));
224 self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move {
225 prev_pending_change.await;
226
227 let old_version = copilot
228 .update(&mut cx, |copilot, _| {
229 let server = copilot.server.as_authenticated().log_err()?;
230 let buffer = server.registered_buffers.get_mut(&id)?;
231 Some(buffer.snapshot.version.clone())
232 })
233 .ok()??;
234 let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?;
235
236 let content_changes = cx
237 .background_executor()
238 .spawn({
239 let new_snapshot = new_snapshot.clone();
240 async move {
241 new_snapshot
242 .edits_since::<(PointUtf16, usize)>(&old_version)
243 .map(|edit| {
244 let edit_start = edit.new.start.0;
245 let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0);
246 let new_text = new_snapshot
247 .text_for_range(edit.new.start.1..edit.new.end.1)
248 .collect();
249 lsp::TextDocumentContentChangeEvent {
250 range: Some(lsp::Range::new(
251 point_to_lsp(edit_start),
252 point_to_lsp(edit_end),
253 )),
254 range_length: None,
255 text: new_text,
256 }
257 })
258 .collect::<Vec<_>>()
259 }
260 })
261 .await;
262
263 copilot
264 .update(&mut cx, |copilot, _| {
265 let server = copilot.server.as_authenticated().log_err()?;
266 let buffer = server.registered_buffers.get_mut(&id)?;
267 if !content_changes.is_empty() {
268 buffer.snapshot_version += 1;
269 buffer.snapshot = new_snapshot;
270 server
271 .lsp
272 .notify::<lsp::notification::DidChangeTextDocument>(
273 lsp::DidChangeTextDocumentParams {
274 text_document: lsp::VersionedTextDocumentIdentifier::new(
275 buffer.uri.clone(),
276 buffer.snapshot_version,
277 ),
278 content_changes,
279 },
280 )
281 .log_err();
282 }
283 let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone()));
284 Some(())
285 })
286 .ok()?;
287
288 Some(())
289 });
290 }
291
292 done_rx
293 }
294}
295
296#[derive(Debug)]
297pub struct Completion {
298 pub uuid: String,
299 pub range: Range<Anchor>,
300 pub text: String,
301}
302
303pub struct Copilot {
304 http: Arc<dyn HttpClient>,
305 node_runtime: NodeRuntime,
306 server: CopilotServer,
307 buffers: HashSet<WeakModel<Buffer>>,
308 server_id: LanguageServerId,
309 _subscription: gpui::Subscription,
310}
311
312pub enum Event {
313 CopilotLanguageServerStarted,
314 CopilotAuthSignedIn,
315 CopilotAuthSignedOut,
316}
317
318impl EventEmitter<Event> for Copilot {}
319
320struct GlobalCopilot(Model<Copilot>);
321
322impl Global for GlobalCopilot {}
323
324impl Copilot {
325 pub fn global(cx: &AppContext) -> Option<Model<Self>> {
326 cx.try_global::<GlobalCopilot>()
327 .map(|model| model.0.clone())
328 }
329
330 pub fn set_global(copilot: Model<Self>, cx: &mut AppContext) {
331 cx.set_global(GlobalCopilot(copilot));
332 }
333
334 fn start(
335 new_server_id: LanguageServerId,
336 http: Arc<dyn HttpClient>,
337 node_runtime: NodeRuntime,
338 cx: &mut ModelContext<Self>,
339 ) -> Self {
340 let mut this = Self {
341 server_id: new_server_id,
342 http,
343 node_runtime,
344 server: CopilotServer::Disabled,
345 buffers: Default::default(),
346 _subscription: cx.on_app_quit(Self::shutdown_language_server),
347 };
348 this.enable_or_disable_copilot(cx);
349 cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
350 .detach();
351 this
352 }
353
354 fn shutdown_language_server(
355 &mut self,
356 _cx: &mut ModelContext<Self>,
357 ) -> impl Future<Output = ()> {
358 let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) {
359 CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })),
360 _ => None,
361 };
362
363 async move {
364 if let Some(shutdown) = shutdown {
365 shutdown.await;
366 }
367 }
368 }
369
370 fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext<Self>) {
371 let server_id = self.server_id;
372 let http = self.http.clone();
373 let node_runtime = self.node_runtime.clone();
374 if all_language_settings(None, cx).inline_completions.provider
375 == InlineCompletionProvider::Copilot
376 {
377 if matches!(self.server, CopilotServer::Disabled) {
378 let start_task = cx
379 .spawn(move |this, cx| {
380 Self::start_language_server(server_id, http, node_runtime, this, cx)
381 })
382 .shared();
383 self.server = CopilotServer::Starting { task: start_task };
384 cx.notify();
385 }
386 } else {
387 self.server = CopilotServer::Disabled;
388 cx.notify();
389 }
390 }
391
392 #[cfg(any(test, feature = "test-support"))]
393 pub fn fake(cx: &mut gpui::TestAppContext) -> (Model<Self>, lsp::FakeLanguageServer) {
394 use lsp::FakeLanguageServer;
395 use node_runtime::NodeRuntime;
396
397 let (server, fake_server) = FakeLanguageServer::new(
398 LanguageServerId(0),
399 LanguageServerBinary {
400 path: "path/to/copilot".into(),
401 arguments: vec![],
402 env: None,
403 },
404 "copilot".into(),
405 Default::default(),
406 cx.to_async(),
407 );
408 let http = http_client::FakeHttpClient::create(|_| async { unreachable!() });
409 let node_runtime = NodeRuntime::unavailable();
410 let this = cx.new_model(|cx| Self {
411 server_id: LanguageServerId(0),
412 http: http.clone(),
413 node_runtime,
414 server: CopilotServer::Running(RunningCopilotServer {
415 lsp: Arc::new(server),
416 sign_in_status: SignInStatus::Authorized,
417 registered_buffers: Default::default(),
418 }),
419 _subscription: cx.on_app_quit(Self::shutdown_language_server),
420 buffers: Default::default(),
421 });
422 (this, fake_server)
423 }
424
425 async fn start_language_server(
426 new_server_id: LanguageServerId,
427 http: Arc<dyn HttpClient>,
428 node_runtime: NodeRuntime,
429 this: WeakModel<Self>,
430 mut cx: AsyncAppContext,
431 ) {
432 let start_language_server = async {
433 let server_path = get_copilot_lsp(http).await?;
434 let node_path = node_runtime.binary_path().await?;
435 let arguments: Vec<OsString> = vec![server_path.into(), "--stdio".into()];
436 let binary = LanguageServerBinary {
437 path: node_path,
438 arguments,
439 // TODO: We could set HTTP_PROXY etc here and fix the copilot issue.
440 env: None,
441 };
442
443 let root_path = if cfg!(target_os = "windows") {
444 Path::new("C:/")
445 } else {
446 Path::new("/")
447 };
448
449 let server_name = LanguageServerName("copilot".into());
450 let server = LanguageServer::new(
451 Arc::new(Mutex::new(None)),
452 new_server_id,
453 server_name,
454 binary,
455 root_path,
456 None,
457 cx.clone(),
458 )?;
459
460 server
461 .on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
462 .detach();
463 let server = cx.update(|cx| server.initialize(None, cx))?.await?;
464
465 let status = server
466 .request::<request::CheckStatus>(request::CheckStatusParams {
467 local_checks_only: false,
468 })
469 .await?;
470
471 server
472 .request::<request::SetEditorInfo>(request::SetEditorInfoParams {
473 editor_info: request::EditorInfo {
474 name: "zed".into(),
475 version: env!("CARGO_PKG_VERSION").into(),
476 },
477 editor_plugin_info: request::EditorPluginInfo {
478 name: "zed-copilot".into(),
479 version: "0.0.1".into(),
480 },
481 })
482 .await?;
483
484 anyhow::Ok((server, status))
485 };
486
487 let server = start_language_server.await;
488 this.update(&mut cx, |this, cx| {
489 cx.notify();
490 match server {
491 Ok((server, status)) => {
492 this.server = CopilotServer::Running(RunningCopilotServer {
493 lsp: server,
494 sign_in_status: SignInStatus::SignedOut,
495 registered_buffers: Default::default(),
496 });
497 cx.emit(Event::CopilotLanguageServerStarted);
498 this.update_sign_in_status(status, cx);
499 }
500 Err(error) => {
501 this.server = CopilotServer::Error(error.to_string().into());
502 cx.notify()
503 }
504 }
505 })
506 .ok();
507 }
508
509 pub fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
510 if let CopilotServer::Running(server) = &mut self.server {
511 let task = match &server.sign_in_status {
512 SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(),
513 SignInStatus::SigningIn { task, .. } => {
514 cx.notify();
515 task.clone()
516 }
517 SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
518 let lsp = server.lsp.clone();
519 let task = cx
520 .spawn(|this, mut cx| async move {
521 let sign_in = async {
522 let sign_in = lsp
523 .request::<request::SignInInitiate>(
524 request::SignInInitiateParams {},
525 )
526 .await?;
527 match sign_in {
528 request::SignInInitiateResult::AlreadySignedIn { user } => {
529 Ok(request::SignInStatus::Ok { user: Some(user) })
530 }
531 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
532 this.update(&mut cx, |this, cx| {
533 if let CopilotServer::Running(RunningCopilotServer {
534 sign_in_status: status,
535 ..
536 }) = &mut this.server
537 {
538 if let SignInStatus::SigningIn {
539 prompt: prompt_flow,
540 ..
541 } = status
542 {
543 *prompt_flow = Some(flow.clone());
544 cx.notify();
545 }
546 }
547 })?;
548 let response = lsp
549 .request::<request::SignInConfirm>(
550 request::SignInConfirmParams {
551 user_code: flow.user_code,
552 },
553 )
554 .await?;
555 Ok(response)
556 }
557 }
558 };
559
560 let sign_in = sign_in.await;
561 this.update(&mut cx, |this, cx| match sign_in {
562 Ok(status) => {
563 this.update_sign_in_status(status, cx);
564 Ok(())
565 }
566 Err(error) => {
567 this.update_sign_in_status(
568 request::SignInStatus::NotSignedIn,
569 cx,
570 );
571 Err(Arc::new(error))
572 }
573 })?
574 })
575 .shared();
576 server.sign_in_status = SignInStatus::SigningIn {
577 prompt: None,
578 task: task.clone(),
579 };
580 cx.notify();
581 task
582 }
583 };
584
585 cx.background_executor()
586 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
587 } else {
588 // If we're downloading, wait until download is finished
589 // If we're in a stuck state, display to the user
590 Task::ready(Err(anyhow!("copilot hasn't started yet")))
591 }
592 }
593
594 pub fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
595 self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx);
596 if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server {
597 let server = server.clone();
598 cx.background_executor().spawn(async move {
599 server
600 .request::<request::SignOut>(request::SignOutParams {})
601 .await?;
602 anyhow::Ok(())
603 })
604 } else {
605 Task::ready(Err(anyhow!("copilot hasn't started yet")))
606 }
607 }
608
609 pub fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
610 let start_task = cx
611 .spawn({
612 let http = self.http.clone();
613 let node_runtime = self.node_runtime.clone();
614 let server_id = self.server_id;
615 move |this, cx| async move {
616 clear_copilot_dir().await;
617 Self::start_language_server(server_id, http, node_runtime, this, cx).await
618 }
619 })
620 .shared();
621
622 self.server = CopilotServer::Starting {
623 task: start_task.clone(),
624 };
625
626 cx.notify();
627
628 cx.background_executor().spawn(start_task)
629 }
630
631 pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
632 if let CopilotServer::Running(server) = &self.server {
633 Some(&server.lsp)
634 } else {
635 None
636 }
637 }
638
639 pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
640 let weak_buffer = buffer.downgrade();
641 self.buffers.insert(weak_buffer.clone());
642
643 if let CopilotServer::Running(RunningCopilotServer {
644 lsp: server,
645 sign_in_status: status,
646 registered_buffers,
647 ..
648 }) = &mut self.server
649 {
650 if !matches!(status, SignInStatus::Authorized { .. }) {
651 return;
652 }
653
654 registered_buffers
655 .entry(buffer.entity_id())
656 .or_insert_with(|| {
657 let uri: lsp::Url = uri_for_buffer(buffer, cx);
658 let language_id = id_for_language(buffer.read(cx).language());
659 let snapshot = buffer.read(cx).snapshot();
660 server
661 .notify::<lsp::notification::DidOpenTextDocument>(
662 lsp::DidOpenTextDocumentParams {
663 text_document: lsp::TextDocumentItem {
664 uri: uri.clone(),
665 language_id: language_id.clone(),
666 version: 0,
667 text: snapshot.text(),
668 },
669 },
670 )
671 .log_err();
672
673 RegisteredBuffer {
674 uri,
675 language_id,
676 snapshot,
677 snapshot_version: 0,
678 pending_buffer_change: Task::ready(Some(())),
679 _subscriptions: [
680 cx.subscribe(buffer, |this, buffer, event, cx| {
681 this.handle_buffer_event(buffer, event, cx).log_err();
682 }),
683 cx.observe_release(buffer, move |this, _buffer, _cx| {
684 this.buffers.remove(&weak_buffer);
685 this.unregister_buffer(&weak_buffer);
686 }),
687 ],
688 }
689 });
690 }
691 }
692
693 fn handle_buffer_event(
694 &mut self,
695 buffer: Model<Buffer>,
696 event: &language::BufferEvent,
697 cx: &mut ModelContext<Self>,
698 ) -> Result<()> {
699 if let Ok(server) = self.server.as_running() {
700 if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
701 {
702 match event {
703 language::BufferEvent::Edited => {
704 drop(registered_buffer.report_changes(&buffer, cx));
705 }
706 language::BufferEvent::Saved => {
707 server
708 .lsp
709 .notify::<lsp::notification::DidSaveTextDocument>(
710 lsp::DidSaveTextDocumentParams {
711 text_document: lsp::TextDocumentIdentifier::new(
712 registered_buffer.uri.clone(),
713 ),
714 text: None,
715 },
716 )?;
717 }
718 language::BufferEvent::FileHandleChanged
719 | language::BufferEvent::LanguageChanged => {
720 let new_language_id = id_for_language(buffer.read(cx).language());
721 let new_uri = uri_for_buffer(&buffer, cx);
722 if new_uri != registered_buffer.uri
723 || new_language_id != registered_buffer.language_id
724 {
725 let old_uri = mem::replace(&mut registered_buffer.uri, new_uri);
726 registered_buffer.language_id = new_language_id;
727 server
728 .lsp
729 .notify::<lsp::notification::DidCloseTextDocument>(
730 lsp::DidCloseTextDocumentParams {
731 text_document: lsp::TextDocumentIdentifier::new(old_uri),
732 },
733 )?;
734 server
735 .lsp
736 .notify::<lsp::notification::DidOpenTextDocument>(
737 lsp::DidOpenTextDocumentParams {
738 text_document: lsp::TextDocumentItem::new(
739 registered_buffer.uri.clone(),
740 registered_buffer.language_id.clone(),
741 registered_buffer.snapshot_version,
742 registered_buffer.snapshot.text(),
743 ),
744 },
745 )?;
746 }
747 }
748 _ => {}
749 }
750 }
751 }
752
753 Ok(())
754 }
755
756 fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
757 if let Ok(server) = self.server.as_running() {
758 if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
759 server
760 .lsp
761 .notify::<lsp::notification::DidCloseTextDocument>(
762 lsp::DidCloseTextDocumentParams {
763 text_document: lsp::TextDocumentIdentifier::new(buffer.uri),
764 },
765 )
766 .log_err();
767 }
768 }
769 }
770
771 pub fn completions<T>(
772 &mut self,
773 buffer: &Model<Buffer>,
774 position: T,
775 cx: &mut ModelContext<Self>,
776 ) -> Task<Result<Vec<Completion>>>
777 where
778 T: ToPointUtf16,
779 {
780 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
781 }
782
783 pub fn completions_cycling<T>(
784 &mut self,
785 buffer: &Model<Buffer>,
786 position: T,
787 cx: &mut ModelContext<Self>,
788 ) -> Task<Result<Vec<Completion>>>
789 where
790 T: ToPointUtf16,
791 {
792 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
793 }
794
795 pub fn accept_completion(
796 &mut self,
797 completion: &Completion,
798 cx: &mut ModelContext<Self>,
799 ) -> Task<Result<()>> {
800 let server = match self.server.as_authenticated() {
801 Ok(server) => server,
802 Err(error) => return Task::ready(Err(error)),
803 };
804 let request =
805 server
806 .lsp
807 .request::<request::NotifyAccepted>(request::NotifyAcceptedParams {
808 uuid: completion.uuid.clone(),
809 });
810 cx.background_executor().spawn(async move {
811 request.await?;
812 Ok(())
813 })
814 }
815
816 pub fn discard_completions(
817 &mut self,
818 completions: &[Completion],
819 cx: &mut ModelContext<Self>,
820 ) -> Task<Result<()>> {
821 let server = match self.server.as_authenticated() {
822 Ok(server) => server,
823 Err(_) => return Task::ready(Ok(())),
824 };
825 let request =
826 server
827 .lsp
828 .request::<request::NotifyRejected>(request::NotifyRejectedParams {
829 uuids: completions
830 .iter()
831 .map(|completion| completion.uuid.clone())
832 .collect(),
833 });
834 cx.background_executor().spawn(async move {
835 request.await?;
836 Ok(())
837 })
838 }
839
840 fn request_completions<R, T>(
841 &mut self,
842 buffer: &Model<Buffer>,
843 position: T,
844 cx: &mut ModelContext<Self>,
845 ) -> Task<Result<Vec<Completion>>>
846 where
847 R: 'static
848 + lsp::request::Request<
849 Params = request::GetCompletionsParams,
850 Result = request::GetCompletionsResult,
851 >,
852 T: ToPointUtf16,
853 {
854 self.register_buffer(buffer, cx);
855
856 let server = match self.server.as_authenticated() {
857 Ok(server) => server,
858 Err(error) => return Task::ready(Err(error)),
859 };
860 let lsp = server.lsp.clone();
861 let registered_buffer = server
862 .registered_buffers
863 .get_mut(&buffer.entity_id())
864 .unwrap();
865 let snapshot = registered_buffer.report_changes(buffer, cx);
866 let buffer = buffer.read(cx);
867 let uri = registered_buffer.uri.clone();
868 let position = position.to_point_utf16(buffer);
869 let settings = language_settings(
870 buffer.language_at(position).map(|l| l.name()),
871 buffer.file(),
872 cx,
873 );
874 let tab_size = settings.tab_size;
875 let hard_tabs = settings.hard_tabs;
876 let relative_path = buffer
877 .file()
878 .map(|file| file.path().to_path_buf())
879 .unwrap_or_default();
880
881 cx.background_executor().spawn(async move {
882 let (version, snapshot) = snapshot.await?;
883 let result = lsp
884 .request::<R>(request::GetCompletionsParams {
885 doc: request::GetCompletionsDocument {
886 uri,
887 tab_size: tab_size.into(),
888 indent_size: 1,
889 insert_spaces: !hard_tabs,
890 relative_path: relative_path.to_string_lossy().into(),
891 position: point_to_lsp(position),
892 version: version.try_into().unwrap(),
893 },
894 })
895 .await?;
896 let completions = result
897 .completions
898 .into_iter()
899 .map(|completion| {
900 let start = snapshot
901 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
902 let end =
903 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
904 Completion {
905 uuid: completion.uuid,
906 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
907 text: completion.text,
908 }
909 })
910 .collect();
911 anyhow::Ok(completions)
912 })
913 }
914
915 pub fn status(&self) -> Status {
916 match &self.server {
917 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
918 CopilotServer::Disabled => Status::Disabled,
919 CopilotServer::Error(error) => Status::Error(error.clone()),
920 CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => {
921 match sign_in_status {
922 SignInStatus::Authorized { .. } => Status::Authorized,
923 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
924 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
925 prompt: prompt.clone(),
926 },
927 SignInStatus::SignedOut => Status::SignedOut,
928 }
929 }
930 }
931 }
932
933 fn update_sign_in_status(
934 &mut self,
935 lsp_status: request::SignInStatus,
936 cx: &mut ModelContext<Self>,
937 ) {
938 self.buffers.retain(|buffer| buffer.is_upgradable());
939
940 if let Ok(server) = self.server.as_running() {
941 match lsp_status {
942 request::SignInStatus::Ok { user: Some(_) }
943 | request::SignInStatus::MaybeOk { .. }
944 | request::SignInStatus::AlreadySignedIn { .. } => {
945 server.sign_in_status = SignInStatus::Authorized;
946 cx.emit(Event::CopilotAuthSignedIn);
947 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
948 if let Some(buffer) = buffer.upgrade() {
949 self.register_buffer(&buffer, cx);
950 }
951 }
952 }
953 request::SignInStatus::NotAuthorized { .. } => {
954 server.sign_in_status = SignInStatus::Unauthorized;
955 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
956 self.unregister_buffer(&buffer);
957 }
958 }
959 request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
960 server.sign_in_status = SignInStatus::SignedOut;
961 cx.emit(Event::CopilotAuthSignedOut);
962 for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
963 self.unregister_buffer(&buffer);
964 }
965 }
966 }
967
968 cx.notify();
969 }
970 }
971}
972
973fn id_for_language(language: Option<&Arc<Language>>) -> String {
974 language
975 .map(|language| language.lsp_id())
976 .unwrap_or_else(|| "plaintext".to_string())
977}
978
979fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp::Url {
980 if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
981 lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
982 } else {
983 format!("buffer://{}", buffer.entity_id()).parse().unwrap()
984 }
985}
986
987async fn clear_copilot_dir() {
988 remove_matching(paths::copilot_dir(), |_| true).await
989}
990
991async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
992 const SERVER_PATH: &str = "dist/language-server.js";
993
994 ///Check for the latest copilot language server and download it if we haven't already
995 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
996 let release =
997 get_release_by_tag_name("zed-industries/copilot", "v0.7.0", http.clone()).await?;
998
999 let version_dir = &paths::copilot_dir().join(format!("copilot-{}", release.tag_name));
1000
1001 fs::create_dir_all(version_dir).await?;
1002 let server_path = version_dir.join(SERVER_PATH);
1003
1004 if fs::metadata(&server_path).await.is_err() {
1005 // Copilot LSP looks for this dist dir specifically, so lets add it in.
1006 let dist_dir = version_dir.join("dist");
1007 fs::create_dir_all(dist_dir.as_path()).await?;
1008
1009 let url = &release
1010 .assets
1011 .first()
1012 .context("Github release for copilot contained no assets")?
1013 .browser_download_url;
1014
1015 let mut response = http
1016 .get(url, Default::default(), true)
1017 .await
1018 .context("error downloading copilot release")?;
1019 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
1020 let archive = Archive::new(decompressed_bytes);
1021 archive.unpack(dist_dir).await?;
1022
1023 remove_matching(paths::copilot_dir(), |entry| entry != version_dir).await;
1024 }
1025
1026 Ok(server_path)
1027 }
1028
1029 match fetch_latest(http).await {
1030 ok @ Result::Ok(..) => ok,
1031 e @ Err(..) => {
1032 e.log_err();
1033 // Fetch a cached binary, if it exists
1034 maybe!(async {
1035 let mut last_version_dir = None;
1036 let mut entries = fs::read_dir(paths::copilot_dir()).await?;
1037 while let Some(entry) = entries.next().await {
1038 let entry = entry?;
1039 if entry.file_type().await?.is_dir() {
1040 last_version_dir = Some(entry.path());
1041 }
1042 }
1043 let last_version_dir =
1044 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
1045 let server_path = last_version_dir.join(SERVER_PATH);
1046 if server_path.exists() {
1047 Ok(server_path)
1048 } else {
1049 Err(anyhow!(
1050 "missing executable in directory {:?}",
1051 last_version_dir
1052 ))
1053 }
1054 })
1055 .await
1056 }
1057 }
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062 use super::*;
1063 use gpui::TestAppContext;
1064
1065 #[gpui::test(iterations = 10)]
1066 async fn test_buffer_management(cx: &mut TestAppContext) {
1067 let (copilot, mut lsp) = Copilot::fake(cx);
1068
1069 let buffer_1 = cx.new_model(|cx| Buffer::local("Hello", cx));
1070 let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64())
1071 .parse()
1072 .unwrap();
1073 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx));
1074 assert_eq!(
1075 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1076 .await,
1077 lsp::DidOpenTextDocumentParams {
1078 text_document: lsp::TextDocumentItem::new(
1079 buffer_1_uri.clone(),
1080 "plaintext".into(),
1081 0,
1082 "Hello".into()
1083 ),
1084 }
1085 );
1086
1087 let buffer_2 = cx.new_model(|cx| Buffer::local("Goodbye", cx));
1088 let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64())
1089 .parse()
1090 .unwrap();
1091 copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx));
1092 assert_eq!(
1093 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1094 .await,
1095 lsp::DidOpenTextDocumentParams {
1096 text_document: lsp::TextDocumentItem::new(
1097 buffer_2_uri.clone(),
1098 "plaintext".into(),
1099 0,
1100 "Goodbye".into()
1101 ),
1102 }
1103 );
1104
1105 buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx));
1106 assert_eq!(
1107 lsp.receive_notification::<lsp::notification::DidChangeTextDocument>()
1108 .await,
1109 lsp::DidChangeTextDocumentParams {
1110 text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1),
1111 content_changes: vec![lsp::TextDocumentContentChangeEvent {
1112 range: Some(lsp::Range::new(
1113 lsp::Position::new(0, 5),
1114 lsp::Position::new(0, 5)
1115 )),
1116 range_length: None,
1117 text: " world".into(),
1118 }],
1119 }
1120 );
1121
1122 // Ensure updates to the file are reflected in the LSP.
1123 buffer_1.update(cx, |buffer, cx| {
1124 buffer.file_updated(
1125 Arc::new(File {
1126 abs_path: "/root/child/buffer-1".into(),
1127 path: Path::new("child/buffer-1").into(),
1128 }),
1129 cx,
1130 )
1131 });
1132 assert_eq!(
1133 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1134 .await,
1135 lsp::DidCloseTextDocumentParams {
1136 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri),
1137 }
1138 );
1139 let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap();
1140 assert_eq!(
1141 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1142 .await,
1143 lsp::DidOpenTextDocumentParams {
1144 text_document: lsp::TextDocumentItem::new(
1145 buffer_1_uri.clone(),
1146 "plaintext".into(),
1147 1,
1148 "Hello world".into()
1149 ),
1150 }
1151 );
1152
1153 // Ensure all previously-registered buffers are closed when signing out.
1154 lsp.handle_request::<request::SignOut, _, _>(|_, _| async {
1155 Ok(request::SignOutResult {})
1156 });
1157 copilot
1158 .update(cx, |copilot, cx| copilot.sign_out(cx))
1159 .await
1160 .unwrap();
1161 assert_eq!(
1162 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1163 .await,
1164 lsp::DidCloseTextDocumentParams {
1165 text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()),
1166 }
1167 );
1168 assert_eq!(
1169 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1170 .await,
1171 lsp::DidCloseTextDocumentParams {
1172 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()),
1173 }
1174 );
1175
1176 // Ensure all previously-registered buffers are re-opened when signing in.
1177 lsp.handle_request::<request::SignInInitiate, _, _>(|_, _| async {
1178 Ok(request::SignInInitiateResult::AlreadySignedIn {
1179 user: "user-1".into(),
1180 })
1181 });
1182 copilot
1183 .update(cx, |copilot, cx| copilot.sign_in(cx))
1184 .await
1185 .unwrap();
1186
1187 assert_eq!(
1188 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1189 .await,
1190 lsp::DidOpenTextDocumentParams {
1191 text_document: lsp::TextDocumentItem::new(
1192 buffer_1_uri.clone(),
1193 "plaintext".into(),
1194 0,
1195 "Hello world".into()
1196 ),
1197 }
1198 );
1199 assert_eq!(
1200 lsp.receive_notification::<lsp::notification::DidOpenTextDocument>()
1201 .await,
1202 lsp::DidOpenTextDocumentParams {
1203 text_document: lsp::TextDocumentItem::new(
1204 buffer_2_uri.clone(),
1205 "plaintext".into(),
1206 0,
1207 "Goodbye".into()
1208 ),
1209 }
1210 );
1211 // Dropping a buffer causes it to be closed on the LSP side as well.
1212 cx.update(|_| drop(buffer_2));
1213 assert_eq!(
1214 lsp.receive_notification::<lsp::notification::DidCloseTextDocument>()
1215 .await,
1216 lsp::DidCloseTextDocumentParams {
1217 text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri),
1218 }
1219 );
1220 }
1221
1222 struct File {
1223 abs_path: PathBuf,
1224 path: Arc<Path>,
1225 }
1226
1227 impl language::File for File {
1228 fn as_local(&self) -> Option<&dyn language::LocalFile> {
1229 Some(self)
1230 }
1231
1232 fn disk_state(&self) -> language::DiskState {
1233 language::DiskState::Present {
1234 mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
1235 }
1236 }
1237
1238 fn path(&self) -> &Arc<Path> {
1239 &self.path
1240 }
1241
1242 fn full_path(&self, _: &AppContext) -> PathBuf {
1243 unimplemented!()
1244 }
1245
1246 fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr {
1247 unimplemented!()
1248 }
1249
1250 fn as_any(&self) -> &dyn std::any::Any {
1251 unimplemented!()
1252 }
1253
1254 fn to_proto(&self, _: &AppContext) -> rpc::proto::File {
1255 unimplemented!()
1256 }
1257
1258 fn worktree_id(&self, _: &AppContext) -> settings::WorktreeId {
1259 settings::WorktreeId::from_usize(0)
1260 }
1261
1262 fn is_private(&self) -> bool {
1263 false
1264 }
1265 }
1266
1267 impl language::LocalFile for File {
1268 fn abs_path(&self, _: &AppContext) -> PathBuf {
1269 self.abs_path.clone()
1270 }
1271
1272 fn load(&self, _: &AppContext) -> Task<Result<String>> {
1273 unimplemented!()
1274 }
1275
1276 fn load_bytes(&self, _cx: &AppContext) -> Task<Result<Vec<u8>>> {
1277 unimplemented!()
1278 }
1279 }
1280}