1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
10use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
11use log::{debug, error};
12use lsp::LanguageServer;
13use node_runtime::NodeRuntime;
14use request::{LogMessage, StatusNotification};
15use settings::Settings;
16use smol::{fs, io::BufReader, stream::StreamExt};
17use std::{
18 ffi::OsString,
19 ops::Range,
20 path::{Path, PathBuf},
21 sync::Arc,
22};
23use util::{
24 channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
25 paths, ResultExt,
26};
27
28const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
29actions!(copilot_auth, [SignIn, SignOut]);
30
31const COPILOT_NAMESPACE: &'static str = "copilot";
32actions!(
33 copilot,
34 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
35);
36
37pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
38 // Disable Copilot for stable releases.
39 if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
40 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
41 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
42 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
43 });
44 return;
45 }
46
47 let copilot = cx.add_model({
48 let node_runtime = node_runtime.clone();
49 move |cx| Copilot::start(http, node_runtime, cx)
50 });
51 cx.set_global(copilot.clone());
52
53 cx.observe(&copilot, |handle, cx| {
54 let status = handle.read(cx).status();
55 cx.update_global::<collections::CommandPaletteFilter, _, _>(
56 move |filter, _cx| match status {
57 Status::Disabled => {
58 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
59 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
60 }
61 Status::Authorized => {
62 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
63 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
64 }
65 _ => {
66 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
67 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
68 }
69 },
70 );
71 })
72 .detach();
73
74 sign_in::init(cx);
75 cx.add_global_action(|_: &SignIn, cx| {
76 if let Some(copilot) = Copilot::global(cx) {
77 copilot
78 .update(cx, |copilot, cx| copilot.sign_in(cx))
79 .detach_and_log_err(cx);
80 }
81 });
82 cx.add_global_action(|_: &SignOut, cx| {
83 if let Some(copilot) = Copilot::global(cx) {
84 copilot
85 .update(cx, |copilot, cx| copilot.sign_out(cx))
86 .detach_and_log_err(cx);
87 }
88 });
89
90 cx.add_global_action(|_: &Reinstall, cx| {
91 if let Some(copilot) = Copilot::global(cx) {
92 copilot
93 .update(cx, |copilot, cx| copilot.reinstall(cx))
94 .detach();
95 }
96 });
97}
98
99enum CopilotServer {
100 Disabled,
101 Starting {
102 task: Shared<Task<()>>,
103 },
104 Error(Arc<str>),
105 Started {
106 server: Arc<LanguageServer>,
107 status: SignInStatus,
108 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
109 },
110}
111
112#[derive(Clone, Debug)]
113enum SignInStatus {
114 Authorized,
115 Unauthorized,
116 SigningIn {
117 prompt: Option<request::PromptUserDeviceFlow>,
118 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
119 },
120 SignedOut,
121}
122
123#[derive(Debug, Clone)]
124pub enum Status {
125 Starting {
126 task: Shared<Task<()>>,
127 },
128 Error(Arc<str>),
129 Disabled,
130 SignedOut,
131 SigningIn {
132 prompt: Option<request::PromptUserDeviceFlow>,
133 },
134 Unauthorized,
135 Authorized,
136}
137
138impl Status {
139 pub fn is_authorized(&self) -> bool {
140 matches!(self, Status::Authorized)
141 }
142}
143
144#[derive(Debug, PartialEq, Eq)]
145pub struct Completion {
146 pub range: Range<Anchor>,
147 pub text: String,
148}
149
150pub struct Copilot {
151 http: Arc<dyn HttpClient>,
152 node_runtime: Arc<NodeRuntime>,
153 server: CopilotServer,
154}
155
156impl Entity for Copilot {
157 type Event = ();
158}
159
160impl Copilot {
161 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
162 if cx.has_global::<ModelHandle<Self>>() {
163 Some(cx.global::<ModelHandle<Self>>().clone())
164 } else {
165 None
166 }
167 }
168
169 fn start(
170 http: Arc<dyn HttpClient>,
171 node_runtime: Arc<NodeRuntime>,
172 cx: &mut ModelContext<Self>,
173 ) -> Self {
174 cx.observe_global::<Settings, _>({
175 let http = http.clone();
176 let node_runtime = node_runtime.clone();
177 move |this, cx| {
178 if cx.global::<Settings>().features.copilot {
179 if matches!(this.server, CopilotServer::Disabled) {
180 let start_task = cx
181 .spawn({
182 let http = http.clone();
183 let node_runtime = node_runtime.clone();
184 move |this, cx| {
185 Self::start_language_server(http, node_runtime, this, cx)
186 }
187 })
188 .shared();
189 this.server = CopilotServer::Starting { task: start_task };
190 cx.notify();
191 }
192 } else {
193 this.server = CopilotServer::Disabled;
194 cx.notify();
195 }
196 }
197 })
198 .detach();
199
200 if cx.global::<Settings>().features.copilot {
201 let start_task = cx
202 .spawn({
203 let http = http.clone();
204 let node_runtime = node_runtime.clone();
205 move |this, cx| async {
206 Self::start_language_server(http, node_runtime, this, cx).await
207 }
208 })
209 .shared();
210
211 Self {
212 http,
213 node_runtime,
214 server: CopilotServer::Starting { task: start_task },
215 }
216 } else {
217 Self {
218 http,
219 node_runtime,
220 server: CopilotServer::Disabled,
221 }
222 }
223 }
224
225 #[cfg(any(test, feature = "test-support"))]
226 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
227 let (server, fake_server) =
228 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
229 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
230 let this = cx.add_model(|cx| Self {
231 http: http.clone(),
232 node_runtime: NodeRuntime::new(http, cx.background().clone()),
233 server: CopilotServer::Started {
234 server: Arc::new(server),
235 status: SignInStatus::Authorized,
236 subscriptions_by_buffer_id: Default::default(),
237 },
238 });
239 (this, fake_server)
240 }
241
242 fn start_language_server(
243 http: Arc<dyn HttpClient>,
244 node_runtime: Arc<NodeRuntime>,
245 this: ModelHandle<Self>,
246 mut cx: AsyncAppContext,
247 ) -> impl Future<Output = ()> {
248 async move {
249 let start_language_server = async {
250 let server_path = get_copilot_lsp(http).await?;
251 let node_path = node_runtime.binary_path().await?;
252 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
253 let server = LanguageServer::new(
254 0,
255 &node_path,
256 arguments,
257 Path::new("/"),
258 None,
259 cx.clone(),
260 )?;
261
262 let server = server.initialize(Default::default()).await?;
263 let status = server
264 .request::<request::CheckStatus>(request::CheckStatusParams {
265 local_checks_only: false,
266 })
267 .await?;
268
269 server
270 .on_notification::<LogMessage, _>(|params, _cx| {
271 match params.level {
272 // Copilot is pretty agressive about logging
273 0 => debug!("copilot: {}", params.message),
274 1 => debug!("copilot: {}", params.message),
275 _ => error!("copilot: {}", params.message),
276 }
277
278 debug!("copilot metadata: {}", params.metadata_str);
279 debug!("copilot extra: {:?}", params.extra);
280 })
281 .detach();
282
283 server
284 .on_notification::<StatusNotification, _>(
285 |_, _| { /* Silence the notification */ },
286 )
287 .detach();
288
289 anyhow::Ok((server, status))
290 };
291
292 let server = start_language_server.await;
293 this.update(&mut cx, |this, cx| {
294 cx.notify();
295 match server {
296 Ok((server, status)) => {
297 this.server = CopilotServer::Started {
298 server,
299 status: SignInStatus::SignedOut,
300 subscriptions_by_buffer_id: Default::default(),
301 };
302 this.update_sign_in_status(status, cx);
303 }
304 Err(error) => {
305 this.server = CopilotServer::Error(error.to_string().into());
306 cx.notify()
307 }
308 }
309 })
310 }
311 }
312
313 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
314 if let CopilotServer::Started { server, status, .. } = &mut self.server {
315 let task = match status {
316 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
317 Task::ready(Ok(())).shared()
318 }
319 SignInStatus::SigningIn { task, .. } => {
320 cx.notify();
321 task.clone()
322 }
323 SignInStatus::SignedOut => {
324 let server = server.clone();
325 let task = cx
326 .spawn(|this, mut cx| async move {
327 let sign_in = async {
328 let sign_in = server
329 .request::<request::SignInInitiate>(
330 request::SignInInitiateParams {},
331 )
332 .await?;
333 match sign_in {
334 request::SignInInitiateResult::AlreadySignedIn { user } => {
335 Ok(request::SignInStatus::Ok { user })
336 }
337 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
338 this.update(&mut cx, |this, cx| {
339 if let CopilotServer::Started { status, .. } =
340 &mut this.server
341 {
342 if let SignInStatus::SigningIn {
343 prompt: prompt_flow,
344 ..
345 } = status
346 {
347 *prompt_flow = Some(flow.clone());
348 cx.notify();
349 }
350 }
351 });
352 let response = server
353 .request::<request::SignInConfirm>(
354 request::SignInConfirmParams {
355 user_code: flow.user_code,
356 },
357 )
358 .await?;
359 Ok(response)
360 }
361 }
362 };
363
364 let sign_in = sign_in.await;
365 this.update(&mut cx, |this, cx| match sign_in {
366 Ok(status) => {
367 this.update_sign_in_status(status, cx);
368 Ok(())
369 }
370 Err(error) => {
371 this.update_sign_in_status(
372 request::SignInStatus::NotSignedIn,
373 cx,
374 );
375 Err(Arc::new(error))
376 }
377 })
378 })
379 .shared();
380 *status = SignInStatus::SigningIn {
381 prompt: None,
382 task: task.clone(),
383 };
384 cx.notify();
385 task
386 }
387 };
388
389 cx.foreground()
390 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
391 } else {
392 // If we're downloading, wait until download is finished
393 // If we're in a stuck state, display to the user
394 Task::ready(Err(anyhow!("copilot hasn't started yet")))
395 }
396 }
397
398 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
399 if let CopilotServer::Started { server, status, .. } = &mut self.server {
400 *status = SignInStatus::SignedOut;
401 cx.notify();
402
403 let server = server.clone();
404 cx.background().spawn(async move {
405 server
406 .request::<request::SignOut>(request::SignOutParams {})
407 .await?;
408 anyhow::Ok(())
409 })
410 } else {
411 Task::ready(Err(anyhow!("copilot hasn't started yet")))
412 }
413 }
414
415 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
416 let start_task = cx
417 .spawn({
418 let http = self.http.clone();
419 let node_runtime = self.node_runtime.clone();
420 move |this, cx| async move {
421 clear_copilot_dir().await;
422 Self::start_language_server(http, node_runtime, this, cx).await
423 }
424 })
425 .shared();
426
427 self.server = CopilotServer::Starting {
428 task: start_task.clone(),
429 };
430
431 cx.notify();
432
433 cx.foreground().spawn(start_task)
434 }
435
436 pub fn completions<T>(
437 &mut self,
438 buffer: &ModelHandle<Buffer>,
439 position: T,
440 cx: &mut ModelContext<Self>,
441 ) -> Task<Result<Vec<Completion>>>
442 where
443 T: ToPointUtf16,
444 {
445 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
446 }
447
448 pub fn completions_cycling<T>(
449 &mut self,
450 buffer: &ModelHandle<Buffer>,
451 position: T,
452 cx: &mut ModelContext<Self>,
453 ) -> Task<Result<Vec<Completion>>>
454 where
455 T: ToPointUtf16,
456 {
457 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
458 }
459
460 fn request_completions<R, T>(
461 &mut self,
462 buffer: &ModelHandle<Buffer>,
463 position: T,
464 cx: &mut ModelContext<Self>,
465 ) -> Task<Result<Vec<Completion>>>
466 where
467 R: lsp::request::Request<
468 Params = request::GetCompletionsParams,
469 Result = request::GetCompletionsResult,
470 >,
471 T: ToPointUtf16,
472 {
473 let buffer_id = buffer.id();
474 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
475 let snapshot = buffer.read(cx).snapshot();
476 let server = match &mut self.server {
477 CopilotServer::Starting { .. } => {
478 return Task::ready(Err(anyhow!("copilot is still starting")))
479 }
480 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
481 CopilotServer::Error(error) => {
482 return Task::ready(Err(anyhow!(
483 "copilot was not started because of an error: {}",
484 error
485 )))
486 }
487 CopilotServer::Started {
488 server,
489 status,
490 subscriptions_by_buffer_id,
491 } => {
492 if matches!(status, SignInStatus::Authorized { .. }) {
493 subscriptions_by_buffer_id
494 .entry(buffer_id)
495 .or_insert_with(|| {
496 server
497 .notify::<lsp::notification::DidOpenTextDocument>(
498 lsp::DidOpenTextDocumentParams {
499 text_document: lsp::TextDocumentItem {
500 uri: uri.clone(),
501 language_id: id_for_language(
502 buffer.read(cx).language(),
503 ),
504 version: 0,
505 text: snapshot.text(),
506 },
507 },
508 )
509 .log_err();
510
511 let uri = uri.clone();
512 cx.observe_release(buffer, move |this, _, _| {
513 if let CopilotServer::Started {
514 server,
515 subscriptions_by_buffer_id,
516 ..
517 } = &mut this.server
518 {
519 server
520 .notify::<lsp::notification::DidCloseTextDocument>(
521 lsp::DidCloseTextDocumentParams {
522 text_document: lsp::TextDocumentIdentifier::new(
523 uri.clone(),
524 ),
525 },
526 )
527 .log_err();
528 subscriptions_by_buffer_id.remove(&buffer_id);
529 }
530 })
531 });
532
533 server.clone()
534 } else {
535 return Task::ready(Err(anyhow!("must sign in before using copilot")));
536 }
537 }
538 };
539
540 let settings = cx.global::<Settings>();
541 let position = position.to_point_utf16(&snapshot);
542 let language = snapshot.language_at(position);
543 let language_name = language.map(|language| language.name());
544 let language_name = language_name.as_deref();
545 let tab_size = settings.tab_size(language_name);
546 let hard_tabs = settings.hard_tabs(language_name);
547 let language_id = id_for_language(language);
548
549 let path;
550 let relative_path;
551 if let Some(file) = snapshot.file() {
552 if let Some(file) = file.as_local() {
553 path = file.abs_path(cx);
554 } else {
555 path = file.full_path(cx);
556 }
557 relative_path = file.path().to_path_buf();
558 } else {
559 path = PathBuf::new();
560 relative_path = PathBuf::new();
561 }
562
563 cx.background().spawn(async move {
564 let result = server
565 .request::<R>(request::GetCompletionsParams {
566 doc: request::GetCompletionsDocument {
567 source: snapshot.text(),
568 tab_size: tab_size.into(),
569 indent_size: 1,
570 insert_spaces: !hard_tabs,
571 uri,
572 path: path.to_string_lossy().into(),
573 relative_path: relative_path.to_string_lossy().into(),
574 language_id,
575 position: point_to_lsp(position),
576 version: 0,
577 },
578 })
579 .await?;
580 let completions = result
581 .completions
582 .into_iter()
583 .map(|completion| {
584 let start = snapshot
585 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
586 let end =
587 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
588 Completion {
589 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
590 text: completion.text,
591 }
592 })
593 .collect();
594 anyhow::Ok(completions)
595 })
596 }
597
598 pub fn status(&self) -> Status {
599 match &self.server {
600 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
601 CopilotServer::Disabled => Status::Disabled,
602 CopilotServer::Error(error) => Status::Error(error.clone()),
603 CopilotServer::Started { status, .. } => match status {
604 SignInStatus::Authorized { .. } => Status::Authorized,
605 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
606 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
607 prompt: prompt.clone(),
608 },
609 SignInStatus::SignedOut => Status::SignedOut,
610 },
611 }
612 }
613
614 fn update_sign_in_status(
615 &mut self,
616 lsp_status: request::SignInStatus,
617 cx: &mut ModelContext<Self>,
618 ) {
619 if let CopilotServer::Started { status, .. } = &mut self.server {
620 *status = match lsp_status {
621 request::SignInStatus::Ok { .. }
622 | request::SignInStatus::MaybeOk { .. }
623 | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
624 request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
625 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
626 };
627 cx.notify();
628 }
629 }
630}
631
632fn id_for_language(language: Option<&Arc<Language>>) -> String {
633 let language_name = language.map(|language| language.name());
634 match language_name.as_deref() {
635 Some("Plain Text") => "plaintext".to_string(),
636 Some(language_name) => language_name.to_lowercase(),
637 None => "plaintext".to_string(),
638 }
639}
640
641async fn clear_copilot_dir() {
642 remove_matching(&paths::COPILOT_DIR, |_| true).await
643}
644
645async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
646 const SERVER_PATH: &'static str = "dist/agent.js";
647
648 ///Check for the latest copilot language server and download it if we haven't already
649 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
650 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
651
652 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
653
654 fs::create_dir_all(version_dir).await?;
655 let server_path = version_dir.join(SERVER_PATH);
656
657 if fs::metadata(&server_path).await.is_err() {
658 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
659 let dist_dir = version_dir.join("dist");
660 fs::create_dir_all(dist_dir.as_path()).await?;
661
662 let url = &release
663 .assets
664 .get(0)
665 .context("Github release for copilot contained no assets")?
666 .browser_download_url;
667
668 let mut response = http
669 .get(&url, Default::default(), true)
670 .await
671 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
672 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
673 let archive = Archive::new(decompressed_bytes);
674 archive.unpack(dist_dir).await?;
675
676 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
677 }
678
679 Ok(server_path)
680 }
681
682 match fetch_latest(http).await {
683 ok @ Result::Ok(..) => ok,
684 e @ Err(..) => {
685 e.log_err();
686 // Fetch a cached binary, if it exists
687 (|| async move {
688 let mut last_version_dir = None;
689 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
690 while let Some(entry) = entries.next().await {
691 let entry = entry?;
692 if entry.file_type().await?.is_dir() {
693 last_version_dir = Some(entry.path());
694 }
695 }
696 let last_version_dir =
697 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
698 let server_path = last_version_dir.join(SERVER_PATH);
699 if server_path.exists() {
700 Ok(server_path)
701 } else {
702 Err(anyhow!(
703 "missing executable in directory {:?}",
704 last_version_dir
705 ))
706 }
707 })()
708 .await
709 }
710 }
711}