github_download.rs

  1use std::{path::Path, pin::Pin, task::Poll};
  2
  3use anyhow::{Context, Result};
  4use async_compression::futures::bufread::GzipDecoder;
  5use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
  6use http_client::github::AssetKind;
  7use language::LspAdapterDelegate;
  8use sha2::{Digest, Sha256};
  9
 10#[derive(serde::Deserialize, serde::Serialize, Debug)]
 11pub(crate) struct GithubBinaryMetadata {
 12    pub(crate) metadata_version: u64,
 13    pub(crate) digest: Option<String>,
 14}
 15
 16impl GithubBinaryMetadata {
 17    pub(crate) async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
 18        let metadata_content = async_fs::read_to_string(metadata_path)
 19            .await
 20            .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
 21        serde_json::from_str(&metadata_content)
 22            .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
 23    }
 24
 25    pub(crate) async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
 26        let metadata_content = serde_json::to_string(self)
 27            .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
 28        async_fs::write(metadata_path, metadata_content.as_bytes())
 29            .await
 30            .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
 31        Ok(())
 32    }
 33}
 34
 35pub(crate) async fn download_server_binary(
 36    delegate: &dyn LspAdapterDelegate,
 37    url: &str,
 38    digest: Option<&str>,
 39    destination_path: &Path,
 40    asset_kind: AssetKind,
 41) -> Result<(), anyhow::Error> {
 42    log::info!("downloading github artifact from {url}");
 43    let mut response = delegate
 44        .http_client()
 45        .get(url, Default::default(), true)
 46        .await
 47        .with_context(|| format!("downloading release from {url}"))?;
 48    let body = response.body_mut();
 49    match digest {
 50        Some(expected_sha_256) => {
 51            let temp_asset_file = tempfile::NamedTempFile::new()
 52                .with_context(|| format!("creating a temporary file for {url}"))?;
 53            let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
 54            let mut writer = HashingWriter {
 55                writer: async_fs::File::from(temp_asset_file),
 56                hasher: Sha256::new(),
 57            };
 58            futures::io::copy(&mut BufReader::new(body), &mut writer)
 59                .await
 60                .with_context(|| {
 61                    format!("saving archive contents into the temporary file for {url}",)
 62                })?;
 63            let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
 64
 65            anyhow::ensure!(
 66                asset_sha_256 == expected_sha_256,
 67                "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
 68            );
 69            writer
 70                .writer
 71                .seek(std::io::SeekFrom::Start(0))
 72                .await
 73                .with_context(|| format!("seeking temporary file {destination_path:?}",))?;
 74            stream_file_archive(&mut writer.writer, url, destination_path, asset_kind)
 75                .await
 76                .with_context(|| {
 77                    format!("extracting downloaded asset for {url} into {destination_path:?}",)
 78                })?;
 79        }
 80        None => stream_response_archive(body, url, destination_path, asset_kind)
 81            .await
 82            .with_context(|| {
 83                format!("extracting response for asset {url} into {destination_path:?}",)
 84            })?,
 85    }
 86    Ok(())
 87}
 88
 89async fn stream_response_archive(
 90    response: impl AsyncRead + Unpin,
 91    url: &str,
 92    destination_path: &Path,
 93    asset_kind: AssetKind,
 94) -> Result<()> {
 95    match asset_kind {
 96        AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
 97        AssetKind::Gz => extract_gz(destination_path, url, response).await?,
 98        AssetKind::Zip => {
 99            util::archive::extract_zip(destination_path, response).await?;
100        }
101    };
102    Ok(())
103}
104
105async fn stream_file_archive(
106    file_archive: impl AsyncRead + AsyncSeek + Unpin,
107    url: &str,
108    destination_path: &Path,
109    asset_kind: AssetKind,
110) -> Result<()> {
111    match asset_kind {
112        AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
113        AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
114        #[cfg(not(windows))]
115        AssetKind::Zip => {
116            util::archive::extract_seekable_zip(destination_path, file_archive).await?;
117        }
118        #[cfg(windows)]
119        AssetKind::Zip => {
120            util::archive::extract_zip(destination_path, file_archive).await?;
121        }
122    };
123    Ok(())
124}
125
126async fn extract_tar_gz(
127    destination_path: &Path,
128    url: &str,
129    from: impl AsyncRead + Unpin,
130) -> Result<(), anyhow::Error> {
131    let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
132    let archive = async_tar::Archive::new(decompressed_bytes);
133    archive
134        .unpack(&destination_path)
135        .await
136        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
137    Ok(())
138}
139
140async fn extract_gz(
141    destination_path: &Path,
142    url: &str,
143    from: impl AsyncRead + Unpin,
144) -> Result<(), anyhow::Error> {
145    let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
146    let mut file = smol::fs::File::create(&destination_path)
147        .await
148        .with_context(|| {
149            format!("creating a file {destination_path:?} for a download from {url}")
150        })?;
151    futures::io::copy(&mut decompressed_bytes, &mut file)
152        .await
153        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
154    Ok(())
155}
156
157struct HashingWriter<W: AsyncWrite + Unpin> {
158    writer: W,
159    hasher: Sha256,
160}
161
162impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
163    fn poll_write(
164        mut self: Pin<&mut Self>,
165        cx: &mut std::task::Context<'_>,
166        buf: &[u8],
167    ) -> Poll<std::result::Result<usize, std::io::Error>> {
168        match Pin::new(&mut self.writer).poll_write(cx, buf) {
169            Poll::Ready(Ok(n)) => {
170                self.hasher.update(&buf[..n]);
171                Poll::Ready(Ok(n))
172            }
173            other => other,
174        }
175    }
176
177    fn poll_flush(
178        mut self: Pin<&mut Self>,
179        cx: &mut std::task::Context<'_>,
180    ) -> Poll<Result<(), std::io::Error>> {
181        Pin::new(&mut self.writer).poll_flush(cx)
182    }
183
184    fn poll_close(
185        mut self: Pin<&mut Self>,
186        cx: &mut std::task::Context<'_>,
187    ) -> Poll<std::result::Result<(), std::io::Error>> {
188        Pin::new(&mut self.writer).poll_close(cx)
189    }
190}