1use std::collections::BTreeSet;
  2use std::sync::Arc;
  3
  4use rpc::ExtensionProvides;
  5
  6use super::Database;
  7use crate::db::ExtensionVersionConstraints;
  8use crate::{
  9    db::{ExtensionMetadata, NewExtensionVersion, queries::extensions::convert_time_to_chrono},
 10    test_both_dbs,
 11};
 12
 13test_both_dbs!(
 14    test_extensions,
 15    test_extensions_postgres,
 16    test_extensions_sqlite
 17);
 18
 19test_both_dbs!(
 20    test_agent_servers_filter,
 21    test_agent_servers_filter_postgres,
 22    test_agent_servers_filter_sqlite
 23);
 24
 25async fn test_agent_servers_filter(db: &Arc<Database>) {
 26    // No extensions initially
 27    let versions = db.get_known_extension_versions().await.unwrap();
 28    assert!(versions.is_empty());
 29
 30    // Shared timestamp
 31    let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
 32    let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
 33
 34    // Insert two extensions, only one provides AgentServers
 35    db.insert_extension_versions(
 36        &[
 37            (
 38                "ext_agent_servers",
 39                vec![NewExtensionVersion {
 40                    name: "Agent Servers Provider".into(),
 41                    version: semver::Version::parse("1.0.0").unwrap(),
 42                    description: "has agent servers".into(),
 43                    authors: vec!["author".into()],
 44                    repository: "org/agent-servers".into(),
 45                    schema_version: 1,
 46                    wasm_api_version: None,
 47                    provides: BTreeSet::from_iter([ExtensionProvides::AgentServers]),
 48                    published_at: t0,
 49                }],
 50            ),
 51            (
 52                "ext_plain",
 53                vec![NewExtensionVersion {
 54                    name: "Plain Extension".into(),
 55                    version: semver::Version::parse("0.1.0").unwrap(),
 56                    description: "no agent servers".into(),
 57                    authors: vec!["author2".into()],
 58                    repository: "org/plain".into(),
 59                    schema_version: 1,
 60                    wasm_api_version: None,
 61                    provides: BTreeSet::default(),
 62                    published_at: t0,
 63                }],
 64            ),
 65        ]
 66        .into_iter()
 67        .collect(),
 68    )
 69    .await
 70    .unwrap();
 71
 72    // Filter by AgentServers provides
 73    let provides_filter = BTreeSet::from_iter([ExtensionProvides::AgentServers]);
 74
 75    let filtered = db
 76        .get_extensions(None, Some(&provides_filter), 1, 10)
 77        .await
 78        .unwrap();
 79
 80    // Expect only the extension that declared AgentServers
 81    assert_eq!(filtered.len(), 1);
 82    assert_eq!(filtered[0].id.as_ref(), "ext_agent_servers");
 83}
 84
 85async fn test_extensions(db: &Arc<Database>) {
 86    let versions = db.get_known_extension_versions().await.unwrap();
 87    assert!(versions.is_empty());
 88
 89    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
 90    assert!(extensions.is_empty());
 91
 92    let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
 93    let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
 94
 95    let t0_chrono = convert_time_to_chrono(t0);
 96
 97    db.insert_extension_versions(
 98        &[
 99            (
100                "ext1",
101                vec![
102                    NewExtensionVersion {
103                        name: "Extension 1".into(),
104                        version: semver::Version::parse("0.0.1").unwrap(),
105                        description: "an extension".into(),
106                        authors: vec!["max".into()],
107                        repository: "ext1/repo".into(),
108                        schema_version: 1,
109                        wasm_api_version: None,
110                        provides: BTreeSet::default(),
111                        published_at: t0,
112                    },
113                    NewExtensionVersion {
114                        name: "Extension One".into(),
115                        version: semver::Version::parse("0.0.2").unwrap(),
116                        description: "a good extension".into(),
117                        authors: vec!["max".into(), "marshall".into()],
118                        repository: "ext1/repo".into(),
119                        schema_version: 1,
120                        wasm_api_version: None,
121                        provides: BTreeSet::default(),
122                        published_at: t0,
123                    },
124                ],
125            ),
126            (
127                "ext2",
128                vec![NewExtensionVersion {
129                    name: "Extension Two".into(),
130                    version: semver::Version::parse("0.2.0").unwrap(),
131                    description: "a great extension".into(),
132                    authors: vec!["marshall".into()],
133                    repository: "ext2/repo".into(),
134                    schema_version: 0,
135                    wasm_api_version: None,
136                    provides: BTreeSet::default(),
137                    published_at: t0,
138                }],
139            ),
140        ]
141        .into_iter()
142        .collect(),
143    )
144    .await
145    .unwrap();
146
147    let versions = db.get_known_extension_versions().await.unwrap();
148    assert_eq!(
149        versions,
150        [
151            ("ext1".into(), vec!["0.0.1".into(), "0.0.2".into()]),
152            ("ext2".into(), vec!["0.2.0".into()])
153        ]
154        .into_iter()
155        .collect()
156    );
157
158    // The latest version of each extension is returned.
159    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
160    assert_eq!(
161        extensions,
162        &[
163            ExtensionMetadata {
164                id: "ext1".into(),
165                manifest: rpc::ExtensionApiManifest {
166                    name: "Extension One".into(),
167                    version: "0.0.2".into(),
168                    authors: vec!["max".into(), "marshall".into()],
169                    description: Some("a good extension".into()),
170                    repository: "ext1/repo".into(),
171                    schema_version: Some(1),
172                    wasm_api_version: None,
173                    provides: BTreeSet::default(),
174                },
175                published_at: t0_chrono,
176                download_count: 0,
177            },
178            ExtensionMetadata {
179                id: "ext2".into(),
180                manifest: rpc::ExtensionApiManifest {
181                    name: "Extension Two".into(),
182                    version: "0.2.0".into(),
183                    authors: vec!["marshall".into()],
184                    description: Some("a great extension".into()),
185                    repository: "ext2/repo".into(),
186                    schema_version: Some(0),
187                    wasm_api_version: None,
188                    provides: BTreeSet::default(),
189                },
190                published_at: t0_chrono,
191                download_count: 0
192            },
193        ]
194    );
195
196    // Extensions with too new of a schema version are excluded.
197    let extensions = db.get_extensions(None, None, 0, 5).await.unwrap();
198    assert_eq!(
199        extensions,
200        &[ExtensionMetadata {
201            id: "ext2".into(),
202            manifest: rpc::ExtensionApiManifest {
203                name: "Extension Two".into(),
204                version: "0.2.0".into(),
205                authors: vec!["marshall".into()],
206                description: Some("a great extension".into()),
207                repository: "ext2/repo".into(),
208                schema_version: Some(0),
209                wasm_api_version: None,
210                provides: BTreeSet::default(),
211            },
212            published_at: t0_chrono,
213            download_count: 0
214        },]
215    );
216
217    // Record extensions being downloaded.
218    for _ in 0..7 {
219        assert!(db.record_extension_download("ext2", "0.0.2").await.unwrap());
220    }
221
222    for _ in 0..3 {
223        assert!(db.record_extension_download("ext1", "0.0.1").await.unwrap());
224    }
225
226    for _ in 0..2 {
227        assert!(db.record_extension_download("ext1", "0.0.2").await.unwrap());
228    }
229
230    // Record download returns false if the extension does not exist.
231    assert!(
232        !db.record_extension_download("no-such-extension", "0.0.2")
233            .await
234            .unwrap()
235    );
236
237    // Extensions are returned in descending order of total downloads.
238    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
239    assert_eq!(
240        extensions,
241        &[
242            ExtensionMetadata {
243                id: "ext2".into(),
244                manifest: rpc::ExtensionApiManifest {
245                    name: "Extension Two".into(),
246                    version: "0.2.0".into(),
247                    authors: vec!["marshall".into()],
248                    description: Some("a great extension".into()),
249                    repository: "ext2/repo".into(),
250                    schema_version: Some(0),
251                    wasm_api_version: None,
252                    provides: BTreeSet::default(),
253                },
254                published_at: t0_chrono,
255                download_count: 7
256            },
257            ExtensionMetadata {
258                id: "ext1".into(),
259                manifest: rpc::ExtensionApiManifest {
260                    name: "Extension One".into(),
261                    version: "0.0.2".into(),
262                    authors: vec!["max".into(), "marshall".into()],
263                    description: Some("a good extension".into()),
264                    repository: "ext1/repo".into(),
265                    schema_version: Some(1),
266                    wasm_api_version: None,
267                    provides: BTreeSet::default(),
268                },
269                published_at: t0_chrono,
270                download_count: 5,
271            },
272        ]
273    );
274
275    // Add more extensions, including a new version of `ext1`, and backfilling
276    // an older version of `ext2`.
277    db.insert_extension_versions(
278        &[
279            (
280                "ext1",
281                vec![NewExtensionVersion {
282                    name: "Extension One".into(),
283                    version: semver::Version::parse("0.0.3").unwrap(),
284                    description: "a real good extension".into(),
285                    authors: vec!["max".into(), "marshall".into()],
286                    repository: "ext1/repo".into(),
287                    schema_version: 1,
288                    wasm_api_version: None,
289                    provides: BTreeSet::default(),
290                    published_at: t0,
291                }],
292            ),
293            (
294                "ext2",
295                vec![NewExtensionVersion {
296                    name: "Extension Two".into(),
297                    version: semver::Version::parse("0.1.0").unwrap(),
298                    description: "an old extension".into(),
299                    authors: vec!["marshall".into()],
300                    repository: "ext2/repo".into(),
301                    schema_version: 0,
302                    wasm_api_version: None,
303                    provides: BTreeSet::default(),
304                    published_at: t0,
305                }],
306            ),
307        ]
308        .into_iter()
309        .collect(),
310    )
311    .await
312    .unwrap();
313
314    let versions = db.get_known_extension_versions().await.unwrap();
315    assert_eq!(
316        versions,
317        [
318            (
319                "ext1".into(),
320                vec!["0.0.1".into(), "0.0.2".into(), "0.0.3".into()]
321            ),
322            ("ext2".into(), vec!["0.1.0".into(), "0.2.0".into()])
323        ]
324        .into_iter()
325        .collect()
326    );
327
328    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
329    assert_eq!(
330        extensions,
331        &[
332            ExtensionMetadata {
333                id: "ext2".into(),
334                manifest: rpc::ExtensionApiManifest {
335                    name: "Extension Two".into(),
336                    version: "0.2.0".into(),
337                    authors: vec!["marshall".into()],
338                    description: Some("a great extension".into()),
339                    repository: "ext2/repo".into(),
340                    schema_version: Some(0),
341                    wasm_api_version: None,
342                    provides: BTreeSet::default(),
343                },
344                published_at: t0_chrono,
345                download_count: 7
346            },
347            ExtensionMetadata {
348                id: "ext1".into(),
349                manifest: rpc::ExtensionApiManifest {
350                    name: "Extension One".into(),
351                    version: "0.0.3".into(),
352                    authors: vec!["max".into(), "marshall".into()],
353                    description: Some("a real good extension".into()),
354                    repository: "ext1/repo".into(),
355                    schema_version: Some(1),
356                    wasm_api_version: None,
357                    provides: BTreeSet::default(),
358                },
359                published_at: t0_chrono,
360                download_count: 5,
361            },
362        ]
363    );
364}
365
366test_both_dbs!(
367    test_extensions_by_id,
368    test_extensions_by_id_postgres,
369    test_extensions_by_id_sqlite
370);
371
372async fn test_extensions_by_id(db: &Arc<Database>) {
373    let versions = db.get_known_extension_versions().await.unwrap();
374    assert!(versions.is_empty());
375
376    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
377    assert!(extensions.is_empty());
378
379    let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
380    let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
381
382    let t0_chrono = convert_time_to_chrono(t0);
383
384    db.insert_extension_versions(
385        &[
386            (
387                "ext1",
388                vec![
389                    NewExtensionVersion {
390                        name: "Extension 1".into(),
391                        version: semver::Version::parse("0.0.1").unwrap(),
392                        description: "an extension".into(),
393                        authors: vec!["max".into()],
394                        repository: "ext1/repo".into(),
395                        schema_version: 1,
396                        wasm_api_version: Some("0.0.4".into()),
397                        provides: BTreeSet::from_iter([
398                            ExtensionProvides::Grammars,
399                            ExtensionProvides::Languages,
400                        ]),
401                        published_at: t0,
402                    },
403                    NewExtensionVersion {
404                        name: "Extension 1".into(),
405                        version: semver::Version::parse("0.0.2").unwrap(),
406                        description: "a good extension".into(),
407                        authors: vec!["max".into()],
408                        repository: "ext1/repo".into(),
409                        schema_version: 1,
410                        wasm_api_version: Some("0.0.4".into()),
411                        provides: BTreeSet::from_iter([
412                            ExtensionProvides::Grammars,
413                            ExtensionProvides::Languages,
414                            ExtensionProvides::LanguageServers,
415                        ]),
416                        published_at: t0,
417                    },
418                    NewExtensionVersion {
419                        name: "Extension 1".into(),
420                        version: semver::Version::parse("0.0.3").unwrap(),
421                        description: "a real good extension".into(),
422                        authors: vec!["max".into(), "marshall".into()],
423                        repository: "ext1/repo".into(),
424                        schema_version: 1,
425                        wasm_api_version: Some("0.0.5".into()),
426                        provides: BTreeSet::from_iter([
427                            ExtensionProvides::Grammars,
428                            ExtensionProvides::Languages,
429                            ExtensionProvides::LanguageServers,
430                        ]),
431                        published_at: t0,
432                    },
433                ],
434            ),
435            (
436                "ext2",
437                vec![NewExtensionVersion {
438                    name: "Extension 2".into(),
439                    version: semver::Version::parse("0.2.0").unwrap(),
440                    description: "a great extension".into(),
441                    authors: vec!["marshall".into()],
442                    repository: "ext2/repo".into(),
443                    schema_version: 0,
444                    wasm_api_version: None,
445                    provides: BTreeSet::default(),
446                    published_at: t0,
447                }],
448            ),
449        ]
450        .into_iter()
451        .collect(),
452    )
453    .await
454    .unwrap();
455
456    let extensions = db
457        .get_extensions_by_ids(
458            &["ext1"],
459            Some(&ExtensionVersionConstraints {
460                schema_versions: 1..=1,
461                wasm_api_versions: "0.0.1".parse().unwrap()..="0.0.4".parse().unwrap(),
462            }),
463        )
464        .await
465        .unwrap();
466
467    assert_eq!(
468        extensions,
469        &[ExtensionMetadata {
470            id: "ext1".into(),
471            manifest: rpc::ExtensionApiManifest {
472                name: "Extension 1".into(),
473                version: "0.0.2".into(),
474                authors: vec!["max".into()],
475                description: Some("a good extension".into()),
476                repository: "ext1/repo".into(),
477                schema_version: Some(1),
478                wasm_api_version: Some("0.0.4".into()),
479                provides: BTreeSet::from_iter([
480                    ExtensionProvides::Grammars,
481                    ExtensionProvides::Languages,
482                    ExtensionProvides::LanguageServers,
483                ]),
484            },
485            published_at: t0_chrono,
486            download_count: 0,
487        }]
488    );
489}