extensions.rs

  1use std::str::FromStr;
  2
  3use chrono::Utc;
  4use sea_orm::sea_query::IntoCondition;
  5use util::ResultExt;
  6
  7use super::*;
  8
  9impl Database {
 10    pub async fn get_extensions(
 11        &self,
 12        filter: Option<&str>,
 13        provides_filter: Option<&BTreeSet<ExtensionProvides>>,
 14        max_schema_version: i32,
 15        limit: usize,
 16    ) -> Result<Vec<ExtensionMetadata>> {
 17        self.transaction(|tx| async move {
 18            let mut condition = Condition::all()
 19                .add(
 20                    extension::Column::LatestVersion
 21                        .into_expr()
 22                        .eq(extension_version::Column::Version.into_expr()),
 23                )
 24                .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
 25            if let Some(filter) = filter {
 26                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 27                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 28            }
 29
 30            if let Some(provides_filter) = provides_filter {
 31                condition = apply_provides_filter(condition, provides_filter);
 32            }
 33
 34            self.get_extensions_where(condition, Some(limit as u64), &tx)
 35                .await
 36        })
 37        .await
 38    }
 39
 40    pub async fn get_extensions_by_ids(
 41        &self,
 42        ids: &[&str],
 43        constraints: Option<&ExtensionVersionConstraints>,
 44    ) -> Result<Vec<ExtensionMetadata>> {
 45        self.transaction(|tx| async move {
 46            let extensions = extension::Entity::find()
 47                .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
 48                .all(&*tx)
 49                .await?;
 50
 51            let mut max_versions = self
 52                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
 53                .await?;
 54
 55            Ok(extensions
 56                .into_iter()
 57                .filter_map(|extension| {
 58                    let (version, _) = max_versions.remove(&extension.id)?;
 59                    Some(metadata_from_extension_and_version(extension, version))
 60                })
 61                .collect())
 62        })
 63        .await
 64    }
 65
 66    async fn get_latest_versions_for_extensions(
 67        &self,
 68        extensions: &[extension::Model],
 69        constraints: Option<&ExtensionVersionConstraints>,
 70        tx: &DatabaseTransaction,
 71    ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
 72        let mut versions = extension_version::Entity::find()
 73            .filter(
 74                extension_version::Column::ExtensionId
 75                    .is_in(extensions.iter().map(|extension| extension.id)),
 76            )
 77            .stream(tx)
 78            .await?;
 79
 80        let mut max_versions =
 81            HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
 82        while let Some(version) = versions.next().await {
 83            let version = version?;
 84            let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
 85            else {
 86                continue;
 87            };
 88
 89            if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
 90                if max_extension_version > &extension_version {
 91                    continue;
 92                }
 93            }
 94
 95            if let Some(constraints) = constraints {
 96                if !constraints
 97                    .schema_versions
 98                    .contains(&version.schema_version)
 99                {
100                    continue;
101                }
102
103                if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
104                    if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
105                        if !constraints.wasm_api_versions.contains(&version) {
106                            continue;
107                        }
108                    } else {
109                        continue;
110                    }
111                }
112            }
113
114            max_versions.insert(version.extension_id, (version, extension_version));
115        }
116
117        Ok(max_versions)
118    }
119
120    /// Returns all of the versions for the extension with the given ID.
121    pub async fn get_extension_versions(
122        &self,
123        extension_id: &str,
124    ) -> Result<Vec<ExtensionMetadata>> {
125        self.transaction(|tx| async move {
126            let condition = extension::Column::ExternalId
127                .eq(extension_id)
128                .into_condition();
129
130            self.get_extensions_where(condition, None, &tx).await
131        })
132        .await
133    }
134
135    async fn get_extensions_where(
136        &self,
137        condition: Condition,
138        limit: Option<u64>,
139        tx: &DatabaseTransaction,
140    ) -> Result<Vec<ExtensionMetadata>> {
141        let extensions = extension::Entity::find()
142            .inner_join(extension_version::Entity)
143            .select_also(extension_version::Entity)
144            .filter(condition)
145            .order_by_desc(extension::Column::TotalDownloadCount)
146            .order_by_asc(extension::Column::Name)
147            .limit(limit)
148            .all(tx)
149            .await?;
150
151        Ok(extensions
152            .into_iter()
153            .filter_map(|(extension, version)| {
154                Some(metadata_from_extension_and_version(extension, version?))
155            })
156            .collect())
157    }
158
159    pub async fn get_extension(
160        &self,
161        extension_id: &str,
162        constraints: Option<&ExtensionVersionConstraints>,
163    ) -> Result<Option<ExtensionMetadata>> {
164        self.transaction(|tx| async move {
165            let extension = extension::Entity::find()
166                .filter(extension::Column::ExternalId.eq(extension_id))
167                .one(&*tx)
168                .await?
169                .ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
170
171            let extensions = [extension];
172            let mut versions = self
173                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
174                .await?;
175            let [extension] = extensions;
176
177            Ok(versions.remove(&extension.id).map(|(max_version, _)| {
178                metadata_from_extension_and_version(extension, max_version)
179            }))
180        })
181        .await
182    }
183
184    pub async fn get_extension_version(
185        &self,
186        extension_id: &str,
187        version: &str,
188    ) -> Result<Option<ExtensionMetadata>> {
189        self.transaction(|tx| async move {
190            let extension = extension::Entity::find()
191                .filter(extension::Column::ExternalId.eq(extension_id))
192                .filter(extension_version::Column::Version.eq(version))
193                .inner_join(extension_version::Entity)
194                .select_also(extension_version::Entity)
195                .one(&*tx)
196                .await?;
197
198            Ok(extension.and_then(|(extension, version)| {
199                Some(metadata_from_extension_and_version(extension, version?))
200            }))
201        })
202        .await
203    }
204
205    pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
206        self.transaction(|tx| async move {
207            let mut extension_external_ids_by_id = HashMap::default();
208
209            let mut rows = extension::Entity::find().stream(&*tx).await?;
210            while let Some(row) = rows.next().await {
211                let row = row?;
212                extension_external_ids_by_id.insert(row.id, row.external_id);
213            }
214            drop(rows);
215
216            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
217                HashMap::default();
218            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
219            while let Some(row) = rows.next().await {
220                let row = row?;
221
222                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
223                    continue;
224                };
225
226                let versions = known_versions_by_extension_id
227                    .entry(extension_id.clone())
228                    .or_default();
229                if let Err(ix) = versions.binary_search(&row.version) {
230                    versions.insert(ix, row.version);
231                }
232            }
233            drop(rows);
234
235            Ok(known_versions_by_extension_id)
236        })
237        .await
238    }
239
240    pub async fn insert_extension_versions(
241        &self,
242        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
243    ) -> Result<()> {
244        self.transaction(|tx| async move {
245            for (external_id, versions) in versions_by_extension_id {
246                if versions.is_empty() {
247                    continue;
248                }
249
250                let latest_version = versions
251                    .iter()
252                    .max_by_key(|version| &version.version)
253                    .unwrap();
254
255                let insert = extension::Entity::insert(extension::ActiveModel {
256                    name: ActiveValue::Set(latest_version.name.clone()),
257                    external_id: ActiveValue::Set(external_id.to_string()),
258                    id: ActiveValue::NotSet,
259                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
260                    total_download_count: ActiveValue::NotSet,
261                })
262                .on_conflict(
263                    OnConflict::columns([extension::Column::ExternalId])
264                        .update_column(extension::Column::ExternalId)
265                        .to_owned(),
266                );
267
268                let extension = if tx.support_returning() {
269                    insert.exec_with_returning(&*tx).await?
270                } else {
271                    // Sqlite
272                    insert.exec_without_returning(&*tx).await?;
273                    extension::Entity::find()
274                        .filter(extension::Column::ExternalId.eq(*external_id))
275                        .one(&*tx)
276                        .await?
277                        .ok_or_else(|| anyhow!("failed to insert extension"))?
278                };
279
280                extension_version::Entity::insert_many(versions.iter().map(|version| {
281                    extension_version::ActiveModel {
282                        extension_id: ActiveValue::Set(extension.id),
283                        published_at: ActiveValue::Set(version.published_at),
284                        version: ActiveValue::Set(version.version.to_string()),
285                        authors: ActiveValue::Set(version.authors.join(", ")),
286                        repository: ActiveValue::Set(version.repository.clone()),
287                        description: ActiveValue::Set(version.description.clone()),
288                        schema_version: ActiveValue::Set(version.schema_version),
289                        wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
290                        provides_themes: ActiveValue::Set(
291                            version.provides.contains(&ExtensionProvides::Themes),
292                        ),
293                        provides_icon_themes: ActiveValue::Set(
294                            version.provides.contains(&ExtensionProvides::IconThemes),
295                        ),
296                        provides_languages: ActiveValue::Set(
297                            version.provides.contains(&ExtensionProvides::Languages),
298                        ),
299                        provides_grammars: ActiveValue::Set(
300                            version.provides.contains(&ExtensionProvides::Grammars),
301                        ),
302                        provides_language_servers: ActiveValue::Set(
303                            version
304                                .provides
305                                .contains(&ExtensionProvides::LanguageServers),
306                        ),
307                        provides_context_servers: ActiveValue::Set(
308                            version
309                                .provides
310                                .contains(&ExtensionProvides::ContextServers),
311                        ),
312                        provides_slash_commands: ActiveValue::Set(
313                            version.provides.contains(&ExtensionProvides::SlashCommands),
314                        ),
315                        provides_indexed_docs_providers: ActiveValue::Set(
316                            version
317                                .provides
318                                .contains(&ExtensionProvides::IndexedDocsProviders),
319                        ),
320                        provides_snippets: ActiveValue::Set(
321                            version.provides.contains(&ExtensionProvides::Snippets),
322                        ),
323                        download_count: ActiveValue::NotSet,
324                    }
325                }))
326                .on_conflict(OnConflict::new().do_nothing().to_owned())
327                .exec_without_returning(&*tx)
328                .await?;
329
330                if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
331                    if db_version >= latest_version.version {
332                        continue;
333                    }
334                }
335
336                let mut extension = extension.into_active_model();
337                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
338                extension.name = ActiveValue::set(latest_version.name.clone());
339                extension::Entity::update(extension).exec(&*tx).await?;
340            }
341
342            Ok(())
343        })
344        .await
345    }
346
347    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
348        self.transaction(|tx| async move {
349            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
350            enum QueryId {
351                Id,
352            }
353
354            let extension_id: Option<ExtensionId> = extension::Entity::find()
355                .filter(extension::Column::ExternalId.eq(extension))
356                .select_only()
357                .column(extension::Column::Id)
358                .into_values::<_, QueryId>()
359                .one(&*tx)
360                .await?;
361            let Some(extension_id) = extension_id else {
362                return Ok(false);
363            };
364
365            extension_version::Entity::update_many()
366                .col_expr(
367                    extension_version::Column::DownloadCount,
368                    extension_version::Column::DownloadCount.into_expr().add(1),
369                )
370                .filter(
371                    extension_version::Column::ExtensionId
372                        .eq(extension_id)
373                        .and(extension_version::Column::Version.eq(version)),
374                )
375                .exec(&*tx)
376                .await?;
377
378            extension::Entity::update_many()
379                .col_expr(
380                    extension::Column::TotalDownloadCount,
381                    extension::Column::TotalDownloadCount.into_expr().add(1),
382                )
383                .filter(extension::Column::Id.eq(extension_id))
384                .exec(&*tx)
385                .await?;
386
387            Ok(true)
388        })
389        .await
390    }
391}
392
393fn apply_provides_filter(
394    mut condition: Condition,
395    provides_filter: &BTreeSet<ExtensionProvides>,
396) -> Condition {
397    if provides_filter.contains(&ExtensionProvides::Themes) {
398        condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
399    }
400
401    if provides_filter.contains(&ExtensionProvides::IconThemes) {
402        condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
403    }
404
405    if provides_filter.contains(&ExtensionProvides::Languages) {
406        condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
407    }
408
409    if provides_filter.contains(&ExtensionProvides::Grammars) {
410        condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
411    }
412
413    if provides_filter.contains(&ExtensionProvides::LanguageServers) {
414        condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
415    }
416
417    if provides_filter.contains(&ExtensionProvides::ContextServers) {
418        condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
419    }
420
421    if provides_filter.contains(&ExtensionProvides::SlashCommands) {
422        condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
423    }
424
425    if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
426        condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
427    }
428
429    if provides_filter.contains(&ExtensionProvides::Snippets) {
430        condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
431    }
432
433    condition
434}
435
436fn metadata_from_extension_and_version(
437    extension: extension::Model,
438    version: extension_version::Model,
439) -> ExtensionMetadata {
440    let provides = version.provides();
441
442    ExtensionMetadata {
443        id: extension.external_id.into(),
444        manifest: rpc::ExtensionApiManifest {
445            name: extension.name,
446            version: version.version.into(),
447            authors: version
448                .authors
449                .split(',')
450                .map(|author| author.trim().to_string())
451                .collect::<Vec<_>>(),
452            description: Some(version.description),
453            repository: version.repository,
454            schema_version: Some(version.schema_version),
455            wasm_api_version: version.wasm_api_version,
456            provides,
457        },
458
459        published_at: convert_time_to_chrono(version.published_at),
460        download_count: extension.total_download_count as u64,
461    }
462}
463
464pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
465    chrono::DateTime::from_naive_utc_and_offset(
466        #[allow(deprecated)]
467        chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
468        Utc,
469    )
470}