1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use collections::HashMap;
9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
10use gpui::{
11 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
12 Task,
13};
14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
15use log::{debug, error};
16use lsp::LanguageServer;
17use node_runtime::NodeRuntime;
18use request::{LogMessage, StatusNotification};
19use settings::Settings;
20use smol::{fs, io::BufReader, stream::StreamExt};
21use std::{
22 ffi::OsString,
23 ops::Range,
24 path::{Path, PathBuf},
25 sync::Arc,
26};
27use util::{
28 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
29};
30
31const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
32actions!(copilot_auth, [SignIn, SignOut]);
33
34const COPILOT_NAMESPACE: &'static str = "copilot";
35actions!(
36 copilot,
37 [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
38);
39
40pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
41 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
42 cx.set_global(copilot.clone());
43 cx.add_global_action(|_: &SignIn, cx| {
44 let copilot = Copilot::global(cx).unwrap();
45 copilot
46 .update(cx, |copilot, cx| copilot.sign_in(cx))
47 .detach_and_log_err(cx);
48 });
49 cx.add_global_action(|_: &SignOut, cx| {
50 let copilot = Copilot::global(cx).unwrap();
51 copilot
52 .update(cx, |copilot, cx| copilot.sign_out(cx))
53 .detach_and_log_err(cx);
54 });
55
56 cx.add_global_action(|_: &Reinstall, cx| {
57 let copilot = Copilot::global(cx).unwrap();
58 copilot
59 .update(cx, |copilot, cx| copilot.reinstall(cx))
60 .detach();
61 });
62
63 cx.observe(&copilot, |handle, cx| {
64 let status = handle.read(cx).status();
65 cx.update_global::<collections::CommandPaletteFilter, _, _>(
66 move |filter, _cx| match status {
67 Status::Disabled => {
68 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
69 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
70 }
71 Status::Authorized => {
72 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
73 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
74 }
75 _ => {
76 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
77 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
78 }
79 },
80 );
81 })
82 .detach();
83
84 sign_in::init(cx);
85}
86
87enum CopilotServer {
88 Disabled,
89 Starting {
90 task: Shared<Task<()>>,
91 },
92 Error(Arc<str>),
93 Started {
94 server: Arc<LanguageServer>,
95 status: SignInStatus,
96 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
97 },
98}
99
100#[derive(Clone, Debug)]
101enum SignInStatus {
102 Authorized {
103 _user: String,
104 },
105 Unauthorized {
106 _user: String,
107 },
108 SigningIn {
109 prompt: Option<request::PromptUserDeviceFlow>,
110 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
111 },
112 SignedOut,
113}
114
115#[derive(Debug, Clone)]
116pub enum Status {
117 Starting {
118 task: Shared<Task<()>>,
119 },
120 Error(Arc<str>),
121 Disabled,
122 SignedOut,
123 SigningIn {
124 prompt: Option<request::PromptUserDeviceFlow>,
125 },
126 Unauthorized,
127 Authorized,
128}
129
130impl Status {
131 pub fn is_authorized(&self) -> bool {
132 matches!(self, Status::Authorized)
133 }
134}
135
136#[derive(Debug, PartialEq, Eq)]
137pub struct Completion {
138 pub range: Range<Anchor>,
139 pub text: String,
140}
141
142pub struct Copilot {
143 http: Arc<dyn HttpClient>,
144 node_runtime: Arc<NodeRuntime>,
145 server: CopilotServer,
146}
147
148impl Entity for Copilot {
149 type Event = ();
150}
151
152impl Copilot {
153 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
154 match self.server {
155 CopilotServer::Starting { ref task } => Some(task.clone()),
156 _ => None,
157 }
158 }
159
160 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
161 if cx.has_global::<ModelHandle<Self>>() {
162 Some(cx.global::<ModelHandle<Self>>().clone())
163 } else {
164 None
165 }
166 }
167
168 fn start(
169 http: Arc<dyn HttpClient>,
170 node_runtime: Arc<NodeRuntime>,
171 cx: &mut ModelContext<Self>,
172 ) -> Self {
173 cx.observe_global::<Settings, _>({
174 let http = http.clone();
175 let node_runtime = node_runtime.clone();
176 move |this, cx| {
177 if cx.global::<Settings>().enable_copilot_integration {
178 if matches!(this.server, CopilotServer::Disabled) {
179 let start_task = cx
180 .spawn({
181 let http = http.clone();
182 let node_runtime = node_runtime.clone();
183 move |this, cx| {
184 Self::start_language_server(http, node_runtime, this, cx)
185 }
186 })
187 .shared();
188 this.server = CopilotServer::Starting { task: start_task };
189 cx.notify();
190 }
191 } else {
192 this.server = CopilotServer::Disabled;
193 cx.notify();
194 }
195 }
196 })
197 .detach();
198
199 if cx.global::<Settings>().enable_copilot_integration {
200 let start_task = cx
201 .spawn({
202 let http = http.clone();
203 let node_runtime = node_runtime.clone();
204 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
205 })
206 .shared();
207
208 Self {
209 http,
210 node_runtime,
211 server: CopilotServer::Starting { task: start_task },
212 }
213 } else {
214 Self {
215 http,
216 node_runtime,
217 server: CopilotServer::Disabled,
218 }
219 }
220 }
221
222 fn start_language_server(
223 http: Arc<dyn HttpClient>,
224 node_runtime: Arc<NodeRuntime>,
225 this: ModelHandle<Self>,
226 mut cx: AsyncAppContext,
227 ) -> impl Future<Output = ()> {
228 async move {
229 let start_language_server = async {
230 let server_path = get_copilot_lsp(http).await?;
231 let node_path = node_runtime.binary_path().await?;
232 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
233 let server = LanguageServer::new(
234 0,
235 &node_path,
236 arguments,
237 Path::new("/"),
238 None,
239 cx.clone(),
240 )?;
241
242 let server = server.initialize(Default::default()).await?;
243 let status = server
244 .request::<request::CheckStatus>(request::CheckStatusParams {
245 local_checks_only: false,
246 })
247 .await?;
248
249 server
250 .on_notification::<LogMessage, _>(|params, _cx| {
251 match params.level {
252 // Copilot is pretty agressive about logging
253 0 => debug!("copilot: {}", params.message),
254 1 => debug!("copilot: {}", params.message),
255 _ => error!("copilot: {}", params.message),
256 }
257
258 debug!("copilot metadata: {}", params.metadata_str);
259 debug!("copilot extra: {:?}", params.extra);
260 })
261 .detach();
262
263 server
264 .on_notification::<StatusNotification, _>(
265 |_, _| { /* Silence the notification */ },
266 )
267 .detach();
268
269 anyhow::Ok((server, status))
270 };
271
272 let server = start_language_server.await;
273 this.update(&mut cx, |this, cx| {
274 cx.notify();
275 match server {
276 Ok((server, status)) => {
277 this.server = CopilotServer::Started {
278 server,
279 status: SignInStatus::SignedOut,
280 subscriptions_by_buffer_id: Default::default(),
281 };
282 this.update_sign_in_status(status, cx);
283 }
284 Err(error) => {
285 this.server = CopilotServer::Error(error.to_string().into());
286 cx.notify()
287 }
288 }
289 })
290 }
291 }
292
293 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
294 if let CopilotServer::Started { server, status, .. } = &mut self.server {
295 let task = match status {
296 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
297 Task::ready(Ok(())).shared()
298 }
299 SignInStatus::SigningIn { task, .. } => {
300 cx.notify();
301 task.clone()
302 }
303 SignInStatus::SignedOut => {
304 let server = server.clone();
305 let task = cx
306 .spawn(|this, mut cx| async move {
307 let sign_in = async {
308 let sign_in = server
309 .request::<request::SignInInitiate>(
310 request::SignInInitiateParams {},
311 )
312 .await?;
313 match sign_in {
314 request::SignInInitiateResult::AlreadySignedIn { user } => {
315 Ok(request::SignInStatus::Ok { user })
316 }
317 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
318 this.update(&mut cx, |this, cx| {
319 if let CopilotServer::Started { status, .. } =
320 &mut this.server
321 {
322 if let SignInStatus::SigningIn {
323 prompt: prompt_flow,
324 ..
325 } = status
326 {
327 *prompt_flow = Some(flow.clone());
328 cx.notify();
329 }
330 }
331 });
332 let response = server
333 .request::<request::SignInConfirm>(
334 request::SignInConfirmParams {
335 user_code: flow.user_code,
336 },
337 )
338 .await?;
339 Ok(response)
340 }
341 }
342 };
343
344 let sign_in = sign_in.await;
345 this.update(&mut cx, |this, cx| match sign_in {
346 Ok(status) => {
347 this.update_sign_in_status(status, cx);
348 Ok(())
349 }
350 Err(error) => {
351 this.update_sign_in_status(
352 request::SignInStatus::NotSignedIn,
353 cx,
354 );
355 Err(Arc::new(error))
356 }
357 })
358 })
359 .shared();
360 *status = SignInStatus::SigningIn {
361 prompt: None,
362 task: task.clone(),
363 };
364 cx.notify();
365 task
366 }
367 };
368
369 cx.foreground()
370 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
371 } else {
372 // If we're downloading, wait until download is finished
373 // If we're in a stuck state, display to the user
374 Task::ready(Err(anyhow!("copilot hasn't started yet")))
375 }
376 }
377
378 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
379 if let CopilotServer::Started { server, status, .. } = &mut self.server {
380 *status = SignInStatus::SignedOut;
381 cx.notify();
382
383 let server = server.clone();
384 cx.background().spawn(async move {
385 server
386 .request::<request::SignOut>(request::SignOutParams {})
387 .await?;
388 anyhow::Ok(())
389 })
390 } else {
391 Task::ready(Err(anyhow!("copilot hasn't started yet")))
392 }
393 }
394
395 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
396 let start_task = cx
397 .spawn({
398 let http = self.http.clone();
399 let node_runtime = self.node_runtime.clone();
400 move |this, cx| async move {
401 clear_copilot_dir().await;
402 Self::start_language_server(http, node_runtime, this, cx).await
403 }
404 })
405 .shared();
406
407 self.server = CopilotServer::Starting {
408 task: start_task.clone(),
409 };
410
411 cx.notify();
412
413 cx.foreground().spawn(start_task)
414 }
415
416 pub fn completions<T>(
417 &mut self,
418 buffer: &ModelHandle<Buffer>,
419 position: T,
420 cx: &mut ModelContext<Self>,
421 ) -> Task<Result<Vec<Completion>>>
422 where
423 T: ToPointUtf16,
424 {
425 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
426 }
427
428 pub fn completions_cycling<T>(
429 &mut self,
430 buffer: &ModelHandle<Buffer>,
431 position: T,
432 cx: &mut ModelContext<Self>,
433 ) -> Task<Result<Vec<Completion>>>
434 where
435 T: ToPointUtf16,
436 {
437 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
438 }
439
440 fn request_completions<R, T>(
441 &mut self,
442 buffer: &ModelHandle<Buffer>,
443 position: T,
444 cx: &mut ModelContext<Self>,
445 ) -> Task<Result<Vec<Completion>>>
446 where
447 R: lsp::request::Request<
448 Params = request::GetCompletionsParams,
449 Result = request::GetCompletionsResult,
450 >,
451 T: ToPointUtf16,
452 {
453 let buffer_id = buffer.id();
454 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
455 let snapshot = buffer.read(cx).snapshot();
456 let server = match &mut self.server {
457 CopilotServer::Starting { .. } => {
458 return Task::ready(Err(anyhow!("copilot is still starting")))
459 }
460 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
461 CopilotServer::Error(error) => {
462 return Task::ready(Err(anyhow!(
463 "copilot was not started because of an error: {}",
464 error
465 )))
466 }
467 CopilotServer::Started {
468 server,
469 status,
470 subscriptions_by_buffer_id,
471 } => {
472 if matches!(status, SignInStatus::Authorized { .. }) {
473 subscriptions_by_buffer_id
474 .entry(buffer_id)
475 .or_insert_with(|| {
476 server
477 .notify::<lsp::notification::DidOpenTextDocument>(
478 lsp::DidOpenTextDocumentParams {
479 text_document: lsp::TextDocumentItem {
480 uri: uri.clone(),
481 language_id: id_for_language(
482 buffer.read(cx).language(),
483 ),
484 version: 0,
485 text: snapshot.text(),
486 },
487 },
488 )
489 .log_err();
490
491 let uri = uri.clone();
492 cx.observe_release(buffer, move |this, _, _| {
493 if let CopilotServer::Started {
494 server,
495 subscriptions_by_buffer_id,
496 ..
497 } = &mut this.server
498 {
499 server
500 .notify::<lsp::notification::DidCloseTextDocument>(
501 lsp::DidCloseTextDocumentParams {
502 text_document: lsp::TextDocumentIdentifier::new(
503 uri.clone(),
504 ),
505 },
506 )
507 .log_err();
508 subscriptions_by_buffer_id.remove(&buffer_id);
509 }
510 })
511 });
512
513 server.clone()
514 } else {
515 return Task::ready(Err(anyhow!("must sign in before using copilot")));
516 }
517 }
518 };
519
520 let settings = cx.global::<Settings>();
521 let position = position.to_point_utf16(&snapshot);
522 let language = snapshot.language_at(position);
523 let language_name = language.map(|language| language.name());
524 let language_name = language_name.as_deref();
525 let tab_size = settings.tab_size(language_name);
526 let hard_tabs = settings.hard_tabs(language_name);
527 let language_id = id_for_language(language);
528
529 let path;
530 let relative_path;
531 if let Some(file) = snapshot.file() {
532 if let Some(file) = file.as_local() {
533 path = file.abs_path(cx);
534 } else {
535 path = file.full_path(cx);
536 }
537 relative_path = file.path().to_path_buf();
538 } else {
539 path = PathBuf::new();
540 relative_path = PathBuf::new();
541 }
542
543 cx.background().spawn(async move {
544 let result = server
545 .request::<R>(request::GetCompletionsParams {
546 doc: request::GetCompletionsDocument {
547 source: snapshot.text(),
548 tab_size: tab_size.into(),
549 indent_size: 1,
550 insert_spaces: !hard_tabs,
551 uri,
552 path: path.to_string_lossy().into(),
553 relative_path: relative_path.to_string_lossy().into(),
554 language_id,
555 position: point_to_lsp(position),
556 version: 0,
557 },
558 })
559 .await?;
560 let completions = result
561 .completions
562 .into_iter()
563 .map(|completion| {
564 let start = snapshot
565 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
566 let end =
567 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
568 Completion {
569 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
570 text: completion.text,
571 }
572 })
573 .collect();
574 anyhow::Ok(completions)
575 })
576 }
577
578 pub fn status(&self) -> Status {
579 match &self.server {
580 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
581 CopilotServer::Disabled => Status::Disabled,
582 CopilotServer::Error(error) => Status::Error(error.clone()),
583 CopilotServer::Started { status, .. } => match status {
584 SignInStatus::Authorized { .. } => Status::Authorized,
585 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
586 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
587 prompt: prompt.clone(),
588 },
589 SignInStatus::SignedOut => Status::SignedOut,
590 },
591 }
592 }
593
594 fn update_sign_in_status(
595 &mut self,
596 lsp_status: request::SignInStatus,
597 cx: &mut ModelContext<Self>,
598 ) {
599 if let CopilotServer::Started { status, .. } = &mut self.server {
600 *status = match lsp_status {
601 request::SignInStatus::Ok { user }
602 | request::SignInStatus::MaybeOk { user }
603 | request::SignInStatus::AlreadySignedIn { user } => {
604 SignInStatus::Authorized { _user: user }
605 }
606 request::SignInStatus::NotAuthorized { user } => {
607 SignInStatus::Unauthorized { _user: user }
608 }
609 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
610 };
611 cx.notify();
612 }
613 }
614}
615
616fn id_for_language(language: Option<&Arc<Language>>) -> String {
617 let language_name = language.map(|language| language.name());
618 match language_name.as_deref() {
619 Some("Plain Text") => "plaintext".to_string(),
620 Some(language_name) => language_name.to_lowercase(),
621 None => "plaintext".to_string(),
622 }
623}
624
625async fn clear_copilot_dir() {
626 remove_matching(&paths::COPILOT_DIR, |_| true).await
627}
628
629async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
630 const SERVER_PATH: &'static str = "dist/agent.js";
631
632 ///Check for the latest copilot language server and download it if we haven't already
633 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
634 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
635
636 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
637
638 fs::create_dir_all(version_dir).await?;
639 let server_path = version_dir.join(SERVER_PATH);
640
641 if fs::metadata(&server_path).await.is_err() {
642 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
643 let dist_dir = version_dir.join("dist");
644 fs::create_dir_all(dist_dir.as_path()).await?;
645
646 let url = &release
647 .assets
648 .get(0)
649 .context("Github release for copilot contained no assets")?
650 .browser_download_url;
651
652 let mut response = http
653 .get(&url, Default::default(), true)
654 .await
655 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
656 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
657 let archive = Archive::new(decompressed_bytes);
658 archive.unpack(dist_dir).await?;
659
660 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
661 }
662
663 Ok(server_path)
664 }
665
666 match fetch_latest(http).await {
667 ok @ Result::Ok(..) => ok,
668 e @ Err(..) => {
669 e.log_err();
670 // Fetch a cached binary, if it exists
671 (|| async move {
672 let mut last_version_dir = None;
673 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
674 while let Some(entry) = entries.next().await {
675 let entry = entry?;
676 if entry.file_type().await?.is_dir() {
677 last_version_dir = Some(entry.path());
678 }
679 }
680 let last_version_dir =
681 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
682 let server_path = last_version_dir.join(SERVER_PATH);
683 if server_path.exists() {
684 Ok(server_path)
685 } else {
686 Err(anyhow!(
687 "missing executable in directory {:?}",
688 last_version_dir
689 ))
690 }
691 })()
692 .await
693 }
694 }
695}