1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use collections::HashMap;
9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
10use gpui::{
11 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
12 Task,
13};
14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
15use log::{debug, error};
16use lsp::LanguageServer;
17use node_runtime::NodeRuntime;
18use request::{LogMessage, StatusNotification};
19use settings::Settings;
20use smol::{fs, io::BufReader, stream::StreamExt};
21use std::{
22 ffi::OsString,
23 ops::Range,
24 path::{Path, PathBuf},
25 sync::Arc,
26};
27use util::{
28 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
29};
30
31const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
32actions!(copilot_auth, [SignIn, SignOut]);
33
34const COPILOT_NAMESPACE: &'static str = "copilot";
35actions!(
36 copilot,
37 [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
38);
39
40pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
41 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
42 cx.set_global(copilot.clone());
43 cx.add_global_action(|_: &SignIn, cx| {
44 let copilot = Copilot::global(cx).unwrap();
45 copilot
46 .update(cx, |copilot, cx| copilot.sign_in(cx))
47 .detach_and_log_err(cx);
48 });
49 cx.add_global_action(|_: &SignOut, cx| {
50 let copilot = Copilot::global(cx).unwrap();
51 copilot
52 .update(cx, |copilot, cx| copilot.sign_out(cx))
53 .detach_and_log_err(cx);
54 });
55
56 cx.add_global_action(|_: &Reinstall, cx| {
57 let copilot = Copilot::global(cx).unwrap();
58 copilot
59 .update(cx, |copilot, cx| copilot.reinstall(cx))
60 .detach();
61 });
62
63 cx.observe(&copilot, |handle, cx| {
64 let status = handle.read(cx).status();
65 cx.update_global::<collections::CommandPaletteFilter, _, _>(
66 move |filter, _cx| match status {
67 Status::Disabled => {
68 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
69 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
70 }
71 Status::Authorized => {
72 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
73 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
74 }
75 _ => {
76 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
77 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
78 }
79 },
80 );
81 })
82 .detach();
83
84 sign_in::init(cx);
85}
86
87enum CopilotServer {
88 Disabled,
89 Starting {
90 task: Shared<Task<()>>,
91 },
92 Error(Arc<str>),
93 Started {
94 server: Arc<LanguageServer>,
95 status: SignInStatus,
96 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
97 },
98}
99
100#[derive(Clone, Debug)]
101enum SignInStatus {
102 Authorized {
103 _user: String,
104 },
105 Unauthorized {
106 _user: String,
107 },
108 SigningIn {
109 prompt: Option<request::PromptUserDeviceFlow>,
110 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
111 },
112 SignedOut,
113}
114
115#[derive(Debug, Clone)]
116pub enum Status {
117 Starting {
118 task: Shared<Task<()>>,
119 },
120 Error(Arc<str>),
121 Disabled,
122 SignedOut,
123 SigningIn {
124 prompt: Option<request::PromptUserDeviceFlow>,
125 },
126 Unauthorized,
127 Authorized,
128}
129
130impl Status {
131 pub fn is_authorized(&self) -> bool {
132 matches!(self, Status::Authorized)
133 }
134}
135
136#[derive(Debug, PartialEq, Eq)]
137pub struct Completion {
138 pub range: Range<Anchor>,
139 pub text: String,
140}
141
142pub struct Copilot {
143 http: Arc<dyn HttpClient>,
144 node_runtime: Arc<NodeRuntime>,
145 server: CopilotServer,
146}
147
148impl Entity for Copilot {
149 type Event = ();
150}
151
152impl Copilot {
153 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
154 match self.server {
155 CopilotServer::Starting { ref task } => Some(task.clone()),
156 _ => None,
157 }
158 }
159
160 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
161 if cx.has_global::<ModelHandle<Self>>() {
162 Some(cx.global::<ModelHandle<Self>>().clone())
163 } else {
164 None
165 }
166 }
167
168 fn start(
169 http: Arc<dyn HttpClient>,
170 node_runtime: Arc<NodeRuntime>,
171 cx: &mut ModelContext<Self>,
172 ) -> Self {
173 cx.observe_global::<Settings, _>({
174 let http = http.clone();
175 let node_runtime = node_runtime.clone();
176 move |this, cx| {
177 if cx.global::<Settings>().enable_copilot_integration {
178 if matches!(this.server, CopilotServer::Disabled) {
179 let start_task = cx
180 .spawn({
181 let http = http.clone();
182 let node_runtime = node_runtime.clone();
183 move |this, cx| {
184 Self::start_language_server(http, node_runtime, this, cx)
185 }
186 })
187 .shared();
188 this.server = CopilotServer::Starting { task: start_task };
189 cx.notify();
190 }
191 } else {
192 this.server = CopilotServer::Disabled;
193 cx.notify();
194 }
195 }
196 })
197 .detach();
198
199 if cx.global::<Settings>().enable_copilot_integration {
200 let start_task = cx
201 .spawn({
202 let http = http.clone();
203 let node_runtime = node_runtime.clone();
204 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
205 })
206 .shared();
207
208 Self {
209 http,
210 node_runtime,
211 server: CopilotServer::Starting { task: start_task },
212 }
213 } else {
214 Self {
215 http,
216 node_runtime,
217 server: CopilotServer::Disabled,
218 }
219 }
220 }
221
222 fn start_language_server(
223 http: Arc<dyn HttpClient>,
224 node_runtime: Arc<NodeRuntime>,
225 this: ModelHandle<Self>,
226 mut cx: AsyncAppContext,
227 ) -> impl Future<Output = ()> {
228 async move {
229 let start_language_server = async {
230 let server_path = get_copilot_lsp(http).await?;
231 let node_path = node_runtime.binary_path().await?;
232 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
233 let server = LanguageServer::new(
234 0,
235 &node_path,
236 arguments,
237 Path::new("/"),
238 None,
239 cx.clone(),
240 )?;
241
242 let server = server.initialize(Default::default()).await?;
243 let status = server
244 .request::<request::CheckStatus>(request::CheckStatusParams {
245 local_checks_only: false,
246 })
247 .await?;
248
249 server
250 .on_notification::<LogMessage, _>(|params, _cx| {
251 match params.level {
252 // Copilot is pretty agressive about logging
253 0 => debug!("copilot: {}", params.message),
254 1 => debug!("copilot: {}", params.message),
255 _ => error!("copilot: {}", params.message),
256 }
257
258 debug!("copilot metadata: {}", params.metadata_str);
259 debug!("copilot extra: {:?}", params.extra);
260 })
261 .detach();
262
263 server
264 .on_notification::<StatusNotification, _>(
265 |_, _| { /* Silence the notification */ },
266 )
267 .detach();
268
269 anyhow::Ok((server, status))
270 };
271
272 let server = start_language_server.await;
273 this.update(&mut cx, |this, cx| {
274 cx.notify();
275 match server {
276 Ok((server, status)) => {
277 this.server = CopilotServer::Started {
278 server,
279 status: SignInStatus::SignedOut,
280 subscriptions_by_buffer_id: Default::default(),
281 };
282 this.update_sign_in_status(status, cx);
283 }
284 Err(error) => {
285 this.server = CopilotServer::Error(error.to_string().into());
286 cx.notify()
287 }
288 }
289 })
290 }
291 }
292
293 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
294 if let CopilotServer::Started { server, status, .. } = &mut self.server {
295 let task = match status {
296 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
297 Task::ready(Ok(())).shared()
298 }
299 SignInStatus::SigningIn { task, .. } => {
300 cx.notify();
301 task.clone()
302 }
303 SignInStatus::SignedOut => {
304 let server = server.clone();
305 let task = cx
306 .spawn(|this, mut cx| async move {
307 let sign_in = async {
308 let sign_in = server
309 .request::<request::SignInInitiate>(
310 request::SignInInitiateParams {},
311 )
312 .await?;
313 match sign_in {
314 request::SignInInitiateResult::AlreadySignedIn { user } => {
315 Ok(request::SignInStatus::Ok { user })
316 }
317 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
318 this.update(&mut cx, |this, cx| {
319 if let CopilotServer::Started { status, .. } =
320 &mut this.server
321 {
322 if let SignInStatus::SigningIn {
323 prompt: prompt_flow,
324 ..
325 } = status
326 {
327 *prompt_flow = Some(flow.clone());
328 cx.notify();
329 }
330 }
331 });
332 let response = server
333 .request::<request::SignInConfirm>(
334 request::SignInConfirmParams {
335 user_code: flow.user_code,
336 },
337 )
338 .await?;
339 Ok(response)
340 }
341 }
342 };
343
344 let sign_in = sign_in.await;
345 this.update(&mut cx, |this, cx| match sign_in {
346 Ok(status) => {
347 this.update_sign_in_status(status, cx);
348 Ok(())
349 }
350 Err(error) => {
351 this.update_sign_in_status(
352 request::SignInStatus::NotSignedIn,
353 cx,
354 );
355 Err(Arc::new(error))
356 }
357 })
358 })
359 .shared();
360 *status = SignInStatus::SigningIn {
361 prompt: None,
362 task: task.clone(),
363 };
364 cx.notify();
365 task
366 }
367 };
368
369 cx.foreground()
370 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
371 } else {
372 // If we're downloading, wait until download is finished
373 // If we're in a stuck state, display to the user
374 Task::ready(Err(anyhow!("copilot hasn't started yet")))
375 }
376 }
377
378 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
379 if let CopilotServer::Started { server, status, .. } = &mut self.server {
380 *status = SignInStatus::SignedOut;
381 cx.notify();
382
383 let server = server.clone();
384 cx.background().spawn(async move {
385 server
386 .request::<request::SignOut>(request::SignOutParams {})
387 .await?;
388 anyhow::Ok(())
389 })
390 } else {
391 Task::ready(Err(anyhow!("copilot hasn't started yet")))
392 }
393 }
394
395 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
396 let start_task = cx
397 .spawn({
398 let http = self.http.clone();
399 let node_runtime = self.node_runtime.clone();
400 move |this, cx| async move {
401 clear_copilot_dir().await;
402 Self::start_language_server(http, node_runtime, this, cx).await
403 }
404 })
405 .shared();
406
407 self.server = CopilotServer::Starting {
408 task: start_task.clone(),
409 };
410
411 cx.notify();
412
413 cx.foreground().spawn(start_task)
414 }
415
416 pub fn completions<T>(
417 &mut self,
418 buffer: &ModelHandle<Buffer>,
419 position: T,
420 cx: &mut ModelContext<Self>,
421 ) -> Task<Result<Vec<Completion>>>
422 where
423 T: ToPointUtf16,
424 {
425 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
426 }
427
428 pub fn completions_cycling<T>(
429 &mut self,
430 buffer: &ModelHandle<Buffer>,
431 position: T,
432 cx: &mut ModelContext<Self>,
433 ) -> Task<Result<Vec<Completion>>>
434 where
435 T: ToPointUtf16,
436 {
437 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
438 }
439
440 fn request_completions<R, T>(
441 &mut self,
442 buffer: &ModelHandle<Buffer>,
443 position: T,
444 cx: &mut ModelContext<Self>,
445 ) -> Task<Result<Vec<Completion>>>
446 where
447 R: lsp::request::Request<
448 Params = request::GetCompletionsParams,
449 Result = request::GetCompletionsResult,
450 >,
451 T: ToPointUtf16,
452 {
453 let buffer_id = buffer.id();
454 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
455 let snapshot = buffer.read(cx).snapshot();
456 let server = match &mut self.server {
457 CopilotServer::Starting { .. } => {
458 return Task::ready(Err(anyhow!("copilot is still starting")))
459 }
460 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
461 CopilotServer::Error(error) => {
462 return Task::ready(Err(anyhow!(
463 "copilot was not started because of an error: {}",
464 error
465 )))
466 }
467 CopilotServer::Started {
468 server,
469 status,
470 subscriptions_by_buffer_id,
471 } => {
472 if matches!(status, SignInStatus::Authorized { .. }) {
473 subscriptions_by_buffer_id
474 .entry(buffer_id)
475 .or_insert_with(|| {
476 server
477 .notify::<lsp::notification::DidOpenTextDocument>(
478 lsp::DidOpenTextDocumentParams {
479 text_document: lsp::TextDocumentItem {
480 uri: uri.clone(),
481 language_id: id_for_language(
482 buffer.read(cx).language(),
483 ),
484 version: 0,
485 text: snapshot.text(),
486 },
487 },
488 )
489 .log_err();
490
491 let uri = uri.clone();
492 cx.observe_release(buffer, move |this, _, _| {
493 if let CopilotServer::Started {
494 server,
495 subscriptions_by_buffer_id,
496 ..
497 } = &mut this.server
498 {
499 server
500 .notify::<lsp::notification::DidCloseTextDocument>(
501 lsp::DidCloseTextDocumentParams {
502 text_document: lsp::TextDocumentIdentifier::new(
503 uri.clone(),
504 ),
505 },
506 )
507 .log_err();
508 subscriptions_by_buffer_id.remove(&buffer_id);
509 }
510 })
511 });
512
513 server.clone()
514 } else {
515 return Task::ready(Err(anyhow!("must sign in before using copilot")));
516 }
517 }
518 };
519
520 let settings = cx.global::<Settings>();
521 let position = position.to_point_utf16(&snapshot);
522 let language = snapshot.language_at(position);
523 let language_name = language.map(|language| language.name());
524 let language_name = language_name.as_deref();
525
526 let path;
527 let relative_path;
528 if let Some(file) = snapshot.file() {
529 if let Some(file) = file.as_local() {
530 path = file.abs_path(cx);
531 } else {
532 path = file.full_path(cx);
533 }
534 relative_path = file.path().to_path_buf();
535 } else {
536 path = PathBuf::new();
537 relative_path = PathBuf::new();
538 }
539
540 let params = request::GetCompletionsParams {
541 doc: request::GetCompletionsDocument {
542 source: snapshot.text(),
543 tab_size: settings.tab_size(language_name).into(),
544 indent_size: 1,
545 insert_spaces: !settings.hard_tabs(language_name),
546 uri,
547 path: path.to_string_lossy().into(),
548 relative_path: relative_path.to_string_lossy().into(),
549 language_id: id_for_language(language),
550 position: point_to_lsp(position),
551 version: 0,
552 },
553 };
554 cx.background().spawn(async move {
555 let result = server.request::<R>(params).await?;
556 let completions = result
557 .completions
558 .into_iter()
559 .map(|completion| {
560 let start = snapshot
561 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
562 let end =
563 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
564 Completion {
565 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
566 text: completion.text,
567 }
568 })
569 .collect();
570 anyhow::Ok(completions)
571 })
572 }
573
574 pub fn status(&self) -> Status {
575 match &self.server {
576 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
577 CopilotServer::Disabled => Status::Disabled,
578 CopilotServer::Error(error) => Status::Error(error.clone()),
579 CopilotServer::Started { status, .. } => match status {
580 SignInStatus::Authorized { .. } => Status::Authorized,
581 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
582 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
583 prompt: prompt.clone(),
584 },
585 SignInStatus::SignedOut => Status::SignedOut,
586 },
587 }
588 }
589
590 fn update_sign_in_status(
591 &mut self,
592 lsp_status: request::SignInStatus,
593 cx: &mut ModelContext<Self>,
594 ) {
595 if let CopilotServer::Started { status, .. } = &mut self.server {
596 *status = match lsp_status {
597 request::SignInStatus::Ok { user }
598 | request::SignInStatus::MaybeOk { user }
599 | request::SignInStatus::AlreadySignedIn { user } => {
600 SignInStatus::Authorized { _user: user }
601 }
602 request::SignInStatus::NotAuthorized { user } => {
603 SignInStatus::Unauthorized { _user: user }
604 }
605 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
606 };
607 cx.notify();
608 }
609 }
610}
611
612fn id_for_language(language: Option<&Arc<Language>>) -> String {
613 let language_name = language.map(|language| language.name());
614 match language_name.as_deref() {
615 Some("Plain Text") => "plaintext".to_string(),
616 Some(language_name) => language_name.to_lowercase(),
617 None => "plaintext".to_string(),
618 }
619}
620
621async fn clear_copilot_dir() {
622 remove_matching(&paths::COPILOT_DIR, |_| true).await
623}
624
625async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
626 const SERVER_PATH: &'static str = "dist/agent.js";
627
628 ///Check for the latest copilot language server and download it if we haven't already
629 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
630 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
631
632 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
633
634 fs::create_dir_all(version_dir).await?;
635 let server_path = version_dir.join(SERVER_PATH);
636
637 if fs::metadata(&server_path).await.is_err() {
638 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
639 let dist_dir = version_dir.join("dist");
640 fs::create_dir_all(dist_dir.as_path()).await?;
641
642 let url = &release
643 .assets
644 .get(0)
645 .context("Github release for copilot contained no assets")?
646 .browser_download_url;
647
648 let mut response = http
649 .get(&url, Default::default(), true)
650 .await
651 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
652 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
653 let archive = Archive::new(decompressed_bytes);
654 archive.unpack(dist_dir).await?;
655
656 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
657 }
658
659 Ok(server_path)
660 }
661
662 match fetch_latest(http).await {
663 ok @ Result::Ok(..) => ok,
664 e @ Err(..) => {
665 e.log_err();
666 // Fetch a cached binary, if it exists
667 (|| async move {
668 let mut last_version_dir = None;
669 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
670 while let Some(entry) = entries.next().await {
671 let entry = entry?;
672 if entry.file_type().await?.is_dir() {
673 last_version_dir = Some(entry.path());
674 }
675 }
676 let last_version_dir =
677 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
678 let server_path = last_version_dir.join(SERVER_PATH);
679 if server_path.exists() {
680 Ok(server_path)
681 } else {
682 Err(anyhow!(
683 "missing executable in directory {:?}",
684 last_version_dir
685 ))
686 }
687 })()
688 .await
689 }
690 }
691}