@@ -1,6 +1,7 @@
use std::{
collections::HashMap,
path::{Path, PathBuf},
+ rc::Rc,
};
use anyhow::{anyhow, Result};
@@ -258,22 +259,34 @@ impl VectorDatabase {
pub fn for_each_document(
&self,
- worktree_id: i64,
+ worktree_ids: &[i64],
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
- let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
+ let mut query_statement = self.db.prepare(
+ "
+ SELECT
+ documents.id, documents.embedding
+ FROM
+ documents, files
+ WHERE
+ documents.file_id = files.id AND
+ files.worktree_id IN rarray(?)
+ ",
+ )?;
query_statement
- .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
+ .query_map(params![ids_to_sql(worktree_ids)], |row| {
+ Ok((row.get(0)?, row.get(1)?))
+ })?
.filter_map(|row| row.ok())
.for_each(|row| f(row.0, row.1));
Ok(())
}
- pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
+ pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
let mut statement = self.db.prepare(
"
SELECT
- documents.id, files.relative_path, documents.offset, documents.name
+ documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
FROM
documents, files
WHERE
@@ -282,35 +295,28 @@ impl VectorDatabase {
",
)?;
- let result_iter = statement.query_map(
- params![std::rc::Rc::new(
- ids.iter()
- .copied()
- .map(|v| rusqlite::types::Value::from(v))
- .collect::<Vec<_>>()
- )],
- |row| {
- Ok((
- row.get::<_, i64>(0)?,
- row.get::<_, String>(1)?.into(),
- row.get(2)?,
- row.get(3)?,
- ))
- },
- )?;
+ let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
+ Ok((
+ row.get::<_, i64>(0)?,
+ row.get::<_, i64>(1)?,
+ row.get::<_, String>(2)?.into(),
+ row.get(3)?,
+ row.get(4)?,
+ ))
+ })?;
- let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
+ let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
for row in result_iter {
- let (id, path, offset, name) = row?;
- values_by_id.insert(id, (path, offset, name));
+ let (id, worktree_id, path, offset, name) = row?;
+ values_by_id.insert(id, (worktree_id, path, offset, name));
}
let mut results = Vec::with_capacity(ids.len());
for id in ids {
- let (path, offset, name) = values_by_id
+ let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing document id {}", id))?;
- results.push((path, offset, name));
+ results.push(value);
}
Ok(results)
@@ -339,3 +345,12 @@ impl VectorDatabase {
return Ok(documents);
}
}
+
+fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
+ Rc::new(
+ ids.iter()
+ .copied()
+ .map(|v| rusqlite::types::Value::from(v))
+ .collect::<Vec<_>>(),
+ )
+}
@@ -48,7 +48,9 @@ impl PickerDelegate for SemanticSearchDelegate {
}
fn confirm(&mut self, cx: &mut ViewContext<SemanticSearch>) {
- todo!()
+ if let Some(search_result) = self.matches.get(self.selected_match_index) {
+ // search_result.file_path
+ }
}
fn dismissed(&mut self, _cx: &mut ViewContext<SemanticSearch>) {}
@@ -66,9 +68,9 @@ impl PickerDelegate for SemanticSearchDelegate {
}
fn update_matches(&mut self, query: String, cx: &mut ViewContext<SemanticSearch>) -> Task<()> {
- let task = self
- .vector_store
- .update(cx, |store, cx| store.search(query.to_string(), 10, cx));
+ let task = self.vector_store.update(cx, |store, cx| {
+ store.search(&self.project, query.to_string(), 10, cx)
+ });
cx.spawn(|this, mut cx| async move {
let results = task.await.log_err();
@@ -90,7 +92,7 @@ impl PickerDelegate for SemanticSearchDelegate {
) -> AnyElement<Picker<Self>> {
let theme = theme::current(cx);
let style = &theme.picker.item;
- let current_style = style.style_for(mouse_state, selected);
+ let current_style = style.in_state(selected).style_for(mouse_state);
let search_result = &self.matches[ix];
@@ -99,7 +101,10 @@ impl PickerDelegate for SemanticSearchDelegate {
Flex::column()
.with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false))
- .with_child(Label::new(path.to_string(), style.default.label.clone()))
+ .with_child(Label::new(
+ path.to_string(),
+ style.inactive_state().default.label.clone(),
+ ))
.contained()
.with_style(current_style.container)
.into_any()
@@ -8,11 +8,11 @@ mod vector_store_tests;
use anyhow::{anyhow, Result};
use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
-use embedding::{EmbeddingProvider, OpenAIEmbeddings};
+use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
use language::{Language, LanguageRegistry};
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
-use project::{Fs, Project};
+use project::{Fs, Project, WorktreeId};
use smol::channel;
use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
use tree_sitter::{Parser, QueryCursor};
@@ -36,9 +36,10 @@ pub fn init(
VectorStore::new(
fs,
VECTOR_DB_URL.to_string(),
- Arc::new(OpenAIEmbeddings {
- client: http_client,
- }),
+ // Arc::new(OpenAIEmbeddings {
+ // client: http_client,
+ // }),
+ Arc::new(DummyEmbeddings {}),
language_registry,
)
});
@@ -75,25 +76,6 @@ pub fn init(
}
});
SemanticSearch::init(cx);
- // cx.add_action({
- // let vector_store = vector_store.clone();
- // move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
- // let t0 = std::time::Instant::now();
- // let task = vector_store.update(cx, |store, cx| {
- // store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
- // });
-
- // cx.spawn(|this, cx| async move {
- // let results = task.await?;
- // let duration = t0.elapsed();
-
- // println!("search took {:?}", duration);
- // println!("results {:?}", results);
-
- // anyhow::Ok(())
- // }).detach()
- // }
- // });
}
#[derive(Debug)]
@@ -108,10 +90,12 @@ pub struct VectorStore {
database_url: Arc<str>,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
+ worktree_db_ids: Vec<(WorktreeId, i64)>,
}
#[derive(Debug)]
pub struct SearchResult {
+ pub worktree_id: WorktreeId,
pub name: String,
pub offset: usize,
pub file_path: PathBuf,
@@ -129,6 +113,7 @@ impl VectorStore {
database_url: database_url.into(),
embedding_provider,
language_registry,
+ worktree_db_ids: Vec::new(),
}
}
@@ -178,9 +163,11 @@ impl VectorStore {
}
}
- let embeddings = embedding_provider.embed_batch(context_spans).await?;
- for (document, embedding) in documents.iter_mut().zip(embeddings) {
- document.embedding = embedding;
+ if !documents.is_empty() {
+ let embeddings = embedding_provider.embed_batch(context_spans).await?;
+ for (document, embedding) in documents.iter_mut().zip(embeddings) {
+ document.embedding = embedding;
+ }
}
let sha1 = FileSha1::from_str(content);
@@ -214,7 +201,7 @@ impl VectorStore {
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
- cx.spawn(|_, cx| async move {
+ cx.spawn(|this, mut cx| async move {
futures::future::join_all(worktree_scans_complete).await;
// TODO: remove this after fixing the bug in scan_complete
@@ -231,25 +218,24 @@ impl VectorStore {
.collect::<Vec<_>>()
});
- let worktree_root_paths = worktrees
- .iter()
- .map(|worktree| worktree.abs_path().clone())
- .collect::<Vec<_>>();
-
// Here we query the worktree ids, and yet we dont have them elsewhere
// We likely want to clean up these datastructures
- let (db, worktree_hashes, worktree_ids) = cx
+ let (db, worktree_hashes, worktree_db_ids) = cx
.background()
- .spawn(async move {
- let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
- let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
- for worktree_root_path in worktree_root_paths {
- let worktree_id =
- db.find_or_create_worktree(worktree_root_path.as_ref())?;
- worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
- hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
+ .spawn({
+ let worktrees = worktrees.clone();
+ async move {
+ let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
+ let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
+ HashMap::new();
+ for worktree in worktrees {
+ let worktree_db_id =
+ db.find_or_create_worktree(worktree.abs_path().as_ref())?;
+ worktree_db_ids.insert(worktree.id(), worktree_db_id);
+ hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
+ }
+ anyhow::Ok((db, hashes, worktree_db_ids))
}
- anyhow::Ok((db, hashes, worktree_ids))
})
.await?;
@@ -259,10 +245,10 @@ impl VectorStore {
cx.background()
.spawn({
let fs = fs.clone();
+ let worktree_db_ids = worktree_db_ids.clone();
async move {
for worktree in worktrees.into_iter() {
- let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
- let file_hashes = &worktree_hashes[&worktree_id];
+ let file_hashes = &worktree_hashes[&worktree.id()];
for file in worktree.files(false, 0) {
let absolute_path = worktree.absolutize(&file.path);
@@ -291,7 +277,7 @@ impl VectorStore {
);
paths_tx
.try_send((
- worktree_id,
+ worktree_db_ids[&worktree.id()],
path_buf,
content,
language,
@@ -382,54 +368,92 @@ impl VectorStore {
drop(indexed_files_tx);
db_write_task.await;
+
+ this.update(&mut cx, |this, _| {
+ this.worktree_db_ids.extend(worktree_db_ids);
+ });
+
anyhow::Ok(())
})
}
pub fn search(
&mut self,
+ project: &ModelHandle<Project>,
phrase: String,
limit: usize,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
+ let project = project.read(cx);
+ let worktree_db_ids = project
+ .worktrees(cx)
+ .filter_map(|worktree| {
+ let worktree_id = worktree.read(cx).id();
+ self.worktree_db_ids.iter().find_map(|(id, db_id)| {
+ if *id == worktree_id {
+ Some(*db_id)
+ } else {
+ None
+ }
+ })
+ })
+ .collect::<Vec<_>>();
+
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
- cx.background().spawn(async move {
- let database = VectorDatabase::new(database_url.as_ref())?;
-
- let phrase_embedding = embedding_provider
- .embed_batch(vec![&phrase])
- .await?
- .into_iter()
- .next()
- .unwrap();
-
- let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
- database.for_each_document(0, |id, embedding| {
- let similarity = dot(&embedding.0, &phrase_embedding);
- let ix = match results.binary_search_by(|(_, s)| {
- similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
- }) {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
- results.insert(ix, (id, similarity));
- results.truncate(limit);
- })?;
-
- let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
- let documents = database.get_documents_by_ids(&ids)?;
-
- anyhow::Ok(
+ cx.spawn(|this, cx| async move {
+ let documents = cx
+ .background()
+ .spawn(async move {
+ let database = VectorDatabase::new(database_url.as_ref())?;
+
+ let phrase_embedding = embedding_provider
+ .embed_batch(vec![&phrase])
+ .await?
+ .into_iter()
+ .next()
+ .unwrap();
+
+ let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+ database.for_each_document(&worktree_db_ids, |id, embedding| {
+ let similarity = dot(&embedding.0, &phrase_embedding);
+ let ix = match results.binary_search_by(|(_, s)| {
+ similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) => ix,
+ Err(ix) => ix,
+ };
+ results.insert(ix, (id, similarity));
+ results.truncate(limit);
+ })?;
+
+ let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
+ database.get_documents_by_ids(&ids)
+ })
+ .await?;
+
+ let results = this.read_with(&cx, |this, _| {
documents
.into_iter()
- .map(|(file_path, offset, name)| SearchResult {
- name,
- offset,
- file_path,
+ .filter_map(|(worktree_db_id, file_path, offset, name)| {
+ let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
+ if *db_id == worktree_db_id {
+ Some(*id)
+ } else {
+ None
+ }
+ })?;
+ Some(SearchResult {
+ worktree_id,
+ name,
+ offset,
+ file_path,
+ })
})
- .collect(),
- )
+ .collect()
+ });
+
+ anyhow::Ok(results)
})
}
}
@@ -70,7 +70,10 @@ async fn test_vector_store(cx: &mut TestAppContext) {
});
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
- let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
+ let worktree_id = project.read_with(cx, |project, cx| {
+ project.worktrees(cx).next().unwrap().read(cx).id()
+ });
+ let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
// TODO - remove
cx.foreground()
@@ -79,12 +82,15 @@ async fn test_vector_store(cx: &mut TestAppContext) {
add_project.await.unwrap();
let search_results = store
- .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
+ .update(cx, |store, cx| {
+ store.search(&project, "aaaa".to_string(), 5, cx)
+ })
.await
.unwrap();
assert_eq!(search_results[0].offset, 0);
assert_eq!(search_results[0].name, "aaa");
+ assert_eq!(search_results[0].worktree_id, worktree_id);
}
#[test]