From 400d39740ca505c3b5f143818c0ebe8eeead0e6e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 28 Jun 2023 16:21:03 -0400 Subject: [PATCH] updated both indexing and search method for vector store, to maintain both zed worktree ids and db worktree ids Co-authored-by: maxbrunsfeld --- crates/vector_store/src/db.rs | 67 ++++--- crates/vector_store/src/modal.rs | 17 +- crates/vector_store/src/vector_store.rs | 182 ++++++++++-------- crates/vector_store/src/vector_store_tests.rs | 10 +- 4 files changed, 163 insertions(+), 113 deletions(-) diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index f074a7066b08389ac55fd44c67598e20509f34cc..96856936fc6a9c3e3b3dd60b2a7a171642b27581 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, path::{Path, PathBuf}, + rc::Rc, }; use anyhow::{anyhow, Result}; @@ -258,22 +259,34 @@ impl VectorDatabase { pub fn for_each_document( &self, - worktree_id: i64, + worktree_ids: &[i64], mut f: impl FnMut(i64, Embedding), ) -> Result<()> { - let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?; + let mut query_statement = self.db.prepare( + " + SELECT + documents.id, documents.embedding + FROM + documents, files + WHERE + documents.file_id = files.id AND + files.worktree_id IN rarray(?) + ", + )?; query_statement - .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))? + .query_map(params![ids_to_sql(worktree_ids)], |row| { + Ok((row.get(0)?, row.get(1)?)) + })? .filter_map(|row| row.ok()) .for_each(|row| f(row.0, row.1)); Ok(()) } - pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { + pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { let mut statement = self.db.prepare( " SELECT - documents.id, files.relative_path, documents.offset, documents.name + documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name FROM documents, files WHERE @@ -282,35 +295,28 @@ impl VectorDatabase { ", )?; - let result_iter = statement.query_map( - params![std::rc::Rc::new( - ids.iter() - .copied() - .map(|v| rusqlite::types::Value::from(v)) - .collect::>() - )], - |row| { - Ok(( - row.get::<_, i64>(0)?, - row.get::<_, String>(1)?.into(), - row.get(2)?, - row.get(3)?, - )) - }, - )?; + let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, i64>(1)?, + row.get::<_, String>(2)?.into(), + row.get(3)?, + row.get(4)?, + )) + })?; - let mut values_by_id = HashMap::::default(); + let mut values_by_id = HashMap::::default(); for row in result_iter { - let (id, path, offset, name) = row?; - values_by_id.insert(id, (path, offset, name)); + let (id, worktree_id, path, offset, name) = row?; + values_by_id.insert(id, (worktree_id, path, offset, name)); } let mut results = Vec::with_capacity(ids.len()); for id in ids { - let (path, offset, name) = values_by_id + let value = values_by_id .remove(id) .ok_or(anyhow!("missing document id {}", id))?; - results.push((path, offset, name)); + results.push(value); } Ok(results) @@ -339,3 +345,12 @@ impl VectorDatabase { return Ok(documents); } } + +fn ids_to_sql(ids: &[i64]) -> Rc> { + Rc::new( + ids.iter() + .copied() + .map(|v| rusqlite::types::Value::from(v)) + .collect::>(), + ) +} diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index 48429150cd88a3946edcd6aa68eea883f76786c0..8052277a0bf37b13fcc16d26e4081484cec05894 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -48,7 +48,9 @@ impl PickerDelegate for SemanticSearchDelegate { } fn confirm(&mut self, cx: &mut ViewContext) { - todo!() + if let Some(search_result) = self.matches.get(self.selected_match_index) { + // search_result.file_path + } } fn dismissed(&mut self, _cx: &mut ViewContext) {} @@ -66,9 +68,9 @@ impl PickerDelegate for SemanticSearchDelegate { } fn update_matches(&mut self, query: String, cx: &mut ViewContext) -> Task<()> { - let task = self - .vector_store - .update(cx, |store, cx| store.search(query.to_string(), 10, cx)); + let task = self.vector_store.update(cx, |store, cx| { + store.search(&self.project, query.to_string(), 10, cx) + }); cx.spawn(|this, mut cx| async move { let results = task.await.log_err(); @@ -90,7 +92,7 @@ impl PickerDelegate for SemanticSearchDelegate { ) -> AnyElement> { let theme = theme::current(cx); let style = &theme.picker.item; - let current_style = style.style_for(mouse_state, selected); + let current_style = style.in_state(selected).style_for(mouse_state); let search_result = &self.matches[ix]; @@ -99,7 +101,10 @@ impl PickerDelegate for SemanticSearchDelegate { Flex::column() .with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false)) - .with_child(Label::new(path.to_string(), style.default.label.clone())) + .with_child(Label::new( + path.to_string(), + style.inactive_state().default.label.clone(), + )) .contained() .with_style(current_style.container) .into_any() diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 2dc479045f5ae7e4cbd26bf63b5a797d77a438b2..92926b1f752adb799eaef7ae4f63dec92df553ce 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -8,11 +8,11 @@ mod vector_store_tests; use anyhow::{anyhow, Result}; use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; -use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext}; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; -use project::{Fs, Project}; +use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; use tree_sitter::{Parser, QueryCursor}; @@ -36,9 +36,10 @@ pub fn init( VectorStore::new( fs, VECTOR_DB_URL.to_string(), - Arc::new(OpenAIEmbeddings { - client: http_client, - }), + // Arc::new(OpenAIEmbeddings { + // client: http_client, + // }), + Arc::new(DummyEmbeddings {}), language_registry, ) }); @@ -75,25 +76,6 @@ pub fn init( } }); SemanticSearch::init(cx); - // cx.add_action({ - // let vector_store = vector_store.clone(); - // move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext| { - // let t0 = std::time::Instant::now(); - // let task = vector_store.update(cx, |store, cx| { - // store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx) - // }); - - // cx.spawn(|this, cx| async move { - // let results = task.await?; - // let duration = t0.elapsed(); - - // println!("search took {:?}", duration); - // println!("results {:?}", results); - - // anyhow::Ok(()) - // }).detach() - // } - // }); } #[derive(Debug)] @@ -108,10 +90,12 @@ pub struct VectorStore { database_url: Arc, embedding_provider: Arc, language_registry: Arc, + worktree_db_ids: Vec<(WorktreeId, i64)>, } #[derive(Debug)] pub struct SearchResult { + pub worktree_id: WorktreeId, pub name: String, pub offset: usize, pub file_path: PathBuf, @@ -129,6 +113,7 @@ impl VectorStore { database_url: database_url.into(), embedding_provider, language_registry, + worktree_db_ids: Vec::new(), } } @@ -178,9 +163,11 @@ impl VectorStore { } } - let embeddings = embedding_provider.embed_batch(context_spans).await?; - for (document, embedding) in documents.iter_mut().zip(embeddings) { - document.embedding = embedding; + if !documents.is_empty() { + let embeddings = embedding_provider.embed_batch(context_spans).await?; + for (document, embedding) in documents.iter_mut().zip(embeddings) { + document.embedding = embedding; + } } let sha1 = FileSha1::from_str(content); @@ -214,7 +201,7 @@ impl VectorStore { let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); - cx.spawn(|_, cx| async move { + cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; // TODO: remove this after fixing the bug in scan_complete @@ -231,25 +218,24 @@ impl VectorStore { .collect::>() }); - let worktree_root_paths = worktrees - .iter() - .map(|worktree| worktree.abs_path().clone()) - .collect::>(); - // Here we query the worktree ids, and yet we dont have them elsewhere // We likely want to clean up these datastructures - let (db, worktree_hashes, worktree_ids) = cx + let (db, worktree_hashes, worktree_db_ids) = cx .background() - .spawn(async move { - let mut worktree_ids: HashMap = HashMap::new(); - let mut hashes: HashMap> = HashMap::new(); - for worktree_root_path in worktree_root_paths { - let worktree_id = - db.find_or_create_worktree(worktree_root_path.as_ref())?; - worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id); - hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?); + .spawn({ + let worktrees = worktrees.clone(); + async move { + let mut worktree_db_ids: HashMap = HashMap::new(); + let mut hashes: HashMap> = + HashMap::new(); + for worktree in worktrees { + let worktree_db_id = + db.find_or_create_worktree(worktree.abs_path().as_ref())?; + worktree_db_ids.insert(worktree.id(), worktree_db_id); + hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?); + } + anyhow::Ok((db, hashes, worktree_db_ids)) } - anyhow::Ok((db, hashes, worktree_ids)) }) .await?; @@ -259,10 +245,10 @@ impl VectorStore { cx.background() .spawn({ let fs = fs.clone(); + let worktree_db_ids = worktree_db_ids.clone(); async move { for worktree in worktrees.into_iter() { - let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()]; - let file_hashes = &worktree_hashes[&worktree_id]; + let file_hashes = &worktree_hashes[&worktree.id()]; for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -291,7 +277,7 @@ impl VectorStore { ); paths_tx .try_send(( - worktree_id, + worktree_db_ids[&worktree.id()], path_buf, content, language, @@ -382,54 +368,92 @@ impl VectorStore { drop(indexed_files_tx); db_write_task.await; + + this.update(&mut cx, |this, _| { + this.worktree_db_ids.extend(worktree_db_ids); + }); + anyhow::Ok(()) }) } pub fn search( &mut self, + project: &ModelHandle, phrase: String, limit: usize, cx: &mut ModelContext, ) -> Task>> { + let project = project.read(cx); + let worktree_db_ids = project + .worktrees(cx) + .filter_map(|worktree| { + let worktree_id = worktree.read(cx).id(); + self.worktree_db_ids.iter().find_map(|(id, db_id)| { + if *id == worktree_id { + Some(*db_id) + } else { + None + } + }) + }) + .collect::>(); + let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); - cx.background().spawn(async move { - let database = VectorDatabase::new(database_url.as_ref())?; - - let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) - .await? - .into_iter() - .next() - .unwrap(); - - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - database.for_each_document(0, |id, embedding| { - let similarity = dot(&embedding.0, &phrase_embedding); - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; - - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - let documents = database.get_documents_by_ids(&ids)?; - - anyhow::Ok( + cx.spawn(|this, cx| async move { + let documents = cx + .background() + .spawn(async move { + let database = VectorDatabase::new(database_url.as_ref())?; + + let phrase_embedding = embedding_provider + .embed_batch(vec![&phrase]) + .await? + .into_iter() + .next() + .unwrap(); + + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + database.for_each_document(&worktree_db_ids, |id, embedding| { + let similarity = dot(&embedding.0, &phrase_embedding); + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; + + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + database.get_documents_by_ids(&ids) + }) + .await?; + + let results = this.read_with(&cx, |this, _| { documents .into_iter() - .map(|(file_path, offset, name)| SearchResult { - name, - offset, - file_path, + .filter_map(|(worktree_db_id, file_path, offset, name)| { + let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| { + if *db_id == worktree_db_id { + Some(*id) + } else { + None + } + })?; + Some(SearchResult { + worktree_id, + name, + offset, + file_path, + }) }) - .collect(), - ) + .collect() + }); + + anyhow::Ok(results) }) } } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index c67bb9954fd884e75715e6641b7eb87b04810b36..6f8856c4fb898392d8d969a293a3a5e8f8970222 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -70,7 +70,10 @@ async fn test_vector_store(cx: &mut TestAppContext) { }); let project = Project::test(fs, ["/the-root".as_ref()], cx).await; - let add_project = store.update(cx, |store, cx| store.add_project(project, cx)); + let worktree_id = project.read_with(cx, |project, cx| { + project.worktrees(cx).next().unwrap().read(cx).id() + }); + let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx)); // TODO - remove cx.foreground() @@ -79,12 +82,15 @@ async fn test_vector_store(cx: &mut TestAppContext) { add_project.await.unwrap(); let search_results = store - .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx)) + .update(cx, |store, cx| { + store.search(&project, "aaaa".to_string(), 5, cx) + }) .await .unwrap(); assert_eq!(search_results[0].offset, 0); assert_eq!(search_results[0].name, "aaa"); + assert_eq!(search_results[0].worktree_id, worktree_id); } #[test]