1use std::str::FromStr;
2
3use chrono::Utc;
4use sea_orm::sea_query::IntoCondition;
5use util::ResultExt;
6
7use super::*;
8
9impl Database {
10 pub async fn get_extensions(
11 &self,
12 filter: Option<&str>,
13 max_schema_version: i32,
14 limit: usize,
15 ) -> Result<Vec<ExtensionMetadata>> {
16 self.transaction(|tx| async move {
17 let mut condition = Condition::all()
18 .add(
19 extension::Column::LatestVersion
20 .into_expr()
21 .eq(extension_version::Column::Version.into_expr()),
22 )
23 .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
24 if let Some(filter) = filter {
25 let fuzzy_name_filter = Self::fuzzy_like_string(filter);
26 condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
27 }
28
29 self.get_extensions_where(condition, Some(limit as u64), &tx)
30 .await
31 })
32 .await
33 }
34
35 pub async fn get_extensions_by_ids(
36 &self,
37 ids: &[&str],
38 constraints: Option<&ExtensionVersionConstraints>,
39 ) -> Result<Vec<ExtensionMetadata>> {
40 self.transaction(|tx| async move {
41 let extensions = extension::Entity::find()
42 .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
43 .all(&*tx)
44 .await?;
45
46 let mut max_versions = self
47 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
48 .await?;
49
50 Ok(extensions
51 .into_iter()
52 .filter_map(|extension| {
53 let (version, _) = max_versions.remove(&extension.id)?;
54 Some(metadata_from_extension_and_version(extension, version))
55 })
56 .collect())
57 })
58 .await
59 }
60
61 async fn get_latest_versions_for_extensions(
62 &self,
63 extensions: &[extension::Model],
64 constraints: Option<&ExtensionVersionConstraints>,
65 tx: &DatabaseTransaction,
66 ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
67 let mut versions = extension_version::Entity::find()
68 .filter(
69 extension_version::Column::ExtensionId
70 .is_in(extensions.iter().map(|extension| extension.id)),
71 )
72 .stream(tx)
73 .await?;
74
75 let mut max_versions =
76 HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
77 while let Some(version) = versions.next().await {
78 let version = version?;
79 let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
80 else {
81 continue;
82 };
83
84 if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
85 if max_extension_version > &extension_version {
86 continue;
87 }
88 }
89
90 if let Some(constraints) = constraints {
91 if !constraints
92 .schema_versions
93 .contains(&version.schema_version)
94 {
95 continue;
96 }
97
98 if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
99 if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
100 if !constraints.wasm_api_versions.contains(&version) {
101 continue;
102 }
103 } else {
104 continue;
105 }
106 }
107 }
108
109 max_versions.insert(version.extension_id, (version, extension_version));
110 }
111
112 Ok(max_versions)
113 }
114
115 /// Returns all of the versions for the extension with the given ID.
116 pub async fn get_extension_versions(
117 &self,
118 extension_id: &str,
119 ) -> Result<Vec<ExtensionMetadata>> {
120 self.transaction(|tx| async move {
121 let condition = extension::Column::ExternalId
122 .eq(extension_id)
123 .into_condition();
124
125 self.get_extensions_where(condition, None, &tx).await
126 })
127 .await
128 }
129
130 async fn get_extensions_where(
131 &self,
132 condition: Condition,
133 limit: Option<u64>,
134 tx: &DatabaseTransaction,
135 ) -> Result<Vec<ExtensionMetadata>> {
136 let extensions = extension::Entity::find()
137 .inner_join(extension_version::Entity)
138 .select_also(extension_version::Entity)
139 .filter(condition)
140 .order_by_desc(extension::Column::TotalDownloadCount)
141 .order_by_asc(extension::Column::Name)
142 .limit(limit)
143 .all(tx)
144 .await?;
145
146 Ok(extensions
147 .into_iter()
148 .filter_map(|(extension, version)| {
149 Some(metadata_from_extension_and_version(extension, version?))
150 })
151 .collect())
152 }
153
154 pub async fn get_extension(
155 &self,
156 extension_id: &str,
157 constraints: Option<&ExtensionVersionConstraints>,
158 ) -> Result<Option<ExtensionMetadata>> {
159 self.transaction(|tx| async move {
160 let extension = extension::Entity::find()
161 .filter(extension::Column::ExternalId.eq(extension_id))
162 .one(&*tx)
163 .await?
164 .ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
165
166 let extensions = [extension];
167 let mut versions = self
168 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
169 .await?;
170 let [extension] = extensions;
171
172 Ok(versions.remove(&extension.id).map(|(max_version, _)| {
173 metadata_from_extension_and_version(extension, max_version)
174 }))
175 })
176 .await
177 }
178
179 pub async fn get_extension_version(
180 &self,
181 extension_id: &str,
182 version: &str,
183 ) -> Result<Option<ExtensionMetadata>> {
184 self.transaction(|tx| async move {
185 let extension = extension::Entity::find()
186 .filter(extension::Column::ExternalId.eq(extension_id))
187 .filter(extension_version::Column::Version.eq(version))
188 .inner_join(extension_version::Entity)
189 .select_also(extension_version::Entity)
190 .one(&*tx)
191 .await?;
192
193 Ok(extension.and_then(|(extension, version)| {
194 Some(metadata_from_extension_and_version(extension, version?))
195 }))
196 })
197 .await
198 }
199
200 pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
201 self.transaction(|tx| async move {
202 let mut extension_external_ids_by_id = HashMap::default();
203
204 let mut rows = extension::Entity::find().stream(&*tx).await?;
205 while let Some(row) = rows.next().await {
206 let row = row?;
207 extension_external_ids_by_id.insert(row.id, row.external_id);
208 }
209 drop(rows);
210
211 let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
212 HashMap::default();
213 let mut rows = extension_version::Entity::find().stream(&*tx).await?;
214 while let Some(row) = rows.next().await {
215 let row = row?;
216
217 let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
218 continue;
219 };
220
221 let versions = known_versions_by_extension_id
222 .entry(extension_id.clone())
223 .or_default();
224 if let Err(ix) = versions.binary_search(&row.version) {
225 versions.insert(ix, row.version);
226 }
227 }
228 drop(rows);
229
230 Ok(known_versions_by_extension_id)
231 })
232 .await
233 }
234
235 pub async fn insert_extension_versions(
236 &self,
237 versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
238 ) -> Result<()> {
239 self.transaction(|tx| async move {
240 for (external_id, versions) in versions_by_extension_id {
241 if versions.is_empty() {
242 continue;
243 }
244
245 let latest_version = versions
246 .iter()
247 .max_by_key(|version| &version.version)
248 .unwrap();
249
250 let insert = extension::Entity::insert(extension::ActiveModel {
251 name: ActiveValue::Set(latest_version.name.clone()),
252 external_id: ActiveValue::Set(external_id.to_string()),
253 id: ActiveValue::NotSet,
254 latest_version: ActiveValue::Set(latest_version.version.to_string()),
255 total_download_count: ActiveValue::NotSet,
256 })
257 .on_conflict(
258 OnConflict::columns([extension::Column::ExternalId])
259 .update_column(extension::Column::ExternalId)
260 .to_owned(),
261 );
262
263 let extension = if tx.support_returning() {
264 insert.exec_with_returning(&*tx).await?
265 } else {
266 // Sqlite
267 insert.exec_without_returning(&*tx).await?;
268 extension::Entity::find()
269 .filter(extension::Column::ExternalId.eq(*external_id))
270 .one(&*tx)
271 .await?
272 .ok_or_else(|| anyhow!("failed to insert extension"))?
273 };
274
275 extension_version::Entity::insert_many(versions.iter().map(|version| {
276 extension_version::ActiveModel {
277 extension_id: ActiveValue::Set(extension.id),
278 published_at: ActiveValue::Set(version.published_at),
279 version: ActiveValue::Set(version.version.to_string()),
280 authors: ActiveValue::Set(version.authors.join(", ")),
281 repository: ActiveValue::Set(version.repository.clone()),
282 description: ActiveValue::Set(version.description.clone()),
283 schema_version: ActiveValue::Set(version.schema_version),
284 wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
285 download_count: ActiveValue::NotSet,
286 }
287 }))
288 .on_conflict(OnConflict::new().do_nothing().to_owned())
289 .exec_without_returning(&*tx)
290 .await?;
291
292 if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
293 if db_version >= latest_version.version {
294 continue;
295 }
296 }
297
298 let mut extension = extension.into_active_model();
299 extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
300 extension.name = ActiveValue::set(latest_version.name.clone());
301 extension::Entity::update(extension).exec(&*tx).await?;
302 }
303
304 Ok(())
305 })
306 .await
307 }
308
309 pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
310 self.transaction(|tx| async move {
311 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
312 enum QueryId {
313 Id,
314 }
315
316 let extension_id: Option<ExtensionId> = extension::Entity::find()
317 .filter(extension::Column::ExternalId.eq(extension))
318 .select_only()
319 .column(extension::Column::Id)
320 .into_values::<_, QueryId>()
321 .one(&*tx)
322 .await?;
323 let Some(extension_id) = extension_id else {
324 return Ok(false);
325 };
326
327 extension_version::Entity::update_many()
328 .col_expr(
329 extension_version::Column::DownloadCount,
330 extension_version::Column::DownloadCount.into_expr().add(1),
331 )
332 .filter(
333 extension_version::Column::ExtensionId
334 .eq(extension_id)
335 .and(extension_version::Column::Version.eq(version)),
336 )
337 .exec(&*tx)
338 .await?;
339
340 extension::Entity::update_many()
341 .col_expr(
342 extension::Column::TotalDownloadCount,
343 extension::Column::TotalDownloadCount.into_expr().add(1),
344 )
345 .filter(extension::Column::Id.eq(extension_id))
346 .exec(&*tx)
347 .await?;
348
349 Ok(true)
350 })
351 .await
352 }
353}
354
355fn metadata_from_extension_and_version(
356 extension: extension::Model,
357 version: extension_version::Model,
358) -> ExtensionMetadata {
359 ExtensionMetadata {
360 id: extension.external_id.into(),
361 manifest: rpc::ExtensionApiManifest {
362 name: extension.name,
363 version: version.version.into(),
364 authors: version
365 .authors
366 .split(',')
367 .map(|author| author.trim().to_string())
368 .collect::<Vec<_>>(),
369 description: Some(version.description),
370 repository: version.repository,
371 schema_version: Some(version.schema_version),
372 wasm_api_version: version.wasm_api_version,
373 },
374
375 published_at: convert_time_to_chrono(version.published_at),
376 download_count: extension.total_download_count as u64,
377 }
378}
379
380pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
381 chrono::DateTime::from_naive_utc_and_offset(
382 #[allow(deprecated)]
383 chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
384 Utc,
385 )
386}