1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use collections::HashMap;
9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
10use gpui::{
11 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
12 Task,
13};
14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
15use log::{debug, error};
16use lsp::LanguageServer;
17use node_runtime::NodeRuntime;
18use request::{LogMessage, StatusNotification};
19use settings::Settings;
20use smol::{fs, io::BufReader, stream::StreamExt};
21use staff_mode::staff_mode;
22
23use std::{
24 ffi::OsString,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use util::{
30 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
31};
32
33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
34actions!(copilot_auth, [SignIn, SignOut]);
35
36const COPILOT_NAMESPACE: &'static str = "copilot";
37actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
38
39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
40 staff_mode(cx, {
41 move |cx| {
42 let copilot = cx.add_model({
43 let node_runtime = node_runtime.clone();
44 let http = client.http_client().clone();
45 move |cx| Copilot::start(http, node_runtime, cx)
46 });
47 cx.set_global(copilot.clone());
48
49 observe_namespaces(cx, copilot);
50
51 sign_in::init(cx);
52 }
53 });
54
55 cx.add_global_action(|_: &SignIn, cx| {
56 if let Some(copilot) = Copilot::global(cx) {
57 copilot
58 .update(cx, |copilot, cx| copilot.sign_in(cx))
59 .detach_and_log_err(cx);
60 }
61 });
62 cx.add_global_action(|_: &SignOut, cx| {
63 if let Some(copilot) = Copilot::global(cx) {
64 copilot
65 .update(cx, |copilot, cx| copilot.sign_out(cx))
66 .detach_and_log_err(cx);
67 }
68 });
69
70 cx.add_global_action(|_: &Reinstall, cx| {
71 if let Some(copilot) = Copilot::global(cx) {
72 copilot
73 .update(cx, |copilot, cx| copilot.reinstall(cx))
74 .detach();
75 }
76 });
77}
78
79fn observe_namespaces(cx: &mut MutableAppContext, copilot: ModelHandle<Copilot>) {
80 cx.observe(&copilot, |handle, cx| {
81 let status = handle.read(cx).status();
82 cx.update_global::<collections::CommandPaletteFilter, _, _>(
83 move |filter, _cx| match status {
84 Status::Disabled => {
85 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
86 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
87 }
88 Status::Authorized => {
89 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
90 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
91 }
92 _ => {
93 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
94 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
95 }
96 },
97 );
98 })
99 .detach();
100}
101
102enum CopilotServer {
103 Disabled,
104 Starting {
105 task: Shared<Task<()>>,
106 },
107 Error(Arc<str>),
108 Started {
109 server: Arc<LanguageServer>,
110 status: SignInStatus,
111 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
112 },
113}
114
115#[derive(Clone, Debug)]
116enum SignInStatus {
117 Authorized {
118 _user: String,
119 },
120 Unauthorized {
121 _user: String,
122 },
123 SigningIn {
124 prompt: Option<request::PromptUserDeviceFlow>,
125 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
126 },
127 SignedOut,
128}
129
130#[derive(Debug, Clone)]
131pub enum Status {
132 Starting {
133 task: Shared<Task<()>>,
134 },
135 Error(Arc<str>),
136 Disabled,
137 SignedOut,
138 SigningIn {
139 prompt: Option<request::PromptUserDeviceFlow>,
140 },
141 Unauthorized,
142 Authorized,
143}
144
145impl Status {
146 pub fn is_authorized(&self) -> bool {
147 matches!(self, Status::Authorized)
148 }
149}
150
151#[derive(Debug, PartialEq, Eq)]
152pub struct Completion {
153 pub range: Range<Anchor>,
154 pub text: String,
155}
156
157pub struct Copilot {
158 http: Arc<dyn HttpClient>,
159 node_runtime: Arc<NodeRuntime>,
160 server: CopilotServer,
161}
162
163impl Entity for Copilot {
164 type Event = ();
165}
166
167impl Copilot {
168 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
169 match self.server {
170 CopilotServer::Starting { ref task } => Some(task.clone()),
171 _ => None,
172 }
173 }
174
175 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
176 if cx.has_global::<ModelHandle<Self>>() {
177 Some(cx.global::<ModelHandle<Self>>().clone())
178 } else {
179 None
180 }
181 }
182
183 fn start(
184 http: Arc<dyn HttpClient>,
185 node_runtime: Arc<NodeRuntime>,
186 cx: &mut ModelContext<Self>,
187 ) -> Self {
188 cx.observe_global::<Settings, _>({
189 let http = http.clone();
190 let node_runtime = node_runtime.clone();
191 move |this, cx| {
192 if cx.global::<Settings>().enable_copilot_integration {
193 if matches!(this.server, CopilotServer::Disabled) {
194 let start_task = cx
195 .spawn({
196 let http = http.clone();
197 let node_runtime = node_runtime.clone();
198 move |this, cx| {
199 Self::start_language_server(http, node_runtime, this, cx)
200 }
201 })
202 .shared();
203 this.server = CopilotServer::Starting { task: start_task };
204 cx.notify();
205 }
206 } else {
207 this.server = CopilotServer::Disabled;
208 cx.notify();
209 }
210 }
211 })
212 .detach();
213
214 if cx.global::<Settings>().enable_copilot_integration {
215 let start_task = cx
216 .spawn({
217 let http = http.clone();
218 let node_runtime = node_runtime.clone();
219 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
220 })
221 .shared();
222
223 Self {
224 http,
225 node_runtime,
226 server: CopilotServer::Starting { task: start_task },
227 }
228 } else {
229 Self {
230 http,
231 node_runtime,
232 server: CopilotServer::Disabled,
233 }
234 }
235 }
236
237 fn start_language_server(
238 http: Arc<dyn HttpClient>,
239 node_runtime: Arc<NodeRuntime>,
240 this: ModelHandle<Self>,
241 mut cx: AsyncAppContext,
242 ) -> impl Future<Output = ()> {
243 async move {
244 let start_language_server = async {
245 let server_path = get_copilot_lsp(http).await?;
246 let node_path = node_runtime.binary_path().await?;
247 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
248 let server = LanguageServer::new(
249 0,
250 &node_path,
251 arguments,
252 Path::new("/"),
253 None,
254 cx.clone(),
255 )?;
256
257 let server = server.initialize(Default::default()).await?;
258 let status = server
259 .request::<request::CheckStatus>(request::CheckStatusParams {
260 local_checks_only: false,
261 })
262 .await?;
263
264 server
265 .on_notification::<LogMessage, _>(|params, _cx| {
266 match params.level {
267 // Copilot is pretty agressive about logging
268 0 => debug!("copilot: {}", params.message),
269 1 => debug!("copilot: {}", params.message),
270 _ => error!("copilot: {}", params.message),
271 }
272
273 debug!("copilot metadata: {}", params.metadata_str);
274 debug!("copilot extra: {:?}", params.extra);
275 })
276 .detach();
277
278 server
279 .on_notification::<StatusNotification, _>(
280 |_, _| { /* Silence the notification */ },
281 )
282 .detach();
283
284 anyhow::Ok((server, status))
285 };
286
287 let server = start_language_server.await;
288 this.update(&mut cx, |this, cx| {
289 cx.notify();
290 match server {
291 Ok((server, status)) => {
292 this.server = CopilotServer::Started {
293 server,
294 status: SignInStatus::SignedOut,
295 subscriptions_by_buffer_id: Default::default(),
296 };
297 this.update_sign_in_status(status, cx);
298 }
299 Err(error) => {
300 this.server = CopilotServer::Error(error.to_string().into());
301 cx.notify()
302 }
303 }
304 })
305 }
306 }
307
308 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
309 if let CopilotServer::Started { server, status, .. } = &mut self.server {
310 let task = match status {
311 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
312 Task::ready(Ok(())).shared()
313 }
314 SignInStatus::SigningIn { task, .. } => {
315 cx.notify();
316 task.clone()
317 }
318 SignInStatus::SignedOut => {
319 let server = server.clone();
320 let task = cx
321 .spawn(|this, mut cx| async move {
322 let sign_in = async {
323 let sign_in = server
324 .request::<request::SignInInitiate>(
325 request::SignInInitiateParams {},
326 )
327 .await?;
328 match sign_in {
329 request::SignInInitiateResult::AlreadySignedIn { user } => {
330 Ok(request::SignInStatus::Ok { user })
331 }
332 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
333 this.update(&mut cx, |this, cx| {
334 if let CopilotServer::Started { status, .. } =
335 &mut this.server
336 {
337 if let SignInStatus::SigningIn {
338 prompt: prompt_flow,
339 ..
340 } = status
341 {
342 *prompt_flow = Some(flow.clone());
343 cx.notify();
344 }
345 }
346 });
347 let response = server
348 .request::<request::SignInConfirm>(
349 request::SignInConfirmParams {
350 user_code: flow.user_code,
351 },
352 )
353 .await?;
354 Ok(response)
355 }
356 }
357 };
358
359 let sign_in = sign_in.await;
360 this.update(&mut cx, |this, cx| match sign_in {
361 Ok(status) => {
362 this.update_sign_in_status(status, cx);
363 Ok(())
364 }
365 Err(error) => {
366 this.update_sign_in_status(
367 request::SignInStatus::NotSignedIn,
368 cx,
369 );
370 Err(Arc::new(error))
371 }
372 })
373 })
374 .shared();
375 *status = SignInStatus::SigningIn {
376 prompt: None,
377 task: task.clone(),
378 };
379 cx.notify();
380 task
381 }
382 };
383
384 cx.foreground()
385 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
386 } else {
387 // If we're downloading, wait until download is finished
388 // If we're in a stuck state, display to the user
389 Task::ready(Err(anyhow!("copilot hasn't started yet")))
390 }
391 }
392
393 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
394 if let CopilotServer::Started { server, status, .. } = &mut self.server {
395 *status = SignInStatus::SignedOut;
396 cx.notify();
397
398 let server = server.clone();
399 cx.background().spawn(async move {
400 server
401 .request::<request::SignOut>(request::SignOutParams {})
402 .await?;
403 anyhow::Ok(())
404 })
405 } else {
406 Task::ready(Err(anyhow!("copilot hasn't started yet")))
407 }
408 }
409
410 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
411 let start_task = cx
412 .spawn({
413 let http = self.http.clone();
414 let node_runtime = self.node_runtime.clone();
415 move |this, cx| async move {
416 clear_copilot_dir().await;
417 Self::start_language_server(http, node_runtime, this, cx).await
418 }
419 })
420 .shared();
421
422 self.server = CopilotServer::Starting {
423 task: start_task.clone(),
424 };
425
426 cx.notify();
427
428 cx.foreground().spawn(start_task)
429 }
430
431 pub fn completions<T>(
432 &mut self,
433 buffer: &ModelHandle<Buffer>,
434 position: T,
435 cx: &mut ModelContext<Self>,
436 ) -> Task<Result<Vec<Completion>>>
437 where
438 T: ToPointUtf16,
439 {
440 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
441 }
442
443 pub fn completions_cycling<T>(
444 &mut self,
445 buffer: &ModelHandle<Buffer>,
446 position: T,
447 cx: &mut ModelContext<Self>,
448 ) -> Task<Result<Vec<Completion>>>
449 where
450 T: ToPointUtf16,
451 {
452 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
453 }
454
455 fn request_completions<R, T>(
456 &mut self,
457 buffer: &ModelHandle<Buffer>,
458 position: T,
459 cx: &mut ModelContext<Self>,
460 ) -> Task<Result<Vec<Completion>>>
461 where
462 R: lsp::request::Request<
463 Params = request::GetCompletionsParams,
464 Result = request::GetCompletionsResult,
465 >,
466 T: ToPointUtf16,
467 {
468 let buffer_id = buffer.id();
469 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
470 let snapshot = buffer.read(cx).snapshot();
471 let server = match &mut self.server {
472 CopilotServer::Starting { .. } => {
473 return Task::ready(Err(anyhow!("copilot is still starting")))
474 }
475 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
476 CopilotServer::Error(error) => {
477 return Task::ready(Err(anyhow!(
478 "copilot was not started because of an error: {}",
479 error
480 )))
481 }
482 CopilotServer::Started {
483 server,
484 status,
485 subscriptions_by_buffer_id,
486 } => {
487 if matches!(status, SignInStatus::Authorized { .. }) {
488 subscriptions_by_buffer_id
489 .entry(buffer_id)
490 .or_insert_with(|| {
491 server
492 .notify::<lsp::notification::DidOpenTextDocument>(
493 lsp::DidOpenTextDocumentParams {
494 text_document: lsp::TextDocumentItem {
495 uri: uri.clone(),
496 language_id: id_for_language(
497 buffer.read(cx).language(),
498 ),
499 version: 0,
500 text: snapshot.text(),
501 },
502 },
503 )
504 .log_err();
505
506 let uri = uri.clone();
507 cx.observe_release(buffer, move |this, _, _| {
508 if let CopilotServer::Started {
509 server,
510 subscriptions_by_buffer_id,
511 ..
512 } = &mut this.server
513 {
514 server
515 .notify::<lsp::notification::DidCloseTextDocument>(
516 lsp::DidCloseTextDocumentParams {
517 text_document: lsp::TextDocumentIdentifier::new(
518 uri.clone(),
519 ),
520 },
521 )
522 .log_err();
523 subscriptions_by_buffer_id.remove(&buffer_id);
524 }
525 })
526 });
527
528 server.clone()
529 } else {
530 return Task::ready(Err(anyhow!("must sign in before using copilot")));
531 }
532 }
533 };
534
535 let settings = cx.global::<Settings>();
536 let position = position.to_point_utf16(&snapshot);
537 let language = snapshot.language_at(position);
538 let language_name = language.map(|language| language.name());
539 let language_name = language_name.as_deref();
540 let tab_size = settings.tab_size(language_name);
541 let hard_tabs = settings.hard_tabs(language_name);
542 let language_id = id_for_language(language);
543
544 let path;
545 let relative_path;
546 if let Some(file) = snapshot.file() {
547 if let Some(file) = file.as_local() {
548 path = file.abs_path(cx);
549 } else {
550 path = file.full_path(cx);
551 }
552 relative_path = file.path().to_path_buf();
553 } else {
554 path = PathBuf::new();
555 relative_path = PathBuf::new();
556 }
557
558 cx.background().spawn(async move {
559 let result = server
560 .request::<R>(request::GetCompletionsParams {
561 doc: request::GetCompletionsDocument {
562 source: snapshot.text(),
563 tab_size: tab_size.into(),
564 indent_size: 1,
565 insert_spaces: !hard_tabs,
566 uri,
567 path: path.to_string_lossy().into(),
568 relative_path: relative_path.to_string_lossy().into(),
569 language_id,
570 position: point_to_lsp(position),
571 version: 0,
572 },
573 })
574 .await?;
575 let completions = result
576 .completions
577 .into_iter()
578 .map(|completion| {
579 let start = snapshot
580 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
581 let end =
582 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
583 Completion {
584 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
585 text: completion.text,
586 }
587 })
588 .collect();
589 anyhow::Ok(completions)
590 })
591 }
592
593 pub fn status(&self) -> Status {
594 match &self.server {
595 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
596 CopilotServer::Disabled => Status::Disabled,
597 CopilotServer::Error(error) => Status::Error(error.clone()),
598 CopilotServer::Started { status, .. } => match status {
599 SignInStatus::Authorized { .. } => Status::Authorized,
600 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
601 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
602 prompt: prompt.clone(),
603 },
604 SignInStatus::SignedOut => Status::SignedOut,
605 },
606 }
607 }
608
609 fn update_sign_in_status(
610 &mut self,
611 lsp_status: request::SignInStatus,
612 cx: &mut ModelContext<Self>,
613 ) {
614 if let CopilotServer::Started { status, .. } = &mut self.server {
615 *status = match lsp_status {
616 request::SignInStatus::Ok { user }
617 | request::SignInStatus::MaybeOk { user }
618 | request::SignInStatus::AlreadySignedIn { user } => {
619 SignInStatus::Authorized { _user: user }
620 }
621 request::SignInStatus::NotAuthorized { user } => {
622 SignInStatus::Unauthorized { _user: user }
623 }
624 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
625 };
626 cx.notify();
627 }
628 }
629}
630
631fn id_for_language(language: Option<&Arc<Language>>) -> String {
632 let language_name = language.map(|language| language.name());
633 match language_name.as_deref() {
634 Some("Plain Text") => "plaintext".to_string(),
635 Some(language_name) => language_name.to_lowercase(),
636 None => "plaintext".to_string(),
637 }
638}
639
640async fn clear_copilot_dir() {
641 remove_matching(&paths::COPILOT_DIR, |_| true).await
642}
643
644async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
645 const SERVER_PATH: &'static str = "dist/agent.js";
646
647 ///Check for the latest copilot language server and download it if we haven't already
648 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
649 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
650
651 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
652
653 fs::create_dir_all(version_dir).await?;
654 let server_path = version_dir.join(SERVER_PATH);
655
656 if fs::metadata(&server_path).await.is_err() {
657 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
658 let dist_dir = version_dir.join("dist");
659 fs::create_dir_all(dist_dir.as_path()).await?;
660
661 let url = &release
662 .assets
663 .get(0)
664 .context("Github release for copilot contained no assets")?
665 .browser_download_url;
666
667 let mut response = http
668 .get(&url, Default::default(), true)
669 .await
670 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
671 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
672 let archive = Archive::new(decompressed_bytes);
673 archive.unpack(dist_dir).await?;
674
675 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
676 }
677
678 Ok(server_path)
679 }
680
681 match fetch_latest(http).await {
682 ok @ Result::Ok(..) => ok,
683 e @ Err(..) => {
684 e.log_err();
685 // Fetch a cached binary, if it exists
686 (|| async move {
687 let mut last_version_dir = None;
688 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
689 while let Some(entry) = entries.next().await {
690 let entry = entry?;
691 if entry.file_type().await?.is_dir() {
692 last_version_dir = Some(entry.path());
693 }
694 }
695 let last_version_dir =
696 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
697 let server_path = last_version_dir.join(SERVER_PATH);
698 if server_path.exists() {
699 Ok(server_path)
700 } else {
701 Err(anyhow!(
702 "missing executable in directory {:?}",
703 last_version_dir
704 ))
705 }
706 })()
707 .await
708 }
709 }
710}