WIP: work towards wiring up a embeddings_for_digest hashmap that is stored for all indexed files

KCaverly created

Change summary

crates/semantic_index/src/db.rs             | 36 +++++++++
crates/semantic_index/src/semantic_index.rs | 91 +++++++++++++++++-----
2 files changed, 104 insertions(+), 23 deletions(-)

Detailed changes

crates/semantic_index/src/db.rs 🔗

@@ -9,6 +9,7 @@ use gpui::executor;
 use project::{search::PathMatcher, Fs};
 use rpc::proto::Timestamp;
 use rusqlite::params;
+use rusqlite::types::Value;
 use std::{
     cmp::Ordering,
     collections::HashMap,
@@ -283,6 +284,41 @@ impl VectorDatabase {
         })
     }
 
+    pub fn embeddings_for_files(
+        &self,
+        worktree_id_file_paths: Vec<(i64, PathBuf)>,
+    ) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
+        todo!();
+        // The remainder of the code is wired up.
+        // I'm having a bit of trouble figuring out the rusqlite syntax for a WHERE (files.worktree_id, files.relative_path) IN (VALUES (?, ?), (?, ?)) query
+        async { Ok(HashMap::new()) }
+        // let mut embeddings_by_digest = HashMap::new();
+        // self.transact(move |db| {
+
+        //     let worktree_ids: Rc<Vec<Value>> = Rc::new(
+        //         worktree_id_file_paths
+        //             .iter()
+        //             .map(|(id, _)| Value::from(*id))
+        //             .collect(),
+        //     );
+        //     let file_paths: Rc<Vec<Value>> = Rc::new(worktree_id_file_paths
+        //         .iter()
+        //         .map(|(_, path)| Value::from(path.to_string_lossy().to_string()))
+        //         .collect());
+
+        //     let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE (files.worktree_id, files.relative_path) IN (VALUES (rarray = (?1), rarray = (?2))")?;
+
+        //     for row in query.query_map(params![worktree_ids, file_paths], |row| {
+        //         Ok((row.get::<_, DocumentDigest>(0)?, row.get::<_, Embedding>(1)?))
+        //     })? {
+        //         if let Ok(row) = row {
+        //             embeddings_by_digest.insert(row.0, row.1);
+        //         }
+        //     }
+        //     Ok(embeddings_by_digest)
+        // })
+    }
+
     pub fn find_or_create_worktree(
         &self,
         worktree_root_path: PathBuf,

crates/semantic_index/src/semantic_index.rs 🔗

@@ -10,12 +10,12 @@ mod semantic_index_tests;
 use crate::semantic_index_settings::SemanticIndexSettings;
 use anyhow::{anyhow, Result};
 use db::VectorDatabase;
-use embedding::{EmbeddingProvider, OpenAIEmbeddings};
+use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
 use embedding_queue::{EmbeddingQueue, FileToEmbed};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES};
+use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES};
 use postage::watch;
 use project::{
     search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId,
@@ -103,7 +103,7 @@ pub struct SemanticIndex {
     db: VectorDatabase,
     embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
-    parsing_files_tx: channel::Sender<PendingFile>,
+    parsing_files_tx: channel::Sender<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>,
     _embedding_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
@@ -247,7 +247,8 @@ impl SemanticIndex {
             });
 
             // Parse files into embeddable documents.
-            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
+            let (parsing_files_tx, parsing_files_rx) =
+                channel::unbounded::<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>();
             let embedding_queue = Arc::new(Mutex::new(embedding_queue));
             let mut _parsing_files_tasks = Vec::new();
             for _ in 0..cx.background().num_cpus() {
@@ -258,14 +259,16 @@ impl SemanticIndex {
                 let db = db.clone();
                 _parsing_files_tasks.push(cx.background().spawn(async move {
                     let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
-                    while let Ok(pending_file) = parsing_files_rx.recv().await {
+                    while let Ok((embeddings_for_digest, pending_file)) =
+                        parsing_files_rx.recv().await
+                    {
                         Self::parse_file(
                             &fs,
                             pending_file,
                             &mut retriever,
                             &embedding_queue,
                             &parsing_files_rx,
-                            &db,
+                            &embeddings_for_digest,
                         )
                         .await;
                     }
@@ -294,8 +297,11 @@ impl SemanticIndex {
         pending_file: PendingFile,
         retriever: &mut CodeContextRetriever,
         embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
-        parsing_files_rx: &channel::Receiver<PendingFile>,
-        db: &VectorDatabase,
+        parsing_files_rx: &channel::Receiver<(
+            Arc<HashMap<DocumentDigest, Embedding>>,
+            PendingFile,
+        )>,
+        embeddings_for_digest: &HashMap<DocumentDigest, Embedding>,
     ) {
         let Some(language) = pending_file.language else {
             return;
@@ -312,18 +318,9 @@ impl SemanticIndex {
                     documents.len()
                 );
 
-                if let Some(sha_to_embeddings) = db
-                    .embeddings_for_file(
-                        pending_file.worktree_db_id,
-                        pending_file.relative_path.clone(),
-                    )
-                    .await
-                    .log_err()
-                {
-                    for document in documents.iter_mut() {
-                        if let Some(embedding) = sha_to_embeddings.get(&document.digest) {
-                            document.embedding = Some(embedding.to_owned());
-                        }
+                for document in documents.iter_mut() {
+                    if let Some(embedding) = embeddings_for_digest.get(&document.digest) {
+                        document.embedding = Some(embedding.to_owned());
                     }
                 }
 
@@ -381,6 +378,17 @@ impl SemanticIndex {
             return;
         };
 
+        let embeddings_for_digest = {
+            let mut worktree_id_file_paths = Vec::new();
+            for (path, _) in &project_state.changed_paths {
+                if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id)
+                {
+                    worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf()));
+                }
+            }
+            self.db.embeddings_for_files(worktree_id_file_paths)
+        };
+
         let worktree = worktree.read(cx);
         let change_time = Instant::now();
         for (path, entry_id, change) in changes.iter() {
@@ -405,9 +413,18 @@ impl SemanticIndex {
         }
 
         cx.spawn_weak(|this, mut cx| async move {
+            let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default();
+
             cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
             if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
-                Self::reindex_changed_paths(this, project, Some(change_time), &mut cx).await;
+                Self::reindex_changed_paths(
+                    this,
+                    project,
+                    Some(change_time),
+                    &mut cx,
+                    Arc::new(embeddings_for_digest),
+                )
+                .await;
             }
         })
         .detach();
@@ -561,7 +578,32 @@ impl SemanticIndex {
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
         cx.spawn(|this, mut cx| async move {
-            Self::reindex_changed_paths(this.clone(), project.clone(), None, &mut cx).await;
+            let embeddings_for_digest = this.read_with(&cx, |this, cx| {
+                if let Some(state) = this.projects.get(&project.downgrade()) {
+                    let mut worktree_id_file_paths = Vec::new();
+                    for (path, _) in &state.changed_paths {
+                        if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id)
+                        {
+                            worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf()));
+                        }
+                    }
+
+                    Ok(this.db.embeddings_for_files(worktree_id_file_paths))
+                } else {
+                    Err(anyhow!("Project not yet initialized"))
+                }
+            })?;
+
+            let embeddings_for_digest = Arc::new(embeddings_for_digest.await?);
+
+            Self::reindex_changed_paths(
+                this.clone(),
+                project.clone(),
+                None,
+                &mut cx,
+                embeddings_for_digest,
+            )
+            .await;
 
             this.update(&mut cx, |this, _cx| {
                 let Some(state) = this.projects.get(&project.downgrade()) else {
@@ -726,6 +768,7 @@ impl SemanticIndex {
         project: ModelHandle<Project>,
         last_changed_before: Option<Instant>,
         cx: &mut AsyncAppContext,
+        embeddings_for_digest: Arc<HashMap<DocumentDigest, Embedding>>,
     ) {
         let mut pending_files = Vec::new();
         let mut files_to_delete = Vec::new();
@@ -805,7 +848,9 @@ impl SemanticIndex {
                 }
                 pending_file.language = Some(language);
             }
-            parsing_files_tx.try_send(pending_file).ok();
+            parsing_files_tx
+                .try_send((embeddings_for_digest.clone(), pending_file))
+                .ok();
         }
     }
 }