1use std::{
2 path::{Path, PathBuf},
3 pin::Pin,
4 task::Poll,
5};
6
7use anyhow::{Context, Result};
8use async_compression::futures::bufread::{BzDecoder, GzipDecoder};
9use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
10use sha2::{Digest, Sha256};
11
12use crate::{HttpClient, github::AssetKind};
13
14#[derive(serde::Deserialize, serde::Serialize, Debug)]
15pub struct GithubBinaryMetadata {
16 pub metadata_version: u64,
17 pub digest: Option<String>,
18}
19
20impl GithubBinaryMetadata {
21 pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
22 let metadata_content = async_fs::read_to_string(metadata_path)
23 .await
24 .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
25 serde_json::from_str(&metadata_content)
26 .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
27 }
28
29 pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
30 let metadata_content = serde_json::to_string(self)
31 .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
32 async_fs::write(metadata_path, metadata_content.as_bytes())
33 .await
34 .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
35 Ok(())
36 }
37}
38
39pub async fn download_server_binary(
40 http_client: &dyn HttpClient,
41 url: &str,
42 digest: Option<&str>,
43 destination_path: &Path,
44 asset_kind: AssetKind,
45) -> Result<(), anyhow::Error> {
46 log::info!("downloading github artifact from {url}");
47 let Some(destination_parent) = destination_path.parent() else {
48 anyhow::bail!("destination path has no parent: {destination_path:?}");
49 };
50
51 let staging_path = staging_path(destination_parent, asset_kind)?;
52 let mut response = http_client
53 .get(url, Default::default(), true)
54 .await
55 .with_context(|| format!("downloading release from {url}"))?;
56 let body = response.body_mut();
57
58 if let Err(err) = extract_to_staging(body, digest, url, &staging_path, asset_kind).await {
59 cleanup_staging_path(&staging_path, asset_kind).await;
60 return Err(err);
61 }
62
63 if let Err(err) = finalize_download(&staging_path, destination_path).await {
64 cleanup_staging_path(&staging_path, asset_kind).await;
65 return Err(err);
66 }
67
68 Ok(())
69}
70
71async fn extract_to_staging(
72 body: impl AsyncRead + Unpin,
73 digest: Option<&str>,
74 url: &str,
75 staging_path: &Path,
76 asset_kind: AssetKind,
77) -> Result<()> {
78 match digest {
79 Some(expected_sha_256) => {
80 let temp_asset_file = tempfile::NamedTempFile::new()
81 .with_context(|| format!("creating a temporary file for {url}"))?;
82 let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
83 let mut writer = HashingWriter {
84 writer: async_fs::File::from(temp_asset_file),
85 hasher: Sha256::new(),
86 };
87 futures::io::copy(&mut BufReader::new(body), &mut writer)
88 .await
89 .with_context(|| {
90 format!("saving archive contents into the temporary file for {url}")
91 })?;
92 let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
93
94 anyhow::ensure!(
95 asset_sha_256 == expected_sha_256,
96 "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
97 );
98 writer
99 .writer
100 .seek(std::io::SeekFrom::Start(0))
101 .await
102 .with_context(|| format!("seeking temporary file for {url}"))?;
103 stream_file_archive(&mut writer.writer, url, staging_path, asset_kind)
104 .await
105 .with_context(|| {
106 format!("extracting downloaded asset for {url} into {staging_path:?}")
107 })?;
108 }
109 None => {
110 stream_response_archive(body, url, staging_path, asset_kind)
111 .await
112 .with_context(|| {
113 format!("extracting response for asset {url} into {staging_path:?}")
114 })?;
115 }
116 }
117 Ok(())
118}
119
120fn staging_path(parent: &Path, asset_kind: AssetKind) -> Result<PathBuf> {
121 match asset_kind {
122 AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
123 let dir = tempfile::Builder::new()
124 .prefix(".tmp-github-download-")
125 .tempdir_in(parent)
126 .with_context(|| format!("creating staging directory in {parent:?}"))?;
127 Ok(dir.keep())
128 }
129 AssetKind::Gz => {
130 let path = tempfile::Builder::new()
131 .prefix(".tmp-github-download-")
132 .tempfile_in(parent)
133 .with_context(|| format!("creating staging file in {parent:?}"))?
134 .into_temp_path()
135 .keep()
136 .with_context(|| format!("persisting staging file in {parent:?}"))?;
137 Ok(path)
138 }
139 }
140}
141
142async fn cleanup_staging_path(staging_path: &Path, asset_kind: AssetKind) {
143 match asset_kind {
144 AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
145 if let Err(err) = async_fs::remove_dir_all(staging_path).await {
146 log::warn!("failed to remove staging directory {staging_path:?}: {err:?}");
147 }
148 }
149 AssetKind::Gz => {
150 if let Err(err) = async_fs::remove_file(staging_path).await {
151 log::warn!("failed to remove staging file {staging_path:?}: {err:?}");
152 }
153 }
154 }
155}
156
157async fn finalize_download(staging_path: &Path, destination_path: &Path) -> Result<()> {
158 _ = async_fs::remove_dir_all(destination_path).await;
159 async_fs::rename(staging_path, destination_path)
160 .await
161 .with_context(|| format!("renaming {staging_path:?} to {destination_path:?}"))?;
162 Ok(())
163}
164
165async fn stream_response_archive(
166 response: impl AsyncRead + Unpin,
167 url: &str,
168 destination_path: &Path,
169 asset_kind: AssetKind,
170) -> Result<()> {
171 match asset_kind {
172 AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
173 AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, response).await?,
174 AssetKind::Gz => extract_gz(destination_path, url, response).await?,
175 AssetKind::Zip => {
176 util::archive::extract_zip(destination_path, response).await?;
177 }
178 };
179 Ok(())
180}
181
182async fn stream_file_archive(
183 file_archive: impl AsyncRead + AsyncSeek + Unpin,
184 url: &str,
185 destination_path: &Path,
186 asset_kind: AssetKind,
187) -> Result<()> {
188 match asset_kind {
189 AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
190 AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, file_archive).await?,
191 AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
192 #[cfg(not(windows))]
193 AssetKind::Zip => {
194 util::archive::extract_seekable_zip(destination_path, file_archive).await?;
195 }
196 #[cfg(windows)]
197 AssetKind::Zip => {
198 util::archive::extract_zip(destination_path, file_archive).await?;
199 }
200 };
201 Ok(())
202}
203
204async fn extract_tar_gz(
205 destination_path: &Path,
206 url: &str,
207 from: impl AsyncRead + Unpin,
208) -> Result<(), anyhow::Error> {
209 let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
210 unpack_tar_archive(destination_path, url, decompressed_bytes).await?;
211 Ok(())
212}
213
214async fn extract_tar_bz2(
215 destination_path: &Path,
216 url: &str,
217 from: impl AsyncRead + Unpin,
218) -> Result<(), anyhow::Error> {
219 let decompressed_bytes = BzDecoder::new(BufReader::new(from));
220 unpack_tar_archive(destination_path, url, decompressed_bytes).await?;
221 Ok(())
222}
223
224async fn unpack_tar_archive(
225 destination_path: &Path,
226 url: &str,
227 archive_bytes: impl AsyncRead + Unpin,
228) -> Result<(), anyhow::Error> {
229 // We don't need to set the modified time. It's irrelevant to downloaded
230 // archive verification, and some filesystems return errors when asked to
231 // apply it after extraction.
232 let archive = async_tar::ArchiveBuilder::new(archive_bytes)
233 .set_preserve_mtime(false)
234 .build();
235 archive
236 .unpack(&destination_path)
237 .await
238 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
239 Ok(())
240}
241
242async fn extract_gz(
243 destination_path: &Path,
244 url: &str,
245 from: impl AsyncRead + Unpin,
246) -> Result<(), anyhow::Error> {
247 let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
248 let mut file = async_fs::File::create(&destination_path)
249 .await
250 .with_context(|| {
251 format!("creating a file {destination_path:?} for a download from {url}")
252 })?;
253 futures::io::copy(&mut decompressed_bytes, &mut file)
254 .await
255 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
256 Ok(())
257}
258
259struct HashingWriter<W: AsyncWrite + Unpin> {
260 writer: W,
261 hasher: Sha256,
262}
263
264impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
265 fn poll_write(
266 mut self: Pin<&mut Self>,
267 cx: &mut std::task::Context<'_>,
268 buf: &[u8],
269 ) -> Poll<std::result::Result<usize, std::io::Error>> {
270 match Pin::new(&mut self.writer).poll_write(cx, buf) {
271 Poll::Ready(Ok(n)) => {
272 self.hasher.update(&buf[..n]);
273 Poll::Ready(Ok(n))
274 }
275 other => other,
276 }
277 }
278
279 fn poll_flush(
280 mut self: Pin<&mut Self>,
281 cx: &mut std::task::Context<'_>,
282 ) -> Poll<Result<(), std::io::Error>> {
283 Pin::new(&mut self.writer).poll_flush(cx)
284 }
285
286 fn poll_close(
287 mut self: Pin<&mut Self>,
288 cx: &mut std::task::Context<'_>,
289 ) -> Poll<std::result::Result<(), std::io::Error>> {
290 Pin::new(&mut self.writer).poll_close(cx)
291 }
292}