1use std::sync::Arc;
2
3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
4use anyhow::Result;
5use async_trait::async_trait;
6use gpui::{Task, TestAppContext};
7use language::{Language, LanguageConfig, LanguageRegistry};
8use project::{FakeFs, Project};
9use rand::Rng;
10use serde_json::json;
11use unindent::Unindent;
12
13#[gpui::test]
14async fn test_vector_store(cx: &mut TestAppContext) {
15 let fs = FakeFs::new(cx.background());
16 fs.insert_tree(
17 "/the-root",
18 json!({
19 "src": {
20 "file1.rs": "
21 fn aaa() {
22 println!(\"aaaa!\");
23 }
24
25 fn zzzzzzzzz() {
26 println!(\"SLEEPING\");
27 }
28 ".unindent(),
29 "file2.rs": "
30 fn bbb() {
31 println!(\"bbbb!\");
32 }
33 ".unindent(),
34 }
35 }),
36 )
37 .await;
38
39 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
40 let rust_language = Arc::new(
41 Language::new(
42 LanguageConfig {
43 name: "Rust".into(),
44 path_suffixes: vec!["rs".into()],
45 ..Default::default()
46 },
47 Some(tree_sitter_rust::language()),
48 )
49 .with_embedding_query(
50 r#"
51 (function_item
52 name: (identifier) @name
53 body: (block)) @item
54 "#,
55 )
56 .unwrap(),
57 );
58 languages.add(rust_language);
59
60 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
61 let db_path = db_dir.path().join("db.sqlite");
62
63 let store = cx.add_model(|_| {
64 VectorStore::new(
65 fs.clone(),
66 db_path,
67 Arc::new(FakeEmbeddingProvider),
68 languages,
69 )
70 });
71
72 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
73 let worktree_id = project.read_with(cx, |project, cx| {
74 project.worktrees(cx).next().unwrap().read(cx).id()
75 });
76 let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
77
78 // TODO - remove
79 cx.foreground()
80 .advance_clock(std::time::Duration::from_secs(3));
81
82 add_project.await.unwrap();
83
84 let search_results = store
85 .update(cx, |store, cx| {
86 store.search(&project, "aaaa".to_string(), 5, cx)
87 })
88 .await
89 .unwrap();
90
91 assert_eq!(search_results[0].offset, 0);
92 assert_eq!(search_results[0].name, "aaa");
93 assert_eq!(search_results[0].worktree_id, worktree_id);
94}
95
96#[test]
97fn test_dot_product() {
98 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
99 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
100
101 for _ in 0..100 {
102 let mut rng = rand::thread_rng();
103 let a: [f32; 32] = rng.gen();
104 let b: [f32; 32] = rng.gen();
105 assert_eq!(
106 round_to_decimals(dot(&a, &b), 3),
107 round_to_decimals(reference_dot(&a, &b), 3)
108 );
109 }
110
111 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
112 let factor = (10.0 as f32).powi(decimal_places);
113 (n * factor).round() / factor
114 }
115
116 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
117 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
118 }
119}
120
121struct FakeEmbeddingProvider;
122
123#[async_trait]
124impl EmbeddingProvider for FakeEmbeddingProvider {
125 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
126 Ok(spans
127 .iter()
128 .map(|span| {
129 let mut result = vec![1.0; 26];
130 for letter in span.chars() {
131 let letter = letter.to_ascii_lowercase();
132 if letter as u32 >= 'a' as u32 {
133 let ix = (letter as u32) - ('a' as u32);
134 if ix < 26 {
135 result[ix as usize] += 1.0;
136 }
137 }
138 }
139
140 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
141 for x in &mut result {
142 *x /= norm;
143 }
144
145 result
146 })
147 .collect())
148 }
149}