1use crate::{
2 db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
3};
4use anyhow::Result;
5use async_trait::async_trait;
6use gpui::{Task, TestAppContext};
7use language::{Language, LanguageConfig, LanguageRegistry};
8use project::{FakeFs, Project};
9use rand::Rng;
10use serde_json::json;
11use settings::SettingsStore;
12use std::sync::Arc;
13use unindent::Unindent;
14
15#[gpui::test]
16async fn test_vector_store(cx: &mut TestAppContext) {
17 cx.update(|cx| {
18 cx.set_global(SettingsStore::test(cx));
19 settings::register::<VectorStoreSettings>(cx);
20 });
21
22 let fs = FakeFs::new(cx.background());
23 fs.insert_tree(
24 "/the-root",
25 json!({
26 "src": {
27 "file1.rs": "
28 fn aaa() {
29 println!(\"aaaa!\");
30 }
31
32 fn zzzzzzzzz() {
33 println!(\"SLEEPING\");
34 }
35 ".unindent(),
36 "file2.rs": "
37 fn bbb() {
38 println!(\"bbbb!\");
39 }
40 ".unindent(),
41 }
42 }),
43 )
44 .await;
45
46 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
47 let rust_language = Arc::new(
48 Language::new(
49 LanguageConfig {
50 name: "Rust".into(),
51 path_suffixes: vec!["rs".into()],
52 ..Default::default()
53 },
54 Some(tree_sitter_rust::language()),
55 )
56 .with_embedding_query(
57 r#"
58 (function_item
59 name: (identifier) @name
60 body: (block)) @item
61 "#,
62 )
63 .unwrap(),
64 );
65 languages.add(rust_language);
66
67 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
68 let db_path = db_dir.path().join("db.sqlite");
69
70 let store = VectorStore::new(
71 fs.clone(),
72 db_path,
73 Arc::new(FakeEmbeddingProvider),
74 languages,
75 cx.to_async(),
76 )
77 .await
78 .unwrap();
79
80 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
81 let worktree_id = project.read_with(cx, |project, cx| {
82 project.worktrees(cx).next().unwrap().read(cx).id()
83 });
84 let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
85
86 add_project.await.unwrap();
87
88 let search_results = store
89 .update(cx, |store, cx| {
90 store.search(project.clone(), "aaaa".to_string(), 5, cx)
91 })
92 .await
93 .unwrap();
94
95 assert_eq!(search_results[0].offset, 0);
96 assert_eq!(search_results[0].name, "aaa");
97 assert_eq!(search_results[0].worktree_id, worktree_id);
98}
99
100#[test]
101fn test_dot_product() {
102 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
103 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
104
105 for _ in 0..100 {
106 let mut rng = rand::thread_rng();
107 let a: [f32; 32] = rng.gen();
108 let b: [f32; 32] = rng.gen();
109 assert_eq!(
110 round_to_decimals(dot(&a, &b), 3),
111 round_to_decimals(reference_dot(&a, &b), 3)
112 );
113 }
114
115 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
116 let factor = (10.0 as f32).powi(decimal_places);
117 (n * factor).round() / factor
118 }
119
120 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
121 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
122 }
123}
124
125struct FakeEmbeddingProvider;
126
127#[async_trait]
128impl EmbeddingProvider for FakeEmbeddingProvider {
129 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
130 Ok(spans
131 .iter()
132 .map(|span| {
133 let mut result = vec![1.0; 26];
134 for letter in span.chars() {
135 let letter = letter.to_ascii_lowercase();
136 if letter as u32 >= 'a' as u32 {
137 let ix = (letter as u32) - ('a' as u32);
138 if ix < 26 {
139 result[ix as usize] += 1.0;
140 }
141 }
142 }
143
144 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
145 for x in &mut result {
146 *x /= norm;
147 }
148
149 result
150 })
151 .collect())
152 }
153}