extensions.rs

  1use super::*;
  2
  3impl Database {
  4    pub async fn get_extensions(
  5        &self,
  6        filter: Option<&str>,
  7        limit: usize,
  8    ) -> Result<Vec<ExtensionMetadata>> {
  9        self.transaction(|tx| async move {
 10            let mut condition = Condition::all();
 11            if let Some(filter) = filter {
 12                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 13                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 14            }
 15
 16            let extensions = extension::Entity::find()
 17                .filter(condition)
 18                .order_by_desc(extension::Column::TotalDownloadCount)
 19                .order_by_asc(extension::Column::Id)
 20                .limit(Some(limit as u64))
 21                .filter(
 22                    extension::Column::LatestVersion
 23                        .into_expr()
 24                        .eq(extension_version::Column::Version.into_expr()),
 25                )
 26                .inner_join(extension_version::Entity)
 27                .select_also(extension_version::Entity)
 28                .all(&*tx)
 29                .await?;
 30
 31            Ok(extensions
 32                .into_iter()
 33                .filter_map(|(extension, latest_version)| {
 34                    let version = latest_version?;
 35                    Some(ExtensionMetadata {
 36                        id: extension.external_id,
 37                        name: extension.name,
 38                        version: version.version,
 39                        authors: version
 40                            .authors
 41                            .split(',')
 42                            .map(|author| author.trim().to_string())
 43                            .collect::<Vec<_>>(),
 44                        repository: version.repository,
 45                        published_at: version.published_at,
 46                        download_count: extension.total_download_count as u64,
 47                    })
 48                })
 49                .collect())
 50        })
 51        .await
 52    }
 53
 54    pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
 55        self.transaction(|tx| async move {
 56            let mut extension_external_ids_by_id = HashMap::default();
 57
 58            let mut rows = extension::Entity::find().stream(&*tx).await?;
 59            while let Some(row) = rows.next().await {
 60                let row = row?;
 61                extension_external_ids_by_id.insert(row.id, row.external_id);
 62            }
 63            drop(rows);
 64
 65            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
 66                HashMap::default();
 67            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
 68            while let Some(row) = rows.next().await {
 69                let row = row?;
 70
 71                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
 72                    continue;
 73                };
 74
 75                let versions = known_versions_by_extension_id
 76                    .entry(extension_id.clone())
 77                    .or_default();
 78                if let Err(ix) = versions.binary_search(&row.version) {
 79                    versions.insert(ix, row.version);
 80                }
 81            }
 82            drop(rows);
 83
 84            Ok(known_versions_by_extension_id)
 85        })
 86        .await
 87    }
 88
 89    pub async fn insert_extension_versions(
 90        &self,
 91        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
 92    ) -> Result<()> {
 93        self.transaction(|tx| async move {
 94            for (external_id, versions) in versions_by_extension_id {
 95                if versions.is_empty() {
 96                    continue;
 97                }
 98
 99                let latest_version = versions
100                    .iter()
101                    .max_by_key(|version| &version.version)
102                    .unwrap();
103
104                let insert = extension::Entity::insert(extension::ActiveModel {
105                    name: ActiveValue::Set(latest_version.name.clone()),
106                    external_id: ActiveValue::Set(external_id.to_string()),
107                    id: ActiveValue::NotSet,
108                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
109                    total_download_count: ActiveValue::NotSet,
110                })
111                .on_conflict(
112                    OnConflict::columns([extension::Column::ExternalId])
113                        .update_column(extension::Column::ExternalId)
114                        .to_owned(),
115                );
116
117                let extension = if tx.support_returning() {
118                    insert.exec_with_returning(&*tx).await?
119                } else {
120                    // Sqlite
121                    insert.exec_without_returning(&*tx).await?;
122                    extension::Entity::find()
123                        .filter(extension::Column::ExternalId.eq(*external_id))
124                        .one(&*tx)
125                        .await?
126                        .ok_or_else(|| anyhow!("failed to insert extension"))?
127                };
128
129                extension_version::Entity::insert_many(versions.iter().map(|version| {
130                    extension_version::ActiveModel {
131                        extension_id: ActiveValue::Set(extension.id),
132                        published_at: ActiveValue::Set(version.published_at),
133                        version: ActiveValue::Set(version.version.to_string()),
134                        authors: ActiveValue::Set(version.authors.join(", ")),
135                        repository: ActiveValue::Set(version.repository.clone()),
136                        description: ActiveValue::Set(version.description.clone()),
137                        download_count: ActiveValue::NotSet,
138                    }
139                }))
140                .on_conflict(OnConflict::new().do_nothing().to_owned())
141                .exec_without_returning(&*tx)
142                .await?;
143
144                if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
145                    if db_version >= latest_version.version {
146                        continue;
147                    }
148                }
149
150                let mut extension = extension.into_active_model();
151                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
152                extension.name = ActiveValue::set(latest_version.name.clone());
153                extension::Entity::update(extension).exec(&*tx).await?;
154            }
155
156            Ok(())
157        })
158        .await
159    }
160
161    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
162        self.transaction(|tx| async move {
163            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
164            enum QueryId {
165                Id,
166            }
167
168            let extension_id: Option<ExtensionId> = extension::Entity::find()
169                .filter(extension::Column::ExternalId.eq(extension))
170                .select_only()
171                .column(extension::Column::Id)
172                .into_values::<_, QueryId>()
173                .one(&*tx)
174                .await?;
175            let Some(extension_id) = extension_id else {
176                return Ok(false);
177            };
178
179            extension_version::Entity::update_many()
180                .col_expr(
181                    extension_version::Column::DownloadCount,
182                    extension_version::Column::DownloadCount.into_expr().add(1),
183                )
184                .filter(
185                    extension_version::Column::ExtensionId
186                        .eq(extension_id)
187                        .and(extension_version::Column::Version.eq(version)),
188                )
189                .exec(&*tx)
190                .await?;
191
192            extension::Entity::update_many()
193                .col_expr(
194                    extension::Column::TotalDownloadCount,
195                    extension::Column::TotalDownloadCount.into_expr().add(1),
196                )
197                .filter(extension::Column::Id.eq(extension_id))
198                .exec(&*tx)
199                .await?;
200
201            Ok(true)
202        })
203        .await
204    }
205}