1use std::{path::PathBuf, pin::Pin, sync::Arc};
2
3use fs::Fs;
4use futures::{AsyncRead, AsyncReadExt, io::BufReader};
5use http::Request;
6use http_client::{AsyncBody, HttpClient};
7use serde::{Deserialize, Serialize};
8
9use crate::devcontainer_api::DevContainerError;
10
11#[derive(Debug, Serialize, Deserialize)]
12#[serde(rename_all = "camelCase")]
13pub(crate) struct TokenResponse {
14 pub(crate) token: String,
15}
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub(crate) struct DockerManifestsResponse {
20 pub(crate) layers: Vec<ManifestLayer>,
21}
22
23#[derive(Debug, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub(crate) struct ManifestLayer {
26 pub(crate) digest: String,
27}
28
29/// Gets a bearer token for pulling from a container registry repository.
30///
31/// This uses the registry's `/token` endpoint directly, which works for
32/// `ghcr.io` and other registries that follow the same convention. For
33/// registries that require a full `WWW-Authenticate` negotiation flow this
34/// would need to be extended.
35pub(crate) async fn get_oci_token(
36 registry: &str,
37 repository_path: &str,
38 client: &Arc<dyn HttpClient>,
39) -> Result<TokenResponse, String> {
40 let url = format!(
41 "https://{registry}/token?service={registry}&scope=repository:{repository_path}:pull",
42 );
43 log::debug!("Fetching OCI token from: {}", url);
44 get_deserialized_response("", &url, client)
45 .await
46 .map_err(|e| {
47 log::error!("OCI token request failed for {}: {e}", url);
48 e
49 })
50}
51
52pub(crate) async fn get_latest_oci_manifest(
53 token: &str,
54 registry: &str,
55 repository_path: &str,
56 client: &Arc<dyn HttpClient>,
57 id: Option<&str>,
58) -> Result<DockerManifestsResponse, String> {
59 get_oci_manifest(registry, repository_path, token, client, "latest", id).await
60}
61
62pub(crate) async fn get_oci_manifest(
63 registry: &str,
64 repository_path: &str,
65 token: &str,
66 client: &Arc<dyn HttpClient>,
67 version: &str,
68 id: Option<&str>,
69) -> Result<DockerManifestsResponse, String> {
70 let url = match id {
71 Some(id) => format!("https://{registry}/v2/{repository_path}/{id}/manifests/{version}"),
72 None => format!("https://{registry}/v2/{repository_path}/manifests/{version}"),
73 };
74
75 get_deserialized_response(token, &url, client).await
76}
77
78pub(crate) async fn get_deserializable_oci_blob<T>(
79 token: &str,
80 registry: &str,
81 repository_path: &str,
82 blob_digest: &str,
83 client: &Arc<dyn HttpClient>,
84) -> Result<T, String>
85where
86 T: for<'a> Deserialize<'a>,
87{
88 let url = format!("https://{registry}/v2/{repository_path}/blobs/{blob_digest}");
89 get_deserialized_response(token, &url, client).await
90}
91
92pub(crate) async fn download_oci_tarball(
93 token: &str,
94 registry: &str,
95 repository_path: &str,
96 blob_digest: &str,
97 accept_header: &str,
98 dest_dir: &PathBuf,
99 client: &Arc<dyn HttpClient>,
100 fs: &Arc<dyn Fs>,
101 id: Option<&str>,
102) -> Result<(), DevContainerError> {
103 let url = match id {
104 Some(id) => format!("https://{registry}/v2/{repository_path}/{id}/blobs/{blob_digest}"),
105 None => format!("https://{registry}/v2/{repository_path}/blobs/{blob_digest}"),
106 };
107
108 let request = Request::get(&url)
109 .header("Authorization", format!("Bearer {}", token))
110 .header("Accept", accept_header)
111 .body(AsyncBody::default())
112 .map_err(|e| {
113 log::error!("Failed to create blob request: {e}");
114 DevContainerError::ResourceFetchFailed
115 })?;
116
117 let mut response = client.send(request).await.map_err(|e| {
118 log::error!("Failed to download feature blob: {e}");
119 DevContainerError::ResourceFetchFailed
120 })?;
121 let status = response.status();
122
123 let body = BufReader::new(response.body_mut());
124
125 if !status.is_success() {
126 let body_text = String::from_utf8_lossy(body.buffer());
127 log::error!(
128 "Feature blob download returned HTTP {}: {}",
129 status.as_u16(),
130 body_text,
131 );
132 return Err(DevContainerError::ResourceFetchFailed);
133 }
134
135 futures::pin_mut!(body);
136 let body: Pin<&mut (dyn AsyncRead + Send)> = body;
137 let archive = async_tar::Archive::new(body);
138 fs.extract_tar_file(dest_dir, archive).await.map_err(|e| {
139 log::error!("Failed to extract feature tarball: {e}");
140 DevContainerError::FilesystemError
141 })?;
142
143 Ok(())
144}
145
146pub(crate) async fn get_deserialized_response<T>(
147 token: &str,
148 url: &str,
149 client: &Arc<dyn HttpClient>,
150) -> Result<T, String>
151where
152 T: for<'de> Deserialize<'de>,
153{
154 let request = match Request::get(url)
155 .header("Authorization", format!("Bearer {}", token))
156 .header("Accept", "application/vnd.oci.image.manifest.v1+json")
157 .body(AsyncBody::default())
158 {
159 Ok(request) => request,
160 Err(e) => return Err(format!("Failed to create request: {}", e)),
161 };
162 let response = match client.send(request).await {
163 Ok(response) => response,
164 Err(e) => {
165 return Err(format!("Failed to send request to {}: {}", url, e));
166 }
167 };
168
169 let status = response.status();
170 let mut output = String::new();
171
172 if let Err(e) = response.into_body().read_to_string(&mut output).await {
173 return Err(format!("Failed to read response body from {}: {}", url, e));
174 };
175
176 if !status.is_success() {
177 return Err(format!(
178 "OCI request to {} returned HTTP {}: {}",
179 url,
180 status.as_u16(),
181 &output[..output.len().min(500)],
182 ));
183 }
184
185 match serde_json_lenient::from_str(&output) {
186 Ok(response) => Ok(response),
187 Err(e) => Err(format!(
188 "Failed to deserialize response from {}: {} (body: {})",
189 url,
190 e,
191 &output[..output.len().min(500)],
192 )),
193 }
194}
195
196#[cfg(test)]
197mod test {
198 use std::{path::PathBuf, sync::Arc};
199
200 use fs::{FakeFs, Fs};
201 use gpui::TestAppContext;
202 use http_client::{FakeHttpClient, anyhow};
203 use serde::Deserialize;
204
205 use crate::oci::{
206 TokenResponse, download_oci_tarball, get_deserializable_oci_blob,
207 get_deserialized_response, get_latest_oci_manifest, get_oci_token,
208 };
209
210 async fn build_test_tarball() -> Vec<u8> {
211 let devcontainer_json = concat!(
212 "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n",
213 "// README at: https://github.com/devcontainers/templates/tree/main/src/alpine\n",
214 "{\n",
215 "\t\"name\": \"Alpine\",\n",
216 "\t// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile\n",
217 "\t\"image\": \"mcr.microsoft.com/devcontainers/base:alpine-${templateOption:imageVariant}\"\n",
218 "}\n",
219 );
220
221 let dependabot_yml = concat!(
222 "version: 2\n",
223 "updates:\n",
224 " - package-ecosystem: \"devcontainers\"\n",
225 " directory: \"/\"\n",
226 " schedule:\n",
227 " interval: weekly\n",
228 );
229
230 let buffer = futures::io::Cursor::new(Vec::new());
231 let mut builder = async_tar::Builder::new(buffer);
232
233 let files: &[(&str, &[u8], u32)] = &[
234 (
235 ".devcontainer/devcontainer.json",
236 devcontainer_json.as_bytes(),
237 0o644,
238 ),
239 (".github/dependabot.yml", dependabot_yml.as_bytes(), 0o644),
240 ("NOTES.md", b"Some notes", 0o644),
241 ("README.md", b"# Alpine\n", 0o644),
242 ];
243
244 for (path, data, mode) in files {
245 let mut header = async_tar::Header::new_gnu();
246 header.set_size(data.len() as u64);
247 header.set_mode(*mode);
248 header.set_entry_type(async_tar::EntryType::Regular);
249 header.set_cksum();
250 builder.append_data(&mut header, path, *data).await.unwrap();
251 }
252
253 let buffer = builder.into_inner().await.unwrap();
254 buffer.into_inner()
255 }
256 fn test_oci_registry() -> &'static str {
257 "ghcr.io"
258 }
259 fn test_oci_repository() -> &'static str {
260 "repository"
261 }
262
263 #[gpui::test]
264 async fn test_get_deserialized_response(_cx: &mut TestAppContext) {
265 let client = FakeHttpClient::create(|_request| async move {
266 Ok(http_client::Response::builder()
267 .status(200)
268 .body("{ \"token\": \"thisisatoken\" }".into())
269 .unwrap())
270 });
271
272 let response =
273 get_deserialized_response::<TokenResponse>("", "https://ghcr.io/token", &client).await;
274 assert!(response.is_ok());
275 assert_eq!(response.unwrap().token, "thisisatoken".to_string())
276 }
277
278 #[gpui::test]
279 async fn test_get_oci_token() {
280 let client = FakeHttpClient::create(|request| async move {
281 let host = request.uri().host();
282 if host.is_none() || host.unwrap() != test_oci_registry() {
283 return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
284 }
285 let path = request.uri().path();
286 if path != "/token" {
287 return Err(anyhow!("Unexpected path: {}", path));
288 }
289 let query = request.uri().query();
290 if query.is_none()
291 || query.unwrap()
292 != format!(
293 "service=ghcr.io&scope=repository:{}:pull",
294 test_oci_repository()
295 )
296 {
297 return Err(anyhow!("Unexpected query: {}", query.unwrap_or_default()));
298 }
299 Ok(http_client::Response::builder()
300 .status(200)
301 .body("{ \"token\": \"thisisatoken\" }".into())
302 .unwrap())
303 });
304
305 let response = get_oci_token(test_oci_registry(), test_oci_repository(), &client).await;
306
307 assert!(response.is_ok());
308 assert_eq!(response.unwrap().token, "thisisatoken".to_string());
309 }
310
311 #[gpui::test]
312 async fn test_get_latest_manifests() {
313 let client = FakeHttpClient::create(|request| async move {
314 let host = request.uri().host();
315 if host.is_none() || host.unwrap() != test_oci_registry() {
316 return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
317 }
318 let path = request.uri().path();
319 if path != format!("/v2/{}/manifests/latest", test_oci_repository()) {
320 return Err(anyhow!("Unexpected path: {}", path));
321 }
322 Ok(http_client::Response::builder()
323 .status(200)
324 .body("{
325 \"schemaVersion\": 2,
326 \"mediaType\": \"application/vnd.oci.image.manifest.v1+json\",
327 \"config\": {
328 \"mediaType\": \"application/vnd.devcontainers\",
329 \"digest\": \"sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a\",
330 \"size\": 2
331 },
332 \"layers\": [
333 {
334 \"mediaType\": \"application/vnd.devcontainers.collection.layer.v1+json\",
335 \"digest\": \"sha256:035e9c9fd9bd61f6d3965fa4bf11f3ddfd2490a8cf324f152c13cc3724d67d09\",
336 \"size\": 65235,
337 \"annotations\": {
338 \"org.opencontainers.image.title\": \"devcontainer-collection.json\"
339 }
340 }
341 ],
342 \"annotations\": {
343 \"com.github.package.type\": \"devcontainer_collection\"
344 }
345 }".into())
346 .unwrap())
347 });
348
349 let response = get_latest_oci_manifest(
350 "",
351 test_oci_registry(),
352 test_oci_repository(),
353 &client,
354 None,
355 )
356 .await;
357 assert!(response.is_ok());
358 let response = response.unwrap();
359
360 assert_eq!(response.layers.len(), 1);
361 assert_eq!(
362 response.layers[0].digest,
363 "sha256:035e9c9fd9bd61f6d3965fa4bf11f3ddfd2490a8cf324f152c13cc3724d67d09"
364 );
365 }
366
367 #[gpui::test]
368 async fn test_get_oci_blob() {
369 #[derive(Debug, Deserialize)]
370 struct DeserializableTestStruct {
371 foo: String,
372 }
373
374 let client = FakeHttpClient::create(|request| async move {
375 let host = request.uri().host();
376 if host.is_none() || host.unwrap() != test_oci_registry() {
377 return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
378 }
379 let path = request.uri().path();
380 if path != format!("/v2/{}/blobs/blobdigest", test_oci_repository()) {
381 return Err(anyhow!("Unexpected path: {}", path));
382 }
383 Ok(http_client::Response::builder()
384 .status(200)
385 .body(
386 r#"
387 {
388 "foo": "bar"
389 }
390 "#
391 .into(),
392 )
393 .unwrap())
394 });
395
396 let response: Result<DeserializableTestStruct, String> = get_deserializable_oci_blob(
397 "",
398 test_oci_registry(),
399 test_oci_repository(),
400 "blobdigest",
401 &client,
402 )
403 .await;
404 assert!(response.is_ok());
405 let response = response.unwrap();
406
407 assert_eq!(response.foo, "bar".to_string());
408 }
409
410 #[gpui::test]
411 async fn test_download_oci_tarball(cx: &mut TestAppContext) {
412 cx.executor().allow_parking();
413 let fs: Arc<dyn Fs> = FakeFs::new(cx.executor());
414
415 let destination_dir = PathBuf::from("/tmp/extracted");
416 fs.create_dir(&destination_dir).await.unwrap();
417
418 let tarball_bytes = build_test_tarball().await;
419 let tarball = std::sync::Arc::new(tarball_bytes);
420
421 let client = FakeHttpClient::create(move |request| {
422 let tarball = tarball.clone();
423 async move {
424 let host = request.uri().host();
425 if host.is_none() || host.unwrap() != test_oci_registry() {
426 return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
427 }
428 let path = request.uri().path();
429 if path != format!("/v2/{}/blobs/blobdigest", test_oci_repository()) {
430 return Err(anyhow!("Unexpected path: {}", path));
431 }
432 Ok(http_client::Response::builder()
433 .status(200)
434 .body(tarball.to_vec().into())
435 .unwrap())
436 }
437 });
438
439 let response = download_oci_tarball(
440 "",
441 test_oci_registry(),
442 test_oci_repository(),
443 "blobdigest",
444 "header",
445 &destination_dir,
446 &client,
447 &fs,
448 None,
449 )
450 .await;
451 assert!(response.is_ok());
452
453 let expected_devcontainer_json = concat!(
454 "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n",
455 "// README at: https://github.com/devcontainers/templates/tree/main/src/alpine\n",
456 "{\n",
457 "\t\"name\": \"Alpine\",\n",
458 "\t// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile\n",
459 "\t\"image\": \"mcr.microsoft.com/devcontainers/base:alpine-${templateOption:imageVariant}\"\n",
460 "}\n",
461 );
462
463 assert_eq!(
464 fs.load(&destination_dir.join(".devcontainer/devcontainer.json"))
465 .await
466 .unwrap(),
467 expected_devcontainer_json
468 )
469 }
470}