vector_store_tests.rs

  1use crate::{
  2    db::dot,
  3    embedding::EmbeddingProvider,
  4    parsing::{CodeContextRetriever, Document},
  5    vector_store_settings::VectorStoreSettings,
  6    VectorStore,
  7};
  8use anyhow::Result;
  9use async_trait::async_trait;
 10use gpui::{Task, TestAppContext};
 11use language::{Language, LanguageConfig, LanguageRegistry};
 12use project::{project_settings::ProjectSettings, FakeFs, Project};
 13use rand::{rngs::StdRng, Rng};
 14use serde_json::json;
 15use settings::SettingsStore;
 16use std::{path::Path, sync::Arc};
 17use unindent::Unindent;
 18
 19#[ctor::ctor]
 20fn init_logger() {
 21    if std::env::var("RUST_LOG").is_ok() {
 22        env_logger::init();
 23    }
 24}
 25
 26#[gpui::test]
 27async fn test_vector_store(cx: &mut TestAppContext) {
 28    cx.update(|cx| {
 29        cx.set_global(SettingsStore::test(cx));
 30        settings::register::<VectorStoreSettings>(cx);
 31        settings::register::<ProjectSettings>(cx);
 32    });
 33
 34    let fs = FakeFs::new(cx.background());
 35    fs.insert_tree(
 36        "/the-root",
 37        json!({
 38            "src": {
 39                "file1.rs": "
 40                    fn aaa() {
 41                        println!(\"aaaa!\");
 42                    }
 43
 44                    fn zzzzzzzzz() {
 45                        println!(\"SLEEPING\");
 46                    }
 47                ".unindent(),
 48                "file2.rs": "
 49                    fn bbb() {
 50                        println!(\"bbbb!\");
 51                    }
 52                ".unindent(),
 53            }
 54        }),
 55    )
 56    .await;
 57
 58    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 59    let rust_language = rust_lang();
 60    languages.add(rust_language);
 61
 62    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 63    let db_path = db_dir.path().join("db.sqlite");
 64
 65    let store = VectorStore::new(
 66        fs.clone(),
 67        db_path,
 68        Arc::new(FakeEmbeddingProvider),
 69        languages,
 70        cx.to_async(),
 71    )
 72    .await
 73    .unwrap();
 74
 75    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 76    let worktree_id = project.read_with(cx, |project, cx| {
 77        project.worktrees(cx).next().unwrap().read(cx).id()
 78    });
 79    store
 80        .update(cx, |store, cx| store.add_project(project.clone(), cx))
 81        .await
 82        .unwrap();
 83    cx.foreground().run_until_parked();
 84
 85    let search_results = store
 86        .update(cx, |store, cx| {
 87            store.search(project.clone(), "aaaa".to_string(), 5, cx)
 88        })
 89        .await
 90        .unwrap();
 91
 92    assert_eq!(search_results[0].byte_range.start, 0);
 93    assert_eq!(search_results[0].name, "aaa");
 94    assert_eq!(search_results[0].worktree_id, worktree_id);
 95}
 96
 97#[gpui::test]
 98async fn test_code_context_retrieval(cx: &mut TestAppContext) {
 99    let language = rust_lang();
100    let mut retriever = CodeContextRetriever::new();
101
102    let text = "
103        /// A doc comment
104        /// that spans multiple lines
105        fn a() {
106            b
107        }
108
109        impl C for D {
110        }
111    "
112    .unindent();
113
114    let parsed_files = retriever
115        .parse_file(Path::new("foo.rs"), &text, language)
116        .unwrap();
117
118    assert_eq!(
119        parsed_files,
120        &[
121            Document {
122                name: "a".into(),
123                range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
124                content: "
125                    The below code snippet is from file 'foo.rs'
126
127                    ```rust
128                    /// A doc comment
129                    /// that spans multiple lines
130                    fn a() {
131                        b
132                    }
133                    ```"
134                .unindent(),
135                embedding: vec![],
136            },
137            Document {
138                name: "C for D".into(),
139                range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
140                content: "
141                    The below code snippet is from file 'foo.rs'
142
143                    ```rust
144                    impl C for D {
145                    }
146                    ```"
147                .unindent(),
148                embedding: vec![],
149            }
150        ]
151    );
152}
153
154#[gpui::test]
155fn test_dot_product(mut rng: StdRng) {
156    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
157    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
158
159    for _ in 0..100 {
160        let size = 1536;
161        let mut a = vec![0.; size];
162        let mut b = vec![0.; size];
163        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
164            *a = rng.gen();
165            *b = rng.gen();
166        }
167
168        assert_eq!(
169            round_to_decimals(dot(&a, &b), 1),
170            round_to_decimals(reference_dot(&a, &b), 1)
171        );
172    }
173
174    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
175        let factor = (10.0 as f32).powi(decimal_places);
176        (n * factor).round() / factor
177    }
178
179    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
180        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
181    }
182}
183
184struct FakeEmbeddingProvider;
185
186#[async_trait]
187impl EmbeddingProvider for FakeEmbeddingProvider {
188    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
189        Ok(spans
190            .iter()
191            .map(|span| {
192                let mut result = vec![1.0; 26];
193                for letter in span.chars() {
194                    let letter = letter.to_ascii_lowercase();
195                    if letter as u32 >= 'a' as u32 {
196                        let ix = (letter as u32) - ('a' as u32);
197                        if ix < 26 {
198                            result[ix as usize] += 1.0;
199                        }
200                    }
201                }
202
203                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
204                for x in &mut result {
205                    *x /= norm;
206                }
207
208                result
209            })
210            .collect())
211    }
212}
213
214fn rust_lang() -> Arc<Language> {
215    Arc::new(
216        Language::new(
217            LanguageConfig {
218                name: "Rust".into(),
219                path_suffixes: vec!["rs".into()],
220                ..Default::default()
221            },
222            Some(tree_sitter_rust::language()),
223        )
224        .with_embedding_query(
225            r#"
226            (
227                (line_comment)* @context
228                .
229                (enum_item
230                    name: (_) @name) @item
231            )
232
233            (
234                (line_comment)* @context
235                .
236                (struct_item
237                    name: (_) @name) @item
238            )
239
240            (
241                (line_comment)* @context
242                .
243                (impl_item
244                    trait: (_)? @name
245                    "for"? @name
246                    type: (_) @name) @item
247            )
248
249            (
250                (line_comment)* @context
251                .
252                (trait_item
253                    name: (_) @name) @item
254            )
255
256            (
257                (line_comment)* @context
258                .
259                (function_item
260                    name: (_) @name) @item
261            )
262
263            (
264                (line_comment)* @context
265                .
266                (macro_definition
267                    name: (_) @name) @item
268            )
269
270            (
271                (line_comment)* @context
272                .
273                (function_signature_item
274                    name: (_) @name) @item
275            )
276            "#,
277        )
278        .unwrap(),
279    )
280}