From 82079dd422613b98c8b1c6edfedaac1187ab2536 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 10 Jul 2023 16:33:14 -0400 Subject: [PATCH] Updated batching to accomodate for full flushes, and cleaned up reindexing. Co-authored-by: maxbrunsfeld --- crates/vector_store/src/embedding.rs | 4 +- crates/vector_store/src/vector_store.rs | 300 ++++++++++++------------ 2 files changed, 150 insertions(+), 154 deletions(-) diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 029a6cdf613aca7535d8591935daac74803b0af2..ea349c8afa4a8d908d60760f8ff1eb6839e3120b 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; +use gpui::executor::Background; use gpui::serde_json; use isahc::http::StatusCode; use isahc::prelude::Configurable; @@ -21,6 +22,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, + pub executor: Arc, } #[derive(Serialize)] @@ -128,7 +130,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { match response.status() { StatusCode::TOO_MANY_REQUESTS => { let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - std::thread::sleep(delay); + self.executor.timer(delay).await; } StatusCode::BAD_REQUEST => { log::info!("BAD REQUEST: {:?}", &response.status()); diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 92557fd80118550ccb52c74d889b4be47ad9ccda..c27c4992f33a72d07c5b7e082d102091617bbf99 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -17,14 +17,12 @@ use gpui::{ use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; use parsing::{CodeContextRetriever, ParsedFile}; -use project::{Fs, Project, WorktreeId}; +use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId}; use smol::channel; use std::{ - cell::RefCell, cmp::Ordering, collections::HashMap, path::{Path, PathBuf}, - rc::Rc, sync::Arc, time::{Duration, Instant, SystemTime}, }; @@ -61,6 +59,7 @@ pub fn init( // Arc::new(embedding::DummyEmbeddings {}), Arc::new(OpenAIEmbeddings { client: http_client, + executor: cx.background(), }), language_registry, cx.clone(), @@ -119,7 +118,7 @@ pub struct VectorStore { _embed_batch_task: Vec>, _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, - projects: HashMap, Rc>>, + projects: HashMap, ProjectState>, } struct ProjectState { @@ -201,6 +200,15 @@ enum DbWrite { }, } +enum EmbeddingJob { + Enqueue { + worktree_id: i64, + parsed_file: ParsedFile, + document_spans: Vec, + }, + Flush, +} + impl VectorStore { async fn new( fs: Arc, @@ -309,29 +317,32 @@ impl VectorStore { } })) } - // batch_tx/rx: Batch Files to Send for Embeddings - let (batch_files_tx, batch_files_rx) = - channel::unbounded::<(i64, ParsedFile, Vec)>(); + let (batch_files_tx, batch_files_rx) = channel::unbounded::(); let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; let mut embeddings_queue = vec![]; - while let Ok((worktree_id, indexed_file, document_spans)) = - batch_files_rx.recv().await - { - dbg!("Batching in while loop"); - queue_len += &document_spans.len(); - embeddings_queue.push((worktree_id, indexed_file, document_spans)); - if queue_len >= EMBEDDINGS_BATCH_SIZE { + + while let Ok(job) = batch_files_rx.recv().await { + let should_flush = match job { + EmbeddingJob::Enqueue { + document_spans, + worktree_id, + parsed_file, + } => { + queue_len += &document_spans.len(); + embeddings_queue.push((worktree_id, parsed_file, document_spans)); + queue_len >= EMBEDDINGS_BATCH_SIZE + } + EmbeddingJob::Flush => true, + }; + + if should_flush { embed_batch_tx.try_send(embeddings_queue).unwrap(); embeddings_queue = vec![]; queue_len = 0; } } - // TODO: This is never getting called, We've gotta manage for how to clear the embedding batch if its less than the necessary batch size. - if queue_len > 0 { - embed_batch_tx.try_send(embeddings_queue).unwrap(); - } }); // parsing_files_tx/rx: Parsing Files to Embeddable Documents @@ -353,13 +364,17 @@ impl VectorStore { retriever.parse_file(pending_file.clone()).await.log_err() { batch_files_tx - .try_send(( - pending_file.worktree_db_id, - indexed_file, + .try_send(EmbeddingJob::Enqueue { + worktree_id: pending_file.worktree_db_id, + parsed_file: indexed_file, document_spans, - )) + }) .unwrap(); } + + if parsing_files_rx.len() == 0 { + batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); + } } })); } @@ -526,143 +541,18 @@ impl VectorStore { // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed. let _subscription = cx.subscribe(&project, |this, project, event, cx| { - if let Some(project_state) = this.projects.get(&project.downgrade()) { - let mut project_state = project_state.borrow_mut(); - let worktree_db_ids = project_state.worktree_db_ids.clone(); - - if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event - { - // Get Worktree Object - let worktree = - project.read(cx).worktree_for_id(worktree_id.clone(), cx); - if worktree.is_none() { - return; - } - let worktree = worktree.unwrap(); - - // Get Database - let db_values = { - if let Ok(db) = - VectorDatabase::new(this.database_url.to_string_lossy().into()) - { - let worktree_db_id: Option = { - let mut found_db_id = None; - for (w_id, db_id) in worktree_db_ids.into_iter() { - if &w_id == &worktree.read(cx).id() { - found_db_id = Some(db_id) - } - } - found_db_id - }; - if worktree_db_id.is_none() { - return; - } - let worktree_db_id = worktree_db_id.unwrap(); - - let file_mtimes = db.get_file_mtimes(worktree_db_id); - if file_mtimes.is_err() { - return; - } - - let file_mtimes = file_mtimes.unwrap(); - Some((file_mtimes, worktree_db_id)) - } else { - return; - } - }; - - if db_values.is_none() { - return; - } - - let (file_mtimes, worktree_db_id) = db_values.unwrap(); - - // Iterate Through Changes - let language_registry = this.language_registry.clone(); - let parsing_files_tx = this.parsing_files_tx.clone(); - - smol::block_on(async move { - for change in changes.into_iter() { - let change_path = change.0.clone(); - let absolute_path = worktree.read(cx).absolutize(&change_path); - // Skip if git ignored or symlink - if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { - if entry.is_ignored || entry.is_symlink { - continue; - } else { - log::info!( - "Testing for Reindexing: {:?}", - &change_path - ); - } - }; - - if let Ok(language) = language_registry - .language_for_file(&change_path.to_path_buf(), None) - .await - { - if language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } - - if let Some(modified_time) = { - let metadata = change_path.metadata(); - if metadata.is_err() { - None - } else { - let mtime = metadata.unwrap().modified(); - if mtime.is_err() { - None - } else { - Some(mtime.unwrap()) - } - } - } { - let existing_time = - file_mtimes.get(&change_path.to_path_buf()); - let already_stored = existing_time - .map_or(false, |existing_time| { - &modified_time != existing_time - }); - - let reindex_time = modified_time - + Duration::from_secs(REINDEXING_DELAY_SECONDS); - - if !already_stored { - project_state.update_pending_files( - PendingFile { - relative_path: change_path.to_path_buf(), - absolute_path, - modified_time, - worktree_db_id, - language: language.clone(), - }, - reindex_time, - ); - - for file in project_state.get_outstanding_files() { - parsing_files_tx.try_send(file).unwrap(); - } - } - } - } - } - }); - }; + if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { + this.project_entries_changed(project, changes, cx, worktree_id); } }); this.projects.insert( project.downgrade(), - Rc::new(RefCell::new(ProjectState { + ProjectState { pending_files: HashMap::new(), worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(), _subscription, - })), + }, ); }); @@ -678,7 +568,7 @@ impl VectorStore { cx: &mut ModelContext, ) -> Task>> { let project_state = if let Some(state) = self.projects.get(&project.downgrade()) { - state.borrow() + state } else { return Task::ready(Err(anyhow!("project not added"))); }; @@ -736,7 +626,7 @@ impl VectorStore { this.read_with(&cx, |this, _| { let project_state = if let Some(state) = this.projects.get(&project.downgrade()) { - state.borrow() + state } else { return Err(anyhow!("project not added")); }; @@ -766,6 +656,110 @@ impl VectorStore { }) }) } + + fn project_entries_changed( + &mut self, + project: ModelHandle, + changes: &[(Arc, ProjectEntryId, PathChange)], + cx: &mut ModelContext<'_, VectorStore>, + worktree_id: &WorktreeId, + ) -> Option<()> { + let project_state = self.projects.get_mut(&project.downgrade())?; + let worktree_db_ids = project_state.worktree_db_ids.clone(); + let worktree = project.read(cx).worktree_for_id(worktree_id.clone(), cx)?; + + // Get Database + let (file_mtimes, worktree_db_id) = { + if let Ok(db) = VectorDatabase::new(self.database_url.to_string_lossy().into()) { + let worktree_db_id = { + let mut found_db_id = None; + for (w_id, db_id) in worktree_db_ids.into_iter() { + if &w_id == &worktree.read(cx).id() { + found_db_id = Some(db_id) + } + } + found_db_id + }?; + + let file_mtimes = db.get_file_mtimes(worktree_db_id).log_err()?; + + Some((file_mtimes, worktree_db_id)) + } else { + return None; + } + }?; + + // Iterate Through Changes + let language_registry = self.language_registry.clone(); + let parsing_files_tx = self.parsing_files_tx.clone(); + + smol::block_on(async move { + for change in changes.into_iter() { + let change_path = change.0.clone(); + let absolute_path = worktree.read(cx).absolutize(&change_path); + // Skip if git ignored or symlink + if let Some(entry) = worktree.read(cx).entry_for_id(change.1) { + if entry.is_ignored || entry.is_symlink { + continue; + } else { + log::info!("Testing for Reindexing: {:?}", &change_path); + } + }; + + if let Ok(language) = language_registry + .language_for_file(&change_path.to_path_buf(), None) + .await + { + if language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + + if let Some(modified_time) = { + let metadata = change_path.metadata(); + if metadata.is_err() { + None + } else { + let mtime = metadata.unwrap().modified(); + if mtime.is_err() { + None + } else { + Some(mtime.unwrap()) + } + } + } { + let existing_time = file_mtimes.get(&change_path.to_path_buf()); + let already_stored = existing_time + .map_or(false, |existing_time| &modified_time != existing_time); + + let reindex_time = + modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS); + + if !already_stored { + project_state.update_pending_files( + PendingFile { + relative_path: change_path.to_path_buf(), + absolute_path, + modified_time, + worktree_db_id, + language: language.clone(), + }, + reindex_time, + ); + + for file in project_state.get_outstanding_files() { + parsing_files_tx.try_send(file).unwrap(); + } + } + } + } + } + }); + Some(()) + } } impl Entity for VectorStore {