github_download.rs

  1use std::{path::Path, pin::Pin, task::Poll};
  2
  3use anyhow::{Context, Result};
  4use async_compression::futures::bufread::GzipDecoder;
  5use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
  6use http_client::github::AssetKind;
  7use language::LspAdapterDelegate;
  8use sha2::{Digest, Sha256};
  9
 10#[derive(serde::Deserialize, serde::Serialize, Debug)]
 11pub(crate) struct GithubBinaryMetadata {
 12    pub(crate) metadata_version: u64,
 13    pub(crate) digest: Option<String>,
 14}
 15
 16impl GithubBinaryMetadata {
 17    pub(crate) async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
 18        let metadata_content = async_fs::read_to_string(metadata_path)
 19            .await
 20            .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
 21        let metadata: GithubBinaryMetadata = serde_json::from_str(&metadata_content)
 22            .with_context(|| format!("parsing metadata file at {metadata_path:?}"))?;
 23        Ok(metadata)
 24    }
 25
 26    pub(crate) async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
 27        let metadata_content = serde_json::to_string(self)
 28            .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
 29        async_fs::write(metadata_path, metadata_content.as_bytes())
 30            .await
 31            .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
 32        Ok(())
 33    }
 34}
 35
 36pub(crate) async fn download_server_binary(
 37    delegate: &dyn LspAdapterDelegate,
 38    url: &str,
 39    digest: Option<&str>,
 40    destination_path: &Path,
 41    asset_kind: AssetKind,
 42) -> Result<(), anyhow::Error> {
 43    log::info!("downloading github artifact from {url}");
 44    let mut response = delegate
 45        .http_client()
 46        .get(url, Default::default(), true)
 47        .await
 48        .with_context(|| format!("downloading release from {url}"))?;
 49    let body = response.body_mut();
 50    match digest {
 51        Some(expected_sha_256) => {
 52            let temp_asset_file = tempfile::NamedTempFile::new()
 53                .with_context(|| format!("creating a temporary file for {url}"))?;
 54            let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
 55            let mut writer = HashingWriter {
 56                writer: async_fs::File::from(temp_asset_file),
 57                hasher: Sha256::new(),
 58            };
 59            futures::io::copy(&mut BufReader::new(body), &mut writer)
 60                .await
 61                .with_context(|| {
 62                    format!("saving archive contents into the temporary file for {url}",)
 63                })?;
 64            let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
 65
 66            // Strip "sha256:" prefix for comparison
 67            let expected_sha_256 = expected_sha_256
 68                .strip_prefix("sha256:")
 69                .unwrap_or(expected_sha_256);
 70
 71            anyhow::ensure!(
 72                asset_sha_256 == expected_sha_256,
 73                "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
 74            );
 75            writer
 76                .writer
 77                .seek(std::io::SeekFrom::Start(0))
 78                .await
 79                .with_context(|| format!("seeking temporary file {destination_path:?}",))?;
 80            stream_file_archive(&mut writer.writer, url, destination_path, asset_kind)
 81                .await
 82                .with_context(|| {
 83                    format!("extracting downloaded asset for {url} into {destination_path:?}",)
 84                })?;
 85        }
 86        None => stream_response_archive(body, url, destination_path, asset_kind)
 87            .await
 88            .with_context(|| {
 89                format!("extracting response for asset {url} into {destination_path:?}",)
 90            })?,
 91    }
 92    Ok(())
 93}
 94
 95async fn stream_response_archive(
 96    response: impl AsyncRead + Unpin,
 97    url: &str,
 98    destination_path: &Path,
 99    asset_kind: AssetKind,
100) -> Result<()> {
101    match asset_kind {
102        AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
103        AssetKind::Gz => extract_gz(destination_path, url, response).await?,
104        AssetKind::Zip => {
105            util::archive::extract_zip(&destination_path, response).await?;
106        }
107    };
108    Ok(())
109}
110
111async fn stream_file_archive(
112    file_archive: impl AsyncRead + AsyncSeek + Unpin,
113    url: &str,
114    destination_path: &Path,
115    asset_kind: AssetKind,
116) -> Result<()> {
117    match asset_kind {
118        AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
119        AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
120        #[cfg(not(windows))]
121        AssetKind::Zip => {
122            util::archive::extract_seekable_zip(&destination_path, file_archive).await?;
123        }
124        #[cfg(windows)]
125        AssetKind::Zip => {
126            util::archive::extract_zip(&destination_path, file_archive).await?;
127        }
128    };
129    Ok(())
130}
131
132async fn extract_tar_gz(
133    destination_path: &Path,
134    url: &str,
135    from: impl AsyncRead + Unpin,
136) -> Result<(), anyhow::Error> {
137    let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
138    let archive = async_tar::Archive::new(decompressed_bytes);
139    archive
140        .unpack(&destination_path)
141        .await
142        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
143    Ok(())
144}
145
146async fn extract_gz(
147    destination_path: &Path,
148    url: &str,
149    from: impl AsyncRead + Unpin,
150) -> Result<(), anyhow::Error> {
151    let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
152    let mut file = smol::fs::File::create(&destination_path)
153        .await
154        .with_context(|| {
155            format!("creating a file {destination_path:?} for a download from {url}")
156        })?;
157    futures::io::copy(&mut decompressed_bytes, &mut file)
158        .await
159        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
160    Ok(())
161}
162
163struct HashingWriter<W: AsyncWrite + Unpin> {
164    writer: W,
165    hasher: Sha256,
166}
167
168impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
169    fn poll_write(
170        mut self: Pin<&mut Self>,
171        cx: &mut std::task::Context<'_>,
172        buf: &[u8],
173    ) -> Poll<std::result::Result<usize, std::io::Error>> {
174        match Pin::new(&mut self.writer).poll_write(cx, buf) {
175            Poll::Ready(Ok(n)) => {
176                self.hasher.update(&buf[..n]);
177                Poll::Ready(Ok(n))
178            }
179            other => other,
180        }
181    }
182
183    fn poll_flush(
184        mut self: Pin<&mut Self>,
185        cx: &mut std::task::Context<'_>,
186    ) -> Poll<Result<(), std::io::Error>> {
187        Pin::new(&mut self.writer).poll_flush(cx)
188    }
189
190    fn poll_close(
191        mut self: Pin<&mut Self>,
192        cx: &mut std::task::Context<'_>,
193    ) -> Poll<std::result::Result<(), std::io::Error>> {
194        Pin::new(&mut self.writer).poll_close(cx)
195    }
196}