extensions.rs

  1use std::str::FromStr;
  2
  3use chrono::Utc;
  4use sea_orm::sea_query::IntoCondition;
  5use util::ResultExt;
  6
  7use super::*;
  8
  9impl Database {
 10    pub async fn get_extensions(
 11        &self,
 12        filter: Option<&str>,
 13        max_schema_version: i32,
 14        limit: usize,
 15    ) -> Result<Vec<ExtensionMetadata>> {
 16        self.transaction(|tx| async move {
 17            let mut condition = Condition::all()
 18                .add(
 19                    extension::Column::LatestVersion
 20                        .into_expr()
 21                        .eq(extension_version::Column::Version.into_expr()),
 22                )
 23                .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
 24            if let Some(filter) = filter {
 25                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 26                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 27            }
 28
 29            self.get_extensions_where(condition, Some(limit as u64), &tx)
 30                .await
 31        })
 32        .await
 33    }
 34
 35    pub async fn get_extensions_by_ids(
 36        &self,
 37        ids: &[&str],
 38        constraints: Option<&ExtensionVersionConstraints>,
 39    ) -> Result<Vec<ExtensionMetadata>> {
 40        self.transaction(|tx| async move {
 41            let extensions = extension::Entity::find()
 42                .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
 43                .all(&*tx)
 44                .await?;
 45
 46            let mut max_versions = self
 47                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
 48                .await?;
 49
 50            Ok(extensions
 51                .into_iter()
 52                .filter_map(|extension| {
 53                    let (version, _) = max_versions.remove(&extension.id)?;
 54                    Some(metadata_from_extension_and_version(extension, version))
 55                })
 56                .collect())
 57        })
 58        .await
 59    }
 60
 61    async fn get_latest_versions_for_extensions(
 62        &self,
 63        extensions: &[extension::Model],
 64        constraints: Option<&ExtensionVersionConstraints>,
 65        tx: &DatabaseTransaction,
 66    ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
 67        let mut versions = extension_version::Entity::find()
 68            .filter(
 69                extension_version::Column::ExtensionId
 70                    .is_in(extensions.iter().map(|extension| extension.id)),
 71            )
 72            .stream(tx)
 73            .await?;
 74
 75        let mut max_versions =
 76            HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
 77        while let Some(version) = versions.next().await {
 78            let version = version?;
 79            let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
 80            else {
 81                continue;
 82            };
 83
 84            if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
 85                if max_extension_version > &extension_version {
 86                    continue;
 87                }
 88            }
 89
 90            if let Some(constraints) = constraints {
 91                if !constraints
 92                    .schema_versions
 93                    .contains(&version.schema_version)
 94                {
 95                    continue;
 96                }
 97
 98                if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
 99                    if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
100                        if !constraints.wasm_api_versions.contains(&version) {
101                            continue;
102                        }
103                    } else {
104                        continue;
105                    }
106                }
107            }
108
109            max_versions.insert(version.extension_id, (version, extension_version));
110        }
111
112        Ok(max_versions)
113    }
114
115    /// Returns all of the versions for the extension with the given ID.
116    pub async fn get_extension_versions(
117        &self,
118        extension_id: &str,
119    ) -> Result<Vec<ExtensionMetadata>> {
120        self.transaction(|tx| async move {
121            let condition = extension::Column::ExternalId
122                .eq(extension_id)
123                .into_condition();
124
125            self.get_extensions_where(condition, None, &tx).await
126        })
127        .await
128    }
129
130    async fn get_extensions_where(
131        &self,
132        condition: Condition,
133        limit: Option<u64>,
134        tx: &DatabaseTransaction,
135    ) -> Result<Vec<ExtensionMetadata>> {
136        let extensions = extension::Entity::find()
137            .inner_join(extension_version::Entity)
138            .select_also(extension_version::Entity)
139            .filter(condition)
140            .order_by_desc(extension::Column::TotalDownloadCount)
141            .order_by_asc(extension::Column::Name)
142            .limit(limit)
143            .all(tx)
144            .await?;
145
146        Ok(extensions
147            .into_iter()
148            .filter_map(|(extension, version)| {
149                Some(metadata_from_extension_and_version(extension, version?))
150            })
151            .collect())
152    }
153
154    pub async fn get_extension(
155        &self,
156        extension_id: &str,
157        constraints: Option<&ExtensionVersionConstraints>,
158    ) -> Result<Option<ExtensionMetadata>> {
159        self.transaction(|tx| async move {
160            let extension = extension::Entity::find()
161                .filter(extension::Column::ExternalId.eq(extension_id))
162                .one(&*tx)
163                .await?
164                .ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
165
166            let extensions = [extension];
167            let mut versions = self
168                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
169                .await?;
170            let [extension] = extensions;
171
172            Ok(versions.remove(&extension.id).map(|(max_version, _)| {
173                metadata_from_extension_and_version(extension, max_version)
174            }))
175        })
176        .await
177    }
178
179    pub async fn get_extension_version(
180        &self,
181        extension_id: &str,
182        version: &str,
183    ) -> Result<Option<ExtensionMetadata>> {
184        self.transaction(|tx| async move {
185            let extension = extension::Entity::find()
186                .filter(extension::Column::ExternalId.eq(extension_id))
187                .filter(extension_version::Column::Version.eq(version))
188                .inner_join(extension_version::Entity)
189                .select_also(extension_version::Entity)
190                .one(&*tx)
191                .await?;
192
193            Ok(extension.and_then(|(extension, version)| {
194                Some(metadata_from_extension_and_version(extension, version?))
195            }))
196        })
197        .await
198    }
199
200    pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
201        self.transaction(|tx| async move {
202            let mut extension_external_ids_by_id = HashMap::default();
203
204            let mut rows = extension::Entity::find().stream(&*tx).await?;
205            while let Some(row) = rows.next().await {
206                let row = row?;
207                extension_external_ids_by_id.insert(row.id, row.external_id);
208            }
209            drop(rows);
210
211            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
212                HashMap::default();
213            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
214            while let Some(row) = rows.next().await {
215                let row = row?;
216
217                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
218                    continue;
219                };
220
221                let versions = known_versions_by_extension_id
222                    .entry(extension_id.clone())
223                    .or_default();
224                if let Err(ix) = versions.binary_search(&row.version) {
225                    versions.insert(ix, row.version);
226                }
227            }
228            drop(rows);
229
230            Ok(known_versions_by_extension_id)
231        })
232        .await
233    }
234
235    pub async fn insert_extension_versions(
236        &self,
237        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
238    ) -> Result<()> {
239        self.transaction(|tx| async move {
240            for (external_id, versions) in versions_by_extension_id {
241                if versions.is_empty() {
242                    continue;
243                }
244
245                let latest_version = versions
246                    .iter()
247                    .max_by_key(|version| &version.version)
248                    .unwrap();
249
250                let insert = extension::Entity::insert(extension::ActiveModel {
251                    name: ActiveValue::Set(latest_version.name.clone()),
252                    external_id: ActiveValue::Set(external_id.to_string()),
253                    id: ActiveValue::NotSet,
254                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
255                    total_download_count: ActiveValue::NotSet,
256                })
257                .on_conflict(
258                    OnConflict::columns([extension::Column::ExternalId])
259                        .update_column(extension::Column::ExternalId)
260                        .to_owned(),
261                );
262
263                let extension = if tx.support_returning() {
264                    insert.exec_with_returning(&*tx).await?
265                } else {
266                    // Sqlite
267                    insert.exec_without_returning(&*tx).await?;
268                    extension::Entity::find()
269                        .filter(extension::Column::ExternalId.eq(*external_id))
270                        .one(&*tx)
271                        .await?
272                        .ok_or_else(|| anyhow!("failed to insert extension"))?
273                };
274
275                extension_version::Entity::insert_many(versions.iter().map(|version| {
276                    extension_version::ActiveModel {
277                        extension_id: ActiveValue::Set(extension.id),
278                        published_at: ActiveValue::Set(version.published_at),
279                        version: ActiveValue::Set(version.version.to_string()),
280                        authors: ActiveValue::Set(version.authors.join(", ")),
281                        repository: ActiveValue::Set(version.repository.clone()),
282                        description: ActiveValue::Set(version.description.clone()),
283                        schema_version: ActiveValue::Set(version.schema_version),
284                        wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
285                        download_count: ActiveValue::NotSet,
286                    }
287                }))
288                .on_conflict(OnConflict::new().do_nothing().to_owned())
289                .exec_without_returning(&*tx)
290                .await?;
291
292                if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
293                    if db_version >= latest_version.version {
294                        continue;
295                    }
296                }
297
298                let mut extension = extension.into_active_model();
299                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
300                extension.name = ActiveValue::set(latest_version.name.clone());
301                extension::Entity::update(extension).exec(&*tx).await?;
302            }
303
304            Ok(())
305        })
306        .await
307    }
308
309    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
310        self.transaction(|tx| async move {
311            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
312            enum QueryId {
313                Id,
314            }
315
316            let extension_id: Option<ExtensionId> = extension::Entity::find()
317                .filter(extension::Column::ExternalId.eq(extension))
318                .select_only()
319                .column(extension::Column::Id)
320                .into_values::<_, QueryId>()
321                .one(&*tx)
322                .await?;
323            let Some(extension_id) = extension_id else {
324                return Ok(false);
325            };
326
327            extension_version::Entity::update_many()
328                .col_expr(
329                    extension_version::Column::DownloadCount,
330                    extension_version::Column::DownloadCount.into_expr().add(1),
331                )
332                .filter(
333                    extension_version::Column::ExtensionId
334                        .eq(extension_id)
335                        .and(extension_version::Column::Version.eq(version)),
336                )
337                .exec(&*tx)
338                .await?;
339
340            extension::Entity::update_many()
341                .col_expr(
342                    extension::Column::TotalDownloadCount,
343                    extension::Column::TotalDownloadCount.into_expr().add(1),
344                )
345                .filter(extension::Column::Id.eq(extension_id))
346                .exec(&*tx)
347                .await?;
348
349            Ok(true)
350        })
351        .await
352    }
353}
354
355fn metadata_from_extension_and_version(
356    extension: extension::Model,
357    version: extension_version::Model,
358) -> ExtensionMetadata {
359    ExtensionMetadata {
360        id: extension.external_id.into(),
361        manifest: rpc::ExtensionApiManifest {
362            name: extension.name,
363            version: version.version.into(),
364            authors: version
365                .authors
366                .split(',')
367                .map(|author| author.trim().to_string())
368                .collect::<Vec<_>>(),
369            description: Some(version.description),
370            repository: version.repository,
371            schema_version: Some(version.schema_version),
372            wasm_api_version: version.wasm_api_version,
373        },
374
375        published_at: convert_time_to_chrono(version.published_at),
376        download_count: extension.total_download_count as u64,
377    }
378}
379
380pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
381    chrono::DateTime::from_naive_utc_and_offset(
382        #[allow(deprecated)]
383        chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
384        Utc,
385    )
386}