copilot.rs

  1mod request;
  2
  3use anyhow::{anyhow, Result};
  4use async_compression::futures::bufread::GzipDecoder;
  5use client::Client;
  6use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext};
  7use lsp::LanguageServer;
  8use smol::{fs, io::BufReader, stream::StreamExt};
  9use std::{
 10    env::consts,
 11    path::{Path, PathBuf},
 12    sync::Arc,
 13};
 14use util::{
 15    fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
 16};
 17
 18actions!(copilot, [SignIn]);
 19
 20pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
 21    let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
 22    cx.set_global(copilot);
 23}
 24
 25enum CopilotServer {
 26    Downloading,
 27    Error(String),
 28    Started {
 29        server: Arc<LanguageServer>,
 30        status: SignInStatus,
 31    },
 32}
 33
 34enum SignInStatus {
 35    Authorized,
 36    Unauthorized,
 37    SignedOut,
 38}
 39
 40struct Copilot {
 41    server: CopilotServer,
 42}
 43
 44impl Entity for Copilot {
 45    type Event = ();
 46}
 47
 48impl Copilot {
 49    fn global(cx: &AppContext) -> Option<ModelHandle<Self>> {
 50        if cx.has_global::<ModelHandle<Self>>() {
 51            Some(cx.global::<ModelHandle<Self>>().clone())
 52        } else {
 53            None
 54        }
 55    }
 56
 57    fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
 58        let copilot = Self {
 59            server: CopilotServer::Downloading,
 60        };
 61        cx.spawn(|this, mut cx| async move {
 62            let start_language_server = async {
 63                let server_path = get_lsp_binary(http).await?;
 64                let server =
 65                    LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
 66                let server = server.initialize(Default::default()).await?;
 67                let status = server
 68                    .request::<request::CheckStatus>(request::CheckStatusParams {
 69                        local_checks_only: false,
 70                    })
 71                    .await?;
 72                let status = match status.status.as_str() {
 73                    "OK" | "MaybeOk" => SignInStatus::Authorized,
 74                    "NotAuthorized" => SignInStatus::Unauthorized,
 75                    _ => SignInStatus::SignedOut,
 76                };
 77                anyhow::Ok((server, status))
 78            };
 79
 80            let server = start_language_server.await;
 81            this.update(&mut cx, |this, cx| {
 82                cx.notify();
 83                match server {
 84                    Ok((server, status)) => {
 85                        this.server = CopilotServer::Started { server, status };
 86                        Ok(())
 87                    }
 88                    Err(error) => {
 89                        this.server = CopilotServer::Error(error.to_string());
 90                        Err(error)
 91                    }
 92                }
 93            })
 94        })
 95        .detach_and_log_err(cx);
 96        copilot
 97    }
 98}
 99
100async fn get_lsp_binary(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
101    ///Check for the latest copilot language server and download it if we haven't already
102    async fn fetch_latest(http: Arc<dyn HttpClient>) -> anyhow::Result<PathBuf> {
103        let release = latest_github_release("zed-industries/copilot", http.clone()).await?;
104        let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH);
105        let asset = release
106            .assets
107            .iter()
108            .find(|asset| asset.name == asset_name)
109            .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?;
110
111        fs::create_dir_all(&*paths::COPILOT_DIR).await?;
112        let destination_path =
113            paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH));
114
115        if fs::metadata(&destination_path).await.is_err() {
116            let mut response = http
117                .get(&asset.browser_download_url, Default::default(), true)
118                .await
119                .map_err(|err| anyhow!("error downloading release: {}", err))?;
120            let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
121            let mut file = fs::File::create(&destination_path).await?;
122            futures::io::copy(decompressed_bytes, &mut file).await?;
123            fs::set_permissions(
124                &destination_path,
125                <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755),
126            )
127            .await?;
128
129            remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await;
130        }
131
132        Ok(destination_path)
133    }
134
135    match fetch_latest(http).await {
136        ok @ Result::Ok(..) => ok,
137        e @ Err(..) => {
138            e.log_err();
139            // Fetch a cached binary, if it exists
140            (|| async move {
141                let mut last = None;
142                let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?;
143                while let Some(entry) = entries.next().await {
144                    last = Some(entry?.path());
145                }
146                last.ok_or_else(|| anyhow!("no cached binary"))
147            })()
148            .await
149        }
150    }
151}