@@ -1,6 +1,6 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use anyhow::{anyhow, Context, Result};
-use globset::{Glob, GlobMatcher};
+use globset::GlobMatcher;
use project::Fs;
use rpc::proto::Timestamp;
use rusqlite::{
@@ -257,16 +257,11 @@ impl VectorDatabase {
exclude_globs: Vec<GlobMatcher>,
) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
- self.for_each_document(&worktree_ids, |relative_path, id, embedding| {
- if (include_globs.is_empty()
- || include_globs
- .iter()
- .any(|include_glob| include_glob.is_match(relative_path.clone())))
- && (exclude_globs.is_empty()
- || !exclude_globs
- .iter()
- .any(|exclude_glob| exclude_glob.is_match(relative_path.clone())))
- {
+ self.for_each_document(
+ &worktree_ids,
+ include_globs,
+ exclude_globs,
+ |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
@@ -276,8 +271,8 @@ impl VectorDatabase {
};
results.insert(ix, (id, similarity));
results.truncate(limit);
- }
- })?;
+ },
+ )?;
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
self.get_documents_by_ids(&ids)
@@ -286,26 +281,55 @@ impl VectorDatabase {
fn for_each_document(
&self,
worktree_ids: &[i64],
- mut f: impl FnMut(String, i64, Vec<f32>),
+ include_globs: Vec<GlobMatcher>,
+ exclude_globs: Vec<GlobMatcher>,
+ mut f: impl FnMut(i64, Vec<f32>),
) -> Result<()> {
+ let mut file_query = self.db.prepare(
+ "
+ SELECT
+ id, relative_path
+ FROM
+ files
+ WHERE
+ worktree_id IN rarray(?)
+ ",
+ )?;
+
+ let mut file_ids = Vec::<i64>::new();
+ let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
+ while let Some(row) = rows.next()? {
+ let file_id = row.get(0)?;
+ let relative_path = row.get_ref(1)?.as_str()?;
+ let included = include_globs.is_empty()
+ || include_globs
+ .iter()
+ .any(|glob| glob.is_match(relative_path));
+ let excluded = exclude_globs
+ .iter()
+ .any(|glob| glob.is_match(relative_path));
+ if included && !excluded {
+ file_ids.push(file_id);
+ }
+ }
+
let mut query_statement = self.db.prepare(
"
SELECT
- files.relative_path, documents.id, documents.embedding
+ id, embedding
FROM
- documents, files
+ documents
WHERE
- documents.file_id = files.id AND
- files.worktree_id IN rarray(?)
+ file_id IN rarray(?)
",
)?;
query_statement
- .query_map(params![ids_to_sql(worktree_ids)], |row| {
- Ok((row.get(0)?, row.get(1)?, row.get::<_, Embedding>(2)?))
+ .query_map(params![ids_to_sql(&file_ids)], |row| {
+ Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
- .for_each(|(relative_path, id, embedding)| f(relative_path, id, embedding.0));
+ .for_each(|(id, embedding)| f(id, embedding.0));
Ok(())
}
@@ -11,7 +11,7 @@ use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use futures::{channel::oneshot, Future};
-use globset::{Glob, GlobMatcher};
+use globset::GlobMatcher;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry};
use parking_lot::Mutex;
@@ -3,7 +3,7 @@ use crate::{
embedding::EmbeddingProvider,
parsing::{subtract_ranges, CodeContextRetriever, Document},
semantic_index_settings::SemanticIndexSettings,
- SemanticIndex,
+ SearchResult, SemanticIndex,
};
use anyhow::Result;
use async_trait::async_trait;
@@ -46,21 +46,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
"src": {
"file1.rs": "
fn aaa() {
- println!(\"aaaa!\");
+ println!(\"aaaaaaaaaaaa!\");
}
- fn zzzzzzzzz() {
+ fn zzzzz() {
println!(\"SLEEPING\");
}
".unindent(),
"file2.rs": "
fn bbb() {
- println!(\"bbbb!\");
+ println!(\"bbbbbbbbbbbbb!\");
}
".unindent(),
"file3.toml": "
- ZZZZZZZ = 5
- ".unindent(),
+ ZZZZZZZZZZZZZZZZZZ = 5
+ ".unindent(),
}
}),
)
@@ -97,27 +97,37 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let search_results = store
.update(cx, |store, cx| {
- store.search_project(project.clone(), "aaaa".to_string(), 5, vec![], vec![], cx)
+ store.search_project(
+ project.clone(),
+ "aaaaaabbbbzz".to_string(),
+ 5,
+ vec![],
+ vec![],
+ cx,
+ )
})
.await
.unwrap();
- search_results[0].buffer.read_with(cx, |buffer, _cx| {
- assert_eq!(search_results[0].range.start.to_offset(buffer), 0);
- assert_eq!(
- buffer.file().unwrap().path().as_ref(),
- Path::new("src/file1.rs")
- );
- });
+ assert_search_results(
+ &search_results,
+ &[
+ (Path::new("src/file1.rs").into(), 0),
+ (Path::new("src/file2.rs").into(), 0),
+ (Path::new("src/file3.toml").into(), 0),
+ (Path::new("src/file1.rs").into(), 45),
+ ],
+ cx,
+ );
// Test Include Files Functonality
let include_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
let exclude_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
- let search_results = store
+ let rust_only_search_results = store
.update(cx, |store, cx| {
store.search_project(
project.clone(),
- "aaaa".to_string(),
+ "aaaaaabbbbzz".to_string(),
5,
include_files,
vec![],
@@ -127,23 +137,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
.await
.unwrap();
- for res in &search_results {
- res.buffer.read_with(cx, |buffer, _cx| {
- assert!(buffer
- .file()
- .unwrap()
- .path()
- .to_str()
- .unwrap()
- .ends_with("rs"));
- });
- }
+ assert_search_results(
+ &rust_only_search_results,
+ &[
+ (Path::new("src/file1.rs").into(), 0),
+ (Path::new("src/file2.rs").into(), 0),
+ (Path::new("src/file1.rs").into(), 45),
+ ],
+ cx,
+ );
- let search_results = store
+ let no_rust_search_results = store
.update(cx, |store, cx| {
store.search_project(
project.clone(),
- "aaaa".to_string(),
+ "aaaaaabbbbzz".to_string(),
5,
vec![],
exclude_files,
@@ -153,17 +161,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
.await
.unwrap();
- for res in &search_results {
- res.buffer.read_with(cx, |buffer, _cx| {
- assert!(!buffer
- .file()
- .unwrap()
- .path()
- .to_str()
- .unwrap()
- .ends_with("rs"));
- });
- }
+ assert_search_results(
+ &no_rust_search_results,
+ &[(Path::new("src/file3.toml").into(), 0)],
+ cx,
+ );
+
fs.save(
"/the-root/src/file2.rs".as_ref(),
&"
@@ -195,6 +198,26 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
);
}
+#[track_caller]
+fn assert_search_results(
+ actual: &[SearchResult],
+ expected: &[(Arc<Path>, usize)],
+ cx: &TestAppContext,
+) {
+ let actual = actual
+ .iter()
+ .map(|search_result| {
+ search_result.buffer.read_with(cx, |buffer, _cx| {
+ (
+ buffer.file().unwrap().path().clone(),
+ search_result.range.start.to_offset(buffer),
+ )
+ })
+ })
+ .collect::<Vec<_>>();
+ assert_eq!(actual, expected);
+}
+
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();