add glob filtering functionality to semantic search

KCaverly created

Change summary

Cargo.lock                                        |  1 
crates/search/src/project_search.rs               | 60 +++++++++++++++-
crates/semantic_index/Cargo.toml                  |  1 
crates/semantic_index/src/db.rs                   | 39 +++++++---
crates/semantic_index/src/semantic_index.rs       | 13 ++
crates/semantic_index/src/semantic_index_tests.rs | 57 +++++++++++++++
6 files changed, 149 insertions(+), 22 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6477,6 +6477,7 @@ dependencies = [
  "editor",
  "env_logger 0.9.3",
  "futures 0.3.28",
+ "globset",
  "gpui",
  "isahc",
  "language",

crates/search/src/project_search.rs 🔗

@@ -187,14 +187,26 @@ impl ProjectSearch {
         cx.notify();
     }
 
-    fn semantic_search(&mut self, query: String, cx: &mut ModelContext<Self>) {
+    fn semantic_search(
+        &mut self,
+        query: String,
+        include_files: Vec<GlobMatcher>,
+        exclude_files: Vec<GlobMatcher>,
+        cx: &mut ModelContext<Self>,
+    ) {
         let search = SemanticIndex::global(cx).map(|index| {
             index.update(cx, |semantic_index, cx| {
-                semantic_index.search_project(self.project.clone(), query.clone(), 10, cx)
+                semantic_index.search_project(
+                    self.project.clone(),
+                    query.clone(),
+                    10,
+                    include_files,
+                    exclude_files,
+                    cx,
+                )
             })
         });
         self.search_id += 1;
-        // self.active_query = Some(query);
         self.match_ranges.clear();
         self.pending_search = Some(cx.spawn(|this, mut cx| async move {
             let results = search?.await.log_err()?;
@@ -638,8 +650,13 @@ impl ProjectSearchView {
             }
 
             let query = self.query_editor.read(cx).text(cx);
-            self.model
-                .update(cx, |model, cx| model.semantic_search(query, cx));
+            if let Some((included_files, exclude_files)) =
+                self.get_included_and_excluded_globsets(cx)
+            {
+                self.model.update(cx, |model, cx| {
+                    model.semantic_search(query, included_files, exclude_files, cx)
+                });
+            }
             return;
         }
 
@@ -648,6 +665,39 @@ impl ProjectSearchView {
         }
     }
 
+    fn get_included_and_excluded_globsets(
+        &mut self,
+        cx: &mut ViewContext<Self>,
+    ) -> Option<(Vec<GlobMatcher>, Vec<GlobMatcher>)> {
+        let text = self.query_editor.read(cx).text(cx);
+        let included_files =
+            match Self::load_glob_set(&self.included_files_editor.read(cx).text(cx)) {
+                Ok(included_files) => {
+                    self.panels_with_errors.remove(&InputPanel::Include);
+                    included_files
+                }
+                Err(_e) => {
+                    self.panels_with_errors.insert(InputPanel::Include);
+                    cx.notify();
+                    return None;
+                }
+            };
+        let excluded_files =
+            match Self::load_glob_set(&self.excluded_files_editor.read(cx).text(cx)) {
+                Ok(excluded_files) => {
+                    self.panels_with_errors.remove(&InputPanel::Exclude);
+                    excluded_files
+                }
+                Err(_e) => {
+                    self.panels_with_errors.insert(InputPanel::Exclude);
+                    cx.notify();
+                    return None;
+                }
+            };
+
+        Some((included_files, excluded_files))
+    }
+
     fn build_search_query(&mut self, cx: &mut ViewContext<Self>) -> Option<SearchQuery> {
         let text = self.query_editor.read(cx).text(cx);
         let included_files =

crates/semantic_index/Cargo.toml 🔗

@@ -37,6 +37,7 @@ tiktoken-rs = "0.5.0"
 parking_lot.workspace = true
 rand.workspace = true
 schemars.workspace = true
+globset.workspace = true
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }

crates/semantic_index/src/db.rs 🔗

@@ -1,5 +1,6 @@
 use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
 use anyhow::{anyhow, Context, Result};
+use globset::{Glob, GlobMatcher};
 use project::Fs;
 use rpc::proto::Timestamp;
 use rusqlite::{
@@ -252,18 +253,30 @@ impl VectorDatabase {
         worktree_ids: &[i64],
         query_embedding: &Vec<f32>,
         limit: usize,
+        include_globs: Vec<GlobMatcher>,
+        exclude_globs: Vec<GlobMatcher>,
     ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
         let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-        self.for_each_document(&worktree_ids, |id, embedding| {
-            let similarity = dot(&embedding, &query_embedding);
-            let ix = match results
-                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
+        self.for_each_document(&worktree_ids, |relative_path, id, embedding| {
+            if (include_globs.is_empty()
+                || include_globs
+                    .iter()
+                    .any(|include_glob| include_glob.is_match(relative_path.clone())))
+                && (exclude_globs.is_empty()
+                    || !exclude_globs
+                        .iter()
+                        .any(|exclude_glob| exclude_glob.is_match(relative_path.clone())))
             {
-                Ok(ix) => ix,
-                Err(ix) => ix,
-            };
-            results.insert(ix, (id, similarity));
-            results.truncate(limit);
+                let similarity = dot(&embedding, &query_embedding);
+                let ix = match results.binary_search_by(|(_, s)| {
+                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+                }) {
+                    Ok(ix) => ix,
+                    Err(ix) => ix,
+                };
+                results.insert(ix, (id, similarity));
+                results.truncate(limit);
+            }
         })?;
 
         let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
@@ -273,12 +286,12 @@ impl VectorDatabase {
     fn for_each_document(
         &self,
         worktree_ids: &[i64],
-        mut f: impl FnMut(i64, Vec<f32>),
+        mut f: impl FnMut(String, i64, Vec<f32>),
     ) -> Result<()> {
         let mut query_statement = self.db.prepare(
             "
             SELECT
-                documents.id, documents.embedding
+                files.relative_path, documents.id, documents.embedding
             FROM
                 documents, files
             WHERE
@@ -289,10 +302,10 @@ impl VectorDatabase {
 
         query_statement
             .query_map(params![ids_to_sql(worktree_ids)], |row| {
-                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
+                Ok((row.get(0)?, row.get(1)?, row.get::<_, Embedding>(2)?))
             })?
             .filter_map(|row| row.ok())
-            .for_each(|(id, embedding)| f(id, embedding.0));
+            .for_each(|(relative_path, id, embedding)| f(relative_path, id, embedding.0));
         Ok(())
     }
 

crates/semantic_index/src/semantic_index.rs 🔗

@@ -11,6 +11,7 @@ use anyhow::{anyhow, Result};
 use db::VectorDatabase;
 use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 use futures::{channel::oneshot, Future};
+use globset::{Glob, GlobMatcher};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;
@@ -624,6 +625,8 @@ impl SemanticIndex {
         project: ModelHandle<Project>,
         phrase: String,
         limit: usize,
+        include_globs: Vec<GlobMatcher>,
+        exclude_globs: Vec<GlobMatcher>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<SearchResult>>> {
         let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
@@ -657,12 +660,16 @@ impl SemanticIndex {
                         .next()
                         .unwrap();
 
-                    database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
+                    database.top_k_search(
+                        &worktree_db_ids,
+                        &phrase_embedding,
+                        limit,
+                        include_globs,
+                        exclude_globs,
+                    )
                 })
                 .await?;
 
-            dbg!(&documents);
-
             let mut tasks = Vec::new();
             let mut ranges = Vec::new();
             let weak_project = project.downgrade();

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -7,6 +7,7 @@ use crate::{
 };
 use anyhow::Result;
 use async_trait::async_trait;
+use globset::Glob;
 use gpui::{Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
 use pretty_assertions::assert_eq;
@@ -96,7 +97,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
 
     let search_results = store
         .update(cx, |store, cx| {
-            store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
+            store.search_project(project.clone(), "aaaa".to_string(), 5, vec![], vec![], cx)
         })
         .await
         .unwrap();
@@ -109,6 +110,60 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
         );
     });
 
+    // Test Include Files Functonality
+    let include_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
+    let exclude_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
+    let search_results = store
+        .update(cx, |store, cx| {
+            store.search_project(
+                project.clone(),
+                "aaaa".to_string(),
+                5,
+                include_files,
+                vec![],
+                cx,
+            )
+        })
+        .await
+        .unwrap();
+
+    for res in &search_results {
+        res.buffer.read_with(cx, |buffer, _cx| {
+            assert!(buffer
+                .file()
+                .unwrap()
+                .path()
+                .to_str()
+                .unwrap()
+                .ends_with("rs"));
+        });
+    }
+
+    let search_results = store
+        .update(cx, |store, cx| {
+            store.search_project(
+                project.clone(),
+                "aaaa".to_string(),
+                5,
+                vec![],
+                exclude_files,
+                cx,
+            )
+        })
+        .await
+        .unwrap();
+
+    for res in &search_results {
+        res.buffer.read_with(cx, |buffer, _cx| {
+            assert!(!buffer
+                .file()
+                .unwrap()
+                .path()
+                .to_str()
+                .unwrap()
+                .ends_with("rs"));
+        });
+    }
     fs.save(
         "/the-root/src/file2.rs".as_ref(),
         &"