1use super::TestDb;
2use crate::db::embedding;
3use collections::HashMap;
4use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, sea_query::Expr};
5use std::ops::Sub;
6use time::{Duration, OffsetDateTime, PrimitiveDateTime};
7
8// SQLite does not support array arguments, so we only test this against a real postgres instance
9#[gpui::test]
10async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) {
11 let test_db = TestDb::postgres(cx.executor());
12 let db = test_db.db();
13
14 let provider = "test_model";
15 let digest1 = vec![1, 2, 3];
16 let digest2 = vec![4, 5, 6];
17 let embeddings = HashMap::from_iter([
18 (digest1.clone(), vec![0.1, 0.2, 0.3]),
19 (digest2.clone(), vec![0.4, 0.5, 0.6]),
20 ]);
21
22 // Save embeddings
23 db.save_embeddings(provider, &embeddings).await.unwrap();
24
25 // Retrieve embeddings
26 let retrieved_embeddings = db
27 .get_embeddings(provider, &[digest1.clone(), digest2.clone()])
28 .await
29 .unwrap();
30 assert_eq!(retrieved_embeddings.len(), 2);
31 assert!(retrieved_embeddings.contains_key(&digest1));
32 assert!(retrieved_embeddings.contains_key(&digest2));
33
34 // Check if the retrieved embeddings are correct
35 assert_eq!(retrieved_embeddings[&digest1], vec![0.1, 0.2, 0.3]);
36 assert_eq!(retrieved_embeddings[&digest2], vec![0.4, 0.5, 0.6]);
37}
38
39#[gpui::test]
40async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) {
41 let test_db = TestDb::postgres(cx.executor());
42 let db = test_db.db();
43
44 let model = "test_model";
45 let digest = vec![7, 8, 9];
46 let embeddings = HashMap::from_iter([(digest.clone(), vec![0.7, 0.8, 0.9])]);
47
48 // Save old embeddings
49 db.save_embeddings(model, &embeddings).await.unwrap();
50
51 // Reach into the DB and change the retrieved at to be > 60 days
52 db.transaction(|tx| {
53 let digest = digest.clone();
54 async move {
55 let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61));
56 let retrieved_at = PrimitiveDateTime::new(sixty_days_ago.date(), sixty_days_ago.time());
57
58 embedding::Entity::update_many()
59 .filter(
60 embedding::Column::Model
61 .eq(model)
62 .and(embedding::Column::Digest.eq(digest)),
63 )
64 .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at))
65 .exec(&*tx)
66 .await
67 .unwrap();
68
69 Ok(())
70 }
71 })
72 .await
73 .unwrap();
74
75 // Purge old embeddings
76 db.purge_old_embeddings().await.unwrap();
77
78 // Try to retrieve the purged embeddings
79 let retrieved_embeddings = db
80 .get_embeddings(model, std::slice::from_ref(&digest))
81 .await
82 .unwrap();
83 assert!(
84 retrieved_embeddings.is_empty(),
85 "Old embeddings should have been purged"
86 );
87}