Cargo.lock 🔗
@@ -7967,6 +7967,7 @@ dependencies = [
"serde_json",
"sha-1 0.10.1",
"smol",
+ "tempdir",
"tree-sitter",
"tree-sitter-rust",
"unindent",
KCaverly and maxbrunsfeld created
Co-authored-by: maxbrunsfeld <max@zed.dev>
Cargo.lock | 1
crates/vector_store/Cargo.toml | 3
crates/vector_store/src/db.rs | 184 ++++++++++++++++----
crates/vector_store/src/vector_store.rs | 170 +++++++++++-------
crates/vector_store/src/vector_store_tests.rs | 23 +
5 files changed, 269 insertions(+), 112 deletions(-)
@@ -7967,6 +7967,7 @@ dependencies = [
"serde_json",
"sha-1 0.10.1",
"smol",
+ "tempdir",
"tree-sitter",
"tree-sitter-rust",
"unindent",
@@ -17,7 +17,7 @@ util = { path = "../util" }
anyhow.workspace = true
futures.workspace = true
smol.workspace = true
-rusqlite = { version = "0.27.0", features=["blob"] }
+rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
@@ -38,3 +38,4 @@ workspace = { path = "../workspace", features = ["test-support"] }
tree-sitter-rust = "*"
rand.workspace = true
unindent.workspace = true
+tempdir.workspace = true
@@ -7,9 +7,10 @@ use anyhow::{anyhow, Result};
use rusqlite::{
params,
- types::{FromSql, FromSqlResult, ValueRef},
- Connection,
+ types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
+ ToSql,
};
+use sha1::{Digest, Sha1};
use crate::IndexedFile;
@@ -32,7 +33,60 @@ pub struct DocumentRecord {
pub struct FileRecord {
pub id: usize,
pub relative_path: String,
- pub sha1: String,
+ pub sha1: FileSha1,
+}
+
+#[derive(Debug)]
+pub struct FileSha1(pub Vec<u8>);
+
+impl FileSha1 {
+ pub fn from_str(content: String) -> Self {
+ let mut hasher = Sha1::new();
+ hasher.update(content);
+ let sha1 = hasher.finalize()[..]
+ .into_iter()
+ .map(|val| val.to_owned())
+ .collect::<Vec<u8>>();
+ return FileSha1(sha1);
+ }
+
+ pub fn equals(&self, content: &String) -> bool {
+ let mut hasher = Sha1::new();
+ hasher.update(content);
+ let sha1 = hasher.finalize()[..]
+ .into_iter()
+ .map(|val| val.to_owned())
+ .collect::<Vec<u8>>();
+
+ let equal = self
+ .0
+ .clone()
+ .into_iter()
+ .zip(sha1)
+ .filter(|&(a, b)| a == b)
+ .count()
+ == self.0.len();
+
+ equal
+ }
+}
+
+impl ToSql for FileSha1 {
+ fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
+ return self.0.to_sql();
+ }
+}
+
+impl FromSql for FileSha1 {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let bytes = value.as_blob()?;
+ Ok(FileSha1(
+ bytes
+ .into_iter()
+ .map(|val| val.to_owned())
+ .collect::<Vec<u8>>(),
+ ))
+ }
}
#[derive(Debug)]
@@ -63,6 +117,8 @@ impl VectorDatabase {
}
fn initialize_database(&self) -> Result<()> {
+ rusqlite::vtab::array::load_module(&self.db)?;
+
// This will create the database if it doesnt exist
// Initialize Vector Databasing Tables
@@ -81,7 +137,7 @@ impl VectorDatabase {
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
- sha1 NVARCHAR(40) NOT NULL,
+ sha1 BLOB NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
@@ -102,30 +158,23 @@ impl VectorDatabase {
Ok(())
}
- // pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
- // // Check if we have the project, if we do, return the ID
- // // If we do not have the project, insert the project and return the ID
-
- // let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
-
- // let projects_query = db.prepare(&format!(
- // "SELECT id FROM projects WHERE path = {}",
- // project_path.to_str().unwrap() // This is unsafe
- // ))?;
-
- // let project_id = db.last_insert_rowid();
-
- // return Ok(project_id as usize);
- // }
-
- pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
+ pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
// Write to files table, and return generated id.
- let files_insert = self.db.execute(
- "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
- params![indexed_file.path.to_str(), indexed_file.sha1],
+ log::info!("Inserting File!");
+ self.db.execute(
+ "
+ DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
+ ",
+ params![worktree_id, indexed_file.path.to_str()],
+ )?;
+ self.db.execute(
+ "
+ INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
+ ",
+ params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
)?;
- let inserted_id = self.db.last_insert_rowid();
+ let file_id = self.db.last_insert_rowid();
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
@@ -135,7 +184,7 @@ impl VectorDatabase {
self.db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
params![
- inserted_id,
+ file_id,
document.offset.to_string(),
document.name,
embedding_blob
@@ -147,25 +196,41 @@ impl VectorDatabase {
}
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
+ // Check that the absolute path doesnt exist
+ let mut worktree_query = self
+ .db
+ .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+
+ let worktree_id = worktree_query
+ .query_row(params![worktree_root_path.to_string_lossy()], |row| {
+ Ok(row.get::<_, i64>(0)?)
+ })
+ .map_err(|err| anyhow!(err));
+
+ if worktree_id.is_ok() {
+ return worktree_id;
+ }
+
+ // If worktree_id is Err, insert new worktree
self.db.execute(
"
INSERT into worktrees (absolute_path) VALUES (?1)
- ON CONFLICT DO NOTHING
",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(self.db.last_insert_rowid())
}
- pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
- let mut statement = self
- .db
- .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
- let mut result = Vec::new();
- for row in
- statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
- {
- result.push(row?);
+ pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
+ let mut statement = self.db.prepare(
+ "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
+ )?;
+ let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
+ for row in statement.query_map(params![worktree_id], |row| {
+ Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
+ })? {
+ let row = row?;
+ result.insert(row.0, row.1);
}
Ok(result)
}
@@ -204,6 +269,53 @@ impl VectorDatabase {
Ok(())
}
+ pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
+ let mut statement = self.db.prepare(
+ "
+ SELECT
+ documents.id, files.relative_path, documents.offset, documents.name
+ FROM
+ documents, files
+ WHERE
+ documents.file_id = files.id AND
+ documents.id in rarray(?)
+ ",
+ )?;
+
+ let result_iter = statement.query_map(
+ params![std::rc::Rc::new(
+ ids.iter()
+ .copied()
+ .map(|v| rusqlite::types::Value::from(v))
+ .collect::<Vec<_>>()
+ )],
+ |row| {
+ Ok((
+ row.get::<_, i64>(0)?,
+ row.get::<_, String>(1)?.into(),
+ row.get(2)?,
+ row.get(3)?,
+ ))
+ },
+ )?;
+
+ let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
+ for row in result_iter {
+ let (id, path, offset, name) = row?;
+ values_by_id.insert(id, (path, offset, name));
+ }
+
+ let mut results = Vec::with_capacity(ids.len());
+ for id in ids {
+ let (path, offset, name) = values_by_id
+ .remove(id)
+ .ok_or(anyhow!("missing document id {}", id))?;
+ results.push((path, offset, name));
+ }
+
+ Ok(results)
+ }
+
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
let mut query_statement = self
.db
@@ -7,15 +7,14 @@ mod search;
mod vector_store_tests;
use anyhow::{anyhow, Result};
-use db::{VectorDatabase, VECTOR_DB_URL};
-use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
+use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
+use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
-use language::LanguageRegistry;
+use language::{Language, LanguageRegistry};
use parsing::Document;
use project::{Fs, Project};
-use search::{BruteForceSearch, VectorSearch};
use smol::channel;
-use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
+use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt, TryFutureExt};
use workspace::WorkspaceCreated;
@@ -45,7 +44,7 @@ pub fn init(
let project = workspace.read(cx).project().clone();
if project.read(cx).is_local() {
vector_store.update(cx, |store, cx| {
- store.add_project(project, cx);
+ store.add_project(project, cx).detach();
});
}
}
@@ -57,16 +56,10 @@ pub fn init(
#[derive(Debug)]
pub struct IndexedFile {
path: PathBuf,
- sha1: String,
+ sha1: FileSha1,
documents: Vec<Document>,
}
-// struct SearchResult {
-// path: PathBuf,
-// offset: usize,
-// name: String,
-// distance: f32,
-// }
struct VectorStore {
fs: Arc<dyn Fs>,
database_url: Arc<str>,
@@ -99,20 +92,10 @@ impl VectorStore {
cursor: &mut QueryCursor,
parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider,
- language_registry: &Arc<LanguageRegistry>,
+ language: Arc<Language>,
file_path: PathBuf,
content: String,
) -> Result<IndexedFile> {
- dbg!(&file_path, &content);
-
- let language = language_registry
- .language_for_file(&file_path, None)
- .await?;
-
- if language.name().as_ref() != "Rust" {
- Err(anyhow!("unsupported language"))?;
- }
-
let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
let outline_config = grammar
.outline_config
@@ -156,9 +139,11 @@ impl VectorStore {
document.embedding = embedding;
}
+ let sha1 = FileSha1::from_str(content);
+
return Ok(IndexedFile {
path: file_path,
- sha1: String::new(),
+ sha1,
documents,
});
}
@@ -171,7 +156,13 @@ impl VectorStore {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
- .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
+ .map(|worktree| {
+ let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
+ async move {
+ scan_complete.await;
+ log::info!("worktree scan completed");
+ }
+ })
.collect::<Vec<_>>();
let fs = self.fs.clone();
@@ -182,6 +173,13 @@ impl VectorStore {
cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await;
+ // TODO: remove this after fixing the bug in scan_complete
+ cx.background()
+ .timer(std::time::Duration::from_secs(3))
+ .await;
+
+ let db = VectorDatabase::new(&database_url)?;
+
let worktrees = project.read_with(&cx, |project, cx| {
project
.worktrees(cx)
@@ -189,37 +187,74 @@ impl VectorStore {
.collect::<Vec<_>>()
});
- let db = VectorDatabase::new(&database_url)?;
let worktree_root_paths = worktrees
.iter()
.map(|worktree| worktree.abs_path().clone())
.collect::<Vec<_>>();
- let (db, file_hashes) = cx
+
+ // Here we query the worktree ids, and yet we dont have them elsewhere
+ // We likely want to clean up these datastructures
+ let (db, worktree_hashes, worktree_ids) = cx
.background()
.spawn(async move {
- let mut hashes = Vec::new();
+ let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
+ let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
for worktree_root_path in worktree_root_paths {
let worktree_id =
db.find_or_create_worktree(worktree_root_path.as_ref())?;
- hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
+ worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
+ hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
}
- anyhow::Ok((db, hashes))
+ anyhow::Ok((db, hashes, worktree_ids))
})
.await?;
- let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
- let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
+ let (paths_tx, paths_rx) =
+ channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
+ let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
cx.background()
.spawn({
let fs = fs.clone();
async move {
for worktree in worktrees.into_iter() {
+ let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
+ let file_hashes = &worktree_hashes[&worktree_id];
for file in worktree.files(false, 0) {
let absolute_path = worktree.absolutize(&file.path);
- dbg!(&absolute_path);
- if let Some(content) = fs.load(&absolute_path).await.log_err() {
- dbg!(&content);
- paths_tx.try_send((0, absolute_path, content)).unwrap();
+
+ if let Ok(language) = language_registry
+ .language_for_file(&absolute_path, None)
+ .await
+ {
+ if language.name().as_ref() != "Rust" {
+ continue;
+ }
+
+ if let Some(content) = fs.load(&absolute_path).await.log_err() {
+ log::info!("loaded file: {absolute_path:?}");
+
+ let path_buf = file.path.to_path_buf();
+ let already_stored = file_hashes
+ .get(&path_buf)
+ .map_or(false, |existing_hash| {
+ existing_hash.equals(&content)
+ });
+
+ if !already_stored {
+ log::info!(
+ "File Changed (Sending to Parse): {:?}",
+ &path_buf
+ );
+ paths_tx
+ .try_send((
+ worktree_id,
+ path_buf,
+ content,
+ language,
+ ))
+ .unwrap();
+ }
+ }
}
}
}
@@ -230,8 +265,8 @@ impl VectorStore {
let db_write_task = cx.background().spawn(
async move {
// Initialize Database, creates database and tables if not exists
- while let Ok(indexed_file) = indexed_files_rx.recv().await {
- db.insert_file(indexed_file).log_err();
+ while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
+ db.insert_file(worktree_id, indexed_file).log_err();
}
// ALL OF THE BELOW IS FOR TESTING,
@@ -271,29 +306,29 @@ impl VectorStore {
.log_err(),
);
- let provider = DummyEmbeddings {};
- // let provider = OpenAIEmbeddings { client };
-
cx.background()
.scoped(|scope| {
for _ in 0..cx.background().num_cpus() {
scope.spawn(async {
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
- while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
+ while let Ok((worktree_id, file_path, content, language)) =
+ paths_rx.recv().await
{
if let Some(indexed_file) = Self::index_file(
&mut cursor,
&mut parser,
- &provider,
- &language_registry,
+ embedding_provider.as_ref(),
+ language,
file_path,
content,
)
.await
.log_err()
{
- indexed_files_tx.try_send(indexed_file).unwrap();
+ indexed_files_tx
+ .try_send((worktree_id, indexed_file))
+ .unwrap();
}
}
});
@@ -315,41 +350,42 @@ impl VectorStore {
) -> Task<Result<Vec<SearchResult>>> {
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
- cx.spawn(|this, cx| async move {
+ cx.background().spawn(async move {
let database = VectorDatabase::new(database_url.as_ref())?;
- // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
- //
- let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+ let phrase_embedding = embedding_provider
+ .embed_batch(vec![&phrase])
+ .await?
+ .into_iter()
+ .next()
+ .unwrap();
+ let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
database.for_each_document(0, |id, embedding| {
- dbg!(id, &embedding);
-
- let similarity = dot(&embedding.0, &embedding.0);
+ let similarity = dot(&embedding.0, &phrase_embedding);
let ix = match results.binary_search_by(|(_, s)| {
- s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
+ similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
-
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
- dbg!(&results);
-
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
- // let documents = database.get_documents_by_ids(ids)?;
-
- // let search_provider = cx
- // .background()
- // .spawn(async move { BruteForceSearch::load(&database) })
- // .await?;
-
- // let results = search_provider.top_k_search(&embedding, limit))
-
- anyhow::Ok(vec![])
+ let documents = database.get_documents_by_ids(&ids)?;
+
+ anyhow::Ok(
+ documents
+ .into_iter()
+ .map(|(file_path, offset, name)| SearchResult {
+ name,
+ offset,
+ file_path,
+ })
+ .collect(),
+ )
})
}
}
@@ -57,20 +57,26 @@ async fn test_vector_store(cx: &mut TestAppContext) {
);
languages.add(rust_language);
+ let db_dir = tempdir::TempDir::new("vector-store").unwrap();
+ let db_path = db_dir.path().join("db.sqlite");
+
let store = cx.add_model(|_| {
VectorStore::new(
fs.clone(),
- "foo".to_string(),
+ db_path.to_string_lossy().to_string(),
Arc::new(FakeEmbeddingProvider),
languages,
)
});
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
- store
- .update(cx, |store, cx| store.add_project(project, cx))
- .await
- .unwrap();
+ let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
+
+ // TODO - remove
+ cx.foreground()
+ .advance_clock(std::time::Duration::from_secs(3));
+
+ add_project.await.unwrap();
let search_results = store
.update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
@@ -78,7 +84,7 @@ async fn test_vector_store(cx: &mut TestAppContext) {
.unwrap();
assert_eq!(search_results[0].offset, 0);
- assert_eq!(search_results[1].name, "aaa");
+ assert_eq!(search_results[0].name, "aaa");
}
#[test]
@@ -114,9 +120,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
Ok(spans
.iter()
.map(|span| {
- let mut result = vec![0.0; 26];
+ let mut result = vec![1.0; 26];
for letter in span.chars() {
- if letter as u32 > 'a' as u32 {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;