1mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use client::Client;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
11 Task,
12};
13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16};
14use lsp::LanguageServer;
15use node_runtime::NodeRuntime;
16use settings::Settings;
17use smol::{fs, io::BufReader, stream::StreamExt};
18use std::{
19 ffi::OsString,
20 ops::Range,
21 path::{Path, PathBuf},
22 sync::Arc,
23};
24use util::{
25 fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
26};
27
28const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
29actions!(copilot_auth, [SignIn, SignOut]);
30
31const COPILOT_NAMESPACE: &'static str = "copilot";
32actions!(
33 copilot,
34 [NextSuggestion, PreviousSuggestion, Toggle, Reinstall]
35);
36
37pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
38 let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), node_runtime, cx));
39 cx.set_global(copilot.clone());
40 cx.add_global_action(|_: &SignIn, cx| {
41 let copilot = Copilot::global(cx).unwrap();
42 copilot
43 .update(cx, |copilot, cx| copilot.sign_in(cx))
44 .detach_and_log_err(cx);
45 });
46 cx.add_global_action(|_: &SignOut, cx| {
47 let copilot = Copilot::global(cx).unwrap();
48 copilot
49 .update(cx, |copilot, cx| copilot.sign_out(cx))
50 .detach_and_log_err(cx);
51 });
52
53 cx.add_global_action(|_: &Reinstall, cx| {
54 let copilot = Copilot::global(cx).unwrap();
55 copilot
56 .update(cx, |copilot, cx| copilot.reinstall(cx))
57 .detach();
58 });
59
60 cx.observe(&copilot, |handle, cx| {
61 let status = handle.read(cx).status();
62 cx.update_global::<collections::CommandPaletteFilter, _, _>(
63 move |filter, _cx| match status {
64 Status::Disabled => {
65 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
66 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
67 }
68 Status::Authorized => {
69 filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
70 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
71 }
72 _ => {
73 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
74 filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
75 }
76 },
77 );
78 })
79 .detach();
80
81 sign_in::init(cx);
82}
83
84enum CopilotServer {
85 Disabled,
86 Starting {
87 task: Shared<Task<()>>,
88 },
89 Error(Arc<str>),
90 Started {
91 server: Arc<LanguageServer>,
92 status: SignInStatus,
93 },
94}
95
96#[derive(Clone, Debug)]
97enum SignInStatus {
98 Authorized {
99 _user: String,
100 },
101 Unauthorized {
102 _user: String,
103 },
104 SigningIn {
105 prompt: Option<request::PromptUserDeviceFlow>,
106 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
107 },
108 SignedOut,
109}
110
111#[derive(Debug, Clone)]
112pub enum Status {
113 Starting {
114 task: Shared<Task<()>>,
115 },
116 Error(Arc<str>),
117 Disabled,
118 SignedOut,
119 SigningIn {
120 prompt: Option<request::PromptUserDeviceFlow>,
121 },
122 Unauthorized,
123 Authorized,
124}
125
126impl Status {
127 pub fn is_authorized(&self) -> bool {
128 matches!(self, Status::Authorized)
129 }
130}
131
132#[derive(Debug, PartialEq, Eq)]
133pub struct Completion {
134 pub range: Range<Anchor>,
135 pub text: String,
136}
137
138pub struct Copilot {
139 http: Arc<dyn HttpClient>,
140 node_runtime: Arc<NodeRuntime>,
141 server: CopilotServer,
142}
143
144impl Entity for Copilot {
145 type Event = ();
146}
147
148impl Copilot {
149 pub fn starting_task(&self) -> Option<Shared<Task<()>>> {
150 match self.server {
151 CopilotServer::Starting { ref task } => Some(task.clone()),
152 _ => None,
153 }
154 }
155
156 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
157 if cx.has_global::<ModelHandle<Self>>() {
158 Some(cx.global::<ModelHandle<Self>>().clone())
159 } else {
160 None
161 }
162 }
163
164 fn start(
165 http: Arc<dyn HttpClient>,
166 node_runtime: Arc<NodeRuntime>,
167 cx: &mut ModelContext<Self>,
168 ) -> Self {
169 cx.observe_global::<Settings, _>({
170 let http = http.clone();
171 let node_runtime = node_runtime.clone();
172 move |this, cx| {
173 if cx.global::<Settings>().enable_copilot_integration {
174 if matches!(this.server, CopilotServer::Disabled) {
175 let start_task = cx
176 .spawn({
177 let http = http.clone();
178 let node_runtime = node_runtime.clone();
179 move |this, cx| {
180 Self::start_language_server(http, node_runtime, this, cx)
181 }
182 })
183 .shared();
184 this.server = CopilotServer::Starting { task: start_task };
185 cx.notify();
186 }
187 } else {
188 this.server = CopilotServer::Disabled;
189 cx.notify();
190 }
191 }
192 })
193 .detach();
194
195 if cx.global::<Settings>().enable_copilot_integration {
196 let start_task = cx
197 .spawn({
198 let http = http.clone();
199 let node_runtime = node_runtime.clone();
200 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
201 })
202 .shared();
203
204 Self {
205 http,
206 node_runtime,
207 server: CopilotServer::Starting { task: start_task },
208 }
209 } else {
210 Self {
211 http,
212 node_runtime,
213 server: CopilotServer::Disabled,
214 }
215 }
216 }
217
218 fn start_language_server(
219 http: Arc<dyn HttpClient>,
220 node_runtime: Arc<NodeRuntime>,
221 this: ModelHandle<Self>,
222 mut cx: AsyncAppContext,
223 ) -> impl Future<Output = ()> {
224 async move {
225 let start_language_server = async {
226 let server_path = get_copilot_lsp(http).await?;
227 let node_path = node_runtime.binary_path().await?;
228 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
229 let server = LanguageServer::new(
230 0,
231 &node_path,
232 arguments,
233 Path::new("/"),
234 None,
235 cx.clone(),
236 )?;
237
238 let server = server.initialize(Default::default()).await?;
239 let status = server
240 .request::<request::CheckStatus>(request::CheckStatusParams {
241 local_checks_only: false,
242 })
243 .await?;
244 anyhow::Ok((server, status))
245 };
246
247 let server = start_language_server.await;
248 this.update(&mut cx, |this, cx| {
249 cx.notify();
250 match server {
251 Ok((server, status)) => {
252 this.server = CopilotServer::Started {
253 server,
254 status: SignInStatus::SignedOut,
255 };
256 this.update_sign_in_status(status, cx);
257 }
258 Err(error) => {
259 this.server = CopilotServer::Error(error.to_string().into());
260 cx.notify()
261 }
262 }
263 })
264 }
265 }
266
267 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
268 if let CopilotServer::Started { server, status } = &mut self.server {
269 let task = match status {
270 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
271 Task::ready(Ok(())).shared()
272 }
273 SignInStatus::SigningIn { task, .. } => {
274 cx.notify();
275 task.clone()
276 }
277 SignInStatus::SignedOut => {
278 let server = server.clone();
279 let task = cx
280 .spawn(|this, mut cx| async move {
281 let sign_in = async {
282 let sign_in = server
283 .request::<request::SignInInitiate>(
284 request::SignInInitiateParams {},
285 )
286 .await?;
287 match sign_in {
288 request::SignInInitiateResult::AlreadySignedIn { user } => {
289 Ok(request::SignInStatus::Ok { user })
290 }
291 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
292 this.update(&mut cx, |this, cx| {
293 if let CopilotServer::Started { status, .. } =
294 &mut this.server
295 {
296 if let SignInStatus::SigningIn {
297 prompt: prompt_flow,
298 ..
299 } = status
300 {
301 *prompt_flow = Some(flow.clone());
302 cx.notify();
303 }
304 }
305 });
306 let response = server
307 .request::<request::SignInConfirm>(
308 request::SignInConfirmParams {
309 user_code: flow.user_code,
310 },
311 )
312 .await?;
313 Ok(response)
314 }
315 }
316 };
317
318 let sign_in = sign_in.await;
319 this.update(&mut cx, |this, cx| match sign_in {
320 Ok(status) => {
321 this.update_sign_in_status(status, cx);
322 Ok(())
323 }
324 Err(error) => {
325 this.update_sign_in_status(
326 request::SignInStatus::NotSignedIn,
327 cx,
328 );
329 Err(Arc::new(error))
330 }
331 })
332 })
333 .shared();
334 *status = SignInStatus::SigningIn {
335 prompt: None,
336 task: task.clone(),
337 };
338 cx.notify();
339 task
340 }
341 };
342
343 cx.foreground()
344 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
345 } else {
346 // If we're downloading, wait until download is finished
347 // If we're in a stuck state, display to the user
348 Task::ready(Err(anyhow!("copilot hasn't started yet")))
349 }
350 }
351
352 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
353 if let CopilotServer::Started { server, status } = &mut self.server {
354 *status = SignInStatus::SignedOut;
355 cx.notify();
356
357 let server = server.clone();
358 cx.background().spawn(async move {
359 server
360 .request::<request::SignOut>(request::SignOutParams {})
361 .await?;
362 anyhow::Ok(())
363 })
364 } else {
365 Task::ready(Err(anyhow!("copilot hasn't started yet")))
366 }
367 }
368
369 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
370 let start_task = cx
371 .spawn({
372 let http = self.http.clone();
373 let node_runtime = self.node_runtime.clone();
374 move |this, cx| async move {
375 clear_copilot_dir().await;
376 Self::start_language_server(http, node_runtime, this, cx).await
377 }
378 })
379 .shared();
380
381 self.server = CopilotServer::Starting {
382 task: start_task.clone(),
383 };
384
385 cx.notify();
386
387 cx.foreground().spawn(start_task)
388 }
389
390 pub fn completion<T>(
391 &self,
392 buffer: &ModelHandle<Buffer>,
393 position: T,
394 cx: &mut ModelContext<Self>,
395 ) -> Task<Result<Option<Completion>>>
396 where
397 T: ToPointUtf16,
398 {
399 let server = match self.authorized_server() {
400 Ok(server) => server,
401 Err(error) => return Task::ready(Err(error)),
402 };
403
404 let buffer = buffer.read(cx).snapshot();
405 let request = server
406 .request::<request::GetCompletions>(build_completion_params(&buffer, position, cx));
407 cx.background().spawn(async move {
408 let result = request.await?;
409 let completion = result
410 .completions
411 .into_iter()
412 .next()
413 .map(|completion| completion_from_lsp(completion, &buffer));
414 anyhow::Ok(completion)
415 })
416 }
417
418 pub fn completions_cycling<T>(
419 &self,
420 buffer: &ModelHandle<Buffer>,
421 position: T,
422 cx: &mut ModelContext<Self>,
423 ) -> Task<Result<Vec<Completion>>>
424 where
425 T: ToPointUtf16,
426 {
427 let server = match self.authorized_server() {
428 Ok(server) => server,
429 Err(error) => return Task::ready(Err(error)),
430 };
431
432 let buffer = buffer.read(cx).snapshot();
433 let request = server.request::<request::GetCompletionsCycling>(build_completion_params(
434 &buffer, position, cx,
435 ));
436 cx.background().spawn(async move {
437 let result = request.await?;
438 let completions = result
439 .completions
440 .into_iter()
441 .map(|completion| completion_from_lsp(completion, &buffer))
442 .collect();
443 anyhow::Ok(completions)
444 })
445 }
446
447 pub fn status(&self) -> Status {
448 match &self.server {
449 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
450 CopilotServer::Disabled => Status::Disabled,
451 CopilotServer::Error(error) => Status::Error(error.clone()),
452 CopilotServer::Started { status, .. } => match status {
453 SignInStatus::Authorized { .. } => Status::Authorized,
454 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
455 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
456 prompt: prompt.clone(),
457 },
458 SignInStatus::SignedOut => Status::SignedOut,
459 },
460 }
461 }
462
463 fn update_sign_in_status(
464 &mut self,
465 lsp_status: request::SignInStatus,
466 cx: &mut ModelContext<Self>,
467 ) {
468 if let CopilotServer::Started { status, .. } = &mut self.server {
469 *status = match lsp_status {
470 request::SignInStatus::Ok { user }
471 | request::SignInStatus::MaybeOk { user }
472 | request::SignInStatus::AlreadySignedIn { user } => {
473 SignInStatus::Authorized { _user: user }
474 }
475 request::SignInStatus::NotAuthorized { user } => {
476 SignInStatus::Unauthorized { _user: user }
477 }
478 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
479 };
480 cx.notify();
481 }
482 }
483
484 fn authorized_server(&self) -> Result<Arc<LanguageServer>> {
485 match &self.server {
486 CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")),
487 CopilotServer::Disabled => Err(anyhow!("copilot is disabled")),
488 CopilotServer::Error(error) => Err(anyhow!(
489 "copilot was not started because of an error: {}",
490 error
491 )),
492 CopilotServer::Started { server, status } => {
493 if matches!(status, SignInStatus::Authorized { .. }) {
494 Ok(server.clone())
495 } else {
496 Err(anyhow!("must sign in before using copilot"))
497 }
498 }
499 }
500 }
501}
502
503fn build_completion_params<T>(
504 buffer: &BufferSnapshot,
505 position: T,
506 cx: &AppContext,
507) -> request::GetCompletionsParams
508where
509 T: ToPointUtf16,
510{
511 let position = position.to_point_utf16(&buffer);
512 let language_name = buffer.language_at(position).map(|language| language.name());
513 let language_name = language_name.as_deref();
514
515 let path;
516 let relative_path;
517 if let Some(file) = buffer.file() {
518 if let Some(file) = file.as_local() {
519 path = file.abs_path(cx);
520 } else {
521 path = file.full_path(cx);
522 }
523 relative_path = file.path().to_path_buf();
524 } else {
525 path = PathBuf::from("/untitled");
526 relative_path = PathBuf::from("untitled");
527 }
528
529 let settings = cx.global::<Settings>();
530 let language_id = match language_name {
531 Some("Plain Text") => "plaintext".to_string(),
532 Some(language_name) => language_name.to_lowercase(),
533 None => "plaintext".to_string(),
534 };
535 request::GetCompletionsParams {
536 doc: request::GetCompletionsDocument {
537 source: buffer.text(),
538 tab_size: settings.tab_size(language_name).into(),
539 indent_size: 1,
540 insert_spaces: !settings.hard_tabs(language_name),
541 uri: lsp::Url::from_file_path(&path).unwrap(),
542 path: path.to_string_lossy().into(),
543 relative_path: relative_path.to_string_lossy().into(),
544 language_id,
545 position: point_to_lsp(position),
546 version: 0,
547 },
548 }
549}
550
551fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion {
552 let start = buffer.clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
553 let end = buffer.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
554 Completion {
555 range: buffer.anchor_before(start)..buffer.anchor_after(end),
556 text: completion.text,
557 }
558}
559
560async fn clear_copilot_dir() {
561 remove_matching(&paths::COPILOT_DIR, |_| true).await
562}
563
564async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
565 const SERVER_PATH: &'static str = "dist/agent.js";
566
567 ///Check for the latest copilot language server and download it if we haven't already
568 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
569 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
570
571 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
572
573 fs::create_dir_all(version_dir).await?;
574 let server_path = version_dir.join(SERVER_PATH);
575
576 if fs::metadata(&server_path).await.is_err() {
577 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
578 let dist_dir = version_dir.join("dist");
579 fs::create_dir_all(dist_dir.as_path()).await?;
580
581 let url = &release
582 .assets
583 .get(0)
584 .context("Github release for copilot contained no assets")?
585 .browser_download_url;
586
587 let mut response = http
588 .get(&url, Default::default(), true)
589 .await
590 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
591 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
592 let archive = Archive::new(decompressed_bytes);
593 archive.unpack(dist_dir).await?;
594
595 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
596 }
597
598 Ok(server_path)
599 }
600
601 match fetch_latest(http).await {
602 ok @ Result::Ok(..) => ok,
603 e @ Err(..) => {
604 e.log_err();
605 // Fetch a cached binary, if it exists
606 (|| async move {
607 let mut last_version_dir = None;
608 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
609 while let Some(entry) = entries.next().await {
610 let entry = entry?;
611 if entry.file_type().await?.is_dir() {
612 last_version_dir = Some(entry.path());
613 }
614 }
615 let last_version_dir =
616 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
617 let server_path = last_version_dir.join(SERVER_PATH);
618 if server_path.exists() {
619 Ok(server_path)
620 } else {
621 Err(anyhow!(
622 "missing executable in directory {:?}",
623 last_version_dir
624 ))
625 }
626 })()
627 .await
628 }
629 }
630}