oci.rs

  1use std::{path::PathBuf, pin::Pin, sync::Arc};
  2
  3use fs::Fs;
  4use futures::{AsyncRead, AsyncReadExt, io::BufReader};
  5use http::Request;
  6use http_client::{AsyncBody, HttpClient};
  7use serde::{Deserialize, Serialize};
  8
  9use crate::devcontainer_api::DevContainerError;
 10
 11#[derive(Debug, Serialize, Deserialize)]
 12#[serde(rename_all = "camelCase")]
 13pub(crate) struct TokenResponse {
 14    pub(crate) token: String,
 15}
 16
 17#[derive(Debug, Deserialize)]
 18#[serde(rename_all = "camelCase")]
 19pub(crate) struct DockerManifestsResponse {
 20    pub(crate) layers: Vec<ManifestLayer>,
 21}
 22
 23#[derive(Debug, Deserialize)]
 24#[serde(rename_all = "camelCase")]
 25pub(crate) struct ManifestLayer {
 26    pub(crate) digest: String,
 27}
 28
 29/// Gets a bearer token for pulling from a container registry repository.
 30///
 31/// This uses the registry's `/token` endpoint directly, which works for
 32/// `ghcr.io` and other registries that follow the same convention.  For
 33/// registries that require a full `WWW-Authenticate` negotiation flow this
 34/// would need to be extended.
 35pub(crate) async fn get_oci_token(
 36    registry: &str,
 37    repository_path: &str,
 38    client: &Arc<dyn HttpClient>,
 39) -> Result<TokenResponse, String> {
 40    let url = format!(
 41        "https://{registry}/token?service={registry}&scope=repository:{repository_path}:pull",
 42    );
 43    log::debug!("Fetching OCI token from: {}", url);
 44    get_deserialized_response("", &url, client)
 45        .await
 46        .map_err(|e| {
 47            log::error!("OCI token request failed for {}: {e}", url);
 48            e
 49        })
 50}
 51
 52pub(crate) async fn get_latest_oci_manifest(
 53    token: &str,
 54    registry: &str,
 55    repository_path: &str,
 56    client: &Arc<dyn HttpClient>,
 57    id: Option<&str>,
 58) -> Result<DockerManifestsResponse, String> {
 59    get_oci_manifest(registry, repository_path, token, client, "latest", id).await
 60}
 61
 62pub(crate) async fn get_oci_manifest(
 63    registry: &str,
 64    repository_path: &str,
 65    token: &str,
 66    client: &Arc<dyn HttpClient>,
 67    version: &str,
 68    id: Option<&str>,
 69) -> Result<DockerManifestsResponse, String> {
 70    let url = match id {
 71        Some(id) => format!("https://{registry}/v2/{repository_path}/{id}/manifests/{version}"),
 72        None => format!("https://{registry}/v2/{repository_path}/manifests/{version}"),
 73    };
 74
 75    get_deserialized_response(token, &url, client).await
 76}
 77
 78pub(crate) async fn get_deserializable_oci_blob<T>(
 79    token: &str,
 80    registry: &str,
 81    repository_path: &str,
 82    blob_digest: &str,
 83    client: &Arc<dyn HttpClient>,
 84) -> Result<T, String>
 85where
 86    T: for<'a> Deserialize<'a>,
 87{
 88    let url = format!("https://{registry}/v2/{repository_path}/blobs/{blob_digest}");
 89    get_deserialized_response(token, &url, client).await
 90}
 91
 92pub(crate) async fn download_oci_tarball(
 93    token: &str,
 94    registry: &str,
 95    repository_path: &str,
 96    blob_digest: &str,
 97    accept_header: &str,
 98    dest_dir: &PathBuf,
 99    client: &Arc<dyn HttpClient>,
100    fs: &Arc<dyn Fs>,
101    id: Option<&str>,
102) -> Result<(), DevContainerError> {
103    let url = match id {
104        Some(id) => format!("https://{registry}/v2/{repository_path}/{id}/blobs/{blob_digest}"),
105        None => format!("https://{registry}/v2/{repository_path}/blobs/{blob_digest}"),
106    };
107
108    let request = Request::get(&url)
109        .header("Authorization", format!("Bearer {}", token))
110        .header("Accept", accept_header)
111        .body(AsyncBody::default())
112        .map_err(|e| {
113            log::error!("Failed to create blob request: {e}");
114            DevContainerError::ResourceFetchFailed
115        })?;
116
117    let mut response = client.send(request).await.map_err(|e| {
118        log::error!("Failed to download feature blob: {e}");
119        DevContainerError::ResourceFetchFailed
120    })?;
121    let status = response.status();
122
123    let body = BufReader::new(response.body_mut());
124
125    if !status.is_success() {
126        let body_text = String::from_utf8_lossy(body.buffer());
127        log::error!(
128            "Feature blob download returned HTTP {}: {}",
129            status.as_u16(),
130            body_text,
131        );
132        return Err(DevContainerError::ResourceFetchFailed);
133    }
134
135    futures::pin_mut!(body);
136    let body: Pin<&mut (dyn AsyncRead + Send)> = body;
137    let archive = async_tar::Archive::new(body);
138    fs.extract_tar_file(dest_dir, archive).await.map_err(|e| {
139        log::error!("Failed to extract feature tarball: {e}");
140        DevContainerError::FilesystemError
141    })?;
142
143    Ok(())
144}
145
146pub(crate) async fn get_deserialized_response<T>(
147    token: &str,
148    url: &str,
149    client: &Arc<dyn HttpClient>,
150) -> Result<T, String>
151where
152    T: for<'de> Deserialize<'de>,
153{
154    let request = match Request::get(url)
155        .header("Authorization", format!("Bearer {}", token))
156        .header("Accept", "application/vnd.oci.image.manifest.v1+json")
157        .body(AsyncBody::default())
158    {
159        Ok(request) => request,
160        Err(e) => return Err(format!("Failed to create request: {}", e)),
161    };
162    let response = match client.send(request).await {
163        Ok(response) => response,
164        Err(e) => {
165            return Err(format!("Failed to send request to {}: {}", url, e));
166        }
167    };
168
169    let status = response.status();
170    let mut output = String::new();
171
172    if let Err(e) = response.into_body().read_to_string(&mut output).await {
173        return Err(format!("Failed to read response body from {}: {}", url, e));
174    };
175
176    if !status.is_success() {
177        return Err(format!(
178            "OCI request to {} returned HTTP {}: {}",
179            url,
180            status.as_u16(),
181            &output[..output.len().min(500)],
182        ));
183    }
184
185    match serde_json_lenient::from_str(&output) {
186        Ok(response) => Ok(response),
187        Err(e) => Err(format!(
188            "Failed to deserialize response from {}: {} (body: {})",
189            url,
190            e,
191            &output[..output.len().min(500)],
192        )),
193    }
194}
195
196#[cfg(test)]
197mod test {
198    use std::{path::PathBuf, sync::Arc};
199
200    use fs::{FakeFs, Fs};
201    use gpui::TestAppContext;
202    use http_client::{FakeHttpClient, anyhow};
203    use serde::Deserialize;
204
205    use crate::oci::{
206        TokenResponse, download_oci_tarball, get_deserializable_oci_blob,
207        get_deserialized_response, get_latest_oci_manifest, get_oci_token,
208    };
209
210    async fn build_test_tarball() -> Vec<u8> {
211        let devcontainer_json = concat!(
212            "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n",
213            "// README at: https://github.com/devcontainers/templates/tree/main/src/alpine\n",
214            "{\n",
215            "\t\"name\": \"Alpine\",\n",
216            "\t// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile\n",
217            "\t\"image\": \"mcr.microsoft.com/devcontainers/base:alpine-${templateOption:imageVariant}\"\n",
218            "}\n",
219        );
220
221        let dependabot_yml = concat!(
222            "version: 2\n",
223            "updates:\n",
224            " - package-ecosystem: \"devcontainers\"\n",
225            "   directory: \"/\"\n",
226            "   schedule:\n",
227            "     interval: weekly\n",
228        );
229
230        let buffer = futures::io::Cursor::new(Vec::new());
231        let mut builder = async_tar::Builder::new(buffer);
232
233        let files: &[(&str, &[u8], u32)] = &[
234            (
235                ".devcontainer/devcontainer.json",
236                devcontainer_json.as_bytes(),
237                0o644,
238            ),
239            (".github/dependabot.yml", dependabot_yml.as_bytes(), 0o644),
240            ("NOTES.md", b"Some notes", 0o644),
241            ("README.md", b"# Alpine\n", 0o644),
242        ];
243
244        for (path, data, mode) in files {
245            let mut header = async_tar::Header::new_gnu();
246            header.set_size(data.len() as u64);
247            header.set_mode(*mode);
248            header.set_entry_type(async_tar::EntryType::Regular);
249            header.set_cksum();
250            builder.append_data(&mut header, path, *data).await.unwrap();
251        }
252
253        let buffer = builder.into_inner().await.unwrap();
254        buffer.into_inner()
255    }
256    fn test_oci_registry() -> &'static str {
257        "ghcr.io"
258    }
259    fn test_oci_repository() -> &'static str {
260        "repository"
261    }
262
263    #[gpui::test]
264    async fn test_get_deserialized_response(_cx: &mut TestAppContext) {
265        let client = FakeHttpClient::create(|_request| async move {
266            Ok(http_client::Response::builder()
267                .status(200)
268                .body("{ \"token\": \"thisisatoken\" }".into())
269                .unwrap())
270        });
271
272        let response =
273            get_deserialized_response::<TokenResponse>("", "https://ghcr.io/token", &client).await;
274        assert!(response.is_ok());
275        assert_eq!(response.unwrap().token, "thisisatoken".to_string())
276    }
277
278    #[gpui::test]
279    async fn test_get_oci_token() {
280        let client = FakeHttpClient::create(|request| async move {
281            let host = request.uri().host();
282            if host.is_none() || host.unwrap() != test_oci_registry() {
283                return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
284            }
285            let path = request.uri().path();
286            if path != "/token" {
287                return Err(anyhow!("Unexpected path: {}", path));
288            }
289            let query = request.uri().query();
290            if query.is_none()
291                || query.unwrap()
292                    != format!(
293                        "service=ghcr.io&scope=repository:{}:pull",
294                        test_oci_repository()
295                    )
296            {
297                return Err(anyhow!("Unexpected query: {}", query.unwrap_or_default()));
298            }
299            Ok(http_client::Response::builder()
300                .status(200)
301                .body("{ \"token\": \"thisisatoken\" }".into())
302                .unwrap())
303        });
304
305        let response = get_oci_token(test_oci_registry(), test_oci_repository(), &client).await;
306
307        assert!(response.is_ok());
308        assert_eq!(response.unwrap().token, "thisisatoken".to_string());
309    }
310
311    #[gpui::test]
312    async fn test_get_latest_manifests() {
313        let client = FakeHttpClient::create(|request| async move {
314            let host = request.uri().host();
315            if host.is_none() || host.unwrap() != test_oci_registry() {
316                return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
317            }
318            let path = request.uri().path();
319            if path != format!("/v2/{}/manifests/latest", test_oci_repository()) {
320                return Err(anyhow!("Unexpected path: {}", path));
321            }
322            Ok(http_client::Response::builder()
323                .status(200)
324                .body("{
325                    \"schemaVersion\": 2,
326                    \"mediaType\": \"application/vnd.oci.image.manifest.v1+json\",
327                    \"config\": {
328                        \"mediaType\": \"application/vnd.devcontainers\",
329                        \"digest\": \"sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a\",
330                        \"size\": 2
331                    },
332                    \"layers\": [
333                        {
334                            \"mediaType\": \"application/vnd.devcontainers.collection.layer.v1+json\",
335                            \"digest\": \"sha256:035e9c9fd9bd61f6d3965fa4bf11f3ddfd2490a8cf324f152c13cc3724d67d09\",
336                            \"size\": 65235,
337                            \"annotations\": {
338                                \"org.opencontainers.image.title\": \"devcontainer-collection.json\"
339                            }
340                        }
341                    ],
342                    \"annotations\": {
343                        \"com.github.package.type\": \"devcontainer_collection\"
344                    }
345                }".into())
346                .unwrap())
347        });
348
349        let response = get_latest_oci_manifest(
350            "",
351            test_oci_registry(),
352            test_oci_repository(),
353            &client,
354            None,
355        )
356        .await;
357        assert!(response.is_ok());
358        let response = response.unwrap();
359
360        assert_eq!(response.layers.len(), 1);
361        assert_eq!(
362            response.layers[0].digest,
363            "sha256:035e9c9fd9bd61f6d3965fa4bf11f3ddfd2490a8cf324f152c13cc3724d67d09"
364        );
365    }
366
367    #[gpui::test]
368    async fn test_get_oci_blob() {
369        #[derive(Debug, Deserialize)]
370        struct DeserializableTestStruct {
371            foo: String,
372        }
373
374        let client = FakeHttpClient::create(|request| async move {
375            let host = request.uri().host();
376            if host.is_none() || host.unwrap() != test_oci_registry() {
377                return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
378            }
379            let path = request.uri().path();
380            if path != format!("/v2/{}/blobs/blobdigest", test_oci_repository()) {
381                return Err(anyhow!("Unexpected path: {}", path));
382            }
383            Ok(http_client::Response::builder()
384                .status(200)
385                .body(
386                    r#"
387                    {
388                        "foo": "bar"
389                    }
390                    "#
391                    .into(),
392                )
393                .unwrap())
394        });
395
396        let response: Result<DeserializableTestStruct, String> = get_deserializable_oci_blob(
397            "",
398            test_oci_registry(),
399            test_oci_repository(),
400            "blobdigest",
401            &client,
402        )
403        .await;
404        assert!(response.is_ok());
405        let response = response.unwrap();
406
407        assert_eq!(response.foo, "bar".to_string());
408    }
409
410    #[gpui::test]
411    async fn test_download_oci_tarball(cx: &mut TestAppContext) {
412        cx.executor().allow_parking();
413        let fs: Arc<dyn Fs> = FakeFs::new(cx.executor());
414
415        let destination_dir = PathBuf::from("/tmp/extracted");
416        fs.create_dir(&destination_dir).await.unwrap();
417
418        let tarball_bytes = build_test_tarball().await;
419        let tarball = std::sync::Arc::new(tarball_bytes);
420
421        let client = FakeHttpClient::create(move |request| {
422            let tarball = tarball.clone();
423            async move {
424                let host = request.uri().host();
425                if host.is_none() || host.unwrap() != test_oci_registry() {
426                    return Err(anyhow!("Unexpected host: {}", host.unwrap_or_default()));
427                }
428                let path = request.uri().path();
429                if path != format!("/v2/{}/blobs/blobdigest", test_oci_repository()) {
430                    return Err(anyhow!("Unexpected path: {}", path));
431                }
432                Ok(http_client::Response::builder()
433                    .status(200)
434                    .body(tarball.to_vec().into())
435                    .unwrap())
436            }
437        });
438
439        let response = download_oci_tarball(
440            "",
441            test_oci_registry(),
442            test_oci_repository(),
443            "blobdigest",
444            "header",
445            &destination_dir,
446            &client,
447            &fs,
448            None,
449        )
450        .await;
451        assert!(response.is_ok());
452
453        let expected_devcontainer_json = concat!(
454            "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n",
455            "// README at: https://github.com/devcontainers/templates/tree/main/src/alpine\n",
456            "{\n",
457            "\t\"name\": \"Alpine\",\n",
458            "\t// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile\n",
459            "\t\"image\": \"mcr.microsoft.com/devcontainers/base:alpine-${templateOption:imageVariant}\"\n",
460            "}\n",
461        );
462
463        assert_eq!(
464            fs.load(&destination_dir.join(".devcontainer/devcontainer.json"))
465                .await
466                .unwrap(),
467            expected_devcontainer_json
468        )
469    }
470}