extensions.rs

  1use std::str::FromStr;
  2
  3use anyhow::Context;
  4use chrono::Utc;
  5use sea_orm::sea_query::IntoCondition;
  6use util::ResultExt;
  7
  8use super::*;
  9
 10impl Database {
 11    pub async fn get_extensions(
 12        &self,
 13        filter: Option<&str>,
 14        provides_filter: Option<&BTreeSet<ExtensionProvides>>,
 15        max_schema_version: i32,
 16        limit: usize,
 17    ) -> Result<Vec<ExtensionMetadata>> {
 18        self.weak_transaction(|tx| async move {
 19            let mut condition = Condition::all()
 20                .add(
 21                    extension::Column::LatestVersion
 22                        .into_expr()
 23                        .eq(extension_version::Column::Version.into_expr()),
 24                )
 25                .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
 26            if let Some(filter) = filter {
 27                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 28                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 29            }
 30
 31            if let Some(provides_filter) = provides_filter {
 32                condition = apply_provides_filter(condition, provides_filter);
 33            }
 34
 35            self.get_extensions_where(condition, Some(limit as u64), &tx)
 36                .await
 37        })
 38        .await
 39    }
 40
 41    pub async fn get_extensions_by_ids(
 42        &self,
 43        ids: &[&str],
 44        constraints: Option<&ExtensionVersionConstraints>,
 45    ) -> Result<Vec<ExtensionMetadata>> {
 46        self.weak_transaction(|tx| async move {
 47            let extensions = extension::Entity::find()
 48                .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
 49                .all(&*tx)
 50                .await?;
 51
 52            let mut max_versions = self
 53                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
 54                .await?;
 55
 56            Ok(extensions
 57                .into_iter()
 58                .filter_map(|extension| {
 59                    let (version, _) = max_versions.remove(&extension.id)?;
 60                    Some(metadata_from_extension_and_version(extension, version))
 61                })
 62                .collect())
 63        })
 64        .await
 65    }
 66
 67    async fn get_latest_versions_for_extensions(
 68        &self,
 69        extensions: &[extension::Model],
 70        constraints: Option<&ExtensionVersionConstraints>,
 71        tx: &DatabaseTransaction,
 72    ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
 73        let mut versions = extension_version::Entity::find()
 74            .filter(
 75                extension_version::Column::ExtensionId
 76                    .is_in(extensions.iter().map(|extension| extension.id)),
 77            )
 78            .stream(tx)
 79            .await?;
 80
 81        let mut max_versions =
 82            HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
 83        while let Some(version) = versions.next().await {
 84            let version = version?;
 85            let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
 86            else {
 87                continue;
 88            };
 89
 90            if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
 91                if max_extension_version > &extension_version {
 92                    continue;
 93                }
 94            }
 95
 96            if let Some(constraints) = constraints {
 97                if !constraints
 98                    .schema_versions
 99                    .contains(&version.schema_version)
100                {
101                    continue;
102                }
103
104                if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
105                    if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
106                        if !constraints.wasm_api_versions.contains(&version) {
107                            continue;
108                        }
109                    } else {
110                        continue;
111                    }
112                }
113            }
114
115            max_versions.insert(version.extension_id, (version, extension_version));
116        }
117
118        Ok(max_versions)
119    }
120
121    /// Returns all of the versions for the extension with the given ID.
122    pub async fn get_extension_versions(
123        &self,
124        extension_id: &str,
125    ) -> Result<Vec<ExtensionMetadata>> {
126        self.weak_transaction(|tx| async move {
127            let condition = extension::Column::ExternalId
128                .eq(extension_id)
129                .into_condition();
130
131            self.get_extensions_where(condition, None, &tx).await
132        })
133        .await
134    }
135
136    async fn get_extensions_where(
137        &self,
138        condition: Condition,
139        limit: Option<u64>,
140        tx: &DatabaseTransaction,
141    ) -> Result<Vec<ExtensionMetadata>> {
142        let extensions = extension::Entity::find()
143            .inner_join(extension_version::Entity)
144            .select_also(extension_version::Entity)
145            .filter(condition)
146            .order_by_desc(extension::Column::TotalDownloadCount)
147            .order_by_asc(extension::Column::Name)
148            .limit(limit)
149            .all(tx)
150            .await?;
151
152        Ok(extensions
153            .into_iter()
154            .filter_map(|(extension, version)| {
155                Some(metadata_from_extension_and_version(extension, version?))
156            })
157            .collect())
158    }
159
160    pub async fn get_extension(
161        &self,
162        extension_id: &str,
163        constraints: Option<&ExtensionVersionConstraints>,
164    ) -> Result<Option<ExtensionMetadata>> {
165        self.weak_transaction(|tx| async move {
166            let extension = extension::Entity::find()
167                .filter(extension::Column::ExternalId.eq(extension_id))
168                .one(&*tx)
169                .await?
170                .with_context(|| format!("no such extension: {extension_id}"))?;
171
172            let extensions = [extension];
173            let mut versions = self
174                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
175                .await?;
176            let [extension] = extensions;
177
178            Ok(versions.remove(&extension.id).map(|(max_version, _)| {
179                metadata_from_extension_and_version(extension, max_version)
180            }))
181        })
182        .await
183    }
184
185    pub async fn get_extension_version(
186        &self,
187        extension_id: &str,
188        version: &str,
189    ) -> Result<Option<ExtensionMetadata>> {
190        self.weak_transaction(|tx| async move {
191            let extension = extension::Entity::find()
192                .filter(extension::Column::ExternalId.eq(extension_id))
193                .filter(extension_version::Column::Version.eq(version))
194                .inner_join(extension_version::Entity)
195                .select_also(extension_version::Entity)
196                .one(&*tx)
197                .await?;
198
199            Ok(extension.and_then(|(extension, version)| {
200                Some(metadata_from_extension_and_version(extension, version?))
201            }))
202        })
203        .await
204    }
205
206    pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
207        self.weak_transaction(|tx| async move {
208            let mut extension_external_ids_by_id = HashMap::default();
209
210            let mut rows = extension::Entity::find().stream(&*tx).await?;
211            while let Some(row) = rows.next().await {
212                let row = row?;
213                extension_external_ids_by_id.insert(row.id, row.external_id);
214            }
215            drop(rows);
216
217            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
218                HashMap::default();
219            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
220            while let Some(row) = rows.next().await {
221                let row = row?;
222
223                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
224                    continue;
225                };
226
227                let versions = known_versions_by_extension_id
228                    .entry(extension_id.clone())
229                    .or_default();
230                if let Err(ix) = versions.binary_search(&row.version) {
231                    versions.insert(ix, row.version);
232                }
233            }
234            drop(rows);
235
236            Ok(known_versions_by_extension_id)
237        })
238        .await
239    }
240
241    pub async fn insert_extension_versions(
242        &self,
243        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
244    ) -> Result<()> {
245        self.weak_transaction(|tx| async move {
246            for (external_id, versions) in versions_by_extension_id {
247                if versions.is_empty() {
248                    continue;
249                }
250
251                let latest_version = versions
252                    .iter()
253                    .max_by_key(|version| &version.version)
254                    .unwrap();
255
256                let insert = extension::Entity::insert(extension::ActiveModel {
257                    name: ActiveValue::Set(latest_version.name.clone()),
258                    external_id: ActiveValue::Set(external_id.to_string()),
259                    id: ActiveValue::NotSet,
260                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
261                    total_download_count: ActiveValue::NotSet,
262                })
263                .on_conflict(
264                    OnConflict::columns([extension::Column::ExternalId])
265                        .update_column(extension::Column::ExternalId)
266                        .to_owned(),
267                );
268
269                let extension = if tx.support_returning() {
270                    insert.exec_with_returning(&*tx).await?
271                } else {
272                    // Sqlite
273                    insert.exec_without_returning(&*tx).await?;
274                    extension::Entity::find()
275                        .filter(extension::Column::ExternalId.eq(*external_id))
276                        .one(&*tx)
277                        .await?
278                        .context("failed to insert extension")?
279                };
280
281                extension_version::Entity::insert_many(versions.iter().map(|version| {
282                    extension_version::ActiveModel {
283                        extension_id: ActiveValue::Set(extension.id),
284                        published_at: ActiveValue::Set(version.published_at),
285                        version: ActiveValue::Set(version.version.to_string()),
286                        authors: ActiveValue::Set(version.authors.join(", ")),
287                        repository: ActiveValue::Set(version.repository.clone()),
288                        description: ActiveValue::Set(version.description.clone()),
289                        schema_version: ActiveValue::Set(version.schema_version),
290                        wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
291                        provides_themes: ActiveValue::Set(
292                            version.provides.contains(&ExtensionProvides::Themes),
293                        ),
294                        provides_icon_themes: ActiveValue::Set(
295                            version.provides.contains(&ExtensionProvides::IconThemes),
296                        ),
297                        provides_languages: ActiveValue::Set(
298                            version.provides.contains(&ExtensionProvides::Languages),
299                        ),
300                        provides_grammars: ActiveValue::Set(
301                            version.provides.contains(&ExtensionProvides::Grammars),
302                        ),
303                        provides_language_servers: ActiveValue::Set(
304                            version
305                                .provides
306                                .contains(&ExtensionProvides::LanguageServers),
307                        ),
308                        provides_context_servers: ActiveValue::Set(
309                            version
310                                .provides
311                                .contains(&ExtensionProvides::ContextServers),
312                        ),
313                        provides_slash_commands: ActiveValue::Set(
314                            version.provides.contains(&ExtensionProvides::SlashCommands),
315                        ),
316                        provides_indexed_docs_providers: ActiveValue::Set(
317                            version
318                                .provides
319                                .contains(&ExtensionProvides::IndexedDocsProviders),
320                        ),
321                        provides_snippets: ActiveValue::Set(
322                            version.provides.contains(&ExtensionProvides::Snippets),
323                        ),
324                        download_count: ActiveValue::NotSet,
325                    }
326                }))
327                .on_conflict(OnConflict::new().do_nothing().to_owned())
328                .exec_without_returning(&*tx)
329                .await?;
330
331                if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
332                    if db_version >= latest_version.version {
333                        continue;
334                    }
335                }
336
337                let mut extension = extension.into_active_model();
338                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
339                extension.name = ActiveValue::set(latest_version.name.clone());
340                extension::Entity::update(extension).exec(&*tx).await?;
341            }
342
343            Ok(())
344        })
345        .await
346    }
347
348    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
349        self.weak_transaction(|tx| async move {
350            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
351            enum QueryId {
352                Id,
353            }
354
355            let extension_id: Option<ExtensionId> = extension::Entity::find()
356                .filter(extension::Column::ExternalId.eq(extension))
357                .select_only()
358                .column(extension::Column::Id)
359                .into_values::<_, QueryId>()
360                .one(&*tx)
361                .await?;
362            let Some(extension_id) = extension_id else {
363                return Ok(false);
364            };
365
366            extension_version::Entity::update_many()
367                .col_expr(
368                    extension_version::Column::DownloadCount,
369                    extension_version::Column::DownloadCount.into_expr().add(1),
370                )
371                .filter(
372                    extension_version::Column::ExtensionId
373                        .eq(extension_id)
374                        .and(extension_version::Column::Version.eq(version)),
375                )
376                .exec(&*tx)
377                .await?;
378
379            extension::Entity::update_many()
380                .col_expr(
381                    extension::Column::TotalDownloadCount,
382                    extension::Column::TotalDownloadCount.into_expr().add(1),
383                )
384                .filter(extension::Column::Id.eq(extension_id))
385                .exec(&*tx)
386                .await?;
387
388            Ok(true)
389        })
390        .await
391    }
392}
393
394fn apply_provides_filter(
395    mut condition: Condition,
396    provides_filter: &BTreeSet<ExtensionProvides>,
397) -> Condition {
398    if provides_filter.contains(&ExtensionProvides::Themes) {
399        condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
400    }
401
402    if provides_filter.contains(&ExtensionProvides::IconThemes) {
403        condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
404    }
405
406    if provides_filter.contains(&ExtensionProvides::Languages) {
407        condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
408    }
409
410    if provides_filter.contains(&ExtensionProvides::Grammars) {
411        condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
412    }
413
414    if provides_filter.contains(&ExtensionProvides::LanguageServers) {
415        condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
416    }
417
418    if provides_filter.contains(&ExtensionProvides::ContextServers) {
419        condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
420    }
421
422    if provides_filter.contains(&ExtensionProvides::SlashCommands) {
423        condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
424    }
425
426    if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
427        condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
428    }
429
430    if provides_filter.contains(&ExtensionProvides::Snippets) {
431        condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
432    }
433
434    condition
435}
436
437fn metadata_from_extension_and_version(
438    extension: extension::Model,
439    version: extension_version::Model,
440) -> ExtensionMetadata {
441    let provides = version.provides();
442
443    ExtensionMetadata {
444        id: extension.external_id.into(),
445        manifest: rpc::ExtensionApiManifest {
446            name: extension.name,
447            version: version.version.into(),
448            authors: version
449                .authors
450                .split(',')
451                .map(|author| author.trim().to_string())
452                .collect::<Vec<_>>(),
453            description: Some(version.description),
454            repository: version.repository,
455            schema_version: Some(version.schema_version),
456            wasm_api_version: version.wasm_api_version,
457            provides,
458        },
459
460        published_at: convert_time_to_chrono(version.published_at),
461        download_count: extension.total_download_count as u64,
462    }
463}
464
465pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
466    chrono::DateTime::from_naive_utc_and_offset(
467        #[allow(deprecated)]
468        chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
469        Utc,
470    )
471}