1use crate::{
2 db::dot,
3 embedding::EmbeddingProvider,
4 parsing::{CodeContextRetriever, Document},
5 vector_store_settings::VectorStoreSettings,
6 VectorStore,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry};
12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
13use rand::{rngs::StdRng, Rng};
14use serde_json::json;
15use settings::SettingsStore;
16use std::{
17 path::Path,
18 sync::{
19 atomic::{self, AtomicUsize},
20 Arc,
21 },
22};
23use unindent::Unindent;
24
25#[ctor::ctor]
26fn init_logger() {
27 if std::env::var("RUST_LOG").is_ok() {
28 env_logger::init();
29 }
30}
31
32#[gpui::test]
33async fn test_vector_store(cx: &mut TestAppContext) {
34 cx.update(|cx| {
35 cx.set_global(SettingsStore::test(cx));
36 settings::register::<VectorStoreSettings>(cx);
37 settings::register::<ProjectSettings>(cx);
38 });
39
40 let fs = FakeFs::new(cx.background());
41 fs.insert_tree(
42 "/the-root",
43 json!({
44 "src": {
45 "file1.rs": "
46 fn aaa() {
47 println!(\"aaaa!\");
48 }
49
50 fn zzzzzzzzz() {
51 println!(\"SLEEPING\");
52 }
53 ".unindent(),
54 "file2.rs": "
55 fn bbb() {
56 println!(\"bbbb!\");
57 }
58 ".unindent(),
59 "file3.toml": "
60 ZZZZZZZ = 5
61 ".unindent(),
62 }
63 }),
64 )
65 .await;
66
67 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
68 let rust_language = rust_lang();
69 let toml_language = toml_lang();
70 languages.add(rust_language);
71 languages.add(toml_language);
72
73 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
74 let db_path = db_dir.path().join("db.sqlite");
75
76 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
77 let store = VectorStore::new(
78 fs.clone(),
79 db_path,
80 embedding_provider.clone(),
81 languages,
82 cx.to_async(),
83 )
84 .await
85 .unwrap();
86
87 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
88 let worktree_id = project.read_with(cx, |project, cx| {
89 project.worktrees(cx).next().unwrap().read(cx).id()
90 });
91 let file_count = store
92 .update(cx, |store, cx| store.index_project(project.clone(), cx))
93 .await
94 .unwrap();
95 assert_eq!(file_count, 3);
96 cx.foreground().run_until_parked();
97 store.update(cx, |store, _cx| {
98 assert_eq!(
99 store.remaining_files_to_index_for_project(&project),
100 Some(0)
101 );
102 });
103
104 let search_results = store
105 .update(cx, |store, cx| {
106 store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
107 })
108 .await
109 .unwrap();
110
111 assert_eq!(search_results[0].byte_range.start, 0);
112 assert_eq!(search_results[0].name, "aaa");
113 assert_eq!(search_results[0].worktree_id, worktree_id);
114
115 fs.save(
116 "/the-root/src/file2.rs".as_ref(),
117 &"
118 fn dddd() { println!(\"ddddd!\"); }
119 struct pqpqpqp {}
120 "
121 .unindent()
122 .into(),
123 Default::default(),
124 )
125 .await
126 .unwrap();
127
128 cx.foreground().run_until_parked();
129
130 let prev_embedding_count = embedding_provider.embedding_count();
131 let file_count = store
132 .update(cx, |store, cx| store.index_project(project.clone(), cx))
133 .await
134 .unwrap();
135 assert_eq!(file_count, 1);
136
137 cx.foreground().run_until_parked();
138 store.update(cx, |store, _cx| {
139 assert_eq!(
140 store.remaining_files_to_index_for_project(&project),
141 Some(0)
142 );
143 });
144
145 assert_eq!(
146 embedding_provider.embedding_count() - prev_embedding_count,
147 2
148 );
149}
150
151#[gpui::test]
152async fn test_code_context_retrieval_rust() {
153 let language = rust_lang();
154 let mut retriever = CodeContextRetriever::new();
155
156 let text = "
157 /// A doc comment
158 /// that spans multiple lines
159 fn a() {
160 b
161 }
162
163 impl C for D {
164 }
165 "
166 .unindent();
167
168 let parsed_files = retriever
169 .parse_file(Path::new("foo.rs"), &text, language)
170 .unwrap();
171
172 assert_eq!(
173 parsed_files,
174 &[
175 Document {
176 name: "a".into(),
177 range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
178 content: "
179 The below code snippet is from file 'foo.rs'
180
181 ```rust
182 /// A doc comment
183 /// that spans multiple lines
184 fn a() {
185 b
186 }
187 ```"
188 .unindent(),
189 embedding: vec![],
190 },
191 Document {
192 name: "C for D".into(),
193 range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
194 content: "
195 The below code snippet is from file 'foo.rs'
196
197 ```rust
198 impl C for D {
199 }
200 ```"
201 .unindent(),
202 embedding: vec![],
203 }
204 ]
205 );
206}
207
208#[gpui::test]
209async fn test_code_context_retrieval_javascript() {
210 let language = js_lang();
211 let mut retriever = CodeContextRetriever::new();
212
213 let text = "
214/* globals importScripts, backend */
215function _authorize() {}
216
217/**
218 * Sometimes the frontend build is way faster than backend.
219 */
220export async function authorizeBank() {
221 _authorize(pushModal, upgradingAccountId, {});
222}
223
224export class SettingsPage {
225 /* This is a test setting */
226 constructor(page) {
227 this.page = page;
228 }
229}
230
231/* This is a test comment */
232class TestClass {}
233
234/* Schema for editor_events in Clickhouse. */
235export interface ClickhouseEditorEvent {
236 installation_id: string
237 operation: string
238}
239";
240
241 let parsed_files = retriever
242 .parse_file(Path::new("foo.js"), &text, language)
243 .unwrap();
244
245 let test_documents = &[
246 Document {
247 name: "function _authorize".into(),
248 range: text.find("function _authorize").unwrap()..(text.find("}").unwrap() + 1),
249 content: "
250 The below code snippet is from file 'foo.js'
251
252 ```javascript
253 /* globals importScripts, backend */
254 function _authorize() {}
255 ```"
256 .unindent(),
257 embedding: vec![],
258 },
259 Document {
260 name: "async function authorizeBank".into(),
261 range: text.find("export async").unwrap()..224,
262 content: "
263 The below code snippet is from file 'foo.js'
264
265 ```javascript
266 /**
267 * Sometimes the frontend build is way faster than backend.
268 */
269 export async function authorizeBank() {
270 _authorize(pushModal, upgradingAccountId, {});
271 }
272 ```"
273 .unindent(),
274 embedding: vec![],
275 },
276 Document {
277 name: "class SettingsPage".into(),
278 range: 226..344,
279 content: "
280 The below code snippet is from file 'foo.js'
281
282 ```javascript
283 export class SettingsPage {
284 /* This is a test setting */
285 constructor(page) {
286 this.page = page;
287 }
288 }
289 ```"
290 .unindent(),
291 embedding: vec![],
292 },
293 Document {
294 name: "constructor".into(),
295 range: 291..342,
296 content: "
297 The below code snippet is from file 'foo.js'
298
299 ```javascript
300 /* This is a test setting */
301 constructor(page) {
302 this.page = page;
303 }
304 ```"
305 .unindent(),
306 embedding: vec![],
307 },
308 Document {
309 name: "class TestClass".into(),
310 range: 375..393,
311 content: "
312 The below code snippet is from file 'foo.js'
313
314 ```javascript
315 /* This is a test comment */
316 class TestClass {}
317 ```"
318 .unindent(),
319 embedding: vec![],
320 },
321 Document {
322 name: "interface ClickhouseEditorEvent".into(),
323 range: 441..533,
324 content: "
325 The below code snippet is from file 'foo.js'
326
327 ```javascript
328 /* Schema for editor_events in Clickhouse. */
329 export interface ClickhouseEditorEvent {
330 installation_id: string
331 operation: string
332 }
333 ```"
334 .unindent(),
335 embedding: vec![],
336 },
337 ];
338
339 for idx in 0..test_documents.len() {
340 assert_eq!(test_documents[idx], parsed_files[idx]);
341 }
342}
343
344#[gpui::test]
345fn test_dot_product(mut rng: StdRng) {
346 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
347 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
348
349 for _ in 0..100 {
350 let size = 1536;
351 let mut a = vec![0.; size];
352 let mut b = vec![0.; size];
353 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
354 *a = rng.gen();
355 *b = rng.gen();
356 }
357
358 assert_eq!(
359 round_to_decimals(dot(&a, &b), 1),
360 round_to_decimals(reference_dot(&a, &b), 1)
361 );
362 }
363
364 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
365 let factor = (10.0 as f32).powi(decimal_places);
366 (n * factor).round() / factor
367 }
368
369 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
370 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
371 }
372}
373
374#[derive(Default)]
375struct FakeEmbeddingProvider {
376 embedding_count: AtomicUsize,
377}
378
379impl FakeEmbeddingProvider {
380 fn embedding_count(&self) -> usize {
381 self.embedding_count.load(atomic::Ordering::SeqCst)
382 }
383}
384
385#[async_trait]
386impl EmbeddingProvider for FakeEmbeddingProvider {
387 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
388 self.embedding_count
389 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
390 Ok(spans
391 .iter()
392 .map(|span| {
393 let mut result = vec![1.0; 26];
394 for letter in span.chars() {
395 let letter = letter.to_ascii_lowercase();
396 if letter as u32 >= 'a' as u32 {
397 let ix = (letter as u32) - ('a' as u32);
398 if ix < 26 {
399 result[ix as usize] += 1.0;
400 }
401 }
402 }
403
404 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
405 for x in &mut result {
406 *x /= norm;
407 }
408
409 result
410 })
411 .collect())
412 }
413}
414
415fn js_lang() -> Arc<Language> {
416 Arc::new(
417 Language::new(
418 LanguageConfig {
419 name: "Javascript".into(),
420 path_suffixes: vec!["js".into()],
421 ..Default::default()
422 },
423 Some(tree_sitter_typescript::language_tsx()),
424 )
425 .with_embedding_query(
426 &r#"
427
428 (
429 (comment)* @context
430 .
431 (export_statement
432 (function_declaration
433 "async"? @name
434 "function" @name
435 name: (_) @name)) @item
436 )
437
438 (
439 (comment)* @context
440 .
441 (function_declaration
442 "async"? @name
443 "function" @name
444 name: (_) @name) @item
445 )
446
447 (
448 (comment)* @context
449 .
450 (export_statement
451 (class_declaration
452 "class" @name
453 name: (_) @name)) @item
454 )
455
456 (
457 (comment)* @context
458 .
459 (class_declaration
460 "class" @name
461 name: (_) @name) @item
462 )
463
464 (
465 (comment)* @context
466 .
467 (method_definition
468 [
469 "get"
470 "set"
471 "async"
472 "*"
473 "static"
474 ]* @name
475 name: (_) @name) @item
476 )
477
478 (
479 (comment)* @context
480 .
481 (export_statement
482 (interface_declaration
483 "interface" @name
484 name: (_) @name)) @item
485 )
486
487 (
488 (comment)* @context
489 .
490 (interface_declaration
491 "interface" @name
492 name: (_) @name) @item
493 )
494
495 (
496 (comment)* @context
497 .
498 (export_statement
499 (enum_declaration
500 "enum" @name
501 name: (_) @name)) @item
502 )
503
504 (
505 (comment)* @context
506 .
507 (enum_declaration
508 "enum" @name
509 name: (_) @name) @item
510 )
511
512 "#
513 .unindent(),
514 )
515 .unwrap(),
516 )
517}
518
519fn rust_lang() -> Arc<Language> {
520 Arc::new(
521 Language::new(
522 LanguageConfig {
523 name: "Rust".into(),
524 path_suffixes: vec!["rs".into()],
525 ..Default::default()
526 },
527 Some(tree_sitter_rust::language()),
528 )
529 .with_embedding_query(
530 r#"
531 (
532 (line_comment)* @context
533 .
534 (enum_item
535 name: (_) @name) @item
536 )
537
538 (
539 (line_comment)* @context
540 .
541 (struct_item
542 name: (_) @name) @item
543 )
544
545 (
546 (line_comment)* @context
547 .
548 (impl_item
549 trait: (_)? @name
550 "for"? @name
551 type: (_) @name) @item
552 )
553
554 (
555 (line_comment)* @context
556 .
557 (trait_item
558 name: (_) @name) @item
559 )
560
561 (
562 (line_comment)* @context
563 .
564 (function_item
565 name: (_) @name) @item
566 )
567
568 (
569 (line_comment)* @context
570 .
571 (macro_definition
572 name: (_) @name) @item
573 )
574
575 (
576 (line_comment)* @context
577 .
578 (function_signature_item
579 name: (_) @name) @item
580 )
581 "#,
582 )
583 .unwrap(),
584 )
585}
586
587fn toml_lang() -> Arc<Language> {
588 Arc::new(Language::new(
589 LanguageConfig {
590 name: "TOML".into(),
591 path_suffixes: vec!["toml".into()],
592 ..Default::default()
593 },
594 Some(tree_sitter_toml::language()),
595 ))
596}