From e02d6bc0d41fe5006307833f5e4c2cd62ba7add1 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 20 Jul 2023 13:46:27 -0400 Subject: [PATCH] add glob filtering functionality to semantic search --- Cargo.lock | 1 + crates/search/src/project_search.rs | 60 +++++++++++++++++-- crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/db.rs | 39 ++++++++---- crates/semantic_index/src/semantic_index.rs | 13 +++- .../src/semantic_index_tests.rs | 57 +++++++++++++++++- 6 files changed, 149 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 75f66163e3fbf5048b01cbf5079f00f2e9c5ce46..f534a4fe7d68a362fd910f0bd02cbf72b24955fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6477,6 +6477,7 @@ dependencies = [ "editor", "env_logger 0.9.3", "futures 0.3.28", + "globset", "gpui", "isahc", "language", diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 5feb94426eb60c67a756c564982a826699bd20a1..25fc897707af6be8b97b277a2d65b8d4cf1eeb17 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -187,14 +187,26 @@ impl ProjectSearch { cx.notify(); } - fn semantic_search(&mut self, query: String, cx: &mut ModelContext) { + fn semantic_search( + &mut self, + query: String, + include_files: Vec, + exclude_files: Vec, + cx: &mut ModelContext, + ) { let search = SemanticIndex::global(cx).map(|index| { index.update(cx, |semantic_index, cx| { - semantic_index.search_project(self.project.clone(), query.clone(), 10, cx) + semantic_index.search_project( + self.project.clone(), + query.clone(), + 10, + include_files, + exclude_files, + cx, + ) }) }); self.search_id += 1; - // self.active_query = Some(query); self.match_ranges.clear(); self.pending_search = Some(cx.spawn(|this, mut cx| async move { let results = search?.await.log_err()?; @@ -638,8 +650,13 @@ impl ProjectSearchView { } let query = self.query_editor.read(cx).text(cx); - self.model - .update(cx, |model, cx| model.semantic_search(query, cx)); + if let Some((included_files, exclude_files)) = + self.get_included_and_excluded_globsets(cx) + { + self.model.update(cx, |model, cx| { + model.semantic_search(query, included_files, exclude_files, cx) + }); + } return; } @@ -648,6 +665,39 @@ impl ProjectSearchView { } } + fn get_included_and_excluded_globsets( + &mut self, + cx: &mut ViewContext, + ) -> Option<(Vec, Vec)> { + let text = self.query_editor.read(cx).text(cx); + let included_files = + match Self::load_glob_set(&self.included_files_editor.read(cx).text(cx)) { + Ok(included_files) => { + self.panels_with_errors.remove(&InputPanel::Include); + included_files + } + Err(_e) => { + self.panels_with_errors.insert(InputPanel::Include); + cx.notify(); + return None; + } + }; + let excluded_files = + match Self::load_glob_set(&self.excluded_files_editor.read(cx).text(cx)) { + Ok(excluded_files) => { + self.panels_with_errors.remove(&InputPanel::Exclude); + excluded_files + } + Err(_e) => { + self.panels_with_errors.insert(InputPanel::Exclude); + cx.notify(); + return None; + } + }; + + Some((included_files, excluded_files)) + } + fn build_search_query(&mut self, cx: &mut ViewContext) -> Option { let text = self.query_editor.read(cx).text(cx); let included_files = diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 35b97245124e8922d5e7a46a369e26c71af7731a..a1f126bfb841ecb8334aeca391ac4959ef9f57b0 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -37,6 +37,7 @@ tiktoken-rs = "0.5.0" parking_lot.workspace = true rand.workspace = true schemars.workspace = true +globset.workspace = true [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index fd99594aab578919f80bd8236270b352a8540993..3ba85a275d0a0d6b197bbad22d5ad5bd792a2fbf 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,5 +1,6 @@ use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; use anyhow::{anyhow, Context, Result}; +use globset::{Glob, GlobMatcher}; use project::Fs; use rpc::proto::Timestamp; use rusqlite::{ @@ -252,18 +253,30 @@ impl VectorDatabase { worktree_ids: &[i64], query_embedding: &Vec, limit: usize, + include_globs: Vec, + exclude_globs: Vec, ) -> Result)>> { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - self.for_each_document(&worktree_ids, |id, embedding| { - let similarity = dot(&embedding, &query_embedding); - let ix = match results - .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) + self.for_each_document(&worktree_ids, |relative_path, id, embedding| { + if (include_globs.is_empty() + || include_globs + .iter() + .any(|include_glob| include_glob.is_match(relative_path.clone()))) + && (exclude_globs.is_empty() + || !exclude_globs + .iter() + .any(|exclude_glob| exclude_glob.is_match(relative_path.clone()))) { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); + let similarity = dot(&embedding, &query_embedding); + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + } })?; let ids = results.into_iter().map(|(id, _)| id).collect::>(); @@ -273,12 +286,12 @@ impl VectorDatabase { fn for_each_document( &self, worktree_ids: &[i64], - mut f: impl FnMut(i64, Vec), + mut f: impl FnMut(String, i64, Vec), ) -> Result<()> { let mut query_statement = self.db.prepare( " SELECT - documents.id, documents.embedding + files.relative_path, documents.id, documents.embedding FROM documents, files WHERE @@ -289,10 +302,10 @@ impl VectorDatabase { query_statement .query_map(params![ids_to_sql(worktree_ids)], |row| { - Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) + Ok((row.get(0)?, row.get(1)?, row.get::<_, Embedding>(2)?)) })? .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| f(id, embedding.0)); + .for_each(|(relative_path, id, embedding)| f(relative_path, id, embedding.0)); Ok(()) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6e0477491518a0c4a18ebfa1c24ddaf51eaf1948..32a11a42ebdcb01205869bcb273784582e291dcf 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -11,6 +11,7 @@ use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use futures::{channel::oneshot, Future}; +use globset::{Glob, GlobMatcher}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; @@ -624,6 +625,8 @@ impl SemanticIndex { project: ModelHandle, phrase: String, limit: usize, + include_globs: Vec, + exclude_globs: Vec, cx: &mut ModelContext, ) -> Task>> { let project_state = if let Some(state) = self.projects.get(&project.downgrade()) { @@ -657,12 +660,16 @@ impl SemanticIndex { .next() .unwrap(); - database.top_k_search(&worktree_db_ids, &phrase_embedding, limit) + database.top_k_search( + &worktree_db_ids, + &phrase_embedding, + limit, + include_globs, + exclude_globs, + ) }) .await?; - dbg!(&documents); - let mut tasks = Vec::new(); let mut ranges = Vec::new(); let weak_project = project.downgrade(); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 31c96ca207bb3da1ace202bd81df461f27ba229b..366d634ddb68df629832b23c3777d6f5cc775b7c 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -7,6 +7,7 @@ use crate::{ }; use anyhow::Result; use async_trait::async_trait; +use globset::Glob; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use pretty_assertions::assert_eq; @@ -96,7 +97,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let search_results = store .update(cx, |store, cx| { - store.search_project(project.clone(), "aaaa".to_string(), 5, cx) + store.search_project(project.clone(), "aaaa".to_string(), 5, vec![], vec![], cx) }) .await .unwrap(); @@ -109,6 +110,60 @@ async fn test_semantic_index(cx: &mut TestAppContext) { ); }); + // Test Include Files Functonality + let include_files = vec![Glob::new("*.rs").unwrap().compile_matcher()]; + let exclude_files = vec![Glob::new("*.rs").unwrap().compile_matcher()]; + let search_results = store + .update(cx, |store, cx| { + store.search_project( + project.clone(), + "aaaa".to_string(), + 5, + include_files, + vec![], + cx, + ) + }) + .await + .unwrap(); + + for res in &search_results { + res.buffer.read_with(cx, |buffer, _cx| { + assert!(buffer + .file() + .unwrap() + .path() + .to_str() + .unwrap() + .ends_with("rs")); + }); + } + + let search_results = store + .update(cx, |store, cx| { + store.search_project( + project.clone(), + "aaaa".to_string(), + 5, + vec![], + exclude_files, + cx, + ) + }) + .await + .unwrap(); + + for res in &search_results { + res.buffer.read_with(cx, |buffer, _cx| { + assert!(!buffer + .file() + .unwrap() + .path() + .to_str() + .unwrap() + .ends_with("rs")); + }); + } fs.save( "/the-root/src/file2.rs".as_ref(), &"