extensions.rs

  1use std::str::FromStr;
  2
  3use anyhow::Context;
  4use chrono::Utc;
  5use sea_orm::sea_query::IntoCondition;
  6use util::ResultExt;
  7
  8use super::*;
  9
 10impl Database {
 11    pub async fn get_extensions_by_ids(
 12        &self,
 13        ids: &[&str],
 14        constraints: Option<&ExtensionVersionConstraints>,
 15    ) -> Result<Vec<ExtensionMetadata>> {
 16        self.transaction(|tx| async move {
 17            let extensions = extension::Entity::find()
 18                .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
 19                .all(&*tx)
 20                .await?;
 21
 22            let mut max_versions = self
 23                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
 24                .await?;
 25
 26            Ok(extensions
 27                .into_iter()
 28                .filter_map(|extension| {
 29                    let (version, _) = max_versions.remove(&extension.id)?;
 30                    Some(metadata_from_extension_and_version(extension, version))
 31                })
 32                .collect())
 33        })
 34        .await
 35    }
 36
 37    async fn get_latest_versions_for_extensions(
 38        &self,
 39        extensions: &[extension::Model],
 40        constraints: Option<&ExtensionVersionConstraints>,
 41        tx: &DatabaseTransaction,
 42    ) -> Result<HashMap<ExtensionId, (extension_version::Model, Version)>> {
 43        let mut versions = extension_version::Entity::find()
 44            .filter(
 45                extension_version::Column::ExtensionId
 46                    .is_in(extensions.iter().map(|extension| extension.id)),
 47            )
 48            .stream(tx)
 49            .await?;
 50
 51        let mut max_versions =
 52            HashMap::<ExtensionId, (extension_version::Model, Version)>::default();
 53        while let Some(version) = versions.next().await {
 54            let version = version?;
 55            let Some(extension_version) = Version::from_str(&version.version).log_err() else {
 56                continue;
 57            };
 58
 59            if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id)
 60                && max_extension_version > &extension_version
 61            {
 62                continue;
 63            }
 64
 65            if let Some(constraints) = constraints {
 66                if !constraints
 67                    .schema_versions
 68                    .contains(&version.schema_version)
 69                {
 70                    continue;
 71                }
 72
 73                if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
 74                    if let Some(version) = Version::from_str(wasm_api_version).log_err() {
 75                        if !constraints.wasm_api_versions.contains(&version) {
 76                            continue;
 77                        }
 78                    } else {
 79                        continue;
 80                    }
 81                }
 82            }
 83
 84            max_versions.insert(version.extension_id, (version, extension_version));
 85        }
 86
 87        Ok(max_versions)
 88    }
 89
 90    /// Returns all of the versions for the extension with the given ID.
 91    pub async fn get_extension_versions(
 92        &self,
 93        extension_id: &str,
 94    ) -> Result<Vec<ExtensionMetadata>> {
 95        self.transaction(|tx| async move {
 96            let condition = extension::Column::ExternalId
 97                .eq(extension_id)
 98                .into_condition();
 99
100            self.get_extensions_where(condition, None, &tx).await
101        })
102        .await
103    }
104
105    async fn get_extensions_where(
106        &self,
107        condition: Condition,
108        limit: Option<u64>,
109        tx: &DatabaseTransaction,
110    ) -> Result<Vec<ExtensionMetadata>> {
111        let extensions = extension::Entity::find()
112            .inner_join(extension_version::Entity)
113            .select_also(extension_version::Entity)
114            .filter(condition)
115            .order_by_desc(extension::Column::TotalDownloadCount)
116            .order_by_asc(extension::Column::Name)
117            .limit(limit)
118            .all(tx)
119            .await?;
120
121        Ok(extensions
122            .into_iter()
123            .filter_map(|(extension, version)| {
124                Some(metadata_from_extension_and_version(extension, version?))
125            })
126            .collect())
127    }
128
129    pub async fn get_extension(
130        &self,
131        extension_id: &str,
132        constraints: Option<&ExtensionVersionConstraints>,
133    ) -> Result<Option<ExtensionMetadata>> {
134        self.transaction(|tx| async move {
135            let extension = extension::Entity::find()
136                .filter(extension::Column::ExternalId.eq(extension_id))
137                .one(&*tx)
138                .await?
139                .with_context(|| format!("no such extension: {extension_id}"))?;
140
141            let extensions = [extension];
142            let mut versions = self
143                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
144                .await?;
145            let [extension] = extensions;
146
147            Ok(versions.remove(&extension.id).map(|(max_version, _)| {
148                metadata_from_extension_and_version(extension, max_version)
149            }))
150        })
151        .await
152    }
153
154    pub async fn get_extension_version(
155        &self,
156        extension_id: &str,
157        version: &str,
158    ) -> Result<Option<ExtensionMetadata>> {
159        self.transaction(|tx| async move {
160            let extension = extension::Entity::find()
161                .filter(extension::Column::ExternalId.eq(extension_id))
162                .filter(extension_version::Column::Version.eq(version))
163                .inner_join(extension_version::Entity)
164                .select_also(extension_version::Entity)
165                .one(&*tx)
166                .await?;
167
168            Ok(extension.and_then(|(extension, version)| {
169                Some(metadata_from_extension_and_version(extension, version?))
170            }))
171        })
172        .await
173    }
174
175    pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
176        self.transaction(|tx| async move {
177            let mut extension_external_ids_by_id = HashMap::default();
178
179            let mut rows = extension::Entity::find().stream(&*tx).await?;
180            while let Some(row) = rows.next().await {
181                let row = row?;
182                extension_external_ids_by_id.insert(row.id, row.external_id);
183            }
184            drop(rows);
185
186            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
187                HashMap::default();
188            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
189            while let Some(row) = rows.next().await {
190                let row = row?;
191
192                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
193                    continue;
194                };
195
196                let versions = known_versions_by_extension_id
197                    .entry(extension_id.clone())
198                    .or_default();
199                if let Err(ix) = versions.binary_search(&row.version) {
200                    versions.insert(ix, row.version);
201                }
202            }
203            drop(rows);
204
205            Ok(known_versions_by_extension_id)
206        })
207        .await
208    }
209
210    pub async fn insert_extension_versions(
211        &self,
212        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
213    ) -> Result<()> {
214        self.transaction(|tx| async move {
215            for (external_id, versions) in versions_by_extension_id {
216                if versions.is_empty() {
217                    continue;
218                }
219
220                let latest_version = versions
221                    .iter()
222                    .max_by_key(|version| &version.version)
223                    .unwrap();
224
225                let insert = extension::Entity::insert(extension::ActiveModel {
226                    name: ActiveValue::Set(latest_version.name.clone()),
227                    external_id: ActiveValue::Set((*external_id).to_owned()),
228                    id: ActiveValue::NotSet,
229                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
230                    total_download_count: ActiveValue::NotSet,
231                })
232                .on_conflict(
233                    OnConflict::columns([extension::Column::ExternalId])
234                        .update_column(extension::Column::ExternalId)
235                        .to_owned(),
236                );
237
238                let extension = if tx.support_returning() {
239                    insert.exec_with_returning(&*tx).await?
240                } else {
241                    // Sqlite
242                    insert.exec_without_returning(&*tx).await?;
243                    extension::Entity::find()
244                        .filter(extension::Column::ExternalId.eq(*external_id))
245                        .one(&*tx)
246                        .await?
247                        .context("failed to insert extension")?
248                };
249
250                extension_version::Entity::insert_many(versions.iter().map(|version| {
251                    extension_version::ActiveModel {
252                        extension_id: ActiveValue::Set(extension.id),
253                        published_at: ActiveValue::Set(version.published_at),
254                        version: ActiveValue::Set(version.version.to_string()),
255                        authors: ActiveValue::Set(version.authors.join(", ")),
256                        repository: ActiveValue::Set(version.repository.clone()),
257                        description: ActiveValue::Set(version.description.clone()),
258                        schema_version: ActiveValue::Set(version.schema_version),
259                        wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
260                        provides_themes: ActiveValue::Set(
261                            version.provides.contains(&ExtensionProvides::Themes),
262                        ),
263                        provides_icon_themes: ActiveValue::Set(
264                            version.provides.contains(&ExtensionProvides::IconThemes),
265                        ),
266                        provides_languages: ActiveValue::Set(
267                            version.provides.contains(&ExtensionProvides::Languages),
268                        ),
269                        provides_grammars: ActiveValue::Set(
270                            version.provides.contains(&ExtensionProvides::Grammars),
271                        ),
272                        provides_language_servers: ActiveValue::Set(
273                            version
274                                .provides
275                                .contains(&ExtensionProvides::LanguageServers),
276                        ),
277                        provides_context_servers: ActiveValue::Set(
278                            version
279                                .provides
280                                .contains(&ExtensionProvides::ContextServers),
281                        ),
282                        provides_agent_servers: ActiveValue::Set(
283                            version.provides.contains(&ExtensionProvides::AgentServers),
284                        ),
285                        provides_slash_commands: ActiveValue::Set(
286                            version.provides.contains(&ExtensionProvides::SlashCommands),
287                        ),
288                        provides_indexed_docs_providers: ActiveValue::Set(
289                            version
290                                .provides
291                                .contains(&ExtensionProvides::IndexedDocsProviders),
292                        ),
293                        provides_snippets: ActiveValue::Set(
294                            version.provides.contains(&ExtensionProvides::Snippets),
295                        ),
296                        provides_debug_adapters: ActiveValue::Set(
297                            version.provides.contains(&ExtensionProvides::DebugAdapters),
298                        ),
299                        download_count: ActiveValue::NotSet,
300                    }
301                }))
302                .on_conflict(OnConflict::new().do_nothing().to_owned())
303                .exec_without_returning(&*tx)
304                .await?;
305
306                if let Ok(db_version) = semver::Version::parse(&extension.latest_version)
307                    && db_version >= latest_version.version
308                {
309                    continue;
310                }
311
312                let mut extension = extension.into_active_model();
313                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
314                extension.name = ActiveValue::set(latest_version.name.clone());
315                extension::Entity::update(extension).exec(&*tx).await?;
316            }
317
318            Ok(())
319        })
320        .await
321    }
322
323    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
324        self.transaction(|tx| async move {
325            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
326            enum QueryId {
327                Id,
328            }
329
330            let extension_id: Option<ExtensionId> = extension::Entity::find()
331                .filter(extension::Column::ExternalId.eq(extension))
332                .select_only()
333                .column(extension::Column::Id)
334                .into_values::<_, QueryId>()
335                .one(&*tx)
336                .await?;
337            let Some(extension_id) = extension_id else {
338                return Ok(false);
339            };
340
341            extension_version::Entity::update_many()
342                .col_expr(
343                    extension_version::Column::DownloadCount,
344                    extension_version::Column::DownloadCount.into_expr().add(1),
345                )
346                .filter(
347                    extension_version::Column::ExtensionId
348                        .eq(extension_id)
349                        .and(extension_version::Column::Version.eq(version)),
350                )
351                .exec(&*tx)
352                .await?;
353
354            extension::Entity::update_many()
355                .col_expr(
356                    extension::Column::TotalDownloadCount,
357                    extension::Column::TotalDownloadCount.into_expr().add(1),
358                )
359                .filter(extension::Column::Id.eq(extension_id))
360                .exec(&*tx)
361                .await?;
362
363            Ok(true)
364        })
365        .await
366    }
367}
368
369fn metadata_from_extension_and_version(
370    extension: extension::Model,
371    version: extension_version::Model,
372) -> ExtensionMetadata {
373    let provides = version.provides();
374
375    ExtensionMetadata {
376        id: extension.external_id.into(),
377        manifest: cloud_api_types::ExtensionApiManifest {
378            name: extension.name,
379            version: version.version.into(),
380            authors: version
381                .authors
382                .split(',')
383                .map(|author| author.trim().to_string())
384                .collect::<Vec<_>>(),
385            description: Some(version.description),
386            repository: version.repository,
387            schema_version: Some(version.schema_version),
388            wasm_api_version: version.wasm_api_version,
389            provides,
390        },
391
392        published_at: convert_time_to_chrono(version.published_at),
393        download_count: extension.total_download_count as u64,
394    }
395}
396
397pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
398    chrono::DateTime::from_naive_utc_and_offset(
399        #[allow(deprecated)]
400        chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
401        Utc,
402    )
403}