Cargo.lock 🔗
@@ -8493,6 +8493,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
+ "parking_lot 0.11.2",
"picker",
"project",
"rand 0.8.5",
KCaverly and maxbrunsfeld created
Co-authored-by: maxbrunsfeld <max@zed.dev>
Cargo.lock | 1
crates/vector_store/Cargo.toml | 1
crates/vector_store/src/modal.rs | 2
crates/vector_store/src/vector_store.rs | 379 +++++++-------------
crates/vector_store/src/vector_store_tests.rs | 78 +++
5 files changed, 208 insertions(+), 253 deletions(-)
@@ -8493,6 +8493,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
+ "parking_lot 0.11.2",
"picker",
"project",
"rand 0.8.5",
@@ -33,6 +33,7 @@ async-trait.workspace = true
bincode = "1.3.3"
matrixmultiply = "0.3.7"
tiktoken-rs = "0.5.0"
+parking_lot.workspace = true
rand.workspace = true
schemars.workspace = true
@@ -124,7 +124,7 @@ impl PickerDelegate for SemanticSearchDelegate {
if let Some(retrieved) = retrieved_cached.log_err() {
if !retrieved {
let task = vector_store.update(&mut cx, |store, cx| {
- store.search(project.clone(), query.to_string(), 10, cx)
+ store.search_project(project.clone(), query.to_string(), 10, cx)
});
if let Some(results) = task.await.log_err() {
@@ -18,15 +18,19 @@ use gpui::{
};
use language::{Language, LanguageRegistry};
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
+use parking_lot::Mutex;
use parsing::{CodeContextRetriever, Document};
-use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
+use project::{Fs, Project, WorktreeId};
use smol::channel;
use std::{
- collections::HashMap,
+ collections::{HashMap, HashSet},
ops::Range,
path::{Path, PathBuf},
- sync::Arc,
- time::{Duration, Instant, SystemTime},
+ sync::{
+ atomic::{self, AtomicUsize},
+ Arc, Weak,
+ },
+ time::{Instant, SystemTime},
};
use util::{
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@@ -99,7 +103,7 @@ pub fn init(
let project = workspace.read(cx).project().clone();
if project.read(cx).is_local() {
vector_store.update(cx, |store, cx| {
- store.add_project(project, cx).detach();
+ store.index_project(project, cx).detach();
});
}
}
@@ -124,13 +128,20 @@ pub struct VectorStore {
_embed_batch_task: Task<()>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
+ next_job_id: Arc<AtomicUsize>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
struct ProjectState {
worktree_db_ids: Vec<(WorktreeId, i64)>,
- pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
- _subscription: gpui::Subscription,
+ outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
+}
+
+type JobId = usize;
+
+struct JobHandle {
+ id: JobId,
+ set: Weak<Mutex<HashSet<JobId>>>,
}
impl ProjectState {
@@ -157,54 +168,15 @@ impl ProjectState {
}
})
}
-
- fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
- // If Pending File Already Exists, Replace it with the new one
- // but keep the old indexing time
- if let Some(old_file) = self
- .pending_files
- .remove(&pending_file.relative_path.clone())
- {
- self.pending_files.insert(
- pending_file.relative_path.clone(),
- (pending_file, old_file.1),
- );
- } else {
- self.pending_files.insert(
- pending_file.relative_path.clone(),
- (pending_file, indexing_time),
- );
- };
- }
-
- fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
- let mut outstanding_files = vec![];
- let mut remove_keys = vec![];
- for key in self.pending_files.keys().into_iter() {
- if let Some(pending_details) = self.pending_files.get(key) {
- let (pending_file, index_time) = pending_details;
- if index_time <= &SystemTime::now() {
- outstanding_files.push(pending_file.clone());
- remove_keys.push(key.clone());
- }
- }
- }
-
- for key in remove_keys.iter() {
- self.pending_files.remove(key);
- }
-
- return outstanding_files;
- }
}
-#[derive(Clone, Debug)]
pub struct PendingFile {
worktree_db_id: i64,
relative_path: PathBuf,
absolute_path: PathBuf,
language: Arc<Language>,
modified_time: SystemTime,
+ job_handle: JobHandle,
}
#[derive(Debug, Clone)]
@@ -221,6 +193,7 @@ enum DbOperation {
documents: Vec<Document>,
path: PathBuf,
mtime: SystemTime,
+ job_handle: JobHandle,
},
Delete {
worktree_id: i64,
@@ -242,6 +215,7 @@ enum EmbeddingJob {
path: PathBuf,
mtime: SystemTime,
documents: Vec<Document>,
+ job_handle: JobHandle,
},
Flush,
}
@@ -274,9 +248,11 @@ impl VectorStore {
documents,
path,
mtime,
+ job_handle,
} => {
db.insert_file(worktree_id, path, mtime, documents)
.log_err();
+ drop(job_handle)
}
DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
@@ -298,7 +274,7 @@ impl VectorStore {
// embed_tx/rx: Embed Batch and Send to Database
let (embed_batch_tx, embed_batch_rx) =
- channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime)>>();
+ channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
let _embed_batch_task = cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
@@ -306,7 +282,7 @@ impl VectorStore {
while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
// Construct Batch
let mut batch_documents = vec![];
- for (_, documents, _, _) in embeddings_queue.iter() {
+ for (_, documents, _, _, _) in embeddings_queue.iter() {
batch_documents
.extend(documents.iter().map(|document| document.content.as_str()));
}
@@ -333,7 +309,7 @@ impl VectorStore {
j += 1;
}
- for (worktree_id, documents, path, mtime) in
+ for (worktree_id, documents, path, mtime, job_handle) in
embeddings_queue.into_iter()
{
for document in documents.iter() {
@@ -350,6 +326,7 @@ impl VectorStore {
documents,
path,
mtime,
+ job_handle,
})
.await
.unwrap();
@@ -372,9 +349,16 @@ impl VectorStore {
worktree_id,
path,
mtime,
+ job_handle,
} => {
queue_len += &documents.len();
- embeddings_queue.push((worktree_id, documents, path, mtime));
+ embeddings_queue.push((
+ worktree_id,
+ documents,
+ path,
+ mtime,
+ job_handle,
+ ));
queue_len >= EMBEDDINGS_BATCH_SIZE
}
EmbeddingJob::Flush => true,
@@ -420,6 +404,7 @@ impl VectorStore {
worktree_id: pending_file.worktree_db_id,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
+ job_handle: pending_file.job_handle,
documents,
})
.unwrap();
@@ -439,6 +424,7 @@ impl VectorStore {
embedding_provider,
language_registry,
db_update_tx,
+ next_job_id: Default::default(),
parsing_files_tx,
_db_update_task,
_embed_batch_task,
@@ -471,11 +457,11 @@ impl VectorStore {
async move { rx.await? }
}
- fn add_project(
+ fn index_project(
&mut self,
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
- ) -> Task<Result<()>> {
+ ) -> Task<Result<usize>> {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
@@ -494,21 +480,16 @@ impl VectorStore {
})
.collect::<Vec<_>>();
- let fs = self.fs.clone();
let language_registry = self.language_registry.clone();
- let database_url = self.database_url.clone();
let db_update_tx = self.db_update_tx.clone();
let parsing_files_tx = self.parsing_files_tx.clone();
+ let next_job_id = self.next_job_id.clone();
cx.spawn(|this, mut cx| async move {
futures::future::join_all(worktree_scans_complete).await;
let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
- if let Some(db_directory) = database_url.parent() {
- fs.create_dir(db_directory).await.log_err();
- }
-
let worktrees = project.read_with(&cx, |project, cx| {
project
.worktrees(cx)
@@ -516,109 +497,115 @@ impl VectorStore {
.collect::<Vec<_>>()
});
- let mut worktree_file_times = HashMap::new();
+ let mut worktree_file_mtimes = HashMap::new();
let mut db_ids_by_worktree_id = HashMap::new();
for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
let db_id = db_id?;
db_ids_by_worktree_id.insert(worktree.id(), db_id);
- worktree_file_times.insert(
+ worktree_file_mtimes.insert(
worktree.id(),
this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
.await?,
);
}
- cx.background()
- .spawn({
- let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
- let db_update_tx = db_update_tx.clone();
- let language_registry = language_registry.clone();
- let parsing_files_tx = parsing_files_tx.clone();
- async move {
- let t0 = Instant::now();
- for worktree in worktrees.into_iter() {
- let mut file_mtimes =
- worktree_file_times.remove(&worktree.id()).unwrap();
- for file in worktree.files(false, 0) {
- let absolute_path = worktree.absolutize(&file.path);
-
- if let Ok(language) = language_registry
- .language_for_file(&absolute_path, None)
- .await
- {
- if language
- .grammar()
- .and_then(|grammar| grammar.embedding_config.as_ref())
- .is_none()
- {
- continue;
- }
-
- let path_buf = file.path.to_path_buf();
- let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
- let already_stored = stored_mtime
- .map_or(false, |existing_mtime| {
- existing_mtime == file.mtime
- });
-
- if !already_stored {
- log::trace!("sending for parsing: {:?}", path_buf);
- parsing_files_tx
- .try_send(PendingFile {
- worktree_db_id: db_ids_by_worktree_id
- [&worktree.id()],
- relative_path: path_buf,
- absolute_path,
- language,
- modified_time: file.mtime,
- })
- .unwrap();
- }
- }
- }
- for file in file_mtimes.keys() {
- db_update_tx
- .try_send(DbOperation::Delete {
- worktree_id: db_ids_by_worktree_id[&worktree.id()],
- path: file.to_owned(),
- })
- .unwrap();
- }
- }
- log::trace!(
- "parsing worktree completed in {:?}",
- t0.elapsed().as_millis()
- );
- }
- })
- .detach();
-
// let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
- this.update(&mut cx, |this, cx| {
- // The below is managing for updated on save
- // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
- // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
- let _subscription = cx.subscribe(&project, |this, project, event, cx| {
- if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
- this.project_entries_changed(project, changes.clone(), cx, worktree_id);
- }
- });
-
+ let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
+ this.update(&mut cx, |this, _| {
this.projects.insert(
project.downgrade(),
ProjectState {
- pending_files: HashMap::new(),
- worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
- _subscription,
+ worktree_db_ids: db_ids_by_worktree_id
+ .iter()
+ .map(|(a, b)| (*a, *b))
+ .collect(),
+ outstanding_jobs: outstanding_jobs.clone(),
},
);
});
- anyhow::Ok(())
+ cx.background()
+ .spawn(async move {
+ let mut count = 0;
+ let t0 = Instant::now();
+ for worktree in worktrees.into_iter() {
+ let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
+ for file in worktree.files(false, 0) {
+ let absolute_path = worktree.absolutize(&file.path);
+
+ if let Ok(language) = language_registry
+ .language_for_file(&absolute_path, None)
+ .await
+ {
+ if language
+ .grammar()
+ .and_then(|grammar| grammar.embedding_config.as_ref())
+ .is_none()
+ {
+ continue;
+ }
+
+ let path_buf = file.path.to_path_buf();
+ let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
+ let already_stored = stored_mtime
+ .map_or(false, |existing_mtime| existing_mtime == file.mtime);
+
+ if !already_stored {
+ log::trace!("sending for parsing: {:?}", path_buf);
+ count += 1;
+ let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
+ let job_handle = JobHandle {
+ id: job_id,
+ set: Arc::downgrade(&outstanding_jobs),
+ };
+ outstanding_jobs.lock().insert(job_id);
+ parsing_files_tx
+ .try_send(PendingFile {
+ worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
+ relative_path: path_buf,
+ absolute_path,
+ language,
+ job_handle,
+ modified_time: file.mtime,
+ })
+ .unwrap();
+ }
+ }
+ }
+ for file in file_mtimes.keys() {
+ db_update_tx
+ .try_send(DbOperation::Delete {
+ worktree_id: db_ids_by_worktree_id[&worktree.id()],
+ path: file.to_owned(),
+ })
+ .unwrap();
+ }
+ }
+ log::trace!(
+ "parsing worktree completed in {:?}",
+ t0.elapsed().as_millis()
+ );
+
+ Ok(count)
+ })
+ .await
})
}
- pub fn search(
+ pub fn remaining_files_to_index_for_project(
+ &self,
+ project: &ModelHandle<Project>,
+ ) -> Option<usize> {
+ Some(
+ self.projects
+ .get(&project.downgrade())?
+ .outstanding_jobs
+ .lock()
+ .len(),
+ )
+ }
+
+ pub fn search_project(
&mut self,
project: ModelHandle<Project>,
phrase: String,
@@ -682,110 +669,16 @@ impl VectorStore {
})
})
}
-
- fn project_entries_changed(
- &mut self,
- project: ModelHandle<Project>,
- changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
- cx: &mut ModelContext<'_, VectorStore>,
- worktree_id: &WorktreeId,
- ) -> Option<()> {
- let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
-
- let worktree = project
- .read(cx)
- .worktree_for_id(worktree_id.clone(), cx)?
- .read(cx)
- .snapshot();
-
- let worktree_db_id = self
- .projects
- .get(&project.downgrade())?
- .db_id_for_worktree_id(worktree.id())?;
- let file_mtimes = self.get_file_mtimes(worktree_db_id);
-
- let language_registry = self.language_registry.clone();
-
- cx.spawn(|this, mut cx| async move {
- let file_mtimes = file_mtimes.await.log_err()?;
-
- for change in changes.into_iter() {
- let change_path = change.0.clone();
- let absolute_path = worktree.absolutize(&change_path);
-
- // Skip if git ignored or symlink
- if let Some(entry) = worktree.entry_for_id(change.1) {
- if entry.is_ignored || entry.is_symlink || entry.is_external {
- continue;
- }
- }
-
- match change.2 {
- PathChange::Removed => this.update(&mut cx, |this, _| {
- this.db_update_tx
- .try_send(DbOperation::Delete {
- worktree_id: worktree_db_id,
- path: absolute_path,
- })
- .unwrap();
- }),
- _ => {
- if let Ok(language) = language_registry
- .language_for_file(&change_path.to_path_buf(), None)
- .await
- {
- if language
- .grammar()
- .and_then(|grammar| grammar.embedding_config.as_ref())
- .is_none()
- {
- continue;
- }
-
- let modified_time =
- change_path.metadata().log_err()?.modified().log_err()?;
-
- let existing_time = file_mtimes.get(&change_path.to_path_buf());
- let already_stored = existing_time
- .map_or(false, |existing_time| &modified_time != existing_time);
-
- if !already_stored {
- this.update(&mut cx, |this, _| {
- let reindex_time = modified_time
- + Duration::from_secs(reindexing_delay as u64);
-
- let project_state =
- this.projects.get_mut(&project.downgrade())?;
- project_state.update_pending_files(
- PendingFile {
- relative_path: change_path.to_path_buf(),
- absolute_path,
- modified_time,
- worktree_db_id,
- language: language.clone(),
- },
- reindex_time,
- );
-
- for file in project_state.get_outstanding_files() {
- this.parsing_files_tx.try_send(file).unwrap();
- }
- Some(())
- });
- }
- }
- }
- }
- }
-
- Some(())
- })
- .detach();
-
- Some(())
- }
}
impl Entity for VectorStore {
type Event = ();
}
+
+impl Drop for JobHandle {
+ fn drop(&mut self) {
+ if let Some(set) = self.set.upgrade() {
+ set.lock().remove(&self.id);
+ }
+ }
+}
@@ -9,11 +9,17 @@ use anyhow::Result;
use async_trait::async_trait;
use gpui::{Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry};
-use project::{project_settings::ProjectSettings, FakeFs, Project};
+use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
-use std::{path::Path, sync::Arc};
+use std::{
+ path::Path,
+ sync::{
+ atomic::{self, AtomicUsize},
+ Arc,
+ },
+};
use unindent::Unindent;
#[ctor::ctor]
@@ -62,29 +68,37 @@ async fn test_vector_store(cx: &mut TestAppContext) {
let db_dir = tempdir::TempDir::new("vector-store").unwrap();
let db_path = db_dir.path().join("db.sqlite");
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let store = VectorStore::new(
fs.clone(),
db_path,
- Arc::new(FakeEmbeddingProvider),
+ embedding_provider.clone(),
languages,
cx.to_async(),
)
.await
.unwrap();
- let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
+ let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
let worktree_id = project.read_with(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
});
- store
- .update(cx, |store, cx| store.add_project(project.clone(), cx))
+ let file_count = store
+ .update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
+ assert_eq!(file_count, 2);
cx.foreground().run_until_parked();
+ store.update(cx, |store, _cx| {
+ assert_eq!(
+ store.remaining_files_to_index_for_project(&project),
+ Some(0)
+ );
+ });
let search_results = store
.update(cx, |store, cx| {
- store.search(project.clone(), "aaaa".to_string(), 5, cx)
+ store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
})
.await
.unwrap();
@@ -92,10 +106,45 @@ async fn test_vector_store(cx: &mut TestAppContext) {
assert_eq!(search_results[0].byte_range.start, 0);
assert_eq!(search_results[0].name, "aaa");
assert_eq!(search_results[0].worktree_id, worktree_id);
+
+ fs.save(
+ "/the-root/src/file2.rs".as_ref(),
+ &"
+ fn dddd() { println!(\"ddddd!\"); }
+ struct pqpqpqp {}
+ "
+ .unindent()
+ .into(),
+ Default::default(),
+ )
+ .await
+ .unwrap();
+
+ cx.foreground().run_until_parked();
+
+ let prev_embedding_count = embedding_provider.embedding_count();
+ let file_count = store
+ .update(cx, |store, cx| store.index_project(project.clone(), cx))
+ .await
+ .unwrap();
+ assert_eq!(file_count, 1);
+
+ cx.foreground().run_until_parked();
+ store.update(cx, |store, _cx| {
+ assert_eq!(
+ store.remaining_files_to_index_for_project(&project),
+ Some(0)
+ );
+ });
+
+ assert_eq!(
+ embedding_provider.embedding_count() - prev_embedding_count,
+ 2
+ );
}
#[gpui::test]
-async fn test_code_context_retrieval(cx: &mut TestAppContext) {
+async fn test_code_context_retrieval() {
let language = rust_lang();
let mut retriever = CodeContextRetriever::new();
@@ -181,11 +230,22 @@ fn test_dot_product(mut rng: StdRng) {
}
}
-struct FakeEmbeddingProvider;
+#[derive(Default)]
+struct FakeEmbeddingProvider {
+ embedding_count: AtomicUsize,
+}
+
+impl FakeEmbeddingProvider {
+ fn embedding_count(&self) -> usize {
+ self.embedding_count.load(atomic::Ordering::SeqCst)
+ }
+}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ self.embedding_count
+ .fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans
.iter()
.map(|span| {