fill embeddings with database values and skip during embeddings queue

KCaverly created

Change summary

crates/semantic_index/src/embedding_queue.rs | 34 ++++++++++++++++++--
crates/semantic_index/src/semantic_index.rs  | 35 +++++++++++----------
2 files changed, 48 insertions(+), 21 deletions(-)

Detailed changes

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -42,6 +42,7 @@ pub struct EmbeddingQueue {
     finished_files_rx: channel::Receiver<FileToEmbed>,
 }
 
+#[derive(Clone)]
 pub struct FileToEmbedFragment {
     file: Arc<Mutex<FileToEmbed>>,
     document_range: Range<usize>,
@@ -74,8 +75,16 @@ impl EmbeddingQueue {
         });
 
         let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+        let mut saved_tokens = 0;
         for (ix, document) in file.lock().documents.iter().enumerate() {
-            let next_token_count = self.pending_batch_token_count + document.token_count;
+            let document_token_count = if document.embedding.is_none() {
+                document.token_count
+            } else {
+                saved_tokens += document.token_count;
+                0
+            };
+
+            let next_token_count = self.pending_batch_token_count + document_token_count;
             if next_token_count > self.embedding_provider.max_tokens_per_batch() {
                 let range_end = fragment_range.end;
                 self.flush();
@@ -87,8 +96,9 @@ impl EmbeddingQueue {
             }
 
             fragment_range.end = ix + 1;
-            self.pending_batch_token_count += document.token_count;
+            self.pending_batch_token_count += document_token_count;
         }
+        log::trace!("Saved Tokens: {:?}", saved_tokens);
     }
 
     pub fn flush(&mut self) {
@@ -100,25 +110,41 @@ impl EmbeddingQueue {
 
         let finished_files_tx = self.finished_files_tx.clone();
         let embedding_provider = self.embedding_provider.clone();
+
         self.executor.spawn(async move {
             let mut spans = Vec::new();
+            let mut document_count = 0;
             for fragment in &batch {
                 let file = fragment.file.lock();
+                document_count += file.documents[fragment.document_range.clone()].len();
                 spans.extend(
                     {
                         file.documents[fragment.document_range.clone()]
-                            .iter()
+                            .iter().filter(|d| d.embedding.is_none())
                             .map(|d| d.content.clone())
                         }
                 );
             }
 
+            log::trace!("Documents Length: {:?}", document_count);
+            log::trace!("Span Length: {:?}", spans.clone().len());
+
+            // If spans is 0, just send the fragment to the finished files if its the last one.
+            if spans.len() == 0 {
+                for fragment in batch.clone() {
+                    if let Some(file) = Arc::into_inner(fragment.file) {
+                        finished_files_tx.try_send(file.into_inner()).unwrap();
+                    }
+                }
+                return;
+            };
+
             match embedding_provider.embed_batch(spans).await {
                 Ok(embeddings) => {
                     let mut embeddings = embeddings.into_iter();
                     for fragment in batch {
                         for document in
-                            &mut fragment.file.lock().documents[fragment.document_range.clone()]
+                            &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
                         {
                             if let Some(embedding) = embeddings.next() {
                                 document.embedding = Some(embedding);

crates/semantic_index/src/semantic_index.rs 🔗

@@ -255,6 +255,7 @@ impl SemanticIndex {
                 let parsing_files_rx = parsing_files_rx.clone();
                 let embedding_provider = embedding_provider.clone();
                 let embedding_queue = embedding_queue.clone();
+                let db = db.clone();
                 _parsing_files_tasks.push(cx.background().spawn(async move {
                     let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
                     while let Ok(pending_file) = parsing_files_rx.recv().await {
@@ -264,6 +265,7 @@ impl SemanticIndex {
                             &mut retriever,
                             &embedding_queue,
                             &parsing_files_rx,
+                            &db,
                         )
                         .await;
                     }
@@ -293,13 +295,14 @@ impl SemanticIndex {
         retriever: &mut CodeContextRetriever,
         embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
         parsing_files_rx: &channel::Receiver<PendingFile>,
+        db: &VectorDatabase,
     ) {
         let Some(language) = pending_file.language else {
             return;
         };
 
         if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
-            if let Some(documents) = retriever
+            if let Some(mut documents) = retriever
                 .parse_file_with_template(&pending_file.relative_path, &content, language)
                 .log_err()
             {
@@ -309,22 +312,20 @@ impl SemanticIndex {
                     documents.len()
                 );
 
-                todo!();
-                // if let Some(embeddings) = db
-                //     .embeddings_for_documents(
-                //         pending_file.worktree_db_id,
-                //         pending_file.relative_path,
-                //         &documents,
-                //     )
-                //     .await
-                //     .log_err()
-                // {
-                //     for (document, embedding) in documents.iter_mut().zip(embeddings) {
-                //         if let Some(embedding) = embedding {
-                //             document.embedding = embedding;
-                //         }
-                //     }
-                // }
+                if let Some(sha_to_embeddings) = db
+                    .embeddings_for_file(
+                        pending_file.worktree_db_id,
+                        pending_file.relative_path.clone(),
+                    )
+                    .await
+                    .log_err()
+                {
+                    for document in documents.iter_mut() {
+                        if let Some(embedding) = sha_to_embeddings.get(&document.digest) {
+                            document.embedding = Some(embedding.to_owned());
+                        }
+                    }
+                }
 
                 embedding_queue.lock().push(FileToEmbed {
                     worktree_id: pending_file.worktree_db_id,