1pub mod request;
2mod sign_in;
3
4use anyhow::{anyhow, Context, Result};
5use async_compression::futures::bufread::GzipDecoder;
6use async_tar::Archive;
7use collections::HashMap;
8use futures::{future::Shared, Future, FutureExt, TryFutureExt};
9use gpui::{
10 actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
11 Task,
12};
13use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
14use log::{debug, error};
15use lsp::LanguageServer;
16use node_runtime::NodeRuntime;
17use request::{LogMessage, StatusNotification};
18use settings::Settings;
19use smol::{fs, io::BufReader, stream::StreamExt};
20use std::{
21 ffi::OsString,
22 ops::Range,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use util::{
27 channel::ReleaseChannel, fs::remove_matching, github::latest_github_release, http::HttpClient,
28 paths, ResultExt,
29};
30
31const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
32actions!(copilot_auth, [SignIn, SignOut]);
33
34const COPILOT_NAMESPACE: &'static str = "copilot";
35actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
36
37pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
38 // Disable Copilot for stable releases.
39 if *cx.global::<ReleaseChannel>() == ReleaseChannel::Stable {
40 cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
41 filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
42 filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
43 });
44 return;
45 }
46
47 let copilot = cx.add_model({
48 let node_runtime = node_runtime.clone();
49 move |cx| Copilot::start(http, node_runtime, cx)
50 });
51 cx.set_global(copilot.clone());
52
53 sign_in::init(cx);
54 cx.add_global_action(|_: &SignIn, cx| {
55 if let Some(copilot) = Copilot::global(cx) {
56 copilot
57 .update(cx, |copilot, cx| copilot.sign_in(cx))
58 .detach_and_log_err(cx);
59 }
60 });
61 cx.add_global_action(|_: &SignOut, cx| {
62 if let Some(copilot) = Copilot::global(cx) {
63 copilot
64 .update(cx, |copilot, cx| copilot.sign_out(cx))
65 .detach_and_log_err(cx);
66 }
67 });
68
69 cx.add_global_action(|_: &Reinstall, cx| {
70 if let Some(copilot) = Copilot::global(cx) {
71 copilot
72 .update(cx, |copilot, cx| copilot.reinstall(cx))
73 .detach();
74 }
75 });
76}
77
78enum CopilotServer {
79 Disabled,
80 Starting {
81 task: Shared<Task<()>>,
82 },
83 Error(Arc<str>),
84 Started {
85 server: Arc<LanguageServer>,
86 status: SignInStatus,
87 subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
88 },
89}
90
91#[derive(Clone, Debug)]
92enum SignInStatus {
93 Authorized,
94 Unauthorized,
95 SigningIn {
96 prompt: Option<request::PromptUserDeviceFlow>,
97 task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
98 },
99 SignedOut,
100}
101
102#[derive(Debug, Clone)]
103pub enum Status {
104 Starting {
105 task: Shared<Task<()>>,
106 },
107 Error(Arc<str>),
108 Disabled,
109 SignedOut,
110 SigningIn {
111 prompt: Option<request::PromptUserDeviceFlow>,
112 },
113 Unauthorized,
114 Authorized,
115}
116
117impl Status {
118 pub fn is_authorized(&self) -> bool {
119 matches!(self, Status::Authorized)
120 }
121}
122
123#[derive(Debug, PartialEq, Eq)]
124pub struct Completion {
125 pub range: Range<Anchor>,
126 pub text: String,
127}
128
129pub struct Copilot {
130 http: Arc<dyn HttpClient>,
131 node_runtime: Arc<NodeRuntime>,
132 server: CopilotServer,
133}
134
135impl Entity for Copilot {
136 type Event = ();
137}
138
139impl Copilot {
140 pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
141 if cx.has_global::<ModelHandle<Self>>() {
142 Some(cx.global::<ModelHandle<Self>>().clone())
143 } else {
144 None
145 }
146 }
147
148 fn start(
149 http: Arc<dyn HttpClient>,
150 node_runtime: Arc<NodeRuntime>,
151 cx: &mut ModelContext<Self>,
152 ) -> Self {
153 cx.observe_global::<Settings, _>({
154 let http = http.clone();
155 let node_runtime = node_runtime.clone();
156 move |this, cx| {
157 if cx.global::<Settings>().enable_copilot_integration {
158 if matches!(this.server, CopilotServer::Disabled) {
159 let start_task = cx
160 .spawn({
161 let http = http.clone();
162 let node_runtime = node_runtime.clone();
163 move |this, cx| {
164 Self::start_language_server(http, node_runtime, this, cx)
165 }
166 })
167 .shared();
168 this.server = CopilotServer::Starting { task: start_task };
169 cx.notify();
170 }
171 } else {
172 this.server = CopilotServer::Disabled;
173 cx.notify();
174 }
175 }
176 })
177 .detach();
178
179 if cx.global::<Settings>().enable_copilot_integration {
180 let start_task = cx
181 .spawn({
182 let http = http.clone();
183 let node_runtime = node_runtime.clone();
184 move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
185 })
186 .shared();
187
188 Self {
189 http,
190 node_runtime,
191 server: CopilotServer::Starting { task: start_task },
192 }
193 } else {
194 Self {
195 http,
196 node_runtime,
197 server: CopilotServer::Disabled,
198 }
199 }
200 }
201
202 #[cfg(any(test, feature = "test-support"))]
203 pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
204 let (server, fake_server) =
205 LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
206 let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
207 let this = cx.add_model(|cx| Self {
208 http: http.clone(),
209 node_runtime: NodeRuntime::new(http, cx.background().clone()),
210 server: CopilotServer::Started {
211 server: Arc::new(server),
212 status: SignInStatus::Authorized,
213 subscriptions_by_buffer_id: Default::default(),
214 },
215 });
216 (this, fake_server)
217 }
218
219 fn start_language_server(
220 http: Arc<dyn HttpClient>,
221 node_runtime: Arc<NodeRuntime>,
222 this: ModelHandle<Self>,
223 mut cx: AsyncAppContext,
224 ) -> impl Future<Output = ()> {
225 async move {
226 let start_language_server = async {
227 let server_path = get_copilot_lsp(http).await?;
228 let node_path = node_runtime.binary_path().await?;
229 let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
230 let server = LanguageServer::new(
231 0,
232 &node_path,
233 arguments,
234 Path::new("/"),
235 None,
236 cx.clone(),
237 )?;
238
239 let server = server.initialize(Default::default()).await?;
240 let status = server
241 .request::<request::CheckStatus>(request::CheckStatusParams {
242 local_checks_only: false,
243 })
244 .await?;
245
246 server
247 .on_notification::<LogMessage, _>(|params, _cx| {
248 match params.level {
249 // Copilot is pretty agressive about logging
250 0 => debug!("copilot: {}", params.message),
251 1 => debug!("copilot: {}", params.message),
252 _ => error!("copilot: {}", params.message),
253 }
254
255 debug!("copilot metadata: {}", params.metadata_str);
256 debug!("copilot extra: {:?}", params.extra);
257 })
258 .detach();
259
260 server
261 .on_notification::<StatusNotification, _>(
262 |_, _| { /* Silence the notification */ },
263 )
264 .detach();
265
266 anyhow::Ok((server, status))
267 };
268
269 let server = start_language_server.await;
270 this.update(&mut cx, |this, cx| {
271 cx.notify();
272 match server {
273 Ok((server, status)) => {
274 this.server = CopilotServer::Started {
275 server,
276 status: SignInStatus::SignedOut,
277 subscriptions_by_buffer_id: Default::default(),
278 };
279 this.update_sign_in_status(status, cx);
280 }
281 Err(error) => {
282 this.server = CopilotServer::Error(error.to_string().into());
283 cx.notify()
284 }
285 }
286 })
287 }
288 }
289
290 fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
291 if let CopilotServer::Started { server, status, .. } = &mut self.server {
292 let task = match status {
293 SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
294 Task::ready(Ok(())).shared()
295 }
296 SignInStatus::SigningIn { task, .. } => {
297 cx.notify();
298 task.clone()
299 }
300 SignInStatus::SignedOut => {
301 let server = server.clone();
302 let task = cx
303 .spawn(|this, mut cx| async move {
304 let sign_in = async {
305 let sign_in = server
306 .request::<request::SignInInitiate>(
307 request::SignInInitiateParams {},
308 )
309 .await?;
310 match sign_in {
311 request::SignInInitiateResult::AlreadySignedIn { user } => {
312 Ok(request::SignInStatus::Ok { user })
313 }
314 request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
315 this.update(&mut cx, |this, cx| {
316 if let CopilotServer::Started { status, .. } =
317 &mut this.server
318 {
319 if let SignInStatus::SigningIn {
320 prompt: prompt_flow,
321 ..
322 } = status
323 {
324 *prompt_flow = Some(flow.clone());
325 cx.notify();
326 }
327 }
328 });
329 let response = server
330 .request::<request::SignInConfirm>(
331 request::SignInConfirmParams {
332 user_code: flow.user_code,
333 },
334 )
335 .await?;
336 Ok(response)
337 }
338 }
339 };
340
341 let sign_in = sign_in.await;
342 this.update(&mut cx, |this, cx| match sign_in {
343 Ok(status) => {
344 this.update_sign_in_status(status, cx);
345 Ok(())
346 }
347 Err(error) => {
348 this.update_sign_in_status(
349 request::SignInStatus::NotSignedIn,
350 cx,
351 );
352 Err(Arc::new(error))
353 }
354 })
355 })
356 .shared();
357 *status = SignInStatus::SigningIn {
358 prompt: None,
359 task: task.clone(),
360 };
361 cx.notify();
362 task
363 }
364 };
365
366 cx.foreground()
367 .spawn(task.map_err(|err| anyhow!("{:?}", err)))
368 } else {
369 // If we're downloading, wait until download is finished
370 // If we're in a stuck state, display to the user
371 Task::ready(Err(anyhow!("copilot hasn't started yet")))
372 }
373 }
374
375 fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
376 if let CopilotServer::Started { server, status, .. } = &mut self.server {
377 *status = SignInStatus::SignedOut;
378 cx.notify();
379
380 let server = server.clone();
381 cx.background().spawn(async move {
382 server
383 .request::<request::SignOut>(request::SignOutParams {})
384 .await?;
385 anyhow::Ok(())
386 })
387 } else {
388 Task::ready(Err(anyhow!("copilot hasn't started yet")))
389 }
390 }
391
392 fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
393 let start_task = cx
394 .spawn({
395 let http = self.http.clone();
396 let node_runtime = self.node_runtime.clone();
397 move |this, cx| async move {
398 clear_copilot_dir().await;
399 Self::start_language_server(http, node_runtime, this, cx).await
400 }
401 })
402 .shared();
403
404 self.server = CopilotServer::Starting {
405 task: start_task.clone(),
406 };
407
408 cx.notify();
409
410 cx.foreground().spawn(start_task)
411 }
412
413 pub fn completions<T>(
414 &mut self,
415 buffer: &ModelHandle<Buffer>,
416 position: T,
417 cx: &mut ModelContext<Self>,
418 ) -> Task<Result<Vec<Completion>>>
419 where
420 T: ToPointUtf16,
421 {
422 self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
423 }
424
425 pub fn completions_cycling<T>(
426 &mut self,
427 buffer: &ModelHandle<Buffer>,
428 position: T,
429 cx: &mut ModelContext<Self>,
430 ) -> Task<Result<Vec<Completion>>>
431 where
432 T: ToPointUtf16,
433 {
434 self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
435 }
436
437 fn request_completions<R, T>(
438 &mut self,
439 buffer: &ModelHandle<Buffer>,
440 position: T,
441 cx: &mut ModelContext<Self>,
442 ) -> Task<Result<Vec<Completion>>>
443 where
444 R: lsp::request::Request<
445 Params = request::GetCompletionsParams,
446 Result = request::GetCompletionsResult,
447 >,
448 T: ToPointUtf16,
449 {
450 let buffer_id = buffer.id();
451 let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
452 let snapshot = buffer.read(cx).snapshot();
453 let server = match &mut self.server {
454 CopilotServer::Starting { .. } => {
455 return Task::ready(Err(anyhow!("copilot is still starting")))
456 }
457 CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
458 CopilotServer::Error(error) => {
459 return Task::ready(Err(anyhow!(
460 "copilot was not started because of an error: {}",
461 error
462 )))
463 }
464 CopilotServer::Started {
465 server,
466 status,
467 subscriptions_by_buffer_id,
468 } => {
469 if matches!(status, SignInStatus::Authorized { .. }) {
470 subscriptions_by_buffer_id
471 .entry(buffer_id)
472 .or_insert_with(|| {
473 server
474 .notify::<lsp::notification::DidOpenTextDocument>(
475 lsp::DidOpenTextDocumentParams {
476 text_document: lsp::TextDocumentItem {
477 uri: uri.clone(),
478 language_id: id_for_language(
479 buffer.read(cx).language(),
480 ),
481 version: 0,
482 text: snapshot.text(),
483 },
484 },
485 )
486 .log_err();
487
488 let uri = uri.clone();
489 cx.observe_release(buffer, move |this, _, _| {
490 if let CopilotServer::Started {
491 server,
492 subscriptions_by_buffer_id,
493 ..
494 } = &mut this.server
495 {
496 server
497 .notify::<lsp::notification::DidCloseTextDocument>(
498 lsp::DidCloseTextDocumentParams {
499 text_document: lsp::TextDocumentIdentifier::new(
500 uri.clone(),
501 ),
502 },
503 )
504 .log_err();
505 subscriptions_by_buffer_id.remove(&buffer_id);
506 }
507 })
508 });
509
510 server.clone()
511 } else {
512 return Task::ready(Err(anyhow!("must sign in before using copilot")));
513 }
514 }
515 };
516
517 let settings = cx.global::<Settings>();
518 let position = position.to_point_utf16(&snapshot);
519 let language = snapshot.language_at(position);
520 let language_name = language.map(|language| language.name());
521 let language_name = language_name.as_deref();
522 let tab_size = settings.tab_size(language_name);
523 let hard_tabs = settings.hard_tabs(language_name);
524 let language_id = id_for_language(language);
525
526 let path;
527 let relative_path;
528 if let Some(file) = snapshot.file() {
529 if let Some(file) = file.as_local() {
530 path = file.abs_path(cx);
531 } else {
532 path = file.full_path(cx);
533 }
534 relative_path = file.path().to_path_buf();
535 } else {
536 path = PathBuf::new();
537 relative_path = PathBuf::new();
538 }
539
540 cx.background().spawn(async move {
541 let result = server
542 .request::<R>(request::GetCompletionsParams {
543 doc: request::GetCompletionsDocument {
544 source: snapshot.text(),
545 tab_size: tab_size.into(),
546 indent_size: 1,
547 insert_spaces: !hard_tabs,
548 uri,
549 path: path.to_string_lossy().into(),
550 relative_path: relative_path.to_string_lossy().into(),
551 language_id,
552 position: point_to_lsp(position),
553 version: 0,
554 },
555 })
556 .await?;
557 let completions = result
558 .completions
559 .into_iter()
560 .map(|completion| {
561 let start = snapshot
562 .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
563 let end =
564 snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
565 Completion {
566 range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
567 text: completion.text,
568 }
569 })
570 .collect();
571 anyhow::Ok(completions)
572 })
573 }
574
575 pub fn status(&self) -> Status {
576 match &self.server {
577 CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
578 CopilotServer::Disabled => Status::Disabled,
579 CopilotServer::Error(error) => Status::Error(error.clone()),
580 CopilotServer::Started { status, .. } => match status {
581 SignInStatus::Authorized { .. } => Status::Authorized,
582 SignInStatus::Unauthorized { .. } => Status::Unauthorized,
583 SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
584 prompt: prompt.clone(),
585 },
586 SignInStatus::SignedOut => Status::SignedOut,
587 },
588 }
589 }
590
591 fn update_sign_in_status(
592 &mut self,
593 lsp_status: request::SignInStatus,
594 cx: &mut ModelContext<Self>,
595 ) {
596 if let CopilotServer::Started { status, .. } = &mut self.server {
597 *status = match lsp_status {
598 request::SignInStatus::Ok { .. }
599 | request::SignInStatus::MaybeOk { .. }
600 | request::SignInStatus::AlreadySignedIn { .. } => SignInStatus::Authorized,
601 request::SignInStatus::NotAuthorized { .. } => SignInStatus::Unauthorized,
602 request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
603 };
604 cx.notify();
605 }
606 }
607}
608
609fn id_for_language(language: Option<&Arc<Language>>) -> String {
610 let language_name = language.map(|language| language.name());
611 match language_name.as_deref() {
612 Some("Plain Text") => "plaintext".to_string(),
613 Some(language_name) => language_name.to_lowercase(),
614 None => "plaintext".to_string(),
615 }
616}
617
618async fn clear_copilot_dir() {
619 remove_matching(&paths::COPILOT_DIR, |_| true).await
620}
621
622async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
623 const SERVER_PATH: &'static str = "dist/agent.js";
624
625 ///Check for the latest copilot language server and download it if we haven't already
626 async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
627 let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
628
629 let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
630
631 fs::create_dir_all(version_dir).await?;
632 let server_path = version_dir.join(SERVER_PATH);
633
634 if fs::metadata(&server_path).await.is_err() {
635 // Copilot LSP looks for this dist dir specifcially, so lets add it in.
636 let dist_dir = version_dir.join("dist");
637 fs::create_dir_all(dist_dir.as_path()).await?;
638
639 let url = &release
640 .assets
641 .get(0)
642 .context("Github release for copilot contained no assets")?
643 .browser_download_url;
644
645 let mut response = http
646 .get(&url, Default::default(), true)
647 .await
648 .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
649 let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
650 let archive = Archive::new(decompressed_bytes);
651 archive.unpack(dist_dir).await?;
652
653 remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
654 }
655
656 Ok(server_path)
657 }
658
659 match fetch_latest(http).await {
660 ok @ Result::Ok(..) => ok,
661 e @ Err(..) => {
662 e.log_err();
663 // Fetch a cached binary, if it exists
664 (|| async move {
665 let mut last_version_dir = None;
666 let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
667 while let Some(entry) = entries.next().await {
668 let entry = entry?;
669 if entry.file_type().await?.is_dir() {
670 last_version_dir = Some(entry.path());
671 }
672 }
673 let last_version_dir =
674 last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
675 let server_path = last_version_dir.join(SERVER_PATH);
676 if server_path.exists() {
677 Ok(server_path)
678 } else {
679 Err(anyhow!(
680 "missing executable in directory {:?}",
681 last_version_dir
682 ))
683 }
684 })()
685 .await
686 }
687 }
688}