1use crate::{db::dot, embedding::EmbeddingProvider, VectorStore};
2use anyhow::Result;
3use async_trait::async_trait;
4use gpui::{Task, TestAppContext};
5use language::{Language, LanguageConfig, LanguageRegistry};
6use project::{FakeFs, Project};
7use rand::Rng;
8use serde_json::json;
9use std::sync::Arc;
10use unindent::Unindent;
11
12#[gpui::test]
13async fn test_vector_store(cx: &mut TestAppContext) {
14 let fs = FakeFs::new(cx.background());
15 fs.insert_tree(
16 "/the-root",
17 json!({
18 "src": {
19 "file1.rs": "
20 fn aaa() {
21 println!(\"aaaa!\");
22 }
23
24 fn zzzzzzzzz() {
25 println!(\"SLEEPING\");
26 }
27 ".unindent(),
28 "file2.rs": "
29 fn bbb() {
30 println!(\"bbbb!\");
31 }
32 ".unindent(),
33 }
34 }),
35 )
36 .await;
37
38 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
39 let rust_language = Arc::new(
40 Language::new(
41 LanguageConfig {
42 name: "Rust".into(),
43 path_suffixes: vec!["rs".into()],
44 ..Default::default()
45 },
46 Some(tree_sitter_rust::language()),
47 )
48 .with_embedding_query(
49 r#"
50 (function_item
51 name: (identifier) @name
52 body: (block)) @item
53 "#,
54 )
55 .unwrap(),
56 );
57 languages.add(rust_language);
58
59 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
60 let db_path = db_dir.path().join("db.sqlite");
61
62 let store = VectorStore::new(
63 fs.clone(),
64 db_path,
65 Arc::new(FakeEmbeddingProvider),
66 languages,
67 cx.to_async(),
68 )
69 .await
70 .unwrap();
71
72 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
73 let worktree_id = project.read_with(cx, |project, cx| {
74 project.worktrees(cx).next().unwrap().read(cx).id()
75 });
76 let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
77
78 add_project.await.unwrap();
79
80 let search_results = store
81 .update(cx, |store, cx| {
82 store.search(project.clone(), "aaaa".to_string(), 5, cx)
83 })
84 .await
85 .unwrap();
86
87 assert_eq!(search_results[0].offset, 0);
88 assert_eq!(search_results[0].name, "aaa");
89 assert_eq!(search_results[0].worktree_id, worktree_id);
90}
91
92#[test]
93fn test_dot_product() {
94 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
95 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
96
97 for _ in 0..100 {
98 let mut rng = rand::thread_rng();
99 let a: [f32; 32] = rng.gen();
100 let b: [f32; 32] = rng.gen();
101 assert_eq!(
102 round_to_decimals(dot(&a, &b), 3),
103 round_to_decimals(reference_dot(&a, &b), 3)
104 );
105 }
106
107 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
108 let factor = (10.0 as f32).powi(decimal_places);
109 (n * factor).round() / factor
110 }
111
112 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
113 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
114 }
115}
116
117struct FakeEmbeddingProvider;
118
119#[async_trait]
120impl EmbeddingProvider for FakeEmbeddingProvider {
121 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
122 Ok(spans
123 .iter()
124 .map(|span| {
125 let mut result = vec![1.0; 26];
126 for letter in span.chars() {
127 let letter = letter.to_ascii_lowercase();
128 if letter as u32 >= 'a' as u32 {
129 let ix = (letter as u32) - ('a' as u32);
130 if ix < 26 {
131 result[ix as usize] += 1.0;
132 }
133 }
134 }
135
136 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
137 for x in &mut result {
138 *x /= norm;
139 }
140
141 result
142 })
143 .collect())
144 }
145}