1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
10use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
11use log::{debug, error};
12use lsp::LanguageServer;
13use node_runtime::NodeRuntime;
14use request::{LogMessage, StatusNotification};
15use settings::Settings;
16use smol::{fs, io::BufReader, stream::StreamExt};
17use std::{
18 ffi::OsString,
19 ops::Range,
20 path::{Path, PathBuf},
21 sync::Arc,
22};
23use util::{
24 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
25};
26
27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
28actions!(copilot_auth, [SignIn, SignOut]);
29
30const COPILOT_NAMESPACE: &'static str = "copilot";
31actions!(
32 copilot,
33 [Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
34);
35
36pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
37 let copilot = cx.add_model({
38 let node_runtime = node_runtime.clone();
39 move |cx| Copilot::start(http, node_runtime, cx)
40 });
41 cx.set_global(copilot.clone());
42
43 cx.observe(&copilot, |handle, cx| {
44 let status = handle.read(cx).status();
45 cx.update_global::<collections::CommandPaletteFilter, _, _>(
46 move |filter, _cx| match status {
47 Status::Disabled => {
48 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
49 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
50 }
51 Status::Authorized => {
52 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
53 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
54 }
55 _ => {
56 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
57 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
58 }
59 },
60 );
61 })
62 .detach();
63
64 sign_in::init(cx);
65 cx.add_global_action(|_: &SignIn, cx| {
66 if let Some(copilot) = Copilot::global(cx) {
67 copilot
68 .update(cx, |copilot, cx| copilot.sign_in(cx))
69 .detach_and_log_err(cx);
70 }
71 });
72 cx.add_global_action(|_: &SignOut, cx| {
73 if let Some(copilot) = Copilot::global(cx) {
74 copilot
75 .update(cx, |copilot, cx| copilot.sign_out(cx))
76 .detach_and_log_err(cx);
77 }
78 });
79
80 cx.add_global_action(|_: &Reinstall, cx| {
81 if let Some(copilot) = Copilot::global(cx) {
82 copilot
83 .update(cx, |copilot, cx| copilot.reinstall(cx))
84 .detach();
85 }
86 });
87}
88
89enum CopilotServer {
90 Disabled,
91 Starting {
92 task: Shared<Task<()>>,
93 },
94 Error(Arc<str>),
95 Started {
96 server: Arc<LanguageServer>,
97 status: SignInStatus,
98 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
99 },
100}
101
102#[derive(Clone, Debug)]
103enum SignInStatus {
104 Authorized,
105 Unauthorized,
106 SigningIn {
107 prompt: Option<request::PromptUserDeviceFlow>,
108 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
109 },
110 SignedOut,
111}
112
113#[derive(Debug, Clone)]
114pub enum Status {
115 Starting {
116 task: Shared<Task<()>>,
117 },
118 Error(Arc<str>),
119 Disabled,
120 SignedOut,
121 SigningIn {
122 prompt: Option<request::PromptUserDeviceFlow>,
123 },
124 Unauthorized,
125 Authorized,
126}
127
128impl Status {
129 pub fn is_authorized(&self) -> bool {
130 matches!(self, Status::Authorized)
131 }
132}
133
134#[derive(Debug, PartialEq, Eq)]
135pub struct Completion {
136 pub range: Range<Anchor>,
137 pub text: String,
138}
139
140pub struct Copilot {
141 http: Arc<dyn HttpClient>,
142 node_runtime: Arc<NodeRuntime>,
143 server: CopilotServer,
144}
145
146impl Entity for Copilot {
147 type Event = ();
148}
149
150impl Copilot {
151 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
152 if cx.has_global::<ModelHandle<Self>>() {
153 Some(cx.global::<ModelHandle<Self>>().clone())
154 } else {
155 None
156 }
157 }
158
159 fn start(
160 http: Arc<dyn HttpClient>,
161 node_runtime: Arc<NodeRuntime>,
162 cx: &mut ModelContext<Self>,
163 ) -> Self {
164 cx.observe_global::<Settings, _>({
165 let http = http.clone();
166 let node_runtime = node_runtime.clone();
167 move |this, cx| {
168 if cx.global::<Settings>().features.copilot {
169 if matches!(this.server, CopilotServer::Disabled) {
170 let start_task = cx
171 .spawn({
172 let http = http.clone();
173 let node_runtime = node_runtime.clone();
174 move |this, cx| {
175 Self::start_language_server(http, node_runtime, this, cx)
176 }
177 })
178 .shared();
179 this.server = CopilotServer::Starting { task: start_task };
180 cx.notify();
181 }
182 } else {
183 this.server = CopilotServer::Disabled;
184 cx.notify();
185 }
186 }
187 })
188 .detach();
189
190 if cx.global::<Settings>().features.copilot {
191 let start_task = cx
192 .spawn({
193 let http = http.clone();
194 let node_runtime = node_runtime.clone();
195 move |this, cx| async {
196 Self::start_language_server(http, node_runtime, this, cx).await
197 }
198 })
199 .shared();
200
201 Self {
202 http,
203 node_runtime,
204 server: CopilotServer::Starting { task: start_task },
205 }
206 } else {
207 Self {
208 http,
209 node_runtime,
210 server: CopilotServer::Disabled,
211 }
212 }
213 }
214
215 #[cfg(any(test, feature = "test-support"))]
216 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
217 let (server, fake_server) =
218 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
219 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
220 let this = cx.add_model(|cx| Self {
221 http: http.clone(),
222 node_runtime: NodeRuntime::new(http, cx.background().clone()),
223 server: CopilotServer::Started {
224 server: Arc::new(server),
225 status: SignInStatus::Authorized,
226 subscriptions_by_buffer_id: Default::default(),
227 },
228 });
229 (this, fake_server)
230 }
231
232 fn start_language_server(
233 http: Arc<dyn HttpClient>,
234 node_runtime: Arc<NodeRuntime>,
235 this: ModelHandle<Self>,
236 mut cx: AsyncAppContext,
237 ) -> impl Future<Output = ()> {
238 async move {
239 let start_language_server = async {
240 let server_path = get_copilot_lsp(http).await?;
241 let node_path = node_runtime.binary_path().await?;
242 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
243 let server = LanguageServer::new(
244 0,
245 &node_path,
246 arguments,
247 Path::new("/"),
248 None,
249 cx.clone(),
250 )?;
251
252 let server = server.initialize(Default::default()).await?;
253 let status = server
254 .request::<request::CheckStatus>(request::CheckStatusParams {
255 local_checks_only: false,
256 })
257 .await?;
258
259 server
260 .on_notification::<LogMessage, _>(|params, _cx| {
261 match params.level {
262 // Copilot is pretty agressive about logging
263 0 => debug!("copilot: {}", params.message),
264 1 => debug!("copilot: {}", params.message),
265 _ => error!("copilot: {}", params.message),
266 }
267
268 debug!("copilot metadata: {}", params.metadata_str);
269 debug!("copilot extra: {:?}", params.extra);
270 })
271 .detach();
272
273 server
274 .on_notification::<StatusNotification, _>(
275 |_, _| { /* Silence the notification */ },
276 )
277 .detach();
278
279 anyhow::Ok((server, status))
280 };
281
282 let server = start_language_server.await;
283 this.update(&mut cx, |this, cx| {
284 cx.notify();
285 match server {
286 Ok((server, status)) => {
287 this.server = CopilotServer::Started {
288 server,
289 status: SignInStatus::SignedOut,
290 subscriptions_by_buffer_id: Default::default(),
291 };
292 this.update_sign_in_status(status, cx);
293 }
294 Err(error) => {
295 this.server = CopilotServer::Error(error.to_string().into());
296 cx.notify()
297 }
298 }
299 })
300 }
301 }
302
303 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
304 if let CopilotServer::Started { server, status, .. } = &mut self.server {
305 let task = match status {
306 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
307 Task::ready(Ok(())).shared()
308 }
309 SignInStatus::SigningIn { task, .. } => {
310 cx.notify();
311 task.clone()
312 }
313 SignInStatus::SignedOut => {
314 let server = server.clone();
315 let task = cx
316 .spawn(|this, mut cx| async move {
317 let sign_in = async {
318 let sign_in = server
319 .request::<request::SignInInitiate>(
320 request::SignInInitiateParams {},
321 )
322 .await?;
323 match sign_in {
324 request::SignInInitiateResult::AlreadySignedIn { user } => {
325 Ok(request::SignInStatus::Ok { user })
326 }
327 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
328 this.update(&mut cx, |this, cx| {
329 if let CopilotServer::Started { status, .. } =
330 &mut this.server
331 {
332 if let SignInStatus::SigningIn {
333 prompt: prompt_flow,
334 ..
335 } = status
336 {
337 *prompt_flow = Some(flow.clone());
338 cx.notify();
339 }
340 }
341 });
342 let response = server
343 .request::<request::SignInConfirm>(
344 request::SignInConfirmParams {
345 user_code: flow.user_code,
346 },
347 )
348 .await?;
349 Ok(response)
350 }
351 }
352 };
353
354 let sign_in = sign_in.await;
355 this.update(&mut cx, |this, cx| match sign_in {
356 Ok(status) => {
357 this.update_sign_in_status(status, cx);
358 Ok(())
359 }
360 Err(error) => {
361 this.update_sign_in_status(
362 request::SignInStatus::NotSignedIn,
363 cx,
364 );
365 Err(Arc::new(error))
366 }
367 })
368 })
369 .shared();
370 *status = SignInStatus::SigningIn {
371 prompt: None,
372 task: task.clone(),
373 };
374 cx.notify();
375 task
376 }
377 };
378
379 cx.foreground()
380 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
381 } else {
382 // If we're downloading, wait until download is finished
383 // If we're in a stuck state, display to the user
384 Task::ready(Err(anyhow!("copilot hasn't started yet")))
385 }
386 }
387
388 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
389 if let CopilotServer::Started { server, status, .. } = &mut self.server {
390 *status = SignInStatus::SignedOut;
391 cx.notify();
392
393 let server = server.clone();
394 cx.background().spawn(async move {
395 server
396 .request::<request::SignOut>(request::SignOutParams {})
397 .await?;
398 anyhow::Ok(())
399 })
400 } else {
401 Task::ready(Err(anyhow!("copilot hasn't started yet")))
402 }
403 }
404
405 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
406 let start_task = cx
407 .spawn({
408 let http = self.http.clone();
409 let node_runtime = self.node_runtime.clone();
410 move |this, cx| async move {
411 clear_copilot_dir().await;
412 Self::start_language_server(http, node_runtime, this, cx).await
413 }
414 })
415 .shared();
416
417 self.server = CopilotServer::Starting {
418 task: start_task.clone(),
419 };
420
421 cx.notify();
422
423 cx.foreground().spawn(start_task)
424 }
425
426 pub fn completions<T>(
427 &mut self,
428 buffer: &ModelHandle<Buffer>,
429 position: T,
430 cx: &mut ModelContext<Self>,
431 ) -> Task<Result<Vec<Completion>>>
432 where
433 T: ToPointUtf16,
434 {
435 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
436 }
437
438 pub fn completions_cycling<T>(
439 &mut self,
440 buffer: &ModelHandle<Buffer>,
441 position: T,
442 cx: &mut ModelContext<Self>,
443 ) -> Task<Result<Vec<Completion>>>
444 where
445 T: ToPointUtf16,
446 {
447 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
448 }
449
450 fn request_completions<R, T>(
451 &mut self,
452 buffer: &ModelHandle<Buffer>,
453 position: T,
454 cx: &mut ModelContext<Self>,
455 ) -> Task<Result<Vec<Completion>>>
456 where
457 R: lsp::request::Request<
458 Params = request::GetCompletionsParams,
459 Result = request::GetCompletionsResult,
460 >,
461 T: ToPointUtf16,
462 {
463 let buffer_id = buffer.id();
464 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
465 let snapshot = buffer.read(cx).snapshot();
466 let server = match &mut self.server {
467 CopilotServer::Starting { .. } => {
468 return Task::ready(Err(anyhow!("copilot is still starting")))
469 }
470 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
471 CopilotServer::Error(error) => {
472 return Task::ready(Err(anyhow!(
473 "copilot was not started because of an error: {}",
474 error
475 )))
476 }
477 CopilotServer::Started {
478 server,
479 status,
480 subscriptions_by_buffer_id,
481 } => {
482 if matches!(status, SignInStatus::Authorized { .. }) {
483 subscriptions_by_buffer_id
484 .entry(buffer_id)
485 .or_insert_with(|| {
486 server
487 .notify::<lsp::notification::DidOpenTextDocument>(
488 lsp::DidOpenTextDocumentParams {
489 text_document: lsp::TextDocumentItem {
490 uri: uri.clone(),
491 language_id: id_for_language(
492 buffer.read(cx).language(),
493 ),
494 version: 0,
495 text: snapshot.text(),
496 },
497 },
498 )
499 .log_err();
500
501 let uri = uri.clone();
502 cx.observe_release(buffer, move |this, _, _| {
503 if let CopilotServer::Started {
504 server,
505 subscriptions_by_buffer_id,
506 ..
507 } = &mut this.server
508 {
509 server
510 .notify::<lsp::notification::DidCloseTextDocument>(
511 lsp::DidCloseTextDocumentParams {
512 text_document: lsp::TextDocumentIdentifier::new(
513 uri.clone(),
514 ),
515 },
516 )
517 .log_err();
518 subscriptions_by_buffer_id.remove(&buffer_id);
519 }
520 })
521 });
522
523 server.clone()
524 } else {
525 return Task::ready(Err(anyhow!("must sign in before using copilot")));
526 }
527 }
528 };
529
530 let settings = cx.global::<Settings>();
531 let position = position.to_point_utf16(&snapshot);
532 let language = snapshot.language_at(position);
533 let language_name = language.map(|language| language.name());
534 let language_name = language_name.as_deref();
535 let tab_size = settings.tab_size(language_name);
536 let hard_tabs = settings.hard_tabs(language_name);
537 let language_id = id_for_language(language);
538
539 let path;
540 let relative_path;
541 if let Some(file) = snapshot.file() {
542 if let Some(file) = file.as_local() {
543 path = file.abs_path(cx);
544 } else {
545 path = file.full_path(cx);
546 }
547 relative_path = file.path().to_path_buf();
548 } else {
549 path = PathBuf::new();
550 relative_path = PathBuf::new();
551 }
552
553 cx.background().spawn(async move {
554 let result = server
555 .request::<R>(request::GetCompletionsParams {
556 doc: request::GetCompletionsDocument {
557 source: snapshot.text(),
558 tab_size: tab_size.into(),
559 indent_size: 1,
560 insert_spaces: !hard_tabs,
561 uri,
562 path: path.to_string_lossy().into(),
563 relative_path: relative_path.to_string_lossy().into(),
564 language_id,
565 position: point_to_lsp(position),
566 version: 0,
567 },
568 })
569 .await?;
570 let completions = result
571 .completions
572 .into_iter()
573 .map(|completion| {
574 let start = snapshot
575 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
576 let end =
577 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
578 Completion {
579 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
580 text: completion.text,
581 }
582 })
583 .collect();
584 anyhow::Ok(completions)
585 })
586 }
587
588 pub fn status(&self) -> Status {
589 match &self.server {
590 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
591 CopilotServer::Disabled => Status::Disabled,
592 CopilotServer::Error(error) => Status::Error(error.clone()),
593 CopilotServer::Started { status, .. } => match status {
594 SignInStatus::Authorized { .. } => Status::Authorized,
595 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
596 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
597 prompt: prompt.clone(),
598 },
599 SignInStatus::SignedOut => Status::SignedOut,
600 },
601 }
602 }
603
604 fn update_sign_in_status(
605 &mut self,
606 lsp_status: request::SignInStatus,
607 cx: &mut ModelContext<Self>,
608 ) {
609 if let CopilotServer::Started { status, .. } = &mut self.server {
610 *status = match lsp_status {
611 request::SignInStatus::Ok { .. }
612 | request::SignInStatus::MaybeOk { .. }
613 | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
614 request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
615 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
616 };
617 cx.notify();
618 }
619 }
620}
621
622fn id_for_language(language: Option<&Arc<Language>>) -> String {
623 let language_name = language.map(|language| language.name());
624 match language_name.as_deref() {
625 Some("Plain Text") => "plaintext".to_string(),
626 Some(language_name) => language_name.to_lowercase(),
627 None => "plaintext".to_string(),
628 }
629}
630
631async fn clear_copilot_dir() {
632 remove_matching(&paths::COPILOT_DIR, |_| true).await
633}
634
635async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
636 const SERVER_PATH: &'static str = "dist/agent.js";
637
638 ///Check for the latest copilot language server and download it if we haven't already
639 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
640 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
641
642 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
643
644 fs::create_dir_all(version_dir).await?;
645 let server_path = version_dir.join(SERVER_PATH);
646
647 if fs::metadata(&server_path).await.is_err() {
648 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
649 let dist_dir = version_dir.join("dist");
650 fs::create_dir_all(dist_dir.as_path()).await?;
651
652 let url = &release
653 .assets
654 .get(0)
655 .context("Github release for copilot contained no assets")?
656 .browser_download_url;
657
658 let mut response = http
659 .get(&url, Default::default(), true)
660 .await
661 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
662 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
663 let archive = Archive::new(decompressed_bytes);
664 archive.unpack(dist_dir).await?;
665
666 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
667 }
668
669 Ok(server_path)
670 }
671
672 match fetch_latest(http).await {
673 ok @ Result::Ok(..) => ok,
674 e @ Err(..) => {
675 e.log_err();
676 // Fetch a cached binary, if it exists
677 (|| async move {
678 let mut last_version_dir = None;
679 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
680 while let Some(entry) = entries.next().await {
681 let entry = entry?;
682 if entry.file_type().await?.is_dir() {
683 last_version_dir = Some(entry.path());
684 }
685 }
686 let last_version_dir =
687 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
688 let server_path = last_version_dir.join(SERVER_PATH);
689 if server_path.exists() {
690 Ok(server_path)
691 } else {
692 Err(anyhow!(
693 "missing executable in directory {:?}",
694 last_version_dir
695 ))
696 }
697 })()
698 .await
699 }
700 }
701}