1use std::{
2 path::{Path, PathBuf},
3 pin::Pin,
4 task::Poll,
5};
6
7use anyhow::{Context, Result};
8use async_compression::futures::bufread::{BzDecoder, GzipDecoder};
9use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
10use sha2::{Digest, Sha256};
11
12use crate::{HttpClient, github::AssetKind};
13
14#[derive(serde::Deserialize, serde::Serialize, Debug)]
15pub struct GithubBinaryMetadata {
16 pub metadata_version: u64,
17 pub digest: Option<String>,
18}
19
20impl GithubBinaryMetadata {
21 pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
22 let metadata_content = async_fs::read_to_string(metadata_path)
23 .await
24 .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
25 serde_json::from_str(&metadata_content)
26 .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
27 }
28
29 pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
30 let metadata_content = serde_json::to_string(self)
31 .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
32 async_fs::write(metadata_path, metadata_content.as_bytes())
33 .await
34 .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
35 Ok(())
36 }
37}
38
39pub async fn download_server_binary(
40 http_client: &dyn HttpClient,
41 url: &str,
42 digest: Option<&str>,
43 destination_path: &Path,
44 asset_kind: AssetKind,
45) -> Result<(), anyhow::Error> {
46 log::info!("downloading github artifact from {url}");
47 let Some(destination_parent) = destination_path.parent() else {
48 anyhow::bail!("destination path has no parent: {destination_path:?}");
49 };
50
51 let staging_path = staging_path(destination_parent, asset_kind)?;
52 let mut response = http_client
53 .get(url, Default::default(), true)
54 .await
55 .with_context(|| format!("downloading release from {url}"))?;
56 let body = response.body_mut();
57
58 if let Err(err) = extract_to_staging(body, digest, url, &staging_path, asset_kind).await {
59 cleanup_staging_path(&staging_path, asset_kind).await;
60 return Err(err);
61 }
62
63 if let Err(err) = finalize_download(&staging_path, destination_path).await {
64 cleanup_staging_path(&staging_path, asset_kind).await;
65 return Err(err);
66 }
67
68 Ok(())
69}
70
71async fn extract_to_staging(
72 body: impl AsyncRead + Unpin,
73 digest: Option<&str>,
74 url: &str,
75 staging_path: &Path,
76 asset_kind: AssetKind,
77) -> Result<()> {
78 match digest {
79 Some(expected_sha_256) => {
80 let temp_asset_file = tempfile::NamedTempFile::new()
81 .with_context(|| format!("creating a temporary file for {url}"))?;
82 let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
83 let mut writer = HashingWriter {
84 writer: async_fs::File::from(temp_asset_file),
85 hasher: Sha256::new(),
86 };
87 futures::io::copy(&mut BufReader::new(body), &mut writer)
88 .await
89 .with_context(|| {
90 format!("saving archive contents into the temporary file for {url}")
91 })?;
92 let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
93
94 anyhow::ensure!(
95 asset_sha_256 == expected_sha_256,
96 "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
97 );
98 writer
99 .writer
100 .seek(std::io::SeekFrom::Start(0))
101 .await
102 .with_context(|| format!("seeking temporary file for {url}"))?;
103 stream_file_archive(&mut writer.writer, url, staging_path, asset_kind)
104 .await
105 .with_context(|| {
106 format!("extracting downloaded asset for {url} into {staging_path:?}")
107 })?;
108 }
109 None => {
110 stream_response_archive(body, url, staging_path, asset_kind)
111 .await
112 .with_context(|| {
113 format!("extracting response for asset {url} into {staging_path:?}")
114 })?;
115 }
116 }
117 Ok(())
118}
119
120fn staging_path(parent: &Path, asset_kind: AssetKind) -> Result<PathBuf> {
121 match asset_kind {
122 AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
123 let dir = tempfile::Builder::new()
124 .prefix(".tmp-github-download-")
125 .tempdir_in(parent)
126 .with_context(|| format!("creating staging directory in {parent:?}"))?;
127 Ok(dir.keep())
128 }
129 AssetKind::Gz => {
130 let path = tempfile::Builder::new()
131 .prefix(".tmp-github-download-")
132 .tempfile_in(parent)
133 .with_context(|| format!("creating staging file in {parent:?}"))?
134 .into_temp_path()
135 .keep()
136 .with_context(|| format!("persisting staging file in {parent:?}"))?;
137 Ok(path)
138 }
139 }
140}
141
142async fn cleanup_staging_path(staging_path: &Path, asset_kind: AssetKind) {
143 match asset_kind {
144 AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
145 if let Err(err) = async_fs::remove_dir_all(staging_path).await {
146 log::warn!("failed to remove staging directory {staging_path:?}: {err:?}");
147 }
148 }
149 AssetKind::Gz => {
150 if let Err(err) = async_fs::remove_file(staging_path).await {
151 log::warn!("failed to remove staging file {staging_path:?}: {err:?}");
152 }
153 }
154 }
155}
156
157async fn finalize_download(staging_path: &Path, destination_path: &Path) -> Result<()> {
158 _ = async_fs::remove_dir_all(destination_path).await;
159 async_fs::rename(staging_path, destination_path)
160 .await
161 .with_context(|| format!("renaming {staging_path:?} to {destination_path:?}"))?;
162 Ok(())
163}
164
165async fn stream_response_archive(
166 response: impl AsyncRead + Unpin,
167 url: &str,
168 destination_path: &Path,
169 asset_kind: AssetKind,
170) -> Result<()> {
171 match asset_kind {
172 AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
173 AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, response).await?,
174 AssetKind::Gz => extract_gz(destination_path, url, response).await?,
175 AssetKind::Zip => {
176 util::archive::extract_zip(destination_path, response).await?;
177 }
178 };
179 Ok(())
180}
181
182async fn stream_file_archive(
183 file_archive: impl AsyncRead + AsyncSeek + Unpin,
184 url: &str,
185 destination_path: &Path,
186 asset_kind: AssetKind,
187) -> Result<()> {
188 match asset_kind {
189 AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
190 AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, file_archive).await?,
191 AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
192 #[cfg(not(windows))]
193 AssetKind::Zip => {
194 util::archive::extract_seekable_zip(destination_path, file_archive).await?;
195 }
196 #[cfg(windows)]
197 AssetKind::Zip => {
198 util::archive::extract_zip(destination_path, file_archive).await?;
199 }
200 };
201 Ok(())
202}
203
204async fn extract_tar_gz(
205 destination_path: &Path,
206 url: &str,
207 from: impl AsyncRead + Unpin,
208) -> Result<(), anyhow::Error> {
209 let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
210 let archive = async_tar::Archive::new(decompressed_bytes);
211 archive
212 .unpack(&destination_path)
213 .await
214 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
215 Ok(())
216}
217
218async fn extract_tar_bz2(
219 destination_path: &Path,
220 url: &str,
221 from: impl AsyncRead + Unpin,
222) -> Result<(), anyhow::Error> {
223 let decompressed_bytes = BzDecoder::new(BufReader::new(from));
224 let archive = async_tar::Archive::new(decompressed_bytes);
225 archive
226 .unpack(&destination_path)
227 .await
228 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
229 Ok(())
230}
231
232async fn extract_gz(
233 destination_path: &Path,
234 url: &str,
235 from: impl AsyncRead + Unpin,
236) -> Result<(), anyhow::Error> {
237 let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
238 let mut file = async_fs::File::create(&destination_path)
239 .await
240 .with_context(|| {
241 format!("creating a file {destination_path:?} for a download from {url}")
242 })?;
243 futures::io::copy(&mut decompressed_bytes, &mut file)
244 .await
245 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
246 Ok(())
247}
248
249struct HashingWriter<W: AsyncWrite + Unpin> {
250 writer: W,
251 hasher: Sha256,
252}
253
254impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
255 fn poll_write(
256 mut self: Pin<&mut Self>,
257 cx: &mut std::task::Context<'_>,
258 buf: &[u8],
259 ) -> Poll<std::result::Result<usize, std::io::Error>> {
260 match Pin::new(&mut self.writer).poll_write(cx, buf) {
261 Poll::Ready(Ok(n)) => {
262 self.hasher.update(&buf[..n]);
263 Poll::Ready(Ok(n))
264 }
265 other => other,
266 }
267 }
268
269 fn poll_flush(
270 mut self: Pin<&mut Self>,
271 cx: &mut std::task::Context<'_>,
272 ) -> Poll<Result<(), std::io::Error>> {
273 Pin::new(&mut self.writer).poll_flush(cx)
274 }
275
276 fn poll_close(
277 mut self: Pin<&mut Self>,
278 cx: &mut std::task::Context<'_>,
279 ) -> Poll<std::result::Result<(), std::io::Error>> {
280 Pin::new(&mut self.writer).poll_close(cx)
281 }
282}