Optimize glob filtering of semantic search

Max Brunsfeld and Kyle created

Co-authored-by: Kyle <kyle@zed.dev>

Change summary

crates/search/src/project_search.rs               |   1 
crates/semantic_index/src/db.rs                   |  66 +++++++---
crates/semantic_index/src/semantic_index.rs       |   2 
crates/semantic_index/src/semantic_index_tests.rs | 103 ++++++++++------
4 files changed, 109 insertions(+), 63 deletions(-)

Detailed changes

crates/search/src/project_search.rs 🔗

@@ -669,7 +669,6 @@ impl ProjectSearchView {
         &mut self,
         cx: &mut ViewContext<Self>,
     ) -> Option<(Vec<GlobMatcher>, Vec<GlobMatcher>)> {
-        let text = self.query_editor.read(cx).text(cx);
         let included_files =
             match Self::load_glob_set(&self.included_files_editor.read(cx).text(cx)) {
                 Ok(included_files) => {

crates/semantic_index/src/db.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
 use anyhow::{anyhow, Context, Result};
-use globset::{Glob, GlobMatcher};
+use globset::GlobMatcher;
 use project::Fs;
 use rpc::proto::Timestamp;
 use rusqlite::{
@@ -257,16 +257,11 @@ impl VectorDatabase {
         exclude_globs: Vec<GlobMatcher>,
     ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
         let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-        self.for_each_document(&worktree_ids, |relative_path, id, embedding| {
-            if (include_globs.is_empty()
-                || include_globs
-                    .iter()
-                    .any(|include_glob| include_glob.is_match(relative_path.clone())))
-                && (exclude_globs.is_empty()
-                    || !exclude_globs
-                        .iter()
-                        .any(|exclude_glob| exclude_glob.is_match(relative_path.clone())))
-            {
+        self.for_each_document(
+            &worktree_ids,
+            include_globs,
+            exclude_globs,
+            |id, embedding| {
                 let similarity = dot(&embedding, &query_embedding);
                 let ix = match results.binary_search_by(|(_, s)| {
                     similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
@@ -276,8 +271,8 @@ impl VectorDatabase {
                 };
                 results.insert(ix, (id, similarity));
                 results.truncate(limit);
-            }
-        })?;
+            },
+        )?;
 
         let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
         self.get_documents_by_ids(&ids)
@@ -286,26 +281,55 @@ impl VectorDatabase {
     fn for_each_document(
         &self,
         worktree_ids: &[i64],
-        mut f: impl FnMut(String, i64, Vec<f32>),
+        include_globs: Vec<GlobMatcher>,
+        exclude_globs: Vec<GlobMatcher>,
+        mut f: impl FnMut(i64, Vec<f32>),
     ) -> Result<()> {
+        let mut file_query = self.db.prepare(
+            "
+            SELECT
+                id, relative_path
+            FROM
+                files
+            WHERE
+                worktree_id IN rarray(?)
+            ",
+        )?;
+
+        let mut file_ids = Vec::<i64>::new();
+        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
+        while let Some(row) = rows.next()? {
+            let file_id = row.get(0)?;
+            let relative_path = row.get_ref(1)?.as_str()?;
+            let included = include_globs.is_empty()
+                || include_globs
+                    .iter()
+                    .any(|glob| glob.is_match(relative_path));
+            let excluded = exclude_globs
+                .iter()
+                .any(|glob| glob.is_match(relative_path));
+            if included && !excluded {
+                file_ids.push(file_id);
+            }
+        }
+
         let mut query_statement = self.db.prepare(
             "
             SELECT
-                files.relative_path, documents.id, documents.embedding
+                id, embedding
             FROM
-                documents, files
+                documents
             WHERE
-                documents.file_id = files.id AND
-                files.worktree_id IN rarray(?)
+                file_id IN rarray(?)
             ",
         )?;
 
         query_statement
-            .query_map(params![ids_to_sql(worktree_ids)], |row| {
-                Ok((row.get(0)?, row.get(1)?, row.get::<_, Embedding>(2)?))
+            .query_map(params![ids_to_sql(&file_ids)], |row| {
+                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
             })?
             .filter_map(|row| row.ok())
-            .for_each(|(relative_path, id, embedding)| f(relative_path, id, embedding.0));
+            .for_each(|(id, embedding)| f(id, embedding.0));
         Ok(())
     }
 

crates/semantic_index/src/semantic_index.rs 🔗

@@ -11,7 +11,7 @@ use anyhow::{anyhow, Result};
 use db::VectorDatabase;
 use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 use futures::{channel::oneshot, Future};
-use globset::{Glob, GlobMatcher};
+use globset::GlobMatcher;
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -3,7 +3,7 @@ use crate::{
     embedding::EmbeddingProvider,
     parsing::{subtract_ranges, CodeContextRetriever, Document},
     semantic_index_settings::SemanticIndexSettings,
-    SemanticIndex,
+    SearchResult, SemanticIndex,
 };
 use anyhow::Result;
 use async_trait::async_trait;
@@ -46,21 +46,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
             "src": {
                 "file1.rs": "
                     fn aaa() {
-                        println!(\"aaaa!\");
+                        println!(\"aaaaaaaaaaaa!\");
                     }
 
-                    fn zzzzzzzzz() {
+                    fn zzzzz() {
                         println!(\"SLEEPING\");
                     }
                 ".unindent(),
                 "file2.rs": "
                     fn bbb() {
-                        println!(\"bbbb!\");
+                        println!(\"bbbbbbbbbbbbb!\");
                     }
                 ".unindent(),
                 "file3.toml": "
-                    ZZZZZZZ = 5
-                    ".unindent(),
+                    ZZZZZZZZZZZZZZZZZZ = 5
+                ".unindent(),
             }
         }),
     )
@@ -97,27 +97,37 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
 
     let search_results = store
         .update(cx, |store, cx| {
-            store.search_project(project.clone(), "aaaa".to_string(), 5, vec![], vec![], cx)
+            store.search_project(
+                project.clone(),
+                "aaaaaabbbbzz".to_string(),
+                5,
+                vec![],
+                vec![],
+                cx,
+            )
         })
         .await
         .unwrap();
 
-    search_results[0].buffer.read_with(cx, |buffer, _cx| {
-        assert_eq!(search_results[0].range.start.to_offset(buffer), 0);
-        assert_eq!(
-            buffer.file().unwrap().path().as_ref(),
-            Path::new("src/file1.rs")
-        );
-    });
+    assert_search_results(
+        &search_results,
+        &[
+            (Path::new("src/file1.rs").into(), 0),
+            (Path::new("src/file2.rs").into(), 0),
+            (Path::new("src/file3.toml").into(), 0),
+            (Path::new("src/file1.rs").into(), 45),
+        ],
+        cx,
+    );
 
     // Test Include Files Functonality
     let include_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
     let exclude_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
-    let search_results = store
+    let rust_only_search_results = store
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
-                "aaaa".to_string(),
+                "aaaaaabbbbzz".to_string(),
                 5,
                 include_files,
                 vec![],
@@ -127,23 +137,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
         .await
         .unwrap();
 
-    for res in &search_results {
-        res.buffer.read_with(cx, |buffer, _cx| {
-            assert!(buffer
-                .file()
-                .unwrap()
-                .path()
-                .to_str()
-                .unwrap()
-                .ends_with("rs"));
-        });
-    }
+    assert_search_results(
+        &rust_only_search_results,
+        &[
+            (Path::new("src/file1.rs").into(), 0),
+            (Path::new("src/file2.rs").into(), 0),
+            (Path::new("src/file1.rs").into(), 45),
+        ],
+        cx,
+    );
 
-    let search_results = store
+    let no_rust_search_results = store
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
-                "aaaa".to_string(),
+                "aaaaaabbbbzz".to_string(),
                 5,
                 vec![],
                 exclude_files,
@@ -153,17 +161,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
         .await
         .unwrap();
 
-    for res in &search_results {
-        res.buffer.read_with(cx, |buffer, _cx| {
-            assert!(!buffer
-                .file()
-                .unwrap()
-                .path()
-                .to_str()
-                .unwrap()
-                .ends_with("rs"));
-        });
-    }
+    assert_search_results(
+        &no_rust_search_results,
+        &[(Path::new("src/file3.toml").into(), 0)],
+        cx,
+    );
+
     fs.save(
         "/the-root/src/file2.rs".as_ref(),
         &"
@@ -195,6 +198,26 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     );
 }
 
+#[track_caller]
+fn assert_search_results(
+    actual: &[SearchResult],
+    expected: &[(Arc<Path>, usize)],
+    cx: &TestAppContext,
+) {
+    let actual = actual
+        .iter()
+        .map(|search_result| {
+            search_result.buffer.read_with(cx, |buffer, _cx| {
+                (
+                    buffer.file().unwrap().path().clone(),
+                    search_result.range.start.to_offset(buffer),
+                )
+            })
+        })
+        .collect::<Vec<_>>();
+    assert_eq!(actual, expected);
+}
+
 #[gpui::test]
 async fn test_code_context_retrieval_rust() {
     let language = rust_lang();