From 5abad58b0d81941726f81fd8e6e8ca876811163e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 16:58:45 -0400 Subject: [PATCH] moved semantic index to use embeddings queue to batch and managed for atomic database writes Co-authored-by: Max --- crates/semantic_index/src/embedding_queue.rs | 25 +- crates/semantic_index/src/semantic_index.rs | 238 +++--------------- .../src/semantic_index_tests.rs | 14 +- 3 files changed, 55 insertions(+), 222 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 6609c39e78a8ce830c55b9ed9c2621a20e850b13..2b48b7a7d68cf944c8c930861fc3f79fff37a3b4 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,10 +1,8 @@ -use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; - -use gpui::AppContext; +use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; - -use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; #[derive(Clone)] pub struct FileToEmbed { @@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed { pub struct EmbeddingQueue { embedding_provider: Arc, pending_batch: Vec, + executor: Arc, pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, @@ -49,10 +48,11 @@ pub struct FileToEmbedFragment { } impl EmbeddingQueue { - pub fn new(embedding_provider: Arc) -> Self { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, + executor, pending_batch: Vec::new(), pending_batch_token_count: 0, finished_files_tx, @@ -60,7 +60,12 @@ impl EmbeddingQueue { } } - pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) { + pub fn push(&mut self, file: FileToEmbed) { + if file.documents.is_empty() { + self.finished_files_tx.try_send(file).unwrap(); + return; + } + let file = Arc::new(Mutex::new(file)); self.pending_batch.push(FileToEmbedFragment { @@ -73,7 +78,7 @@ impl EmbeddingQueue { let next_token_count = self.pending_batch_token_count + document.token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; - self.flush(cx); + self.flush(); self.pending_batch.push(FileToEmbedFragment { file: file.clone(), document_range: range_end..range_end, @@ -86,7 +91,7 @@ impl EmbeddingQueue { } } - pub fn flush(&mut self, cx: &mut AppContext) { + pub fn flush(&mut self) { let batch = mem::take(&mut self.pending_batch); self.pending_batch_token_count = 0; if batch.is_empty() { @@ -95,7 +100,7 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - cx.background().spawn(async move { + self.executor.spawn(async move { let mut spans = Vec::new(); for fragment in &batch { let file = fragment.file.lock(); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index ab05ca7581efd908cf3fae273befe71221020da9..cde53182dcbe86b9b4e47512686e4560357e44d1 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod embedding_queue; mod parsing; pub mod semantic_index_settings; @@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings; use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding_queue::{EmbeddingQueue, FileToEmbed}; use futures::{channel::oneshot, Future}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; @@ -23,7 +25,6 @@ use smol::channel; use std::{ cmp::Ordering, collections::{BTreeMap, HashMap}, - mem, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, @@ -38,7 +39,6 @@ use util::{ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 7; -const EMBEDDINGS_BATCH_SIZE: usize = 80; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); pub fn init( @@ -106,9 +106,8 @@ pub struct SemanticIndex { language_registry: Arc, db_update_tx: channel::Sender, parsing_files_tx: channel::Sender, + _embedding_task: Task<()>, _db_update_task: Task<()>, - _embed_batch_tasks: Vec>, - _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -128,7 +127,7 @@ struct ChangedPathInfo { } #[derive(Clone)] -struct JobHandle { +pub struct JobHandle { /// The outer Arc is here to count the clones of a JobHandle instance; /// when the last handle to a given job is dropped, we decrement a counter (just once). tx: Arc>>>, @@ -230,17 +229,6 @@ enum DbOperation { }, } -enum EmbeddingJob { - Enqueue { - worktree_id: i64, - path: PathBuf, - mtime: SystemTime, - documents: Vec, - job_handle: JobHandle, - }, - Flush, -} - impl SemanticIndex { pub fn global(cx: &AppContext) -> Option> { if cx.has_global::>() { @@ -287,52 +275,35 @@ impl SemanticIndex { } }); - // Group documents into batches and send them to the embedding provider. - let (embed_batch_tx, embed_batch_rx) = - channel::unbounded::, PathBuf, SystemTime, JobHandle)>>(); - let mut _embed_batch_tasks = Vec::new(); - for _ in 0..cx.background().num_cpus() { - let embed_batch_rx = embed_batch_rx.clone(); - _embed_batch_tasks.push(cx.background().spawn({ - let db_update_tx = db_update_tx.clone(); - let embedding_provider = embedding_provider.clone(); - async move { - while let Ok(embeddings_queue) = embed_batch_rx.recv().await { - Self::compute_embeddings_for_batch( - embeddings_queue, - &embedding_provider, - &db_update_tx, - ) - .await; - } + let embedding_queue = + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); + let _embedding_task = cx.background().spawn({ + let embedded_files = embedding_queue.finished_files(); + let db_update_tx = db_update_tx.clone(); + async move { + while let Ok(file) = embedded_files.recv().await { + db_update_tx + .try_send(DbOperation::InsertFile { + worktree_id: file.worktree_id, + documents: file.documents, + path: file.path, + mtime: file.mtime, + job_handle: file.job_handle, + }) + .ok(); } - })); - } - - // Group documents into batches and send them to the embedding provider. - let (batch_files_tx, batch_files_rx) = channel::unbounded::(); - let _batch_files_task = cx.background().spawn(async move { - let mut queue_len = 0; - let mut embeddings_queue = vec![]; - while let Ok(job) = batch_files_rx.recv().await { - Self::enqueue_documents_to_embed( - job, - &mut queue_len, - &mut embeddings_queue, - &embed_batch_tx, - ); } }); // Parse files into embeddable documents. let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); + let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); let parsing_files_rx = parsing_files_rx.clone(); - let batch_files_tx = batch_files_tx.clone(); - let db_update_tx = db_update_tx.clone(); let embedding_provider = embedding_provider.clone(); + let embedding_queue = embedding_queue.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { @@ -340,9 +311,8 @@ impl SemanticIndex { &fs, pending_file, &mut retriever, - &batch_files_tx, + &embedding_queue, &parsing_files_rx, - &db_update_tx, ) .await; } @@ -361,8 +331,7 @@ impl SemanticIndex { db_update_tx, parsing_files_tx, _db_update_task, - _embed_batch_tasks, - _batch_files_task, + _embedding_task, _parsing_files_tasks, projects: HashMap::new(), } @@ -403,136 +372,12 @@ impl SemanticIndex { } } - async fn compute_embeddings_for_batch( - mut embeddings_queue: Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embedding_provider: &Arc, - db_update_tx: &channel::Sender, - ) { - let mut batch_documents = vec![]; - for (_, documents, _, _, _) in embeddings_queue.iter() { - batch_documents.extend(documents.iter().map(|document| document.content.as_str())); - } - - if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await { - log::trace!( - "created {} embeddings for {} files", - embeddings.len(), - embeddings_queue.len(), - ); - - let mut i = 0; - let mut j = 0; - - for embedding in embeddings.iter() { - while embeddings_queue[i].1.len() == j { - i += 1; - j = 0; - } - - embeddings_queue[i].1[j].embedding = embedding.to_owned(); - j += 1; - } - - for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } else { - // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed). - for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents: vec![], - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } - } - - fn enqueue_documents_to_embed( - job: EmbeddingJob, - queue_len: &mut usize, - embeddings_queue: &mut Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embed_batch_tx: &channel::Sender, PathBuf, SystemTime, JobHandle)>>, - ) { - // Handle edge case where individual file has more documents than max batch size - let should_flush = match job { - EmbeddingJob::Enqueue { - documents, - worktree_id, - path, - mtime, - job_handle, - } => { - // If documents is greater than embeddings batch size, recursively batch existing rows. - if &documents.len() > &EMBEDDINGS_BATCH_SIZE { - let first_job = EmbeddingJob::Enqueue { - documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - first_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - - let second_job = EmbeddingJob::Enqueue { - documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - second_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - return; - } else { - *queue_len += &documents.len(); - embeddings_queue.push((worktree_id, documents, path, mtime, job_handle)); - *queue_len >= EMBEDDINGS_BATCH_SIZE - } - } - EmbeddingJob::Flush => true, - }; - - if should_flush { - embed_batch_tx - .try_send(mem::take(embeddings_queue)) - .unwrap(); - *queue_len = 0; - } - } - async fn parse_file( fs: &Arc, pending_file: PendingFile, retriever: &mut CodeContextRetriever, - batch_files_tx: &channel::Sender, + embedding_queue: &Arc>, parsing_files_rx: &channel::Receiver, - db_update_tx: &channel::Sender, ) { let Some(language) = pending_file.language else { return; @@ -549,33 +394,18 @@ impl SemanticIndex { documents.len() ); - if documents.len() == 0 { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id: pending_file.worktree_db_id, - documents, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - }) - .await - .unwrap(); - } else { - batch_files_tx - .try_send(EmbeddingJob::Enqueue { - worktree_id: pending_file.worktree_db_id, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - documents, - }) - .unwrap(); - } + embedding_queue.lock().push(FileToEmbed { + worktree_id: pending_file.worktree_db_id, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + job_handle: pending_file.job_handle, + documents, + }); } } if parsing_files_rx.len() == 0 { - batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); + embedding_queue.lock().flush(); } } @@ -881,7 +711,7 @@ impl SemanticIndex { let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) + .embed_batch(vec![phrase]) .await? .into_iter() .next() diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 71789871653491fbd11f80ee41d4ae4c56cc28d7..dc41c09f7a939b60c27f612d8f51b8f34ce13650 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -235,17 +235,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .collect::>(); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone()); - let finished_files = cx.update(|cx| { - for file in &files { - queue.push(file.clone(), cx); - } - queue.flush(cx); - queue.finished_files() - }); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); + for file in &files { + queue.push(file.clone()); + } + queue.flush(); cx.foreground().run_until_parked(); + let finished_files = queue.finished_files(); let mut embedded_files: Vec<_> = files .iter() .map(|_| finished_files.try_recv().expect("no finished file"))