1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
11 Task,
12};
13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
14use lsp::LanguageServer;
15use node_runtime::NodeRuntime;
16use settings::Settings;
17use smol::{fs, io::BufReader, stream::StreamExt};
18use std::{
19 ffi::OsString,
20 path::{Path, PathBuf},
21 sync::Arc,
22};
23use util::{
24 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
25};
26
27const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
28actions!(copilot_auth, [SignIn, SignOut]);
29
30const COPILOT_NAMESPACE: &'static str = "copilot";
31actions!(
32 copilot,
33 [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
34);
35
36pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
37 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
38 cx.set_global(copilot.clone());
39 cx.add_global_action(|_: &SignIn, cx| {
40 let copilot = Copilot::global(cx).unwrap();
41 copilot
42 .update(cx, |copilot, cx| copilot.sign_in(cx))
43 .detach_and_log_err(cx);
44 });
45 cx.add_global_action(|_: &SignOut, cx| {
46 let copilot = Copilot::global(cx).unwrap();
47 copilot
48 .update(cx, |copilot, cx| copilot.sign_out(cx))
49 .detach_and_log_err(cx);
50 });
51
52 cx.add_global_action(|_: &Reinstall, cx| {
53 let copilot = Copilot::global(cx).unwrap();
54 copilot
55 .update(cx, |copilot, cx| copilot.reinstall(cx))
56 .detach();
57 });
58
59 cx.observe(&copilot, |handle, cx| {
60 let status = handle.read(cx).status();
61 cx.update_global::<collections::CommandPaletteFilter, _, _>(
62 move |filter, _cx| match status {
63 Status::Disabled => {
64 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
65 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
66 }
67 Status::Authorized => {
68 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
69 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
70 }
71 _ => {
72 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
73 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
74 }
75 },
76 );
77 })
78 .detach();
79
80 sign_in::init(cx);
81}
82
83enum CopilotServer {
84 Disabled,
85 Starting {
86 task: Shared<Task<()>>,
87 },
88 Error(Arc<str>),
89 Started {
90 server: Arc<LanguageServer>,
91 status: SignInStatus,
92 },
93}
94
95#[derive(Clone, Debug)]
96enum SignInStatus {
97 Authorized {
98 _user: String,
99 },
100 Unauthorized {
101 _user: String,
102 },
103 SigningIn {
104 prompt: Option<request::PromptUserDeviceFlow>,
105 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
106 },
107 SignedOut,
108}
109
110#[derive(Debug, Clone)]
111pub enum Status {
112 Starting {
113 task: Shared<Task<()>>,
114 },
115 Error(Arc<str>),
116 Disabled,
117 SignedOut,
118 SigningIn {
119 prompt: Option<request::PromptUserDeviceFlow>,
120 },
121 Unauthorized,
122 Authorized,
123}
124
125impl Status {
126 pub fn is_authorized(&self) -> bool {
127 matches!(self, Status::Authorized)
128 }
129}
130
131#[derive(Debug, PartialEq, Eq)]
132pub struct Completion {
133 pub position: Anchor,
134 pub text: String,
135}
136
137pub struct Copilot {
138 http: Arc<dyn HttpClient>,
139 node_runtime: Arc<NodeRuntime>,
140 server: CopilotServer,
141}
142
143impl Entity for Copilot {
144 type Event = ();
145}
146
147impl Copilot {
148 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
149 match self.server {
150 CopilotServer::Starting { ref task } => Some(task.clone()),
151 _ => None,
152 }
153 }
154
155 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
156 if cx.has_global::<ModelHandle<Self>>() {
157 Some(cx.global::<ModelHandle<Self>>().clone())
158 } else {
159 None
160 }
161 }
162
163 fn start(
164 http: Arc<dyn HttpClient>,
165 node_runtime: Arc<NodeRuntime>,
166 cx: &mut ModelContext<Self>,
167 ) -> Self {
168 cx.observe_global::<Settings, _>({
169 let http = http.clone();
170 let node_runtime = node_runtime.clone();
171 move |this, cx| {
172 if cx.global::<Settings>().enable_copilot_integration {
173 if matches!(this.server, CopilotServer::Disabled) {
174 let start_task = cx
175 .spawn({
176 let http = http.clone();
177 let node_runtime = node_runtime.clone();
178 move |this, cx| {
179 Self::start_language_server(http, node_runtime, this, cx)
180 }
181 })
182 .shared();
183 this.server = CopilotServer::Starting { task: start_task };
184 cx.notify();
185 }
186 } else {
187 this.server = CopilotServer::Disabled;
188 cx.notify();
189 }
190 }
191 })
192 .detach();
193
194 if cx.global::<Settings>().enable_copilot_integration {
195 let start_task = cx
196 .spawn({
197 let http = http.clone();
198 let node_runtime = node_runtime.clone();
199 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
200 })
201 .shared();
202
203 Self {
204 http,
205 node_runtime,
206 server: CopilotServer::Starting { task: start_task },
207 }
208 } else {
209 Self {
210 http,
211 node_runtime,
212 server: CopilotServer::Disabled,
213 }
214 }
215 }
216
217 fn start_language_server(
218 http: Arc<dyn HttpClient>,
219 node_runtime: Arc<NodeRuntime>,
220 this: ModelHandle<Self>,
221 mut cx: AsyncAppContext,
222 ) -> impl Future<Output = ()> {
223 async move {
224 let start_language_server = async {
225 let server_path = get_copilot_lsp(http).await?;
226 let node_path = node_runtime.binary_path().await?;
227 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
228 let server =
229 LanguageServer::new(0, &node_path, arguments, Path::new("/"), cx.clone())?;
230
231 let server = server.initialize(Default::default()).await?;
232 let status = server
233 .request::<request::CheckStatus>(request::CheckStatusParams {
234 local_checks_only: false,
235 })
236 .await?;
237 anyhow::Ok((server, status))
238 };
239
240 let server = start_language_server.await;
241 this.update(&mut cx, |this, cx| {
242 cx.notify();
243 match server {
244 Ok((server, status)) => {
245 this.server = CopilotServer::Started {
246 server,
247 status: SignInStatus::SignedOut,
248 };
249 this.update_sign_in_status(status, cx);
250 }
251 Err(error) => {
252 this.server = CopilotServer::Error(error.to_string().into());
253 cx.notify()
254 }
255 }
256 })
257 }
258 }
259
260 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
261 if let CopilotServer::Started { server, status } = &mut self.server {
262 let task = match status {
263 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
264 Task::ready(Ok(())).shared()
265 }
266 SignInStatus::SigningIn { task, .. } => {
267 cx.notify();
268 task.clone()
269 }
270 SignInStatus::SignedOut => {
271 let server = server.clone();
272 let task = cx
273 .spawn(|this, mut cx| async move {
274 let sign_in = async {
275 let sign_in = server
276 .request::<request::SignInInitiate>(
277 request::SignInInitiateParams {},
278 )
279 .await?;
280 match sign_in {
281 request::SignInInitiateResult::AlreadySignedIn { user } => {
282 Ok(request::SignInStatus::Ok { user })
283 }
284 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
285 this.update(&mut cx, |this, cx| {
286 if let CopilotServer::Started { status, .. } =
287 &mut this.server
288 {
289 if let SignInStatus::SigningIn {
290 prompt: prompt_flow,
291 ..
292 } = status
293 {
294 *prompt_flow = Some(flow.clone());
295 cx.notify();
296 }
297 }
298 });
299 let response = server
300 .request::<request::SignInConfirm>(
301 request::SignInConfirmParams {
302 user_code: flow.user_code,
303 },
304 )
305 .await?;
306 Ok(response)
307 }
308 }
309 };
310
311 let sign_in = sign_in.await;
312 this.update(&mut cx, |this, cx| match sign_in {
313 Ok(status) => {
314 this.update_sign_in_status(status, cx);
315 Ok(())
316 }
317 Err(error) => {
318 this.update_sign_in_status(
319 request::SignInStatus::NotSignedIn,
320 cx,
321 );
322 Err(Arc::new(error))
323 }
324 })
325 })
326 .shared();
327 *status = SignInStatus::SigningIn {
328 prompt: None,
329 task: task.clone(),
330 };
331 cx.notify();
332 task
333 }
334 };
335
336 cx.foreground()
337 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
338 } else {
339 // If we're downloading, wait until download is finished
340 // If we're in a stuck state, display to the user
341 Task::ready(Err(anyhow!("copilot hasn't started yet")))
342 }
343 }
344
345 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
346 if let CopilotServer::Started { server, status } = &mut self.server {
347 *status = SignInStatus::SignedOut;
348 cx.notify();
349
350 let server = server.clone();
351 cx.background().spawn(async move {
352 server
353 .request::<request::SignOut>(request::SignOutParams {})
354 .await?;
355 anyhow::Ok(())
356 })
357 } else {
358 Task::ready(Err(anyhow!("copilot hasn't started yet")))
359 }
360 }
361
362 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
363 let start_task = cx
364 .spawn({
365 let http = self.http.clone();
366 let node_runtime = self.node_runtime.clone();
367 move |this, cx| async move {
368 clear_copilot_dir().await;
369 Self::start_language_server(http, node_runtime, this, cx).await
370 }
371 })
372 .shared();
373
374 self.server = CopilotServer::Starting {
375 task: start_task.clone(),
376 };
377
378 cx.notify();
379
380 cx.foreground().spawn(start_task)
381 }
382
383 pub fn completion<T>(
384 &self,
385 buffer: &ModelHandle<Buffer>,
386 position: T,
387 cx: &mut ModelContext<Self>,
388 ) -> Task<Result<Option<Completion>>>
389 where
390 T: ToPointUtf16,
391 {
392 let server = match self.authorized_server() {
393 Ok(server) => server,
394 Err(error) => return Task::ready(Err(error)),
395 };
396
397 let buffer = buffer.read(cx).snapshot();
398 let request = server
399 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
400 cx.background().spawn(async move {
401 let result = request.await?;
402 let completion = result
403 .completions
404 .into_iter()
405 .next()
406 .map(|completion| completion_from_lsp(completion, &buffer));
407 anyhow::Ok(completion)
408 })
409 }
410
411 pub fn completions_cycling<T>(
412 &self,
413 buffer: &ModelHandle<Buffer>,
414 position: T,
415 cx: &mut ModelContext<Self>,
416 ) -> Task<Result<Vec<Completion>>>
417 where
418 T: ToPointUtf16,
419 {
420 let server = match self.authorized_server() {
421 Ok(server) => server,
422 Err(error) => return Task::ready(Err(error)),
423 };
424
425 let buffer = buffer.read(cx).snapshot();
426 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
427 &buffer, position, cx,
428 ));
429 cx.background().spawn(async move {
430 let result = request.await?;
431 let completions = result
432 .completions
433 .into_iter()
434 .map(|completion| completion_from_lsp(completion, &buffer))
435 .collect();
436 anyhow::Ok(completions)
437 })
438 }
439
440 pub fn status(&self) -> Status {
441 match &self.server {
442 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
443 CopilotServer::Disabled => Status::Disabled,
444 CopilotServer::Error(error) => Status::Error(error.clone()),
445 CopilotServer::Started { status, .. } => match status {
446 SignInStatus::Authorized { .. } => Status::Authorized,
447 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
448 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
449 prompt: prompt.clone(),
450 },
451 SignInStatus::SignedOut => Status::SignedOut,
452 },
453 }
454 }
455
456 fn update_sign_in_status(
457 &mut self,
458 lsp_status: request::SignInStatus,
459 cx: &mut ModelContext<Self>,
460 ) {
461 if let CopilotServer::Started { status, .. } = &mut self.server {
462 *status = match lsp_status {
463 request::SignInStatus::Ok { user }
464 | request::SignInStatus::MaybeOk { user }
465 | request::SignInStatus::AlreadySignedIn { user } => {
466 SignInStatus::Authorized { _user: user }
467 }
468 request::SignInStatus::NotAuthorized { user } => {
469 SignInStatus::Unauthorized { _user: user }
470 }
471 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
472 };
473 cx.notify();
474 }
475 }
476
477 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
478 match &self.server {
479 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
480 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
481 CopilotServer::Error(error) => Err(anyhow!(
482 "copilot was not started because of an error: {}",
483 error
484 )),
485 CopilotServer::Started { server, status } => {
486 if matches!(status, SignInStatus::Authorized { .. }) {
487 Ok(server.clone())
488 } else {
489 Err(anyhow!("must sign in before using copilot"))
490 }
491 }
492 }
493 }
494}
495
496fn build_completion_params<T>(
497 buffer: &BufferSnapshot,
498 position: T,
499 cx: &AppContext,
500) -> request::GetCompletionsParams
501where
502 T: ToPointUtf16,
503{
504 let position = position.to_point_utf16(&buffer);
505 let language_name = buffer.language_at(position).map(|language| language.name());
506 let language_name = language_name.as_deref();
507
508 let path;
509 let relative_path;
510 if let Some(file) = buffer.file() {
511 if let Some(file) = file.as_local() {
512 path = file.abs_path(cx);
513 } else {
514 path = file.full_path(cx);
515 }
516 relative_path = file.path().to_path_buf();
517 } else {
518 path = PathBuf::from("/untitled");
519 relative_path = PathBuf::from("untitled");
520 }
521
522 let settings = cx.global::<Settings>();
523 let language_id = match language_name {
524 Some("Plain Text") => "plaintext".to_string(),
525 Some(language_name) => language_name.to_lowercase(),
526 None => "plaintext".to_string(),
527 };
528 request::GetCompletionsParams {
529 doc: request::GetCompletionsDocument {
530 source: buffer.text(),
531 tab_size: settings.tab_size(language_name).into(),
532 indent_size: 1,
533 insert_spaces: !settings.hard_tabs(language_name),
534 uri: lsp::Url::from_file_path(&path).unwrap(),
535 path: path.to_string_lossy().into(),
536 relative_path: relative_path.to_string_lossy().into(),
537 language_id,
538 position: point_to_lsp(position),
539 version: 0,
540 },
541 }
542}
543
544fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
545 let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left);
546 Completion {
547 position: buffer.anchor_before(position),
548 text: completion.display_text,
549 }
550}
551
552async fn clear_copilot_dir() {
553 remove_matching(&paths::COPILOT_DIR, |_| true).await
554}
555
556async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
557 const SERVER_PATH: &'static str = "dist/agent.js";
558
559 ///Check for the latest copilot language server and download it if we haven't already
560 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
561 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
562
563 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
564
565 fs::create_dir_all(version_dir).await?;
566 let server_path = version_dir.join(SERVER_PATH);
567
568 if fs::metadata(&server_path).await.is_err() {
569 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
570 let dist_dir = version_dir.join("dist");
571 fs::create_dir_all(dist_dir.as_path()).await?;
572
573 let url = &release
574 .assets
575 .get(0)
576 .context("Github release for copilot contained no assets")?
577 .browser_download_url;
578
579 let mut response = http
580 .get(&url, Default::default(), true)
581 .await
582 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
583 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
584 let archive = Archive::new(decompressed_bytes);
585 archive.unpack(dist_dir).await?;
586
587 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
588 }
589
590 Ok(server_path)
591 }
592
593 match fetch_latest(http).await {
594 ok @ Result::Ok(..) => ok,
595 e @ Err(..) => {
596 e.log_err();
597 // Fetch a cached binary, if it exists
598 (|| async move {
599 let mut last_version_dir = None;
600 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
601 while let Some(entry) = entries.next().await {
602 let entry = entry?;
603 if entry.file_type().await?.is_dir() {
604 last_version_dir = Some(entry.path());
605 }
606 }
607 let last_version_dir =
608 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
609 let server_path = last_version_dir.join(SERVER_PATH);
610 if server_path.exists() {
611 Ok(server_path)
612 } else {
613 Err(anyhow!(
614 "missing executable in directory {:?}",
615 last_version_dir
616 ))
617 }
618 })()
619 .await
620 }
621 }
622}