Updated batching to accomodate for full flushes, and cleaned up reindexing.

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

crates/vector_store/src/embedding.rs    |   4 
crates/vector_store/src/vector_store.rs | 300 +++++++++++++-------------
2 files changed, 150 insertions(+), 154 deletions(-)

Detailed changes

crates/vector_store/src/embedding.rs 🔗

@@ -1,6 +1,7 @@
 use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use futures::AsyncReadExt;
+use gpui::executor::Background;
 use gpui::serde_json;
 use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
@@ -21,6 +22,7 @@ lazy_static! {
 #[derive(Clone)]
 pub struct OpenAIEmbeddings {
     pub client: Arc<dyn HttpClient>,
+    pub executor: Arc<Background>,
 }
 
 #[derive(Serialize)]
@@ -128,7 +130,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
             match response.status() {
                 StatusCode::TOO_MANY_REQUESTS => {
                     let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
-                    std::thread::sleep(delay);
+                    self.executor.timer(delay).await;
                 }
                 StatusCode::BAD_REQUEST => {
                     log::info!("BAD REQUEST: {:?}", &response.status());

crates/vector_store/src/vector_store.rs 🔗

@@ -17,14 +17,12 @@ use gpui::{
 use language::{Language, LanguageRegistry};
 use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 use parsing::{CodeContextRetriever, ParsedFile};
-use project::{Fs, Project, WorktreeId};
+use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
 use smol::channel;
 use std::{
-    cell::RefCell,
     cmp::Ordering,
     collections::HashMap,
     path::{Path, PathBuf},
-    rc::Rc,
     sync::Arc,
     time::{Duration, Instant, SystemTime},
 };
@@ -61,6 +59,7 @@ pub fn init(
             // Arc::new(embedding::DummyEmbeddings {}),
             Arc::new(OpenAIEmbeddings {
                 client: http_client,
+                executor: cx.background(),
             }),
             language_registry,
             cx.clone(),
@@ -119,7 +118,7 @@ pub struct VectorStore {
     _embed_batch_task: Vec<Task<()>>,
     _batch_files_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
-    projects: HashMap<WeakModelHandle<Project>, Rc<RefCell<ProjectState>>>,
+    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
 
 struct ProjectState {
@@ -201,6 +200,15 @@ enum DbWrite {
     },
 }
 
+enum EmbeddingJob {
+    Enqueue {
+        worktree_id: i64,
+        parsed_file: ParsedFile,
+        document_spans: Vec<String>,
+    },
+    Flush,
+}
+
 impl VectorStore {
     async fn new(
         fs: Arc<dyn Fs>,
@@ -309,29 +317,32 @@ impl VectorStore {
                     }
                 }))
             }
-
             // batch_tx/rx: Batch Files to Send for Embeddings
-            let (batch_files_tx, batch_files_rx) =
-                channel::unbounded::<(i64, ParsedFile, Vec<String>)>();
+            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
             let _batch_files_task = cx.background().spawn(async move {
                 let mut queue_len = 0;
                 let mut embeddings_queue = vec![];
-                while let Ok((worktree_id, indexed_file, document_spans)) =
-                    batch_files_rx.recv().await
-                {
-                    dbg!("Batching in while loop");
-                    queue_len += &document_spans.len();
-                    embeddings_queue.push((worktree_id, indexed_file, document_spans));
-                    if queue_len >= EMBEDDINGS_BATCH_SIZE {
+
+                while let Ok(job) = batch_files_rx.recv().await {
+                    let should_flush = match job {
+                        EmbeddingJob::Enqueue {
+                            document_spans,
+                            worktree_id,
+                            parsed_file,
+                        } => {
+                            queue_len += &document_spans.len();
+                            embeddings_queue.push((worktree_id, parsed_file, document_spans));
+                            queue_len >= EMBEDDINGS_BATCH_SIZE
+                        }
+                        EmbeddingJob::Flush => true,
+                    };
+
+                    if should_flush {
                         embed_batch_tx.try_send(embeddings_queue).unwrap();
                         embeddings_queue = vec![];
                         queue_len = 0;
                     }
                 }
-                // TODO: This is never getting called, We've gotta manage for how to clear the embedding batch if its less than the necessary batch size.
-                if queue_len > 0 {
-                    embed_batch_tx.try_send(embeddings_queue).unwrap();
-                }
             });
 
             // parsing_files_tx/rx: Parsing Files to Embeddable Documents
@@ -353,13 +364,17 @@ impl VectorStore {
                             retriever.parse_file(pending_file.clone()).await.log_err()
                         {
                             batch_files_tx
-                                .try_send((
-                                    pending_file.worktree_db_id,
-                                    indexed_file,
+                                .try_send(EmbeddingJob::Enqueue {
+                                    worktree_id: pending_file.worktree_db_id,
+                                    parsed_file: indexed_file,
                                     document_spans,
-                                ))
+                                })
                                 .unwrap();
                         }
+
+                        if parsing_files_rx.len() == 0 {
+                            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
+                        }
                     }
                 }));
             }
@@ -526,143 +541,18 @@ impl VectorStore {
                 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
                 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
                 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
-                    if let Some(project_state) = this.projects.get(&project.downgrade()) {
-                        let mut project_state = project_state.borrow_mut();
-                        let worktree_db_ids = project_state.worktree_db_ids.clone();
-
-                        if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
-                        {
-                            // Get Worktree Object
-                            let worktree =
-                                project.read(cx).worktree_for_id(worktree_id.clone(), cx);
-                            if worktree.is_none() {
-                                return;
-                            }
-                            let worktree = worktree.unwrap();
-
-                            // Get Database
-                            let db_values = {
-                                if let Ok(db) =
-                                    VectorDatabase::new(this.database_url.to_string_lossy().into())
-                                {
-                                    let worktree_db_id: Option<i64> = {
-                                        let mut found_db_id = None;
-                                        for (w_id, db_id) in worktree_db_ids.into_iter() {
-                                            if &w_id == &worktree.read(cx).id() {
-                                                found_db_id = Some(db_id)
-                                            }
-                                        }
-                                        found_db_id
-                                    };
-                                    if worktree_db_id.is_none() {
-                                        return;
-                                    }
-                                    let worktree_db_id = worktree_db_id.unwrap();
-
-                                    let file_mtimes = db.get_file_mtimes(worktree_db_id);
-                                    if file_mtimes.is_err() {
-                                        return;
-                                    }
-
-                                    let file_mtimes = file_mtimes.unwrap();
-                                    Some((file_mtimes, worktree_db_id))
-                                } else {
-                                    return;
-                                }
-                            };
-
-                            if db_values.is_none() {
-                                return;
-                            }
-
-                            let (file_mtimes, worktree_db_id) = db_values.unwrap();
-
-                            // Iterate Through Changes
-                            let language_registry = this.language_registry.clone();
-                            let parsing_files_tx = this.parsing_files_tx.clone();
-
-                            smol::block_on(async move {
-                                for change in changes.into_iter() {
-                                    let change_path = change.0.clone();
-                                    let absolute_path = worktree.read(cx).absolutize(&change_path);
-                                    // Skip if git ignored or symlink
-                                    if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
-                                        if entry.is_ignored || entry.is_symlink {
-                                            continue;
-                                        } else {
-                                            log::info!(
-                                                "Testing for Reindexing: {:?}",
-                                                &change_path
-                                            );
-                                        }
-                                    };
-
-                                    if let Ok(language) = language_registry
-                                        .language_for_file(&change_path.to_path_buf(), None)
-                                        .await
-                                    {
-                                        if language
-                                            .grammar()
-                                            .and_then(|grammar| grammar.embedding_config.as_ref())
-                                            .is_none()
-                                        {
-                                            continue;
-                                        }
-
-                                        if let Some(modified_time) = {
-                                            let metadata = change_path.metadata();
-                                            if metadata.is_err() {
-                                                None
-                                            } else {
-                                                let mtime = metadata.unwrap().modified();
-                                                if mtime.is_err() {
-                                                    None
-                                                } else {
-                                                    Some(mtime.unwrap())
-                                                }
-                                            }
-                                        } {
-                                            let existing_time =
-                                                file_mtimes.get(&change_path.to_path_buf());
-                                            let already_stored = existing_time
-                                                .map_or(false, |existing_time| {
-                                                    &modified_time != existing_time
-                                                });
-
-                                            let reindex_time = modified_time
-                                                + Duration::from_secs(REINDEXING_DELAY_SECONDS);
-
-                                            if !already_stored {
-                                                project_state.update_pending_files(
-                                                    PendingFile {
-                                                        relative_path: change_path.to_path_buf(),
-                                                        absolute_path,
-                                                        modified_time,
-                                                        worktree_db_id,
-                                                        language: language.clone(),
-                                                    },
-                                                    reindex_time,
-                                                );
-
-                                                for file in project_state.get_outstanding_files() {
-                                                    parsing_files_tx.try_send(file).unwrap();
-                                                }
-                                            }
-                                        }
-                                    }
-                                }
-                            });
-                        };
+                    if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
+                        this.project_entries_changed(project, changes, cx, worktree_id);
                     }
                 });
 
                 this.projects.insert(
                     project.downgrade(),
-                    Rc::new(RefCell::new(ProjectState {
+                    ProjectState {
                         pending_files: HashMap::new(),
                         worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
                         _subscription,
-                    })),
+                    },
                 );
             });
 
@@ -678,7 +568,7 @@ impl VectorStore {
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<SearchResult>>> {
         let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
-            state.borrow()
+            state
         } else {
             return Task::ready(Err(anyhow!("project not added")));
         };
@@ -736,7 +626,7 @@ impl VectorStore {
 
             this.read_with(&cx, |this, _| {
                 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
-                    state.borrow()
+                    state
                 } else {
                     return Err(anyhow!("project not added"));
                 };
@@ -766,6 +656,110 @@ impl VectorStore {
             })
         })
     }
+
+    fn project_entries_changed(
+        &mut self,
+        project: ModelHandle<Project>,
+        changes: &[(Arc<Path>, ProjectEntryId, PathChange)],
+        cx: &mut ModelContext<'_, VectorStore>,
+        worktree_id: &WorktreeId,
+    ) -> Option<()> {
+        let project_state = self.projects.get_mut(&project.downgrade())?;
+        let worktree_db_ids = project_state.worktree_db_ids.clone();
+        let worktree = project.read(cx).worktree_for_id(worktree_id.clone(), cx)?;
+
+        // Get Database
+        let (file_mtimes, worktree_db_id) = {
+            if let Ok(db) = VectorDatabase::new(self.database_url.to_string_lossy().into()) {
+                let worktree_db_id = {
+                    let mut found_db_id = None;
+                    for (w_id, db_id) in worktree_db_ids.into_iter() {
+                        if &w_id == &worktree.read(cx).id() {
+                            found_db_id = Some(db_id)
+                        }
+                    }
+                    found_db_id
+                }?;
+
+                let file_mtimes = db.get_file_mtimes(worktree_db_id).log_err()?;
+
+                Some((file_mtimes, worktree_db_id))
+            } else {
+                return None;
+            }
+        }?;
+
+        // Iterate Through Changes
+        let language_registry = self.language_registry.clone();
+        let parsing_files_tx = self.parsing_files_tx.clone();
+
+        smol::block_on(async move {
+            for change in changes.into_iter() {
+                let change_path = change.0.clone();
+                let absolute_path = worktree.read(cx).absolutize(&change_path);
+                // Skip if git ignored or symlink
+                if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
+                    if entry.is_ignored || entry.is_symlink {
+                        continue;
+                    } else {
+                        log::info!("Testing for Reindexing: {:?}", &change_path);
+                    }
+                };
+
+                if let Ok(language) = language_registry
+                    .language_for_file(&change_path.to_path_buf(), None)
+                    .await
+                {
+                    if language
+                        .grammar()
+                        .and_then(|grammar| grammar.embedding_config.as_ref())
+                        .is_none()
+                    {
+                        continue;
+                    }
+
+                    if let Some(modified_time) = {
+                        let metadata = change_path.metadata();
+                        if metadata.is_err() {
+                            None
+                        } else {
+                            let mtime = metadata.unwrap().modified();
+                            if mtime.is_err() {
+                                None
+                            } else {
+                                Some(mtime.unwrap())
+                            }
+                        }
+                    } {
+                        let existing_time = file_mtimes.get(&change_path.to_path_buf());
+                        let already_stored = existing_time
+                            .map_or(false, |existing_time| &modified_time != existing_time);
+
+                        let reindex_time =
+                            modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS);
+
+                        if !already_stored {
+                            project_state.update_pending_files(
+                                PendingFile {
+                                    relative_path: change_path.to_path_buf(),
+                                    absolute_path,
+                                    modified_time,
+                                    worktree_db_id,
+                                    language: language.clone(),
+                                },
+                                reindex_time,
+                            );
+
+                            for file in project_state.get_outstanding_files() {
+                                parsing_files_tx.try_send(file).unwrap();
+                            }
+                        }
+                    }
+                }
+            }
+        });
+        Some(())
+    }
 }
 
 impl Entity for VectorStore {