vector_store_tests.rs

  1use crate::{
  2    db::dot,
  3    embedding::EmbeddingProvider,
  4    parsing::{CodeContextRetriever, Document},
  5    vector_store_settings::VectorStoreSettings,
  6    VectorStore,
  7};
  8use anyhow::Result;
  9use async_trait::async_trait;
 10use gpui::{Task, TestAppContext};
 11use language::{Language, LanguageConfig, LanguageRegistry};
 12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
 13use rand::{rngs::StdRng, Rng};
 14use serde_json::json;
 15use settings::SettingsStore;
 16use std::{
 17    path::Path,
 18    sync::{
 19        atomic::{self, AtomicUsize},
 20        Arc,
 21    },
 22};
 23use unindent::Unindent;
 24
 25#[ctor::ctor]
 26fn init_logger() {
 27    if std::env::var("RUST_LOG").is_ok() {
 28        env_logger::init();
 29    }
 30}
 31
 32#[gpui::test]
 33async fn test_vector_store(cx: &mut TestAppContext) {
 34    cx.update(|cx| {
 35        cx.set_global(SettingsStore::test(cx));
 36        settings::register::<VectorStoreSettings>(cx);
 37        settings::register::<ProjectSettings>(cx);
 38    });
 39
 40    let fs = FakeFs::new(cx.background());
 41    fs.insert_tree(
 42        "/the-root",
 43        json!({
 44            "src": {
 45                "file1.rs": "
 46                    fn aaa() {
 47                        println!(\"aaaa!\");
 48                    }
 49
 50                    fn zzzzzzzzz() {
 51                        println!(\"SLEEPING\");
 52                    }
 53                ".unindent(),
 54                "file2.rs": "
 55                    fn bbb() {
 56                        println!(\"bbbb!\");
 57                    }
 58                ".unindent(),
 59            }
 60        }),
 61    )
 62    .await;
 63
 64    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 65    let rust_language = rust_lang();
 66    languages.add(rust_language);
 67
 68    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 69    let db_path = db_dir.path().join("db.sqlite");
 70
 71    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 72    let store = VectorStore::new(
 73        fs.clone(),
 74        db_path,
 75        embedding_provider.clone(),
 76        languages,
 77        cx.to_async(),
 78    )
 79    .await
 80    .unwrap();
 81
 82    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 83    let worktree_id = project.read_with(cx, |project, cx| {
 84        project.worktrees(cx).next().unwrap().read(cx).id()
 85    });
 86    let file_count = store
 87        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 88        .await
 89        .unwrap();
 90    assert_eq!(file_count, 2);
 91    cx.foreground().run_until_parked();
 92    store.update(cx, |store, _cx| {
 93        assert_eq!(
 94            store.remaining_files_to_index_for_project(&project),
 95            Some(0)
 96        );
 97    });
 98
 99    let search_results = store
100        .update(cx, |store, cx| {
101            store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
102        })
103        .await
104        .unwrap();
105
106    assert_eq!(search_results[0].byte_range.start, 0);
107    assert_eq!(search_results[0].name, "aaa");
108    assert_eq!(search_results[0].worktree_id, worktree_id);
109
110    fs.save(
111        "/the-root/src/file2.rs".as_ref(),
112        &"
113            fn dddd() { println!(\"ddddd!\"); }
114            struct pqpqpqp {}
115        "
116        .unindent()
117        .into(),
118        Default::default(),
119    )
120    .await
121    .unwrap();
122
123    cx.foreground().run_until_parked();
124
125    let prev_embedding_count = embedding_provider.embedding_count();
126    let file_count = store
127        .update(cx, |store, cx| store.index_project(project.clone(), cx))
128        .await
129        .unwrap();
130    assert_eq!(file_count, 1);
131
132    cx.foreground().run_until_parked();
133    store.update(cx, |store, _cx| {
134        assert_eq!(
135            store.remaining_files_to_index_for_project(&project),
136            Some(0)
137        );
138    });
139
140    assert_eq!(
141        embedding_provider.embedding_count() - prev_embedding_count,
142        2
143    );
144}
145
146#[gpui::test]
147async fn test_code_context_retrieval() {
148    let language = rust_lang();
149    let mut retriever = CodeContextRetriever::new();
150
151    let text = "
152        /// A doc comment
153        /// that spans multiple lines
154        fn a() {
155            b
156        }
157
158        impl C for D {
159        }
160    "
161    .unindent();
162
163    let parsed_files = retriever
164        .parse_file(Path::new("foo.rs"), &text, language)
165        .unwrap();
166
167    assert_eq!(
168        parsed_files,
169        &[
170            Document {
171                name: "a".into(),
172                range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
173                content: "
174                    The below code snippet is from file 'foo.rs'
175
176                    ```rust
177                    /// A doc comment
178                    /// that spans multiple lines
179                    fn a() {
180                        b
181                    }
182                    ```"
183                .unindent(),
184                embedding: vec![],
185            },
186            Document {
187                name: "C for D".into(),
188                range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
189                content: "
190                    The below code snippet is from file 'foo.rs'
191
192                    ```rust
193                    impl C for D {
194                    }
195                    ```"
196                .unindent(),
197                embedding: vec![],
198            }
199        ]
200    );
201}
202
203#[gpui::test]
204fn test_dot_product(mut rng: StdRng) {
205    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
206    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
207
208    for _ in 0..100 {
209        let size = 1536;
210        let mut a = vec![0.; size];
211        let mut b = vec![0.; size];
212        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
213            *a = rng.gen();
214            *b = rng.gen();
215        }
216
217        assert_eq!(
218            round_to_decimals(dot(&a, &b), 1),
219            round_to_decimals(reference_dot(&a, &b), 1)
220        );
221    }
222
223    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
224        let factor = (10.0 as f32).powi(decimal_places);
225        (n * factor).round() / factor
226    }
227
228    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
229        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
230    }
231}
232
233#[derive(Default)]
234struct FakeEmbeddingProvider {
235    embedding_count: AtomicUsize,
236}
237
238impl FakeEmbeddingProvider {
239    fn embedding_count(&self) -> usize {
240        self.embedding_count.load(atomic::Ordering::SeqCst)
241    }
242}
243
244#[async_trait]
245impl EmbeddingProvider for FakeEmbeddingProvider {
246    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
247        self.embedding_count
248            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
249        Ok(spans
250            .iter()
251            .map(|span| {
252                let mut result = vec![1.0; 26];
253                for letter in span.chars() {
254                    let letter = letter.to_ascii_lowercase();
255                    if letter as u32 >= 'a' as u32 {
256                        let ix = (letter as u32) - ('a' as u32);
257                        if ix < 26 {
258                            result[ix as usize] += 1.0;
259                        }
260                    }
261                }
262
263                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
264                for x in &mut result {
265                    *x /= norm;
266                }
267
268                result
269            })
270            .collect())
271    }
272}
273
274fn rust_lang() -> Arc<Language> {
275    Arc::new(
276        Language::new(
277            LanguageConfig {
278                name: "Rust".into(),
279                path_suffixes: vec!["rs".into()],
280                ..Default::default()
281            },
282            Some(tree_sitter_rust::language()),
283        )
284        .with_embedding_query(
285            r#"
286            (
287                (line_comment)* @context
288                .
289                (enum_item
290                    name: (_) @name) @item
291            )
292
293            (
294                (line_comment)* @context
295                .
296                (struct_item
297                    name: (_) @name) @item
298            )
299
300            (
301                (line_comment)* @context
302                .
303                (impl_item
304                    trait: (_)? @name
305                    "for"? @name
306                    type: (_) @name) @item
307            )
308
309            (
310                (line_comment)* @context
311                .
312                (trait_item
313                    name: (_) @name) @item
314            )
315
316            (
317                (line_comment)* @context
318                .
319                (function_item
320                    name: (_) @name) @item
321            )
322
323            (
324                (line_comment)* @context
325                .
326                (macro_definition
327                    name: (_) @name) @item
328            )
329
330            (
331                (line_comment)* @context
332                .
333                (function_signature_item
334                    name: (_) @name) @item
335            )
336            "#,
337        )
338        .unwrap(),
339    )
340}