1use std::str::FromStr;
2
3use anyhow::Context;
4use chrono::Utc;
5use sea_orm::sea_query::IntoCondition;
6use util::ResultExt;
7
8use super::*;
9
10impl Database {
11 pub async fn get_extensions_by_ids(
12 &self,
13 ids: &[&str],
14 constraints: Option<&ExtensionVersionConstraints>,
15 ) -> Result<Vec<ExtensionMetadata>> {
16 self.transaction(|tx| async move {
17 let extensions = extension::Entity::find()
18 .filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
19 .all(&*tx)
20 .await?;
21
22 let mut max_versions = self
23 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
24 .await?;
25
26 Ok(extensions
27 .into_iter()
28 .filter_map(|extension| {
29 let (version, _) = max_versions.remove(&extension.id)?;
30 Some(metadata_from_extension_and_version(extension, version))
31 })
32 .collect())
33 })
34 .await
35 }
36
37 async fn get_latest_versions_for_extensions(
38 &self,
39 extensions: &[extension::Model],
40 constraints: Option<&ExtensionVersionConstraints>,
41 tx: &DatabaseTransaction,
42 ) -> Result<HashMap<ExtensionId, (extension_version::Model, Version)>> {
43 let mut versions = extension_version::Entity::find()
44 .filter(
45 extension_version::Column::ExtensionId
46 .is_in(extensions.iter().map(|extension| extension.id)),
47 )
48 .stream(tx)
49 .await?;
50
51 let mut max_versions =
52 HashMap::<ExtensionId, (extension_version::Model, Version)>::default();
53 while let Some(version) = versions.next().await {
54 let version = version?;
55 let Some(extension_version) = Version::from_str(&version.version).log_err() else {
56 continue;
57 };
58
59 if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id)
60 && max_extension_version > &extension_version
61 {
62 continue;
63 }
64
65 if let Some(constraints) = constraints {
66 if !constraints
67 .schema_versions
68 .contains(&version.schema_version)
69 {
70 continue;
71 }
72
73 if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
74 if let Some(version) = Version::from_str(wasm_api_version).log_err() {
75 if !constraints.wasm_api_versions.contains(&version) {
76 continue;
77 }
78 } else {
79 continue;
80 }
81 }
82 }
83
84 max_versions.insert(version.extension_id, (version, extension_version));
85 }
86
87 Ok(max_versions)
88 }
89
90 /// Returns all of the versions for the extension with the given ID.
91 pub async fn get_extension_versions(
92 &self,
93 extension_id: &str,
94 ) -> Result<Vec<ExtensionMetadata>> {
95 self.transaction(|tx| async move {
96 let condition = extension::Column::ExternalId
97 .eq(extension_id)
98 .into_condition();
99
100 self.get_extensions_where(condition, None, &tx).await
101 })
102 .await
103 }
104
105 async fn get_extensions_where(
106 &self,
107 condition: Condition,
108 limit: Option<u64>,
109 tx: &DatabaseTransaction,
110 ) -> Result<Vec<ExtensionMetadata>> {
111 let extensions = extension::Entity::find()
112 .inner_join(extension_version::Entity)
113 .select_also(extension_version::Entity)
114 .filter(condition)
115 .order_by_desc(extension::Column::TotalDownloadCount)
116 .order_by_asc(extension::Column::Name)
117 .limit(limit)
118 .all(tx)
119 .await?;
120
121 Ok(extensions
122 .into_iter()
123 .filter_map(|(extension, version)| {
124 Some(metadata_from_extension_and_version(extension, version?))
125 })
126 .collect())
127 }
128
129 pub async fn get_extension(
130 &self,
131 extension_id: &str,
132 constraints: Option<&ExtensionVersionConstraints>,
133 ) -> Result<Option<ExtensionMetadata>> {
134 self.transaction(|tx| async move {
135 let extension = extension::Entity::find()
136 .filter(extension::Column::ExternalId.eq(extension_id))
137 .one(&*tx)
138 .await?
139 .with_context(|| format!("no such extension: {extension_id}"))?;
140
141 let extensions = [extension];
142 let mut versions = self
143 .get_latest_versions_for_extensions(&extensions, constraints, &tx)
144 .await?;
145 let [extension] = extensions;
146
147 Ok(versions.remove(&extension.id).map(|(max_version, _)| {
148 metadata_from_extension_and_version(extension, max_version)
149 }))
150 })
151 .await
152 }
153
154 pub async fn get_extension_version(
155 &self,
156 extension_id: &str,
157 version: &str,
158 ) -> Result<Option<ExtensionMetadata>> {
159 self.transaction(|tx| async move {
160 let extension = extension::Entity::find()
161 .filter(extension::Column::ExternalId.eq(extension_id))
162 .filter(extension_version::Column::Version.eq(version))
163 .inner_join(extension_version::Entity)
164 .select_also(extension_version::Entity)
165 .one(&*tx)
166 .await?;
167
168 Ok(extension.and_then(|(extension, version)| {
169 Some(metadata_from_extension_and_version(extension, version?))
170 }))
171 })
172 .await
173 }
174
175 pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
176 self.transaction(|tx| async move {
177 let mut extension_external_ids_by_id = HashMap::default();
178
179 let mut rows = extension::Entity::find().stream(&*tx).await?;
180 while let Some(row) = rows.next().await {
181 let row = row?;
182 extension_external_ids_by_id.insert(row.id, row.external_id);
183 }
184 drop(rows);
185
186 let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
187 HashMap::default();
188 let mut rows = extension_version::Entity::find().stream(&*tx).await?;
189 while let Some(row) = rows.next().await {
190 let row = row?;
191
192 let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
193 continue;
194 };
195
196 let versions = known_versions_by_extension_id
197 .entry(extension_id.clone())
198 .or_default();
199 if let Err(ix) = versions.binary_search(&row.version) {
200 versions.insert(ix, row.version);
201 }
202 }
203 drop(rows);
204
205 Ok(known_versions_by_extension_id)
206 })
207 .await
208 }
209
210 pub async fn insert_extension_versions(
211 &self,
212 versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
213 ) -> Result<()> {
214 self.transaction(|tx| async move {
215 for (external_id, versions) in versions_by_extension_id {
216 if versions.is_empty() {
217 continue;
218 }
219
220 let latest_version = versions
221 .iter()
222 .max_by_key(|version| &version.version)
223 .unwrap();
224
225 let insert = extension::Entity::insert(extension::ActiveModel {
226 name: ActiveValue::Set(latest_version.name.clone()),
227 external_id: ActiveValue::Set((*external_id).to_owned()),
228 id: ActiveValue::NotSet,
229 latest_version: ActiveValue::Set(latest_version.version.to_string()),
230 total_download_count: ActiveValue::NotSet,
231 })
232 .on_conflict(
233 OnConflict::columns([extension::Column::ExternalId])
234 .update_column(extension::Column::ExternalId)
235 .to_owned(),
236 );
237
238 let extension = if tx.support_returning() {
239 insert.exec_with_returning(&*tx).await?
240 } else {
241 // Sqlite
242 insert.exec_without_returning(&*tx).await?;
243 extension::Entity::find()
244 .filter(extension::Column::ExternalId.eq(*external_id))
245 .one(&*tx)
246 .await?
247 .context("failed to insert extension")?
248 };
249
250 extension_version::Entity::insert_many(versions.iter().map(|version| {
251 extension_version::ActiveModel {
252 extension_id: ActiveValue::Set(extension.id),
253 published_at: ActiveValue::Set(version.published_at),
254 version: ActiveValue::Set(version.version.to_string()),
255 authors: ActiveValue::Set(version.authors.join(", ")),
256 repository: ActiveValue::Set(version.repository.clone()),
257 description: ActiveValue::Set(version.description.clone()),
258 schema_version: ActiveValue::Set(version.schema_version),
259 wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
260 provides_themes: ActiveValue::Set(
261 version.provides.contains(&ExtensionProvides::Themes),
262 ),
263 provides_icon_themes: ActiveValue::Set(
264 version.provides.contains(&ExtensionProvides::IconThemes),
265 ),
266 provides_languages: ActiveValue::Set(
267 version.provides.contains(&ExtensionProvides::Languages),
268 ),
269 provides_grammars: ActiveValue::Set(
270 version.provides.contains(&ExtensionProvides::Grammars),
271 ),
272 provides_language_servers: ActiveValue::Set(
273 version
274 .provides
275 .contains(&ExtensionProvides::LanguageServers),
276 ),
277 provides_context_servers: ActiveValue::Set(
278 version
279 .provides
280 .contains(&ExtensionProvides::ContextServers),
281 ),
282 provides_agent_servers: ActiveValue::Set(
283 version.provides.contains(&ExtensionProvides::AgentServers),
284 ),
285 provides_slash_commands: ActiveValue::Set(
286 version.provides.contains(&ExtensionProvides::SlashCommands),
287 ),
288 provides_indexed_docs_providers: ActiveValue::Set(
289 version
290 .provides
291 .contains(&ExtensionProvides::IndexedDocsProviders),
292 ),
293 provides_snippets: ActiveValue::Set(
294 version.provides.contains(&ExtensionProvides::Snippets),
295 ),
296 provides_debug_adapters: ActiveValue::Set(
297 version.provides.contains(&ExtensionProvides::DebugAdapters),
298 ),
299 download_count: ActiveValue::NotSet,
300 }
301 }))
302 .on_conflict(OnConflict::new().do_nothing().to_owned())
303 .exec_without_returning(&*tx)
304 .await?;
305
306 if let Ok(db_version) = semver::Version::parse(&extension.latest_version)
307 && db_version >= latest_version.version
308 {
309 continue;
310 }
311
312 let mut extension = extension.into_active_model();
313 extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
314 extension.name = ActiveValue::set(latest_version.name.clone());
315 extension::Entity::update(extension).exec(&*tx).await?;
316 }
317
318 Ok(())
319 })
320 .await
321 }
322
323 pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
324 self.transaction(|tx| async move {
325 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
326 enum QueryId {
327 Id,
328 }
329
330 let extension_id: Option<ExtensionId> = extension::Entity::find()
331 .filter(extension::Column::ExternalId.eq(extension))
332 .select_only()
333 .column(extension::Column::Id)
334 .into_values::<_, QueryId>()
335 .one(&*tx)
336 .await?;
337 let Some(extension_id) = extension_id else {
338 return Ok(false);
339 };
340
341 extension_version::Entity::update_many()
342 .col_expr(
343 extension_version::Column::DownloadCount,
344 extension_version::Column::DownloadCount.into_expr().add(1),
345 )
346 .filter(
347 extension_version::Column::ExtensionId
348 .eq(extension_id)
349 .and(extension_version::Column::Version.eq(version)),
350 )
351 .exec(&*tx)
352 .await?;
353
354 extension::Entity::update_many()
355 .col_expr(
356 extension::Column::TotalDownloadCount,
357 extension::Column::TotalDownloadCount.into_expr().add(1),
358 )
359 .filter(extension::Column::Id.eq(extension_id))
360 .exec(&*tx)
361 .await?;
362
363 Ok(true)
364 })
365 .await
366 }
367}
368
369fn metadata_from_extension_and_version(
370 extension: extension::Model,
371 version: extension_version::Model,
372) -> ExtensionMetadata {
373 let provides = version.provides();
374
375 ExtensionMetadata {
376 id: extension.external_id.into(),
377 manifest: cloud_api_types::ExtensionApiManifest {
378 name: extension.name,
379 version: version.version.into(),
380 authors: version
381 .authors
382 .split(',')
383 .map(|author| author.trim().to_string())
384 .collect::<Vec<_>>(),
385 description: Some(version.description),
386 repository: version.repository,
387 schema_version: Some(version.schema_version),
388 wasm_api_version: version.wasm_api_version,
389 provides,
390 },
391
392 published_at: convert_time_to_chrono(version.published_at),
393 download_count: extension.total_download_count as u64,
394 }
395}
396
397pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
398 chrono::DateTime::from_naive_utc_and_offset(
399 #[allow(deprecated)]
400 chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
401 Utc,
402 )
403}