1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
11 Task,
12};
13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
14use lsp::LanguageServer;
15use node_runtime::NodeRuntime;
16use settings::Settings;
17use smol::{fs, io::BufReader, stream::StreamExt};
18use std::{
19 ffi::OsString,
20 path::{Path, PathBuf},
21 sync::Arc,
22};
23use util::{
24 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
25};
26
27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
28actions!(copilot_auth, [SignIn, SignOut]);
29
30const COPILOT_NAMESPACE: &'static str = "copilot";
31actions!(
32 copilot,
33 [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
34);
35
36pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
37 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
38 cx.set_global(copilot.clone());
39 cx.add_global_action(|_: &SignIn, cx| {
40 let copilot = Copilot::global(cx).unwrap();
41 copilot
42 .update(cx, |copilot, cx| copilot.sign_in(cx))
43 .detach_and_log_err(cx);
44 });
45 cx.add_global_action(|_: &SignOut, cx| {
46 let copilot = Copilot::global(cx).unwrap();
47 copilot
48 .update(cx, |copilot, cx| copilot.sign_out(cx))
49 .detach_and_log_err(cx);
50 });
51
52 cx.add_global_action(|_: &Reinstall, cx| {
53 let copilot = Copilot::global(cx).unwrap();
54 copilot
55 .update(cx, |copilot, cx| copilot.reinstall(cx))
56 .detach();
57 });
58
59 cx.observe(&copilot, |handle, cx| {
60 let status = handle.read(cx).status();
61 cx.update_global::<collections::CommandPaletteFilter, _, _>(
62 move |filter, _cx| match status {
63 Status::Disabled => {
64 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
65 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
66 }
67 Status::Authorized => {
68 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
69 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
70 }
71 _ => {
72 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
73 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
74 }
75 },
76 );
77 })
78 .detach();
79
80 sign_in::init(cx);
81}
82
83enum CopilotServer {
84 Disabled,
85 Starting {
86 task: Shared<Task<()>>,
87 },
88 Error(Arc<str>),
89 Started {
90 server: Arc<LanguageServer>,
91 status: SignInStatus,
92 },
93}
94
95#[derive(Clone, Debug)]
96enum SignInStatus {
97 Authorized {
98 _user: String,
99 },
100 Unauthorized {
101 _user: String,
102 },
103 SigningIn {
104 prompt: Option<request::PromptUserDeviceFlow>,
105 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
106 },
107 SignedOut,
108}
109
110#[derive(Debug, Clone)]
111pub enum Status {
112 Starting {
113 task: Shared<Task<()>>,
114 },
115 Error(Arc<str>),
116 Disabled,
117 SignedOut,
118 SigningIn {
119 prompt: Option<request::PromptUserDeviceFlow>,
120 },
121 Unauthorized,
122 Authorized,
123}
124
125impl Status {
126 pub fn is_authorized(&self) -> bool {
127 matches!(self, Status::Authorized)
128 }
129}
130
131#[derive(Debug, PartialEq, Eq)]
132pub struct Completion {
133 pub position: Anchor,
134 pub text: String,
135}
136
137pub struct Copilot {
138 http: Arc<dyn HttpClient>,
139 node_runtime: Arc<NodeRuntime>,
140 server: CopilotServer,
141}
142
143impl Entity for Copilot {
144 type Event = ();
145}
146
147impl Copilot {
148 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
149 match self.server {
150 CopilotServer::Starting { ref task } => Some(task.clone()),
151 _ => None,
152 }
153 }
154
155 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
156 if cx.has_global::<ModelHandle<Self>>() {
157 Some(cx.global::<ModelHandle<Self>>().clone())
158 } else {
159 None
160 }
161 }
162
163 fn start(
164 http: Arc<dyn HttpClient>,
165 node_runtime: Arc<NodeRuntime>,
166 cx: &mut ModelContext<Self>,
167 ) -> Self {
168 cx.observe_global::<Settings, _>({
169 let http = http.clone();
170 let node_runtime = node_runtime.clone();
171 move |this, cx| {
172 if cx.global::<Settings>().enable_copilot_integration {
173 if matches!(this.server, CopilotServer::Disabled) {
174 let start_task = cx
175 .spawn({
176 let http = http.clone();
177 let node_runtime = node_runtime.clone();
178 move |this, cx| {
179 Self::start_language_server(http, node_runtime, this, cx)
180 }
181 })
182 .shared();
183 this.server = CopilotServer::Starting { task: start_task };
184 cx.notify();
185 }
186 } else {
187 this.server = CopilotServer::Disabled;
188 cx.notify();
189 }
190 }
191 })
192 .detach();
193
194 if cx.global::<Settings>().enable_copilot_integration {
195 let start_task = cx
196 .spawn({
197 let http = http.clone();
198 let node_runtime = node_runtime.clone();
199 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
200 })
201 .shared();
202
203 Self {
204 http,
205 node_runtime,
206 server: CopilotServer::Starting { task: start_task },
207 }
208 } else {
209 Self {
210 http,
211 node_runtime,
212 server: CopilotServer::Disabled,
213 }
214 }
215 }
216
217 fn start_language_server(
218 http: Arc<dyn HttpClient>,
219 node_runtime: Arc<NodeRuntime>,
220 this: ModelHandle<Self>,
221 mut cx: AsyncAppContext,
222 ) -> impl Future<Output = ()> {
223 async move {
224 let start_language_server = async {
225 let server_path = get_copilot_lsp(http).await?;
226 let node_path = node_runtime.binary_path().await?;
227 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
228 let server = LanguageServer::new(
229 0,
230 &node_path,
231 arguments,
232 Path::new("/"),
233 None,
234 cx.clone(),
235 )?;
236
237 let server = server.initialize(Default::default()).await?;
238 let status = server
239 .request::<request::CheckStatus>(request::CheckStatusParams {
240 local_checks_only: false,
241 })
242 .await?;
243 anyhow::Ok((server, status))
244 };
245
246 let server = start_language_server.await;
247 this.update(&mut cx, |this, cx| {
248 cx.notify();
249 match server {
250 Ok((server, status)) => {
251 this.server = CopilotServer::Started {
252 server,
253 status: SignInStatus::SignedOut,
254 };
255 this.update_sign_in_status(status, cx);
256 }
257 Err(error) => {
258 this.server = CopilotServer::Error(error.to_string().into());
259 cx.notify()
260 }
261 }
262 })
263 }
264 }
265
266 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
267 if let CopilotServer::Started { server, status } = &mut self.server {
268 let task = match status {
269 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
270 Task::ready(Ok(())).shared()
271 }
272 SignInStatus::SigningIn { task, .. } => {
273 cx.notify();
274 task.clone()
275 }
276 SignInStatus::SignedOut => {
277 let server = server.clone();
278 let task = cx
279 .spawn(|this, mut cx| async move {
280 let sign_in = async {
281 let sign_in = server
282 .request::<request::SignInInitiate>(
283 request::SignInInitiateParams {},
284 )
285 .await?;
286 match sign_in {
287 request::SignInInitiateResult::AlreadySignedIn { user } => {
288 Ok(request::SignInStatus::Ok { user })
289 }
290 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
291 this.update(&mut cx, |this, cx| {
292 if let CopilotServer::Started { status, .. } =
293 &mut this.server
294 {
295 if let SignInStatus::SigningIn {
296 prompt: prompt_flow,
297 ..
298 } = status
299 {
300 *prompt_flow = Some(flow.clone());
301 cx.notify();
302 }
303 }
304 });
305 let response = server
306 .request::<request::SignInConfirm>(
307 request::SignInConfirmParams {
308 user_code: flow.user_code,
309 },
310 )
311 .await?;
312 Ok(response)
313 }
314 }
315 };
316
317 let sign_in = sign_in.await;
318 this.update(&mut cx, |this, cx| match sign_in {
319 Ok(status) => {
320 this.update_sign_in_status(status, cx);
321 Ok(())
322 }
323 Err(error) => {
324 this.update_sign_in_status(
325 request::SignInStatus::NotSignedIn,
326 cx,
327 );
328 Err(Arc::new(error))
329 }
330 })
331 })
332 .shared();
333 *status = SignInStatus::SigningIn {
334 prompt: None,
335 task: task.clone(),
336 };
337 cx.notify();
338 task
339 }
340 };
341
342 cx.foreground()
343 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
344 } else {
345 // If we're downloading, wait until download is finished
346 // If we're in a stuck state, display to the user
347 Task::ready(Err(anyhow!("copilot hasn't started yet")))
348 }
349 }
350
351 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
352 if let CopilotServer::Started { server, status } = &mut self.server {
353 *status = SignInStatus::SignedOut;
354 cx.notify();
355
356 let server = server.clone();
357 cx.background().spawn(async move {
358 server
359 .request::<request::SignOut>(request::SignOutParams {})
360 .await?;
361 anyhow::Ok(())
362 })
363 } else {
364 Task::ready(Err(anyhow!("copilot hasn't started yet")))
365 }
366 }
367
368 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
369 let start_task = cx
370 .spawn({
371 let http = self.http.clone();
372 let node_runtime = self.node_runtime.clone();
373 move |this, cx| async move {
374 clear_copilot_dir().await;
375 Self::start_language_server(http, node_runtime, this, cx).await
376 }
377 })
378 .shared();
379
380 self.server = CopilotServer::Starting {
381 task: start_task.clone(),
382 };
383
384 cx.notify();
385
386 cx.foreground().spawn(start_task)
387 }
388
389 pub fn completion<T>(
390 &self,
391 buffer: &ModelHandle<Buffer>,
392 position: T,
393 cx: &mut ModelContext<Self>,
394 ) -> Task<Result<Option<Completion>>>
395 where
396 T: ToPointUtf16,
397 {
398 let server = match self.authorized_server() {
399 Ok(server) => server,
400 Err(error) => return Task::ready(Err(error)),
401 };
402
403 let buffer = buffer.read(cx).snapshot();
404 let request = server
405 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
406 cx.background().spawn(async move {
407 let result = request.await?;
408 let completion = result
409 .completions
410 .into_iter()
411 .next()
412 .map(|completion| completion_from_lsp(completion, &buffer));
413 anyhow::Ok(completion)
414 })
415 }
416
417 pub fn completions_cycling<T>(
418 &self,
419 buffer: &ModelHandle<Buffer>,
420 position: T,
421 cx: &mut ModelContext<Self>,
422 ) -> Task<Result<Vec<Completion>>>
423 where
424 T: ToPointUtf16,
425 {
426 let server = match self.authorized_server() {
427 Ok(server) => server,
428 Err(error) => return Task::ready(Err(error)),
429 };
430
431 let buffer = buffer.read(cx).snapshot();
432 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
433 &buffer, position, cx,
434 ));
435 cx.background().spawn(async move {
436 let result = request.await?;
437 let completions = result
438 .completions
439 .into_iter()
440 .map(|completion| completion_from_lsp(completion, &buffer))
441 .collect();
442 anyhow::Ok(completions)
443 })
444 }
445
446 pub fn status(&self) -> Status {
447 match &self.server {
448 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
449 CopilotServer::Disabled => Status::Disabled,
450 CopilotServer::Error(error) => Status::Error(error.clone()),
451 CopilotServer::Started { status, .. } => match status {
452 SignInStatus::Authorized { .. } => Status::Authorized,
453 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
454 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
455 prompt: prompt.clone(),
456 },
457 SignInStatus::SignedOut => Status::SignedOut,
458 },
459 }
460 }
461
462 fn update_sign_in_status(
463 &mut self,
464 lsp_status: request::SignInStatus,
465 cx: &mut ModelContext<Self>,
466 ) {
467 if let CopilotServer::Started { status, .. } = &mut self.server {
468 *status = match lsp_status {
469 request::SignInStatus::Ok { user }
470 | request::SignInStatus::MaybeOk { user }
471 | request::SignInStatus::AlreadySignedIn { user } => {
472 SignInStatus::Authorized { _user: user }
473 }
474 request::SignInStatus::NotAuthorized { user } => {
475 SignInStatus::Unauthorized { _user: user }
476 }
477 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
478 };
479 cx.notify();
480 }
481 }
482
483 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
484 match &self.server {
485 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
486 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
487 CopilotServer::Error(error) => Err(anyhow!(
488 "copilot was not started because of an error: {}",
489 error
490 )),
491 CopilotServer::Started { server, status } => {
492 if matches!(status, SignInStatus::Authorized { .. }) {
493 Ok(server.clone())
494 } else {
495 Err(anyhow!("must sign in before using copilot"))
496 }
497 }
498 }
499 }
500}
501
502fn build_completion_params<T>(
503 buffer: &BufferSnapshot,
504 position: T,
505 cx: &AppContext,
506) -> request::GetCompletionsParams
507where
508 T: ToPointUtf16,
509{
510 let position = position.to_point_utf16(&buffer);
511 let language_name = buffer.language_at(position).map(|language| language.name());
512 let language_name = language_name.as_deref();
513
514 let path;
515 let relative_path;
516 if let Some(file) = buffer.file() {
517 if let Some(file) = file.as_local() {
518 path = file.abs_path(cx);
519 } else {
520 path = file.full_path(cx);
521 }
522 relative_path = file.path().to_path_buf();
523 } else {
524 path = PathBuf::from("/untitled");
525 relative_path = PathBuf::from("untitled");
526 }
527
528 let settings = cx.global::<Settings>();
529 let language_id = match language_name {
530 Some("Plain Text") => "plaintext".to_string(),
531 Some(language_name) => language_name.to_lowercase(),
532 None => "plaintext".to_string(),
533 };
534 request::GetCompletionsParams {
535 doc: request::GetCompletionsDocument {
536 source: buffer.text(),
537 tab_size: settings.tab_size(language_name).into(),
538 indent_size: 1,
539 insert_spaces: !settings.hard_tabs(language_name),
540 uri: lsp::Url::from_file_path(&path).unwrap(),
541 path: path.to_string_lossy().into(),
542 relative_path: relative_path.to_string_lossy().into(),
543 language_id,
544 position: point_to_lsp(position),
545 version: 0,
546 },
547 }
548}
549
550fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
551 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
552 Completion {
553 position: buffer.anchor_before(position),
554 text: completion.display_text,
555 }
556}
557
558async fn clear_copilot_dir() {
559 remove_matching(&paths::COPILOT_DIR, |_| true).await
560}
561
562async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
563 const SERVER_PATH: &'static str = "dist/agent.js";
564
565 ///Check for the latest copilot language server and download it if we haven't already
566 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
567 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
568
569 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
570
571 fs::create_dir_all(version_dir).await?;
572 let server_path = version_dir.join(SERVER_PATH);
573
574 if fs::metadata(&server_path).await.is_err() {
575 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
576 let dist_dir = version_dir.join("dist");
577 fs::create_dir_all(dist_dir.as_path()).await?;
578
579 let url = &release
580 .assets
581 .get(0)
582 .context("Github release for copilot contained no assets")?
583 .browser_download_url;
584
585 let mut response = http
586 .get(&url, Default::default(), true)
587 .await
588 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
589 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
590 let archive = Archive::new(decompressed_bytes);
591 archive.unpack(dist_dir).await?;
592
593 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
594 }
595
596 Ok(server_path)
597 }
598
599 match fetch_latest(http).await {
600 ok @ Result::Ok(..) => ok,
601 e @ Err(..) => {
602 e.log_err();
603 // Fetch a cached binary, if it exists
604 (|| async move {
605 let mut last_version_dir = None;
606 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
607 while let Some(entry) = entries.next().await {
608 let entry = entry?;
609 if entry.file_type().await?.is_dir() {
610 last_version_dir = Some(entry.path());
611 }
612 }
613 let last_version_dir =
614 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
615 let server_path = last_version_dir.join(SERVER_PATH);
616 if server_path.exists() {
617 Ok(server_path)
618 } else {
619 Err(anyhow!(
620 "missing executable in directory {:?}",
621 last_version_dir
622 ))
623 }
624 })()
625 .await
626 }
627 }
628}