github_download.rs

  1use std::{future::Future, path::Path, pin::Pin, task::Poll};
  2
  3use anyhow::{Context, Result};
  4use async_compression::futures::bufread::GzipDecoder;
  5use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
  6use sha2::{Digest, Sha256};
  7
  8use crate::{HttpClient, github::AssetKind};
  9
 10#[derive(serde::Deserialize, serde::Serialize, Debug)]
 11pub struct GithubBinaryMetadata {
 12    pub metadata_version: u64,
 13    pub digest: Option<String>,
 14}
 15
 16impl GithubBinaryMetadata {
 17    pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
 18        let metadata_content = async_fs::read_to_string(metadata_path)
 19            .await
 20            .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
 21        serde_json::from_str(&metadata_content)
 22            .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
 23    }
 24
 25    pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
 26        let metadata_content = serde_json::to_string(self)
 27            .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
 28        async_fs::write(metadata_path, metadata_content.as_bytes())
 29            .await
 30            .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
 31        Ok(())
 32    }
 33}
 34
 35pub async fn download_server_binary(
 36    http_client: &dyn HttpClient,
 37    url: &str,
 38    digest: Option<&str>,
 39    destination_path: &Path,
 40    asset_kind: AssetKind,
 41) -> Result<(), anyhow::Error> {
 42    log::info!("downloading github artifact from {url}");
 43    let mut response = http_client
 44        .get(url, Default::default(), true)
 45        .await
 46        .with_context(|| format!("downloading release from {url}"))?;
 47    let body = response.body_mut();
 48    match digest {
 49        Some(expected_sha_256) => {
 50            let temp_asset_file = tempfile::NamedTempFile::new()
 51                .with_context(|| format!("creating a temporary file for {url}"))?;
 52            let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
 53            let mut writer = HashingWriter {
 54                writer: async_fs::File::from(temp_asset_file),
 55                hasher: Sha256::new(),
 56            };
 57            futures::io::copy(&mut BufReader::new(body), &mut writer)
 58                .await
 59                .with_context(|| {
 60                    format!("saving archive contents into the temporary file for {url}",)
 61                })?;
 62            let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
 63
 64            anyhow::ensure!(
 65                asset_sha_256 == expected_sha_256,
 66                "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
 67            );
 68            writer
 69                .writer
 70                .seek(std::io::SeekFrom::Start(0))
 71                .await
 72                .with_context(|| format!("seeking temporary file {destination_path:?}",))?;
 73            stream_file_archive(&mut writer.writer, url, destination_path, asset_kind)
 74                .await
 75                .with_context(|| {
 76                    format!("extracting downloaded asset for {url} into {destination_path:?}",)
 77                })?;
 78        }
 79        None => stream_response_archive(body, url, destination_path, asset_kind)
 80            .await
 81            .with_context(|| {
 82                format!("extracting response for asset {url} into {destination_path:?}",)
 83            })?,
 84    }
 85    Ok(())
 86}
 87
 88pub async fn fetch_github_binary_with_digest_check<ValidityCheck, ValidityCheckFuture>(
 89    binary_path: &Path,
 90    metadata_path: &Path,
 91    expected_digest: Option<String>,
 92    url: &str,
 93    asset_kind: AssetKind,
 94    download_destination: &Path,
 95    http_client: &dyn HttpClient,
 96    validity_check: ValidityCheck,
 97) -> Result<()>
 98where
 99    ValidityCheck: FnOnce() -> ValidityCheckFuture,
100    ValidityCheckFuture: Future<Output = Result<()>>,
101{
102    let metadata = GithubBinaryMetadata::read_from_file(metadata_path)
103        .await
104        .ok();
105
106    if let Some(metadata) = metadata {
107        let validity_check_result = validity_check().await;
108
109        if let (Some(actual_digest), Some(expected_digest_ref)) =
110            (&metadata.digest, &expected_digest)
111        {
112            if actual_digest == expected_digest_ref {
113                if validity_check_result.is_ok() {
114                    return Ok(());
115                }
116            } else {
117                log::info!(
118                    "SHA-256 mismatch for {binary_path:?} asset, downloading new asset. Expected: {expected_digest_ref}, Got: {actual_digest}"
119                );
120            }
121        } else if validity_check_result.is_ok() {
122            return Ok(());
123        }
124    }
125
126    download_server_binary(
127        http_client,
128        url,
129        expected_digest.as_deref(),
130        download_destination,
131        asset_kind,
132    )
133    .await?;
134
135    GithubBinaryMetadata::write_to_file(
136        &GithubBinaryMetadata {
137            metadata_version: 1,
138            digest: expected_digest,
139        },
140        metadata_path,
141    )
142    .await?;
143
144    Ok(())
145}
146
147async fn stream_response_archive(
148    response: impl AsyncRead + Unpin,
149    url: &str,
150    destination_path: &Path,
151    asset_kind: AssetKind,
152) -> Result<()> {
153    match asset_kind {
154        AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
155        AssetKind::Gz => extract_gz(destination_path, url, response).await?,
156        AssetKind::Zip => {
157            util::archive::extract_zip(destination_path, response).await?;
158        }
159    };
160    Ok(())
161}
162
163async fn stream_file_archive(
164    file_archive: impl AsyncRead + AsyncSeek + Unpin,
165    url: &str,
166    destination_path: &Path,
167    asset_kind: AssetKind,
168) -> Result<()> {
169    match asset_kind {
170        AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
171        AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
172        #[cfg(not(windows))]
173        AssetKind::Zip => {
174            util::archive::extract_seekable_zip(destination_path, file_archive).await?;
175        }
176        #[cfg(windows)]
177        AssetKind::Zip => {
178            util::archive::extract_zip(destination_path, file_archive).await?;
179        }
180    };
181    Ok(())
182}
183
184async fn extract_tar_gz(
185    destination_path: &Path,
186    url: &str,
187    from: impl AsyncRead + Unpin,
188) -> Result<(), anyhow::Error> {
189    let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
190    let archive = async_tar::Archive::new(decompressed_bytes);
191    archive
192        .unpack(&destination_path)
193        .await
194        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
195    Ok(())
196}
197
198async fn extract_gz(
199    destination_path: &Path,
200    url: &str,
201    from: impl AsyncRead + Unpin,
202) -> Result<(), anyhow::Error> {
203    let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
204    let mut file = async_fs::File::create(&destination_path)
205        .await
206        .with_context(|| {
207            format!("creating a file {destination_path:?} for a download from {url}")
208        })?;
209    futures::io::copy(&mut decompressed_bytes, &mut file)
210        .await
211        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
212    Ok(())
213}
214
215struct HashingWriter<W: AsyncWrite + Unpin> {
216    writer: W,
217    hasher: Sha256,
218}
219
220impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
221    fn poll_write(
222        mut self: Pin<&mut Self>,
223        cx: &mut std::task::Context<'_>,
224        buf: &[u8],
225    ) -> Poll<std::result::Result<usize, std::io::Error>> {
226        match Pin::new(&mut self.writer).poll_write(cx, buf) {
227            Poll::Ready(Ok(n)) => {
228                self.hasher.update(&buf[..n]);
229                Poll::Ready(Ok(n))
230            }
231            other => other,
232        }
233    }
234
235    fn poll_flush(
236        mut self: Pin<&mut Self>,
237        cx: &mut std::task::Context<'_>,
238    ) -> Poll<Result<(), std::io::Error>> {
239        Pin::new(&mut self.writer).poll_flush(cx)
240    }
241
242    fn poll_close(
243        mut self: Pin<&mut Self>,
244        cx: &mut std::task::Context<'_>,
245    ) -> Poll<std::result::Result<(), std::io::Error>> {
246        Pin::new(&mut self.writer).poll_close(cx)
247    }
248}