1use std::str::FromStr;
2
3use anyhow::Context;
4use chrono::Utc;
5use sea_orm::sea_query::IntoCondition;
6use util::ResultExt;
7
8use super::*;
9
10impl Database {
11 pub async fn get_extensions(
12 &self,
13 filter: Option<&str>,
14 provides_filter: Option<&BTreeSet<ExtensionProvides>>,
15 max_schema_version: i32,
16 limit: usize,
17 ) -> Result<Vec<ExtensionMetadata>> {
18 self.transaction(|tx| async move {
19 let mut condition = Condition::all()
20 .add(
21 extension::Column::LatestVersion
22 .into_expr()
23 .eq(extension_version::Column::Version.into_expr()),
24 )
25 .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
26 if let Some(filter) = filter {
27 let fuzzy_name_filter = Self::fuzzy_like_string(filter);
28 condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
29 }
30
31 if let Some(provides_filter) = provides_filter {
32 condition = apply_provides_filter(condition, provides_filter);
33 }
34
35 self.get_extensions_where(condition, Some(limit as u64), &tx)
36 .await
37 })
38 .await
39 }
40
41 pub async fn get_extensions_by_ids(
42 &self,
43 ids: &[&str],
44 constraints: Option<&ExtensionVersionConstraints>,
45 ) -> Result<Vec<ExtensionMetadata>> {
46 self.transaction(|tx| async move {
47 let extensions = extension::Entity::find()
48 .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
49 .all(&*tx)
50 .await?;
51
52 let mut max_versions = self
53 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
54 .await?;
55
56 Ok(extensions
57 .into_iter()
58 .filter_map(|extension| {
59 let (version, _) = max_versions.remove(&extension.id)?;
60 Some(metadata_from_extension_and_version(extension, version))
61 })
62 .collect())
63 })
64 .await
65 }
66
67 async fn get_latest_versions_for_extensions(
68 &self,
69 extensions: &[extension::Model],
70 constraints: Option<&ExtensionVersionConstraints>,
71 tx: &DatabaseTransaction,
72 ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
73 let mut versions = extension_version::Entity::find()
74 .filter(
75 extension_version::Column::ExtensionId
76 .is_in(extensions.iter().map(|extension| extension.id)),
77 )
78 .stream(tx)
79 .await?;
80
81 let mut max_versions =
82 HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
83 while let Some(version) = versions.next().await {
84 let version = version?;
85 let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
86 else {
87 continue;
88 };
89
90 if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id)
91 && max_extension_version > &extension_version
92 {
93 continue;
94 }
95
96 if let Some(constraints) = constraints {
97 if !constraints
98 .schema_versions
99 .contains(&version.schema_version)
100 {
101 continue;
102 }
103
104 if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
105 if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
106 if !constraints.wasm_api_versions.contains(&version) {
107 continue;
108 }
109 } else {
110 continue;
111 }
112 }
113 }
114
115 max_versions.insert(version.extension_id, (version, extension_version));
116 }
117
118 Ok(max_versions)
119 }
120
121 /// Returns all of the versions for the extension with the given ID.
122 pub async fn get_extension_versions(
123 &self,
124 extension_id: &str,
125 ) -> Result<Vec<ExtensionMetadata>> {
126 self.transaction(|tx| async move {
127 let condition = extension::Column::ExternalId
128 .eq(extension_id)
129 .into_condition();
130
131 self.get_extensions_where(condition, None, &tx).await
132 })
133 .await
134 }
135
136 async fn get_extensions_where(
137 &self,
138 condition: Condition,
139 limit: Option<u64>,
140 tx: &DatabaseTransaction,
141 ) -> Result<Vec<ExtensionMetadata>> {
142 let extensions = extension::Entity::find()
143 .inner_join(extension_version::Entity)
144 .select_also(extension_version::Entity)
145 .filter(condition)
146 .order_by_desc(extension::Column::TotalDownloadCount)
147 .order_by_asc(extension::Column::Name)
148 .limit(limit)
149 .all(tx)
150 .await?;
151
152 Ok(extensions
153 .into_iter()
154 .filter_map(|(extension, version)| {
155 Some(metadata_from_extension_and_version(extension, version?))
156 })
157 .collect())
158 }
159
160 pub async fn get_extension(
161 &self,
162 extension_id: &str,
163 constraints: Option<&ExtensionVersionConstraints>,
164 ) -> Result<Option<ExtensionMetadata>> {
165 self.transaction(|tx| async move {
166 let extension = extension::Entity::find()
167 .filter(extension::Column::ExternalId.eq(extension_id))
168 .one(&*tx)
169 .await?
170 .with_context(|| format!("no such extension: {extension_id}"))?;
171
172 let extensions = [extension];
173 let mut versions = self
174 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
175 .await?;
176 let [extension] = extensions;
177
178 Ok(versions.remove(&extension.id).map(|(max_version, _)| {
179 metadata_from_extension_and_version(extension, max_version)
180 }))
181 })
182 .await
183 }
184
185 pub async fn get_extension_version(
186 &self,
187 extension_id: &str,
188 version: &str,
189 ) -> Result<Option<ExtensionMetadata>> {
190 self.transaction(|tx| async move {
191 let extension = extension::Entity::find()
192 .filter(extension::Column::ExternalId.eq(extension_id))
193 .filter(extension_version::Column::Version.eq(version))
194 .inner_join(extension_version::Entity)
195 .select_also(extension_version::Entity)
196 .one(&*tx)
197 .await?;
198
199 Ok(extension.and_then(|(extension, version)| {
200 Some(metadata_from_extension_and_version(extension, version?))
201 }))
202 })
203 .await
204 }
205
206 pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
207 self.transaction(|tx| async move {
208 let mut extension_external_ids_by_id = HashMap::default();
209
210 let mut rows = extension::Entity::find().stream(&*tx).await?;
211 while let Some(row) = rows.next().await {
212 let row = row?;
213 extension_external_ids_by_id.insert(row.id, row.external_id);
214 }
215 drop(rows);
216
217 let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
218 HashMap::default();
219 let mut rows = extension_version::Entity::find().stream(&*tx).await?;
220 while let Some(row) = rows.next().await {
221 let row = row?;
222
223 let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
224 continue;
225 };
226
227 let versions = known_versions_by_extension_id
228 .entry(extension_id.clone())
229 .or_default();
230 if let Err(ix) = versions.binary_search(&row.version) {
231 versions.insert(ix, row.version);
232 }
233 }
234 drop(rows);
235
236 Ok(known_versions_by_extension_id)
237 })
238 .await
239 }
240
241 pub async fn insert_extension_versions(
242 &self,
243 versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
244 ) -> Result<()> {
245 self.transaction(|tx| async move {
246 for (external_id, versions) in versions_by_extension_id {
247 if versions.is_empty() {
248 continue;
249 }
250
251 let latest_version = versions
252 .iter()
253 .max_by_key(|version| &version.version)
254 .unwrap();
255
256 let insert = extension::Entity::insert(extension::ActiveModel {
257 name: ActiveValue::Set(latest_version.name.clone()),
258 external_id: ActiveValue::Set(external_id.to_string()),
259 id: ActiveValue::NotSet,
260 latest_version: ActiveValue::Set(latest_version.version.to_string()),
261 total_download_count: ActiveValue::NotSet,
262 })
263 .on_conflict(
264 OnConflict::columns([extension::Column::ExternalId])
265 .update_column(extension::Column::ExternalId)
266 .to_owned(),
267 );
268
269 let extension = if tx.support_returning() {
270 insert.exec_with_returning(&*tx).await?
271 } else {
272 // Sqlite
273 insert.exec_without_returning(&*tx).await?;
274 extension::Entity::find()
275 .filter(extension::Column::ExternalId.eq(*external_id))
276 .one(&*tx)
277 .await?
278 .context("failed to insert extension")?
279 };
280
281 extension_version::Entity::insert_many(versions.iter().map(|version| {
282 extension_version::ActiveModel {
283 extension_id: ActiveValue::Set(extension.id),
284 published_at: ActiveValue::Set(version.published_at),
285 version: ActiveValue::Set(version.version.to_string()),
286 authors: ActiveValue::Set(version.authors.join(", ")),
287 repository: ActiveValue::Set(version.repository.clone()),
288 description: ActiveValue::Set(version.description.clone()),
289 schema_version: ActiveValue::Set(version.schema_version),
290 wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
291 provides_themes: ActiveValue::Set(
292 version.provides.contains(&ExtensionProvides::Themes),
293 ),
294 provides_icon_themes: ActiveValue::Set(
295 version.provides.contains(&ExtensionProvides::IconThemes),
296 ),
297 provides_languages: ActiveValue::Set(
298 version.provides.contains(&ExtensionProvides::Languages),
299 ),
300 provides_grammars: ActiveValue::Set(
301 version.provides.contains(&ExtensionProvides::Grammars),
302 ),
303 provides_language_servers: ActiveValue::Set(
304 version
305 .provides
306 .contains(&ExtensionProvides::LanguageServers),
307 ),
308 provides_context_servers: ActiveValue::Set(
309 version
310 .provides
311 .contains(&ExtensionProvides::ContextServers),
312 ),
313 provides_slash_commands: ActiveValue::Set(
314 version.provides.contains(&ExtensionProvides::SlashCommands),
315 ),
316 provides_indexed_docs_providers: ActiveValue::Set(
317 version
318 .provides
319 .contains(&ExtensionProvides::IndexedDocsProviders),
320 ),
321 provides_snippets: ActiveValue::Set(
322 version.provides.contains(&ExtensionProvides::Snippets),
323 ),
324 provides_debug_adapters: ActiveValue::Set(
325 version.provides.contains(&ExtensionProvides::DebugAdapters),
326 ),
327 download_count: ActiveValue::NotSet,
328 }
329 }))
330 .on_conflict(OnConflict::new().do_nothing().to_owned())
331 .exec_without_returning(&*tx)
332 .await?;
333
334 if let Ok(db_version) = semver::Version::parse(&extension.latest_version)
335 && db_version >= latest_version.version
336 {
337 continue;
338 }
339
340 let mut extension = extension.into_active_model();
341 extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
342 extension.name = ActiveValue::set(latest_version.name.clone());
343 extension::Entity::update(extension).exec(&*tx).await?;
344 }
345
346 Ok(())
347 })
348 .await
349 }
350
351 pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
352 self.transaction(|tx| async move {
353 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
354 enum QueryId {
355 Id,
356 }
357
358 let extension_id: Option<ExtensionId> = extension::Entity::find()
359 .filter(extension::Column::ExternalId.eq(extension))
360 .select_only()
361 .column(extension::Column::Id)
362 .into_values::<_, QueryId>()
363 .one(&*tx)
364 .await?;
365 let Some(extension_id) = extension_id else {
366 return Ok(false);
367 };
368
369 extension_version::Entity::update_many()
370 .col_expr(
371 extension_version::Column::DownloadCount,
372 extension_version::Column::DownloadCount.into_expr().add(1),
373 )
374 .filter(
375 extension_version::Column::ExtensionId
376 .eq(extension_id)
377 .and(extension_version::Column::Version.eq(version)),
378 )
379 .exec(&*tx)
380 .await?;
381
382 extension::Entity::update_many()
383 .col_expr(
384 extension::Column::TotalDownloadCount,
385 extension::Column::TotalDownloadCount.into_expr().add(1),
386 )
387 .filter(extension::Column::Id.eq(extension_id))
388 .exec(&*tx)
389 .await?;
390
391 Ok(true)
392 })
393 .await
394 }
395}
396
397fn apply_provides_filter(
398 mut condition: Condition,
399 provides_filter: &BTreeSet<ExtensionProvides>,
400) -> Condition {
401 if provides_filter.contains(&ExtensionProvides::Themes) {
402 condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
403 }
404
405 if provides_filter.contains(&ExtensionProvides::IconThemes) {
406 condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
407 }
408
409 if provides_filter.contains(&ExtensionProvides::Languages) {
410 condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
411 }
412
413 if provides_filter.contains(&ExtensionProvides::Grammars) {
414 condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
415 }
416
417 if provides_filter.contains(&ExtensionProvides::LanguageServers) {
418 condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
419 }
420
421 if provides_filter.contains(&ExtensionProvides::ContextServers) {
422 condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
423 }
424
425 if provides_filter.contains(&ExtensionProvides::SlashCommands) {
426 condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
427 }
428
429 if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
430 condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
431 }
432
433 if provides_filter.contains(&ExtensionProvides::Snippets) {
434 condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
435 }
436
437 if provides_filter.contains(&ExtensionProvides::DebugAdapters) {
438 condition = condition.add(extension_version::Column::ProvidesDebugAdapters.eq(true));
439 }
440
441 condition
442}
443
444fn metadata_from_extension_and_version(
445 extension: extension::Model,
446 version: extension_version::Model,
447) -> ExtensionMetadata {
448 let provides = version.provides();
449
450 ExtensionMetadata {
451 id: extension.external_id.into(),
452 manifest: rpc::ExtensionApiManifest {
453 name: extension.name,
454 version: version.version.into(),
455 authors: version
456 .authors
457 .split(',')
458 .map(|author| author.trim().to_string())
459 .collect::<Vec<_>>(),
460 description: Some(version.description),
461 repository: version.repository,
462 schema_version: Some(version.schema_version),
463 wasm_api_version: version.wasm_api_version,
464 provides,
465 },
466
467 published_at: convert_time_to_chrono(version.published_at),
468 download_count: extension.total_download_count as u64,
469 }
470}
471
472pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
473 chrono::DateTime::from_naive_utc_and_offset(
474 #[allow(deprecated)]
475 chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
476 Utc,
477 )
478}