move embeddings queue to use single hashmap for all changed paths

KCaverly and Antonio created

Co-authored-by: Antonio <me@as-cii.com>

Change summary

crates/semantic_index/src/db.rs                   | 79 ++++++----------
crates/semantic_index/src/semantic_index.rs       | 14 ++
crates/semantic_index/src/semantic_index_tests.rs |  5 
3 files changed, 46 insertions(+), 52 deletions(-)

Detailed changes

crates/semantic_index/src/db.rs 🔗

@@ -265,58 +265,43 @@ impl VectorDatabase {
         })
     }
 
-    pub fn embeddings_for_file(
+    pub fn embeddings_for_files(
         &self,
-        worktree_id: i64,
-        relative_path: PathBuf,
+        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
     ) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
-        let relative_path = relative_path.to_string_lossy().into_owned();
         self.transact(move |db| {
-            let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE files.worktree_id = ?1 AND files.relative_path = ?2")?;
-            let mut result: HashMap<DocumentDigest, Embedding> = HashMap::new();
-            for row in query.query_map(params![worktree_id, relative_path], |row| {
-                Ok((row.get::<_, DocumentDigest>(0)?.into(), row.get::<_, Embedding>(1)?.into()))
-            })? {
-                let row = row?;
-                result.insert(row.0, row.1);
+            let mut query = db.prepare(
+                "
+                SELECT digest, embedding
+                FROM documents
+                LEFT JOIN files ON files.id = documents.file_id
+                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
+            ",
+            )?;
+            let mut embeddings_by_digest = HashMap::new();
+            for (worktree_id, file_paths) in worktree_id_file_paths {
+                let file_paths = Rc::new(
+                    file_paths
+                        .into_iter()
+                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
+                        .collect::<Vec<_>>(),
+                );
+                let rows = query.query_map(params![worktree_id, file_paths], |row| {
+                    Ok((
+                        row.get::<_, DocumentDigest>(0)?,
+                        row.get::<_, Embedding>(1)?,
+                    ))
+                })?;
+
+                for row in rows {
+                    if let Ok(row) = row {
+                        embeddings_by_digest.insert(row.0, row.1);
+                    }
+                }
             }
-            Ok(result)
-        })
-    }
 
-    pub fn embeddings_for_files(
-        &self,
-        worktree_id_file_paths: Vec<(i64, PathBuf)>,
-    ) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
-        todo!();
-        // The remainder of the code is wired up.
-        // I'm having a bit of trouble figuring out the rusqlite syntax for a WHERE (files.worktree_id, files.relative_path) IN (VALUES (?, ?), (?, ?)) query
-        async { Ok(HashMap::new()) }
-        // let mut embeddings_by_digest = HashMap::new();
-        // self.transact(move |db| {
-
-        //     let worktree_ids: Rc<Vec<Value>> = Rc::new(
-        //         worktree_id_file_paths
-        //             .iter()
-        //             .map(|(id, _)| Value::from(*id))
-        //             .collect(),
-        //     );
-        //     let file_paths: Rc<Vec<Value>> = Rc::new(worktree_id_file_paths
-        //         .iter()
-        //         .map(|(_, path)| Value::from(path.to_string_lossy().to_string()))
-        //         .collect());
-
-        //     let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE (files.worktree_id, files.relative_path) IN (VALUES (rarray = (?1), rarray = (?2))")?;
-
-        //     for row in query.query_map(params![worktree_ids, file_paths], |row| {
-        //         Ok((row.get::<_, DocumentDigest>(0)?, row.get::<_, Embedding>(1)?))
-        //     })? {
-        //         if let Ok(row) = row {
-        //             embeddings_by_digest.insert(row.0, row.1);
-        //         }
-        //     }
-        //     Ok(embeddings_by_digest)
-        // })
+            Ok(embeddings_by_digest)
+        })
     }
 
     pub fn find_or_create_worktree(

crates/semantic_index/src/semantic_index.rs 🔗

@@ -379,11 +379,14 @@ impl SemanticIndex {
         };
 
         let embeddings_for_digest = {
-            let mut worktree_id_file_paths = Vec::new();
+            let mut worktree_id_file_paths = HashMap::new();
             for (path, _) in &project_state.changed_paths {
                 if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id)
                 {
-                    worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf()));
+                    worktree_id_file_paths
+                        .entry(worktree_db_id)
+                        .or_insert(Vec::new())
+                        .push(path.path.clone());
                 }
             }
             self.db.embeddings_for_files(worktree_id_file_paths)
@@ -580,11 +583,14 @@ impl SemanticIndex {
         cx.spawn(|this, mut cx| async move {
             let embeddings_for_digest = this.read_with(&cx, |this, cx| {
                 if let Some(state) = this.projects.get(&project.downgrade()) {
-                    let mut worktree_id_file_paths = Vec::new();
+                    let mut worktree_id_file_paths = HashMap::default();
                     for (path, _) in &state.changed_paths {
                         if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id)
                         {
-                            worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf()));
+                            worktree_id_file_paths
+                                .entry(worktree_db_id)
+                                .or_insert(Vec::new())
+                                .push(path.path.clone());
                         }
                     }
 

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -55,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
                     fn bbb() {
                         println!(\"bbbbbbbbbbbbb!\");
                     }
+                    struct pqpqpqp {}
                 ".unindent(),
                 "file3.toml": "
                     ZZZZZZZZZZZZZZZZZZ = 5
@@ -121,6 +122,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
             (Path::new("src/file2.rs").into(), 0),
             (Path::new("src/file3.toml").into(), 0),
             (Path::new("src/file1.rs").into(), 45),
+            (Path::new("src/file2.rs").into(), 45),
         ],
         cx,
     );
@@ -148,6 +150,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
             (Path::new("src/file1.rs").into(), 0),
             (Path::new("src/file2.rs").into(), 0),
             (Path::new("src/file1.rs").into(), 45),
+            (Path::new("src/file2.rs").into(), 45),
         ],
         cx,
     );
@@ -199,7 +202,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
 
     assert_eq!(
         embedding_provider.embedding_count() - prev_embedding_count,
-        2
+        1
     );
 }