1use std::str::FromStr;
2
3use chrono::Utc;
4use sea_orm::sea_query::IntoCondition;
5use util::ResultExt;
6
7use super::*;
8
9impl Database {
10 pub async fn get_extensions(
11 &self,
12 filter: Option<&str>,
13 provides_filter: Option<&BTreeSet<ExtensionProvides>>,
14 max_schema_version: i32,
15 limit: usize,
16 ) -> Result<Vec<ExtensionMetadata>> {
17 self.transaction(|tx| async move {
18 let mut condition = Condition::all()
19 .add(
20 extension::Column::LatestVersion
21 .into_expr()
22 .eq(extension_version::Column::Version.into_expr()),
23 )
24 .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
25 if let Some(filter) = filter {
26 let fuzzy_name_filter = Self::fuzzy_like_string(filter);
27 condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
28 }
29
30 if let Some(provides_filter) = provides_filter {
31 condition = apply_provides_filter(condition, provides_filter);
32 }
33
34 self.get_extensions_where(condition, Some(limit as u64), &tx)
35 .await
36 })
37 .await
38 }
39
40 pub async fn get_extensions_by_ids(
41 &self,
42 ids: &[&str],
43 constraints: Option<&ExtensionVersionConstraints>,
44 ) -> Result<Vec<ExtensionMetadata>> {
45 self.transaction(|tx| async move {
46 let extensions = extension::Entity::find()
47 .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
48 .all(&*tx)
49 .await?;
50
51 let mut max_versions = self
52 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
53 .await?;
54
55 Ok(extensions
56 .into_iter()
57 .filter_map(|extension| {
58 let (version, _) = max_versions.remove(&extension.id)?;
59 Some(metadata_from_extension_and_version(extension, version))
60 })
61 .collect())
62 })
63 .await
64 }
65
66 async fn get_latest_versions_for_extensions(
67 &self,
68 extensions: &[extension::Model],
69 constraints: Option<&ExtensionVersionConstraints>,
70 tx: &DatabaseTransaction,
71 ) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
72 let mut versions = extension_version::Entity::find()
73 .filter(
74 extension_version::Column::ExtensionId
75 .is_in(extensions.iter().map(|extension| extension.id)),
76 )
77 .stream(tx)
78 .await?;
79
80 let mut max_versions =
81 HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
82 while let Some(version) = versions.next().await {
83 let version = version?;
84 let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
85 else {
86 continue;
87 };
88
89 if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
90 if max_extension_version > &extension_version {
91 continue;
92 }
93 }
94
95 if let Some(constraints) = constraints {
96 if !constraints
97 .schema_versions
98 .contains(&version.schema_version)
99 {
100 continue;
101 }
102
103 if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
104 if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
105 if !constraints.wasm_api_versions.contains(&version) {
106 continue;
107 }
108 } else {
109 continue;
110 }
111 }
112 }
113
114 max_versions.insert(version.extension_id, (version, extension_version));
115 }
116
117 Ok(max_versions)
118 }
119
120 /// Returns all of the versions for the extension with the given ID.
121 pub async fn get_extension_versions(
122 &self,
123 extension_id: &str,
124 ) -> Result<Vec<ExtensionMetadata>> {
125 self.transaction(|tx| async move {
126 let condition = extension::Column::ExternalId
127 .eq(extension_id)
128 .into_condition();
129
130 self.get_extensions_where(condition, None, &tx).await
131 })
132 .await
133 }
134
135 async fn get_extensions_where(
136 &self,
137 condition: Condition,
138 limit: Option<u64>,
139 tx: &DatabaseTransaction,
140 ) -> Result<Vec<ExtensionMetadata>> {
141 let extensions = extension::Entity::find()
142 .inner_join(extension_version::Entity)
143 .select_also(extension_version::Entity)
144 .filter(condition)
145 .order_by_desc(extension::Column::TotalDownloadCount)
146 .order_by_asc(extension::Column::Name)
147 .limit(limit)
148 .all(tx)
149 .await?;
150
151 Ok(extensions
152 .into_iter()
153 .filter_map(|(extension, version)| {
154 Some(metadata_from_extension_and_version(extension, version?))
155 })
156 .collect())
157 }
158
159 pub async fn get_extension(
160 &self,
161 extension_id: &str,
162 constraints: Option<&ExtensionVersionConstraints>,
163 ) -> Result<Option<ExtensionMetadata>> {
164 self.transaction(|tx| async move {
165 let extension = extension::Entity::find()
166 .filter(extension::Column::ExternalId.eq(extension_id))
167 .one(&*tx)
168 .await?
169 .ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
170
171 let extensions = [extension];
172 let mut versions = self
173 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
174 .await?;
175 let [extension] = extensions;
176
177 Ok(versions.remove(&extension.id).map(|(max_version, _)| {
178 metadata_from_extension_and_version(extension, max_version)
179 }))
180 })
181 .await
182 }
183
184 pub async fn get_extension_version(
185 &self,
186 extension_id: &str,
187 version: &str,
188 ) -> Result<Option<ExtensionMetadata>> {
189 self.transaction(|tx| async move {
190 let extension = extension::Entity::find()
191 .filter(extension::Column::ExternalId.eq(extension_id))
192 .filter(extension_version::Column::Version.eq(version))
193 .inner_join(extension_version::Entity)
194 .select_also(extension_version::Entity)
195 .one(&*tx)
196 .await?;
197
198 Ok(extension.and_then(|(extension, version)| {
199 Some(metadata_from_extension_and_version(extension, version?))
200 }))
201 })
202 .await
203 }
204
205 pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
206 self.transaction(|tx| async move {
207 let mut extension_external_ids_by_id = HashMap::default();
208
209 let mut rows = extension::Entity::find().stream(&*tx).await?;
210 while let Some(row) = rows.next().await {
211 let row = row?;
212 extension_external_ids_by_id.insert(row.id, row.external_id);
213 }
214 drop(rows);
215
216 let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
217 HashMap::default();
218 let mut rows = extension_version::Entity::find().stream(&*tx).await?;
219 while let Some(row) = rows.next().await {
220 let row = row?;
221
222 let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
223 continue;
224 };
225
226 let versions = known_versions_by_extension_id
227 .entry(extension_id.clone())
228 .or_default();
229 if let Err(ix) = versions.binary_search(&row.version) {
230 versions.insert(ix, row.version);
231 }
232 }
233 drop(rows);
234
235 Ok(known_versions_by_extension_id)
236 })
237 .await
238 }
239
240 pub async fn insert_extension_versions(
241 &self,
242 versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
243 ) -> Result<()> {
244 self.transaction(|tx| async move {
245 for (external_id, versions) in versions_by_extension_id {
246 if versions.is_empty() {
247 continue;
248 }
249
250 let latest_version = versions
251 .iter()
252 .max_by_key(|version| &version.version)
253 .unwrap();
254
255 let insert = extension::Entity::insert(extension::ActiveModel {
256 name: ActiveValue::Set(latest_version.name.clone()),
257 external_id: ActiveValue::Set(external_id.to_string()),
258 id: ActiveValue::NotSet,
259 latest_version: ActiveValue::Set(latest_version.version.to_string()),
260 total_download_count: ActiveValue::NotSet,
261 })
262 .on_conflict(
263 OnConflict::columns([extension::Column::ExternalId])
264 .update_column(extension::Column::ExternalId)
265 .to_owned(),
266 );
267
268 let extension = if tx.support_returning() {
269 insert.exec_with_returning(&*tx).await?
270 } else {
271 // Sqlite
272 insert.exec_without_returning(&*tx).await?;
273 extension::Entity::find()
274 .filter(extension::Column::ExternalId.eq(*external_id))
275 .one(&*tx)
276 .await?
277 .ok_or_else(|| anyhow!("failed to insert extension"))?
278 };
279
280 extension_version::Entity::insert_many(versions.iter().map(|version| {
281 extension_version::ActiveModel {
282 extension_id: ActiveValue::Set(extension.id),
283 published_at: ActiveValue::Set(version.published_at),
284 version: ActiveValue::Set(version.version.to_string()),
285 authors: ActiveValue::Set(version.authors.join(", ")),
286 repository: ActiveValue::Set(version.repository.clone()),
287 description: ActiveValue::Set(version.description.clone()),
288 schema_version: ActiveValue::Set(version.schema_version),
289 wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
290 provides_themes: ActiveValue::Set(
291 version.provides.contains(&ExtensionProvides::Themes),
292 ),
293 provides_icon_themes: ActiveValue::Set(
294 version.provides.contains(&ExtensionProvides::IconThemes),
295 ),
296 provides_languages: ActiveValue::Set(
297 version.provides.contains(&ExtensionProvides::Languages),
298 ),
299 provides_grammars: ActiveValue::Set(
300 version.provides.contains(&ExtensionProvides::Grammars),
301 ),
302 provides_language_servers: ActiveValue::Set(
303 version
304 .provides
305 .contains(&ExtensionProvides::LanguageServers),
306 ),
307 provides_context_servers: ActiveValue::Set(
308 version
309 .provides
310 .contains(&ExtensionProvides::ContextServers),
311 ),
312 provides_slash_commands: ActiveValue::Set(
313 version.provides.contains(&ExtensionProvides::SlashCommands),
314 ),
315 provides_indexed_docs_providers: ActiveValue::Set(
316 version
317 .provides
318 .contains(&ExtensionProvides::IndexedDocsProviders),
319 ),
320 provides_snippets: ActiveValue::Set(
321 version.provides.contains(&ExtensionProvides::Snippets),
322 ),
323 download_count: ActiveValue::NotSet,
324 }
325 }))
326 .on_conflict(OnConflict::new().do_nothing().to_owned())
327 .exec_without_returning(&*tx)
328 .await?;
329
330 if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
331 if db_version >= latest_version.version {
332 continue;
333 }
334 }
335
336 let mut extension = extension.into_active_model();
337 extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
338 extension.name = ActiveValue::set(latest_version.name.clone());
339 extension::Entity::update(extension).exec(&*tx).await?;
340 }
341
342 Ok(())
343 })
344 .await
345 }
346
347 pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
348 self.transaction(|tx| async move {
349 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
350 enum QueryId {
351 Id,
352 }
353
354 let extension_id: Option<ExtensionId> = extension::Entity::find()
355 .filter(extension::Column::ExternalId.eq(extension))
356 .select_only()
357 .column(extension::Column::Id)
358 .into_values::<_, QueryId>()
359 .one(&*tx)
360 .await?;
361 let Some(extension_id) = extension_id else {
362 return Ok(false);
363 };
364
365 extension_version::Entity::update_many()
366 .col_expr(
367 extension_version::Column::DownloadCount,
368 extension_version::Column::DownloadCount.into_expr().add(1),
369 )
370 .filter(
371 extension_version::Column::ExtensionId
372 .eq(extension_id)
373 .and(extension_version::Column::Version.eq(version)),
374 )
375 .exec(&*tx)
376 .await?;
377
378 extension::Entity::update_many()
379 .col_expr(
380 extension::Column::TotalDownloadCount,
381 extension::Column::TotalDownloadCount.into_expr().add(1),
382 )
383 .filter(extension::Column::Id.eq(extension_id))
384 .exec(&*tx)
385 .await?;
386
387 Ok(true)
388 })
389 .await
390 }
391}
392
393fn apply_provides_filter(
394 mut condition: Condition,
395 provides_filter: &BTreeSet<ExtensionProvides>,
396) -> Condition {
397 if provides_filter.contains(&ExtensionProvides::Themes) {
398 condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
399 }
400
401 if provides_filter.contains(&ExtensionProvides::IconThemes) {
402 condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
403 }
404
405 if provides_filter.contains(&ExtensionProvides::Languages) {
406 condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
407 }
408
409 if provides_filter.contains(&ExtensionProvides::Grammars) {
410 condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
411 }
412
413 if provides_filter.contains(&ExtensionProvides::LanguageServers) {
414 condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
415 }
416
417 if provides_filter.contains(&ExtensionProvides::ContextServers) {
418 condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
419 }
420
421 if provides_filter.contains(&ExtensionProvides::SlashCommands) {
422 condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
423 }
424
425 if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
426 condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
427 }
428
429 if provides_filter.contains(&ExtensionProvides::Snippets) {
430 condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
431 }
432
433 condition
434}
435
436fn metadata_from_extension_and_version(
437 extension: extension::Model,
438 version: extension_version::Model,
439) -> ExtensionMetadata {
440 let provides = version.provides();
441
442 ExtensionMetadata {
443 id: extension.external_id.into(),
444 manifest: rpc::ExtensionApiManifest {
445 name: extension.name,
446 version: version.version.into(),
447 authors: version
448 .authors
449 .split(',')
450 .map(|author| author.trim().to_string())
451 .collect::<Vec<_>>(),
452 description: Some(version.description),
453 repository: version.repository,
454 schema_version: Some(version.schema_version),
455 wasm_api_version: version.wasm_api_version,
456 provides,
457 },
458
459 published_at: convert_time_to_chrono(version.published_at),
460 download_count: extension.total_download_count as u64,
461 }
462}
463
464pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
465 chrono::DateTime::from_naive_utc_and_offset(
466 #[allow(deprecated)]
467 chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
468 Utc,
469 )
470}