diff --git a/crates/http_client/src/github_download.rs b/crates/http_client/src/github_download.rs index 02dee08b215e547d632caaf5f94b0872aa6aa20d..642bbf11c11ce8816a1506c3c4989dce434552d8 100644 --- a/crates/http_client/src/github_download.rs +++ b/crates/http_client/src/github_download.rs @@ -1,4 +1,8 @@ -use std::{path::Path, pin::Pin, task::Poll}; +use std::{ + path::{Path, PathBuf}, + pin::Pin, + task::Poll, +}; use anyhow::{Context, Result}; use async_compression::futures::bufread::GzipDecoder; @@ -40,11 +44,37 @@ pub async fn download_server_binary( asset_kind: AssetKind, ) -> Result<(), anyhow::Error> { log::info!("downloading github artifact from {url}"); + let Some(destination_parent) = destination_path.parent() else { + anyhow::bail!("destination path has no parent: {destination_path:?}"); + }; + + let staging_path = staging_path(destination_parent, asset_kind)?; let mut response = http_client .get(url, Default::default(), true) .await .with_context(|| format!("downloading release from {url}"))?; let body = response.body_mut(); + + if let Err(err) = extract_to_staging(body, digest, url, &staging_path, asset_kind).await { + cleanup_staging_path(&staging_path, asset_kind).await; + return Err(err); + } + + if let Err(err) = finalize_download(&staging_path, destination_path).await { + cleanup_staging_path(&staging_path, asset_kind).await; + return Err(err); + } + + Ok(()) +} + +async fn extract_to_staging( + body: impl AsyncRead + Unpin, + digest: Option<&str>, + url: &str, + staging_path: &Path, + asset_kind: AssetKind, +) -> Result<()> { match digest { Some(expected_sha_256) => { let temp_asset_file = tempfile::NamedTempFile::new() @@ -57,7 +87,7 @@ pub async fn download_server_binary( futures::io::copy(&mut BufReader::new(body), &mut writer) .await .with_context(|| { - format!("saving archive contents into the temporary file for {url}",) + format!("saving archive contents into the temporary file for {url}") })?; let asset_sha_256 = format!("{:x}", writer.hasher.finalize()); @@ -69,22 +99,68 @@ pub async fn download_server_binary( .writer .seek(std::io::SeekFrom::Start(0)) .await - .with_context(|| format!("seeking temporary file {destination_path:?}",))?; - stream_file_archive(&mut writer.writer, url, destination_path, asset_kind) + .with_context(|| format!("seeking temporary file for {url}"))?; + stream_file_archive(&mut writer.writer, url, staging_path, asset_kind) .await .with_context(|| { - format!("extracting downloaded asset for {url} into {destination_path:?}",) + format!("extracting downloaded asset for {url} into {staging_path:?}") + })?; + } + None => { + stream_response_archive(body, url, staging_path, asset_kind) + .await + .with_context(|| { + format!("extracting response for asset {url} into {staging_path:?}") })?; } - None => stream_response_archive(body, url, destination_path, asset_kind) - .await - .with_context(|| { - format!("extracting response for asset {url} into {destination_path:?}",) - })?, } Ok(()) } +fn staging_path(parent: &Path, asset_kind: AssetKind) -> Result { + match asset_kind { + AssetKind::TarGz | AssetKind::Zip => { + let dir = tempfile::Builder::new() + .prefix(".tmp-github-download-") + .tempdir_in(parent) + .with_context(|| format!("creating staging directory in {parent:?}"))?; + Ok(dir.keep()) + } + AssetKind::Gz => { + let path = tempfile::Builder::new() + .prefix(".tmp-github-download-") + .tempfile_in(parent) + .with_context(|| format!("creating staging file in {parent:?}"))? + .into_temp_path() + .keep() + .with_context(|| format!("persisting staging file in {parent:?}"))?; + Ok(path) + } + } +} + +async fn cleanup_staging_path(staging_path: &Path, asset_kind: AssetKind) { + match asset_kind { + AssetKind::TarGz | AssetKind::Zip => { + if let Err(err) = async_fs::remove_dir_all(staging_path).await { + log::warn!("failed to remove staging directory {staging_path:?}: {err:?}"); + } + } + AssetKind::Gz => { + if let Err(err) = async_fs::remove_file(staging_path).await { + log::warn!("failed to remove staging file {staging_path:?}: {err:?}"); + } + } + } +} + +async fn finalize_download(staging_path: &Path, destination_path: &Path) -> Result<()> { + async_fs::rename(staging_path, destination_path) + .await + .with_context(|| format!("renaming {staging_path:?} to {destination_path:?}"))?; + Ok(()) +} + async fn stream_response_archive( response: impl AsyncRead + Unpin, url: &str,