github_download.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    pin::Pin,
  4    task::Poll,
  5};
  6
  7use anyhow::{Context, Result};
  8use async_compression::futures::bufread::{BzDecoder, GzipDecoder};
  9use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
 10use sha2::{Digest, Sha256};
 11
 12use crate::{HttpClient, github::AssetKind};
 13
 14#[derive(serde::Deserialize, serde::Serialize, Debug)]
 15pub struct GithubBinaryMetadata {
 16    pub metadata_version: u64,
 17    pub digest: Option<String>,
 18}
 19
 20impl GithubBinaryMetadata {
 21    pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
 22        let metadata_content = async_fs::read_to_string(metadata_path)
 23            .await
 24            .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
 25        serde_json::from_str(&metadata_content)
 26            .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
 27    }
 28
 29    pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
 30        let metadata_content = serde_json::to_string(self)
 31            .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
 32        async_fs::write(metadata_path, metadata_content.as_bytes())
 33            .await
 34            .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
 35        Ok(())
 36    }
 37}
 38
 39pub async fn download_server_binary(
 40    http_client: &dyn HttpClient,
 41    url: &str,
 42    digest: Option<&str>,
 43    destination_path: &Path,
 44    asset_kind: AssetKind,
 45) -> Result<(), anyhow::Error> {
 46    log::info!("downloading github artifact from {url}");
 47    let Some(destination_parent) = destination_path.parent() else {
 48        anyhow::bail!("destination path has no parent: {destination_path:?}");
 49    };
 50
 51    let staging_path = staging_path(destination_parent, asset_kind)?;
 52    let mut response = http_client
 53        .get(url, Default::default(), true)
 54        .await
 55        .with_context(|| format!("downloading release from {url}"))?;
 56    let body = response.body_mut();
 57
 58    if let Err(err) = extract_to_staging(body, digest, url, &staging_path, asset_kind).await {
 59        cleanup_staging_path(&staging_path, asset_kind).await;
 60        return Err(err);
 61    }
 62
 63    if let Err(err) = finalize_download(&staging_path, destination_path).await {
 64        cleanup_staging_path(&staging_path, asset_kind).await;
 65        return Err(err);
 66    }
 67
 68    Ok(())
 69}
 70
 71async fn extract_to_staging(
 72    body: impl AsyncRead + Unpin,
 73    digest: Option<&str>,
 74    url: &str,
 75    staging_path: &Path,
 76    asset_kind: AssetKind,
 77) -> Result<()> {
 78    match digest {
 79        Some(expected_sha_256) => {
 80            let temp_asset_file = tempfile::NamedTempFile::new()
 81                .with_context(|| format!("creating a temporary file for {url}"))?;
 82            let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
 83            let mut writer = HashingWriter {
 84                writer: async_fs::File::from(temp_asset_file),
 85                hasher: Sha256::new(),
 86            };
 87            futures::io::copy(&mut BufReader::new(body), &mut writer)
 88                .await
 89                .with_context(|| {
 90                    format!("saving archive contents into the temporary file for {url}")
 91                })?;
 92            let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
 93
 94            anyhow::ensure!(
 95                asset_sha_256 == expected_sha_256,
 96                "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
 97            );
 98            writer
 99                .writer
100                .seek(std::io::SeekFrom::Start(0))
101                .await
102                .with_context(|| format!("seeking temporary file for {url}"))?;
103            stream_file_archive(&mut writer.writer, url, staging_path, asset_kind)
104                .await
105                .with_context(|| {
106                    format!("extracting downloaded asset for {url} into {staging_path:?}")
107                })?;
108        }
109        None => {
110            stream_response_archive(body, url, staging_path, asset_kind)
111                .await
112                .with_context(|| {
113                    format!("extracting response for asset {url} into {staging_path:?}")
114                })?;
115        }
116    }
117    Ok(())
118}
119
120fn staging_path(parent: &Path, asset_kind: AssetKind) -> Result<PathBuf> {
121    match asset_kind {
122        AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
123            let dir = tempfile::Builder::new()
124                .prefix(".tmp-github-download-")
125                .tempdir_in(parent)
126                .with_context(|| format!("creating staging directory in {parent:?}"))?;
127            Ok(dir.keep())
128        }
129        AssetKind::Gz => {
130            let path = tempfile::Builder::new()
131                .prefix(".tmp-github-download-")
132                .tempfile_in(parent)
133                .with_context(|| format!("creating staging file in {parent:?}"))?
134                .into_temp_path()
135                .keep()
136                .with_context(|| format!("persisting staging file in {parent:?}"))?;
137            Ok(path)
138        }
139    }
140}
141
142async fn cleanup_staging_path(staging_path: &Path, asset_kind: AssetKind) {
143    match asset_kind {
144        AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
145            if let Err(err) = async_fs::remove_dir_all(staging_path).await {
146                log::warn!("failed to remove staging directory {staging_path:?}: {err:?}");
147            }
148        }
149        AssetKind::Gz => {
150            if let Err(err) = async_fs::remove_file(staging_path).await {
151                log::warn!("failed to remove staging file {staging_path:?}: {err:?}");
152            }
153        }
154    }
155}
156
157async fn finalize_download(staging_path: &Path, destination_path: &Path) -> Result<()> {
158    _ = async_fs::remove_dir_all(destination_path).await;
159    async_fs::rename(staging_path, destination_path)
160        .await
161        .with_context(|| format!("renaming {staging_path:?} to {destination_path:?}"))?;
162    Ok(())
163}
164
165async fn stream_response_archive(
166    response: impl AsyncRead + Unpin,
167    url: &str,
168    destination_path: &Path,
169    asset_kind: AssetKind,
170) -> Result<()> {
171    match asset_kind {
172        AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
173        AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, response).await?,
174        AssetKind::Gz => extract_gz(destination_path, url, response).await?,
175        AssetKind::Zip => {
176            util::archive::extract_zip(destination_path, response).await?;
177        }
178    };
179    Ok(())
180}
181
182async fn stream_file_archive(
183    file_archive: impl AsyncRead + AsyncSeek + Unpin,
184    url: &str,
185    destination_path: &Path,
186    asset_kind: AssetKind,
187) -> Result<()> {
188    match asset_kind {
189        AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
190        AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, file_archive).await?,
191        AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
192        #[cfg(not(windows))]
193        AssetKind::Zip => {
194            util::archive::extract_seekable_zip(destination_path, file_archive).await?;
195        }
196        #[cfg(windows)]
197        AssetKind::Zip => {
198            util::archive::extract_zip(destination_path, file_archive).await?;
199        }
200    };
201    Ok(())
202}
203
204async fn extract_tar_gz(
205    destination_path: &Path,
206    url: &str,
207    from: impl AsyncRead + Unpin,
208) -> Result<(), anyhow::Error> {
209    let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
210    let archive = async_tar::Archive::new(decompressed_bytes);
211    archive
212        .unpack(&destination_path)
213        .await
214        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
215    Ok(())
216}
217
218async fn extract_tar_bz2(
219    destination_path: &Path,
220    url: &str,
221    from: impl AsyncRead + Unpin,
222) -> Result<(), anyhow::Error> {
223    let decompressed_bytes = BzDecoder::new(BufReader::new(from));
224    let archive = async_tar::Archive::new(decompressed_bytes);
225    archive
226        .unpack(&destination_path)
227        .await
228        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
229    Ok(())
230}
231
232async fn extract_gz(
233    destination_path: &Path,
234    url: &str,
235    from: impl AsyncRead + Unpin,
236) -> Result<(), anyhow::Error> {
237    let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
238    let mut file = async_fs::File::create(&destination_path)
239        .await
240        .with_context(|| {
241            format!("creating a file {destination_path:?} for a download from {url}")
242        })?;
243    futures::io::copy(&mut decompressed_bytes, &mut file)
244        .await
245        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
246    Ok(())
247}
248
249struct HashingWriter<W: AsyncWrite + Unpin> {
250    writer: W,
251    hasher: Sha256,
252}
253
254impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
255    fn poll_write(
256        mut self: Pin<&mut Self>,
257        cx: &mut std::task::Context<'_>,
258        buf: &[u8],
259    ) -> Poll<std::result::Result<usize, std::io::Error>> {
260        match Pin::new(&mut self.writer).poll_write(cx, buf) {
261            Poll::Ready(Ok(n)) => {
262                self.hasher.update(&buf[..n]);
263                Poll::Ready(Ok(n))
264            }
265            other => other,
266        }
267    }
268
269    fn poll_flush(
270        mut self: Pin<&mut Self>,
271        cx: &mut std::task::Context<'_>,
272    ) -> Poll<Result<(), std::io::Error>> {
273        Pin::new(&mut self.writer).poll_flush(cx)
274    }
275
276    fn poll_close(
277        mut self: Pin<&mut Self>,
278        cx: &mut std::task::Context<'_>,
279    ) -> Poll<std::result::Result<(), std::io::Error>> {
280        Pin::new(&mut self.writer).poll_close(cx)
281    }
282}