1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
10use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
11use log::{debug, error};
12use lsp::LanguageServer;
13use node_runtime::NodeRuntime;
14use request::{LogMessage, StatusNotification};
15use settings::Settings;
16use smol::{fs, io::BufReader, stream::StreamExt};
17use std::{
18 ffi::OsString,
19 ops::Range,
20 path::{Path, PathBuf},
21 sync::Arc,
22};
23use util::{
24 channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
25 paths, ResultExt,
26};
27
28const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
29actions!(copilot_auth, [SignIn, SignOut]);
30
31const COPILOT_NAMESPACE: &'static str = "copilot";
32actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
33
34pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
35 // Disable Copilot for stable releases.
36 if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
37 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
38 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
39 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
40 });
41 return;
42 }
43
44 let copilot = cx.add_model({
45 let node_runtime = node_runtime.clone();
46 move |cx| Copilot::start(http, node_runtime, cx)
47 });
48 cx.set_global(copilot.clone());
49
50 cx.observe(&copilot, |handle, cx| {
51 let status = handle.read(cx).status();
52 cx.update_global::<collections::CommandPaletteFilter, _, _>(
53 move |filter, _cx| match status {
54 Status::Disabled => {
55 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
56 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
57 }
58 Status::Authorized => {
59 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
60 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
61 }
62 _ => {
63 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
64 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
65 }
66 },
67 );
68 })
69 .detach();
70
71 sign_in::init(cx);
72 cx.add_global_action(|_: &SignIn, cx| {
73 if let Some(copilot) = Copilot::global(cx) {
74 copilot
75 .update(cx, |copilot, cx| copilot.sign_in(cx))
76 .detach_and_log_err(cx);
77 }
78 });
79 cx.add_global_action(|_: &SignOut, cx| {
80 if let Some(copilot) = Copilot::global(cx) {
81 copilot
82 .update(cx, |copilot, cx| copilot.sign_out(cx))
83 .detach_and_log_err(cx);
84 }
85 });
86
87 cx.add_global_action(|_: &Reinstall, cx| {
88 if let Some(copilot) = Copilot::global(cx) {
89 copilot
90 .update(cx, |copilot, cx| copilot.reinstall(cx))
91 .detach();
92 }
93 });
94}
95
96enum CopilotServer {
97 Disabled,
98 Starting {
99 task: Shared<Task<()>>,
100 },
101 Error(Arc<str>),
102 Started {
103 server: Arc<LanguageServer>,
104 status: SignInStatus,
105 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
106 },
107}
108
109#[derive(Clone, Debug)]
110enum SignInStatus {
111 Authorized,
112 Unauthorized,
113 SigningIn {
114 prompt: Option<request::PromptUserDeviceFlow>,
115 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
116 },
117 SignedOut,
118}
119
120#[derive(Debug, Clone)]
121pub enum Status {
122 Starting {
123 task: Shared<Task<()>>,
124 },
125 Error(Arc<str>),
126 Disabled,
127 SignedOut,
128 SigningIn {
129 prompt: Option<request::PromptUserDeviceFlow>,
130 },
131 Unauthorized,
132 Authorized,
133}
134
135impl Status {
136 pub fn is_authorized(&self) -> bool {
137 matches!(self, Status::Authorized)
138 }
139}
140
141#[derive(Debug, PartialEq, Eq)]
142pub struct Completion {
143 pub range: Range<Anchor>,
144 pub text: String,
145}
146
147pub struct Copilot {
148 http: Arc<dyn HttpClient>,
149 node_runtime: Arc<NodeRuntime>,
150 server: CopilotServer,
151}
152
153impl Entity for Copilot {
154 type Event = ();
155}
156
157impl Copilot {
158 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
159 if cx.has_global::<ModelHandle<Self>>() {
160 Some(cx.global::<ModelHandle<Self>>().clone())
161 } else {
162 None
163 }
164 }
165
166 fn start(
167 http: Arc<dyn HttpClient>,
168 node_runtime: Arc<NodeRuntime>,
169 cx: &mut ModelContext<Self>,
170 ) -> Self {
171 cx.observe_global::<Settings, _>({
172 let http = http.clone();
173 let node_runtime = node_runtime.clone();
174 move |this, cx| {
175 if cx.global::<Settings>().enable_copilot_integration {
176 if matches!(this.server, CopilotServer::Disabled) {
177 let start_task = cx
178 .spawn({
179 let http = http.clone();
180 let node_runtime = node_runtime.clone();
181 move |this, cx| {
182 Self::start_language_server(http, node_runtime, this, cx)
183 }
184 })
185 .shared();
186 this.server = CopilotServer::Starting { task: start_task };
187 cx.notify();
188 }
189 } else {
190 this.server = CopilotServer::Disabled;
191 cx.notify();
192 }
193 }
194 })
195 .detach();
196
197 if cx.global::<Settings>().enable_copilot_integration {
198 let start_task = cx
199 .spawn({
200 let http = http.clone();
201 let node_runtime = node_runtime.clone();
202 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
203 })
204 .shared();
205
206 Self {
207 http,
208 node_runtime,
209 server: CopilotServer::Starting { task: start_task },
210 }
211 } else {
212 Self {
213 http,
214 node_runtime,
215 server: CopilotServer::Disabled,
216 }
217 }
218 }
219
220 #[cfg(any(test, feature = "test-support"))]
221 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
222 let (server, fake_server) =
223 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
224 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
225 let this = cx.add_model(|cx| Self {
226 http: http.clone(),
227 node_runtime: NodeRuntime::new(http, cx.background().clone()),
228 server: CopilotServer::Started {
229 server: Arc::new(server),
230 status: SignInStatus::Authorized,
231 subscriptions_by_buffer_id: Default::default(),
232 },
233 });
234 (this, fake_server)
235 }
236
237 fn start_language_server(
238 http: Arc<dyn HttpClient>,
239 node_runtime: Arc<NodeRuntime>,
240 this: ModelHandle<Self>,
241 mut cx: AsyncAppContext,
242 ) -> impl Future<Output = ()> {
243 async move {
244 let start_language_server = async {
245 let server_path = get_copilot_lsp(http).await?;
246 let node_path = node_runtime.binary_path().await?;
247 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
248 let server = LanguageServer::new(
249 0,
250 &node_path,
251 arguments,
252 Path::new("/"),
253 None,
254 cx.clone(),
255 )?;
256
257 let server = server.initialize(Default::default()).await?;
258 let status = server
259 .request::<request::CheckStatus>(request::CheckStatusParams {
260 local_checks_only: false,
261 })
262 .await?;
263
264 server
265 .on_notification::<LogMessage, _>(|params, _cx| {
266 match params.level {
267 // Copilot is pretty agressive about logging
268 0 => debug!("copilot: {}", params.message),
269 1 => debug!("copilot: {}", params.message),
270 _ => error!("copilot: {}", params.message),
271 }
272
273 debug!("copilot metadata: {}", params.metadata_str);
274 debug!("copilot extra: {:?}", params.extra);
275 })
276 .detach();
277
278 server
279 .on_notification::<StatusNotification, _>(
280 |_, _| { /* Silence the notification */ },
281 )
282 .detach();
283
284 anyhow::Ok((server, status))
285 };
286
287 let server = start_language_server.await;
288 this.update(&mut cx, |this, cx| {
289 cx.notify();
290 match server {
291 Ok((server, status)) => {
292 this.server = CopilotServer::Started {
293 server,
294 status: SignInStatus::SignedOut,
295 subscriptions_by_buffer_id: Default::default(),
296 };
297 this.update_sign_in_status(status, cx);
298 }
299 Err(error) => {
300 this.server = CopilotServer::Error(error.to_string().into());
301 cx.notify()
302 }
303 }
304 })
305 }
306 }
307
308 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
309 if let CopilotServer::Started { server, status, .. } = &mut self.server {
310 let task = match status {
311 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
312 Task::ready(Ok(())).shared()
313 }
314 SignInStatus::SigningIn { task, .. } => {
315 cx.notify();
316 task.clone()
317 }
318 SignInStatus::SignedOut => {
319 let server = server.clone();
320 let task = cx
321 .spawn(|this, mut cx| async move {
322 let sign_in = async {
323 let sign_in = server
324 .request::<request::SignInInitiate>(
325 request::SignInInitiateParams {},
326 )
327 .await?;
328 match sign_in {
329 request::SignInInitiateResult::AlreadySignedIn { user } => {
330 Ok(request::SignInStatus::Ok { user })
331 }
332 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
333 this.update(&mut cx, |this, cx| {
334 if let CopilotServer::Started { status, .. } =
335 &mut this.server
336 {
337 if let SignInStatus::SigningIn {
338 prompt: prompt_flow,
339 ..
340 } = status
341 {
342 *prompt_flow = Some(flow.clone());
343 cx.notify();
344 }
345 }
346 });
347 let response = server
348 .request::<request::SignInConfirm>(
349 request::SignInConfirmParams {
350 user_code: flow.user_code,
351 },
352 )
353 .await?;
354 Ok(response)
355 }
356 }
357 };
358
359 let sign_in = sign_in.await;
360 this.update(&mut cx, |this, cx| match sign_in {
361 Ok(status) => {
362 this.update_sign_in_status(status, cx);
363 Ok(())
364 }
365 Err(error) => {
366 this.update_sign_in_status(
367 request::SignInStatus::NotSignedIn,
368 cx,
369 );
370 Err(Arc::new(error))
371 }
372 })
373 })
374 .shared();
375 *status = SignInStatus::SigningIn {
376 prompt: None,
377 task: task.clone(),
378 };
379 cx.notify();
380 task
381 }
382 };
383
384 cx.foreground()
385 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
386 } else {
387 // If we're downloading, wait until download is finished
388 // If we're in a stuck state, display to the user
389 Task::ready(Err(anyhow!("copilot hasn't started yet")))
390 }
391 }
392
393 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
394 if let CopilotServer::Started { server, status, .. } = &mut self.server {
395 *status = SignInStatus::SignedOut;
396 cx.notify();
397
398 let server = server.clone();
399 cx.background().spawn(async move {
400 server
401 .request::<request::SignOut>(request::SignOutParams {})
402 .await?;
403 anyhow::Ok(())
404 })
405 } else {
406 Task::ready(Err(anyhow!("copilot hasn't started yet")))
407 }
408 }
409
410 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
411 let start_task = cx
412 .spawn({
413 let http = self.http.clone();
414 let node_runtime = self.node_runtime.clone();
415 move |this, cx| async move {
416 clear_copilot_dir().await;
417 Self::start_language_server(http, node_runtime, this, cx).await
418 }
419 })
420 .shared();
421
422 self.server = CopilotServer::Starting {
423 task: start_task.clone(),
424 };
425
426 cx.notify();
427
428 cx.foreground().spawn(start_task)
429 }
430
431 pub fn completions<T>(
432 &mut self,
433 buffer: &ModelHandle<Buffer>,
434 position: T,
435 cx: &mut ModelContext<Self>,
436 ) -> Task<Result<Vec<Completion>>>
437 where
438 T: ToPointUtf16,
439 {
440 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
441 }
442
443 pub fn completions_cycling<T>(
444 &mut self,
445 buffer: &ModelHandle<Buffer>,
446 position: T,
447 cx: &mut ModelContext<Self>,
448 ) -> Task<Result<Vec<Completion>>>
449 where
450 T: ToPointUtf16,
451 {
452 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
453 }
454
455 fn request_completions<R, T>(
456 &mut self,
457 buffer: &ModelHandle<Buffer>,
458 position: T,
459 cx: &mut ModelContext<Self>,
460 ) -> Task<Result<Vec<Completion>>>
461 where
462 R: lsp::request::Request<
463 Params = request::GetCompletionsParams,
464 Result = request::GetCompletionsResult,
465 >,
466 T: ToPointUtf16,
467 {
468 let buffer_id = buffer.id();
469 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
470 let snapshot = buffer.read(cx).snapshot();
471 let server = match &mut self.server {
472 CopilotServer::Starting { .. } => {
473 return Task::ready(Err(anyhow!("copilot is still starting")))
474 }
475 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
476 CopilotServer::Error(error) => {
477 return Task::ready(Err(anyhow!(
478 "copilot was not started because of an error: {}",
479 error
480 )))
481 }
482 CopilotServer::Started {
483 server,
484 status,
485 subscriptions_by_buffer_id,
486 } => {
487 if matches!(status, SignInStatus::Authorized { .. }) {
488 subscriptions_by_buffer_id
489 .entry(buffer_id)
490 .or_insert_with(|| {
491 server
492 .notify::<lsp::notification::DidOpenTextDocument>(
493 lsp::DidOpenTextDocumentParams {
494 text_document: lsp::TextDocumentItem {
495 uri: uri.clone(),
496 language_id: id_for_language(
497 buffer.read(cx).language(),
498 ),
499 version: 0,
500 text: snapshot.text(),
501 },
502 },
503 )
504 .log_err();
505
506 let uri = uri.clone();
507 cx.observe_release(buffer, move |this, _, _| {
508 if let CopilotServer::Started {
509 server,
510 subscriptions_by_buffer_id,
511 ..
512 } = &mut this.server
513 {
514 server
515 .notify::<lsp::notification::DidCloseTextDocument>(
516 lsp::DidCloseTextDocumentParams {
517 text_document: lsp::TextDocumentIdentifier::new(
518 uri.clone(),
519 ),
520 },
521 )
522 .log_err();
523 subscriptions_by_buffer_id.remove(&buffer_id);
524 }
525 })
526 });
527
528 server.clone()
529 } else {
530 return Task::ready(Err(anyhow!("must sign in before using copilot")));
531 }
532 }
533 };
534
535 let settings = cx.global::<Settings>();
536 let position = position.to_point_utf16(&snapshot);
537 let language = snapshot.language_at(position);
538 let language_name = language.map(|language| language.name());
539 let language_name = language_name.as_deref();
540 let tab_size = settings.tab_size(language_name);
541 let hard_tabs = settings.hard_tabs(language_name);
542 let language_id = id_for_language(language);
543
544 let path;
545 let relative_path;
546 if let Some(file) = snapshot.file() {
547 if let Some(file) = file.as_local() {
548 path = file.abs_path(cx);
549 } else {
550 path = file.full_path(cx);
551 }
552 relative_path = file.path().to_path_buf();
553 } else {
554 path = PathBuf::new();
555 relative_path = PathBuf::new();
556 }
557
558 cx.background().spawn(async move {
559 let result = server
560 .request::<R>(request::GetCompletionsParams {
561 doc: request::GetCompletionsDocument {
562 source: snapshot.text(),
563 tab_size: tab_size.into(),
564 indent_size: 1,
565 insert_spaces: !hard_tabs,
566 uri,
567 path: path.to_string_lossy().into(),
568 relative_path: relative_path.to_string_lossy().into(),
569 language_id,
570 position: point_to_lsp(position),
571 version: 0,
572 },
573 })
574 .await?;
575 let completions = result
576 .completions
577 .into_iter()
578 .map(|completion| {
579 let start = snapshot
580 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
581 let end =
582 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
583 Completion {
584 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
585 text: completion.text,
586 }
587 })
588 .collect();
589 anyhow::Ok(completions)
590 })
591 }
592
593 pub fn status(&self) -> Status {
594 match &self.server {
595 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
596 CopilotServer::Disabled => Status::Disabled,
597 CopilotServer::Error(error) => Status::Error(error.clone()),
598 CopilotServer::Started { status, .. } => match status {
599 SignInStatus::Authorized { .. } => Status::Authorized,
600 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
601 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
602 prompt: prompt.clone(),
603 },
604 SignInStatus::SignedOut => Status::SignedOut,
605 },
606 }
607 }
608
609 fn update_sign_in_status(
610 &mut self,
611 lsp_status: request::SignInStatus,
612 cx: &mut ModelContext<Self>,
613 ) {
614 if let CopilotServer::Started { status, .. } = &mut self.server {
615 *status = match lsp_status {
616 request::SignInStatus::Ok { .. }
617 | request::SignInStatus::MaybeOk { .. }
618 | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
619 request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
620 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
621 };
622 cx.notify();
623 }
624 }
625}
626
627fn id_for_language(language: Option<&Arc<Language>>) -> String {
628 let language_name = language.map(|language| language.name());
629 match language_name.as_deref() {
630 Some("Plain Text") => "plaintext".to_string(),
631 Some(language_name) => language_name.to_lowercase(),
632 None => "plaintext".to_string(),
633 }
634}
635
636async fn clear_copilot_dir() {
637 remove_matching(&paths::COPILOT_DIR, |_| true).await
638}
639
640async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
641 const SERVER_PATH: &'static str = "dist/agent.js";
642
643 ///Check for the latest copilot language server and download it if we haven't already
644 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
645 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
646
647 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
648
649 fs::create_dir_all(version_dir).await?;
650 let server_path = version_dir.join(SERVER_PATH);
651
652 if fs::metadata(&server_path).await.is_err() {
653 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
654 let dist_dir = version_dir.join("dist");
655 fs::create_dir_all(dist_dir.as_path()).await?;
656
657 let url = &release
658 .assets
659 .get(0)
660 .context("Github release for copilot contained no assets")?
661 .browser_download_url;
662
663 let mut response = http
664 .get(&url, Default::default(), true)
665 .await
666 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
667 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
668 let archive = Archive::new(decompressed_bytes);
669 archive.unpack(dist_dir).await?;
670
671 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
672 }
673
674 Ok(server_path)
675 }
676
677 match fetch_latest(http).await {
678 ok @ Result::Ok(..) => ok,
679 e @ Err(..) => {
680 e.log_err();
681 // Fetch a cached binary, if it exists
682 (|| async move {
683 let mut last_version_dir = None;
684 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
685 while let Some(entry) = entries.next().await {
686 let entry = entry?;
687 if entry.file_type().await?.is_dir() {
688 last_version_dir = Some(entry.path());
689 }
690 }
691 let last_version_dir =
692 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
693 let server_path = last_version_dir.join(SERVER_PATH);
694 if server_path.exists() {
695 Ok(server_path)
696 } else {
697 Err(anyhow!(
698 "missing executable in directory {:?}",
699 last_version_dir
700 ))
701 }
702 })()
703 .await
704 }
705 }
706}