1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use collections::HashMap;
9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
10use gpui::{
11 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
12 Task,
13};
14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
15use log::{debug, error};
16use lsp::LanguageServer;
17use node_runtime::NodeRuntime;
18use request::{LogMessage, StatusNotification};
19use settings::Settings;
20use smol::{fs, io::BufReader, stream::StreamExt};
21use staff_mode::{not_staff_mode, staff_mode};
22
23use std::{
24 ffi::OsString,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use util::{
30 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
31};
32
33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
34actions!(copilot_auth, [SignIn, SignOut]);
35
36const COPILOT_NAMESPACE: &'static str = "copilot";
37actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
38
39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
40 staff_mode(cx, {
41 move |cx| {
42 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
43 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
44 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
45 });
46
47 let copilot = cx.add_model({
48 let node_runtime = node_runtime.clone();
49 let http = client.http_client().clone();
50 move |cx| Copilot::start(http, node_runtime, cx)
51 });
52 cx.set_global(copilot.clone());
53
54 observe_namespaces(cx, copilot);
55
56 sign_in::init(cx);
57 }
58 });
59 not_staff_mode(cx, |cx| {
60 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
61 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
62 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
63 });
64 });
65
66 cx.add_global_action(|_: &SignIn, cx| {
67 if let Some(copilot) = Copilot::global(cx) {
68 copilot
69 .update(cx, |copilot, cx| copilot.sign_in(cx))
70 .detach_and_log_err(cx);
71 }
72 });
73 cx.add_global_action(|_: &SignOut, cx| {
74 if let Some(copilot) = Copilot::global(cx) {
75 copilot
76 .update(cx, |copilot, cx| copilot.sign_out(cx))
77 .detach_and_log_err(cx);
78 }
79 });
80
81 cx.add_global_action(|_: &Reinstall, cx| {
82 if let Some(copilot) = Copilot::global(cx) {
83 copilot
84 .update(cx, |copilot, cx| copilot.reinstall(cx))
85 .detach();
86 }
87 });
88}
89
90fn observe_namespaces(cx: &mut MutableAppContext, copilot: ModelHandle<Copilot>) {
91 cx.observe(&copilot, |handle, cx| {
92 let status = handle.read(cx).status();
93 cx.update_global::<collections::CommandPaletteFilter, _, _>(
94 move |filter, _cx| match status {
95 Status::Disabled => {
96 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
97 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
98 }
99 Status::Authorized => {
100 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
101 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
102 }
103 _ => {
104 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
105 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
106 }
107 },
108 );
109 })
110 .detach();
111}
112
113enum CopilotServer {
114 Disabled,
115 Starting {
116 task: Shared<Task<()>>,
117 },
118 Error(Arc<str>),
119 Started {
120 server: Arc<LanguageServer>,
121 status: SignInStatus,
122 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
123 },
124}
125
126#[derive(Clone, Debug)]
127enum SignInStatus {
128 Authorized {
129 _user: String,
130 },
131 Unauthorized {
132 _user: String,
133 },
134 SigningIn {
135 prompt: Option<request::PromptUserDeviceFlow>,
136 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
137 },
138 SignedOut,
139}
140
141#[derive(Debug, Clone)]
142pub enum Status {
143 Starting {
144 task: Shared<Task<()>>,
145 },
146 Error(Arc<str>),
147 Disabled,
148 SignedOut,
149 SigningIn {
150 prompt: Option<request::PromptUserDeviceFlow>,
151 },
152 Unauthorized,
153 Authorized,
154}
155
156impl Status {
157 pub fn is_authorized(&self) -> bool {
158 matches!(self, Status::Authorized)
159 }
160}
161
162#[derive(Debug, PartialEq, Eq)]
163pub struct Completion {
164 pub range: Range<Anchor>,
165 pub text: String,
166}
167
168pub struct Copilot {
169 http: Arc<dyn HttpClient>,
170 node_runtime: Arc<NodeRuntime>,
171 server: CopilotServer,
172}
173
174impl Entity for Copilot {
175 type Event = ();
176}
177
178impl Copilot {
179 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
180 if cx.has_global::<ModelHandle<Self>>() {
181 Some(cx.global::<ModelHandle<Self>>().clone())
182 } else {
183 None
184 }
185 }
186
187 fn start(
188 http: Arc<dyn HttpClient>,
189 node_runtime: Arc<NodeRuntime>,
190 cx: &mut ModelContext<Self>,
191 ) -> Self {
192 cx.observe_global::<Settings, _>({
193 let http = http.clone();
194 let node_runtime = node_runtime.clone();
195 move |this, cx| {
196 if cx.global::<Settings>().enable_copilot_integration {
197 if matches!(this.server, CopilotServer::Disabled) {
198 let start_task = cx
199 .spawn({
200 let http = http.clone();
201 let node_runtime = node_runtime.clone();
202 move |this, cx| {
203 Self::start_language_server(http, node_runtime, this, cx)
204 }
205 })
206 .shared();
207 this.server = CopilotServer::Starting { task: start_task };
208 cx.notify();
209 }
210 } else {
211 this.server = CopilotServer::Disabled;
212 cx.notify();
213 }
214 }
215 })
216 .detach();
217
218 if cx.global::<Settings>().enable_copilot_integration {
219 let start_task = cx
220 .spawn({
221 let http = http.clone();
222 let node_runtime = node_runtime.clone();
223 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
224 })
225 .shared();
226
227 Self {
228 http,
229 node_runtime,
230 server: CopilotServer::Starting { task: start_task },
231 }
232 } else {
233 Self {
234 http,
235 node_runtime,
236 server: CopilotServer::Disabled,
237 }
238 }
239 }
240
241 fn start_language_server(
242 http: Arc<dyn HttpClient>,
243 node_runtime: Arc<NodeRuntime>,
244 this: ModelHandle<Self>,
245 mut cx: AsyncAppContext,
246 ) -> impl Future<Output = ()> {
247 async move {
248 let start_language_server = async {
249 let server_path = get_copilot_lsp(http).await?;
250 let node_path = node_runtime.binary_path().await?;
251 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
252 let server = LanguageServer::new(
253 0,
254 &node_path,
255 arguments,
256 Path::new("/"),
257 None,
258 cx.clone(),
259 )?;
260
261 let server = server.initialize(Default::default()).await?;
262 let status = server
263 .request::<request::CheckStatus>(request::CheckStatusParams {
264 local_checks_only: false,
265 })
266 .await?;
267
268 server
269 .on_notification::<LogMessage, _>(|params, _cx| {
270 match params.level {
271 // Copilot is pretty agressive about logging
272 0 => debug!("copilot: {}", params.message),
273 1 => debug!("copilot: {}", params.message),
274 _ => error!("copilot: {}", params.message),
275 }
276
277 debug!("copilot metadata: {}", params.metadata_str);
278 debug!("copilot extra: {:?}", params.extra);
279 })
280 .detach();
281
282 server
283 .on_notification::<StatusNotification, _>(
284 |_, _| { /* Silence the notification */ },
285 )
286 .detach();
287
288 anyhow::Ok((server, status))
289 };
290
291 let server = start_language_server.await;
292 this.update(&mut cx, |this, cx| {
293 cx.notify();
294 match server {
295 Ok((server, status)) => {
296 this.server = CopilotServer::Started {
297 server,
298 status: SignInStatus::SignedOut,
299 subscriptions_by_buffer_id: Default::default(),
300 };
301 this.update_sign_in_status(status, cx);
302 }
303 Err(error) => {
304 this.server = CopilotServer::Error(error.to_string().into());
305 cx.notify()
306 }
307 }
308 })
309 }
310 }
311
312 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
313 if let CopilotServer::Started { server, status, .. } = &mut self.server {
314 let task = match status {
315 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
316 Task::ready(Ok(())).shared()
317 }
318 SignInStatus::SigningIn { task, .. } => {
319 cx.notify();
320 task.clone()
321 }
322 SignInStatus::SignedOut => {
323 let server = server.clone();
324 let task = cx
325 .spawn(|this, mut cx| async move {
326 let sign_in = async {
327 let sign_in = server
328 .request::<request::SignInInitiate>(
329 request::SignInInitiateParams {},
330 )
331 .await?;
332 match sign_in {
333 request::SignInInitiateResult::AlreadySignedIn { user } => {
334 Ok(request::SignInStatus::Ok { user })
335 }
336 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
337 this.update(&mut cx, |this, cx| {
338 if let CopilotServer::Started { status, .. } =
339 &mut this.server
340 {
341 if let SignInStatus::SigningIn {
342 prompt: prompt_flow,
343 ..
344 } = status
345 {
346 *prompt_flow = Some(flow.clone());
347 cx.notify();
348 }
349 }
350 });
351 let response = server
352 .request::<request::SignInConfirm>(
353 request::SignInConfirmParams {
354 user_code: flow.user_code,
355 },
356 )
357 .await?;
358 Ok(response)
359 }
360 }
361 };
362
363 let sign_in = sign_in.await;
364 this.update(&mut cx, |this, cx| match sign_in {
365 Ok(status) => {
366 this.update_sign_in_status(status, cx);
367 Ok(())
368 }
369 Err(error) => {
370 this.update_sign_in_status(
371 request::SignInStatus::NotSignedIn,
372 cx,
373 );
374 Err(Arc::new(error))
375 }
376 })
377 })
378 .shared();
379 *status = SignInStatus::SigningIn {
380 prompt: None,
381 task: task.clone(),
382 };
383 cx.notify();
384 task
385 }
386 };
387
388 cx.foreground()
389 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
390 } else {
391 // If we're downloading, wait until download is finished
392 // If we're in a stuck state, display to the user
393 Task::ready(Err(anyhow!("copilot hasn't started yet")))
394 }
395 }
396
397 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
398 if let CopilotServer::Started { server, status, .. } = &mut self.server {
399 *status = SignInStatus::SignedOut;
400 cx.notify();
401
402 let server = server.clone();
403 cx.background().spawn(async move {
404 server
405 .request::<request::SignOut>(request::SignOutParams {})
406 .await?;
407 anyhow::Ok(())
408 })
409 } else {
410 Task::ready(Err(anyhow!("copilot hasn't started yet")))
411 }
412 }
413
414 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
415 let start_task = cx
416 .spawn({
417 let http = self.http.clone();
418 let node_runtime = self.node_runtime.clone();
419 move |this, cx| async move {
420 clear_copilot_dir().await;
421 Self::start_language_server(http, node_runtime, this, cx).await
422 }
423 })
424 .shared();
425
426 self.server = CopilotServer::Starting {
427 task: start_task.clone(),
428 };
429
430 cx.notify();
431
432 cx.foreground().spawn(start_task)
433 }
434
435 pub fn completions<T>(
436 &mut self,
437 buffer: &ModelHandle<Buffer>,
438 position: T,
439 cx: &mut ModelContext<Self>,
440 ) -> Task<Result<Vec<Completion>>>
441 where
442 T: ToPointUtf16,
443 {
444 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
445 }
446
447 pub fn completions_cycling<T>(
448 &mut self,
449 buffer: &ModelHandle<Buffer>,
450 position: T,
451 cx: &mut ModelContext<Self>,
452 ) -> Task<Result<Vec<Completion>>>
453 where
454 T: ToPointUtf16,
455 {
456 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
457 }
458
459 fn request_completions<R, T>(
460 &mut self,
461 buffer: &ModelHandle<Buffer>,
462 position: T,
463 cx: &mut ModelContext<Self>,
464 ) -> Task<Result<Vec<Completion>>>
465 where
466 R: lsp::request::Request<
467 Params = request::GetCompletionsParams,
468 Result = request::GetCompletionsResult,
469 >,
470 T: ToPointUtf16,
471 {
472 let buffer_id = buffer.id();
473 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
474 let snapshot = buffer.read(cx).snapshot();
475 let server = match &mut self.server {
476 CopilotServer::Starting { .. } => {
477 return Task::ready(Err(anyhow!("copilot is still starting")))
478 }
479 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
480 CopilotServer::Error(error) => {
481 return Task::ready(Err(anyhow!(
482 "copilot was not started because of an error: {}",
483 error
484 )))
485 }
486 CopilotServer::Started {
487 server,
488 status,
489 subscriptions_by_buffer_id,
490 } => {
491 if matches!(status, SignInStatus::Authorized { .. }) {
492 subscriptions_by_buffer_id
493 .entry(buffer_id)
494 .or_insert_with(|| {
495 server
496 .notify::<lsp::notification::DidOpenTextDocument>(
497 lsp::DidOpenTextDocumentParams {
498 text_document: lsp::TextDocumentItem {
499 uri: uri.clone(),
500 language_id: id_for_language(
501 buffer.read(cx).language(),
502 ),
503 version: 0,
504 text: snapshot.text(),
505 },
506 },
507 )
508 .log_err();
509
510 let uri = uri.clone();
511 cx.observe_release(buffer, move |this, _, _| {
512 if let CopilotServer::Started {
513 server,
514 subscriptions_by_buffer_id,
515 ..
516 } = &mut this.server
517 {
518 server
519 .notify::<lsp::notification::DidCloseTextDocument>(
520 lsp::DidCloseTextDocumentParams {
521 text_document: lsp::TextDocumentIdentifier::new(
522 uri.clone(),
523 ),
524 },
525 )
526 .log_err();
527 subscriptions_by_buffer_id.remove(&buffer_id);
528 }
529 })
530 });
531
532 server.clone()
533 } else {
534 return Task::ready(Err(anyhow!("must sign in before using copilot")));
535 }
536 }
537 };
538
539 let settings = cx.global::<Settings>();
540 let position = position.to_point_utf16(&snapshot);
541 let language = snapshot.language_at(position);
542 let language_name = language.map(|language| language.name());
543 let language_name = language_name.as_deref();
544 let tab_size = settings.tab_size(language_name);
545 let hard_tabs = settings.hard_tabs(language_name);
546 let language_id = id_for_language(language);
547
548 let path;
549 let relative_path;
550 if let Some(file) = snapshot.file() {
551 if let Some(file) = file.as_local() {
552 path = file.abs_path(cx);
553 } else {
554 path = file.full_path(cx);
555 }
556 relative_path = file.path().to_path_buf();
557 } else {
558 path = PathBuf::new();
559 relative_path = PathBuf::new();
560 }
561
562 cx.background().spawn(async move {
563 let result = server
564 .request::<R>(request::GetCompletionsParams {
565 doc: request::GetCompletionsDocument {
566 source: snapshot.text(),
567 tab_size: tab_size.into(),
568 indent_size: 1,
569 insert_spaces: !hard_tabs,
570 uri,
571 path: path.to_string_lossy().into(),
572 relative_path: relative_path.to_string_lossy().into(),
573 language_id,
574 position: point_to_lsp(position),
575 version: 0,
576 },
577 })
578 .await?;
579 let completions = result
580 .completions
581 .into_iter()
582 .map(|completion| {
583 let start = snapshot
584 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
585 let end =
586 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
587 Completion {
588 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
589 text: completion.text,
590 }
591 })
592 .collect();
593 anyhow::Ok(completions)
594 })
595 }
596
597 pub fn status(&self) -> Status {
598 match &self.server {
599 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
600 CopilotServer::Disabled => Status::Disabled,
601 CopilotServer::Error(error) => Status::Error(error.clone()),
602 CopilotServer::Started { status, .. } => match status {
603 SignInStatus::Authorized { .. } => Status::Authorized,
604 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
605 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
606 prompt: prompt.clone(),
607 },
608 SignInStatus::SignedOut => Status::SignedOut,
609 },
610 }
611 }
612
613 fn update_sign_in_status(
614 &mut self,
615 lsp_status: request::SignInStatus,
616 cx: &mut ModelContext<Self>,
617 ) {
618 if let CopilotServer::Started { status, .. } = &mut self.server {
619 *status = match lsp_status {
620 request::SignInStatus::Ok { user }
621 | request::SignInStatus::MaybeOk { user }
622 | request::SignInStatus::AlreadySignedIn { user } => {
623 SignInStatus::Authorized { _user: user }
624 }
625 request::SignInStatus::NotAuthorized { user } => {
626 SignInStatus::Unauthorized { _user: user }
627 }
628 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
629 };
630 cx.notify();
631 }
632 }
633}
634
635fn id_for_language(language: Option<&Arc<Language>>) -> String {
636 let language_name = language.map(|language| language.name());
637 match language_name.as_deref() {
638 Some("Plain Text") => "plaintext".to_string(),
639 Some(language_name) => language_name.to_lowercase(),
640 None => "plaintext".to_string(),
641 }
642}
643
644async fn clear_copilot_dir() {
645 remove_matching(&paths::COPILOT_DIR, |_| true).await
646}
647
648async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
649 const SERVER_PATH: &'static str = "dist/agent.js";
650
651 ///Check for the latest copilot language server and download it if we haven't already
652 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
653 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
654
655 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
656
657 fs::create_dir_all(version_dir).await?;
658 let server_path = version_dir.join(SERVER_PATH);
659
660 if fs::metadata(&server_path).await.is_err() {
661 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
662 let dist_dir = version_dir.join("dist");
663 fs::create_dir_all(dist_dir.as_path()).await?;
664
665 let url = &release
666 .assets
667 .get(0)
668 .context("Github release for copilot contained no assets")?
669 .browser_download_url;
670
671 let mut response = http
672 .get(&url, Default::default(), true)
673 .await
674 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
675 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
676 let archive = Archive::new(decompressed_bytes);
677 archive.unpack(dist_dir).await?;
678
679 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
680 }
681
682 Ok(server_path)
683 }
684
685 match fetch_latest(http).await {
686 ok @ Result::Ok(..) => ok,
687 e @ Err(..) => {
688 e.log_err();
689 // Fetch a cached binary, if it exists
690 (|| async move {
691 let mut last_version_dir = None;
692 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
693 while let Some(entry) = entries.next().await {
694 let entry = entry?;
695 if entry.file_type().await?.is_dir() {
696 last_version_dir = Some(entry.path());
697 }
698 }
699 let last_version_dir =
700 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
701 let server_path = last_version_dir.join(SERVER_PATH);
702 if server_path.exists() {
703 Ok(server_path)
704 } else {
705 Err(anyhow!(
706 "missing executable in directory {:?}",
707 last_version_dir
708 ))
709 }
710 })()
711 .await
712 }
713 }
714}