copilot.rs

  1mod request;
  2mod sign_in;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_compression::futures::bufread::GzipDecoder;
  6use async_tar::Archive;
  7use client::Client;
  8use collections::HashMap;
  9use futures::{future::Shared, Future, FutureExt, TryFutureExt};
 10use gpui::{
 11    actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
 12    Task,
 13};
 14use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16};
 15use log::{debug, error};
 16use lsp::LanguageServer;
 17use node_runtime::NodeRuntime;
 18use request::{LogMessage, StatusNotification};
 19use settings::Settings;
 20use smol::{fs, io::BufReader, stream::StreamExt};
 21use staff_mode::{not_staff_mode, staff_mode};
 22
 23use std::{
 24    ffi::OsString,
 25    ops::Range,
 26    path::{Path, PathBuf},
 27    sync::Arc,
 28};
 29use util::{
 30    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 31};
 32
 33const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth";
 34actions!(copilot_auth, [SignIn, SignOut]);
 35
 36const COPILOT_NAMESPACE: &'static str = "copilot";
 37actions!(copilot, [NextSuggestion, PreviousSuggestion, Reinstall]);
 38
 39pub fn init(client: Arc<Client>, node_runtime: Arc<NodeRuntime>, cx: &mut MutableAppContext) {
 40    staff_mode(cx, {
 41        move |cx| {
 42            cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 43                filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
 44                filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
 45            });
 46
 47            let copilot = cx.add_model({
 48                let node_runtime = node_runtime.clone();
 49                let http = client.http_client().clone();
 50                move |cx| Copilot::start(http, node_runtime, cx)
 51            });
 52            cx.set_global(copilot.clone());
 53
 54            observe_namespaces(cx, copilot);
 55
 56            sign_in::init(cx);
 57        }
 58    });
 59    not_staff_mode(cx, |cx| {
 60        cx.update_global::<collections::CommandPaletteFilter, _, _>(|filter, _cx| {
 61            filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 62            filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 63        });
 64    });
 65
 66    cx.add_global_action(|_: &SignIn, cx| {
 67        if let Some(copilot) = Copilot::global(cx) {
 68            copilot
 69                .update(cx, |copilot, cx| copilot.sign_in(cx))
 70                .detach_and_log_err(cx);
 71        }
 72    });
 73    cx.add_global_action(|_: &SignOut, cx| {
 74        if let Some(copilot) = Copilot::global(cx) {
 75            copilot
 76                .update(cx, |copilot, cx| copilot.sign_out(cx))
 77                .detach_and_log_err(cx);
 78        }
 79    });
 80
 81    cx.add_global_action(|_: &Reinstall, cx| {
 82        if let Some(copilot) = Copilot::global(cx) {
 83            copilot
 84                .update(cx, |copilot, cx| copilot.reinstall(cx))
 85                .detach();
 86        }
 87    });
 88}
 89
 90fn observe_namespaces(cx: &mut MutableAppContext, copilot: ModelHandle<Copilot>) {
 91    cx.observe(&copilot, |handle, cx| {
 92        let status = handle.read(cx).status();
 93        cx.update_global::<collections::CommandPaletteFilter, _, _>(
 94            move |filter, _cx| match status {
 95                Status::Disabled => {
 96                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
 97                    filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE);
 98                }
 99                Status::Authorized => {
100                    filter.filtered_namespaces.remove(COPILOT_NAMESPACE);
101                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
102                }
103                _ => {
104                    filter.filtered_namespaces.insert(COPILOT_NAMESPACE);
105                    filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE);
106                }
107            },
108        );
109    })
110    .detach();
111}
112
113enum CopilotServer {
114    Disabled,
115    Starting {
116        task: Shared<Task<()>>,
117    },
118    Error(Arc<str>),
119    Started {
120        server: Arc<LanguageServer>,
121        status: SignInStatus,
122        subscriptions_by_buffer_id: HashMap<usize, gpui::Subscription>,
123    },
124}
125
126#[derive(Clone, Debug)]
127enum SignInStatus {
128    Authorized {
129        _user: String,
130    },
131    Unauthorized {
132        _user: String,
133    },
134    SigningIn {
135        prompt: Option<request::PromptUserDeviceFlow>,
136        task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
137    },
138    SignedOut,
139}
140
141#[derive(Debug, Clone)]
142pub enum Status {
143    Starting {
144        task: Shared<Task<()>>,
145    },
146    Error(Arc<str>),
147    Disabled,
148    SignedOut,
149    SigningIn {
150        prompt: Option<request::PromptUserDeviceFlow>,
151    },
152    Unauthorized,
153    Authorized,
154}
155
156impl Status {
157    pub fn is_authorized(&self) -> bool {
158        matches!(self, Status::Authorized)
159    }
160}
161
162#[derive(Debug, PartialEq, Eq)]
163pub struct Completion {
164    pub range: Range<Anchor>,
165    pub text: String,
166}
167
168pub struct Copilot {
169    http: Arc<dyn HttpClient>,
170    node_runtime: Arc<NodeRuntime>,
171    server: CopilotServer,
172}
173
174impl Entity for Copilot {
175    type Event = ();
176}
177
178impl Copilot {
179    pub fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
180        if cx.has_global::<ModelHandle<Self>>() {
181            Some(cx.global::<ModelHandle<Self>>().clone())
182        } else {
183            None
184        }
185    }
186
187    fn start(
188        http: Arc<dyn HttpClient>,
189        node_runtime: Arc<NodeRuntime>,
190        cx: &mut ModelContext<Self>,
191    ) -> Self {
192        cx.observe_global::<Settings, _>({
193            let http = http.clone();
194            let node_runtime = node_runtime.clone();
195            move |this, cx| {
196                if cx.global::<Settings>().enable_copilot_integration {
197                    if matches!(this.server, CopilotServer::Disabled) {
198                        let start_task = cx
199                            .spawn({
200                                let http = http.clone();
201                                let node_runtime = node_runtime.clone();
202                                move |this, cx| {
203                                    Self::start_language_server(http, node_runtime, this, cx)
204                                }
205                            })
206                            .shared();
207                        this.server = CopilotServer::Starting { task: start_task };
208                        cx.notify();
209                    }
210                } else {
211                    this.server = CopilotServer::Disabled;
212                    cx.notify();
213                }
214            }
215        })
216        .detach();
217
218        if cx.global::<Settings>().enable_copilot_integration {
219            let start_task = cx
220                .spawn({
221                    let http = http.clone();
222                    let node_runtime = node_runtime.clone();
223                    move |this, cx| Self::start_language_server(http, node_runtime, this, cx)
224                })
225                .shared();
226
227            Self {
228                http,
229                node_runtime,
230                server: CopilotServer::Starting { task: start_task },
231            }
232        } else {
233            Self {
234                http,
235                node_runtime,
236                server: CopilotServer::Disabled,
237            }
238        }
239    }
240
241    fn start_language_server(
242        http: Arc<dyn HttpClient>,
243        node_runtime: Arc<NodeRuntime>,
244        this: ModelHandle<Self>,
245        mut cx: AsyncAppContext,
246    ) -> impl Future<Output = ()> {
247        async move {
248            let start_language_server = async {
249                let server_path = get_copilot_lsp(http).await?;
250                let node_path = node_runtime.binary_path().await?;
251                let arguments: &[OsString] = &[server_path.into(), "--stdio".into()];
252                let server = LanguageServer::new(
253                    0,
254                    &node_path,
255                    arguments,
256                    Path::new("/"),
257                    None,
258                    cx.clone(),
259                )?;
260
261                let server = server.initialize(Default::default()).await?;
262                let status = server
263                    .request::<request::CheckStatus>(request::CheckStatusParams {
264                        local_checks_only: false,
265                    })
266                    .await?;
267
268                server
269                    .on_notification::<LogMessage, _>(|params, _cx| {
270                        match params.level {
271                            // Copilot is pretty agressive about logging
272                            0 => debug!("copilot: {}", params.message),
273                            1 => debug!("copilot: {}", params.message),
274                            _ => error!("copilot: {}", params.message),
275                        }
276
277                        debug!("copilot metadata: {}", params.metadata_str);
278                        debug!("copilot extra: {:?}", params.extra);
279                    })
280                    .detach();
281
282                server
283                    .on_notification::<StatusNotification, _>(
284                        |_, _| { /* Silence the notification */ },
285                    )
286                    .detach();
287
288                anyhow::Ok((server, status))
289            };
290
291            let server = start_language_server.await;
292            this.update(&mut cx, |this, cx| {
293                cx.notify();
294                match server {
295                    Ok((server, status)) => {
296                        this.server = CopilotServer::Started {
297                            server,
298                            status: SignInStatus::SignedOut,
299                            subscriptions_by_buffer_id: Default::default(),
300                        };
301                        this.update_sign_in_status(status, cx);
302                    }
303                    Err(error) => {
304                        this.server = CopilotServer::Error(error.to_string().into());
305                        cx.notify()
306                    }
307                }
308            })
309        }
310    }
311
312    fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
313        if let CopilotServer::Started { server, status, .. } = &mut self.server {
314            let task = match status {
315                SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => {
316                    Task::ready(Ok(())).shared()
317                }
318                SignInStatus::SigningIn { task, .. } => {
319                    cx.notify();
320                    task.clone()
321                }
322                SignInStatus::SignedOut => {
323                    let server = server.clone();
324                    let task = cx
325                        .spawn(|this, mut cx| async move {
326                            let sign_in = async {
327                                let sign_in = server
328                                    .request::<request::SignInInitiate>(
329                                        request::SignInInitiateParams {},
330                                    )
331                                    .await?;
332                                match sign_in {
333                                    request::SignInInitiateResult::AlreadySignedIn { user } => {
334                                        Ok(request::SignInStatus::Ok { user })
335                                    }
336                                    request::SignInInitiateResult::PromptUserDeviceFlow(flow) => {
337                                        this.update(&mut cx, |this, cx| {
338                                            if let CopilotServer::Started { status, .. } =
339                                                &mut this.server
340                                            {
341                                                if let SignInStatus::SigningIn {
342                                                    prompt: prompt_flow,
343                                                    ..
344                                                } = status
345                                                {
346                                                    *prompt_flow = Some(flow.clone());
347                                                    cx.notify();
348                                                }
349                                            }
350                                        });
351                                        let response = server
352                                            .request::<request::SignInConfirm>(
353                                                request::SignInConfirmParams {
354                                                    user_code: flow.user_code,
355                                                },
356                                            )
357                                            .await?;
358                                        Ok(response)
359                                    }
360                                }
361                            };
362
363                            let sign_in = sign_in.await;
364                            this.update(&mut cx, |this, cx| match sign_in {
365                                Ok(status) => {
366                                    this.update_sign_in_status(status, cx);
367                                    Ok(())
368                                }
369                                Err(error) => {
370                                    this.update_sign_in_status(
371                                        request::SignInStatus::NotSignedIn,
372                                        cx,
373                                    );
374                                    Err(Arc::new(error))
375                                }
376                            })
377                        })
378                        .shared();
379                    *status = SignInStatus::SigningIn {
380                        prompt: None,
381                        task: task.clone(),
382                    };
383                    cx.notify();
384                    task
385                }
386            };
387
388            cx.foreground()
389                .spawn(task.map_err(|err| anyhow!("{:?}", err)))
390        } else {
391            // If we're downloading, wait until download is finished
392            // If we're in a stuck state, display to the user
393            Task::ready(Err(anyhow!("copilot hasn't started yet")))
394        }
395    }
396
397    fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
398        if let CopilotServer::Started { server, status, .. } = &mut self.server {
399            *status = SignInStatus::SignedOut;
400            cx.notify();
401
402            let server = server.clone();
403            cx.background().spawn(async move {
404                server
405                    .request::<request::SignOut>(request::SignOutParams {})
406                    .await?;
407                anyhow::Ok(())
408            })
409        } else {
410            Task::ready(Err(anyhow!("copilot hasn't started yet")))
411        }
412    }
413
414    fn reinstall(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
415        let start_task = cx
416            .spawn({
417                let http = self.http.clone();
418                let node_runtime = self.node_runtime.clone();
419                move |this, cx| async move {
420                    clear_copilot_dir().await;
421                    Self::start_language_server(http, node_runtime, this, cx).await
422                }
423            })
424            .shared();
425
426        self.server = CopilotServer::Starting {
427            task: start_task.clone(),
428        };
429
430        cx.notify();
431
432        cx.foreground().spawn(start_task)
433    }
434
435    pub fn completions<T>(
436        &mut self,
437        buffer: &ModelHandle<Buffer>,
438        position: T,
439        cx: &mut ModelContext<Self>,
440    ) -> Task<Result<Vec<Completion>>>
441    where
442        T: ToPointUtf16,
443    {
444        self.request_completions::<request::GetCompletions, _>(buffer, position, cx)
445    }
446
447    pub fn completions_cycling<T>(
448        &mut self,
449        buffer: &ModelHandle<Buffer>,
450        position: T,
451        cx: &mut ModelContext<Self>,
452    ) -> Task<Result<Vec<Completion>>>
453    where
454        T: ToPointUtf16,
455    {
456        self.request_completions::<request::GetCompletionsCycling, _>(buffer, position, cx)
457    }
458
459    fn request_completions<R, T>(
460        &mut self,
461        buffer: &ModelHandle<Buffer>,
462        position: T,
463        cx: &mut ModelContext<Self>,
464    ) -> Task<Result<Vec<Completion>>>
465    where
466        R: lsp::request::Request<
467            Params = request::GetCompletionsParams,
468            Result = request::GetCompletionsResult,
469        >,
470        T: ToPointUtf16,
471    {
472        let buffer_id = buffer.id();
473        let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap();
474        let snapshot = buffer.read(cx).snapshot();
475        let server = match &mut self.server {
476            CopilotServer::Starting { .. } => {
477                return Task::ready(Err(anyhow!("copilot is still starting")))
478            }
479            CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))),
480            CopilotServer::Error(error) => {
481                return Task::ready(Err(anyhow!(
482                    "copilot was not started because of an error: {}",
483                    error
484                )))
485            }
486            CopilotServer::Started {
487                server,
488                status,
489                subscriptions_by_buffer_id,
490            } => {
491                if matches!(status, SignInStatus::Authorized { .. }) {
492                    subscriptions_by_buffer_id
493                        .entry(buffer_id)
494                        .or_insert_with(|| {
495                            server
496                                .notify::<lsp::notification::DidOpenTextDocument>(
497                                    lsp::DidOpenTextDocumentParams {
498                                        text_document: lsp::TextDocumentItem {
499                                            uri: uri.clone(),
500                                            language_id: id_for_language(
501                                                buffer.read(cx).language(),
502                                            ),
503                                            version: 0,
504                                            text: snapshot.text(),
505                                        },
506                                    },
507                                )
508                                .log_err();
509
510                            let uri = uri.clone();
511                            cx.observe_release(buffer, move |this, _, _| {
512                                if let CopilotServer::Started {
513                                    server,
514                                    subscriptions_by_buffer_id,
515                                    ..
516                                } = &mut this.server
517                                {
518                                    server
519                                        .notify::<lsp::notification::DidCloseTextDocument>(
520                                            lsp::DidCloseTextDocumentParams {
521                                                text_document: lsp::TextDocumentIdentifier::new(
522                                                    uri.clone(),
523                                                ),
524                                            },
525                                        )
526                                        .log_err();
527                                    subscriptions_by_buffer_id.remove(&buffer_id);
528                                }
529                            })
530                        });
531
532                    server.clone()
533                } else {
534                    return Task::ready(Err(anyhow!("must sign in before using copilot")));
535                }
536            }
537        };
538
539        let settings = cx.global::<Settings>();
540        let position = position.to_point_utf16(&snapshot);
541        let language = snapshot.language_at(position);
542        let language_name = language.map(|language| language.name());
543        let language_name = language_name.as_deref();
544        let tab_size = settings.tab_size(language_name);
545        let hard_tabs = settings.hard_tabs(language_name);
546        let language_id = id_for_language(language);
547
548        let path;
549        let relative_path;
550        if let Some(file) = snapshot.file() {
551            if let Some(file) = file.as_local() {
552                path = file.abs_path(cx);
553            } else {
554                path = file.full_path(cx);
555            }
556            relative_path = file.path().to_path_buf();
557        } else {
558            path = PathBuf::new();
559            relative_path = PathBuf::new();
560        }
561
562        cx.background().spawn(async move {
563            let result = server
564                .request::<R>(request::GetCompletionsParams {
565                    doc: request::GetCompletionsDocument {
566                        source: snapshot.text(),
567                        tab_size: tab_size.into(),
568                        indent_size: 1,
569                        insert_spaces: !hard_tabs,
570                        uri,
571                        path: path.to_string_lossy().into(),
572                        relative_path: relative_path.to_string_lossy().into(),
573                        language_id,
574                        position: point_to_lsp(position),
575                        version: 0,
576                    },
577                })
578                .await?;
579            let completions = result
580                .completions
581                .into_iter()
582                .map(|completion| {
583                    let start = snapshot
584                        .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left);
585                    let end =
586                        snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left);
587                    Completion {
588                        range: snapshot.anchor_before(start)..snapshot.anchor_after(end),
589                        text: completion.text,
590                    }
591                })
592                .collect();
593            anyhow::Ok(completions)
594        })
595    }
596
597    pub fn status(&self) -> Status {
598        match &self.server {
599            CopilotServer::Starting { task } => Status::Starting { task: task.clone() },
600            CopilotServer::Disabled => Status::Disabled,
601            CopilotServer::Error(error) => Status::Error(error.clone()),
602            CopilotServer::Started { status, .. } => match status {
603                SignInStatus::Authorized { .. } => Status::Authorized,
604                SignInStatus::Unauthorized { .. } => Status::Unauthorized,
605                SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
606                    prompt: prompt.clone(),
607                },
608                SignInStatus::SignedOut => Status::SignedOut,
609            },
610        }
611    }
612
613    fn update_sign_in_status(
614        &mut self,
615        lsp_status: request::SignInStatus,
616        cx: &mut ModelContext<Self>,
617    ) {
618        if let CopilotServer::Started { status, .. } = &mut self.server {
619            *status = match lsp_status {
620                request::SignInStatus::Ok { user }
621                | request::SignInStatus::MaybeOk { user }
622                | request::SignInStatus::AlreadySignedIn { user } => {
623                    SignInStatus::Authorized { _user: user }
624                }
625                request::SignInStatus::NotAuthorized { user } => {
626                    SignInStatus::Unauthorized { _user: user }
627                }
628                request::SignInStatus::NotSignedIn => SignInStatus::SignedOut,
629            };
630            cx.notify();
631        }
632    }
633}
634
635fn id_for_language(language: Option<&Arc<Language>>) -> String {
636    let language_name = language.map(|language| language.name());
637    match language_name.as_deref() {
638        Some("Plain Text") => "plaintext".to_string(),
639        Some(language_name) => language_name.to_lowercase(),
640        None => "plaintext".to_string(),
641    }
642}
643
644async fn clear_copilot_dir() {
645    remove_matching(&paths::COPILOT_DIR, |_| true).await
646}
647
648async fn get_copilot_lsp(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
649    const SERVER_PATH: &'static str = "dist/agent.js";
650
651    ///Check for the latest copilot language server and download it if we haven't already
652    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
653        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
654
655        let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name));
656
657        fs::create_dir_all(version_dir).await?;
658        let server_path = version_dir.join(SERVER_PATH);
659
660        if fs::metadata(&server_path).await.is_err() {
661            // Copilot LSP looks for this dist dir specifcially, so lets add it in.
662            let dist_dir = version_dir.join("dist");
663            fs::create_dir_all(dist_dir.as_path()).await?;
664
665            let url = &release
666                .assets
667                .get(0)
668                .context("Github release for copilot contained no assets")?
669                .browser_download_url;
670
671            let mut response = http
672                .get(&url, Default::default(), true)
673                .await
674                .map_err(|err| anyhow!("error downloading copilot release: {}", err))?;
675            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
676            let archive = Archive::new(decompressed_bytes);
677            archive.unpack(dist_dir).await?;
678
679            remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await;
680        }
681
682        Ok(server_path)
683    }
684
685    match fetch_latest(http).await {
686        ok @ Result::Ok(..) => ok,
687        e @ Err(..) => {
688            e.log_err();
689            // Fetch a cached binary, if it exists
690            (|| async move {
691                let mut last_version_dir = None;
692                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
693                while let Some(entry) = entries.next().await {
694                    let entry = entry?;
695                    if entry.file_type().await?.is_dir() {
696                        last_version_dir = Some(entry.path());
697                    }
698                }
699                let last_version_dir =
700                    last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?;
701                let server_path = last_version_dir.join(SERVER_PATH);
702                if server_path.exists() {
703                    Ok(server_path)
704                } else {
705                    Err(anyhow!(
706                        "missing executable in directory {:?}",
707                        last_version_dir
708                    ))
709                }
710            })()
711            .await
712        }
713    }
714}