extensions.rs

  1use super::*;
  2
  3impl Database {
  4    pub async fn get_extensions(
  5        &self,
  6        filter: Option<&str>,
  7        limit: usize,
  8    ) -> Result<Vec<ExtensionMetadata>> {
  9        self.transaction(|tx| async move {
 10            let mut condition = Condition::all();
 11            if let Some(filter) = filter {
 12                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
 13                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
 14            }
 15
 16            let extensions = extension::Entity::find()
 17                .filter(condition)
 18                .order_by_desc(extension::Column::TotalDownloadCount)
 19                .order_by_asc(extension::Column::Name)
 20                .limit(Some(limit as u64))
 21                .filter(
 22                    extension::Column::LatestVersion
 23                        .into_expr()
 24                        .eq(extension_version::Column::Version.into_expr()),
 25                )
 26                .inner_join(extension_version::Entity)
 27                .select_also(extension_version::Entity)
 28                .all(&*tx)
 29                .await?;
 30
 31            Ok(extensions
 32                .into_iter()
 33                .filter_map(|(extension, latest_version)| {
 34                    let version = latest_version?;
 35                    Some(ExtensionMetadata {
 36                        id: extension.external_id,
 37                        name: extension.name,
 38                        version: version.version,
 39                        authors: version
 40                            .authors
 41                            .split(',')
 42                            .map(|author| author.trim().to_string())
 43                            .collect::<Vec<_>>(),
 44                        description: version.description,
 45                        repository: version.repository,
 46                        published_at: version.published_at,
 47                        download_count: extension.total_download_count as u64,
 48                    })
 49                })
 50                .collect())
 51        })
 52        .await
 53    }
 54
 55    pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
 56        self.transaction(|tx| async move {
 57            let mut extension_external_ids_by_id = HashMap::default();
 58
 59            let mut rows = extension::Entity::find().stream(&*tx).await?;
 60            while let Some(row) = rows.next().await {
 61                let row = row?;
 62                extension_external_ids_by_id.insert(row.id, row.external_id);
 63            }
 64            drop(rows);
 65
 66            let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
 67                HashMap::default();
 68            let mut rows = extension_version::Entity::find().stream(&*tx).await?;
 69            while let Some(row) = rows.next().await {
 70                let row = row?;
 71
 72                let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
 73                    continue;
 74                };
 75
 76                let versions = known_versions_by_extension_id
 77                    .entry(extension_id.clone())
 78                    .or_default();
 79                if let Err(ix) = versions.binary_search(&row.version) {
 80                    versions.insert(ix, row.version);
 81                }
 82            }
 83            drop(rows);
 84
 85            Ok(known_versions_by_extension_id)
 86        })
 87        .await
 88    }
 89
 90    pub async fn insert_extension_versions(
 91        &self,
 92        versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
 93    ) -> Result<()> {
 94        self.transaction(|tx| async move {
 95            for (external_id, versions) in versions_by_extension_id {
 96                if versions.is_empty() {
 97                    continue;
 98                }
 99
100                let latest_version = versions
101                    .iter()
102                    .max_by_key(|version| &version.version)
103                    .unwrap();
104
105                let insert = extension::Entity::insert(extension::ActiveModel {
106                    name: ActiveValue::Set(latest_version.name.clone()),
107                    external_id: ActiveValue::Set(external_id.to_string()),
108                    id: ActiveValue::NotSet,
109                    latest_version: ActiveValue::Set(latest_version.version.to_string()),
110                    total_download_count: ActiveValue::NotSet,
111                })
112                .on_conflict(
113                    OnConflict::columns([extension::Column::ExternalId])
114                        .update_column(extension::Column::ExternalId)
115                        .to_owned(),
116                );
117
118                let extension = if tx.support_returning() {
119                    insert.exec_with_returning(&*tx).await?
120                } else {
121                    // Sqlite
122                    insert.exec_without_returning(&*tx).await?;
123                    extension::Entity::find()
124                        .filter(extension::Column::ExternalId.eq(*external_id))
125                        .one(&*tx)
126                        .await?
127                        .ok_or_else(|| anyhow!("failed to insert extension"))?
128                };
129
130                extension_version::Entity::insert_many(versions.iter().map(|version| {
131                    extension_version::ActiveModel {
132                        extension_id: ActiveValue::Set(extension.id),
133                        published_at: ActiveValue::Set(version.published_at),
134                        version: ActiveValue::Set(version.version.to_string()),
135                        authors: ActiveValue::Set(version.authors.join(", ")),
136                        repository: ActiveValue::Set(version.repository.clone()),
137                        description: ActiveValue::Set(version.description.clone()),
138                        download_count: ActiveValue::NotSet,
139                    }
140                }))
141                .on_conflict(OnConflict::new().do_nothing().to_owned())
142                .exec_without_returning(&*tx)
143                .await?;
144
145                if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
146                    if db_version >= latest_version.version {
147                        continue;
148                    }
149                }
150
151                let mut extension = extension.into_active_model();
152                extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
153                extension.name = ActiveValue::set(latest_version.name.clone());
154                extension::Entity::update(extension).exec(&*tx).await?;
155            }
156
157            Ok(())
158        })
159        .await
160    }
161
162    pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
163        self.transaction(|tx| async move {
164            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
165            enum QueryId {
166                Id,
167            }
168
169            let extension_id: Option<ExtensionId> = extension::Entity::find()
170                .filter(extension::Column::ExternalId.eq(extension))
171                .select_only()
172                .column(extension::Column::Id)
173                .into_values::<_, QueryId>()
174                .one(&*tx)
175                .await?;
176            let Some(extension_id) = extension_id else {
177                return Ok(false);
178            };
179
180            extension_version::Entity::update_many()
181                .col_expr(
182                    extension_version::Column::DownloadCount,
183                    extension_version::Column::DownloadCount.into_expr().add(1),
184                )
185                .filter(
186                    extension_version::Column::ExtensionId
187                        .eq(extension_id)
188                        .and(extension_version::Column::Version.eq(version)),
189                )
190                .exec(&*tx)
191                .await?;
192
193            extension::Entity::update_many()
194                .col_expr(
195                    extension::Column::TotalDownloadCount,
196                    extension::Column::TotalDownloadCount.into_expr().add(1),
197                )
198                .filter(extension::Column::Id.eq(extension_id))
199                .exec(&*tx)
200                .await?;
201
202            Ok(true)
203        })
204        .await
205    }
206}