1use std::{future::Future, path::Path, pin::Pin, task::Poll};
2
3use anyhow::{Context, Result};
4use async_compression::futures::bufread::GzipDecoder;
5use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
6use sha2::{Digest, Sha256};
7
8use crate::{HttpClient, github::AssetKind};
9
10#[derive(serde::Deserialize, serde::Serialize, Debug)]
11pub struct GithubBinaryMetadata {
12 pub metadata_version: u64,
13 pub digest: Option<String>,
14}
15
16impl GithubBinaryMetadata {
17 pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
18 let metadata_content = async_fs::read_to_string(metadata_path)
19 .await
20 .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
21 serde_json::from_str(&metadata_content)
22 .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
23 }
24
25 pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
26 let metadata_content = serde_json::to_string(self)
27 .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
28 async_fs::write(metadata_path, metadata_content.as_bytes())
29 .await
30 .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
31 Ok(())
32 }
33}
34
35pub async fn download_server_binary(
36 http_client: &dyn HttpClient,
37 url: &str,
38 digest: Option<&str>,
39 destination_path: &Path,
40 asset_kind: AssetKind,
41) -> Result<(), anyhow::Error> {
42 log::info!("downloading github artifact from {url}");
43 let mut response = http_client
44 .get(url, Default::default(), true)
45 .await
46 .with_context(|| format!("downloading release from {url}"))?;
47 let body = response.body_mut();
48 match digest {
49 Some(expected_sha_256) => {
50 let temp_asset_file = tempfile::NamedTempFile::new()
51 .with_context(|| format!("creating a temporary file for {url}"))?;
52 let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
53 let mut writer = HashingWriter {
54 writer: async_fs::File::from(temp_asset_file),
55 hasher: Sha256::new(),
56 };
57 futures::io::copy(&mut BufReader::new(body), &mut writer)
58 .await
59 .with_context(|| {
60 format!("saving archive contents into the temporary file for {url}",)
61 })?;
62 let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
63
64 anyhow::ensure!(
65 asset_sha_256 == expected_sha_256,
66 "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
67 );
68 writer
69 .writer
70 .seek(std::io::SeekFrom::Start(0))
71 .await
72 .with_context(|| format!("seeking temporary file {destination_path:?}",))?;
73 stream_file_archive(&mut writer.writer, url, destination_path, asset_kind)
74 .await
75 .with_context(|| {
76 format!("extracting downloaded asset for {url} into {destination_path:?}",)
77 })?;
78 }
79 None => stream_response_archive(body, url, destination_path, asset_kind)
80 .await
81 .with_context(|| {
82 format!("extracting response for asset {url} into {destination_path:?}",)
83 })?,
84 }
85 Ok(())
86}
87
88pub async fn fetch_github_binary_with_digest_check<ValidityCheck, ValidityCheckFuture>(
89 binary_path: &Path,
90 metadata_path: &Path,
91 expected_digest: Option<String>,
92 url: &str,
93 asset_kind: AssetKind,
94 download_destination: &Path,
95 http_client: &dyn HttpClient,
96 validity_check: ValidityCheck,
97) -> Result<()>
98where
99 ValidityCheck: FnOnce() -> ValidityCheckFuture,
100 ValidityCheckFuture: Future<Output = Result<()>>,
101{
102 let metadata = GithubBinaryMetadata::read_from_file(metadata_path)
103 .await
104 .ok();
105
106 if let Some(metadata) = metadata {
107 let validity_check_result = validity_check().await;
108
109 if let (Some(actual_digest), Some(expected_digest_ref)) =
110 (&metadata.digest, &expected_digest)
111 {
112 if actual_digest == expected_digest_ref {
113 if validity_check_result.is_ok() {
114 return Ok(());
115 }
116 } else {
117 log::info!(
118 "SHA-256 mismatch for {binary_path:?} asset, downloading new asset. Expected: {expected_digest_ref}, Got: {actual_digest}"
119 );
120 }
121 } else if validity_check_result.is_ok() {
122 return Ok(());
123 }
124 }
125
126 download_server_binary(
127 http_client,
128 url,
129 expected_digest.as_deref(),
130 download_destination,
131 asset_kind,
132 )
133 .await?;
134
135 GithubBinaryMetadata::write_to_file(
136 &GithubBinaryMetadata {
137 metadata_version: 1,
138 digest: expected_digest,
139 },
140 metadata_path,
141 )
142 .await?;
143
144 Ok(())
145}
146
147async fn stream_response_archive(
148 response: impl AsyncRead + Unpin,
149 url: &str,
150 destination_path: &Path,
151 asset_kind: AssetKind,
152) -> Result<()> {
153 match asset_kind {
154 AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
155 AssetKind::Gz => extract_gz(destination_path, url, response).await?,
156 AssetKind::Zip => {
157 util::archive::extract_zip(destination_path, response).await?;
158 }
159 };
160 Ok(())
161}
162
163async fn stream_file_archive(
164 file_archive: impl AsyncRead + AsyncSeek + Unpin,
165 url: &str,
166 destination_path: &Path,
167 asset_kind: AssetKind,
168) -> Result<()> {
169 match asset_kind {
170 AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
171 AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
172 #[cfg(not(windows))]
173 AssetKind::Zip => {
174 util::archive::extract_seekable_zip(destination_path, file_archive).await?;
175 }
176 #[cfg(windows)]
177 AssetKind::Zip => {
178 util::archive::extract_zip(destination_path, file_archive).await?;
179 }
180 };
181 Ok(())
182}
183
184async fn extract_tar_gz(
185 destination_path: &Path,
186 url: &str,
187 from: impl AsyncRead + Unpin,
188) -> Result<(), anyhow::Error> {
189 let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
190 let archive = async_tar::Archive::new(decompressed_bytes);
191 archive
192 .unpack(&destination_path)
193 .await
194 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
195 Ok(())
196}
197
198async fn extract_gz(
199 destination_path: &Path,
200 url: &str,
201 from: impl AsyncRead + Unpin,
202) -> Result<(), anyhow::Error> {
203 let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
204 let mut file = async_fs::File::create(&destination_path)
205 .await
206 .with_context(|| {
207 format!("creating a file {destination_path:?} for a download from {url}")
208 })?;
209 futures::io::copy(&mut decompressed_bytes, &mut file)
210 .await
211 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
212 Ok(())
213}
214
215struct HashingWriter<W: AsyncWrite + Unpin> {
216 writer: W,
217 hasher: Sha256,
218}
219
220impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
221 fn poll_write(
222 mut self: Pin<&mut Self>,
223 cx: &mut std::task::Context<'_>,
224 buf: &[u8],
225 ) -> Poll<std::result::Result<usize, std::io::Error>> {
226 match Pin::new(&mut self.writer).poll_write(cx, buf) {
227 Poll::Ready(Ok(n)) => {
228 self.hasher.update(&buf[..n]);
229 Poll::Ready(Ok(n))
230 }
231 other => other,
232 }
233 }
234
235 fn poll_flush(
236 mut self: Pin<&mut Self>,
237 cx: &mut std::task::Context<'_>,
238 ) -> Poll<Result<(), std::io::Error>> {
239 Pin::new(&mut self.writer).poll_flush(cx)
240 }
241
242 fn poll_close(
243 mut self: Pin<&mut Self>,
244 cx: &mut std::task::Context<'_>,
245 ) -> Poll<std::result::Result<(), std::io::Error>> {
246 Pin::new(&mut self.writer).poll_close(cx)
247 }
248}