1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use collections::HashMap;
9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
10use gpui::{
11 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
12 Task,
13};
14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
15use log::{debug, error};
16use lsp::LanguageServer;
17use node_runtime::NodeRuntime;
18use request::{LogMessage, StatusNotification};
19use settings::Settings;
20use smol::{fs, io::BufReader, stream::StreamExt};
21use staff_mode::{not_staff_mode, staff_mode};
22
23use std::{
24 ffi::OsString,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use util::{
30 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
31};
32
33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
34actions!(copilot_auth, [SignIn, SignOut]);
35
36const COPILOT_NAMESPACE: &'static str = "copilot";
37actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
38
39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
40 staff_mode(cx, {
41 move |cx| {
42 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
43 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
44 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
45 });
46
47 let copilot = cx.add_model({
48 let node_runtime = node_runtime.clone();
49 let http = client.http_client().clone();
50 move |cx| Copilot::start(http, node_runtime, cx)
51 });
52 cx.set_global(copilot.clone());
53
54 observe_namespaces(cx, copilot);
55
56 sign_in::init(cx);
57 }
58 });
59 not_staff_mode(cx, |cx| {
60 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
61 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
62 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
63 });
64 });
65
66 cx.add_global_action(|_: &SignIn, cx| {
67 if let Some(copilot) = Copilot::global(cx) {
68 copilot
69 .update(cx, |copilot, cx| copilot.sign_in(cx))
70 .detach_and_log_err(cx);
71 }
72 });
73 cx.add_global_action(|_: &SignOut, cx| {
74 if let Some(copilot) = Copilot::global(cx) {
75 copilot
76 .update(cx, |copilot, cx| copilot.sign_out(cx))
77 .detach_and_log_err(cx);
78 }
79 });
80
81 cx.add_global_action(|_: &Reinstall, cx| {
82 if let Some(copilot) = Copilot::global(cx) {
83 copilot
84 .update(cx, |copilot, cx| copilot.reinstall(cx))
85 .detach();
86 }
87 });
88}
89
90fn observe_namespaces(cx: &mut MutableAppContext, copilot: ModelHandle<Copilot>) {
91 cx.observe(&copilot, |handle, cx| {
92 let status = handle.read(cx).status();
93 cx.update_global::<collections::CommandPaletteFilter, _, _>(
94 move |filter, _cx| match status {
95 Status::Disabled => {
96 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
97 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
98 }
99 Status::Authorized => {
100 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
101 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
102 }
103 _ => {
104 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
105 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
106 }
107 },
108 );
109 })
110 .detach();
111}
112
113enum CopilotServer {
114 Disabled,
115 Starting {
116 task: Shared<Task<()>>,
117 },
118 Error(Arc<str>),
119 Started {
120 server: Arc<LanguageServer>,
121 status: SignInStatus,
122 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
123 },
124}
125
126#[derive(Clone, Debug)]
127enum SignInStatus {
128 Authorized,
129 Unauthorized,
130 SigningIn {
131 prompt: Option<request::PromptUserDeviceFlow>,
132 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
133 },
134 SignedOut,
135}
136
137#[derive(Debug, Clone)]
138pub enum Status {
139 Starting {
140 task: Shared<Task<()>>,
141 },
142 Error(Arc<str>),
143 Disabled,
144 SignedOut,
145 SigningIn {
146 prompt: Option<request::PromptUserDeviceFlow>,
147 },
148 Unauthorized,
149 Authorized,
150}
151
152impl Status {
153 pub fn is_authorized(&self) -> bool {
154 matches!(self, Status::Authorized)
155 }
156}
157
158#[derive(Debug, PartialEq, Eq)]
159pub struct Completion {
160 pub range: Range<Anchor>,
161 pub text: String,
162}
163
164pub struct Copilot {
165 http: Arc<dyn HttpClient>,
166 node_runtime: Arc<NodeRuntime>,
167 server: CopilotServer,
168}
169
170impl Entity for Copilot {
171 type Event = ();
172}
173
174impl Copilot {
175 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
176 if cx.has_global::<ModelHandle<Self>>() {
177 Some(cx.global::<ModelHandle<Self>>().clone())
178 } else {
179 None
180 }
181 }
182
183 fn start(
184 http: Arc<dyn HttpClient>,
185 node_runtime: Arc<NodeRuntime>,
186 cx: &mut ModelContext<Self>,
187 ) -> Self {
188 cx.observe_global::<Settings, _>({
189 let http = http.clone();
190 let node_runtime = node_runtime.clone();
191 move |this, cx| {
192 if cx.global::<Settings>().enable_copilot_integration {
193 if matches!(this.server, CopilotServer::Disabled) {
194 let start_task = cx
195 .spawn({
196 let http = http.clone();
197 let node_runtime = node_runtime.clone();
198 move |this, cx| {
199 Self::start_language_server(http, node_runtime, this, cx)
200 }
201 })
202 .shared();
203 this.server = CopilotServer::Starting { task: start_task };
204 cx.notify();
205 }
206 } else {
207 this.server = CopilotServer::Disabled;
208 cx.notify();
209 }
210 }
211 })
212 .detach();
213
214 if cx.global::<Settings>().enable_copilot_integration {
215 let start_task = cx
216 .spawn({
217 let http = http.clone();
218 let node_runtime = node_runtime.clone();
219 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
220 })
221 .shared();
222
223 Self {
224 http,
225 node_runtime,
226 server: CopilotServer::Starting { task: start_task },
227 }
228 } else {
229 Self {
230 http,
231 node_runtime,
232 server: CopilotServer::Disabled,
233 }
234 }
235 }
236
237 #[cfg(any(test, feature = "test-support"))]
238 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
239 let (server, fake_server) =
240 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
241 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
242 let this = cx.add_model(|cx| Self {
243 http: http.clone(),
244 node_runtime: NodeRuntime::new(http, cx.background().clone()),
245 server: CopilotServer::Started {
246 server: Arc::new(server),
247 status: SignInStatus::Authorized,
248 subscriptions_by_buffer_id: Default::default(),
249 },
250 });
251 (this, fake_server)
252 }
253
254 fn start_language_server(
255 http: Arc<dyn HttpClient>,
256 node_runtime: Arc<NodeRuntime>,
257 this: ModelHandle<Self>,
258 mut cx: AsyncAppContext,
259 ) -> impl Future<Output = ()> {
260 async move {
261 let start_language_server = async {
262 let server_path = get_copilot_lsp(http).await?;
263 let node_path = node_runtime.binary_path().await?;
264 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
265 let server = LanguageServer::new(
266 0,
267 &node_path,
268 arguments,
269 Path::new("/"),
270 None,
271 cx.clone(),
272 )?;
273
274 let server = server.initialize(Default::default()).await?;
275 let status = server
276 .request::<request::CheckStatus>(request::CheckStatusParams {
277 local_checks_only: false,
278 })
279 .await?;
280
281 server
282 .on_notification::<LogMessage, _>(|params, _cx| {
283 match params.level {
284 // Copilot is pretty agressive about logging
285 0 => debug!("copilot: {}", params.message),
286 1 => debug!("copilot: {}", params.message),
287 _ => error!("copilot: {}", params.message),
288 }
289
290 debug!("copilot metadata: {}", params.metadata_str);
291 debug!("copilot extra: {:?}", params.extra);
292 })
293 .detach();
294
295 server
296 .on_notification::<StatusNotification, _>(
297 |_, _| { /* Silence the notification */ },
298 )
299 .detach();
300
301 anyhow::Ok((server, status))
302 };
303
304 let server = start_language_server.await;
305 this.update(&mut cx, |this, cx| {
306 cx.notify();
307 match server {
308 Ok((server, status)) => {
309 this.server = CopilotServer::Started {
310 server,
311 status: SignInStatus::SignedOut,
312 subscriptions_by_buffer_id: Default::default(),
313 };
314 this.update_sign_in_status(status, cx);
315 }
316 Err(error) => {
317 this.server = CopilotServer::Error(error.to_string().into());
318 cx.notify()
319 }
320 }
321 })
322 }
323 }
324
325 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
326 if let CopilotServer::Started { server, status, .. } = &mut self.server {
327 let task = match status {
328 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
329 Task::ready(Ok(())).shared()
330 }
331 SignInStatus::SigningIn { task, .. } => {
332 cx.notify();
333 task.clone()
334 }
335 SignInStatus::SignedOut => {
336 let server = server.clone();
337 let task = cx
338 .spawn(|this, mut cx| async move {
339 let sign_in = async {
340 let sign_in = server
341 .request::<request::SignInInitiate>(
342 request::SignInInitiateParams {},
343 )
344 .await?;
345 match sign_in {
346 request::SignInInitiateResult::AlreadySignedIn { user } => {
347 Ok(request::SignInStatus::Ok { user })
348 }
349 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
350 this.update(&mut cx, |this, cx| {
351 if let CopilotServer::Started { status, .. } =
352 &mut this.server
353 {
354 if let SignInStatus::SigningIn {
355 prompt: prompt_flow,
356 ..
357 } = status
358 {
359 *prompt_flow = Some(flow.clone());
360 cx.notify();
361 }
362 }
363 });
364 let response = server
365 .request::<request::SignInConfirm>(
366 request::SignInConfirmParams {
367 user_code: flow.user_code,
368 },
369 )
370 .await?;
371 Ok(response)
372 }
373 }
374 };
375
376 let sign_in = sign_in.await;
377 this.update(&mut cx, |this, cx| match sign_in {
378 Ok(status) => {
379 this.update_sign_in_status(status, cx);
380 Ok(())
381 }
382 Err(error) => {
383 this.update_sign_in_status(
384 request::SignInStatus::NotSignedIn,
385 cx,
386 );
387 Err(Arc::new(error))
388 }
389 })
390 })
391 .shared();
392 *status = SignInStatus::SigningIn {
393 prompt: None,
394 task: task.clone(),
395 };
396 cx.notify();
397 task
398 }
399 };
400
401 cx.foreground()
402 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
403 } else {
404 // If we're downloading, wait until download is finished
405 // If we're in a stuck state, display to the user
406 Task::ready(Err(anyhow!("copilot hasn't started yet")))
407 }
408 }
409
410 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
411 if let CopilotServer::Started { server, status, .. } = &mut self.server {
412 *status = SignInStatus::SignedOut;
413 cx.notify();
414
415 let server = server.clone();
416 cx.background().spawn(async move {
417 server
418 .request::<request::SignOut>(request::SignOutParams {})
419 .await?;
420 anyhow::Ok(())
421 })
422 } else {
423 Task::ready(Err(anyhow!("copilot hasn't started yet")))
424 }
425 }
426
427 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
428 let start_task = cx
429 .spawn({
430 let http = self.http.clone();
431 let node_runtime = self.node_runtime.clone();
432 move |this, cx| async move {
433 clear_copilot_dir().await;
434 Self::start_language_server(http, node_runtime, this, cx).await
435 }
436 })
437 .shared();
438
439 self.server = CopilotServer::Starting {
440 task: start_task.clone(),
441 };
442
443 cx.notify();
444
445 cx.foreground().spawn(start_task)
446 }
447
448 pub fn completions<T>(
449 &mut self,
450 buffer: &ModelHandle<Buffer>,
451 position: T,
452 cx: &mut ModelContext<Self>,
453 ) -> Task<Result<Vec<Completion>>>
454 where
455 T: ToPointUtf16,
456 {
457 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
458 }
459
460 pub fn completions_cycling<T>(
461 &mut self,
462 buffer: &ModelHandle<Buffer>,
463 position: T,
464 cx: &mut ModelContext<Self>,
465 ) -> Task<Result<Vec<Completion>>>
466 where
467 T: ToPointUtf16,
468 {
469 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
470 }
471
472 fn request_completions<R, T>(
473 &mut self,
474 buffer: &ModelHandle<Buffer>,
475 position: T,
476 cx: &mut ModelContext<Self>,
477 ) -> Task<Result<Vec<Completion>>>
478 where
479 R: lsp::request::Request<
480 Params = request::GetCompletionsParams,
481 Result = request::GetCompletionsResult,
482 >,
483 T: ToPointUtf16,
484 {
485 let buffer_id = buffer.id();
486 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
487 let snapshot = buffer.read(cx).snapshot();
488 let server = match &mut self.server {
489 CopilotServer::Starting { .. } => {
490 return Task::ready(Err(anyhow!("copilot is still starting")))
491 }
492 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
493 CopilotServer::Error(error) => {
494 return Task::ready(Err(anyhow!(
495 "copilot was not started because of an error: {}",
496 error
497 )))
498 }
499 CopilotServer::Started {
500 server,
501 status,
502 subscriptions_by_buffer_id,
503 } => {
504 if matches!(status, SignInStatus::Authorized { .. }) {
505 subscriptions_by_buffer_id
506 .entry(buffer_id)
507 .or_insert_with(|| {
508 server
509 .notify::<lsp::notification::DidOpenTextDocument>(
510 lsp::DidOpenTextDocumentParams {
511 text_document: lsp::TextDocumentItem {
512 uri: uri.clone(),
513 language_id: id_for_language(
514 buffer.read(cx).language(),
515 ),
516 version: 0,
517 text: snapshot.text(),
518 },
519 },
520 )
521 .log_err();
522
523 let uri = uri.clone();
524 cx.observe_release(buffer, move |this, _, _| {
525 if let CopilotServer::Started {
526 server,
527 subscriptions_by_buffer_id,
528 ..
529 } = &mut this.server
530 {
531 server
532 .notify::<lsp::notification::DidCloseTextDocument>(
533 lsp::DidCloseTextDocumentParams {
534 text_document: lsp::TextDocumentIdentifier::new(
535 uri.clone(),
536 ),
537 },
538 )
539 .log_err();
540 subscriptions_by_buffer_id.remove(&buffer_id);
541 }
542 })
543 });
544
545 server.clone()
546 } else {
547 return Task::ready(Err(anyhow!("must sign in before using copilot")));
548 }
549 }
550 };
551
552 let settings = cx.global::<Settings>();
553 let position = position.to_point_utf16(&snapshot);
554 let language = snapshot.language_at(position);
555 let language_name = language.map(|language| language.name());
556 let language_name = language_name.as_deref();
557 let tab_size = settings.tab_size(language_name);
558 let hard_tabs = settings.hard_tabs(language_name);
559 let language_id = id_for_language(language);
560
561 let path;
562 let relative_path;
563 if let Some(file) = snapshot.file() {
564 if let Some(file) = file.as_local() {
565 path = file.abs_path(cx);
566 } else {
567 path = file.full_path(cx);
568 }
569 relative_path = file.path().to_path_buf();
570 } else {
571 path = PathBuf::new();
572 relative_path = PathBuf::new();
573 }
574
575 cx.background().spawn(async move {
576 let result = server
577 .request::<R>(request::GetCompletionsParams {
578 doc: request::GetCompletionsDocument {
579 source: snapshot.text(),
580 tab_size: tab_size.into(),
581 indent_size: 1,
582 insert_spaces: !hard_tabs,
583 uri,
584 path: path.to_string_lossy().into(),
585 relative_path: relative_path.to_string_lossy().into(),
586 language_id,
587 position: point_to_lsp(position),
588 version: 0,
589 },
590 })
591 .await?;
592 let completions = result
593 .completions
594 .into_iter()
595 .map(|completion| {
596 let start = snapshot
597 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
598 let end =
599 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
600 Completion {
601 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
602 text: completion.text,
603 }
604 })
605 .collect();
606 anyhow::Ok(completions)
607 })
608 }
609
610 pub fn status(&self) -> Status {
611 match &self.server {
612 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
613 CopilotServer::Disabled => Status::Disabled,
614 CopilotServer::Error(error) => Status::Error(error.clone()),
615 CopilotServer::Started { status, .. } => match status {
616 SignInStatus::Authorized { .. } => Status::Authorized,
617 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
618 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
619 prompt: prompt.clone(),
620 },
621 SignInStatus::SignedOut => Status::SignedOut,
622 },
623 }
624 }
625
626 fn update_sign_in_status(
627 &mut self,
628 lsp_status: request::SignInStatus,
629 cx: &mut ModelContext<Self>,
630 ) {
631 if let CopilotServer::Started { status, .. } = &mut self.server {
632 *status = match lsp_status {
633 request::SignInStatus::Ok { .. }
634 | request::SignInStatus::MaybeOk { .. }
635 | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
636 request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
637 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
638 };
639 cx.notify();
640 }
641 }
642}
643
644fn id_for_language(language: Option<&Arc<Language>>) -> String {
645 let language_name = language.map(|language| language.name());
646 match language_name.as_deref() {
647 Some("Plain Text") => "plaintext".to_string(),
648 Some(language_name) => language_name.to_lowercase(),
649 None => "plaintext".to_string(),
650 }
651}
652
653async fn clear_copilot_dir() {
654 remove_matching(&paths::COPILOT_DIR, |_| true).await
655}
656
657async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
658 const SERVER_PATH: &'static str = "dist/agent.js";
659
660 ///Check for the latest copilot language server and download it if we haven't already
661 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
662 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
663
664 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
665
666 fs::create_dir_all(version_dir).await?;
667 let server_path = version_dir.join(SERVER_PATH);
668
669 if fs::metadata(&server_path).await.is_err() {
670 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
671 let dist_dir = version_dir.join("dist");
672 fs::create_dir_all(dist_dir.as_path()).await?;
673
674 let url = &release
675 .assets
676 .get(0)
677 .context("Github release for copilot contained no assets")?
678 .browser_download_url;
679
680 let mut response = http
681 .get(&url, Default::default(), true)
682 .await
683 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
684 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
685 let archive = Archive::new(decompressed_bytes);
686 archive.unpack(dist_dir).await?;
687
688 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
689 }
690
691 Ok(server_path)
692 }
693
694 match fetch_latest(http).await {
695 ok @ Result::Ok(..) => ok,
696 e @ Err(..) => {
697 e.log_err();
698 // Fetch a cached binary, if it exists
699 (|| async move {
700 let mut last_version_dir = None;
701 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
702 while let Some(entry) = entries.next().await {
703 let entry = entry?;
704 if entry.file_type().await?.is_dir() {
705 last_version_dir = Some(entry.path());
706 }
707 }
708 let last_version_dir =
709 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
710 let server_path = last_version_dir.join(SERVER_PATH);
711 if server_path.exists() {
712 Ok(server_path)
713 } else {
714 Err(anyhow!(
715 "missing executable in directory {:?}",
716 last_version_dir
717 ))
718 }
719 })()
720 .await
721 }
722 }
723}