1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use collections::HashMap;
9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
10use gpui::{
11 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
12 Task,
13};
14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
15use log::{debug, error};
16use lsp::LanguageServer;
17use node_runtime::NodeRuntime;
18use request::{LogMessage, StatusNotification};
19use settings::Settings;
20use smol::{fs, io::BufReader, stream::StreamExt};
21use staff_mode::{not_staff_mode, staff_mode};
22
23use std::{
24 ffi::OsString,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use util::{
30 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
31};
32
33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
34actions!(copilot_auth, [SignIn, SignOut]);
35
36const COPILOT_NAMESPACE: &'static str = "copilot";
37actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
38
39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
40 staff_mode(cx, {
41 move |cx| {
42 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
43 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
44 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
45 });
46
47 let copilot = cx.add_model({
48 let node_runtime = node_runtime.clone();
49 let http = client.http_client().clone();
50 move |cx| Copilot::start(http, node_runtime, cx)
51 });
52 cx.set_global(copilot.clone());
53
54 observe_namespaces(cx, copilot);
55
56 sign_in::init(cx);
57 }
58 });
59 not_staff_mode(cx, |cx| {
60 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
61 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
62 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
63 });
64 });
65
66 cx.add_global_action(|_: &SignIn, cx| {
67 if let Some(copilot) = Copilot::global(cx) {
68 copilot
69 .update(cx, |copilot, cx| copilot.sign_in(cx))
70 .detach_and_log_err(cx);
71 }
72 });
73 cx.add_global_action(|_: &SignOut, cx| {
74 if let Some(copilot) = Copilot::global(cx) {
75 copilot
76 .update(cx, |copilot, cx| copilot.sign_out(cx))
77 .detach_and_log_err(cx);
78 }
79 });
80
81 cx.add_global_action(|_: &Reinstall, cx| {
82 if let Some(copilot) = Copilot::global(cx) {
83 copilot
84 .update(cx, |copilot, cx| copilot.reinstall(cx))
85 .detach();
86 }
87 });
88}
89
90fn observe_namespaces(cx: &mut MutableAppContext, copilot: ModelHandle<Copilot>) {
91 cx.observe(&copilot, |handle, cx| {
92 let status = handle.read(cx).status();
93 cx.update_global::<collections::CommandPaletteFilter, _, _>(
94 move |filter, _cx| match status {
95 Status::Disabled => {
96 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
97 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
98 }
99 Status::Authorized => {
100 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
101 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
102 }
103 _ => {
104 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
105 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
106 }
107 },
108 );
109 })
110 .detach();
111}
112
113enum CopilotServer {
114 Disabled,
115 Starting {
116 task: Shared<Task<()>>,
117 },
118 Error(Arc<str>),
119 Started {
120 server: Arc<LanguageServer>,
121 status: SignInStatus,
122 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
123 },
124}
125
126#[derive(Clone, Debug)]
127enum SignInStatus {
128 Authorized {
129 _user: String,
130 },
131 Unauthorized {
132 _user: String,
133 },
134 SigningIn {
135 prompt: Option<request::PromptUserDeviceFlow>,
136 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
137 },
138 SignedOut,
139}
140
141#[derive(Debug, Clone)]
142pub enum Status {
143 Starting {
144 task: Shared<Task<()>>,
145 },
146 Error(Arc<str>),
147 Disabled,
148 SignedOut,
149 SigningIn {
150 prompt: Option<request::PromptUserDeviceFlow>,
151 },
152 Unauthorized,
153 Authorized,
154}
155
156impl Status {
157 pub fn is_authorized(&self) -> bool {
158 matches!(self, Status::Authorized)
159 }
160}
161
162#[derive(Debug, PartialEq, Eq)]
163pub struct Completion {
164 pub range: Range<Anchor>,
165 pub text: String,
166}
167
168pub struct Copilot {
169 http: Arc<dyn HttpClient>,
170 node_runtime: Arc<NodeRuntime>,
171 server: CopilotServer,
172}
173
174impl Entity for Copilot {
175 type Event = ();
176}
177
178impl Copilot {
179 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
180 match self.server {
181 CopilotServer::Starting { ref task } => Some(task.clone()),
182 _ => None,
183 }
184 }
185
186 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
187 if cx.has_global::<ModelHandle<Self>>() {
188 Some(cx.global::<ModelHandle<Self>>().clone())
189 } else {
190 None
191 }
192 }
193
194 fn start(
195 http: Arc<dyn HttpClient>,
196 node_runtime: Arc<NodeRuntime>,
197 cx: &mut ModelContext<Self>,
198 ) -> Self {
199 cx.observe_global::<Settings, _>({
200 let http = http.clone();
201 let node_runtime = node_runtime.clone();
202 move |this, cx| {
203 if cx.global::<Settings>().enable_copilot_integration {
204 if matches!(this.server, CopilotServer::Disabled) {
205 let start_task = cx
206 .spawn({
207 let http = http.clone();
208 let node_runtime = node_runtime.clone();
209 move |this, cx| {
210 Self::start_language_server(http, node_runtime, this, cx)
211 }
212 })
213 .shared();
214 this.server = CopilotServer::Starting { task: start_task };
215 cx.notify();
216 }
217 } else {
218 this.server = CopilotServer::Disabled;
219 cx.notify();
220 }
221 }
222 })
223 .detach();
224
225 if cx.global::<Settings>().enable_copilot_integration {
226 let start_task = cx
227 .spawn({
228 let http = http.clone();
229 let node_runtime = node_runtime.clone();
230 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
231 })
232 .shared();
233
234 Self {
235 http,
236 node_runtime,
237 server: CopilotServer::Starting { task: start_task },
238 }
239 } else {
240 Self {
241 http,
242 node_runtime,
243 server: CopilotServer::Disabled,
244 }
245 }
246 }
247
248 fn start_language_server(
249 http: Arc<dyn HttpClient>,
250 node_runtime: Arc<NodeRuntime>,
251 this: ModelHandle<Self>,
252 mut cx: AsyncAppContext,
253 ) -> impl Future<Output = ()> {
254 async move {
255 let start_language_server = async {
256 let server_path = get_copilot_lsp(http).await?;
257 let node_path = node_runtime.binary_path().await?;
258 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
259 let server = LanguageServer::new(
260 0,
261 &node_path,
262 arguments,
263 Path::new("/"),
264 None,
265 cx.clone(),
266 )?;
267
268 let server = server.initialize(Default::default()).await?;
269 let status = server
270 .request::<request::CheckStatus>(request::CheckStatusParams {
271 local_checks_only: false,
272 })
273 .await?;
274
275 server
276 .on_notification::<LogMessage, _>(|params, _cx| {
277 match params.level {
278 // Copilot is pretty agressive about logging
279 0 => debug!("copilot: {}", params.message),
280 1 => debug!("copilot: {}", params.message),
281 _ => error!("copilot: {}", params.message),
282 }
283
284 debug!("copilot metadata: {}", params.metadata_str);
285 debug!("copilot extra: {:?}", params.extra);
286 })
287 .detach();
288
289 server
290 .on_notification::<StatusNotification, _>(
291 |_, _| { /* Silence the notification */ },
292 )
293 .detach();
294
295 anyhow::Ok((server, status))
296 };
297
298 let server = start_language_server.await;
299 this.update(&mut cx, |this, cx| {
300 cx.notify();
301 match server {
302 Ok((server, status)) => {
303 this.server = CopilotServer::Started {
304 server,
305 status: SignInStatus::SignedOut,
306 subscriptions_by_buffer_id: Default::default(),
307 };
308 this.update_sign_in_status(status, cx);
309 }
310 Err(error) => {
311 this.server = CopilotServer::Error(error.to_string().into());
312 cx.notify()
313 }
314 }
315 })
316 }
317 }
318
319 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
320 if let CopilotServer::Started { server, status, .. } = &mut self.server {
321 let task = match status {
322 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
323 Task::ready(Ok(())).shared()
324 }
325 SignInStatus::SigningIn { task, .. } => {
326 cx.notify();
327 task.clone()
328 }
329 SignInStatus::SignedOut => {
330 let server = server.clone();
331 let task = cx
332 .spawn(|this, mut cx| async move {
333 let sign_in = async {
334 let sign_in = server
335 .request::<request::SignInInitiate>(
336 request::SignInInitiateParams {},
337 )
338 .await?;
339 match sign_in {
340 request::SignInInitiateResult::AlreadySignedIn { user } => {
341 Ok(request::SignInStatus::Ok { user })
342 }
343 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
344 this.update(&mut cx, |this, cx| {
345 if let CopilotServer::Started { status, .. } =
346 &mut this.server
347 {
348 if let SignInStatus::SigningIn {
349 prompt: prompt_flow,
350 ..
351 } = status
352 {
353 *prompt_flow = Some(flow.clone());
354 cx.notify();
355 }
356 }
357 });
358 let response = server
359 .request::<request::SignInConfirm>(
360 request::SignInConfirmParams {
361 user_code: flow.user_code,
362 },
363 )
364 .await?;
365 Ok(response)
366 }
367 }
368 };
369
370 let sign_in = sign_in.await;
371 this.update(&mut cx, |this, cx| match sign_in {
372 Ok(status) => {
373 this.update_sign_in_status(status, cx);
374 Ok(())
375 }
376 Err(error) => {
377 this.update_sign_in_status(
378 request::SignInStatus::NotSignedIn,
379 cx,
380 );
381 Err(Arc::new(error))
382 }
383 })
384 })
385 .shared();
386 *status = SignInStatus::SigningIn {
387 prompt: None,
388 task: task.clone(),
389 };
390 cx.notify();
391 task
392 }
393 };
394
395 cx.foreground()
396 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
397 } else {
398 // If we're downloading, wait until download is finished
399 // If we're in a stuck state, display to the user
400 Task::ready(Err(anyhow!("copilot hasn't started yet")))
401 }
402 }
403
404 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
405 if let CopilotServer::Started { server, status, .. } = &mut self.server {
406 *status = SignInStatus::SignedOut;
407 cx.notify();
408
409 let server = server.clone();
410 cx.background().spawn(async move {
411 server
412 .request::<request::SignOut>(request::SignOutParams {})
413 .await?;
414 anyhow::Ok(())
415 })
416 } else {
417 Task::ready(Err(anyhow!("copilot hasn't started yet")))
418 }
419 }
420
421 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
422 let start_task = cx
423 .spawn({
424 let http = self.http.clone();
425 let node_runtime = self.node_runtime.clone();
426 move |this, cx| async move {
427 clear_copilot_dir().await;
428 Self::start_language_server(http, node_runtime, this, cx).await
429 }
430 })
431 .shared();
432
433 self.server = CopilotServer::Starting {
434 task: start_task.clone(),
435 };
436
437 cx.notify();
438
439 cx.foreground().spawn(start_task)
440 }
441
442 pub fn completions<T>(
443 &mut self,
444 buffer: &ModelHandle<Buffer>,
445 position: T,
446 cx: &mut ModelContext<Self>,
447 ) -> Task<Result<Vec<Completion>>>
448 where
449 T: ToPointUtf16,
450 {
451 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
452 }
453
454 pub fn completions_cycling<T>(
455 &mut self,
456 buffer: &ModelHandle<Buffer>,
457 position: T,
458 cx: &mut ModelContext<Self>,
459 ) -> Task<Result<Vec<Completion>>>
460 where
461 T: ToPointUtf16,
462 {
463 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
464 }
465
466 fn request_completions<R, T>(
467 &mut self,
468 buffer: &ModelHandle<Buffer>,
469 position: T,
470 cx: &mut ModelContext<Self>,
471 ) -> Task<Result<Vec<Completion>>>
472 where
473 R: lsp::request::Request<
474 Params = request::GetCompletionsParams,
475 Result = request::GetCompletionsResult,
476 >,
477 T: ToPointUtf16,
478 {
479 let buffer_id = buffer.id();
480 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
481 let snapshot = buffer.read(cx).snapshot();
482 let server = match &mut self.server {
483 CopilotServer::Starting { .. } => {
484 return Task::ready(Err(anyhow!("copilot is still starting")))
485 }
486 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
487 CopilotServer::Error(error) => {
488 return Task::ready(Err(anyhow!(
489 "copilot was not started because of an error: {}",
490 error
491 )))
492 }
493 CopilotServer::Started {
494 server,
495 status,
496 subscriptions_by_buffer_id,
497 } => {
498 if matches!(status, SignInStatus::Authorized { .. }) {
499 subscriptions_by_buffer_id
500 .entry(buffer_id)
501 .or_insert_with(|| {
502 server
503 .notify::<lsp::notification::DidOpenTextDocument>(
504 lsp::DidOpenTextDocumentParams {
505 text_document: lsp::TextDocumentItem {
506 uri: uri.clone(),
507 language_id: id_for_language(
508 buffer.read(cx).language(),
509 ),
510 version: 0,
511 text: snapshot.text(),
512 },
513 },
514 )
515 .log_err();
516
517 let uri = uri.clone();
518 cx.observe_release(buffer, move |this, _, _| {
519 if let CopilotServer::Started {
520 server,
521 subscriptions_by_buffer_id,
522 ..
523 } = &mut this.server
524 {
525 server
526 .notify::<lsp::notification::DidCloseTextDocument>(
527 lsp::DidCloseTextDocumentParams {
528 text_document: lsp::TextDocumentIdentifier::new(
529 uri.clone(),
530 ),
531 },
532 )
533 .log_err();
534 subscriptions_by_buffer_id.remove(&buffer_id);
535 }
536 })
537 });
538
539 server.clone()
540 } else {
541 return Task::ready(Err(anyhow!("must sign in before using copilot")));
542 }
543 }
544 };
545
546 let settings = cx.global::<Settings>();
547 let position = position.to_point_utf16(&snapshot);
548 let language = snapshot.language_at(position);
549 let language_name = language.map(|language| language.name());
550 let language_name = language_name.as_deref();
551 let tab_size = settings.tab_size(language_name);
552 let hard_tabs = settings.hard_tabs(language_name);
553 let language_id = id_for_language(language);
554
555 let path;
556 let relative_path;
557 if let Some(file) = snapshot.file() {
558 if let Some(file) = file.as_local() {
559 path = file.abs_path(cx);
560 } else {
561 path = file.full_path(cx);
562 }
563 relative_path = file.path().to_path_buf();
564 } else {
565 path = PathBuf::new();
566 relative_path = PathBuf::new();
567 }
568
569 cx.background().spawn(async move {
570 let result = server
571 .request::<R>(request::GetCompletionsParams {
572 doc: request::GetCompletionsDocument {
573 source: snapshot.text(),
574 tab_size: tab_size.into(),
575 indent_size: 1,
576 insert_spaces: !hard_tabs,
577 uri,
578 path: path.to_string_lossy().into(),
579 relative_path: relative_path.to_string_lossy().into(),
580 language_id,
581 position: point_to_lsp(position),
582 version: 0,
583 },
584 })
585 .await?;
586 let completions = result
587 .completions
588 .into_iter()
589 .map(|completion| {
590 let start = snapshot
591 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
592 let end =
593 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
594 Completion {
595 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
596 text: completion.text,
597 }
598 })
599 .collect();
600 anyhow::Ok(completions)
601 })
602 }
603
604 pub fn status(&self) -> Status {
605 match &self.server {
606 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
607 CopilotServer::Disabled => Status::Disabled,
608 CopilotServer::Error(error) => Status::Error(error.clone()),
609 CopilotServer::Started { status, .. } => match status {
610 SignInStatus::Authorized { .. } => Status::Authorized,
611 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
612 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
613 prompt: prompt.clone(),
614 },
615 SignInStatus::SignedOut => Status::SignedOut,
616 },
617 }
618 }
619
620 fn update_sign_in_status(
621 &mut self,
622 lsp_status: request::SignInStatus,
623 cx: &mut ModelContext<Self>,
624 ) {
625 if let CopilotServer::Started { status, .. } = &mut self.server {
626 *status = match lsp_status {
627 request::SignInStatus::Ok { user }
628 | request::SignInStatus::MaybeOk { user }
629 | request::SignInStatus::AlreadySignedIn { user } => {
630 SignInStatus::Authorized { _user: user }
631 }
632 request::SignInStatus::NotAuthorized { user } => {
633 SignInStatus::Unauthorized { _user: user }
634 }
635 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
636 };
637 cx.notify();
638 }
639 }
640}
641
642fn id_for_language(language: Option<&Arc<Language>>) -> String {
643 let language_name = language.map(|language| language.name());
644 match language_name.as_deref() {
645 Some("Plain Text") => "plaintext".to_string(),
646 Some(language_name) => language_name.to_lowercase(),
647 None => "plaintext".to_string(),
648 }
649}
650
651async fn clear_copilot_dir() {
652 remove_matching(&paths::COPILOT_DIR, |_| true).await
653}
654
655async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
656 const SERVER_PATH: &'static str = "dist/agent.js";
657
658 ///Check for the latest copilot language server and download it if we haven't already
659 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
660 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
661
662 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
663
664 fs::create_dir_all(version_dir).await?;
665 let server_path = version_dir.join(SERVER_PATH);
666
667 if fs::metadata(&server_path).await.is_err() {
668 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
669 let dist_dir = version_dir.join("dist");
670 fs::create_dir_all(dist_dir.as_path()).await?;
671
672 let url = &release
673 .assets
674 .get(0)
675 .context("Github release for copilot contained no assets")?
676 .browser_download_url;
677
678 let mut response = http
679 .get(&url, Default::default(), true)
680 .await
681 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
682 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
683 let archive = Archive::new(decompressed_bytes);
684 archive.unpack(dist_dir).await?;
685
686 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
687 }
688
689 Ok(server_path)
690 }
691
692 match fetch_latest(http).await {
693 ok @ Result::Ok(..) => ok,
694 e @ Err(..) => {
695 e.log_err();
696 // Fetch a cached binary, if it exists
697 (|| async move {
698 let mut last_version_dir = None;
699 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
700 while let Some(entry) = entries.next().await {
701 let entry = entry?;
702 if entry.file_type().await?.is_dir() {
703 last_version_dir = Some(entry.path());
704 }
705 }
706 let last_version_dir =
707 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
708 let server_path = last_version_dir.join(SERVER_PATH);
709 if server_path.exists() {
710 Ok(server_path)
711 } else {
712 Err(anyhow!(
713 "missing executable in directory {:?}",
714 last_version_dir
715 ))
716 }
717 })()
718 .await
719 }
720 }
721}