1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
11 Task,
12};
13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
14use log::{debug, error};
15use lsp::LanguageServer;
16use node_runtime::NodeRuntime;
17use request::{LogMessage, StatusNotification};
18use settings::Settings;
19use smol::{fs, io::BufReader, stream::StreamExt};
20use std::{
21 ffi::OsString,
22 ops::Range,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use util::{
27 channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
28 paths, ResultExt,
29};
30
31const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
32actions!(copilot_auth, [SignIn, SignOut]);
33
34const COPILOT_NAMESPACE: &'static str = "copilot";
35actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
36
37pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
38 // Disable Copilot for stable releases.
39 if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
40 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
41 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
42 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
43 });
44 return;
45 }
46
47 let copilot = cx.add_model({
48 let node_runtime = node_runtime.clone();
49 move |cx| Copilot::start(http, node_runtime, cx)
50 });
51 cx.set_global(copilot.clone());
52
53 cx.observe(&copilot, |handle, cx| {
54 let status = handle.read(cx).status();
55 cx.update_global::<collections::CommandPaletteFilter, _, _>(
56 move |filter, _cx| match status {
57 Status::Disabled => {
58 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
59 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
60 }
61 Status::Authorized => {
62 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
63 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
64 }
65 _ => {
66 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
67 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
68 }
69 },
70 );
71 })
72 .detach();
73
74 sign_in::init(cx);
75 cx.add_global_action(|_: &SignIn, cx| {
76 if let Some(copilot) = Copilot::global(cx) {
77 copilot
78 .update(cx, |copilot, cx| copilot.sign_in(cx))
79 .detach_and_log_err(cx);
80 }
81 });
82 cx.add_global_action(|_: &SignOut, cx| {
83 if let Some(copilot) = Copilot::global(cx) {
84 copilot
85 .update(cx, |copilot, cx| copilot.sign_out(cx))
86 .detach_and_log_err(cx);
87 }
88 });
89
90 cx.add_global_action(|_: &Reinstall, cx| {
91 if let Some(copilot) = Copilot::global(cx) {
92 copilot
93 .update(cx, |copilot, cx| copilot.reinstall(cx))
94 .detach();
95 }
96 });
97}
98
99enum CopilotServer {
100 Disabled,
101 Starting {
102 task: Shared<Task<()>>,
103 },
104 Error(Arc<str>),
105 Started {
106 server: Arc<LanguageServer>,
107 status: SignInStatus,
108 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
109 },
110}
111
112#[derive(Clone, Debug)]
113enum SignInStatus {
114 Authorized,
115 Unauthorized,
116 SigningIn {
117 prompt: Option<request::PromptUserDeviceFlow>,
118 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
119 },
120 SignedOut,
121}
122
123#[derive(Debug, Clone)]
124pub enum Status {
125 Starting {
126 task: Shared<Task<()>>,
127 },
128 Error(Arc<str>),
129 Disabled,
130 SignedOut,
131 SigningIn {
132 prompt: Option<request::PromptUserDeviceFlow>,
133 },
134 Unauthorized,
135 Authorized,
136}
137
138impl Status {
139 pub fn is_authorized(&self) -> bool {
140 matches!(self, Status::Authorized)
141 }
142}
143
144#[derive(Debug, PartialEq, Eq)]
145pub struct Completion {
146 pub range: Range<Anchor>,
147 pub text: String,
148}
149
150pub struct Copilot {
151 http: Arc<dyn HttpClient>,
152 node_runtime: Arc<NodeRuntime>,
153 server: CopilotServer,
154}
155
156impl Entity for Copilot {
157 type Event = ();
158}
159
160impl Copilot {
161 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
162 if cx.has_global::<ModelHandle<Self>>() {
163 Some(cx.global::<ModelHandle<Self>>().clone())
164 } else {
165 None
166 }
167 }
168
169 fn start(
170 http: Arc<dyn HttpClient>,
171 node_runtime: Arc<NodeRuntime>,
172 cx: &mut ModelContext<Self>,
173 ) -> Self {
174 cx.observe_global::<Settings, _>({
175 let http = http.clone();
176 let node_runtime = node_runtime.clone();
177 move |this, cx| {
178 if cx.global::<Settings>().enable_copilot_integration {
179 if matches!(this.server, CopilotServer::Disabled) {
180 let start_task = cx
181 .spawn({
182 let http = http.clone();
183 let node_runtime = node_runtime.clone();
184 move |this, cx| {
185 Self::start_language_server(http, node_runtime, this, cx)
186 }
187 })
188 .shared();
189 this.server = CopilotServer::Starting { task: start_task };
190 cx.notify();
191 }
192 } else {
193 this.server = CopilotServer::Disabled;
194 cx.notify();
195 }
196 }
197 })
198 .detach();
199
200 if cx.global::<Settings>().enable_copilot_integration {
201 let start_task = cx
202 .spawn({
203 let http = http.clone();
204 let node_runtime = node_runtime.clone();
205 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
206 })
207 .shared();
208
209 Self {
210 http,
211 node_runtime,
212 server: CopilotServer::Starting { task: start_task },
213 }
214 } else {
215 Self {
216 http,
217 node_runtime,
218 server: CopilotServer::Disabled,
219 }
220 }
221 }
222
223 #[cfg(any(test, feature = "test-support"))]
224 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
225 let (server, fake_server) =
226 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
227 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
228 let this = cx.add_model(|cx| Self {
229 http: http.clone(),
230 node_runtime: NodeRuntime::new(http, cx.background().clone()),
231 server: CopilotServer::Started {
232 server: Arc::new(server),
233 status: SignInStatus::Authorized,
234 subscriptions_by_buffer_id: Default::default(),
235 },
236 });
237 (this, fake_server)
238 }
239
240 fn start_language_server(
241 http: Arc<dyn HttpClient>,
242 node_runtime: Arc<NodeRuntime>,
243 this: ModelHandle<Self>,
244 mut cx: AsyncAppContext,
245 ) -> impl Future<Output = ()> {
246 async move {
247 let start_language_server = async {
248 let server_path = get_copilot_lsp(http).await?;
249 let node_path = node_runtime.binary_path().await?;
250 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
251 let server = LanguageServer::new(
252 0,
253 &node_path,
254 arguments,
255 Path::new("/"),
256 None,
257 cx.clone(),
258 )?;
259
260 let server = server.initialize(Default::default()).await?;
261 let status = server
262 .request::<request::CheckStatus>(request::CheckStatusParams {
263 local_checks_only: false,
264 })
265 .await?;
266
267 server
268 .on_notification::<LogMessage, _>(|params, _cx| {
269 match params.level {
270 // Copilot is pretty agressive about logging
271 0 => debug!("copilot: {}", params.message),
272 1 => debug!("copilot: {}", params.message),
273 _ => error!("copilot: {}", params.message),
274 }
275
276 debug!("copilot metadata: {}", params.metadata_str);
277 debug!("copilot extra: {:?}", params.extra);
278 })
279 .detach();
280
281 server
282 .on_notification::<StatusNotification, _>(
283 |_, _| { /* Silence the notification */ },
284 )
285 .detach();
286
287 anyhow::Ok((server, status))
288 };
289
290 let server = start_language_server.await;
291 this.update(&mut cx, |this, cx| {
292 cx.notify();
293 match server {
294 Ok((server, status)) => {
295 this.server = CopilotServer::Started {
296 server,
297 status: SignInStatus::SignedOut,
298 subscriptions_by_buffer_id: Default::default(),
299 };
300 this.update_sign_in_status(status, cx);
301 }
302 Err(error) => {
303 this.server = CopilotServer::Error(error.to_string().into());
304 cx.notify()
305 }
306 }
307 })
308 }
309 }
310
311 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
312 if let CopilotServer::Started { server, status, .. } = &mut self.server {
313 let task = match status {
314 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
315 Task::ready(Ok(())).shared()
316 }
317 SignInStatus::SigningIn { task, .. } => {
318 cx.notify();
319 task.clone()
320 }
321 SignInStatus::SignedOut => {
322 let server = server.clone();
323 let task = cx
324 .spawn(|this, mut cx| async move {
325 let sign_in = async {
326 let sign_in = server
327 .request::<request::SignInInitiate>(
328 request::SignInInitiateParams {},
329 )
330 .await?;
331 match sign_in {
332 request::SignInInitiateResult::AlreadySignedIn { user } => {
333 Ok(request::SignInStatus::Ok { user })
334 }
335 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
336 this.update(&mut cx, |this, cx| {
337 if let CopilotServer::Started { status, .. } =
338 &mut this.server
339 {
340 if let SignInStatus::SigningIn {
341 prompt: prompt_flow,
342 ..
343 } = status
344 {
345 *prompt_flow = Some(flow.clone());
346 cx.notify();
347 }
348 }
349 });
350 let response = server
351 .request::<request::SignInConfirm>(
352 request::SignInConfirmParams {
353 user_code: flow.user_code,
354 },
355 )
356 .await?;
357 Ok(response)
358 }
359 }
360 };
361
362 let sign_in = sign_in.await;
363 this.update(&mut cx, |this, cx| match sign_in {
364 Ok(status) => {
365 this.update_sign_in_status(status, cx);
366 Ok(())
367 }
368 Err(error) => {
369 this.update_sign_in_status(
370 request::SignInStatus::NotSignedIn,
371 cx,
372 );
373 Err(Arc::new(error))
374 }
375 })
376 })
377 .shared();
378 *status = SignInStatus::SigningIn {
379 prompt: None,
380 task: task.clone(),
381 };
382 cx.notify();
383 task
384 }
385 };
386
387 cx.foreground()
388 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
389 } else {
390 // If we're downloading, wait until download is finished
391 // If we're in a stuck state, display to the user
392 Task::ready(Err(anyhow!("copilot hasn't started yet")))
393 }
394 }
395
396 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
397 if let CopilotServer::Started { server, status, .. } = &mut self.server {
398 *status = SignInStatus::SignedOut;
399 cx.notify();
400
401 let server = server.clone();
402 cx.background().spawn(async move {
403 server
404 .request::<request::SignOut>(request::SignOutParams {})
405 .await?;
406 anyhow::Ok(())
407 })
408 } else {
409 Task::ready(Err(anyhow!("copilot hasn't started yet")))
410 }
411 }
412
413 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
414 let start_task = cx
415 .spawn({
416 let http = self.http.clone();
417 let node_runtime = self.node_runtime.clone();
418 move |this, cx| async move {
419 clear_copilot_dir().await;
420 Self::start_language_server(http, node_runtime, this, cx).await
421 }
422 })
423 .shared();
424
425 self.server = CopilotServer::Starting {
426 task: start_task.clone(),
427 };
428
429 cx.notify();
430
431 cx.foreground().spawn(start_task)
432 }
433
434 pub fn completions<T>(
435 &mut self,
436 buffer: &ModelHandle<Buffer>,
437 position: T,
438 cx: &mut ModelContext<Self>,
439 ) -> Task<Result<Vec<Completion>>>
440 where
441 T: ToPointUtf16,
442 {
443 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
444 }
445
446 pub fn completions_cycling<T>(
447 &mut self,
448 buffer: &ModelHandle<Buffer>,
449 position: T,
450 cx: &mut ModelContext<Self>,
451 ) -> Task<Result<Vec<Completion>>>
452 where
453 T: ToPointUtf16,
454 {
455 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
456 }
457
458 fn request_completions<R, T>(
459 &mut self,
460 buffer: &ModelHandle<Buffer>,
461 position: T,
462 cx: &mut ModelContext<Self>,
463 ) -> Task<Result<Vec<Completion>>>
464 where
465 R: lsp::request::Request<
466 Params = request::GetCompletionsParams,
467 Result = request::GetCompletionsResult,
468 >,
469 T: ToPointUtf16,
470 {
471 let buffer_id = buffer.id();
472 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
473 let snapshot = buffer.read(cx).snapshot();
474 let server = match &mut self.server {
475 CopilotServer::Starting { .. } => {
476 return Task::ready(Err(anyhow!("copilot is still starting")))
477 }
478 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
479 CopilotServer::Error(error) => {
480 return Task::ready(Err(anyhow!(
481 "copilot was not started because of an error: {}",
482 error
483 )))
484 }
485 CopilotServer::Started {
486 server,
487 status,
488 subscriptions_by_buffer_id,
489 } => {
490 if matches!(status, SignInStatus::Authorized { .. }) {
491 subscriptions_by_buffer_id
492 .entry(buffer_id)
493 .or_insert_with(|| {
494 server
495 .notify::<lsp::notification::DidOpenTextDocument>(
496 lsp::DidOpenTextDocumentParams {
497 text_document: lsp::TextDocumentItem {
498 uri: uri.clone(),
499 language_id: id_for_language(
500 buffer.read(cx).language(),
501 ),
502 version: 0,
503 text: snapshot.text(),
504 },
505 },
506 )
507 .log_err();
508
509 let uri = uri.clone();
510 cx.observe_release(buffer, move |this, _, _| {
511 if let CopilotServer::Started {
512 server,
513 subscriptions_by_buffer_id,
514 ..
515 } = &mut this.server
516 {
517 server
518 .notify::<lsp::notification::DidCloseTextDocument>(
519 lsp::DidCloseTextDocumentParams {
520 text_document: lsp::TextDocumentIdentifier::new(
521 uri.clone(),
522 ),
523 },
524 )
525 .log_err();
526 subscriptions_by_buffer_id.remove(&buffer_id);
527 }
528 })
529 });
530
531 server.clone()
532 } else {
533 return Task::ready(Err(anyhow!("must sign in before using copilot")));
534 }
535 }
536 };
537
538 let settings = cx.global::<Settings>();
539 let position = position.to_point_utf16(&snapshot);
540 let language = snapshot.language_at(position);
541 let language_name = language.map(|language| language.name());
542 let language_name = language_name.as_deref();
543 let tab_size = settings.tab_size(language_name);
544 let hard_tabs = settings.hard_tabs(language_name);
545 let language_id = id_for_language(language);
546
547 let path;
548 let relative_path;
549 if let Some(file) = snapshot.file() {
550 if let Some(file) = file.as_local() {
551 path = file.abs_path(cx);
552 } else {
553 path = file.full_path(cx);
554 }
555 relative_path = file.path().to_path_buf();
556 } else {
557 path = PathBuf::new();
558 relative_path = PathBuf::new();
559 }
560
561 cx.background().spawn(async move {
562 let result = server
563 .request::<R>(request::GetCompletionsParams {
564 doc: request::GetCompletionsDocument {
565 source: snapshot.text(),
566 tab_size: tab_size.into(),
567 indent_size: 1,
568 insert_spaces: !hard_tabs,
569 uri,
570 path: path.to_string_lossy().into(),
571 relative_path: relative_path.to_string_lossy().into(),
572 language_id,
573 position: point_to_lsp(position),
574 version: 0,
575 },
576 })
577 .await?;
578 let completions = result
579 .completions
580 .into_iter()
581 .map(|completion| {
582 let start = snapshot
583 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
584 let end =
585 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
586 Completion {
587 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
588 text: completion.text,
589 }
590 })
591 .collect();
592 anyhow::Ok(completions)
593 })
594 }
595
596 pub fn status(&self) -> Status {
597 match &self.server {
598 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
599 CopilotServer::Disabled => Status::Disabled,
600 CopilotServer::Error(error) => Status::Error(error.clone()),
601 CopilotServer::Started { status, .. } => match status {
602 SignInStatus::Authorized { .. } => Status::Authorized,
603 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
604 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
605 prompt: prompt.clone(),
606 },
607 SignInStatus::SignedOut => Status::SignedOut,
608 },
609 }
610 }
611
612 fn update_sign_in_status(
613 &mut self,
614 lsp_status: request::SignInStatus,
615 cx: &mut ModelContext<Self>,
616 ) {
617 if let CopilotServer::Started { status, .. } = &mut self.server {
618 *status = match lsp_status {
619 request::SignInStatus::Ok { .. }
620 | request::SignInStatus::MaybeOk { .. }
621 | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
622 request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
623 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
624 };
625 cx.notify();
626 }
627 }
628}
629
630fn id_for_language(language: Option<&Arc<Language>>) -> String {
631 let language_name = language.map(|language| language.name());
632 match language_name.as_deref() {
633 Some("Plain Text") => "plaintext".to_string(),
634 Some(language_name) => language_name.to_lowercase(),
635 None => "plaintext".to_string(),
636 }
637}
638
639async fn clear_copilot_dir() {
640 remove_matching(&paths::COPILOT_DIR, |_| true).await
641}
642
643async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
644 const SERVER_PATH: &'static str = "dist/agent.js";
645
646 ///Check for the latest copilot language server and download it if we haven't already
647 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
648 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
649
650 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
651
652 fs::create_dir_all(version_dir).await?;
653 let server_path = version_dir.join(SERVER_PATH);
654
655 if fs::metadata(&server_path).await.is_err() {
656 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
657 let dist_dir = version_dir.join("dist");
658 fs::create_dir_all(dist_dir.as_path()).await?;
659
660 let url = &release
661 .assets
662 .get(0)
663 .context("Github release for copilot contained no assets")?
664 .browser_download_url;
665
666 let mut response = http
667 .get(&url, Default::default(), true)
668 .await
669 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
670 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
671 let archive = Archive::new(decompressed_bytes);
672 archive.unpack(dist_dir).await?;
673
674 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
675 }
676
677 Ok(server_path)
678 }
679
680 match fetch_latest(http).await {
681 ok @ Result::Ok(..) => ok,
682 e @ Err(..) => {
683 e.log_err();
684 // Fetch a cached binary, if it exists
685 (|| async move {
686 let mut last_version_dir = None;
687 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
688 while let Some(entry) = entries.next().await {
689 let entry = entry?;
690 if entry.file_type().await?.is_dir() {
691 last_version_dir = Some(entry.path());
692 }
693 }
694 let last_version_dir =
695 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
696 let server_path = last_version_dir.join(SERVER_PATH);
697 if server_path.exists() {
698 Ok(server_path)
699 } else {
700 Err(anyhow!(
701 "missing executable in directory {:?}",
702 last_version_dir
703 ))
704 }
705 })()
706 .await
707 }
708 }
709}