Detailed changes
@@ -7958,13 +7958,18 @@ dependencies = [
"language",
"lazy_static",
"log",
+ "matrixmultiply",
"ndarray",
"project",
+ "rand 0.8.5",
"rusqlite",
"serde",
"serde_json",
+ "sha-1 0.10.1",
"smol",
"tree-sitter",
+ "tree-sitter-rust",
+ "unindent",
"util",
"workspace",
]
@@ -27,9 +27,14 @@ serde_json.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
ndarray = "0.15.6"
+sha-1 = "0.10.1"
+matrixmultiply = "0.3.7"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
+tree-sitter-rust = "*"
+rand.workspace = true
+unindent.workspace = true
@@ -1,4 +1,7 @@
-use std::{collections::HashMap, path::PathBuf};
+use std::{
+ collections::HashMap,
+ path::{Path, PathBuf},
+};
use anyhow::{anyhow, Result};
@@ -13,7 +16,7 @@ use crate::IndexedFile;
// This is saving to a local database store within the users dev zed path
// Where do we want this to sit?
// Assuming near where the workspace DB sits.
-const VECTOR_DB_URL: &str = "embeddings_db";
+pub const VECTOR_DB_URL: &str = "embeddings_db";
// Note this is not an appropriate document
#[derive(Debug)]
@@ -28,7 +31,7 @@ pub struct DocumentRecord {
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
- pub path: String,
+ pub relative_path: String,
pub sha1: String,
}
@@ -51,9 +54,9 @@ pub struct VectorDatabase {
}
impl VectorDatabase {
- pub fn new() -> Result<Self> {
+ pub fn new(path: &str) -> Result<Self> {
let this = Self {
- db: rusqlite::Connection::open(VECTOR_DB_URL)?,
+ db: rusqlite::Connection::open(path)?,
};
this.initialize_database()?;
Ok(this)
@@ -63,21 +66,23 @@ impl VectorDatabase {
// This will create the database if it doesnt exist
// Initialize Vector Databasing Tables
- // self.db.execute(
- // "
- // CREATE TABLE IF NOT EXISTS projects (
- // id INTEGER PRIMARY KEY AUTOINCREMENT,
- // path NVARCHAR(100) NOT NULL
- // )
- // ",
- // [],
- // )?;
+ self.db.execute(
+ "CREATE TABLE IF NOT EXISTS worktrees (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ absolute_path VARCHAR NOT NULL
+ );
+ CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
+ ",
+ [],
+ )?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
- path NVARCHAR(100) NOT NULL,
- sha1 NVARCHAR(40) NOT NULL
+ worktree_id INTEGER NOT NULL,
+ relative_path VARCHAR NOT NULL,
+ sha1 NVARCHAR(40) NOT NULL,
+ FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
@@ -87,7 +92,7 @@ impl VectorDatabase {
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
offset INTEGER NOT NULL,
- name NVARCHAR(100) NOT NULL,
+ name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
@@ -116,7 +121,7 @@ impl VectorDatabase {
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
// Write to files table, and return generated id.
let files_insert = self.db.execute(
- "INSERT INTO files (path, sha1) VALUES (?1, ?2)",
+ "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
params![indexed_file.path.to_str(), indexed_file.sha1],
)?;
@@ -141,12 +146,38 @@ impl VectorDatabase {
Ok(())
}
+ pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
+ self.db.execute(
+ "
+ INSERT into worktrees (absolute_path) VALUES (?1)
+ ON CONFLICT DO NOTHING
+ ",
+ params![worktree_root_path.to_string_lossy()],
+ )?;
+ Ok(self.db.last_insert_rowid())
+ }
+
+ pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
+ let mut statement = self
+ .db
+ .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
+ let mut result = Vec::new();
+ for row in
+ statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
+ {
+ result.push(row?);
+ }
+ Ok(result)
+ }
+
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
- let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
+ let mut query_statement = self
+ .db
+ .prepare("SELECT id, relative_path, sha1 FROM files")?;
let result_iter = query_statement.query_map([], |row| {
Ok(FileRecord {
id: row.get(0)?,
- path: row.get(1)?,
+ relative_path: row.get(1)?,
sha1: row.get(2)?,
})
})?;
@@ -160,6 +191,19 @@ impl VectorDatabase {
Ok(pages)
}
+ pub fn for_each_document(
+ &self,
+ worktree_id: i64,
+ mut f: impl FnMut(i64, Embedding),
+ ) -> Result<()> {
+ let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
+ query_statement
+ .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
+ .filter_map(|row| row.ok())
+ .for_each(|row| f(row.0, row.1));
+ Ok(())
+ }
+
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
let mut query_statement = self
.db
@@ -44,7 +44,7 @@ struct OpenAIEmbeddingUsage {
}
#[async_trait]
-pub trait EmbeddingProvider: Sync {
+pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
}
@@ -1,4 +1,4 @@
-use std::cmp::Ordering;
+use std::{cmp::Ordering, path::PathBuf};
use async_trait::async_trait;
use ndarray::{Array1, Array2};
@@ -20,7 +20,6 @@ pub struct BruteForceSearch {
impl BruteForceSearch {
pub fn load(db: &VectorDatabase) -> Result<Self> {
- // let db = VectorDatabase {};
let documents = db.get_documents()?;
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
let mut document_ids = vec![];
@@ -63,20 +62,5 @@ impl VectorSearch for BruteForceSearch {
with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
with_indices.truncate(limit);
with_indices
-
- // // extract the sorted indices from the sorted tuple vector
- // let stored_indices = with_indices
- // .into_iter()
- // .map(|(index, value)| index)
- // .collect::<Vec<>>();
-
- // let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
-
- // let mut results = vec![];
- // for idx in sorted_indices[0..limit].to_vec() {
- // results.push((self.document_ids[idx], 1.0 - similarities[idx]));
- // }
-
- // return results;
}
}
@@ -3,16 +3,19 @@ mod embedding;
mod parsing;
mod search;
+#[cfg(test)]
+mod vector_store_tests;
+
use anyhow::{anyhow, Result};
-use db::VectorDatabase;
+use db::{VectorDatabase, VECTOR_DB_URL};
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
-use gpui::{AppContext, Entity, ModelContext, ModelHandle};
+use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
use language::LanguageRegistry;
use parsing::Document;
use project::{Fs, Project};
use search::{BruteForceSearch, VectorSearch};
use smol::channel;
-use std::{path::PathBuf, sync::Arc, time::Instant};
+use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt, TryFutureExt};
use workspace::WorkspaceCreated;
@@ -23,7 +26,16 @@ pub fn init(
language_registry: Arc<LanguageRegistry>,
cx: &mut AppContext,
) {
- let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
+ let vector_store = cx.add_model(|cx| {
+ VectorStore::new(
+ fs,
+ VECTOR_DB_URL.to_string(),
+ Arc::new(OpenAIEmbeddings {
+ client: http_client,
+ }),
+ language_registry,
+ )
+ });
cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone();
@@ -49,28 +61,36 @@ pub struct IndexedFile {
documents: Vec<Document>,
}
-struct SearchResult {
- path: PathBuf,
- offset: usize,
- name: String,
- distance: f32,
-}
-
+// struct SearchResult {
+// path: PathBuf,
+// offset: usize,
+// name: String,
+// distance: f32,
+// }
struct VectorStore {
fs: Arc<dyn Fs>,
- http_client: Arc<dyn HttpClient>,
+ database_url: Arc<str>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
}
+pub struct SearchResult {
+ pub name: String,
+ pub offset: usize,
+ pub file_path: PathBuf,
+}
+
impl VectorStore {
fn new(
fs: Arc<dyn Fs>,
- http_client: Arc<dyn HttpClient>,
+ database_url: String,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
) -> Self {
Self {
fs,
- http_client,
+ database_url: database_url.into(),
+ embedding_provider,
language_registry,
}
}
@@ -79,10 +99,12 @@ impl VectorStore {
cursor: &mut QueryCursor,
parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider,
- fs: &Arc<dyn Fs>,
language_registry: &Arc<LanguageRegistry>,
file_path: PathBuf,
+ content: String,
) -> Result<IndexedFile> {
+ dbg!(&file_path, &content);
+
let language = language_registry
.language_for_file(&file_path, None)
.await?;
@@ -97,7 +119,6 @@ impl VectorStore {
.as_ref()
.ok_or_else(|| anyhow!("no outline query"))?;
- let content = fs.load(&file_path).await?;
parser.set_language(grammar.ts_language).unwrap();
let tree = parser
.parse(&content, None)
@@ -142,7 +163,11 @@ impl VectorStore {
});
}
- fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
+ fn add_project(
+ &mut self,
+ project: ModelHandle<Project>,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
@@ -151,7 +176,8 @@ impl VectorStore {
let fs = self.fs.clone();
let language_registry = self.language_registry.clone();
- let client = self.http_client.clone();
+ let embedding_provider = self.embedding_provider.clone();
+ let database_url = self.database_url.clone();
cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await;
@@ -163,24 +189,47 @@ impl VectorStore {
.collect::<Vec<_>>()
});
- let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
- let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
- cx.background()
+ let db = VectorDatabase::new(&database_url)?;
+ let worktree_root_paths = worktrees
+ .iter()
+ .map(|worktree| worktree.abs_path().clone())
+ .collect::<Vec<_>>();
+ let (db, file_hashes) = cx
+ .background()
.spawn(async move {
- for worktree in worktrees {
- for file in worktree.files(false, 0) {
- paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
- }
+ let mut hashes = Vec::new();
+ for worktree_root_path in worktree_root_paths {
+ let worktree_id =
+ db.find_or_create_worktree(worktree_root_path.as_ref())?;
+ hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
}
+ anyhow::Ok((db, hashes))
})
- .detach();
+ .await?;
+ let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
+ let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
cx.background()
.spawn({
- let client = client.clone();
+ let fs = fs.clone();
async move {
+ for worktree in worktrees.into_iter() {
+ for file in worktree.files(false, 0) {
+ let absolute_path = worktree.absolutize(&file.path);
+ dbg!(&absolute_path);
+ if let Some(content) = fs.load(&absolute_path).await.log_err() {
+ dbg!(&content);
+ paths_tx.try_send((0, absolute_path, content)).unwrap();
+ }
+ }
+ }
+ }
+ })
+ .detach();
+
+ let db_write_task = cx.background().spawn(
+ async move {
// Initialize Database, creates database and tables if not exists
- let db = VectorDatabase::new()?;
while let Ok(indexed_file) = indexed_files_rx.recv().await {
db.insert_file(indexed_file).log_err();
}
@@ -188,39 +237,39 @@ impl VectorStore {
// ALL OF THE BELOW IS FOR TESTING,
// This should be removed as we find and appropriate place for evaluate our search.
- let embedding_provider = OpenAIEmbeddings{ client };
- let queries = vec![
- "compute embeddings for all of the symbols in the codebase, and write them to a database",
- "compute an outline view of all of the symbols in a buffer",
- "scan a directory on the file system and load all of its children into an in-memory snapshot",
- ];
- let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
-
- let t2 = Instant::now();
- let documents = db.get_documents().unwrap();
- let files = db.get_files().unwrap();
- println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
-
- let t1 = Instant::now();
- let mut bfs = BruteForceSearch::load(&db).unwrap();
- println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
- for (idx, embed) in embeddings.into_iter().enumerate() {
- let t0 = Instant::now();
- println!("\nQuery: {:?}", queries[idx]);
- let results = bfs.top_k_search(&embed, 5).await;
- println!("Search Elapsed: {}", t0.elapsed().as_millis());
- for (id, distance) in results {
- println!("");
- println!(" distance: {:?}", distance);
- println!(" document: {:?}", documents[&id].name);
- println!(" path: {:?}", files[&documents[&id].file_id].path);
- }
+ // let queries = vec![
+ // "compute embeddings for all of the symbols in the codebase, and write them to a database",
+ // "compute an outline view of all of the symbols in a buffer",
+ // "scan a directory on the file system and load all of its children into an in-memory snapshot",
+ // ];
+ // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
- }
+ // let t2 = Instant::now();
+ // let documents = db.get_documents().unwrap();
+ // let files = db.get_files().unwrap();
+ // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
+
+ // let t1 = Instant::now();
+ // let mut bfs = BruteForceSearch::load(&db).unwrap();
+ // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
+ // for (idx, embed) in embeddings.into_iter().enumerate() {
+ // let t0 = Instant::now();
+ // println!("\nQuery: {:?}", queries[idx]);
+ // let results = bfs.top_k_search(&embed, 5).await;
+ // println!("Search Elapsed: {}", t0.elapsed().as_millis());
+ // for (id, distance) in results {
+ // println!("");
+ // println!(" distance: {:?}", distance);
+ // println!(" document: {:?}", documents[&id].name);
+ // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
+ // }
+
+ // }
anyhow::Ok(())
- }}.log_err())
- .detach();
+ }
+ .log_err(),
+ );
let provider = DummyEmbeddings {};
// let provider = OpenAIEmbeddings { client };
@@ -231,14 +280,15 @@ impl VectorStore {
scope.spawn(async {
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
- while let Ok(file_path) = paths_rx.recv().await {
+ while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
+ {
if let Some(indexed_file) = Self::index_file(
&mut cursor,
&mut parser,
&provider,
- &fs,
&language_registry,
file_path,
+ content,
)
.await
.log_err()
@@ -250,11 +300,86 @@ impl VectorStore {
}
})
.await;
+ drop(indexed_files_tx);
+
+ db_write_task.await;
+ anyhow::Ok(())
+ })
+ }
+
+ pub fn search(
+ &mut self,
+ phrase: String,
+ limit: usize,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<SearchResult>>> {
+ let embedding_provider = self.embedding_provider.clone();
+ let database_url = self.database_url.clone();
+ cx.spawn(|this, cx| async move {
+ let database = VectorDatabase::new(database_url.as_ref())?;
+
+ // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
+ //
+ let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+
+ database.for_each_document(0, |id, embedding| {
+ dbg!(id, &embedding);
+
+ let similarity = dot(&embedding.0, &embedding.0);
+ let ix = match results.binary_search_by(|(_, s)| {
+ s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) => ix,
+ Err(ix) => ix,
+ };
+
+ results.insert(ix, (id, similarity));
+ results.truncate(limit);
+ })?;
+
+ dbg!(&results);
+
+ let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
+ // let documents = database.get_documents_by_ids(ids)?;
+
+ // let search_provider = cx
+ // .background()
+ // .spawn(async move { BruteForceSearch::load(&database) })
+ // .await?;
+
+ // let results = search_provider.top_k_search(&embedding, limit))
+
+ anyhow::Ok(vec![])
})
- .detach();
}
}
impl Entity for VectorStore {
type Event = ();
}
+
+fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
+ let len = vec_a.len();
+ assert_eq!(len, vec_b.len());
+
+ let mut result = 0.0;
+ unsafe {
+ matrixmultiply::sgemm(
+ 1,
+ len,
+ 1,
+ 1.0,
+ vec_a.as_ptr(),
+ len as isize,
+ 1,
+ vec_b.as_ptr(),
+ 1,
+ len as isize,
+ 0.0,
+ &mut result as *mut f32,
+ 1,
+ 1,
+ );
+ }
+ result
+}
@@ -0,0 +1,136 @@
+use std::sync::Arc;
+
+use crate::{dot, embedding::EmbeddingProvider, VectorStore};
+use anyhow::Result;
+use async_trait::async_trait;
+use gpui::{Task, TestAppContext};
+use language::{Language, LanguageConfig, LanguageRegistry};
+use project::{FakeFs, Project};
+use rand::Rng;
+use serde_json::json;
+use unindent::Unindent;
+
+#[gpui::test]
+async fn test_vector_store(cx: &mut TestAppContext) {
+ let fs = FakeFs::new(cx.background());
+ fs.insert_tree(
+ "/the-root",
+ json!({
+ "src": {
+ "file1.rs": "
+ fn aaa() {
+ println!(\"aaaa!\");
+ }
+
+ fn zzzzzzzzz() {
+ println!(\"SLEEPING\");
+ }
+ ".unindent(),
+ "file2.rs": "
+ fn bbb() {
+ println!(\"bbbb!\");
+ }
+ ".unindent(),
+ }
+ }),
+ )
+ .await;
+
+ let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
+ let rust_language = Arc::new(
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ path_suffixes: vec!["rs".into()],
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::language()),
+ )
+ .with_outline_query(
+ r#"
+ (function_item
+ name: (identifier) @name
+ body: (block)) @item
+ "#,
+ )
+ .unwrap(),
+ );
+ languages.add(rust_language);
+
+ let store = cx.add_model(|_| {
+ VectorStore::new(
+ fs.clone(),
+ "foo".to_string(),
+ Arc::new(FakeEmbeddingProvider),
+ languages,
+ )
+ });
+
+ let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
+ store
+ .update(cx, |store, cx| store.add_project(project, cx))
+ .await
+ .unwrap();
+
+ let search_results = store
+ .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
+ .await
+ .unwrap();
+
+ assert_eq!(search_results[0].offset, 0);
+ assert_eq!(search_results[1].name, "aaa");
+}
+
+#[test]
+fn test_dot_product() {
+ assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
+ assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
+
+ for _ in 0..100 {
+ let mut rng = rand::thread_rng();
+ let a: [f32; 32] = rng.gen();
+ let b: [f32; 32] = rng.gen();
+ assert_eq!(
+ round_to_decimals(dot(&a, &b), 3),
+ round_to_decimals(reference_dot(&a, &b), 3)
+ );
+ }
+
+ fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+ let factor = (10.0 as f32).powi(decimal_places);
+ (n * factor).round() / factor
+ }
+
+ fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
+ a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+ }
+}
+
+struct FakeEmbeddingProvider;
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+ async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ Ok(spans
+ .iter()
+ .map(|span| {
+ let mut result = vec![0.0; 26];
+ for letter in span.chars() {
+ if letter as u32 > 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result
+ })
+ .collect())
+ }
+}