1use crate::{
2 db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
3};
4use anyhow::Result;
5use async_trait::async_trait;
6use gpui::{Task, TestAppContext};
7use language::{Language, LanguageConfig, LanguageRegistry};
8use project::{project_settings::ProjectSettings, FakeFs, Project};
9use rand::{rngs::StdRng, Rng};
10use serde_json::json;
11use settings::SettingsStore;
12use std::sync::Arc;
13use unindent::Unindent;
14
15#[gpui::test]
16async fn test_vector_store(cx: &mut TestAppContext) {
17 cx.update(|cx| {
18 cx.set_global(SettingsStore::test(cx));
19 settings::register::<VectorStoreSettings>(cx);
20 settings::register::<ProjectSettings>(cx);
21 });
22
23 let fs = FakeFs::new(cx.background());
24 fs.insert_tree(
25 "/the-root",
26 json!({
27 "src": {
28 "file1.rs": "
29 fn aaa() {
30 println!(\"aaaa!\");
31 }
32
33 fn zzzzzzzzz() {
34 println!(\"SLEEPING\");
35 }
36 ".unindent(),
37 "file2.rs": "
38 fn bbb() {
39 println!(\"bbbb!\");
40 }
41 ".unindent(),
42 }
43 }),
44 )
45 .await;
46
47 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
48 let rust_language = Arc::new(
49 Language::new(
50 LanguageConfig {
51 name: "Rust".into(),
52 path_suffixes: vec!["rs".into()],
53 ..Default::default()
54 },
55 Some(tree_sitter_rust::language()),
56 )
57 .with_embedding_query(
58 r#"
59 (function_item
60 name: (identifier) @name
61 body: (block)) @item
62 "#,
63 )
64 .unwrap(),
65 );
66 languages.add(rust_language);
67
68 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
69 let db_path = db_dir.path().join("db.sqlite");
70
71 let store = VectorStore::new(
72 fs.clone(),
73 db_path,
74 Arc::new(FakeEmbeddingProvider),
75 languages,
76 cx.to_async(),
77 )
78 .await
79 .unwrap();
80
81 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
82 let worktree_id = project.read_with(cx, |project, cx| {
83 project.worktrees(cx).next().unwrap().read(cx).id()
84 });
85 store
86 .update(cx, |store, cx| store.add_project(project.clone(), cx))
87 .await
88 .unwrap();
89 cx.foreground().run_until_parked();
90
91 let search_results = store
92 .update(cx, |store, cx| {
93 store.search(project.clone(), "aaaa".to_string(), 5, cx)
94 })
95 .await
96 .unwrap();
97
98 assert_eq!(search_results[0].offset, 0);
99 assert_eq!(search_results[0].name, "aaa");
100 assert_eq!(search_results[0].worktree_id, worktree_id);
101}
102
103#[gpui::test]
104fn test_dot_product(mut rng: StdRng) {
105 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
106 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
107
108 for _ in 0..100 {
109 let size = 1536;
110 let mut a = vec![0.; size];
111 let mut b = vec![0.; size];
112 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
113 *a = rng.gen();
114 *b = rng.gen();
115 }
116
117 assert_eq!(
118 round_to_decimals(dot(&a, &b), 1),
119 round_to_decimals(reference_dot(&a, &b), 1)
120 );
121 }
122
123 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
124 let factor = (10.0 as f32).powi(decimal_places);
125 (n * factor).round() / factor
126 }
127
128 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
129 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
130 }
131}
132
133struct FakeEmbeddingProvider;
134
135#[async_trait]
136impl EmbeddingProvider for FakeEmbeddingProvider {
137 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
138 Ok(spans
139 .iter()
140 .map(|span| {
141 let mut result = vec![1.0; 26];
142 for letter in span.chars() {
143 let letter = letter.to_ascii_lowercase();
144 if letter as u32 >= 'a' as u32 {
145 let ix = (letter as u32) - ('a' as u32);
146 if ix < 26 {
147 result[ix as usize] += 1.0;
148 }
149 }
150 }
151
152 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
153 for x in &mut result {
154 *x /= norm;
155 }
156
157 result
158 })
159 .collect())
160 }
161}