@@ -42,6 +42,7 @@ pub struct EmbeddingQueue {
finished_files_rx: channel::Receiver<FileToEmbed>,
}
+#[derive(Clone)]
pub struct FileToEmbedFragment {
file: Arc<Mutex<FileToEmbed>>,
document_range: Range<usize>,
@@ -74,8 +75,16 @@ impl EmbeddingQueue {
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+ let mut saved_tokens = 0;
for (ix, document) in file.lock().documents.iter().enumerate() {
- let next_token_count = self.pending_batch_token_count + document.token_count;
+ let document_token_count = if document.embedding.is_none() {
+ document.token_count
+ } else {
+ saved_tokens += document.token_count;
+ 0
+ };
+
+ let next_token_count = self.pending_batch_token_count + document_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
@@ -87,8 +96,9 @@ impl EmbeddingQueue {
}
fragment_range.end = ix + 1;
- self.pending_batch_token_count += document.token_count;
+ self.pending_batch_token_count += document_token_count;
}
+ log::trace!("Saved Tokens: {:?}", saved_tokens);
}
pub fn flush(&mut self) {
@@ -100,25 +110,41 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
+
self.executor.spawn(async move {
let mut spans = Vec::new();
+ let mut document_count = 0;
for fragment in &batch {
let file = fragment.file.lock();
+ document_count += file.documents[fragment.document_range.clone()].len();
spans.extend(
{
file.documents[fragment.document_range.clone()]
- .iter()
+ .iter().filter(|d| d.embedding.is_none())
.map(|d| d.content.clone())
}
);
}
+ log::trace!("Documents Length: {:?}", document_count);
+ log::trace!("Span Length: {:?}", spans.clone().len());
+
+ // If spans is 0, just send the fragment to the finished files if its the last one.
+ if spans.len() == 0 {
+ for fragment in batch.clone() {
+ if let Some(file) = Arc::into_inner(fragment.file) {
+ finished_files_tx.try_send(file.into_inner()).unwrap();
+ }
+ }
+ return;
+ };
+
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for document in
- &mut fragment.file.lock().documents[fragment.document_range.clone()]
+ &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
document.embedding = Some(embedding);
@@ -255,6 +255,7 @@ impl SemanticIndex {
let parsing_files_rx = parsing_files_rx.clone();
let embedding_provider = embedding_provider.clone();
let embedding_queue = embedding_queue.clone();
+ let db = db.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
while let Ok(pending_file) = parsing_files_rx.recv().await {
@@ -264,6 +265,7 @@ impl SemanticIndex {
&mut retriever,
&embedding_queue,
&parsing_files_rx,
+ &db,
)
.await;
}
@@ -293,13 +295,14 @@ impl SemanticIndex {
retriever: &mut CodeContextRetriever,
embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
parsing_files_rx: &channel::Receiver<PendingFile>,
+ db: &VectorDatabase,
) {
let Some(language) = pending_file.language else {
return;
};
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
- if let Some(documents) = retriever
+ if let Some(mut documents) = retriever
.parse_file_with_template(&pending_file.relative_path, &content, language)
.log_err()
{
@@ -309,22 +312,20 @@ impl SemanticIndex {
documents.len()
);
- todo!();
- // if let Some(embeddings) = db
- // .embeddings_for_documents(
- // pending_file.worktree_db_id,
- // pending_file.relative_path,
- // &documents,
- // )
- // .await
- // .log_err()
- // {
- // for (document, embedding) in documents.iter_mut().zip(embeddings) {
- // if let Some(embedding) = embedding {
- // document.embedding = embedding;
- // }
- // }
- // }
+ if let Some(sha_to_embeddings) = db
+ .embeddings_for_file(
+ pending_file.worktree_db_id,
+ pending_file.relative_path.clone(),
+ )
+ .await
+ .log_err()
+ {
+ for document in documents.iter_mut() {
+ if let Some(embedding) = sha_to_embeddings.get(&document.digest) {
+ document.embedding = Some(embedding.to_owned());
+ }
+ }
+ }
embedding_queue.lock().push(FileToEmbed {
worktree_id: pending_file.worktree_db_id,