1use crate::db::ExtensionVersionConstraints;
2use crate::{AppState, Error, Result, db::NewExtensionVersion};
3use anyhow::Context as _;
4use aws_sdk_s3::presigning::PresigningConfig;
5use axum::{
6 Extension, Json, Router,
7 extract::{Path, Query, RawQuery},
8 http::StatusCode,
9 response::Redirect,
10 routing::get,
11};
12use cloud_api_types::{ExtensionApiManifest, GetExtensionsResponse};
13use collections::HashMap;
14use semver::Version as SemanticVersion;
15use serde::Deserialize;
16use std::{sync::Arc, time::Duration};
17use time::PrimitiveDateTime;
18use util::{ResultExt, maybe};
19
20pub fn router() -> Router {
21 Router::new()
22 .route("/extensions", get(get_extensions))
23 .route("/extensions/updates", get(get_extension_updates))
24 .route("/extensions/:extension_id", get(get_extension_versions))
25 .route(
26 "/extensions/:extension_id/download",
27 get(download_latest_extension),
28 )
29 .route(
30 "/extensions/:extension_id/:version/download",
31 get(download_extension),
32 )
33}
34
35const UPSTREAM_EXTENSIONS_URL: &str = "https://cloud.zed.dev/extensions";
36
37async fn get_extensions(RawQuery(query): RawQuery) -> Result<Json<GetExtensionsResponse>> {
38 let upstream_url = match query {
39 Some(query) => format!("{UPSTREAM_EXTENSIONS_URL}?{query}"),
40 None => UPSTREAM_EXTENSIONS_URL.to_string(),
41 };
42
43 let response = reqwest::get(&upstream_url).await.map_err(|error| {
44 tracing::error!(
45 ?error,
46 "failed to proxy request to upstream extensions service"
47 );
48 Error::http(
49 StatusCode::BAD_GATEWAY,
50 "upstream extensions service unavailable".into(),
51 )
52 })?;
53
54 let status = response.status();
55 if !status.is_success() {
56 let upstream_status =
57 StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
58 let body = response.text().await.unwrap_or_default();
59 tracing::error!(
60 status = status.as_u16(),
61 body,
62 "upstream extensions service returned an error"
63 );
64 return Err(Error::http(upstream_status, body));
65 }
66
67 let body: GetExtensionsResponse = response.json().await.map_err(|error| {
68 tracing::error!(
69 ?error,
70 "failed to parse response from upstream extensions service"
71 );
72 Error::http(
73 StatusCode::BAD_GATEWAY,
74 "failed to parse upstream response".into(),
75 )
76 })?;
77
78 Ok(Json(body))
79}
80
81#[derive(Debug, Deserialize)]
82struct GetExtensionUpdatesParams {
83 ids: String,
84 min_schema_version: i32,
85 max_schema_version: i32,
86 min_wasm_api_version: semver::Version,
87 max_wasm_api_version: semver::Version,
88}
89
90async fn get_extension_updates(
91 Extension(app): Extension<Arc<AppState>>,
92 Query(params): Query<GetExtensionUpdatesParams>,
93) -> Result<Json<GetExtensionsResponse>> {
94 let constraints = ExtensionVersionConstraints {
95 schema_versions: params.min_schema_version..=params.max_schema_version,
96 wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
97 };
98
99 let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
100
101 let extensions = app
102 .db
103 .get_extensions_by_ids(&extension_ids, Some(&constraints))
104 .await?;
105
106 Ok(Json(GetExtensionsResponse { data: extensions }))
107}
108
109#[derive(Debug, Deserialize)]
110struct GetExtensionVersionsParams {
111 extension_id: String,
112}
113
114async fn get_extension_versions(
115 Extension(app): Extension<Arc<AppState>>,
116 Path(params): Path<GetExtensionVersionsParams>,
117) -> Result<Json<GetExtensionsResponse>> {
118 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
119
120 Ok(Json(GetExtensionsResponse {
121 data: extension_versions,
122 }))
123}
124
125#[derive(Debug, Deserialize)]
126struct DownloadLatestExtensionPathParams {
127 extension_id: String,
128}
129
130#[derive(Debug, Deserialize)]
131struct DownloadLatestExtensionQueryParams {
132 min_schema_version: Option<i32>,
133 max_schema_version: Option<i32>,
134 min_wasm_api_version: Option<SemanticVersion>,
135 max_wasm_api_version: Option<SemanticVersion>,
136}
137
138async fn download_latest_extension(
139 Extension(app): Extension<Arc<AppState>>,
140 Path(params): Path<DownloadLatestExtensionPathParams>,
141 Query(query): Query<DownloadLatestExtensionQueryParams>,
142) -> Result<Redirect> {
143 let constraints = maybe!({
144 let min_schema_version = query.min_schema_version?;
145 let max_schema_version = query.max_schema_version?;
146 let min_wasm_api_version = query.min_wasm_api_version?;
147 let max_wasm_api_version = query.max_wasm_api_version?;
148
149 Some(ExtensionVersionConstraints {
150 schema_versions: min_schema_version..=max_schema_version,
151 wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
152 })
153 });
154
155 let extension = app
156 .db
157 .get_extension(¶ms.extension_id, constraints.as_ref())
158 .await?
159 .context("unknown extension")?;
160 download_extension(
161 Extension(app),
162 Path(DownloadExtensionParams {
163 extension_id: params.extension_id,
164 version: extension.manifest.version.to_string(),
165 }),
166 )
167 .await
168}
169
170#[derive(Debug, Deserialize)]
171struct DownloadExtensionParams {
172 extension_id: String,
173 version: String,
174}
175
176async fn download_extension(
177 Extension(app): Extension<Arc<AppState>>,
178 Path(params): Path<DownloadExtensionParams>,
179) -> Result<Redirect> {
180 let Some((blob_store_client, bucket)) = app
181 .blob_store_client
182 .clone()
183 .zip(app.config.blob_store_bucket.clone())
184 else {
185 Err(Error::http(
186 StatusCode::NOT_IMPLEMENTED,
187 "not supported".into(),
188 ))?
189 };
190
191 let DownloadExtensionParams {
192 extension_id,
193 version,
194 } = params;
195
196 let version_exists = app
197 .db
198 .record_extension_download(&extension_id, &version)
199 .await?;
200
201 if !version_exists {
202 Err(Error::http(
203 StatusCode::NOT_FOUND,
204 "unknown extension version".into(),
205 ))?;
206 }
207
208 let url = blob_store_client
209 .get_object()
210 .bucket(bucket)
211 .key(format!(
212 "extensions/{extension_id}/{version}/archive.tar.gz"
213 ))
214 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
215 .await
216 .context("creating presigned extension download url")?;
217
218 Ok(Redirect::temporary(url.uri()))
219}
220
221const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
222const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
223
224pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
225 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
226 log::info!("no blob store client");
227 return;
228 };
229 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
230 log::info!("no blob store bucket");
231 return;
232 };
233
234 let executor = app_state.executor.clone();
235 executor.spawn_detached({
236 let executor = executor.clone();
237 async move {
238 loop {
239 fetch_extensions_from_blob_store(
240 &blob_store_client,
241 &blob_store_bucket,
242 &app_state,
243 )
244 .await
245 .log_err();
246 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
247 }
248 }
249 });
250}
251
252async fn fetch_extensions_from_blob_store(
253 blob_store_client: &aws_sdk_s3::Client,
254 blob_store_bucket: &String,
255 app_state: &Arc<AppState>,
256) -> anyhow::Result<()> {
257 log::info!("fetching extensions from blob store");
258
259 let mut next_marker = None;
260 let mut published_versions = HashMap::<String, Vec<String>>::default();
261
262 loop {
263 let list = blob_store_client
264 .list_objects()
265 .bucket(blob_store_bucket)
266 .prefix("extensions/")
267 .set_marker(next_marker.clone())
268 .send()
269 .await?;
270 let objects = list.contents.unwrap_or_default();
271 log::info!("fetched {} object(s) from blob store", objects.len());
272
273 for object in &objects {
274 let Some(key) = object.key.as_ref() else {
275 continue;
276 };
277 let mut parts = key.split('/');
278 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
279 continue;
280 };
281 let Some(extension_id) = parts.next() else {
282 continue;
283 };
284 let Some(version) = parts.next() else {
285 continue;
286 };
287 if parts.next() == Some("manifest.json") {
288 published_versions
289 .entry(extension_id.to_owned())
290 .or_default()
291 .push(version.to_owned());
292 }
293 }
294
295 if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
296 next_marker.clone_from(&last_object.key);
297 } else {
298 break;
299 }
300 }
301
302 log::info!("found {} published extensions", published_versions.len());
303
304 let known_versions = app_state.db.get_known_extension_versions().await?;
305
306 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
307 let empty = Vec::new();
308 for (extension_id, published_versions) in &published_versions {
309 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
310
311 for published_version in published_versions {
312 if known_versions
313 .binary_search_by_key(&published_version, |known_version| known_version)
314 .is_err()
315 && let Some(extension) = fetch_extension_manifest(
316 blob_store_client,
317 blob_store_bucket,
318 extension_id,
319 published_version,
320 )
321 .await
322 .log_err()
323 {
324 new_versions
325 .entry(extension_id)
326 .or_default()
327 .push(extension);
328 }
329 }
330 }
331
332 app_state
333 .db
334 .insert_extension_versions(&new_versions)
335 .await?;
336
337 log::info!(
338 "fetched {} new extensions from blob store",
339 new_versions.values().map(|v| v.len()).sum::<usize>()
340 );
341
342 Ok(())
343}
344
345async fn fetch_extension_manifest(
346 blob_store_client: &aws_sdk_s3::Client,
347 blob_store_bucket: &String,
348 extension_id: &str,
349 version: &str,
350) -> anyhow::Result<NewExtensionVersion> {
351 let object = blob_store_client
352 .get_object()
353 .bucket(blob_store_bucket)
354 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
355 .send()
356 .await?;
357 let manifest_bytes = object
358 .body
359 .collect()
360 .await
361 .map(|data| data.into_bytes())
362 .with_context(|| {
363 format!("failed to download manifest for extension {extension_id} version {version}")
364 })?
365 .to_vec();
366 let manifest =
367 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
368 format!(
369 "invalid manifest for extension {extension_id} version {version}: {}",
370 String::from_utf8_lossy(&manifest_bytes)
371 )
372 })?;
373 let published_at = object.last_modified.with_context(|| {
374 format!("missing last modified timestamp for extension {extension_id} version {version}")
375 })?;
376 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
377 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
378 let version = semver::Version::parse(&manifest.version).with_context(|| {
379 format!("invalid version for extension {extension_id} version {version}")
380 })?;
381 Ok(NewExtensionVersion {
382 name: manifest.name,
383 version,
384 description: manifest.description.unwrap_or_default(),
385 authors: manifest.authors,
386 repository: manifest.repository,
387 schema_version: manifest.schema_version.unwrap_or(0),
388 wasm_api_version: manifest.wasm_api_version,
389 provides: manifest.provides,
390 published_at,
391 })
392}