updated both indexing and search method for vector store, to maintain both zed worktree ids and db worktree ids

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

crates/vector_store/src/db.rs                 |  67 ++++---
crates/vector_store/src/modal.rs              |  17 +
crates/vector_store/src/vector_store.rs       | 182 +++++++++++---------
crates/vector_store/src/vector_store_tests.rs |  10 
4 files changed, 163 insertions(+), 113 deletions(-)

Detailed changes

crates/vector_store/src/db.rs 🔗

@@ -1,6 +1,7 @@
 use std::{
     collections::HashMap,
     path::{Path, PathBuf},
+    rc::Rc,
 };
 
 use anyhow::{anyhow, Result};
@@ -258,22 +259,34 @@ impl VectorDatabase {
 
     pub fn for_each_document(
         &self,
-        worktree_id: i64,
+        worktree_ids: &[i64],
         mut f: impl FnMut(i64, Embedding),
     ) -> Result<()> {
-        let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
+        let mut query_statement = self.db.prepare(
+            "
+            SELECT
+                documents.id, documents.embedding
+            FROM
+                documents, files
+            WHERE
+                documents.file_id = files.id AND
+                files.worktree_id IN rarray(?)
+            ",
+        )?;
         query_statement
-            .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
+            .query_map(params![ids_to_sql(worktree_ids)], |row| {
+                Ok((row.get(0)?, row.get(1)?))
+            })?
             .filter_map(|row| row.ok())
             .for_each(|row| f(row.0, row.1));
         Ok(())
     }
 
-    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
+    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
         let mut statement = self.db.prepare(
             "
                 SELECT
-                    documents.id, files.relative_path, documents.offset, documents.name
+                    documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
                 FROM
                     documents, files
                 WHERE
@@ -282,35 +295,28 @@ impl VectorDatabase {
             ",
         )?;
 
-        let result_iter = statement.query_map(
-            params![std::rc::Rc::new(
-                ids.iter()
-                    .copied()
-                    .map(|v| rusqlite::types::Value::from(v))
-                    .collect::<Vec<_>>()
-            )],
-            |row| {
-                Ok((
-                    row.get::<_, i64>(0)?,
-                    row.get::<_, String>(1)?.into(),
-                    row.get(2)?,
-                    row.get(3)?,
-                ))
-            },
-        )?;
+        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
+            Ok((
+                row.get::<_, i64>(0)?,
+                row.get::<_, i64>(1)?,
+                row.get::<_, String>(2)?.into(),
+                row.get(3)?,
+                row.get(4)?,
+            ))
+        })?;
 
-        let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
+        let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
         for row in result_iter {
-            let (id, path, offset, name) = row?;
-            values_by_id.insert(id, (path, offset, name));
+            let (id, worktree_id, path, offset, name) = row?;
+            values_by_id.insert(id, (worktree_id, path, offset, name));
         }
 
         let mut results = Vec::with_capacity(ids.len());
         for id in ids {
-            let (path, offset, name) = values_by_id
+            let value = values_by_id
                 .remove(id)
                 .ok_or(anyhow!("missing document id {}", id))?;
-            results.push((path, offset, name));
+            results.push(value);
         }
 
         Ok(results)
@@ -339,3 +345,12 @@ impl VectorDatabase {
         return Ok(documents);
     }
 }
+
+fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
+    Rc::new(
+        ids.iter()
+            .copied()
+            .map(|v| rusqlite::types::Value::from(v))
+            .collect::<Vec<_>>(),
+    )
+}

crates/vector_store/src/modal.rs 🔗

@@ -48,7 +48,9 @@ impl PickerDelegate for SemanticSearchDelegate {
     }
 
     fn confirm(&mut self, cx: &mut ViewContext<SemanticSearch>) {
-        todo!()
+        if let Some(search_result) = self.matches.get(self.selected_match_index) {
+            // search_result.file_path
+        }
     }
 
     fn dismissed(&mut self, _cx: &mut ViewContext<SemanticSearch>) {}
@@ -66,9 +68,9 @@ impl PickerDelegate for SemanticSearchDelegate {
     }
 
     fn update_matches(&mut self, query: String, cx: &mut ViewContext<SemanticSearch>) -> Task<()> {
-        let task = self
-            .vector_store
-            .update(cx, |store, cx| store.search(query.to_string(), 10, cx));
+        let task = self.vector_store.update(cx, |store, cx| {
+            store.search(&self.project, query.to_string(), 10, cx)
+        });
 
         cx.spawn(|this, mut cx| async move {
             let results = task.await.log_err();
@@ -90,7 +92,7 @@ impl PickerDelegate for SemanticSearchDelegate {
     ) -> AnyElement<Picker<Self>> {
         let theme = theme::current(cx);
         let style = &theme.picker.item;
-        let current_style = style.style_for(mouse_state, selected);
+        let current_style = style.in_state(selected).style_for(mouse_state);
 
         let search_result = &self.matches[ix];
 
@@ -99,7 +101,10 @@ impl PickerDelegate for SemanticSearchDelegate {
 
         Flex::column()
             .with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false))
-            .with_child(Label::new(path.to_string(), style.default.label.clone()))
+            .with_child(Label::new(
+                path.to_string(),
+                style.inactive_state().default.label.clone(),
+            ))
             .contained()
             .with_style(current_style.container)
             .into_any()

crates/vector_store/src/vector_store.rs 🔗

@@ -8,11 +8,11 @@ mod vector_store_tests;
 
 use anyhow::{anyhow, Result};
 use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
-use embedding::{EmbeddingProvider, OpenAIEmbeddings};
+use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
 use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 use language::{Language, LanguageRegistry};
 use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
-use project::{Fs, Project};
+use project::{Fs, Project, WorktreeId};
 use smol::channel;
 use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 use tree_sitter::{Parser, QueryCursor};
@@ -36,9 +36,10 @@ pub fn init(
         VectorStore::new(
             fs,
             VECTOR_DB_URL.to_string(),
-            Arc::new(OpenAIEmbeddings {
-                client: http_client,
-            }),
+            // Arc::new(OpenAIEmbeddings {
+            //     client: http_client,
+            // }),
+            Arc::new(DummyEmbeddings {}),
             language_registry,
         )
     });
@@ -75,25 +76,6 @@ pub fn init(
         }
     });
     SemanticSearch::init(cx);
-    // cx.add_action({
-    //     let vector_store = vector_store.clone();
-    //     move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
-    //         let t0 = std::time::Instant::now();
-    //         let task = vector_store.update(cx, |store, cx| {
-    //             store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
-    //         });
-
-    //         cx.spawn(|this, cx| async move {
-    //             let results = task.await?;
-    //             let duration = t0.elapsed();
-
-    //             println!("search took {:?}", duration);
-    //             println!("results {:?}", results);
-
-    //             anyhow::Ok(())
-    //         }).detach()
-    //     }
-    // });
 }
 
 #[derive(Debug)]
@@ -108,10 +90,12 @@ pub struct VectorStore {
     database_url: Arc<str>,
     embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
+    worktree_db_ids: Vec<(WorktreeId, i64)>,
 }
 
 #[derive(Debug)]
 pub struct SearchResult {
+    pub worktree_id: WorktreeId,
     pub name: String,
     pub offset: usize,
     pub file_path: PathBuf,
@@ -129,6 +113,7 @@ impl VectorStore {
             database_url: database_url.into(),
             embedding_provider,
             language_registry,
+            worktree_db_ids: Vec::new(),
         }
     }
 
@@ -178,9 +163,11 @@ impl VectorStore {
             }
         }
 
-        let embeddings = embedding_provider.embed_batch(context_spans).await?;
-        for (document, embedding) in documents.iter_mut().zip(embeddings) {
-            document.embedding = embedding;
+        if !documents.is_empty() {
+            let embeddings = embedding_provider.embed_batch(context_spans).await?;
+            for (document, embedding) in documents.iter_mut().zip(embeddings) {
+                document.embedding = embedding;
+            }
         }
 
         let sha1 = FileSha1::from_str(content);
@@ -214,7 +201,7 @@ impl VectorStore {
         let embedding_provider = self.embedding_provider.clone();
         let database_url = self.database_url.clone();
 
-        cx.spawn(|_, cx| async move {
+        cx.spawn(|this, mut cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
 
             // TODO: remove this after fixing the bug in scan_complete
@@ -231,25 +218,24 @@ impl VectorStore {
                     .collect::<Vec<_>>()
             });
 
-            let worktree_root_paths = worktrees
-                .iter()
-                .map(|worktree| worktree.abs_path().clone())
-                .collect::<Vec<_>>();
-
             // Here we query the worktree ids, and yet we dont have them elsewhere
             // We likely want to clean up these datastructures
-            let (db, worktree_hashes, worktree_ids) = cx
+            let (db, worktree_hashes, worktree_db_ids) = cx
                 .background()
-                .spawn(async move {
-                    let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
-                    let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
-                    for worktree_root_path in worktree_root_paths {
-                        let worktree_id =
-                            db.find_or_create_worktree(worktree_root_path.as_ref())?;
-                        worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
-                        hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
+                .spawn({
+                    let worktrees = worktrees.clone();
+                    async move {
+                        let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
+                        let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
+                            HashMap::new();
+                        for worktree in worktrees {
+                            let worktree_db_id =
+                                db.find_or_create_worktree(worktree.abs_path().as_ref())?;
+                            worktree_db_ids.insert(worktree.id(), worktree_db_id);
+                            hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
+                        }
+                        anyhow::Ok((db, hashes, worktree_db_ids))
                     }
-                    anyhow::Ok((db, hashes, worktree_ids))
                 })
                 .await?;
 
@@ -259,10 +245,10 @@ impl VectorStore {
             cx.background()
                 .spawn({
                     let fs = fs.clone();
+                    let worktree_db_ids = worktree_db_ids.clone();
                     async move {
                         for worktree in worktrees.into_iter() {
-                            let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
-                            let file_hashes = &worktree_hashes[&worktree_id];
+                            let file_hashes = &worktree_hashes[&worktree.id()];
                             for file in worktree.files(false, 0) {
                                 let absolute_path = worktree.absolutize(&file.path);
 
@@ -291,7 +277,7 @@ impl VectorStore {
                                             );
                                             paths_tx
                                                 .try_send((
-                                                    worktree_id,
+                                                    worktree_db_ids[&worktree.id()],
                                                     path_buf,
                                                     content,
                                                     language,
@@ -382,54 +368,92 @@ impl VectorStore {
             drop(indexed_files_tx);
 
             db_write_task.await;
+
+            this.update(&mut cx, |this, _| {
+                this.worktree_db_ids.extend(worktree_db_ids);
+            });
+
             anyhow::Ok(())
         })
     }
 
     pub fn search(
         &mut self,
+        project: &ModelHandle<Project>,
         phrase: String,
         limit: usize,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<SearchResult>>> {
+        let project = project.read(cx);
+        let worktree_db_ids = project
+            .worktrees(cx)
+            .filter_map(|worktree| {
+                let worktree_id = worktree.read(cx).id();
+                self.worktree_db_ids.iter().find_map(|(id, db_id)| {
+                    if *id == worktree_id {
+                        Some(*db_id)
+                    } else {
+                        None
+                    }
+                })
+            })
+            .collect::<Vec<_>>();
+
         let embedding_provider = self.embedding_provider.clone();
         let database_url = self.database_url.clone();
-        cx.background().spawn(async move {
-            let database = VectorDatabase::new(database_url.as_ref())?;
-
-            let phrase_embedding = embedding_provider
-                .embed_batch(vec![&phrase])
-                .await?
-                .into_iter()
-                .next()
-                .unwrap();
-
-            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-            database.for_each_document(0, |id, embedding| {
-                let similarity = dot(&embedding.0, &phrase_embedding);
-                let ix = match results.binary_search_by(|(_, s)| {
-                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
-                }) {
-                    Ok(ix) => ix,
-                    Err(ix) => ix,
-                };
-                results.insert(ix, (id, similarity));
-                results.truncate(limit);
-            })?;
-
-            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
-            let documents = database.get_documents_by_ids(&ids)?;
-
-            anyhow::Ok(
+        cx.spawn(|this, cx| async move {
+            let documents = cx
+                .background()
+                .spawn(async move {
+                    let database = VectorDatabase::new(database_url.as_ref())?;
+
+                    let phrase_embedding = embedding_provider
+                        .embed_batch(vec![&phrase])
+                        .await?
+                        .into_iter()
+                        .next()
+                        .unwrap();
+
+                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+                    database.for_each_document(&worktree_db_ids, |id, embedding| {
+                        let similarity = dot(&embedding.0, &phrase_embedding);
+                        let ix = match results.binary_search_by(|(_, s)| {
+                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+                        }) {
+                            Ok(ix) => ix,
+                            Err(ix) => ix,
+                        };
+                        results.insert(ix, (id, similarity));
+                        results.truncate(limit);
+                    })?;
+
+                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
+                    database.get_documents_by_ids(&ids)
+                })
+                .await?;
+
+            let results = this.read_with(&cx, |this, _| {
                 documents
                     .into_iter()
-                    .map(|(file_path, offset, name)| SearchResult {
-                        name,
-                        offset,
-                        file_path,
+                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
+                        let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
+                            if *db_id == worktree_db_id {
+                                Some(*id)
+                            } else {
+                                None
+                            }
+                        })?;
+                        Some(SearchResult {
+                            worktree_id,
+                            name,
+                            offset,
+                            file_path,
+                        })
                     })
-                    .collect(),
-            )
+                    .collect()
+            });
+
+            anyhow::Ok(results)
         })
     }
 }

crates/vector_store/src/vector_store_tests.rs 🔗

@@ -70,7 +70,10 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     });
 
     let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
-    let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
+    let worktree_id = project.read_with(cx, |project, cx| {
+        project.worktrees(cx).next().unwrap().read(cx).id()
+    });
+    let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
 
     // TODO - remove
     cx.foreground()
@@ -79,12 +82,15 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     add_project.await.unwrap();
 
     let search_results = store
-        .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
+        .update(cx, |store, cx| {
+            store.search(&project, "aaaa".to_string(), 5, cx)
+        })
         .await
         .unwrap();
 
     assert_eq!(search_results[0].offset, 0);
     assert_eq!(search_results[0].name, "aaa");
+    assert_eq!(search_results[0].worktree_id, worktree_id);
 }
 
 #[test]