@@ -1,10 +1,8 @@
-use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
-
-use gpui::AppContext;
+use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
-
-use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
#[derive(Clone)]
pub struct FileToEmbed {
@@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed {
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileToEmbedFragment>,
+ executor: Arc<Background>,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
@@ -49,10 +48,11 @@ pub struct FileToEmbedFragment {
}
impl EmbeddingQueue {
- pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
+ pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
+ executor,
pending_batch: Vec::new(),
pending_batch_token_count: 0,
finished_files_tx,
@@ -60,7 +60,12 @@ impl EmbeddingQueue {
}
}
- pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
+ pub fn push(&mut self, file: FileToEmbed) {
+ if file.documents.is_empty() {
+ self.finished_files_tx.try_send(file).unwrap();
+ return;
+ }
+
let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileToEmbedFragment {
@@ -73,7 +78,7 @@ impl EmbeddingQueue {
let next_token_count = self.pending_batch_token_count + document.token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
- self.flush(cx);
+ self.flush();
self.pending_batch.push(FileToEmbedFragment {
file: file.clone(),
document_range: range_end..range_end,
@@ -86,7 +91,7 @@ impl EmbeddingQueue {
}
}
- pub fn flush(&mut self, cx: &mut AppContext) {
+ pub fn flush(&mut self) {
let batch = mem::take(&mut self.pending_batch);
self.pending_batch_token_count = 0;
if batch.is_empty() {
@@ -95,7 +100,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
- cx.background().spawn(async move {
+ self.executor.spawn(async move {
let mut spans = Vec::new();
for fragment in &batch {
let file = fragment.file.lock();
@@ -1,5 +1,6 @@
mod db;
mod embedding;
+mod embedding_queue;
mod parsing;
pub mod semantic_index_settings;
@@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
+use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{channel::oneshot, Future};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry};
@@ -23,7 +25,6 @@ use smol::channel;
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap},
- mem,
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Weak},
@@ -38,7 +39,6 @@ use util::{
use workspace::WorkspaceCreated;
const SEMANTIC_INDEX_VERSION: usize = 7;
-const EMBEDDINGS_BATCH_SIZE: usize = 80;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600);
pub fn init(
@@ -106,9 +106,8 @@ pub struct SemanticIndex {
language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>,
+ _embedding_task: Task<()>,
_db_update_task: Task<()>,
- _embed_batch_tasks: Vec<Task<()>>,
- _batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
@@ -128,7 +127,7 @@ struct ChangedPathInfo {
}
#[derive(Clone)]
-struct JobHandle {
+pub struct JobHandle {
/// The outer Arc is here to count the clones of a JobHandle instance;
/// when the last handle to a given job is dropped, we decrement a counter (just once).
tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
@@ -230,17 +229,6 @@ enum DbOperation {
},
}
-enum EmbeddingJob {
- Enqueue {
- worktree_id: i64,
- path: PathBuf,
- mtime: SystemTime,
- documents: Vec<Document>,
- job_handle: JobHandle,
- },
- Flush,
-}
-
impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
@@ -287,52 +275,35 @@ impl SemanticIndex {
}
});
- // Group documents into batches and send them to the embedding provider.
- let (embed_batch_tx, embed_batch_rx) =
- channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
- let mut _embed_batch_tasks = Vec::new();
- for _ in 0..cx.background().num_cpus() {
- let embed_batch_rx = embed_batch_rx.clone();
- _embed_batch_tasks.push(cx.background().spawn({
- let db_update_tx = db_update_tx.clone();
- let embedding_provider = embedding_provider.clone();
- async move {
- while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
- Self::compute_embeddings_for_batch(
- embeddings_queue,
- &embedding_provider,
- &db_update_tx,
- )
- .await;
- }
+ let embedding_queue =
+ EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
+ let _embedding_task = cx.background().spawn({
+ let embedded_files = embedding_queue.finished_files();
+ let db_update_tx = db_update_tx.clone();
+ async move {
+ while let Ok(file) = embedded_files.recv().await {
+ db_update_tx
+ .try_send(DbOperation::InsertFile {
+ worktree_id: file.worktree_id,
+ documents: file.documents,
+ path: file.path,
+ mtime: file.mtime,
+ job_handle: file.job_handle,
+ })
+ .ok();
}
- }));
- }
-
- // Group documents into batches and send them to the embedding provider.
- let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
- let _batch_files_task = cx.background().spawn(async move {
- let mut queue_len = 0;
- let mut embeddings_queue = vec![];
- while let Ok(job) = batch_files_rx.recv().await {
- Self::enqueue_documents_to_embed(
- job,
- &mut queue_len,
- &mut embeddings_queue,
- &embed_batch_tx,
- );
}
});
// Parse files into embeddable documents.
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
+ let embedding_queue = Arc::new(Mutex::new(embedding_queue));
let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
let fs = fs.clone();
let parsing_files_rx = parsing_files_rx.clone();
- let batch_files_tx = batch_files_tx.clone();
- let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
+ let embedding_queue = embedding_queue.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
while let Ok(pending_file) = parsing_files_rx.recv().await {
@@ -340,9 +311,8 @@ impl SemanticIndex {
&fs,
pending_file,
&mut retriever,
- &batch_files_tx,
+ &embedding_queue,
&parsing_files_rx,
- &db_update_tx,
)
.await;
}
@@ -361,8 +331,7 @@ impl SemanticIndex {
db_update_tx,
parsing_files_tx,
_db_update_task,
- _embed_batch_tasks,
- _batch_files_task,
+ _embedding_task,
_parsing_files_tasks,
projects: HashMap::new(),
}
@@ -403,136 +372,12 @@ impl SemanticIndex {
}
}
- async fn compute_embeddings_for_batch(
- mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
- embedding_provider: &Arc<dyn EmbeddingProvider>,
- db_update_tx: &channel::Sender<DbOperation>,
- ) {
- let mut batch_documents = vec![];
- for (_, documents, _, _, _) in embeddings_queue.iter() {
- batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
- }
-
- if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
- log::trace!(
- "created {} embeddings for {} files",
- embeddings.len(),
- embeddings_queue.len(),
- );
-
- let mut i = 0;
- let mut j = 0;
-
- for embedding in embeddings.iter() {
- while embeddings_queue[i].1.len() == j {
- i += 1;
- j = 0;
- }
-
- embeddings_queue[i].1[j].embedding = embedding.to_owned();
- j += 1;
- }
-
- for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
- db_update_tx
- .send(DbOperation::InsertFile {
- worktree_id,
- documents,
- path,
- mtime,
- job_handle,
- })
- .await
- .unwrap();
- }
- } else {
- // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
- for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
- db_update_tx
- .send(DbOperation::InsertFile {
- worktree_id,
- documents: vec![],
- path,
- mtime,
- job_handle,
- })
- .await
- .unwrap();
- }
- }
- }
-
- fn enqueue_documents_to_embed(
- job: EmbeddingJob,
- queue_len: &mut usize,
- embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
- embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
- ) {
- // Handle edge case where individual file has more documents than max batch size
- let should_flush = match job {
- EmbeddingJob::Enqueue {
- documents,
- worktree_id,
- path,
- mtime,
- job_handle,
- } => {
- // If documents is greater than embeddings batch size, recursively batch existing rows.
- if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
- let first_job = EmbeddingJob::Enqueue {
- documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
- worktree_id,
- path: path.clone(),
- mtime,
- job_handle: job_handle.clone(),
- };
-
- Self::enqueue_documents_to_embed(
- first_job,
- queue_len,
- embeddings_queue,
- embed_batch_tx,
- );
-
- let second_job = EmbeddingJob::Enqueue {
- documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
- worktree_id,
- path: path.clone(),
- mtime,
- job_handle: job_handle.clone(),
- };
-
- Self::enqueue_documents_to_embed(
- second_job,
- queue_len,
- embeddings_queue,
- embed_batch_tx,
- );
- return;
- } else {
- *queue_len += &documents.len();
- embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
- *queue_len >= EMBEDDINGS_BATCH_SIZE
- }
- }
- EmbeddingJob::Flush => true,
- };
-
- if should_flush {
- embed_batch_tx
- .try_send(mem::take(embeddings_queue))
- .unwrap();
- *queue_len = 0;
- }
- }
-
async fn parse_file(
fs: &Arc<dyn Fs>,
pending_file: PendingFile,
retriever: &mut CodeContextRetriever,
- batch_files_tx: &channel::Sender<EmbeddingJob>,
+ embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
parsing_files_rx: &channel::Receiver<PendingFile>,
- db_update_tx: &channel::Sender<DbOperation>,
) {
let Some(language) = pending_file.language else {
return;
@@ -549,33 +394,18 @@ impl SemanticIndex {
documents.len()
);
- if documents.len() == 0 {
- db_update_tx
- .send(DbOperation::InsertFile {
- worktree_id: pending_file.worktree_db_id,
- documents,
- path: pending_file.relative_path,
- mtime: pending_file.modified_time,
- job_handle: pending_file.job_handle,
- })
- .await
- .unwrap();
- } else {
- batch_files_tx
- .try_send(EmbeddingJob::Enqueue {
- worktree_id: pending_file.worktree_db_id,
- path: pending_file.relative_path,
- mtime: pending_file.modified_time,
- job_handle: pending_file.job_handle,
- documents,
- })
- .unwrap();
- }
+ embedding_queue.lock().push(FileToEmbed {
+ worktree_id: pending_file.worktree_db_id,
+ path: pending_file.relative_path,
+ mtime: pending_file.modified_time,
+ job_handle: pending_file.job_handle,
+ documents,
+ });
}
}
if parsing_files_rx.len() == 0 {
- batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
+ embedding_queue.lock().flush();
}
}
@@ -881,7 +711,7 @@ impl SemanticIndex {
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
let phrase_embedding = embedding_provider
- .embed_batch(vec![&phrase])
+ .embed_batch(vec![phrase])
.await?
.into_iter()
.next()