embedding_tests.rs

 1use super::TestDb;
 2use crate::db::embedding;
 3use collections::HashMap;
 4use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, sea_query::Expr};
 5use std::ops::Sub;
 6use time::{Duration, OffsetDateTime, PrimitiveDateTime};
 7
 8// SQLite does not support array arguments, so we only test this against a real postgres instance
 9#[gpui::test]
10async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) {
11    let test_db = TestDb::postgres(cx.executor());
12    let db = test_db.db();
13
14    let provider = "test_model";
15    let digest1 = vec![1, 2, 3];
16    let digest2 = vec![4, 5, 6];
17    let embeddings = HashMap::from_iter([
18        (digest1.clone(), vec![0.1, 0.2, 0.3]),
19        (digest2.clone(), vec![0.4, 0.5, 0.6]),
20    ]);
21
22    // Save embeddings
23    db.save_embeddings(provider, &embeddings).await.unwrap();
24
25    // Retrieve embeddings
26    let retrieved_embeddings = db
27        .get_embeddings(provider, &[digest1.clone(), digest2.clone()])
28        .await
29        .unwrap();
30    assert_eq!(retrieved_embeddings.len(), 2);
31    assert!(retrieved_embeddings.contains_key(&digest1));
32    assert!(retrieved_embeddings.contains_key(&digest2));
33
34    // Check if the retrieved embeddings are correct
35    assert_eq!(retrieved_embeddings[&digest1], vec![0.1, 0.2, 0.3]);
36    assert_eq!(retrieved_embeddings[&digest2], vec![0.4, 0.5, 0.6]);
37}
38
39#[gpui::test]
40async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) {
41    let test_db = TestDb::postgres(cx.executor());
42    let db = test_db.db();
43
44    let model = "test_model";
45    let digest = vec![7, 8, 9];
46    let embeddings = HashMap::from_iter([(digest.clone(), vec![0.7, 0.8, 0.9])]);
47
48    // Save old embeddings
49    db.save_embeddings(model, &embeddings).await.unwrap();
50
51    // Reach into the DB and change the retrieved at to be > 60 days
52    db.transaction(|tx| {
53        let digest = digest.clone();
54        async move {
55            let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61));
56            let retrieved_at = PrimitiveDateTime::new(sixty_days_ago.date(), sixty_days_ago.time());
57
58            embedding::Entity::update_many()
59                .filter(
60                    embedding::Column::Model
61                        .eq(model)
62                        .and(embedding::Column::Digest.eq(digest)),
63                )
64                .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at))
65                .exec(&*tx)
66                .await
67                .unwrap();
68
69            Ok(())
70        }
71    })
72    .await
73    .unwrap();
74
75    // Purge old embeddings
76    db.purge_old_embeddings().await.unwrap();
77
78    // Try to retrieve the purged embeddings
79    let retrieved_embeddings = db
80        .get_embeddings(model, std::slice::from_ref(&digest))
81        .await
82        .unwrap();
83    assert!(
84        retrieved_embeddings.is_empty(),
85        "Old embeddings should have been purged"
86    );
87}