1use std::str::FromStr;
2
3use anyhow::Context;
4use chrono::Utc;
5use sea_orm::sea_query::IntoCondition;
6use util::ResultExt;
7
8use super::*;
9
10impl Database {
11 pub async fn get_extensions(
12 &self,
13 filter: Option<&str>,
14 provides_filter: Option<&BTreeSet<ExtensionProvides>>,
15 max_schema_version: i32,
16 limit: usize,
17 ) -> Result<Vec<ExtensionMetadata>> {
18 self.transaction(|tx| async move {
19 let mut condition = Condition::all()
20 .add(
21 extension::Column::LatestVersion
22 .into_expr()
23 .eq(extension_version::Column::Version.into_expr()),
24 )
25 .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
26 if let Some(filter) = filter {
27 let fuzzy_name_filter = Self::fuzzy_like_string(filter);
28 condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
29 }
30
31 if let Some(provides_filter) = provides_filter {
32 condition = apply_provides_filter(condition, provides_filter);
33 }
34
35 self.get_extensions_where(condition, Some(limit as u64), &tx)
36 .await
37 })
38 .await
39 }
40
41 pub async fn get_extensions_by_ids(
42 &self,
43 ids: &[&str],
44 constraints: Option<&ExtensionVersionConstraints>,
45 ) -> Result<Vec<ExtensionMetadata>> {
46 self.transaction(|tx| async move {
47 let extensions = extension::Entity::find()
48 .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
49 .all(&*tx)
50 .await?;
51
52 let mut max_versions = self
53 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
54 .await?;
55
56 Ok(extensions
57 .into_iter()
58 .filter_map(|extension| {
59 let (version, _) = max_versions.remove(&extension.id)?;
60 Some(metadata_from_extension_and_version(extension, version))
61 })
62 .collect())
63 })
64 .await
65 }
66
67 async fn get_latest_versions_for_extensions(
68 &self,
69 extensions: &[extension::Model],
70 constraints: Option<&ExtensionVersionConstraints>,
71 tx: &DatabaseTransaction,
72 ) -> Result<HashMap<ExtensionId, (extension_version::Model, Version)>> {
73 let mut versions = extension_version::Entity::find()
74 .filter(
75 extension_version::Column::ExtensionId
76 .is_in(extensions.iter().map(|extension| extension.id)),
77 )
78 .stream(tx)
79 .await?;
80
81 let mut max_versions =
82 HashMap::<ExtensionId, (extension_version::Model, Version)>::default();
83 while let Some(version) = versions.next().await {
84 let version = version?;
85 let Some(extension_version) = Version::from_str(&version.version).log_err() else {
86 continue;
87 };
88
89 if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id)
90 && max_extension_version > &extension_version
91 {
92 continue;
93 }
94
95 if let Some(constraints) = constraints {
96 if !constraints
97 .schema_versions
98 .contains(&version.schema_version)
99 {
100 continue;
101 }
102
103 if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
104 if let Some(version) = Version::from_str(wasm_api_version).log_err() {
105 if !constraints.wasm_api_versions.contains(&version) {
106 continue;
107 }
108 } else {
109 continue;
110 }
111 }
112 }
113
114 max_versions.insert(version.extension_id, (version, extension_version));
115 }
116
117 Ok(max_versions)
118 }
119
120 /// Returns all of the versions for the extension with the given ID.
121 pub async fn get_extension_versions(
122 &self,
123 extension_id: &str,
124 ) -> Result<Vec<ExtensionMetadata>> {
125 self.transaction(|tx| async move {
126 let condition = extension::Column::ExternalId
127 .eq(extension_id)
128 .into_condition();
129
130 self.get_extensions_where(condition, None, &tx).await
131 })
132 .await
133 }
134
135 async fn get_extensions_where(
136 &self,
137 condition: Condition,
138 limit: Option<u64>,
139 tx: &DatabaseTransaction,
140 ) -> Result<Vec<ExtensionMetadata>> {
141 let extensions = extension::Entity::find()
142 .inner_join(extension_version::Entity)
143 .select_also(extension_version::Entity)
144 .filter(condition)
145 .order_by_desc(extension::Column::TotalDownloadCount)
146 .order_by_asc(extension::Column::Name)
147 .limit(limit)
148 .all(tx)
149 .await?;
150
151 Ok(extensions
152 .into_iter()
153 .filter_map(|(extension, version)| {
154 Some(metadata_from_extension_and_version(extension, version?))
155 })
156 .collect())
157 }
158
159 pub async fn get_extension(
160 &self,
161 extension_id: &str,
162 constraints: Option<&ExtensionVersionConstraints>,
163 ) -> Result<Option<ExtensionMetadata>> {
164 self.transaction(|tx| async move {
165 let extension = extension::Entity::find()
166 .filter(extension::Column::ExternalId.eq(extension_id))
167 .one(&*tx)
168 .await?
169 .with_context(|| format!("no such extension: {extension_id}"))?;
170
171 let extensions = [extension];
172 let mut versions = self
173 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
174 .await?;
175 let [extension] = extensions;
176
177 Ok(versions.remove(&extension.id).map(|(max_version, _)| {
178 metadata_from_extension_and_version(extension, max_version)
179 }))
180 })
181 .await
182 }
183
184 pub async fn get_extension_version(
185 &self,
186 extension_id: &str,
187 version: &str,
188 ) -> Result<Option<ExtensionMetadata>> {
189 self.transaction(|tx| async move {
190 let extension = extension::Entity::find()
191 .filter(extension::Column::ExternalId.eq(extension_id))
192 .filter(extension_version::Column::Version.eq(version))
193 .inner_join(extension_version::Entity)
194 .select_also(extension_version::Entity)
195 .one(&*tx)
196 .await?;
197
198 Ok(extension.and_then(|(extension, version)| {
199 Some(metadata_from_extension_and_version(extension, version?))
200 }))
201 })
202 .await
203 }
204
205 pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
206 self.transaction(|tx| async move {
207 let mut extension_external_ids_by_id = HashMap::default();
208
209 let mut rows = extension::Entity::find().stream(&*tx).await?;
210 while let Some(row) = rows.next().await {
211 let row = row?;
212 extension_external_ids_by_id.insert(row.id, row.external_id);
213 }
214 drop(rows);
215
216 let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
217 HashMap::default();
218 let mut rows = extension_version::Entity::find().stream(&*tx).await?;
219 while let Some(row) = rows.next().await {
220 let row = row?;
221
222 let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
223 continue;
224 };
225
226 let versions = known_versions_by_extension_id
227 .entry(extension_id.clone())
228 .or_default();
229 if let Err(ix) = versions.binary_search(&row.version) {
230 versions.insert(ix, row.version);
231 }
232 }
233 drop(rows);
234
235 Ok(known_versions_by_extension_id)
236 })
237 .await
238 }
239
240 pub async fn insert_extension_versions(
241 &self,
242 versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
243 ) -> Result<()> {
244 self.transaction(|tx| async move {
245 for (external_id, versions) in versions_by_extension_id {
246 if versions.is_empty() {
247 continue;
248 }
249
250 let latest_version = versions
251 .iter()
252 .max_by_key(|version| &version.version)
253 .unwrap();
254
255 let insert = extension::Entity::insert(extension::ActiveModel {
256 name: ActiveValue::Set(latest_version.name.clone()),
257 external_id: ActiveValue::Set((*external_id).to_owned()),
258 id: ActiveValue::NotSet,
259 latest_version: ActiveValue::Set(latest_version.version.to_string()),
260 total_download_count: ActiveValue::NotSet,
261 })
262 .on_conflict(
263 OnConflict::columns([extension::Column::ExternalId])
264 .update_column(extension::Column::ExternalId)
265 .to_owned(),
266 );
267
268 let extension = if tx.support_returning() {
269 insert.exec_with_returning(&*tx).await?
270 } else {
271 // Sqlite
272 insert.exec_without_returning(&*tx).await?;
273 extension::Entity::find()
274 .filter(extension::Column::ExternalId.eq(*external_id))
275 .one(&*tx)
276 .await?
277 .context("failed to insert extension")?
278 };
279
280 extension_version::Entity::insert_many(versions.iter().map(|version| {
281 extension_version::ActiveModel {
282 extension_id: ActiveValue::Set(extension.id),
283 published_at: ActiveValue::Set(version.published_at),
284 version: ActiveValue::Set(version.version.to_string()),
285 authors: ActiveValue::Set(version.authors.join(", ")),
286 repository: ActiveValue::Set(version.repository.clone()),
287 description: ActiveValue::Set(version.description.clone()),
288 schema_version: ActiveValue::Set(version.schema_version),
289 wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
290 provides_themes: ActiveValue::Set(
291 version.provides.contains(&ExtensionProvides::Themes),
292 ),
293 provides_icon_themes: ActiveValue::Set(
294 version.provides.contains(&ExtensionProvides::IconThemes),
295 ),
296 provides_languages: ActiveValue::Set(
297 version.provides.contains(&ExtensionProvides::Languages),
298 ),
299 provides_grammars: ActiveValue::Set(
300 version.provides.contains(&ExtensionProvides::Grammars),
301 ),
302 provides_language_servers: ActiveValue::Set(
303 version
304 .provides
305 .contains(&ExtensionProvides::LanguageServers),
306 ),
307 provides_context_servers: ActiveValue::Set(
308 version
309 .provides
310 .contains(&ExtensionProvides::ContextServers),
311 ),
312 provides_agent_servers: ActiveValue::Set(
313 version.provides.contains(&ExtensionProvides::AgentServers),
314 ),
315 provides_slash_commands: ActiveValue::Set(
316 version.provides.contains(&ExtensionProvides::SlashCommands),
317 ),
318 provides_indexed_docs_providers: ActiveValue::Set(
319 version
320 .provides
321 .contains(&ExtensionProvides::IndexedDocsProviders),
322 ),
323 provides_snippets: ActiveValue::Set(
324 version.provides.contains(&ExtensionProvides::Snippets),
325 ),
326 provides_debug_adapters: ActiveValue::Set(
327 version.provides.contains(&ExtensionProvides::DebugAdapters),
328 ),
329 download_count: ActiveValue::NotSet,
330 }
331 }))
332 .on_conflict(OnConflict::new().do_nothing().to_owned())
333 .exec_without_returning(&*tx)
334 .await?;
335
336 if let Ok(db_version) = semver::Version::parse(&extension.latest_version)
337 && db_version >= latest_version.version
338 {
339 continue;
340 }
341
342 let mut extension = extension.into_active_model();
343 extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
344 extension.name = ActiveValue::set(latest_version.name.clone());
345 extension::Entity::update(extension).exec(&*tx).await?;
346 }
347
348 Ok(())
349 })
350 .await
351 }
352
353 pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
354 self.transaction(|tx| async move {
355 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
356 enum QueryId {
357 Id,
358 }
359
360 let extension_id: Option<ExtensionId> = extension::Entity::find()
361 .filter(extension::Column::ExternalId.eq(extension))
362 .select_only()
363 .column(extension::Column::Id)
364 .into_values::<_, QueryId>()
365 .one(&*tx)
366 .await?;
367 let Some(extension_id) = extension_id else {
368 return Ok(false);
369 };
370
371 extension_version::Entity::update_many()
372 .col_expr(
373 extension_version::Column::DownloadCount,
374 extension_version::Column::DownloadCount.into_expr().add(1),
375 )
376 .filter(
377 extension_version::Column::ExtensionId
378 .eq(extension_id)
379 .and(extension_version::Column::Version.eq(version)),
380 )
381 .exec(&*tx)
382 .await?;
383
384 extension::Entity::update_many()
385 .col_expr(
386 extension::Column::TotalDownloadCount,
387 extension::Column::TotalDownloadCount.into_expr().add(1),
388 )
389 .filter(extension::Column::Id.eq(extension_id))
390 .exec(&*tx)
391 .await?;
392
393 Ok(true)
394 })
395 .await
396 }
397}
398
399fn apply_provides_filter(
400 mut condition: Condition,
401 provides_filter: &BTreeSet<ExtensionProvides>,
402) -> Condition {
403 if provides_filter.contains(&ExtensionProvides::Themes) {
404 condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
405 }
406
407 if provides_filter.contains(&ExtensionProvides::IconThemes) {
408 condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
409 }
410
411 if provides_filter.contains(&ExtensionProvides::Languages) {
412 condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
413 }
414
415 if provides_filter.contains(&ExtensionProvides::Grammars) {
416 condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
417 }
418
419 if provides_filter.contains(&ExtensionProvides::LanguageServers) {
420 condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
421 }
422
423 if provides_filter.contains(&ExtensionProvides::ContextServers) {
424 condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
425 }
426
427 if provides_filter.contains(&ExtensionProvides::AgentServers) {
428 condition = condition.add(extension_version::Column::ProvidesAgentServers.eq(true));
429 }
430
431 if provides_filter.contains(&ExtensionProvides::SlashCommands) {
432 condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
433 }
434
435 if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
436 condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
437 }
438
439 if provides_filter.contains(&ExtensionProvides::Snippets) {
440 condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
441 }
442
443 if provides_filter.contains(&ExtensionProvides::DebugAdapters) {
444 condition = condition.add(extension_version::Column::ProvidesDebugAdapters.eq(true));
445 }
446
447 condition
448}
449
450fn metadata_from_extension_and_version(
451 extension: extension::Model,
452 version: extension_version::Model,
453) -> ExtensionMetadata {
454 let provides = version.provides();
455
456 ExtensionMetadata {
457 id: extension.external_id.into(),
458 manifest: rpc::ExtensionApiManifest {
459 name: extension.name,
460 version: version.version.into(),
461 authors: version
462 .authors
463 .split(',')
464 .map(|author| author.trim().to_string())
465 .collect::<Vec<_>>(),
466 description: Some(version.description),
467 repository: version.repository,
468 schema_version: Some(version.schema_version),
469 wasm_api_version: version.wasm_api_version,
470 provides,
471 },
472
473 published_at: convert_time_to_chrono(version.published_at),
474 download_count: extension.total_download_count as u64,
475 }
476}
477
478pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
479 chrono::DateTime::from_naive_utc_and_offset(
480 #[allow(deprecated)]
481 chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
482 Utc,
483 )
484}