1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
11 Task,
12};
13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
14use log::{debug, error};
15use lsp::LanguageServer;
16use node_runtime::NodeRuntime;
17use request::{LogMessage, StatusNotification};
18use settings::Settings;
19use smol::{fs, io::BufReader, stream::StreamExt};
20use std::{
21 ffi::OsString,
22 ops::Range,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use util::{
27 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
28};
29
30const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
31actions!(copilot_auth, [SignIn, SignOut]);
32
33const COPILOT_NAMESPACE: &'static str = "copilot";
34actions!(
35 copilot,
36 [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
37);
38
39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
40 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
41 cx.set_global(copilot.clone());
42 cx.add_global_action(|_: &SignIn, cx| {
43 let copilot = Copilot::global(cx).unwrap();
44 copilot
45 .update(cx, |copilot, cx| copilot.sign_in(cx))
46 .detach_and_log_err(cx);
47 });
48 cx.add_global_action(|_: &SignOut, cx| {
49 let copilot = Copilot::global(cx).unwrap();
50 copilot
51 .update(cx, |copilot, cx| copilot.sign_out(cx))
52 .detach_and_log_err(cx);
53 });
54
55 cx.add_global_action(|_: &Reinstall, cx| {
56 let copilot = Copilot::global(cx).unwrap();
57 copilot
58 .update(cx, |copilot, cx| copilot.reinstall(cx))
59 .detach();
60 });
61
62 cx.observe(&copilot, |handle, cx| {
63 let status = handle.read(cx).status();
64 cx.update_global::<collections::CommandPaletteFilter, _, _>(
65 move |filter, _cx| match status {
66 Status::Disabled => {
67 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
68 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
69 }
70 Status::Authorized => {
71 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
72 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
73 }
74 _ => {
75 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
76 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
77 }
78 },
79 );
80 })
81 .detach();
82
83 sign_in::init(cx);
84}
85
86enum CopilotServer {
87 Disabled,
88 Starting {
89 task: Shared<Task<()>>,
90 },
91 Error(Arc<str>),
92 Started {
93 server: Arc<LanguageServer>,
94 status: SignInStatus,
95 },
96}
97
98#[derive(Clone, Debug)]
99enum SignInStatus {
100 Authorized {
101 _user: String,
102 },
103 Unauthorized {
104 _user: String,
105 },
106 SigningIn {
107 prompt: Option<request::PromptUserDeviceFlow>,
108 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
109 },
110 SignedOut,
111}
112
113#[derive(Debug, Clone)]
114pub enum Status {
115 Starting {
116 task: Shared<Task<()>>,
117 },
118 Error(Arc<str>),
119 Disabled,
120 SignedOut,
121 SigningIn {
122 prompt: Option<request::PromptUserDeviceFlow>,
123 },
124 Unauthorized,
125 Authorized,
126}
127
128impl Status {
129 pub fn is_authorized(&self) -> bool {
130 matches!(self, Status::Authorized)
131 }
132}
133
134#[derive(Debug, PartialEq, Eq)]
135pub struct Completion {
136 pub range: Range<Anchor>,
137 pub text: String,
138}
139
140pub struct Copilot {
141 http: Arc<dyn HttpClient>,
142 node_runtime: Arc<NodeRuntime>,
143 server: CopilotServer,
144}
145
146impl Entity for Copilot {
147 type Event = ();
148}
149
150impl Copilot {
151 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
152 match self.server {
153 CopilotServer::Starting { ref task } => Some(task.clone()),
154 _ => None,
155 }
156 }
157
158 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
159 if cx.has_global::<ModelHandle<Self>>() {
160 Some(cx.global::<ModelHandle<Self>>().clone())
161 } else {
162 None
163 }
164 }
165
166 fn start(
167 http: Arc<dyn HttpClient>,
168 node_runtime: Arc<NodeRuntime>,
169 cx: &mut ModelContext<Self>,
170 ) -> Self {
171 cx.observe_global::<Settings, _>({
172 let http = http.clone();
173 let node_runtime = node_runtime.clone();
174 move |this, cx| {
175 if cx.global::<Settings>().enable_copilot_integration {
176 if matches!(this.server, CopilotServer::Disabled) {
177 let start_task = cx
178 .spawn({
179 let http = http.clone();
180 let node_runtime = node_runtime.clone();
181 move |this, cx| {
182 Self::start_language_server(http, node_runtime, this, cx)
183 }
184 })
185 .shared();
186 this.server = CopilotServer::Starting { task: start_task };
187 cx.notify();
188 }
189 } else {
190 this.server = CopilotServer::Disabled;
191 cx.notify();
192 }
193 }
194 })
195 .detach();
196
197 if cx.global::<Settings>().enable_copilot_integration {
198 let start_task = cx
199 .spawn({
200 let http = http.clone();
201 let node_runtime = node_runtime.clone();
202 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
203 })
204 .shared();
205
206 Self {
207 http,
208 node_runtime,
209 server: CopilotServer::Starting { task: start_task },
210 }
211 } else {
212 Self {
213 http,
214 node_runtime,
215 server: CopilotServer::Disabled,
216 }
217 }
218 }
219
220 fn start_language_server(
221 http: Arc<dyn HttpClient>,
222 node_runtime: Arc<NodeRuntime>,
223 this: ModelHandle<Self>,
224 mut cx: AsyncAppContext,
225 ) -> impl Future<Output = ()> {
226 async move {
227 let start_language_server = async {
228 let server_path = get_copilot_lsp(http).await?;
229 let node_path = node_runtime.binary_path().await?;
230 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
231 let server = LanguageServer::new(
232 0,
233 &node_path,
234 arguments,
235 Path::new("/"),
236 None,
237 cx.clone(),
238 )?;
239
240 let server = server.initialize(Default::default()).await?;
241 let status = server
242 .request::<request::CheckStatus>(request::CheckStatusParams {
243 local_checks_only: false,
244 })
245 .await?;
246
247 server
248 .on_notification::<LogMessage, _>(|params, _cx| {
249 match params.level {
250 // Copilot is pretty agressive about logging
251 0 => debug!("copilot: {}", params.message),
252 1 => debug!("copilot: {}", params.message),
253 _ => error!("copilot: {}", params.message),
254 }
255
256 debug!("copilot metadata: {}", params.metadata_str);
257 debug!("copilot extra: {:?}", params.extra);
258 })
259 .detach();
260
261 server
262 .on_notification::<StatusNotification, _>(
263 |_, _| { /* Silence the notification */ },
264 )
265 .detach();
266
267 anyhow::Ok((server, status))
268 };
269
270 let server = start_language_server.await;
271 this.update(&mut cx, |this, cx| {
272 cx.notify();
273 match server {
274 Ok((server, status)) => {
275 this.server = CopilotServer::Started {
276 server,
277 status: SignInStatus::SignedOut,
278 };
279 this.update_sign_in_status(status, cx);
280 }
281 Err(error) => {
282 this.server = CopilotServer::Error(error.to_string().into());
283 cx.notify()
284 }
285 }
286 })
287 }
288 }
289
290 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
291 if let CopilotServer::Started { server, status } = &mut self.server {
292 let task = match status {
293 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
294 Task::ready(Ok(())).shared()
295 }
296 SignInStatus::SigningIn { task, .. } => {
297 cx.notify();
298 task.clone()
299 }
300 SignInStatus::SignedOut => {
301 let server = server.clone();
302 let task = cx
303 .spawn(|this, mut cx| async move {
304 let sign_in = async {
305 let sign_in = server
306 .request::<request::SignInInitiate>(
307 request::SignInInitiateParams {},
308 )
309 .await?;
310 match sign_in {
311 request::SignInInitiateResult::AlreadySignedIn { user } => {
312 Ok(request::SignInStatus::Ok { user })
313 }
314 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
315 this.update(&mut cx, |this, cx| {
316 if let CopilotServer::Started { status, .. } =
317 &mut this.server
318 {
319 if let SignInStatus::SigningIn {
320 prompt: prompt_flow,
321 ..
322 } = status
323 {
324 *prompt_flow = Some(flow.clone());
325 cx.notify();
326 }
327 }
328 });
329 let response = server
330 .request::<request::SignInConfirm>(
331 request::SignInConfirmParams {
332 user_code: flow.user_code,
333 },
334 )
335 .await?;
336 Ok(response)
337 }
338 }
339 };
340
341 let sign_in = sign_in.await;
342 this.update(&mut cx, |this, cx| match sign_in {
343 Ok(status) => {
344 this.update_sign_in_status(status, cx);
345 Ok(())
346 }
347 Err(error) => {
348 this.update_sign_in_status(
349 request::SignInStatus::NotSignedIn,
350 cx,
351 );
352 Err(Arc::new(error))
353 }
354 })
355 })
356 .shared();
357 *status = SignInStatus::SigningIn {
358 prompt: None,
359 task: task.clone(),
360 };
361 cx.notify();
362 task
363 }
364 };
365
366 cx.foreground()
367 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
368 } else {
369 // If we're downloading, wait until download is finished
370 // If we're in a stuck state, display to the user
371 Task::ready(Err(anyhow!("copilot hasn't started yet")))
372 }
373 }
374
375 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
376 if let CopilotServer::Started { server, status } = &mut self.server {
377 *status = SignInStatus::SignedOut;
378 cx.notify();
379
380 let server = server.clone();
381 cx.background().spawn(async move {
382 server
383 .request::<request::SignOut>(request::SignOutParams {})
384 .await?;
385 anyhow::Ok(())
386 })
387 } else {
388 Task::ready(Err(anyhow!("copilot hasn't started yet")))
389 }
390 }
391
392 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
393 let start_task = cx
394 .spawn({
395 let http = self.http.clone();
396 let node_runtime = self.node_runtime.clone();
397 move |this, cx| async move {
398 clear_copilot_dir().await;
399 Self::start_language_server(http, node_runtime, this, cx).await
400 }
401 })
402 .shared();
403
404 self.server = CopilotServer::Starting {
405 task: start_task.clone(),
406 };
407
408 cx.notify();
409
410 cx.foreground().spawn(start_task)
411 }
412
413 pub fn completion<T>(
414 &self,
415 buffer: &ModelHandle<Buffer>,
416 position: T,
417 cx: &mut ModelContext<Self>,
418 ) -> Task<Result<Option<Completion>>>
419 where
420 T: ToPointUtf16,
421 {
422 let server = match self.authorized_server() {
423 Ok(server) => server,
424 Err(error) => return Task::ready(Err(error)),
425 };
426
427 let buffer = buffer.read(cx).snapshot();
428 let request = server
429 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
430 cx.background().spawn(async move {
431 let result = request.await?;
432 let completion = result
433 .completions
434 .into_iter()
435 .next()
436 .map(|completion| completion_from_lsp(completion, &buffer));
437 anyhow::Ok(completion)
438 })
439 }
440
441 pub fn completions_cycling<T>(
442 &self,
443 buffer: &ModelHandle<Buffer>,
444 position: T,
445 cx: &mut ModelContext<Self>,
446 ) -> Task<Result<Vec<Completion>>>
447 where
448 T: ToPointUtf16,
449 {
450 let server = match self.authorized_server() {
451 Ok(server) => server,
452 Err(error) => return Task::ready(Err(error)),
453 };
454
455 let buffer = buffer.read(cx).snapshot();
456 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
457 &buffer, position, cx,
458 ));
459 cx.background().spawn(async move {
460 let result = request.await?;
461 let completions = result
462 .completions
463 .into_iter()
464 .map(|completion| completion_from_lsp(completion, &buffer))
465 .collect();
466 anyhow::Ok(completions)
467 })
468 }
469
470 pub fn status(&self) -> Status {
471 match &self.server {
472 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
473 CopilotServer::Disabled => Status::Disabled,
474 CopilotServer::Error(error) => Status::Error(error.clone()),
475 CopilotServer::Started { status, .. } => match status {
476 SignInStatus::Authorized { .. } => Status::Authorized,
477 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
478 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
479 prompt: prompt.clone(),
480 },
481 SignInStatus::SignedOut => Status::SignedOut,
482 },
483 }
484 }
485
486 fn update_sign_in_status(
487 &mut self,
488 lsp_status: request::SignInStatus,
489 cx: &mut ModelContext<Self>,
490 ) {
491 if let CopilotServer::Started { status, .. } = &mut self.server {
492 *status = match lsp_status {
493 request::SignInStatus::Ok { user }
494 | request::SignInStatus::MaybeOk { user }
495 | request::SignInStatus::AlreadySignedIn { user } => {
496 SignInStatus::Authorized { _user: user }
497 }
498 request::SignInStatus::NotAuthorized { user } => {
499 SignInStatus::Unauthorized { _user: user }
500 }
501 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
502 };
503 cx.notify();
504 }
505 }
506
507 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
508 match &self.server {
509 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
510 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
511 CopilotServer::Error(error) => Err(anyhow!(
512 "copilot was not started because of an error: {}",
513 error
514 )),
515 CopilotServer::Started { server, status } => {
516 if matches!(status, SignInStatus::Authorized { .. }) {
517 Ok(server.clone())
518 } else {
519 Err(anyhow!("must sign in before using copilot"))
520 }
521 }
522 }
523 }
524}
525
526fn build_completion_params<T>(
527 buffer: &BufferSnapshot,
528 position: T,
529 cx: &AppContext,
530) -> request::GetCompletionsParams
531where
532 T: ToPointUtf16,
533{
534 let position = position.to_point_utf16(&buffer);
535 let language_name = buffer.language_at(position).map(|language| language.name());
536 let language_name = language_name.as_deref();
537
538 let path;
539 let relative_path;
540 if let Some(file) = buffer.file() {
541 if let Some(file) = file.as_local() {
542 path = file.abs_path(cx);
543 } else {
544 path = file.full_path(cx);
545 }
546 relative_path = file.path().to_path_buf();
547 } else {
548 path = PathBuf::from("/untitled");
549 relative_path = PathBuf::from("untitled");
550 }
551
552 let settings = cx.global::<Settings>();
553 let language_id = match language_name {
554 Some("Plain Text") => "plaintext".to_string(),
555 Some(language_name) => language_name.to_lowercase(),
556 None => "plaintext".to_string(),
557 };
558 request::GetCompletionsParams {
559 doc: request::GetCompletionsDocument {
560 source: buffer.text(),
561 tab_size: settings.tab_size(language_name).into(),
562 indent_size: 1,
563 insert_spaces: !settings.hard_tabs(language_name),
564 uri: lsp::Url::from_file_path(&path).unwrap(),
565 path: path.to_string_lossy().into(),
566 relative_path: relative_path.to_string_lossy().into(),
567 language_id,
568 position: point_to_lsp(position),
569 version: 0,
570 },
571 }
572}
573
574fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
575 let start = buffer.clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
576 let end = buffer.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
577 Completion {
578 range: buffer.anchor_before(start)..buffer.anchor_after(end),
579 text: completion.text,
580 }
581}
582
583async fn clear_copilot_dir() {
584 remove_matching(&paths::COPILOT_DIR, |_| true).await
585}
586
587async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
588 const SERVER_PATH: &'static str = "dist/agent.js";
589
590 ///Check for the latest copilot language server and download it if we haven't already
591 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
592 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
593
594 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
595
596 fs::create_dir_all(version_dir).await?;
597 let server_path = version_dir.join(SERVER_PATH);
598
599 if fs::metadata(&server_path).await.is_err() {
600 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
601 let dist_dir = version_dir.join("dist");
602 fs::create_dir_all(dist_dir.as_path()).await?;
603
604 let url = &release
605 .assets
606 .get(0)
607 .context("Github release for copilot contained no assets")?
608 .browser_download_url;
609
610 let mut response = http
611 .get(&url, Default::default(), true)
612 .await
613 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
614 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
615 let archive = Archive::new(decompressed_bytes);
616 archive.unpack(dist_dir).await?;
617
618 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
619 }
620
621 Ok(server_path)
622 }
623
624 match fetch_latest(http).await {
625 ok @ Result::Ok(..) => ok,
626 e @ Err(..) => {
627 e.log_err();
628 // Fetch a cached binary, if it exists
629 (|| async move {
630 let mut last_version_dir = None;
631 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
632 while let Some(entry) = entries.next().await {
633 let entry = entry?;
634 if entry.file_type().await?.is_dir() {
635 last_version_dir = Some(entry.path());
636 }
637 }
638 let last_version_dir =
639 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
640 let server_path = last_version_dir.join(SERVER_PATH);
641 if server_path.exists() {
642 Ok(server_path)
643 } else {
644 Err(anyhow!(
645 "missing executable in directory {:?}",
646 last_version_dir
647 ))
648 }
649 })()
650 .await
651 }
652 }
653}