@@ -1,4 +1,5 @@
use std::{
+ cmp::Ordering,
collections::HashMap,
path::{Path, PathBuf},
rc::Rc,
@@ -14,16 +15,6 @@ use rusqlite::{
types::{FromSql, FromSqlResult, ValueRef},
};
-// Note this is not an appropriate document
-#[derive(Debug)]
-pub struct DocumentRecord {
- pub id: usize,
- pub file_id: usize,
- pub offset: usize,
- pub name: String,
- pub embedding: Embedding,
-}
-
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
@@ -32,7 +23,7 @@ pub struct FileRecord {
}
#[derive(Debug)]
-pub struct Embedding(pub Vec<f32>);
+struct Embedding(pub Vec<f32>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
@@ -205,10 +196,35 @@ impl VectorDatabase {
Ok(result)
}
- pub fn for_each_document(
+ pub fn top_k_search(
+ &self,
+ worktree_ids: &[i64],
+ query_embedding: &Vec<f32>,
+ limit: usize,
+ ) -> Result<Vec<(i64, PathBuf, usize, String)>> {
+ let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+ self.for_each_document(&worktree_ids, |id, embedding| {
+ eprintln!("document {id} {embedding:?}");
+
+ let similarity = dot(&embedding, &query_embedding);
+ let ix = match results
+ .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
+ {
+ Ok(ix) => ix,
+ Err(ix) => ix,
+ };
+ results.insert(ix, (id, similarity));
+ results.truncate(limit);
+ })?;
+
+ let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
+ self.get_documents_by_ids(&ids)
+ }
+
+ fn for_each_document(
&self,
worktree_ids: &[i64],
- mut f: impl FnMut(i64, Embedding),
+ mut f: impl FnMut(i64, Vec<f32>),
) -> Result<()> {
let mut query_statement = self.db.prepare(
"
@@ -221,16 +237,20 @@ impl VectorDatabase {
files.worktree_id IN rarray(?)
",
)?;
+
query_statement
.query_map(params![ids_to_sql(worktree_ids)], |row| {
- Ok((row.get(0)?, row.get(1)?))
+ Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
- .for_each(|row| f(row.0, row.1));
+ .for_each(|(id, embedding)| {
+ dbg!("id");
+ f(id, embedding.0)
+ });
Ok(())
}
- pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
+ fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
let mut statement = self.db.prepare(
"
SELECT
@@ -279,3 +299,29 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
.collect::<Vec<_>>(),
)
}
+
+pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
+ let len = vec_a.len();
+ assert_eq!(len, vec_b.len());
+
+ let mut result = 0.0;
+ unsafe {
+ matrixmultiply::sgemm(
+ 1,
+ len,
+ 1,
+ 1.0,
+ vec_a.as_ptr(),
+ len as isize,
+ 1,
+ vec_b.as_ptr(),
+ 1,
+ len as isize,
+ 0.0,
+ &mut result as *mut f32,
+ 1,
+ 1,
+ );
+ }
+ result
+}
@@ -20,7 +20,6 @@ use parsing::{CodeContextRetriever, ParsedFile};
use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
use smol::channel;
use std::{
- cmp::Ordering,
collections::HashMap,
path::{Path, PathBuf},
sync::Arc,
@@ -112,10 +111,10 @@ pub struct VectorStore {
database_url: Arc<PathBuf>,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
- db_update_tx: channel::Sender<DbWrite>,
+ db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>,
_db_update_task: Task<()>,
- _embed_batch_task: Vec<Task<()>>,
+ _embed_batch_task: Task<()>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
@@ -128,6 +127,30 @@ struct ProjectState {
}
impl ProjectState {
+ fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
+ self.worktree_db_ids
+ .iter()
+ .find_map(|(worktree_id, db_id)| {
+ if *worktree_id == id {
+ Some(*db_id)
+ } else {
+ None
+ }
+ })
+ }
+
+ fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
+ self.worktree_db_ids
+ .iter()
+ .find_map(|(worktree_id, db_id)| {
+ if *db_id == id {
+ Some(*worktree_id)
+ } else {
+ None
+ }
+ })
+ }
+
fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
// If Pending File Already Exists, Replace it with the new one
// but keep the old indexing time
@@ -185,7 +208,7 @@ pub struct SearchResult {
pub file_path: PathBuf,
}
-enum DbWrite {
+enum DbOperation {
InsertFile {
worktree_id: i64,
indexed_file: ParsedFile,
@@ -198,6 +221,10 @@ enum DbWrite {
path: PathBuf,
sender: oneshot::Sender<Result<i64>>,
},
+ FileMTimes {
+ worktree_id: i64,
+ sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
+ },
}
enum EmbeddingJob {
@@ -243,20 +270,27 @@ impl VectorStore {
let _db_update_task = cx.background().spawn(async move {
while let Ok(job) = db_update_rx.recv().await {
match job {
- DbWrite::InsertFile {
+ DbOperation::InsertFile {
worktree_id,
indexed_file,
} => {
log::info!("Inserting Data for {:?}", &indexed_file.path);
db.insert_file(worktree_id, indexed_file).log_err();
}
- DbWrite::Delete { worktree_id, path } => {
+ DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
}
- DbWrite::FindOrCreateWorktree { path, sender } => {
+ DbOperation::FindOrCreateWorktree { path, sender } => {
let id = db.find_or_create_worktree(&path);
sender.send(id).ok();
}
+ DbOperation::FileMTimes {
+ worktree_id: worktree_db_id,
+ sender,
+ } => {
+ let file_mtimes = db.get_file_mtimes(worktree_db_id);
+ sender.send(file_mtimes).ok();
+ }
}
}
});
@@ -264,24 +298,18 @@ impl VectorStore {
// embed_tx/rx: Embed Batch and Send to Database
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
- let mut _embed_batch_task = Vec::new();
- for _ in 0..1 {
- //cx.background().num_cpus() {
+ let _embed_batch_task = cx.background().spawn({
let db_update_tx = db_update_tx.clone();
- let embed_batch_rx = embed_batch_rx.clone();
let embedding_provider = embedding_provider.clone();
- _embed_batch_task.push(cx.background().spawn(async move {
- while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
+ async move {
+ while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
// Construct Batch
- let mut embeddings_queue = embeddings_queue.clone();
let mut document_spans = vec![];
- for (_, _, document_span) in embeddings_queue.clone().into_iter() {
- document_spans.extend(document_span);
+ for (_, _, document_span) in embeddings_queue.iter() {
+ document_spans.extend(document_span.iter().map(|s| s.as_str()));
}
- if let Ok(embeddings) = embedding_provider
- .embed_batch(document_spans.iter().map(|x| &**x).collect())
- .await
+ if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
{
let mut i = 0;
let mut j = 0;
@@ -306,7 +334,7 @@ impl VectorStore {
}
db_update_tx
- .send(DbWrite::InsertFile {
+ .send(DbOperation::InsertFile {
worktree_id,
indexed_file,
})
@@ -315,8 +343,9 @@ impl VectorStore {
}
}
}
- }))
- }
+ }
+ });
+
// batch_tx/rx: Batch Files to Send for Embeddings
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
let _batch_files_task = cx.background().spawn(async move {
@@ -398,7 +427,21 @@ impl VectorStore {
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
- .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
+ .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
+ .unwrap();
+ async move { rx.await? }
+ }
+
+ fn get_file_mtimes(
+ &self,
+ worktree_id: i64,
+ ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
+ let (tx, rx) = oneshot::channel();
+ self.db_update_tx
+ .try_send(DbOperation::FileMTimes {
+ worktree_id,
+ sender: tx,
+ })
.unwrap();
async move { rx.await? }
}
@@ -450,26 +493,17 @@ impl VectorStore {
.collect::<Vec<_>>()
});
- // Here we query the worktree ids, and yet we dont have them elsewhere
- // We likely want to clean up these datastructures
- let (mut worktree_file_times, db_ids_by_worktree_id) = cx
- .background()
- .spawn({
- let worktrees = worktrees.clone();
- async move {
- let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
- let mut db_ids_by_worktree_id = HashMap::new();
- let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
- HashMap::new();
- for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
- let db_id = db_id?;
- db_ids_by_worktree_id.insert(worktree.id(), db_id);
- file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
- }
- anyhow::Ok((file_times, db_ids_by_worktree_id))
- }
- })
- .await?;
+ let mut worktree_file_times = HashMap::new();
+ let mut db_ids_by_worktree_id = HashMap::new();
+ for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
+ let db_id = db_id?;
+ db_ids_by_worktree_id.insert(worktree.id(), db_id);
+ worktree_file_times.insert(
+ worktree.id(),
+ this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
+ .await?,
+ );
+ }
cx.background()
.spawn({
@@ -520,7 +554,7 @@ impl VectorStore {
}
for file in file_mtimes.keys() {
db_update_tx
- .try_send(DbWrite::Delete {
+ .try_send(DbOperation::Delete {
worktree_id: db_ids_by_worktree_id[&worktree.id()],
path: file.to_owned(),
})
@@ -542,7 +576,7 @@ impl VectorStore {
// greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
let _subscription = cx.subscribe(&project, |this, project, event, cx| {
if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
- this.project_entries_changed(project, changes, cx, worktree_id);
+ this.project_entries_changed(project, changes.clone(), cx, worktree_id);
}
});
@@ -578,16 +612,7 @@ impl VectorStore {
.worktrees(cx)
.filter_map(|worktree| {
let worktree_id = worktree.read(cx).id();
- project_state
- .worktree_db_ids
- .iter()
- .find_map(|(id, db_id)| {
- if *id == worktree_id {
- Some(*db_id)
- } else {
- None
- }
- })
+ project_state.db_id_for_worktree_id(worktree_id)
})
.collect::<Vec<_>>();
@@ -606,24 +631,12 @@ impl VectorStore {
.next()
.unwrap();
- let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
- database.for_each_document(&worktree_db_ids, |id, embedding| {
- let similarity = dot(&embedding.0, &phrase_embedding);
- let ix = match results.binary_search_by(|(_, s)| {
- similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
- }) {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
- results.insert(ix, (id, similarity));
- results.truncate(limit);
- })?;
-
- let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
- database.get_documents_by_ids(&ids)
+ database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
})
.await?;
+ dbg!(&documents);
+
this.read_with(&cx, |this, _| {
let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
state
@@ -634,17 +647,7 @@ impl VectorStore {
Ok(documents
.into_iter()
.filter_map(|(worktree_db_id, file_path, offset, name)| {
- let worktree_id =
- project_state
- .worktree_db_ids
- .iter()
- .find_map(|(id, db_id)| {
- if *db_id == worktree_db_id {
- Some(*id)
- } else {
- None
- }
- })?;
+ let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
Some(SearchResult {
worktree_id,
name,
@@ -660,51 +663,36 @@ impl VectorStore {
fn project_entries_changed(
&mut self,
project: ModelHandle<Project>,
- changes: &[(Arc<Path>, ProjectEntryId, PathChange)],
+ changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
cx: &mut ModelContext<'_, VectorStore>,
worktree_id: &WorktreeId,
) -> Option<()> {
- let project_state = self.projects.get_mut(&project.downgrade())?;
- let worktree_db_ids = project_state.worktree_db_ids.clone();
- let worktree = project.read(cx).worktree_for_id(worktree_id.clone(), cx)?;
-
- // Get Database
- let (file_mtimes, worktree_db_id) = {
- if let Ok(db) = VectorDatabase::new(self.database_url.to_string_lossy().into()) {
- let worktree_db_id = {
- let mut found_db_id = None;
- for (w_id, db_id) in worktree_db_ids.into_iter() {
- if &w_id == &worktree.read(cx).id() {
- found_db_id = Some(db_id)
- }
- }
- found_db_id
- }?;
-
- let file_mtimes = db.get_file_mtimes(worktree_db_id).log_err()?;
+ let worktree = project
+ .read(cx)
+ .worktree_for_id(worktree_id.clone(), cx)?
+ .read(cx)
+ .snapshot();
- Some((file_mtimes, worktree_db_id))
- } else {
- return None;
- }
- }?;
+ let worktree_db_id = self
+ .projects
+ .get(&project.downgrade())?
+ .db_id_for_worktree_id(worktree.id())?;
+ let file_mtimes = self.get_file_mtimes(worktree_db_id);
- // Iterate Through Changes
let language_registry = self.language_registry.clone();
- let parsing_files_tx = self.parsing_files_tx.clone();
- smol::block_on(async move {
+ cx.spawn(|this, mut cx| async move {
+ let file_mtimes = file_mtimes.await.log_err()?;
+
for change in changes.into_iter() {
let change_path = change.0.clone();
- let absolute_path = worktree.read(cx).absolutize(&change_path);
+ let absolute_path = worktree.absolutize(&change_path);
// Skip if git ignored or symlink
- if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
- if entry.is_ignored || entry.is_symlink {
+ if let Some(entry) = worktree.entry_for_id(change.1) {
+ if entry.is_ignored || entry.is_symlink || entry.is_external {
continue;
- } else {
- log::info!("Testing for Reindexing: {:?}", &change_path);
}
- };
+ }
if let Ok(language) = language_registry
.language_for_file(&change_path.to_path_buf(), None)
@@ -718,27 +706,18 @@ impl VectorStore {
continue;
}
- if let Some(modified_time) = {
- let metadata = change_path.metadata();
- if metadata.is_err() {
- None
- } else {
- let mtime = metadata.unwrap().modified();
- if mtime.is_err() {
- None
- } else {
- Some(mtime.unwrap())
- }
- }
- } {
- let existing_time = file_mtimes.get(&change_path.to_path_buf());
- let already_stored = existing_time
- .map_or(false, |existing_time| &modified_time != existing_time);
+ let modified_time = change_path.metadata().log_err()?.modified().log_err()?;
- let reindex_time =
- modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS);
+ let existing_time = file_mtimes.get(&change_path.to_path_buf());
+ let already_stored = existing_time
+ .map_or(false, |existing_time| &modified_time != existing_time);
- if !already_stored {
+ if !already_stored {
+ this.update(&mut cx, |this, _| {
+ let reindex_time =
+ modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS);
+
+ let project_state = this.projects.get_mut(&project.downgrade())?;
project_state.update_pending_files(
PendingFile {
relative_path: change_path.to_path_buf(),
@@ -751,13 +730,18 @@ impl VectorStore {
);
for file in project_state.get_outstanding_files() {
- parsing_files_tx.try_send(file).unwrap();
+ this.parsing_files_tx.try_send(file).unwrap();
}
- }
+ Some(())
+ });
}
}
}
- });
+
+ Some(())
+ })
+ .detach();
+
Some(())
}
}
@@ -765,29 +749,3 @@ impl VectorStore {
impl Entity for VectorStore {
type Event = ();
}
-
-fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
- let len = vec_a.len();
- assert_eq!(len, vec_b.len());
-
- let mut result = 0.0;
- unsafe {
- matrixmultiply::sgemm(
- 1,
- len,
- 1,
- 1.0,
- vec_a.as_ptr(),
- len as isize,
- 1,
- vec_b.as_ptr(),
- 1,
- len as isize,
- 0.0,
- &mut result as *mut f32,
- 1,
- 1,
- );
- }
- result
-}