extensions.rs

  1use std::str::FromStr;
  2
  3use anyhow::Context;
  4use chrono::Utc;
  5use sea_orm::sea_query::IntoCondition;
  6use util::ResultExt;
  7
  8use super::*;
  9
 10impl Database {
 11    pub async fn get_extensions(
 12        &self,
 13        filter: Option<&str>,
 14        provides_filter: Option<&BTreeSet<ExtensionProvides>>,
 15        max_schema_version: i32,
 16        limit: usize,
 17    ) -> Result<Vec<ExtensionMetadata>> {
 18        self.transaction(|tx| async move {
 19            let mut condition = Condition::all()
 20                .add(
 21                    extension::Column::LatestVersion
 22                        .into_expr()
 23                        .eq(extension_version::Column::Version.into_expr()),
 24                )
 25                .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
 26            if let Some(filter) = filter {
 27                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 28                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 29            }
 30
 31            if let Some(provides_filter) = provides_filter {
 32                condition = apply_provides_filter(condition, provides_filter);
 33            }
 34
 35            self.get_extensions_where(condition, Some(limit as u64), &tx)
 36                .await
 37        })
 38        .await
 39    }
 40
 41    pub async fn get_extensions_by_ids(
 42        &self,
 43        ids: &[&str],
 44        constraints: Option<&ExtensionVersionConstraints>,
 45    ) -> Result<Vec<ExtensionMetadata>> {
 46        self.transaction(|tx| async move {
 47            let extensions = extension::Entity::find()
 48                .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
 49                .all(&*tx)
 50                .await?;
 51
 52            let mut max_versions = self
 53                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
 54                .await?;
 55
 56            Ok(extensions
 57                .into_iter()
 58                .filter_map(|extension| {
 59                    let (version, _) = max_versions.remove(&extension.id)?;
 60                    Some(metadata_from_extension_and_version(extension, version))
 61                })
 62                .collect())
 63        })
 64        .await
 65    }
 66
 67    async fn get_latest_versions_for_extensions(
 68        &self,
 69        extensions: &[extension::Model],
 70        constraints: Option<&ExtensionVersionConstraints>,
 71        tx: &DatabaseTransaction,
 72    ) -> Result<HashMap<ExtensionId, (extension_version::Model, Version)>> {
 73        let mut versions = extension_version::Entity::find()
 74            .filter(
 75                extension_version::Column::ExtensionId
 76                    .is_in(extensions.iter().map(|extension| extension.id)),
 77            )
 78            .stream(tx)
 79            .await?;
 80
 81        let mut max_versions =
 82            HashMap::<ExtensionId, (extension_version::Model, Version)>::default();
 83        while let Some(version) = versions.next().await {
 84            let version = version?;
 85            let Some(extension_version) = Version::from_str(&version.version).log_err() else {
 86                continue;
 87            };
 88
 89            if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id)
 90                && max_extension_version > &extension_version
 91            {
 92                continue;
 93            }
 94
 95            if let Some(constraints) = constraints {
 96                if !constraints
 97                    .schema_versions
 98                    .contains(&version.schema_version)
 99                {
100                    continue;
101                }
102
103                if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
104                    if let Some(version) = Version::from_str(wasm_api_version).log_err() {
105                        if !constraints.wasm_api_versions.contains(&version) {
106                            continue;
107                        }
108                    } else {
109                        continue;
110                    }
111                }
112            }
113
114            max_versions.insert(version.extension_id, (version, extension_version));
115        }
116
117        Ok(max_versions)
118    }
119
120    /// Returns all of the versions for the extension with the given ID.
121    pub async fn get_extension_versions(
122        &self,
123        extension_id: &str,
124    ) -> Result<Vec<ExtensionMetadata>> {
125        self.transaction(|tx| async move {
126            let condition = extension::Column::ExternalId
127                .eq(extension_id)
128                .into_condition();
129
130            self.get_extensions_where(condition, None, &tx).await
131        })
132        .await
133    }
134
135    async fn get_extensions_where(
136        &self,
137        condition: Condition,
138        limit: Option<u64>,
139        tx: &DatabaseTransaction,
140    ) -> Result<Vec<ExtensionMetadata>> {
141        let extensions = extension::Entity::find()
142            .inner_join(extension_version::Entity)
143            .select_also(extension_version::Entity)
144            .filter(condition)
145            .order_by_desc(extension::Column::TotalDownloadCount)
146            .order_by_asc(extension::Column::Name)
147            .limit(limit)
148            .all(tx)
149            .await?;
150
151        Ok(extensions
152            .into_iter()
153            .filter_map(|(extension, version)| {
154                Some(metadata_from_extension_and_version(extension, version?))
155            })
156            .collect())
157    }
158
159    pub async fn get_extension(
160        &self,
161        extension_id: &str,
162        constraints: Option<&ExtensionVersionConstraints>,
163    ) -> Result<Option<ExtensionMetadata>> {
164        self.transaction(|tx| async move {
165            let extension = extension::Entity::find()
166                .filter(extension::Column::ExternalId.eq(extension_id))
167                .one(&*tx)
168                .await?
169                .with_context(|| format!("no such extension: {extension_id}"))?;
170
171            let extensions = [extension];
172            let mut versions = self
173                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
174                .await?;
175            let [extension] = extensions;
176
177            Ok(versions.remove(&extension.id).map(|(max_version, _)| {
178                metadata_from_extension_and_version(extension, max_version)
179            }))
180        })
181        .await
182    }
183
184    pub async fn get_extension_version(
185        &self,
186        extension_id: &str,
187        version: &str,
188    ) -> Result<Option<ExtensionMetadata>> {
189        self.transaction(|tx| async move {
190            let extension = extension::Entity::find()
191                .filter(extension::Column::ExternalId.eq(extension_id))
192                .filter(extension_version::Column::Version.eq(version))
193                .inner_join(extension_version::Entity)
194                .select_also(extension_version::Entity)
195                .one(&*tx)
196                .await?;
197
198            Ok(extension.and_then(|(extension, version)| {
199                Some(metadata_from_extension_and_version(extension, version?))
200            }))
201        })
202        .await
203    }
204
205    pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
206        self.transaction(|tx| async move {
207            let mut extension_external_ids_by_id = HashMap::default();
208
209            let mut rows = extension::Entity::find().stream(&*tx).await?;
210            while let Some(row) = rows.next().await {
211                let row = row?;
212                extension_external_ids_by_id.insert(row.id, row.external_id);
213            }
214            drop(rows);
215
216            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
217                HashMap::default();
218            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
219            while let Some(row) = rows.next().await {
220                let row = row?;
221
222                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
223                    continue;
224                };
225
226                let versions = known_versions_by_extension_id
227                    .entry(extension_id.clone())
228                    .or_default();
229                if let Err(ix) = versions.binary_search(&row.version) {
230                    versions.insert(ix, row.version);
231                }
232            }
233            drop(rows);
234
235            Ok(known_versions_by_extension_id)
236        })
237        .await
238    }
239
240    pub async fn insert_extension_versions(
241        &self,
242        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
243    ) -> Result<()> {
244        self.transaction(|tx| async move {
245            for (external_id, versions) in versions_by_extension_id {
246                if versions.is_empty() {
247                    continue;
248                }
249
250                let latest_version = versions
251                    .iter()
252                    .max_by_key(|version| &version.version)
253                    .unwrap();
254
255                let insert = extension::Entity::insert(extension::ActiveModel {
256                    name: ActiveValue::Set(latest_version.name.clone()),
257                    external_id: ActiveValue::Set((*external_id).to_owned()),
258                    id: ActiveValue::NotSet,
259                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
260                    total_download_count: ActiveValue::NotSet,
261                })
262                .on_conflict(
263                    OnConflict::columns([extension::Column::ExternalId])
264                        .update_column(extension::Column::ExternalId)
265                        .to_owned(),
266                );
267
268                let extension = if tx.support_returning() {
269                    insert.exec_with_returning(&*tx).await?
270                } else {
271                    // Sqlite
272                    insert.exec_without_returning(&*tx).await?;
273                    extension::Entity::find()
274                        .filter(extension::Column::ExternalId.eq(*external_id))
275                        .one(&*tx)
276                        .await?
277                        .context("failed to insert extension")?
278                };
279
280                extension_version::Entity::insert_many(versions.iter().map(|version| {
281                    extension_version::ActiveModel {
282                        extension_id: ActiveValue::Set(extension.id),
283                        published_at: ActiveValue::Set(version.published_at),
284                        version: ActiveValue::Set(version.version.to_string()),
285                        authors: ActiveValue::Set(version.authors.join(", ")),
286                        repository: ActiveValue::Set(version.repository.clone()),
287                        description: ActiveValue::Set(version.description.clone()),
288                        schema_version: ActiveValue::Set(version.schema_version),
289                        wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
290                        provides_themes: ActiveValue::Set(
291                            version.provides.contains(&ExtensionProvides::Themes),
292                        ),
293                        provides_icon_themes: ActiveValue::Set(
294                            version.provides.contains(&ExtensionProvides::IconThemes),
295                        ),
296                        provides_languages: ActiveValue::Set(
297                            version.provides.contains(&ExtensionProvides::Languages),
298                        ),
299                        provides_grammars: ActiveValue::Set(
300                            version.provides.contains(&ExtensionProvides::Grammars),
301                        ),
302                        provides_language_servers: ActiveValue::Set(
303                            version
304                                .provides
305                                .contains(&ExtensionProvides::LanguageServers),
306                        ),
307                        provides_context_servers: ActiveValue::Set(
308                            version
309                                .provides
310                                .contains(&ExtensionProvides::ContextServers),
311                        ),
312                        provides_agent_servers: ActiveValue::Set(
313                            version.provides.contains(&ExtensionProvides::AgentServers),
314                        ),
315                        provides_slash_commands: ActiveValue::Set(
316                            version.provides.contains(&ExtensionProvides::SlashCommands),
317                        ),
318                        provides_indexed_docs_providers: ActiveValue::Set(
319                            version
320                                .provides
321                                .contains(&ExtensionProvides::IndexedDocsProviders),
322                        ),
323                        provides_snippets: ActiveValue::Set(
324                            version.provides.contains(&ExtensionProvides::Snippets),
325                        ),
326                        provides_debug_adapters: ActiveValue::Set(
327                            version.provides.contains(&ExtensionProvides::DebugAdapters),
328                        ),
329                        download_count: ActiveValue::NotSet,
330                    }
331                }))
332                .on_conflict(OnConflict::new().do_nothing().to_owned())
333                .exec_without_returning(&*tx)
334                .await?;
335
336                if let Ok(db_version) = semver::Version::parse(&extension.latest_version)
337                    && db_version >= latest_version.version
338                {
339                    continue;
340                }
341
342                let mut extension = extension.into_active_model();
343                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
344                extension.name = ActiveValue::set(latest_version.name.clone());
345                extension::Entity::update(extension).exec(&*tx).await?;
346            }
347
348            Ok(())
349        })
350        .await
351    }
352
353    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
354        self.transaction(|tx| async move {
355            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
356            enum QueryId {
357                Id,
358            }
359
360            let extension_id: Option<ExtensionId> = extension::Entity::find()
361                .filter(extension::Column::ExternalId.eq(extension))
362                .select_only()
363                .column(extension::Column::Id)
364                .into_values::<_, QueryId>()
365                .one(&*tx)
366                .await?;
367            let Some(extension_id) = extension_id else {
368                return Ok(false);
369            };
370
371            extension_version::Entity::update_many()
372                .col_expr(
373                    extension_version::Column::DownloadCount,
374                    extension_version::Column::DownloadCount.into_expr().add(1),
375                )
376                .filter(
377                    extension_version::Column::ExtensionId
378                        .eq(extension_id)
379                        .and(extension_version::Column::Version.eq(version)),
380                )
381                .exec(&*tx)
382                .await?;
383
384            extension::Entity::update_many()
385                .col_expr(
386                    extension::Column::TotalDownloadCount,
387                    extension::Column::TotalDownloadCount.into_expr().add(1),
388                )
389                .filter(extension::Column::Id.eq(extension_id))
390                .exec(&*tx)
391                .await?;
392
393            Ok(true)
394        })
395        .await
396    }
397}
398
399fn apply_provides_filter(
400    mut condition: Condition,
401    provides_filter: &BTreeSet<ExtensionProvides>,
402) -> Condition {
403    if provides_filter.contains(&ExtensionProvides::Themes) {
404        condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
405    }
406
407    if provides_filter.contains(&ExtensionProvides::IconThemes) {
408        condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
409    }
410
411    if provides_filter.contains(&ExtensionProvides::Languages) {
412        condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
413    }
414
415    if provides_filter.contains(&ExtensionProvides::Grammars) {
416        condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
417    }
418
419    if provides_filter.contains(&ExtensionProvides::LanguageServers) {
420        condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
421    }
422
423    if provides_filter.contains(&ExtensionProvides::ContextServers) {
424        condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
425    }
426
427    if provides_filter.contains(&ExtensionProvides::AgentServers) {
428        condition = condition.add(extension_version::Column::ProvidesAgentServers.eq(true));
429    }
430
431    if provides_filter.contains(&ExtensionProvides::SlashCommands) {
432        condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
433    }
434
435    if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
436        condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
437    }
438
439    if provides_filter.contains(&ExtensionProvides::Snippets) {
440        condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
441    }
442
443    if provides_filter.contains(&ExtensionProvides::DebugAdapters) {
444        condition = condition.add(extension_version::Column::ProvidesDebugAdapters.eq(true));
445    }
446
447    condition
448}
449
450fn metadata_from_extension_and_version(
451    extension: extension::Model,
452    version: extension_version::Model,
453) -> ExtensionMetadata {
454    let provides = version.provides();
455
456    ExtensionMetadata {
457        id: extension.external_id.into(),
458        manifest: rpc::ExtensionApiManifest {
459            name: extension.name,
460            version: version.version.into(),
461            authors: version
462                .authors
463                .split(',')
464                .map(|author| author.trim().to_string())
465                .collect::<Vec<_>>(),
466            description: Some(version.description),
467            repository: version.repository,
468            schema_version: Some(version.schema_version),
469            wasm_api_version: version.wasm_api_version,
470            provides,
471        },
472
473        published_at: convert_time_to_chrono(version.published_at),
474        download_count: extension.total_download_count as u64,
475    }
476}
477
478pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
479    chrono::DateTime::from_naive_utc_and_offset(
480        #[allow(deprecated)]
481        chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
482        Utc,
483    )
484}