1use crate::{
2 db::dot,
3 embedding::EmbeddingProvider,
4 parsing::{CodeContextRetriever, Document},
5 vector_store_settings::VectorStoreSettings,
6 VectorStore,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry};
12use project::{project_settings::ProjectSettings, FakeFs, Project};
13use rand::{rngs::StdRng, Rng};
14use serde_json::json;
15use settings::SettingsStore;
16use std::{path::Path, sync::Arc};
17use unindent::Unindent;
18
19#[ctor::ctor]
20fn init_logger() {
21 if std::env::var("RUST_LOG").is_ok() {
22 env_logger::init();
23 }
24}
25
26#[gpui::test]
27async fn test_vector_store(cx: &mut TestAppContext) {
28 cx.update(|cx| {
29 cx.set_global(SettingsStore::test(cx));
30 settings::register::<VectorStoreSettings>(cx);
31 settings::register::<ProjectSettings>(cx);
32 });
33
34 let fs = FakeFs::new(cx.background());
35 fs.insert_tree(
36 "/the-root",
37 json!({
38 "src": {
39 "file1.rs": "
40 fn aaa() {
41 println!(\"aaaa!\");
42 }
43
44 fn zzzzzzzzz() {
45 println!(\"SLEEPING\");
46 }
47 ".unindent(),
48 "file2.rs": "
49 fn bbb() {
50 println!(\"bbbb!\");
51 }
52 ".unindent(),
53 }
54 }),
55 )
56 .await;
57
58 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
59 let rust_language = rust_lang();
60 languages.add(rust_language);
61
62 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
63 let db_path = db_dir.path().join("db.sqlite");
64
65 let store = VectorStore::new(
66 fs.clone(),
67 db_path,
68 Arc::new(FakeEmbeddingProvider),
69 languages,
70 cx.to_async(),
71 )
72 .await
73 .unwrap();
74
75 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
76 let worktree_id = project.read_with(cx, |project, cx| {
77 project.worktrees(cx).next().unwrap().read(cx).id()
78 });
79 store
80 .update(cx, |store, cx| store.add_project(project.clone(), cx))
81 .await
82 .unwrap();
83 cx.foreground().run_until_parked();
84
85 let search_results = store
86 .update(cx, |store, cx| {
87 store.search(project.clone(), "aaaa".to_string(), 5, cx)
88 })
89 .await
90 .unwrap();
91
92 assert_eq!(search_results[0].byte_range.start, 0);
93 assert_eq!(search_results[0].name, "aaa");
94 assert_eq!(search_results[0].worktree_id, worktree_id);
95}
96
97#[gpui::test]
98async fn test_code_context_retrieval(cx: &mut TestAppContext) {
99 let language = rust_lang();
100 let mut retriever = CodeContextRetriever::new();
101
102 let text = "
103 /// A doc comment
104 /// that spans multiple lines
105 fn a() {
106 b
107 }
108
109 impl C for D {
110 }
111 "
112 .unindent();
113
114 let parsed_files = retriever
115 .parse_file(Path::new("foo.rs"), &text, language)
116 .unwrap();
117
118 assert_eq!(
119 parsed_files,
120 &[
121 Document {
122 name: "a".into(),
123 range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
124 content: "
125 The below code snippet is from file 'foo.rs'
126
127 ```rust
128 /// A doc comment
129 /// that spans multiple lines
130 fn a() {
131 b
132 }
133 ```"
134 .unindent(),
135 embedding: vec![],
136 },
137 Document {
138 name: "C for D".into(),
139 range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
140 content: "
141 The below code snippet is from file 'foo.rs'
142
143 ```rust
144 impl C for D {
145 }
146 ```"
147 .unindent(),
148 embedding: vec![],
149 }
150 ]
151 );
152}
153
154#[gpui::test]
155fn test_dot_product(mut rng: StdRng) {
156 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
157 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
158
159 for _ in 0..100 {
160 let size = 1536;
161 let mut a = vec![0.; size];
162 let mut b = vec![0.; size];
163 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
164 *a = rng.gen();
165 *b = rng.gen();
166 }
167
168 assert_eq!(
169 round_to_decimals(dot(&a, &b), 1),
170 round_to_decimals(reference_dot(&a, &b), 1)
171 );
172 }
173
174 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
175 let factor = (10.0 as f32).powi(decimal_places);
176 (n * factor).round() / factor
177 }
178
179 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
180 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
181 }
182}
183
184struct FakeEmbeddingProvider;
185
186#[async_trait]
187impl EmbeddingProvider for FakeEmbeddingProvider {
188 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
189 Ok(spans
190 .iter()
191 .map(|span| {
192 let mut result = vec![1.0; 26];
193 for letter in span.chars() {
194 let letter = letter.to_ascii_lowercase();
195 if letter as u32 >= 'a' as u32 {
196 let ix = (letter as u32) - ('a' as u32);
197 if ix < 26 {
198 result[ix as usize] += 1.0;
199 }
200 }
201 }
202
203 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
204 for x in &mut result {
205 *x /= norm;
206 }
207
208 result
209 })
210 .collect())
211 }
212}
213
214fn rust_lang() -> Arc<Language> {
215 Arc::new(
216 Language::new(
217 LanguageConfig {
218 name: "Rust".into(),
219 path_suffixes: vec!["rs".into()],
220 ..Default::default()
221 },
222 Some(tree_sitter_rust::language()),
223 )
224 .with_embedding_query(
225 r#"
226 (
227 (line_comment)* @context
228 .
229 (enum_item
230 name: (_) @name) @item
231 )
232
233 (
234 (line_comment)* @context
235 .
236 (struct_item
237 name: (_) @name) @item
238 )
239
240 (
241 (line_comment)* @context
242 .
243 (impl_item
244 trait: (_)? @name
245 "for"? @name
246 type: (_) @name) @item
247 )
248
249 (
250 (line_comment)* @context
251 .
252 (trait_item
253 name: (_) @name) @item
254 )
255
256 (
257 (line_comment)* @context
258 .
259 (function_item
260 name: (_) @name) @item
261 )
262
263 (
264 (line_comment)* @context
265 .
266 (macro_definition
267 name: (_) @name) @item
268 )
269
270 (
271 (line_comment)* @context
272 .
273 (function_signature_item
274 name: (_) @name) @item
275 )
276 "#,
277 )
278 .unwrap(),
279 )
280}