extensions.rs

  1use std::str::FromStr;
  2
  3use chrono::Utc;
  4use sea_orm::sea_query::IntoCondition;
  5use util::ResultExt;
  6
  7use super::*;
  8
  9impl Database {
 10    pub async fn get_extensions(
 11        &self,
 12        filter: Option<&str>,
 13        max_schema_version: i32,
 14        limit: usize,
 15    ) -> Result<Vec<ExtensionMetadata>> {
 16        self.transaction(|tx| async move {
 17            let mut condition = Condition::all()
 18                .add(
 19                    extension::Column::LatestVersion
 20                        .into_expr()
 21                        .eq(extension_version::Column::Version.into_expr()),
 22                )
 23                .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
 24            if let Some(filter) = filter {
 25                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 26                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 27            }
 28
 29            self.get_extensions_where(condition, Some(limit as u64), &tx)
 30                .await
 31        })
 32        .await
 33    }
 34
 35    pub async fn get_extensions_by_ids(
 36        &self,
 37        ids: &[&str],
 38        constraints: Option<&ExtensionVersionConstraints>,
 39    ) -> Result<Vec<ExtensionMetadata>> {
 40        self.transaction(|tx| async move {
 41            let extensions = extension::Entity::find()
 42                .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
 43                .all(&*tx)
 44                .await?;
 45
 46            let mut max_versions = self
 47                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
 48                .await?;
 49
 50            Ok(extensions
 51                .into_iter()
 52                .filter_map(|extension| {
 53                    let (version, _) = max_versions.remove(&extension.id)?;
 54                    Some(metadata_from_extension_and_version(extension, version))
 55                })
 56                .collect())
 57        })
 58        .await
 59    }
 60
 61    async fn get_latest_versions_for_extensions(
 62        &self,
 63        extensions: &[extension::Model],
 64        constraints: Option<&ExtensionVersionConstraints>,
 65        tx: &DatabaseTransaction,
 66    ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
 67        let mut versions = extension_version::Entity::find()
 68            .filter(
 69                extension_version::Column::ExtensionId
 70                    .is_in(extensions.iter().map(|extension| extension.id)),
 71            )
 72            .stream(tx)
 73            .await?;
 74
 75        let mut max_versions =
 76            HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
 77        while let Some(version) = versions.next().await {
 78            let version = version?;
 79            let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
 80            else {
 81                continue;
 82            };
 83
 84            if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
 85                if max_extension_version > &extension_version {
 86                    continue;
 87                }
 88            }
 89
 90            if let Some(constraints) = constraints {
 91                if !constraints
 92                    .schema_versions
 93                    .contains(&version.schema_version)
 94                {
 95                    continue;
 96                }
 97
 98                if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
 99                    if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
100                        if !constraints.wasm_api_versions.contains(&version) {
101                            continue;
102                        }
103                    } else {
104                        continue;
105                    }
106                }
107            }
108
109            max_versions.insert(version.extension_id, (version, extension_version));
110        }
111
112        Ok(max_versions)
113    }
114
115    /// Returns all of the versions for the extension with the given ID.
116    pub async fn get_extension_versions(
117        &self,
118        extension_id: &str,
119    ) -> Result<Vec<ExtensionMetadata>> {
120        self.transaction(|tx| async move {
121            let condition = extension::Column::ExternalId
122                .eq(extension_id)
123                .into_condition();
124
125            self.get_extensions_where(condition, None, &tx).await
126        })
127        .await
128    }
129
130    async fn get_extensions_where(
131        &self,
132        condition: Condition,
133        limit: Option<u64>,
134        tx: &DatabaseTransaction,
135    ) -> Result<Vec<ExtensionMetadata>> {
136        let extensions = extension::Entity::find()
137            .inner_join(extension_version::Entity)
138            .select_also(extension_version::Entity)
139            .filter(condition)
140            .order_by_desc(extension::Column::TotalDownloadCount)
141            .order_by_asc(extension::Column::Name)
142            .limit(limit)
143            .all(tx)
144            .await?;
145
146        Ok(extensions
147            .into_iter()
148            .filter_map(|(extension, version)| {
149                Some(metadata_from_extension_and_version(extension, version?))
150            })
151            .collect())
152    }
153
154    pub async fn get_extension(
155        &self,
156        extension_id: &str,
157        constraints: Option<&ExtensionVersionConstraints>,
158    ) -> Result<Option<ExtensionMetadata>> {
159        self.transaction(|tx| async move {
160            let extension = extension::Entity::find()
161                .filter(extension::Column::ExternalId.eq(extension_id))
162                .one(&*tx)
163                .await?
164                .ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
165
166            let extensions = [extension];
167            let mut versions = self
168                .get_latest_versions_for_extensions(&extensions, constraints, &tx)
169                .await?;
170            let [extension] = extensions;
171
172            Ok(versions.remove(&extension.id).map(|(max_version, _)| {
173                metadata_from_extension_and_version(extension, max_version)
174            }))
175        })
176        .await
177    }
178
179    pub async fn get_extension_version(
180        &self,
181        extension_id: &str,
182        version: &str,
183    ) -> Result<Option<ExtensionMetadata>> {
184        self.transaction(|tx| async move {
185            let extension = extension::Entity::find()
186                .filter(extension::Column::ExternalId.eq(extension_id))
187                .filter(extension_version::Column::Version.eq(version))
188                .inner_join(extension_version::Entity)
189                .select_also(extension_version::Entity)
190                .one(&*tx)
191                .await?;
192
193            Ok(extension.and_then(|(extension, version)| {
194                Some(metadata_from_extension_and_version(extension, version?))
195            }))
196        })
197        .await
198    }
199
200    pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
201        self.transaction(|tx| async move {
202            let mut extension_external_ids_by_id = HashMap::default();
203
204            let mut rows = extension::Entity::find().stream(&*tx).await?;
205            while let Some(row) = rows.next().await {
206                let row = row?;
207                extension_external_ids_by_id.insert(row.id, row.external_id);
208            }
209            drop(rows);
210
211            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
212                HashMap::default();
213            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
214            while let Some(row) = rows.next().await {
215                let row = row?;
216
217                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
218                    continue;
219                };
220
221                let versions = known_versions_by_extension_id
222                    .entry(extension_id.clone())
223                    .or_default();
224                if let Err(ix) = versions.binary_search(&row.version) {
225                    versions.insert(ix, row.version);
226                }
227            }
228            drop(rows);
229
230            Ok(known_versions_by_extension_id)
231        })
232        .await
233    }
234
235    pub async fn insert_extension_versions(
236        &self,
237        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
238    ) -> Result<()> {
239        self.transaction(|tx| async move {
240            for (external_id, versions) in versions_by_extension_id {
241                if versions.is_empty() {
242                    continue;
243                }
244
245                let latest_version = versions
246                    .iter()
247                    .max_by_key(|version| &version.version)
248                    .unwrap();
249
250                let insert = extension::Entity::insert(extension::ActiveModel {
251                    name: ActiveValue::Set(latest_version.name.clone()),
252                    external_id: ActiveValue::Set(external_id.to_string()),
253                    id: ActiveValue::NotSet,
254                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
255                    total_download_count: ActiveValue::NotSet,
256                })
257                .on_conflict(
258                    OnConflict::columns([extension::Column::ExternalId])
259                        .update_column(extension::Column::ExternalId)
260                        .to_owned(),
261                );
262
263                let extension = if tx.support_returning() {
264                    insert.exec_with_returning(&*tx).await?
265                } else {
266                    // Sqlite
267                    insert.exec_without_returning(&*tx).await?;
268                    extension::Entity::find()
269                        .filter(extension::Column::ExternalId.eq(*external_id))
270                        .one(&*tx)
271                        .await?
272                        .ok_or_else(|| anyhow!("failed to insert extension"))?
273                };
274
275                extension_version::Entity::insert_many(versions.iter().map(|version| {
276                    extension_version::ActiveModel {
277                        extension_id: ActiveValue::Set(extension.id),
278                        published_at: ActiveValue::Set(version.published_at),
279                        version: ActiveValue::Set(version.version.to_string()),
280                        authors: ActiveValue::Set(version.authors.join(", ")),
281                        repository: ActiveValue::Set(version.repository.clone()),
282                        description: ActiveValue::Set(version.description.clone()),
283                        schema_version: ActiveValue::Set(version.schema_version),
284                        wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
285                        provides_themes: ActiveValue::Set(
286                            version.provides.contains(&ExtensionProvides::Themes),
287                        ),
288                        provides_icon_themes: ActiveValue::Set(
289                            version.provides.contains(&ExtensionProvides::IconThemes),
290                        ),
291                        provides_languages: ActiveValue::Set(
292                            version.provides.contains(&ExtensionProvides::Languages),
293                        ),
294                        provides_grammars: ActiveValue::Set(
295                            version.provides.contains(&ExtensionProvides::Grammars),
296                        ),
297                        provides_language_servers: ActiveValue::Set(
298                            version
299                                .provides
300                                .contains(&ExtensionProvides::LanguageServers),
301                        ),
302                        provides_context_servers: ActiveValue::Set(
303                            version
304                                .provides
305                                .contains(&ExtensionProvides::ContextServers),
306                        ),
307                        provides_slash_commands: ActiveValue::Set(
308                            version.provides.contains(&ExtensionProvides::SlashCommands),
309                        ),
310                        provides_indexed_docs_providers: ActiveValue::Set(
311                            version
312                                .provides
313                                .contains(&ExtensionProvides::IndexedDocsProviders),
314                        ),
315                        provides_snippets: ActiveValue::Set(
316                            version.provides.contains(&ExtensionProvides::Snippets),
317                        ),
318                        download_count: ActiveValue::NotSet,
319                    }
320                }))
321                .on_conflict(OnConflict::new().do_nothing().to_owned())
322                .exec_without_returning(&*tx)
323                .await?;
324
325                if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
326                    if db_version >= latest_version.version {
327                        continue;
328                    }
329                }
330
331                let mut extension = extension.into_active_model();
332                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
333                extension.name = ActiveValue::set(latest_version.name.clone());
334                extension::Entity::update(extension).exec(&*tx).await?;
335            }
336
337            Ok(())
338        })
339        .await
340    }
341
342    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
343        self.transaction(|tx| async move {
344            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
345            enum QueryId {
346                Id,
347            }
348
349            let extension_id: Option<ExtensionId> = extension::Entity::find()
350                .filter(extension::Column::ExternalId.eq(extension))
351                .select_only()
352                .column(extension::Column::Id)
353                .into_values::<_, QueryId>()
354                .one(&*tx)
355                .await?;
356            let Some(extension_id) = extension_id else {
357                return Ok(false);
358            };
359
360            extension_version::Entity::update_many()
361                .col_expr(
362                    extension_version::Column::DownloadCount,
363                    extension_version::Column::DownloadCount.into_expr().add(1),
364                )
365                .filter(
366                    extension_version::Column::ExtensionId
367                        .eq(extension_id)
368                        .and(extension_version::Column::Version.eq(version)),
369                )
370                .exec(&*tx)
371                .await?;
372
373            extension::Entity::update_many()
374                .col_expr(
375                    extension::Column::TotalDownloadCount,
376                    extension::Column::TotalDownloadCount.into_expr().add(1),
377                )
378                .filter(extension::Column::Id.eq(extension_id))
379                .exec(&*tx)
380                .await?;
381
382            Ok(true)
383        })
384        .await
385    }
386}
387
388fn metadata_from_extension_and_version(
389    extension: extension::Model,
390    version: extension_version::Model,
391) -> ExtensionMetadata {
392    let provides = version.provides();
393
394    ExtensionMetadata {
395        id: extension.external_id.into(),
396        manifest: rpc::ExtensionApiManifest {
397            name: extension.name,
398            version: version.version.into(),
399            authors: version
400                .authors
401                .split(',')
402                .map(|author| author.trim().to_string())
403                .collect::<Vec<_>>(),
404            description: Some(version.description),
405            repository: version.repository,
406            schema_version: Some(version.schema_version),
407            wasm_api_version: version.wasm_api_version,
408            provides,
409        },
410
411        published_at: convert_time_to_chrono(version.published_at),
412        download_count: extension.total_download_count as u64,
413    }
414}
415
416pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
417    chrono::DateTime::from_naive_utc_and_offset(
418        #[allow(deprecated)]
419        chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
420        Utc,
421    )
422}