1use std::str::FromStr;
  2
  3use anyhow::Context;
  4use chrono::Utc;
  5use sea_orm::sea_query::IntoCondition;
  6use util::ResultExt;
  7
  8use super::*;
  9
 10impl Database {
 11    pub async fn get_extensions(
 12        &self,
 13        filter: Option<&str>,
 14        provides_filter: Option<&BTreeSet<ExtensionProvides>>,
 15        max_schema_version: i32,
 16        limit: usize,
 17    ) -> Result<Vec<ExtensionMetadata>> {
 18        self.transaction(|tx| async move {
 19            let mut condition = Condition::all()
 20                .add(
 21                    extension::Column::LatestVersion
 22                        .into_expr()
 23                        .eq(extension_version::Column::Version.into_expr()),
 24                )
 25                .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
 26            if let Some(filter) = filter {
 27                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 28                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 29            }
 30
 31            if let Some(provides_filter) = provides_filter {
 32                condition = apply_provides_filter(condition, provides_filter);
 33            }
 34
 35            self.get_extensions_where(condition, Some(limit as u64), &tx)
 36                .await
 37        })
 38        .await
 39    }
 40
 41    pub async fn get_extensions_by_ids(
 42        &self,
 43        ids: &[&str],
 44        constraints: Option<&ExtensionVersionConstraints>,
 45    ) -> Result<Vec<ExtensionMetadata>> {
 46        self.transaction(|tx| async move {
 47            let extensions = extension::Entity::find()
 48                .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
 49                .all(&*tx)
 50                .await?;
 51
 52            let mut max_versions = self
 53                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
 54                .await?;
 55
 56            Ok(extensions
 57                .into_iter()
 58                .filter_map(|extension| {
 59                    let (version, _) = max_versions.remove(&extension.id)?;
 60                    Some(metadata_from_extension_and_version(extension, version))
 61                })
 62                .collect())
 63        })
 64        .await
 65    }
 66
 67    async fn get_latest_versions_for_extensions(
 68        &self,
 69        extensions: &[extension::Model],
 70        constraints: Option<&ExtensionVersionConstraints>,
 71        tx: &DatabaseTransaction,
 72    ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
 73        let mut versions = extension_version::Entity::find()
 74            .filter(
 75                extension_version::Column::ExtensionId
 76                    .is_in(extensions.iter().map(|extension| extension.id)),
 77            )
 78            .stream(tx)
 79            .await?;
 80
 81        let mut max_versions =
 82            HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
 83        while let Some(version) = versions.next().await {
 84            let version = version?;
 85            let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
 86            else {
 87                continue;
 88            };
 89
 90            if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id)
 91                && max_extension_version > &extension_version
 92            {
 93                continue;
 94            }
 95
 96            if let Some(constraints) = constraints {
 97                if !constraints
 98                    .schema_versions
 99                    .contains(&version.schema_version)
100                {
101                    continue;
102                }
103
104                if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
105                    if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
106                        if !constraints.wasm_api_versions.contains(&version) {
107                            continue;
108                        }
109                    } else {
110                        continue;
111                    }
112                }
113            }
114
115            max_versions.insert(version.extension_id, (version, extension_version));
116        }
117
118        Ok(max_versions)
119    }
120
121    /// Returns all of the versions for the extension with the given ID.
122    pub async fn get_extension_versions(
123        &self,
124        extension_id: &str,
125    ) -> Result<Vec<ExtensionMetadata>> {
126        self.transaction(|tx| async move {
127            let condition = extension::Column::ExternalId
128                .eq(extension_id)
129                .into_condition();
130
131            self.get_extensions_where(condition, None, &tx).await
132        })
133        .await
134    }
135
136    async fn get_extensions_where(
137        &self,
138        condition: Condition,
139        limit: Option<u64>,
140        tx: &DatabaseTransaction,
141    ) -> Result<Vec<ExtensionMetadata>> {
142        let extensions = extension::Entity::find()
143            .inner_join(extension_version::Entity)
144            .select_also(extension_version::Entity)
145            .filter(condition)
146            .order_by_desc(extension::Column::TotalDownloadCount)
147            .order_by_asc(extension::Column::Name)
148            .limit(limit)
149            .all(tx)
150            .await?;
151
152        Ok(extensions
153            .into_iter()
154            .filter_map(|(extension, version)| {
155                Some(metadata_from_extension_and_version(extension, version?))
156            })
157            .collect())
158    }
159
160    pub async fn get_extension(
161        &self,
162        extension_id: &str,
163        constraints: Option<&ExtensionVersionConstraints>,
164    ) -> Result<Option<ExtensionMetadata>> {
165        self.transaction(|tx| async move {
166            let extension = extension::Entity::find()
167                .filter(extension::Column::ExternalId.eq(extension_id))
168                .one(&*tx)
169                .await?
170                .with_context(|| format!("no such extension: {extension_id}"))?;
171
172            let extensions = [extension];
173            let mut versions = self
174                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
175                .await?;
176            let [extension] = extensions;
177
178            Ok(versions.remove(&extension.id).map(|(max_version, _)| {
179                metadata_from_extension_and_version(extension, max_version)
180            }))
181        })
182        .await
183    }
184
185    pub async fn get_extension_version(
186        &self,
187        extension_id: &str,
188        version: &str,
189    ) -> Result<Option<ExtensionMetadata>> {
190        self.transaction(|tx| async move {
191            let extension = extension::Entity::find()
192                .filter(extension::Column::ExternalId.eq(extension_id))
193                .filter(extension_version::Column::Version.eq(version))
194                .inner_join(extension_version::Entity)
195                .select_also(extension_version::Entity)
196                .one(&*tx)
197                .await?;
198
199            Ok(extension.and_then(|(extension, version)| {
200                Some(metadata_from_extension_and_version(extension, version?))
201            }))
202        })
203        .await
204    }
205
206    pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
207        self.transaction(|tx| async move {
208            let mut extension_external_ids_by_id = HashMap::default();
209
210            let mut rows = extension::Entity::find().stream(&*tx).await?;
211            while let Some(row) = rows.next().await {
212                let row = row?;
213                extension_external_ids_by_id.insert(row.id, row.external_id);
214            }
215            drop(rows);
216
217            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
218                HashMap::default();
219            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
220            while let Some(row) = rows.next().await {
221                let row = row?;
222
223                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
224                    continue;
225                };
226
227                let versions = known_versions_by_extension_id
228                    .entry(extension_id.clone())
229                    .or_default();
230                if let Err(ix) = versions.binary_search(&row.version) {
231                    versions.insert(ix, row.version);
232                }
233            }
234            drop(rows);
235
236            Ok(known_versions_by_extension_id)
237        })
238        .await
239    }
240
241    pub async fn insert_extension_versions(
242        &self,
243        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
244    ) -> Result<()> {
245        self.transaction(|tx| async move {
246            for (external_id, versions) in versions_by_extension_id {
247                if versions.is_empty() {
248                    continue;
249                }
250
251                let latest_version = versions
252                    .iter()
253                    .max_by_key(|version| &version.version)
254                    .unwrap();
255
256                let insert = extension::Entity::insert(extension::ActiveModel {
257                    name: ActiveValue::Set(latest_version.name.clone()),
258                    external_id: ActiveValue::Set((*external_id).to_owned()),
259                    id: ActiveValue::NotSet,
260                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
261                    total_download_count: ActiveValue::NotSet,
262                })
263                .on_conflict(
264                    OnConflict::columns([extension::Column::ExternalId])
265                        .update_column(extension::Column::ExternalId)
266                        .to_owned(),
267                );
268
269                let extension = if tx.support_returning() {
270                    insert.exec_with_returning(&*tx).await?
271                } else {
272                    // Sqlite
273                    insert.exec_without_returning(&*tx).await?;
274                    extension::Entity::find()
275                        .filter(extension::Column::ExternalId.eq(*external_id))
276                        .one(&*tx)
277                        .await?
278                        .context("failed to insert extension")?
279                };
280
281                extension_version::Entity::insert_many(versions.iter().map(|version| {
282                    extension_version::ActiveModel {
283                        extension_id: ActiveValue::Set(extension.id),
284                        published_at: ActiveValue::Set(version.published_at),
285                        version: ActiveValue::Set(version.version.to_string()),
286                        authors: ActiveValue::Set(version.authors.join(", ")),
287                        repository: ActiveValue::Set(version.repository.clone()),
288                        description: ActiveValue::Set(version.description.clone()),
289                        schema_version: ActiveValue::Set(version.schema_version),
290                        wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
291                        provides_themes: ActiveValue::Set(
292                            version.provides.contains(&ExtensionProvides::Themes),
293                        ),
294                        provides_icon_themes: ActiveValue::Set(
295                            version.provides.contains(&ExtensionProvides::IconThemes),
296                        ),
297                        provides_languages: ActiveValue::Set(
298                            version.provides.contains(&ExtensionProvides::Languages),
299                        ),
300                        provides_grammars: ActiveValue::Set(
301                            version.provides.contains(&ExtensionProvides::Grammars),
302                        ),
303                        provides_language_servers: ActiveValue::Set(
304                            version
305                                .provides
306                                .contains(&ExtensionProvides::LanguageServers),
307                        ),
308                        provides_context_servers: ActiveValue::Set(
309                            version
310                                .provides
311                                .contains(&ExtensionProvides::ContextServers),
312                        ),
313                        provides_agent_servers: ActiveValue::Set(
314                            version.provides.contains(&ExtensionProvides::AgentServers),
315                        ),
316                        provides_slash_commands: ActiveValue::Set(
317                            version.provides.contains(&ExtensionProvides::SlashCommands),
318                        ),
319                        provides_indexed_docs_providers: ActiveValue::Set(
320                            version
321                                .provides
322                                .contains(&ExtensionProvides::IndexedDocsProviders),
323                        ),
324                        provides_snippets: ActiveValue::Set(
325                            version.provides.contains(&ExtensionProvides::Snippets),
326                        ),
327                        provides_debug_adapters: ActiveValue::Set(
328                            version.provides.contains(&ExtensionProvides::DebugAdapters),
329                        ),
330                        download_count: ActiveValue::NotSet,
331                    }
332                }))
333                .on_conflict(OnConflict::new().do_nothing().to_owned())
334                .exec_without_returning(&*tx)
335                .await?;
336
337                if let Ok(db_version) = semver::Version::parse(&extension.latest_version)
338                    && db_version >= latest_version.version
339                {
340                    continue;
341                }
342
343                let mut extension = extension.into_active_model();
344                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
345                extension.name = ActiveValue::set(latest_version.name.clone());
346                extension::Entity::update(extension).exec(&*tx).await?;
347            }
348
349            Ok(())
350        })
351        .await
352    }
353
354    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
355        self.transaction(|tx| async move {
356            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
357            enum QueryId {
358                Id,
359            }
360
361            let extension_id: Option<ExtensionId> = extension::Entity::find()
362                .filter(extension::Column::ExternalId.eq(extension))
363                .select_only()
364                .column(extension::Column::Id)
365                .into_values::<_, QueryId>()
366                .one(&*tx)
367                .await?;
368            let Some(extension_id) = extension_id else {
369                return Ok(false);
370            };
371
372            extension_version::Entity::update_many()
373                .col_expr(
374                    extension_version::Column::DownloadCount,
375                    extension_version::Column::DownloadCount.into_expr().add(1),
376                )
377                .filter(
378                    extension_version::Column::ExtensionId
379                        .eq(extension_id)
380                        .and(extension_version::Column::Version.eq(version)),
381                )
382                .exec(&*tx)
383                .await?;
384
385            extension::Entity::update_many()
386                .col_expr(
387                    extension::Column::TotalDownloadCount,
388                    extension::Column::TotalDownloadCount.into_expr().add(1),
389                )
390                .filter(extension::Column::Id.eq(extension_id))
391                .exec(&*tx)
392                .await?;
393
394            Ok(true)
395        })
396        .await
397    }
398}
399
400fn apply_provides_filter(
401    mut condition: Condition,
402    provides_filter: &BTreeSet<ExtensionProvides>,
403) -> Condition {
404    if provides_filter.contains(&ExtensionProvides::Themes) {
405        condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
406    }
407
408    if provides_filter.contains(&ExtensionProvides::IconThemes) {
409        condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
410    }
411
412    if provides_filter.contains(&ExtensionProvides::Languages) {
413        condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
414    }
415
416    if provides_filter.contains(&ExtensionProvides::Grammars) {
417        condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
418    }
419
420    if provides_filter.contains(&ExtensionProvides::LanguageServers) {
421        condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
422    }
423
424    if provides_filter.contains(&ExtensionProvides::ContextServers) {
425        condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
426    }
427
428    if provides_filter.contains(&ExtensionProvides::AgentServers) {
429        condition = condition.add(extension_version::Column::ProvidesAgentServers.eq(true));
430    }
431
432    if provides_filter.contains(&ExtensionProvides::SlashCommands) {
433        condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
434    }
435
436    if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
437        condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
438    }
439
440    if provides_filter.contains(&ExtensionProvides::Snippets) {
441        condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
442    }
443
444    if provides_filter.contains(&ExtensionProvides::DebugAdapters) {
445        condition = condition.add(extension_version::Column::ProvidesDebugAdapters.eq(true));
446    }
447
448    condition
449}
450
451fn metadata_from_extension_and_version(
452    extension: extension::Model,
453    version: extension_version::Model,
454) -> ExtensionMetadata {
455    let provides = version.provides();
456
457    ExtensionMetadata {
458        id: extension.external_id.into(),
459        manifest: rpc::ExtensionApiManifest {
460            name: extension.name,
461            version: version.version.into(),
462            authors: version
463                .authors
464                .split(',')
465                .map(|author| author.trim().to_string())
466                .collect::<Vec<_>>(),
467            description: Some(version.description),
468            repository: version.repository,
469            schema_version: Some(version.schema_version),
470            wasm_api_version: version.wasm_api_version,
471            provides,
472        },
473
474        published_at: convert_time_to_chrono(version.published_at),
475        download_count: extension.total_download_count as u64,
476    }
477}
478
479pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
480    chrono::DateTime::from_naive_utc_and_offset(
481        #[allow(deprecated)]
482        chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
483        Utc,
484    )
485}