1use crate::{
2 db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
3};
4use anyhow::Result;
5use async_trait::async_trait;
6use gpui::{Task, TestAppContext};
7use language::{Language, LanguageConfig, LanguageRegistry};
8use project::{FakeFs, Project};
9use rand::{rngs::StdRng, Rng};
10use serde_json::json;
11use settings::SettingsStore;
12use std::sync::Arc;
13use unindent::Unindent;
14
15#[gpui::test]
16async fn test_vector_store(cx: &mut TestAppContext) {
17 cx.update(|cx| {
18 cx.set_global(SettingsStore::test(cx));
19 settings::register::<VectorStoreSettings>(cx);
20 });
21
22 let fs = FakeFs::new(cx.background());
23 fs.insert_tree(
24 "/the-root",
25 json!({
26 "src": {
27 "file1.rs": "
28 fn aaa() {
29 println!(\"aaaa!\");
30 }
31
32 fn zzzzzzzzz() {
33 println!(\"SLEEPING\");
34 }
35 ".unindent(),
36 "file2.rs": "
37 fn bbb() {
38 println!(\"bbbb!\");
39 }
40 ".unindent(),
41 }
42 }),
43 )
44 .await;
45
46 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
47 let rust_language = Arc::new(
48 Language::new(
49 LanguageConfig {
50 name: "Rust".into(),
51 path_suffixes: vec!["rs".into()],
52 ..Default::default()
53 },
54 Some(tree_sitter_rust::language()),
55 )
56 .with_embedding_query(
57 r#"
58 (function_item
59 name: (identifier) @name
60 body: (block)) @item
61 "#,
62 )
63 .unwrap(),
64 );
65 languages.add(rust_language);
66
67 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
68 let db_path = db_dir.path().join("db.sqlite");
69
70 let store = VectorStore::new(
71 fs.clone(),
72 db_path,
73 Arc::new(FakeEmbeddingProvider),
74 languages,
75 cx.to_async(),
76 )
77 .await
78 .unwrap();
79
80 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
81 let worktree_id = project.read_with(cx, |project, cx| {
82 project.worktrees(cx).next().unwrap().read(cx).id()
83 });
84 store
85 .update(cx, |store, cx| store.add_project(project.clone(), cx))
86 .await
87 .unwrap();
88 cx.foreground().run_until_parked();
89
90 let search_results = store
91 .update(cx, |store, cx| {
92 store.search(project.clone(), "aaaa".to_string(), 5, cx)
93 })
94 .await
95 .unwrap();
96
97 assert_eq!(search_results[0].offset, 0);
98 assert_eq!(search_results[0].name, "aaa");
99 assert_eq!(search_results[0].worktree_id, worktree_id);
100}
101
102#[gpui::test]
103fn test_dot_product(mut rng: StdRng) {
104 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
105 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
106
107 for _ in 0..100 {
108 let size = 1536;
109 let mut a = vec![0.; size];
110 let mut b = vec![0.; size];
111 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
112 *a = rng.gen();
113 *b = rng.gen();
114 }
115
116 assert_eq!(
117 round_to_decimals(dot(&a, &b), 1),
118 round_to_decimals(reference_dot(&a, &b), 1)
119 );
120 }
121
122 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
123 let factor = (10.0 as f32).powi(decimal_places);
124 (n * factor).round() / factor
125 }
126
127 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
128 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
129 }
130}
131
132struct FakeEmbeddingProvider;
133
134#[async_trait]
135impl EmbeddingProvider for FakeEmbeddingProvider {
136 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
137 Ok(spans
138 .iter()
139 .map(|span| {
140 let mut result = vec![1.0; 26];
141 for letter in span.chars() {
142 let letter = letter.to_ascii_lowercase();
143 if letter as u32 >= 'a' as u32 {
144 let ix = (letter as u32) - ('a' as u32);
145 if ix < 26 {
146 result[ix as usize] += 1.0;
147 }
148 }
149 }
150
151 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
152 for x in &mut result {
153 *x /= norm;
154 }
155
156 result
157 })
158 .collect())
159 }
160}