extensions.rs

  1use super::*;
  2
  3impl Database {
  4    pub async fn get_extensions(
  5        &self,
  6        filter: Option<&str>,
  7        max_schema_version: i32,
  8        limit: usize,
  9    ) -> Result<Vec<ExtensionMetadata>> {
 10        self.transaction(|tx| async move {
 11            let mut condition = Condition::all().add(
 12                extension::Column::LatestVersion
 13                    .into_expr()
 14                    .eq(extension_version::Column::Version.into_expr()),
 15            );
 16            if let Some(filter) = filter {
 17                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 18                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 19            }
 20
 21            let extensions = extension::Entity::find()
 22                .inner_join(extension_version::Entity)
 23                .select_also(extension_version::Entity)
 24                .filter(condition)
 25                .filter(extension_version::Column::SchemaVersion.lte(max_schema_version))
 26                .order_by_desc(extension::Column::TotalDownloadCount)
 27                .order_by_asc(extension::Column::Name)
 28                .limit(Some(limit as u64))
 29                .all(&*tx)
 30                .await?;
 31
 32            Ok(extensions
 33                .into_iter()
 34                .filter_map(|(extension, latest_version)| {
 35                    let version = latest_version?;
 36                    Some(ExtensionMetadata {
 37                        id: extension.external_id,
 38                        name: extension.name,
 39                        version: version.version,
 40                        authors: version
 41                            .authors
 42                            .split(',')
 43                            .map(|author| author.trim().to_string())
 44                            .collect::<Vec<_>>(),
 45                        description: version.description,
 46                        repository: version.repository,
 47                        published_at: version.published_at,
 48                        download_count: extension.total_download_count as u64,
 49                    })
 50                })
 51                .collect())
 52        })
 53        .await
 54    }
 55
 56    pub async fn get_extension(&self, extension_id: &str) -> Result<Option<ExtensionMetadata>> {
 57        self.transaction(|tx| async move {
 58            let extension = extension::Entity::find()
 59                .filter(extension::Column::ExternalId.eq(extension_id))
 60                .filter(
 61                    extension::Column::LatestVersion
 62                        .into_expr()
 63                        .eq(extension_version::Column::Version.into_expr()),
 64                )
 65                .inner_join(extension_version::Entity)
 66                .select_also(extension_version::Entity)
 67                .one(&*tx)
 68                .await?;
 69
 70            Ok(extension.and_then(|(extension, latest_version)| {
 71                let version = latest_version?;
 72                Some(ExtensionMetadata {
 73                    id: extension.external_id,
 74                    name: extension.name,
 75                    version: version.version,
 76                    authors: version
 77                        .authors
 78                        .split(',')
 79                        .map(|author| author.trim().to_string())
 80                        .collect::<Vec<_>>(),
 81                    description: version.description,
 82                    repository: version.repository,
 83                    published_at: version.published_at,
 84                    download_count: extension.total_download_count as u64,
 85                })
 86            }))
 87        })
 88        .await
 89    }
 90
 91    pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
 92        self.transaction(|tx| async move {
 93            let mut extension_external_ids_by_id = HashMap::default();
 94
 95            let mut rows = extension::Entity::find().stream(&*tx).await?;
 96            while let Some(row) = rows.next().await {
 97                let row = row?;
 98                extension_external_ids_by_id.insert(row.id, row.external_id);
 99            }
100            drop(rows);
101
102            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
103                HashMap::default();
104            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
105            while let Some(row) = rows.next().await {
106                let row = row?;
107
108                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
109                    continue;
110                };
111
112                let versions = known_versions_by_extension_id
113                    .entry(extension_id.clone())
114                    .or_default();
115                if let Err(ix) = versions.binary_search(&row.version) {
116                    versions.insert(ix, row.version);
117                }
118            }
119            drop(rows);
120
121            Ok(known_versions_by_extension_id)
122        })
123        .await
124    }
125
126    pub async fn insert_extension_versions(
127        &self,
128        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
129    ) -> Result<()> {
130        self.transaction(|tx| async move {
131            for (external_id, versions) in versions_by_extension_id {
132                if versions.is_empty() {
133                    continue;
134                }
135
136                let latest_version = versions
137                    .iter()
138                    .max_by_key(|version| &version.version)
139                    .unwrap();
140
141                let insert = extension::Entity::insert(extension::ActiveModel {
142                    name: ActiveValue::Set(latest_version.name.clone()),
143                    external_id: ActiveValue::Set(external_id.to_string()),
144                    id: ActiveValue::NotSet,
145                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
146                    total_download_count: ActiveValue::NotSet,
147                })
148                .on_conflict(
149                    OnConflict::columns([extension::Column::ExternalId])
150                        .update_column(extension::Column::ExternalId)
151                        .to_owned(),
152                );
153
154                let extension = if tx.support_returning() {
155                    insert.exec_with_returning(&*tx).await?
156                } else {
157                    // Sqlite
158                    insert.exec_without_returning(&*tx).await?;
159                    extension::Entity::find()
160                        .filter(extension::Column::ExternalId.eq(*external_id))
161                        .one(&*tx)
162                        .await?
163                        .ok_or_else(|| anyhow!("failed to insert extension"))?
164                };
165
166                extension_version::Entity::insert_many(versions.iter().map(|version| {
167                    extension_version::ActiveModel {
168                        extension_id: ActiveValue::Set(extension.id),
169                        published_at: ActiveValue::Set(version.published_at),
170                        version: ActiveValue::Set(version.version.to_string()),
171                        authors: ActiveValue::Set(version.authors.join(", ")),
172                        repository: ActiveValue::Set(version.repository.clone()),
173                        description: ActiveValue::Set(version.description.clone()),
174                        schema_version: ActiveValue::Set(version.schema_version),
175                        download_count: ActiveValue::NotSet,
176                    }
177                }))
178                .on_conflict(OnConflict::new().do_nothing().to_owned())
179                .exec_without_returning(&*tx)
180                .await?;
181
182                if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
183                    if db_version >= latest_version.version {
184                        continue;
185                    }
186                }
187
188                let mut extension = extension.into_active_model();
189                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
190                extension.name = ActiveValue::set(latest_version.name.clone());
191                extension::Entity::update(extension).exec(&*tx).await?;
192            }
193
194            Ok(())
195        })
196        .await
197    }
198
199    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
200        self.transaction(|tx| async move {
201            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
202            enum QueryId {
203                Id,
204            }
205
206            let extension_id: Option<ExtensionId> = extension::Entity::find()
207                .filter(extension::Column::ExternalId.eq(extension))
208                .select_only()
209                .column(extension::Column::Id)
210                .into_values::<_, QueryId>()
211                .one(&*tx)
212                .await?;
213            let Some(extension_id) = extension_id else {
214                return Ok(false);
215            };
216
217            extension_version::Entity::update_many()
218                .col_expr(
219                    extension_version::Column::DownloadCount,
220                    extension_version::Column::DownloadCount.into_expr().add(1),
221                )
222                .filter(
223                    extension_version::Column::ExtensionId
224                        .eq(extension_id)
225                        .and(extension_version::Column::Version.eq(version)),
226                )
227                .exec(&*tx)
228                .await?;
229
230            extension::Entity::update_many()
231                .col_expr(
232                    extension::Column::TotalDownloadCount,
233                    extension::Column::TotalDownloadCount.into_expr().add(1),
234                )
235                .filter(extension::Column::Id.eq(extension_id))
236                .exec(&*tx)
237                .await?;
238
239            Ok(true)
240        })
241        .await
242    }
243}