Detailed changes
@@ -3539,7 +3539,7 @@ dependencies = [
"gif",
"jpeg-decoder",
"num-iter",
- "num-rational",
+ "num-rational 0.3.2",
"num-traits",
"png",
"scoped_threadpool",
@@ -4631,6 +4631,31 @@ dependencies = [
"winapi 0.3.9",
]
+[[package]]
+name = "num"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
+dependencies = [
+ "num-bigint 0.2.6",
+ "num-complex",
+ "num-integer",
+ "num-iter",
+ "num-rational 0.2.4",
+ "num-traits",
+]
+
+[[package]]
+name = "num-bigint"
+version = "0.2.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
+dependencies = [
+ "autocfg",
+ "num-integer",
+ "num-traits",
+]
+
[[package]]
name = "num-bigint"
version = "0.4.4"
@@ -4659,6 +4684,16 @@ dependencies = [
"zeroize",
]
+[[package]]
+name = "num-complex"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
+dependencies = [
+ "autocfg",
+ "num-traits",
+]
+
[[package]]
name = "num-derive"
version = "0.3.3"
@@ -4691,6 +4726,18 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "num-rational"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
+dependencies = [
+ "autocfg",
+ "num-bigint 0.2.6",
+ "num-integer",
+ "num-traits",
+]
+
[[package]]
name = "num-rational"
version = "0.3.2"
@@ -5007,6 +5054,17 @@ dependencies = [
"windows-targets 0.48.5",
]
+[[package]]
+name = "parse_duration"
+version = "2.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
+dependencies = [
+ "lazy_static",
+ "num",
+ "regex",
+]
+
[[package]]
name = "password-hash"
version = "0.2.3"
@@ -6674,6 +6732,7 @@ dependencies = [
"log",
"matrixmultiply",
"parking_lot 0.11.2",
+ "parse_duration",
"picker",
"postage",
"pretty_assertions",
@@ -7005,7 +7064,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
dependencies = [
"chrono",
- "num-bigint",
+ "num-bigint 0.4.4",
"num-traits",
"thiserror",
]
@@ -7237,7 +7296,7 @@ dependencies = [
"log",
"md-5",
"memchr",
- "num-bigint",
+ "num-bigint 0.4.4",
"once_cell",
"paste",
"percent-encoding",
@@ -39,6 +39,7 @@ rand.workspace = true
schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
+parse_duration = "2.1.1"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
@@ -1,20 +1,26 @@
-use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
+use crate::{
+ embedding::Embedding,
+ parsing::{Document, DocumentDigest},
+ SEMANTIC_INDEX_VERSION,
+};
use anyhow::{anyhow, Context, Result};
+use futures::channel::oneshot;
+use gpui::executor;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
-use rusqlite::{
- params,
- types::{FromSql, FromSqlResult, ValueRef},
-};
+use rusqlite::params;
+use rusqlite::types::Value;
use std::{
cmp::Ordering,
collections::HashMap,
+ future::Future,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
- time::SystemTime,
+ time::{Instant, SystemTime},
};
+use util::TryFutureExt;
#[derive(Debug)]
pub struct FileRecord {
@@ -23,145 +29,181 @@ pub struct FileRecord {
pub mtime: Timestamp,
}
-#[derive(Debug)]
-struct Embedding(pub Vec<f32>);
-
-#[derive(Debug)]
-struct Sha1(pub Vec<u8>);
-
-impl FromSql for Embedding {
- fn column_result(value: ValueRef) -> FromSqlResult<Self> {
- let bytes = value.as_blob()?;
- let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
- if embedding.is_err() {
- return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
- }
- return Ok(Embedding(embedding.unwrap()));
- }
-}
-
-impl FromSql for Sha1 {
- fn column_result(value: ValueRef) -> FromSqlResult<Self> {
- let bytes = value.as_blob()?;
- let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
- if sha1.is_err() {
- return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
- }
- return Ok(Sha1(sha1.unwrap()));
- }
-}
-
+#[derive(Clone)]
pub struct VectorDatabase {
- db: rusqlite::Connection,
+ path: Arc<Path>,
+ transactions:
+ smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
}
impl VectorDatabase {
- pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
+ pub async fn new(
+ fs: Arc<dyn Fs>,
+ path: Arc<Path>,
+ executor: Arc<executor::Background>,
+ ) -> Result<Self> {
if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?;
}
+ let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
+ Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
+ >();
+ executor
+ .spawn({
+ let path = path.clone();
+ async move {
+ let mut connection = rusqlite::Connection::open(&path)?;
+
+ connection.pragma_update(None, "journal_mode", "wal")?;
+ connection.pragma_update(None, "synchronous", "normal")?;
+ connection.pragma_update(None, "cache_size", 1000000)?;
+ connection.pragma_update(None, "temp_store", "MEMORY")?;
+
+ while let Ok(transaction) = transactions_rx.recv().await {
+ transaction(&mut connection);
+ }
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ })
+ .detach();
let this = Self {
- db: rusqlite::Connection::open(path.as_path())?,
+ transactions: transactions_tx,
+ path,
};
- this.initialize_database()?;
+ this.initialize_database().await?;
Ok(this)
}
- fn get_existing_version(&self) -> Result<i64> {
- let mut version_query = self
- .db
- .prepare("SELECT version from semantic_index_config")?;
- version_query
- .query_row([], |row| Ok(row.get::<_, i64>(0)?))
- .map_err(|err| anyhow!("version query failed: {err}"))
+ pub fn path(&self) -> &Arc<Path> {
+ &self.path
}
- fn initialize_database(&self) -> Result<()> {
- rusqlite::vtab::array::load_module(&self.db)?;
-
- // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
- if self
- .get_existing_version()
- .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
- {
- log::trace!("vector database schema up to date");
- return Ok(());
+ fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
+ where
+ F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
+ T: 'static + Send,
+ {
+ let (tx, rx) = oneshot::channel();
+ let transactions = self.transactions.clone();
+ async move {
+ if transactions
+ .send(Box::new(|connection| {
+ let result = connection
+ .transaction()
+ .map_err(|err| anyhow!(err))
+ .and_then(|transaction| {
+ let result = f(&transaction)?;
+ transaction.commit()?;
+ Ok(result)
+ });
+ let _ = tx.send(result);
+ }))
+ .await
+ .is_err()
+ {
+ return Err(anyhow!("connection was dropped"))?;
+ }
+ rx.await?
}
+ }
- log::trace!("vector database schema out of date. updating...");
- self.db
- .execute("DROP TABLE IF EXISTS documents", [])
- .context("failed to drop 'documents' table")?;
- self.db
- .execute("DROP TABLE IF EXISTS files", [])
- .context("failed to drop 'files' table")?;
- self.db
- .execute("DROP TABLE IF EXISTS worktrees", [])
- .context("failed to drop 'worktrees' table")?;
- self.db
- .execute("DROP TABLE IF EXISTS semantic_index_config", [])
- .context("failed to drop 'semantic_index_config' table")?;
-
- // Initialize Vector Databasing Tables
- self.db.execute(
- "CREATE TABLE semantic_index_config (
- version INTEGER NOT NULL
- )",
- [],
- )?;
+ fn initialize_database(&self) -> impl Future<Output = Result<()>> {
+ self.transact(|db| {
+ rusqlite::vtab::array::load_module(&db)?;
+
+ // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
+ let version_query = db.prepare("SELECT version from semantic_index_config");
+ let version = version_query
+ .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
+ if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
+ log::trace!("vector database schema up to date");
+ return Ok(());
+ }
- self.db.execute(
- "INSERT INTO semantic_index_config (version) VALUES (?1)",
- params![SEMANTIC_INDEX_VERSION],
- )?;
+ log::trace!("vector database schema out of date. updating...");
+ db.execute("DROP TABLE IF EXISTS documents", [])
+ .context("failed to drop 'documents' table")?;
+ db.execute("DROP TABLE IF EXISTS files", [])
+ .context("failed to drop 'files' table")?;
+ db.execute("DROP TABLE IF EXISTS worktrees", [])
+ .context("failed to drop 'worktrees' table")?;
+ db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
+ .context("failed to drop 'semantic_index_config' table")?;
+
+ // Initialize Vector Databasing Tables
+ db.execute(
+ "CREATE TABLE semantic_index_config (
+ version INTEGER NOT NULL
+ )",
+ [],
+ )?;
- self.db.execute(
- "CREATE TABLE worktrees (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- absolute_path VARCHAR NOT NULL
- );
- CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
- ",
- [],
- )?;
+ db.execute(
+ "INSERT INTO semantic_index_config (version) VALUES (?1)",
+ params![SEMANTIC_INDEX_VERSION],
+ )?;
- self.db.execute(
- "CREATE TABLE files (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- worktree_id INTEGER NOT NULL,
- relative_path VARCHAR NOT NULL,
- mtime_seconds INTEGER NOT NULL,
- mtime_nanos INTEGER NOT NULL,
- FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
- )",
- [],
- )?;
+ db.execute(
+ "CREATE TABLE worktrees (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ absolute_path VARCHAR NOT NULL
+ );
+ CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
+ ",
+ [],
+ )?;
- self.db.execute(
- "CREATE TABLE documents (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- file_id INTEGER NOT NULL,
- start_byte INTEGER NOT NULL,
- end_byte INTEGER NOT NULL,
- name VARCHAR NOT NULL,
- embedding BLOB NOT NULL,
- sha1 BLOB NOT NULL,
- FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
- )",
- [],
- )?;
+ db.execute(
+ "CREATE TABLE files (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ worktree_id INTEGER NOT NULL,
+ relative_path VARCHAR NOT NULL,
+ mtime_seconds INTEGER NOT NULL,
+ mtime_nanos INTEGER NOT NULL,
+ FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
+ )",
+ [],
+ )?;
- log::trace!("vector database initialized with updated schema.");
- Ok(())
+ db.execute(
+ "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
+ [],
+ )?;
+
+ db.execute(
+ "CREATE TABLE documents (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ file_id INTEGER NOT NULL,
+ start_byte INTEGER NOT NULL,
+ end_byte INTEGER NOT NULL,
+ name VARCHAR NOT NULL,
+ embedding BLOB NOT NULL,
+ digest BLOB NOT NULL,
+ FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
+ )",
+ [],
+ )?;
+
+ log::trace!("vector database initialized with updated schema.");
+ Ok(())
+ })
}
- pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
- self.db.execute(
- "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
- params![worktree_id, delete_path.to_str()],
- )?;
- Ok(())
+ pub fn delete_file(
+ &self,
+ worktree_id: i64,
+ delete_path: PathBuf,
+ ) -> impl Future<Output = Result<()>> {
+ self.transact(move |db| {
+ db.execute(
+ "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
+ params![worktree_id, delete_path.to_str()],
+ )?;
+ Ok(())
+ })
}
pub fn insert_file(
@@ -170,139 +212,187 @@ impl VectorDatabase {
path: PathBuf,
mtime: SystemTime,
documents: Vec<Document>,
- ) -> Result<()> {
- // Return the existing ID, if both the file and mtime match
- let mtime = Timestamp::from(mtime);
- let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
- let existing_id = existing_id_query
- .query_row(
+ ) -> impl Future<Output = Result<()>> {
+ self.transact(move |db| {
+ // Return the existing ID, if both the file and mtime match
+ let mtime = Timestamp::from(mtime);
+
+ db.execute(
+ "
+ REPLACE INTO files
+ (worktree_id, relative_path, mtime_seconds, mtime_nanos)
+ VALUES (?1, ?2, ?3, ?4)
+ ",
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
- |row| Ok(row.get::<_, i64>(0)?),
- )
- .map_err(|err| anyhow!(err));
- let file_id = if existing_id.is_ok() {
- // If already exists, just return the existing id
- existing_id.unwrap()
- } else {
- // Delete Existing Row
- self.db.execute(
- "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
- params![worktree_id, path.to_str()],
)?;
- self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
- self.db.last_insert_rowid()
- };
- // Currently inserting at approximately 3400 documents a second
- // I imagine we can speed this up with a bulk insert of some kind.
- for document in documents {
- let embedding_blob = bincode::serialize(&document.embedding)?;
- let sha_blob = bincode::serialize(&document.sha1)?;
+ let file_id = db.last_insert_rowid();
- self.db.execute(
- "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
- params![
+ let t0 = Instant::now();
+ let mut query = db.prepare(
+ "
+ INSERT INTO documents
+ (file_id, start_byte, end_byte, name, embedding, digest)
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6)
+ ",
+ )?;
+ log::trace!(
+ "Preparing Query Took: {:?} milliseconds",
+ t0.elapsed().as_millis()
+ );
+
+ for document in documents {
+ query.execute(params![
file_id,
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
- embedding_blob,
- sha_blob
- ],
- )?;
- }
+ document.embedding,
+ document.digest
+ ])?;
+ }
- Ok(())
+ Ok(())
+ })
}
- pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
- let mut worktree_query = self
- .db
- .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
- let worktree_id = worktree_query
- .query_row(params![worktree_root_path.to_string_lossy()], |row| {
- Ok(row.get::<_, i64>(0)?)
- })
- .map_err(|err| anyhow!(err));
-
- if worktree_id.is_ok() {
- return Ok(true);
- } else {
- return Ok(false);
- }
+ pub fn worktree_previously_indexed(
+ &self,
+ worktree_root_path: &Path,
+ ) -> impl Future<Output = Result<bool>> {
+ let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
+ self.transact(move |db| {
+ let mut worktree_query =
+ db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+ let worktree_id = worktree_query
+ .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
+
+ if worktree_id.is_ok() {
+ return Ok(true);
+ } else {
+ return Ok(false);
+ }
+ })
}
- pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
- // Check that the absolute path doesnt exist
- let mut worktree_query = self
- .db
- .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+ pub fn embeddings_for_files(
+ &self,
+ worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
+ ) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
+ self.transact(move |db| {
+ let mut query = db.prepare(
+ "
+ SELECT digest, embedding
+ FROM documents
+ LEFT JOIN files ON files.id = documents.file_id
+ WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
+ ",
+ )?;
+ let mut embeddings_by_digest = HashMap::new();
+ for (worktree_id, file_paths) in worktree_id_file_paths {
+ let file_paths = Rc::new(
+ file_paths
+ .into_iter()
+ .map(|p| Value::Text(p.to_string_lossy().into_owned()))
+ .collect::<Vec<_>>(),
+ );
+ let rows = query.query_map(params![worktree_id, file_paths], |row| {
+ Ok((
+ row.get::<_, DocumentDigest>(0)?,
+ row.get::<_, Embedding>(1)?,
+ ))
+ })?;
+
+ for row in rows {
+ if let Ok(row) = row {
+ embeddings_by_digest.insert(row.0, row.1);
+ }
+ }
+ }
- let worktree_id = worktree_query
- .query_row(params![worktree_root_path.to_string_lossy()], |row| {
- Ok(row.get::<_, i64>(0)?)
- })
- .map_err(|err| anyhow!(err));
+ Ok(embeddings_by_digest)
+ })
+ }
- if worktree_id.is_ok() {
- return worktree_id;
- }
+ pub fn find_or_create_worktree(
+ &self,
+ worktree_root_path: PathBuf,
+ ) -> impl Future<Output = Result<i64>> {
+ self.transact(move |db| {
+ let mut worktree_query =
+ db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+ let worktree_id = worktree_query
+ .query_row(params![worktree_root_path.to_string_lossy()], |row| {
+ Ok(row.get::<_, i64>(0)?)
+ });
+
+ if worktree_id.is_ok() {
+ return Ok(worktree_id?);
+ }
- // If worktree_id is Err, insert new worktree
- self.db.execute(
- "
- INSERT into worktrees (absolute_path) VALUES (?1)
- ",
- params![worktree_root_path.to_string_lossy()],
- )?;
- Ok(self.db.last_insert_rowid())
+ // If worktree_id is Err, insert new worktree
+ db.execute(
+ "INSERT into worktrees (absolute_path) VALUES (?1)",
+ params![worktree_root_path.to_string_lossy()],
+ )?;
+ Ok(db.last_insert_rowid())
+ })
}
- pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
- let mut statement = self.db.prepare(
- "
- SELECT relative_path, mtime_seconds, mtime_nanos
- FROM files
- WHERE worktree_id = ?1
- ORDER BY relative_path",
- )?;
- let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
- for row in statement.query_map(params![worktree_id], |row| {
- Ok((
- row.get::<_, String>(0)?.into(),
- Timestamp {
- seconds: row.get(1)?,
- nanos: row.get(2)?,
- }
- .into(),
- ))
- })? {
- let row = row?;
- result.insert(row.0, row.1);
- }
- Ok(result)
+ pub fn get_file_mtimes(
+ &self,
+ worktree_id: i64,
+ ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
+ self.transact(move |db| {
+ let mut statement = db.prepare(
+ "
+ SELECT relative_path, mtime_seconds, mtime_nanos
+ FROM files
+ WHERE worktree_id = ?1
+ ORDER BY relative_path",
+ )?;
+ let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
+ for row in statement.query_map(params![worktree_id], |row| {
+ Ok((
+ row.get::<_, String>(0)?.into(),
+ Timestamp {
+ seconds: row.get(1)?,
+ nanos: row.get(2)?,
+ }
+ .into(),
+ ))
+ })? {
+ let row = row?;
+ result.insert(row.0, row.1);
+ }
+ Ok(result)
+ })
}
pub fn top_k_search(
&self,
- query_embedding: &Vec<f32>,
+ query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
- ) -> Result<Vec<(i64, f32)>> {
- let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
- self.for_each_document(file_ids, |id, embedding| {
- let similarity = dot(&embedding, &query_embedding);
- let ix = match results
- .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
- {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
- results.insert(ix, (id, similarity));
- results.truncate(limit);
- })?;
-
- Ok(results)
+ ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
+ let query_embedding = query_embedding.clone();
+ let file_ids = file_ids.to_vec();
+ self.transact(move |db| {
+ let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+ Self::for_each_document(db, &file_ids, |id, embedding| {
+ let similarity = embedding.similarity(&query_embedding);
+ let ix = match results.binary_search_by(|(_, s)| {
+ similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) => ix,
+ Err(ix) => ix,
+ };
+ results.insert(ix, (id, similarity));
+ results.truncate(limit);
+ })?;
+
+ anyhow::Ok(results)
+ })
}
pub fn retrieve_included_file_ids(
@@ -310,37 +400,46 @@ impl VectorDatabase {
worktree_ids: &[i64],
includes: &[PathMatcher],
excludes: &[PathMatcher],
- ) -> Result<Vec<i64>> {
- let mut file_query = self.db.prepare(
- "
- SELECT
- id, relative_path
- FROM
- files
- WHERE
- worktree_id IN rarray(?)
- ",
- )?;
+ ) -> impl Future<Output = Result<Vec<i64>>> {
+ let worktree_ids = worktree_ids.to_vec();
+ let includes = includes.to_vec();
+ let excludes = excludes.to_vec();
+ self.transact(move |db| {
+ let mut file_query = db.prepare(
+ "
+ SELECT
+ id, relative_path
+ FROM
+ files
+ WHERE
+ worktree_id IN rarray(?)
+ ",
+ )?;
- let mut file_ids = Vec::<i64>::new();
- let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
-
- while let Some(row) = rows.next()? {
- let file_id = row.get(0)?;
- let relative_path = row.get_ref(1)?.as_str()?;
- let included =
- includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
- let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
- if included && !excluded {
- file_ids.push(file_id);
+ let mut file_ids = Vec::<i64>::new();
+ let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
+
+ while let Some(row) = rows.next()? {
+ let file_id = row.get(0)?;
+ let relative_path = row.get_ref(1)?.as_str()?;
+ let included =
+ includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
+ let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
+ if included && !excluded {
+ file_ids.push(file_id);
+ }
}
- }
- Ok(file_ids)
+ anyhow::Ok(file_ids)
+ })
}
- fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
- let mut query_statement = self.db.prepare(
+ fn for_each_document(
+ db: &rusqlite::Connection,
+ file_ids: &[i64],
+ mut f: impl FnMut(i64, Embedding),
+ ) -> Result<()> {
+ let mut query_statement = db.prepare(
"
SELECT
id, embedding
@@ -356,51 +455,57 @@ impl VectorDatabase {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
- .for_each(|(id, embedding)| f(id, embedding.0));
+ .for_each(|(id, embedding)| f(id, embedding));
Ok(())
}
- pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
- let mut statement = self.db.prepare(
- "
- SELECT
- documents.id,
- files.worktree_id,
- files.relative_path,
- documents.start_byte,
- documents.end_byte
- FROM
- documents, files
- WHERE
- documents.file_id = files.id AND
- documents.id in rarray(?)
- ",
- )?;
+ pub fn get_documents_by_ids(
+ &self,
+ ids: &[i64],
+ ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
+ let ids = ids.to_vec();
+ self.transact(move |db| {
+ let mut statement = db.prepare(
+ "
+ SELECT
+ documents.id,
+ files.worktree_id,
+ files.relative_path,
+ documents.start_byte,
+ documents.end_byte
+ FROM
+ documents, files
+ WHERE
+ documents.file_id = files.id AND
+ documents.id in rarray(?)
+ ",
+ )?;
- let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
- Ok((
- row.get::<_, i64>(0)?,
- row.get::<_, i64>(1)?,
- row.get::<_, String>(2)?.into(),
- row.get(3)?..row.get(4)?,
- ))
- })?;
-
- let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
- for row in result_iter {
- let (id, worktree_id, path, range) = row?;
- values_by_id.insert(id, (worktree_id, path, range));
- }
+ let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
+ Ok((
+ row.get::<_, i64>(0)?,
+ row.get::<_, i64>(1)?,
+ row.get::<_, String>(2)?.into(),
+ row.get(3)?..row.get(4)?,
+ ))
+ })?;
+
+ let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
+ for row in result_iter {
+ let (id, worktree_id, path, range) = row?;
+ values_by_id.insert(id, (worktree_id, path, range));
+ }
- let mut results = Vec::with_capacity(ids.len());
- for id in ids {
- let value = values_by_id
- .remove(id)
- .ok_or(anyhow!("missing document id {}", id))?;
- results.push(value);
- }
+ let mut results = Vec::with_capacity(ids.len());
+ for id in &ids {
+ let value = values_by_id
+ .remove(id)
+ .ok_or(anyhow!("missing document id {}", id))?;
+ results.push(value);
+ }
- Ok(results)
+ Ok(results)
+ })
}
}
@@ -412,29 +517,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
.collect::<Vec<_>>(),
)
}
-
-pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
- let len = vec_a.len();
- assert_eq!(len, vec_b.len());
-
- let mut result = 0.0;
- unsafe {
- matrixmultiply::sgemm(
- 1,
- len,
- 1,
- 1.0,
- vec_a.as_ptr(),
- len as isize,
- 1,
- vec_b.as_ptr(),
- 1,
- len as isize,
- 0.0,
- &mut result as *mut f32,
- 1,
- 1,
- );
- }
- result
-}
@@ -7,6 +7,9 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
+use parse_duration::parse;
+use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
+use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
@@ -19,6 +22,62 @@ lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
+#[derive(Debug, PartialEq, Clone)]
+pub struct Embedding(Vec<f32>);
+
+impl From<Vec<f32>> for Embedding {
+ fn from(value: Vec<f32>) -> Self {
+ Embedding(value)
+ }
+}
+
+impl Embedding {
+ pub fn similarity(&self, other: &Self) -> f32 {
+ let len = self.0.len();
+ assert_eq!(len, other.0.len());
+
+ let mut result = 0.0;
+ unsafe {
+ matrixmultiply::sgemm(
+ 1,
+ len,
+ 1,
+ 1.0,
+ self.0.as_ptr(),
+ len as isize,
+ 1,
+ other.0.as_ptr(),
+ 1,
+ len as isize,
+ 0.0,
+ &mut result as *mut f32,
+ 1,
+ 1,
+ );
+ }
+ result
+ }
+}
+
+impl FromSql for Embedding {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let bytes = value.as_blob()?;
+ let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+ if embedding.is_err() {
+ return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+ }
+ Ok(Embedding(embedding.unwrap()))
+ }
+}
+
+impl ToSql for Embedding {
+ fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+ let bytes = bincode::serialize(&self.0)
+ .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
+ Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
+ }
+}
+
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
@@ -52,42 +111,53 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
+ fn max_tokens_per_batch(&self) -> usize;
+ fn truncate(&self, span: &str) -> (String, usize);
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
- let dummy_vec = vec![0.32 as f32; 1536];
+ let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
return Ok(vec![dummy_vec; spans.len()]);
}
-}
-const OPENAI_INPUT_LIMIT: usize = 8190;
+ fn max_tokens_per_batch(&self) -> usize {
+ OPENAI_INPUT_LIMIT
+ }
-impl OpenAIEmbeddings {
- fn truncate(span: String) -> String {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
- if tokens.len() > OPENAI_INPUT_LIMIT {
+ fn truncate(&self, span: &str) -> (String, usize) {
+ let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+ let token_count = tokens.len();
+ let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
- let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
- if result.is_ok() {
- let transformed = result.unwrap();
- return transformed;
- }
- }
+ let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
+ new_input.ok().unwrap_or_else(|| span.to_string())
+ } else {
+ span.to_string()
+ };
- span
+ (output, tokens.len())
}
+}
+
+const OPENAI_INPUT_LIMIT: usize = 8190;
- async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
+impl OpenAIEmbeddings {
+ async fn send_request(
+ &self,
+ api_key: &str,
+ spans: Vec<&str>,
+ request_timeout: u64,
+ ) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
- .timeout(Duration::from_secs(4))
+ .timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
@@ -105,7 +175,27 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ fn max_tokens_per_batch(&self) -> usize {
+ 50000
+ }
+
+ fn truncate(&self, span: &str) -> (String, usize) {
+ let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+ let token_count = tokens.len();
+ let output = if token_count > OPENAI_INPUT_LIMIT {
+ tokens.truncate(OPENAI_INPUT_LIMIT);
+ OPENAI_BPE_TOKENIZER
+ .decode(tokens)
+ .ok()
+ .unwrap_or_else(|| span.to_string())
+ } else {
+ span.to_string()
+ };
+
+ (output, token_count)
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@@ -114,45 +204,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
.ok_or_else(|| anyhow!("no api key"))?;
let mut request_number = 0;
- let mut truncated = false;
+ let mut request_timeout: u64 = 10;
let mut response: Response<AsyncBody>;
- let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
while request_number < MAX_RETRIES {
response = self
- .send_request(api_key, spans.iter().map(|x| &**x).collect())
+ .send_request(
+ api_key,
+ spans.iter().map(|x| &**x).collect(),
+ request_timeout,
+ )
.await?;
request_number += 1;
- if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
- return Err(anyhow!(
- "openai max retries, error: {:?}",
- &response.status()
- ));
- }
-
match response.status() {
- StatusCode::TOO_MANY_REQUESTS => {
- let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
- log::trace!(
- "open ai rate limiting, delaying request by {:?} seconds",
- delay.as_secs()
- );
- self.executor.timer(delay).await;
- }
- StatusCode::BAD_REQUEST => {
- // Only truncate if it hasnt been truncated before
- if !truncated {
- for span in spans.iter_mut() {
- *span = Self::truncate(span.clone());
- }
- truncated = true;
- } else {
- // If failing once already truncated, log the error and break the loop
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- log::trace!("open ai bad request: {:?} {:?}", &response.status(), body);
- break;
- }
+ StatusCode::REQUEST_TIMEOUT => {
+ request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
@@ -163,18 +229,96 @@ impl EmbeddingProvider for OpenAIEmbeddings {
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
+
return Ok(response
.data
.into_iter()
- .map(|embedding| embedding.embedding)
+ .map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
+ StatusCode::TOO_MANY_REQUESTS => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let delay_duration = {
+ let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+ if let Some(time_to_reset) =
+ response.headers().get("x-ratelimit-reset-tokens")
+ {
+ if let Ok(time_str) = time_to_reset.to_str() {
+ parse(time_str).unwrap_or(delay)
+ } else {
+ delay
+ }
+ } else {
+ delay
+ }
+ };
+
+ log::trace!(
+ "openai rate limiting: waiting {:?} until lifted",
+ &delay_duration
+ );
+
+ self.executor.timer(delay_duration).await;
+ }
_ => {
- return Err(anyhow!("openai embedding failed {}", response.status()));
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "open ai bad request: {:?} {:?}",
+ &response.status(),
+ body
+ ));
}
}
}
+ Err(anyhow!("openai max retries"))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::prelude::*;
+
+ #[gpui::test]
+ fn test_similarity(mut rng: StdRng) {
+ assert_eq!(
+ Embedding::from(vec![1., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
+ 0.
+ );
+ assert_eq!(
+ Embedding::from(vec![2., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
+ 6.
+ );
- Err(anyhow!("openai embedding failed"))
+ for _ in 0..100 {
+ let size = 1536;
+ let mut a = vec![0.; size];
+ let mut b = vec![0.; size];
+ for (a, b) in a.iter_mut().zip(b.iter_mut()) {
+ *a = rng.gen();
+ *b = rng.gen();
+ }
+ let a = Embedding::from(a);
+ let b = Embedding::from(b);
+
+ assert_eq!(
+ round_to_decimals(a.similarity(&b), 1),
+ round_to_decimals(reference_dot(&a.0, &b.0), 1)
+ );
+ }
+
+ fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+ let factor = (10.0 as f32).powi(decimal_places);
+ (n * factor).round() / factor
+ }
+
+ fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
+ a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+ }
}
}
@@ -0,0 +1,173 @@
+use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+use gpui::executor::Background;
+use parking_lot::Mutex;
+use smol::channel;
+use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
+
+#[derive(Clone)]
+pub struct FileToEmbed {
+ pub worktree_id: i64,
+ pub path: PathBuf,
+ pub mtime: SystemTime,
+ pub documents: Vec<Document>,
+ pub job_handle: JobHandle,
+}
+
+impl std::fmt::Debug for FileToEmbed {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("FileToEmbed")
+ .field("worktree_id", &self.worktree_id)
+ .field("path", &self.path)
+ .field("mtime", &self.mtime)
+ .field("document", &self.documents)
+ .finish_non_exhaustive()
+ }
+}
+
+impl PartialEq for FileToEmbed {
+ fn eq(&self, other: &Self) -> bool {
+ self.worktree_id == other.worktree_id
+ && self.path == other.path
+ && self.mtime == other.mtime
+ && self.documents == other.documents
+ }
+}
+
+pub struct EmbeddingQueue {
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ pending_batch: Vec<FileToEmbedFragment>,
+ executor: Arc<Background>,
+ pending_batch_token_count: usize,
+ finished_files_tx: channel::Sender<FileToEmbed>,
+ finished_files_rx: channel::Receiver<FileToEmbed>,
+}
+
+#[derive(Clone)]
+pub struct FileToEmbedFragment {
+ file: Arc<Mutex<FileToEmbed>>,
+ document_range: Range<usize>,
+}
+
+impl EmbeddingQueue {
+ pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
+ let (finished_files_tx, finished_files_rx) = channel::unbounded();
+ Self {
+ embedding_provider,
+ executor,
+ pending_batch: Vec::new(),
+ pending_batch_token_count: 0,
+ finished_files_tx,
+ finished_files_rx,
+ }
+ }
+
+ pub fn push(&mut self, file: FileToEmbed) {
+ if file.documents.is_empty() {
+ self.finished_files_tx.try_send(file).unwrap();
+ return;
+ }
+
+ let file = Arc::new(Mutex::new(file));
+
+ self.pending_batch.push(FileToEmbedFragment {
+ file: file.clone(),
+ document_range: 0..0,
+ });
+
+ let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+ let mut saved_tokens = 0;
+ for (ix, document) in file.lock().documents.iter().enumerate() {
+ let document_token_count = if document.embedding.is_none() {
+ document.token_count
+ } else {
+ saved_tokens += document.token_count;
+ 0
+ };
+
+ let next_token_count = self.pending_batch_token_count + document_token_count;
+ if next_token_count > self.embedding_provider.max_tokens_per_batch() {
+ let range_end = fragment_range.end;
+ self.flush();
+ self.pending_batch.push(FileToEmbedFragment {
+ file: file.clone(),
+ document_range: range_end..range_end,
+ });
+ fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+ }
+
+ fragment_range.end = ix + 1;
+ self.pending_batch_token_count += document_token_count;
+ }
+ log::trace!("Saved Tokens: {:?}", saved_tokens);
+ }
+
+ pub fn flush(&mut self) {
+ let batch = mem::take(&mut self.pending_batch);
+ self.pending_batch_token_count = 0;
+ if batch.is_empty() {
+ return;
+ }
+
+ let finished_files_tx = self.finished_files_tx.clone();
+ let embedding_provider = self.embedding_provider.clone();
+
+ self.executor.spawn(async move {
+ let mut spans = Vec::new();
+ let mut document_count = 0;
+ for fragment in &batch {
+ let file = fragment.file.lock();
+ document_count += file.documents[fragment.document_range.clone()].len();
+ spans.extend(
+ {
+ file.documents[fragment.document_range.clone()]
+ .iter().filter(|d| d.embedding.is_none())
+ .map(|d| d.content.clone())
+ }
+ );
+ }
+
+ log::trace!("Documents Length: {:?}", document_count);
+ log::trace!("Span Length: {:?}", spans.clone().len());
+
+ // If spans is 0, just send the fragment to the finished files if its the last one.
+ if spans.len() == 0 {
+ for fragment in batch.clone() {
+ if let Some(file) = Arc::into_inner(fragment.file) {
+ finished_files_tx.try_send(file.into_inner()).unwrap();
+ }
+ }
+ return;
+ };
+
+ match embedding_provider.embed_batch(spans).await {
+ Ok(embeddings) => {
+ let mut embeddings = embeddings.into_iter();
+ for fragment in batch {
+ for document in
+ &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
+ {
+ if let Some(embedding) = embeddings.next() {
+ document.embedding = Some(embedding);
+ } else {
+ //
+ log::error!("number of embeddings returned different from number of documents");
+ }
+ }
+
+ if let Some(file) = Arc::into_inner(fragment.file) {
+ finished_files_tx.try_send(file.into_inner()).unwrap();
+ }
+ }
+ }
+ Err(error) => {
+ log::error!("{:?}", error);
+ }
+ }
+ })
+ .detach();
+ }
+
+ pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
+ self.finished_files_rx.clone()
+ }
+}
@@ -1,5 +1,10 @@
-use anyhow::{anyhow, Ok, Result};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use anyhow::{anyhow, Result};
use language::{Grammar, Language};
+use rusqlite::{
+ types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
+ ToSql,
+};
use sha1::{Digest, Sha1};
use std::{
cmp::{self, Reverse},
@@ -10,13 +15,44 @@ use std::{
};
use tree_sitter::{Parser, QueryCursor};
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+pub struct DocumentDigest([u8; 20]);
+
+impl FromSql for DocumentDigest {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let blob = value.as_blob()?;
+ let bytes =
+ blob.try_into()
+ .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
+ expected_size: 20,
+ blob_size: blob.len(),
+ })?;
+ return Ok(DocumentDigest(bytes));
+ }
+}
+
+impl ToSql for DocumentDigest {
+ fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+ self.0.to_sql()
+ }
+}
+
+impl From<&'_ str> for DocumentDigest {
+ fn from(value: &'_ str) -> Self {
+ let mut sha1 = Sha1::new();
+ sha1.update(value);
+ Self(sha1.finalize().into())
+ }
+}
+
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub name: String,
pub range: Range<usize>,
pub content: String,
- pub embedding: Vec<f32>,
- pub sha1: [u8; 20],
+ pub embedding: Option<Embedding>,
+ pub digest: DocumentDigest,
+ pub token_count: usize,
}
const CODE_CONTEXT_TEMPLATE: &str =
@@ -30,6 +66,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
+ pub embedding_provider: Arc<dyn EmbeddingProvider>,
}
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
@@ -47,10 +84,11 @@ pub struct CodeContextMatch {
}
impl CodeContextRetriever {
- pub fn new() -> Self {
+ pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
Self {
parser: Parser::new(),
cursor: QueryCursor::new(),
+ embedding_provider,
}
}
@@ -64,16 +102,15 @@ impl CodeContextRetriever {
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
-
- let mut sha1 = Sha1::new();
- sha1.update(&document_span);
-
+ let digest = DocumentDigest::from(document_span.as_str());
+ let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
- embedding: Vec::new(),
+ embedding: Default::default(),
name: language_name.to_string(),
- sha1: sha1.finalize().into(),
+ digest,
+ token_count,
}])
}
@@ -81,16 +118,15 @@ impl CodeContextRetriever {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content);
-
- let mut sha1 = Sha1::new();
- sha1.update(&document_span);
-
+ let digest = DocumentDigest::from(document_span.as_str());
+ let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
- embedding: Vec::new(),
+ embedding: None,
name: "Markdown".to_string(),
- sha1: sha1.finalize().into(),
+ digest,
+ token_count,
}])
}
@@ -166,10 +202,16 @@ impl CodeContextRetriever {
let mut documents = self.parse_file(content, language)?;
for document in &mut documents {
- document.content = CODE_CONTEXT_TEMPLATE
+ let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("item", &document.content);
+
+ let (document_content, token_count) =
+ self.embedding_provider.truncate(&document_content);
+
+ document.content = document_content;
+ document.token_count = token_count;
}
Ok(documents)
}
@@ -263,15 +305,14 @@ impl CodeContextRetriever {
);
}
- let mut sha1 = Sha1::new();
- sha1.update(&document_content);
-
+ let sha1 = DocumentDigest::from(document_content.as_str());
documents.push(Document {
name,
content: document_content,
range: item_range.clone(),
- embedding: vec![],
- sha1: sha1.finalize().into(),
+ embedding: None,
+ digest: sha1,
+ token_count: 0,
})
}
@@ -1,5 +1,6 @@
mod db;
mod embedding;
+mod embedding_queue;
mod parsing;
pub mod semantic_index_settings;
@@ -9,23 +10,25 @@ mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
-use embedding::{EmbeddingProvider, OpenAIEmbeddings};
-use futures::{channel::oneshot, Future};
+use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use embedding_queue::{EmbeddingQueue, FileToEmbed};
+use futures::{FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry};
use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
+use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch;
-use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId};
+use project::{
+ search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId,
+};
use smol::channel;
use std::{
cmp::Ordering,
- collections::HashMap,
- mem,
+ collections::{BTreeMap, HashMap},
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Weak},
- time::{Instant, SystemTime},
+ time::{Duration, Instant, SystemTime},
};
use util::{
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@@ -35,8 +38,9 @@ use util::{
};
use workspace::WorkspaceCreated;
-const SEMANTIC_INDEX_VERSION: usize = 7;
-const EMBEDDINGS_BATCH_SIZE: usize = 80;
+const SEMANTIC_INDEX_VERSION: usize = 9;
+const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
+const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
pub fn init(
fs: Arc<dyn Fs>,
@@ -97,14 +101,11 @@ pub fn init(
pub struct SemanticIndex {
fs: Arc<dyn Fs>,
- database_url: Arc<PathBuf>,
+ db: VectorDatabase,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
- db_update_tx: channel::Sender<DbOperation>,
- parsing_files_tx: channel::Sender<PendingFile>,
- _db_update_task: Task<()>,
- _embed_batch_tasks: Vec<Task<()>>,
- _batch_files_task: Task<()>,
+ parsing_files_tx: channel::Sender<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>,
+ _embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
@@ -113,13 +114,18 @@ struct ProjectState {
worktree_db_ids: Vec<(WorktreeId, i64)>,
_subscription: gpui::Subscription,
outstanding_job_count_rx: watch::Receiver<usize>,
- _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
- job_queue_tx: channel::Sender<IndexOperation>,
- _queue_update_task: Task<()>,
+ outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
+ changed_paths: BTreeMap<ProjectPath, ChangedPathInfo>,
+}
+
+struct ChangedPathInfo {
+ changed_at: Instant,
+ mtime: SystemTime,
+ is_deleted: bool,
}
#[derive(Clone)]
-struct JobHandle {
+pub struct JobHandle {
/// The outer Arc is here to count the clones of a JobHandle instance;
/// when the last handle to a given job is dropped, we decrement a counter (just once).
tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
@@ -133,31 +139,21 @@ impl JobHandle {
}
}
}
+
impl ProjectState {
fn new(
- cx: &mut AppContext,
subscription: gpui::Subscription,
worktree_db_ids: Vec<(WorktreeId, i64)>,
- outstanding_job_count_rx: watch::Receiver<usize>,
- _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
+ changed_paths: BTreeMap<ProjectPath, ChangedPathInfo>,
) -> Self {
- let (job_queue_tx, job_queue_rx) = channel::unbounded();
- let _queue_update_task = cx.background().spawn({
- let mut worktree_queue = HashMap::new();
- async move {
- while let Ok(operation) = job_queue_rx.recv().await {
- Self::update_queue(&mut worktree_queue, operation);
- }
- }
- });
-
+ let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0);
+ let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx));
Self {
worktree_db_ids,
outstanding_job_count_rx,
- _outstanding_job_count_tx,
+ outstanding_job_count_tx,
+ changed_paths,
_subscription: subscription,
- _queue_update_task,
- job_queue_tx,
}
}
@@ -165,41 +161,6 @@ impl ProjectState {
self.outstanding_job_count_rx.borrow().clone()
}
- fn update_queue(queue: &mut HashMap<PathBuf, IndexOperation>, operation: IndexOperation) {
- match operation {
- IndexOperation::FlushQueue => {
- let queue = std::mem::take(queue);
- for (_, op) in queue {
- match op {
- IndexOperation::IndexFile {
- absolute_path: _,
- payload,
- tx,
- } => {
- let _ = tx.try_send(payload);
- }
- IndexOperation::DeleteFile {
- absolute_path: _,
- payload,
- tx,
- } => {
- let _ = tx.try_send(payload);
- }
- _ => {}
- }
- }
- }
- IndexOperation::IndexFile {
- ref absolute_path, ..
- }
- | IndexOperation::DeleteFile {
- ref absolute_path, ..
- } => {
- queue.insert(absolute_path.clone(), operation);
- }
- }
- }
-
fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
self.worktree_db_ids
.iter()
@@ -230,66 +191,16 @@ pub struct PendingFile {
worktree_db_id: i64,
relative_path: PathBuf,
absolute_path: PathBuf,
- language: Arc<Language>,
+ language: Option<Arc<Language>>,
modified_time: SystemTime,
job_handle: JobHandle,
}
-enum IndexOperation {
- IndexFile {
- absolute_path: PathBuf,
- payload: PendingFile,
- tx: channel::Sender<PendingFile>,
- },
- DeleteFile {
- absolute_path: PathBuf,
- payload: DbOperation,
- tx: channel::Sender<DbOperation>,
- },
- FlushQueue,
-}
pub struct SearchResult {
pub buffer: ModelHandle<Buffer>,
pub range: Range<Anchor>,
}
-enum DbOperation {
- InsertFile {
- worktree_id: i64,
- documents: Vec<Document>,
- path: PathBuf,
- mtime: SystemTime,
- job_handle: JobHandle,
- },
- Delete {
- worktree_id: i64,
- path: PathBuf,
- },
- FindOrCreateWorktree {
- path: PathBuf,
- sender: oneshot::Sender<Result<i64>>,
- },
- FileMTimes {
- worktree_id: i64,
- sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
- },
- WorktreePreviouslyIndexed {
- path: Arc<Path>,
- sender: oneshot::Sender<Result<bool>>,
- },
-}
-
-enum EmbeddingJob {
- Enqueue {
- worktree_id: i64,
- path: PathBuf,
- mtime: SystemTime,
- documents: Vec<Document>,
- job_handle: JobHandle,
- },
- Flush,
-}
-
impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
@@ -306,18 +217,14 @@ impl SemanticIndex {
async fn new(
fs: Arc<dyn Fs>,
- database_url: PathBuf,
+ database_path: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
let t0 = Instant::now();
- let database_url = Arc::new(database_url);
-
- let db = cx
- .background()
- .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
- .await?;
+ let database_path = Arc::from(database_path);
+ let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
log::trace!(
"db initialization took {:?} milliseconds",
@@ -326,73 +233,55 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
- // Perform database operations
- let (db_update_tx, db_update_rx) = channel::unbounded();
- let _db_update_task = cx.background().spawn({
+ let embedding_queue =
+ EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
+ let _embedding_task = cx.background().spawn({
+ let embedded_files = embedding_queue.finished_files();
+ let db = db.clone();
async move {
- while let Ok(job) = db_update_rx.recv().await {
- Self::run_db_operation(&db, job)
- }
- }
- });
-
- // Group documents into batches and send them to the embedding provider.
- let (embed_batch_tx, embed_batch_rx) =
- channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
- let mut _embed_batch_tasks = Vec::new();
- for _ in 0..cx.background().num_cpus() {
- let embed_batch_rx = embed_batch_rx.clone();
- _embed_batch_tasks.push(cx.background().spawn({
- let db_update_tx = db_update_tx.clone();
- let embedding_provider = embedding_provider.clone();
- async move {
- while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
- Self::compute_embeddings_for_batch(
- embeddings_queue,
- &embedding_provider,
- &db_update_tx,
- )
- .await;
- }
+ while let Ok(file) = embedded_files.recv().await {
+ db.insert_file(file.worktree_id, file.path, file.mtime, file.documents)
+ .await
+ .log_err();
}
- }));
- }
-
- // Group documents into batches and send them to the embedding provider.
- let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
- let _batch_files_task = cx.background().spawn(async move {
- let mut queue_len = 0;
- let mut embeddings_queue = vec![];
- while let Ok(job) = batch_files_rx.recv().await {
- Self::enqueue_documents_to_embed(
- job,
- &mut queue_len,
- &mut embeddings_queue,
- &embed_batch_tx,
- );
}
});
// Parse files into embeddable documents.
- let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
+ let (parsing_files_tx, parsing_files_rx) =
+ channel::unbounded::<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>();
+ let embedding_queue = Arc::new(Mutex::new(embedding_queue));
let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
let fs = fs.clone();
- let parsing_files_rx = parsing_files_rx.clone();
- let batch_files_tx = batch_files_tx.clone();
- let db_update_tx = db_update_tx.clone();
+ let mut parsing_files_rx = parsing_files_rx.clone();
+ let embedding_provider = embedding_provider.clone();
+ let embedding_queue = embedding_queue.clone();
+ let background = cx.background().clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
- let mut retriever = CodeContextRetriever::new();
- while let Ok(pending_file) = parsing_files_rx.recv().await {
- Self::parse_file(
- &fs,
- pending_file,
- &mut retriever,
- &batch_files_tx,
- &parsing_files_rx,
- &db_update_tx,
- )
- .await;
+ let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+ loop {
+ let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
+ let mut next_file_to_parse = parsing_files_rx.next().fuse();
+ futures::select_biased! {
+ next_file_to_parse = next_file_to_parse => {
+ if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
+ Self::parse_file(
+ &fs,
+ pending_file,
+ &mut retriever,
+ &embedding_queue,
+ &embeddings_for_digest,
+ )
+ .await
+ } else {
+ break;
+ }
+ },
+ _ = timer => {
+ embedding_queue.lock().flush();
+ }
+ }
}
}));
}
@@ -403,192 +292,31 @@ impl SemanticIndex {
);
Self {
fs,
- database_url,
+ db,
embedding_provider,
language_registry,
- db_update_tx,
parsing_files_tx,
- _db_update_task,
- _embed_batch_tasks,
- _batch_files_task,
+ _embedding_task,
_parsing_files_tasks,
projects: HashMap::new(),
}
}))
}
- fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
- match job {
- DbOperation::InsertFile {
- worktree_id,
- documents,
- path,
- mtime,
- job_handle,
- } => {
- db.insert_file(worktree_id, path, mtime, documents)
- .log_err();
- drop(job_handle)
- }
- DbOperation::Delete { worktree_id, path } => {
- db.delete_file(worktree_id, path).log_err();
- }
- DbOperation::FindOrCreateWorktree { path, sender } => {
- let id = db.find_or_create_worktree(&path);
- sender.send(id).ok();
- }
- DbOperation::FileMTimes {
- worktree_id: worktree_db_id,
- sender,
- } => {
- let file_mtimes = db.get_file_mtimes(worktree_db_id);
- sender.send(file_mtimes).ok();
- }
- DbOperation::WorktreePreviouslyIndexed { path, sender } => {
- let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
- sender.send(worktree_indexed).ok();
- }
- }
- }
-
- async fn compute_embeddings_for_batch(
- mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
- embedding_provider: &Arc<dyn EmbeddingProvider>,
- db_update_tx: &channel::Sender<DbOperation>,
- ) {
- let mut batch_documents = vec![];
- for (_, documents, _, _, _) in embeddings_queue.iter() {
- batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
- }
-
- if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
- log::trace!(
- "created {} embeddings for {} files",
- embeddings.len(),
- embeddings_queue.len(),
- );
-
- let mut i = 0;
- let mut j = 0;
-
- for embedding in embeddings.iter() {
- while embeddings_queue[i].1.len() == j {
- i += 1;
- j = 0;
- }
-
- embeddings_queue[i].1[j].embedding = embedding.to_owned();
- j += 1;
- }
-
- for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
- db_update_tx
- .send(DbOperation::InsertFile {
- worktree_id,
- documents,
- path,
- mtime,
- job_handle,
- })
- .await
- .unwrap();
- }
- } else {
- // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
- for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
- db_update_tx
- .send(DbOperation::InsertFile {
- worktree_id,
- documents: vec![],
- path,
- mtime,
- job_handle,
- })
- .await
- .unwrap();
- }
- }
- }
-
- fn enqueue_documents_to_embed(
- job: EmbeddingJob,
- queue_len: &mut usize,
- embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
- embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
- ) {
- // Handle edge case where individual file has more documents than max batch size
- let should_flush = match job {
- EmbeddingJob::Enqueue {
- documents,
- worktree_id,
- path,
- mtime,
- job_handle,
- } => {
- // If documents is greater than embeddings batch size, recursively batch existing rows.
- if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
- let first_job = EmbeddingJob::Enqueue {
- documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
- worktree_id,
- path: path.clone(),
- mtime,
- job_handle: job_handle.clone(),
- };
-
- Self::enqueue_documents_to_embed(
- first_job,
- queue_len,
- embeddings_queue,
- embed_batch_tx,
- );
-
- let second_job = EmbeddingJob::Enqueue {
- documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
- worktree_id,
- path: path.clone(),
- mtime,
- job_handle: job_handle.clone(),
- };
-
- Self::enqueue_documents_to_embed(
- second_job,
- queue_len,
- embeddings_queue,
- embed_batch_tx,
- );
- return;
- } else {
- *queue_len += &documents.len();
- embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
- *queue_len >= EMBEDDINGS_BATCH_SIZE
- }
- }
- EmbeddingJob::Flush => true,
- };
-
- if should_flush {
- embed_batch_tx
- .try_send(mem::take(embeddings_queue))
- .unwrap();
- *queue_len = 0;
- }
- }
-
async fn parse_file(
fs: &Arc<dyn Fs>,
pending_file: PendingFile,
retriever: &mut CodeContextRetriever,
- batch_files_tx: &channel::Sender<EmbeddingJob>,
- parsing_files_rx: &channel::Receiver<PendingFile>,
- db_update_tx: &channel::Sender<DbOperation>,
+ embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
+ embeddings_for_digest: &HashMap<DocumentDigest, Embedding>,
) {
+ let Some(language) = pending_file.language else {
+ return;
+ };
+
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
- if let Some(documents) = retriever
- .parse_file_with_template(
- &pending_file.relative_path,
- &content,
- pending_file.language,
- )
+ if let Some(mut documents) = retriever
+ .parse_file_with_template(&pending_file.relative_path, &content, language)
.log_err()
{
log::trace!(
@@ -597,66 +325,23 @@ impl SemanticIndex {
documents.len()
);
- if documents.len() == 0 {
- db_update_tx
- .send(DbOperation::InsertFile {
- worktree_id: pending_file.worktree_db_id,
- documents,
- path: pending_file.relative_path,
- mtime: pending_file.modified_time,
- job_handle: pending_file.job_handle,
- })
- .await
- .unwrap();
- } else {
- batch_files_tx
- .try_send(EmbeddingJob::Enqueue {
- worktree_id: pending_file.worktree_db_id,
- path: pending_file.relative_path,
- mtime: pending_file.modified_time,
- job_handle: pending_file.job_handle,
- documents,
- })
- .unwrap();
+ for document in documents.iter_mut() {
+ if let Some(embedding) = embeddings_for_digest.get(&document.digest) {
+ document.embedding = Some(embedding.to_owned());
+ }
}
- }
- }
- if parsing_files_rx.len() == 0 {
- batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
+ embedding_queue.lock().push(FileToEmbed {
+ worktree_id: pending_file.worktree_db_id,
+ path: pending_file.relative_path,
+ mtime: pending_file.modified_time,
+ job_handle: pending_file.job_handle,
+ documents,
+ });
+ }
}
}
- fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
- let (tx, rx) = oneshot::channel();
- self.db_update_tx
- .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
- .unwrap();
- async move { rx.await? }
- }
-
- fn get_file_mtimes(
- &self,
- worktree_id: i64,
- ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
- let (tx, rx) = oneshot::channel();
- self.db_update_tx
- .try_send(DbOperation::FileMTimes {
- worktree_id,
- sender: tx,
- })
- .unwrap();
- async move { rx.await? }
- }
-
- fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
- let (tx, rx) = oneshot::channel();
- self.db_update_tx
- .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
- .unwrap();
- async move { rx.await? }
- }
-
pub fn project_previously_indexed(
&mut self,
project: ModelHandle<Project>,
@@ -665,7 +350,10 @@ impl SemanticIndex {
let worktrees_indexed_previously = project
.read(cx)
.worktrees(cx)
- .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
+ .map(|worktree| {
+ self.db
+ .worktree_previously_indexed(&worktree.read(cx).abs_path())
+ })
.collect::<Vec<_>>();
cx.spawn(|_, _cx| async move {
let worktree_indexed_previously =
@@ -679,103 +367,73 @@ impl SemanticIndex {
}
fn project_entries_changed(
- &self,
+ &mut self,
project: ModelHandle<Project>,
changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
cx: &mut ModelContext<'_, SemanticIndex>,
worktree_id: &WorktreeId,
- ) -> Result<()> {
- let parsing_files_tx = self.parsing_files_tx.clone();
- let db_update_tx = self.db_update_tx.clone();
- let (job_queue_tx, outstanding_job_tx, worktree_db_id) = {
- let state = self
- .projects
- .get(&project.downgrade())
- .ok_or(anyhow!("Project not yet initialized"))?;
- let worktree_db_id = state
- .db_id_for_worktree_id(*worktree_id)
- .ok_or(anyhow!("Worktree ID in Database Not Available"))?;
- (
- state.job_queue_tx.clone(),
- state._outstanding_job_count_tx.clone(),
- worktree_db_id,
- )
+ ) {
+ let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
+ return;
+ };
+ let project = project.downgrade();
+ let Some(project_state) = self.projects.get_mut(&project) else {
+ return;
};
- let language_registry = self.language_registry.clone();
- let parsing_files_tx = parsing_files_tx.clone();
- let db_update_tx = db_update_tx.clone();
-
- let worktree = project
- .read(cx)
- .worktree_for_id(worktree_id.clone(), cx)
- .ok_or(anyhow!("Worktree not available"))?
- .read(cx)
- .snapshot();
- cx.spawn(|_, _| async move {
- let worktree = worktree.clone();
- for (path, entry_id, path_change) in changes.iter() {
- let relative_path = path.to_path_buf();
- let absolute_path = worktree.absolutize(path);
-
- let Some(entry) = worktree.entry_for_id(*entry_id) else {
- continue;
- };
- if entry.is_ignored || entry.is_symlink || entry.is_external {
- continue;
+ let embeddings_for_digest = {
+ let mut worktree_id_file_paths = HashMap::new();
+ for (path, _) in &project_state.changed_paths {
+ if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id)
+ {
+ worktree_id_file_paths
+ .entry(worktree_db_id)
+ .or_insert(Vec::new())
+ .push(path.path.clone());
}
+ }
+ self.db.embeddings_for_files(worktree_id_file_paths)
+ };
- log::trace!("File Event: {:?}, Path: {:?}", &path_change, &path);
- match path_change {
- PathChange::AddedOrUpdated | PathChange::Updated | PathChange::Added => {
- if let Ok(language) = language_registry
- .language_for_file(&relative_path, None)
- .await
- {
- if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
- && &language.name().as_ref() != &"Markdown"
- && language
- .grammar()
- .and_then(|grammar| grammar.embedding_config.as_ref())
- .is_none()
- {
- continue;
- }
+ let worktree = worktree.read(cx);
+ let change_time = Instant::now();
+ for (path, entry_id, change) in changes.iter() {
+ let Some(entry) = worktree.entry_for_id(*entry_id) else {
+ continue;
+ };
+ if entry.is_ignored || entry.is_symlink || entry.is_external {
+ continue;
+ }
+ let project_path = ProjectPath {
+ worktree_id: *worktree_id,
+ path: path.clone(),
+ };
+ project_state.changed_paths.insert(
+ project_path,
+ ChangedPathInfo {
+ changed_at: change_time,
+ mtime: entry.mtime,
+ is_deleted: *change == PathChange::Removed,
+ },
+ );
+ }
- let job_handle = JobHandle::new(&outstanding_job_tx);
- let new_operation = IndexOperation::IndexFile {
- absolute_path: absolute_path.clone(),
- payload: PendingFile {
- worktree_db_id,
- relative_path,
- absolute_path,
- language,
- modified_time: entry.mtime,
- job_handle,
- },
- tx: parsing_files_tx.clone(),
- };
- let _ = job_queue_tx.try_send(new_operation);
- }
- }
- PathChange::Removed => {
- let new_operation = IndexOperation::DeleteFile {
- absolute_path,
- payload: DbOperation::Delete {
- worktree_id: worktree_db_id,
- path: relative_path,
- },
- tx: db_update_tx.clone(),
- };
- let _ = job_queue_tx.try_send(new_operation);
- }
- _ => {}
- }
+ cx.spawn_weak(|this, mut cx| async move {
+ let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default();
+
+ cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
+ if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
+ Self::reindex_changed_paths(
+ this,
+ project,
+ Some(change_time),
+ &mut cx,
+ Arc::new(embeddings_for_digest),
+ )
+ .await;
}
})
.detach();
-
- Ok(())
}
pub fn initialize_project(
@@ -799,20 +457,18 @@ impl SemanticIndex {
.read(cx)
.worktrees(cx)
.map(|worktree| {
- self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
+ self.db
+ .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
})
.collect::<Vec<_>>();
let _subscription = cx.subscribe(&project, |this, project, event, cx| {
if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
- let _ =
- this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id);
+ this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id);
};
});
let language_registry = self.language_registry.clone();
- let parsing_files_tx = self.parsing_files_tx.clone();
- let db_update_tx = self.db_update_tx.clone();
cx.spawn(|this, mut cx| async move {
futures::future::join_all(worktree_scans_complete).await;
@@ -833,7 +489,7 @@ impl SemanticIndex {
db_ids_by_worktree_id.insert(worktree.id(), db_id);
worktree_file_mtimes.insert(
worktree.id(),
- this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
+ this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id))
.await?,
);
}
@@ -843,17 +499,13 @@ impl SemanticIndex {
.map(|(a, b)| (*a, *b))
.collect();
- let (job_count_tx, job_count_rx) = watch::channel_with(0);
- let job_count_tx = Arc::new(Mutex::new(job_count_tx));
- let job_count_tx_longlived = job_count_tx.clone();
-
- let worktree_files = cx
+ let changed_paths = cx
.background()
.spawn(async move {
- let mut worktree_files = Vec::new();
+ let mut changed_paths = BTreeMap::new();
+ let now = Instant::now();
for worktree in worktrees.into_iter() {
let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
- let worktree_db_id = db_ids_by_worktree_id[&worktree.id()];
for file in worktree.files(false, 0) {
let absolute_path = worktree.absolutize(&file.path);
@@ -876,59 +528,51 @@ impl SemanticIndex {
continue;
}
- let path_buf = file.path.to_path_buf();
let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
let already_stored = stored_mtime
.map_or(false, |existing_mtime| existing_mtime == file.mtime);
if !already_stored {
- let job_handle = JobHandle::new(&job_count_tx);
- worktree_files.push(IndexOperation::IndexFile {
- absolute_path: absolute_path.clone(),
- payload: PendingFile {
- worktree_db_id,
- relative_path: path_buf,
- absolute_path,
- language,
- job_handle,
- modified_time: file.mtime,
+ changed_paths.insert(
+ ProjectPath {
+ worktree_id: worktree.id(),
+ path: file.path.clone(),
+ },
+ ChangedPathInfo {
+ changed_at: now,
+ mtime: file.mtime,
+ is_deleted: false,
},
- tx: parsing_files_tx.clone(),
- });
+ );
}
}
}
+
// Clean up entries from database that are no longer in the worktree.
- for (path, _) in file_mtimes {
- worktree_files.push(IndexOperation::DeleteFile {
- absolute_path: worktree.absolutize(path.as_path()),
- payload: DbOperation::Delete {
- worktree_id: worktree_db_id,
- path,
+ for (path, mtime) in file_mtimes {
+ changed_paths.insert(
+ ProjectPath {
+ worktree_id: worktree.id(),
+ path: path.into(),
},
- tx: db_update_tx.clone(),
- });
+ ChangedPathInfo {
+ changed_at: now,
+ mtime,
+ is_deleted: true,
+ },
+ );
}
}
- anyhow::Ok(worktree_files)
+ anyhow::Ok(changed_paths)
})
.await?;
- this.update(&mut cx, |this, cx| {
- let project_state = ProjectState::new(
- cx,
- _subscription,
- worktree_db_ids,
- job_count_rx,
- job_count_tx_longlived,
+ this.update(&mut cx, |this, _| {
+ this.projects.insert(
+ project.downgrade(),
+ ProjectState::new(_subscription, worktree_db_ids, changed_paths),
);
-
- for op in worktree_files {
- let _ = project_state.job_queue_tx.try_send(op);
- }
-
- this.projects.insert(project.downgrade(), project_state);
});
Result::<(), _>::Ok(())
})
@@ -939,27 +583,45 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<(usize, watch::Receiver<usize>)>> {
- let state = self.projects.get_mut(&project.downgrade());
- let state = if state.is_none() {
- return Task::Ready(Some(Err(anyhow!("Project not yet initialized"))));
- } else {
- state.unwrap()
- };
+ cx.spawn(|this, mut cx| async move {
+ let embeddings_for_digest = this.read_with(&cx, |this, _| {
+ if let Some(state) = this.projects.get(&project.downgrade()) {
+ let mut worktree_id_file_paths = HashMap::default();
+ for (path, _) in &state.changed_paths {
+ if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id)
+ {
+ worktree_id_file_paths
+ .entry(worktree_db_id)
+ .or_insert(Vec::new())
+ .push(path.path.clone());
+ }
+ }
+
+ Ok(this.db.embeddings_for_files(worktree_id_file_paths))
+ } else {
+ Err(anyhow!("Project not yet initialized"))
+ }
+ })?;
- // let parsing_files_tx = self.parsing_files_tx.clone();
- // let db_update_tx = self.db_update_tx.clone();
- let job_count_rx = state.outstanding_job_count_rx.clone();
- let count = state.get_outstanding_count();
+ let embeddings_for_digest = Arc::new(embeddings_for_digest.await?);
- cx.spawn(|this, mut cx| async move {
- this.update(&mut cx, |this, _| {
- let Some(state) = this.projects.get_mut(&project.downgrade()) else {
- return;
- };
- let _ = state.job_queue_tx.try_send(IndexOperation::FlushQueue);
- });
+ Self::reindex_changed_paths(
+ this.clone(),
+ project.clone(),
+ None,
+ &mut cx,
+ embeddings_for_digest,
+ )
+ .await;
- Ok((count, job_count_rx))
+ this.update(&mut cx, |this, _cx| {
+ let Some(state) = this.projects.get(&project.downgrade()) else {
+ return Err(anyhow!("Project not yet initialized"));
+ };
+ let job_count_rx = state.outstanding_job_count_rx.clone();
+ let count = state.get_outstanding_count();
+ Ok((count, job_count_rx))
+ })
})
}
@@ -1,14 +1,15 @@
use crate::{
- db::dot,
- embedding::EmbeddingProvider,
- parsing::{subtract_ranges, CodeContextRetriever, Document},
+ embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
+ embedding_queue::EmbeddingQueue,
+ parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest},
semantic_index_settings::SemanticIndexSettings,
- SearchResult, SemanticIndex,
+ FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use anyhow::Result;
use async_trait::async_trait;
-use gpui::{Task, TestAppContext};
+use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
+use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
use rand::{rngs::StdRng, Rng};
@@ -20,8 +21,10 @@ use std::{
atomic::{self, AtomicUsize},
Arc,
},
+ time::SystemTime,
};
use unindent::Unindent;
+use util::RandomCharIter;
#[ctor::ctor]
fn init_logger() {
@@ -31,12 +34,8 @@ fn init_logger() {
}
#[gpui::test]
-async fn test_semantic_index(cx: &mut TestAppContext) {
- cx.update(|cx| {
- cx.set_global(SettingsStore::test(cx));
- settings::register::<SemanticIndexSettings>(cx);
- settings::register::<ProjectSettings>(cx);
- });
+async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
+ init_test(cx);
let fs = FakeFs::new(cx.background());
fs.insert_tree(
@@ -56,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
fn bbb() {
println!(\"bbbbbbbbbbbbb!\");
}
+ struct pqpqpqp {}
".unindent(),
"file3.toml": "
ZZZZZZZZZZZZZZZZZZ = 5
@@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let db_path = db_dir.path().join("db.sqlite");
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let store = SemanticIndex::new(
+ let semantic_index = SemanticIndex::new(
fs.clone(),
db_path,
embedding_provider.clone(),
@@ -87,21 +87,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
- let _ = store
+ let _ = semantic_index
.update(cx, |store, cx| {
store.initialize_project(project.clone(), cx)
})
.await;
- let (file_count, outstanding_file_count) = store
+ let (file_count, outstanding_file_count) = semantic_index
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 3);
- cx.foreground().run_until_parked();
+ deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*outstanding_file_count.borrow(), 0);
- let search_results = store
+ let search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -122,6 +122,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
(Path::new("src/file2.rs").into(), 0),
(Path::new("src/file3.toml").into(), 0),
(Path::new("src/file1.rs").into(), 45),
+ (Path::new("src/file2.rs").into(), 45),
],
cx,
);
@@ -129,7 +130,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
// Test Include Files Functonality
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
- let rust_only_search_results = store
+ let rust_only_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -149,11 +150,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
(Path::new("src/file1.rs").into(), 0),
(Path::new("src/file2.rs").into(), 0),
(Path::new("src/file1.rs").into(), 45),
+ (Path::new("src/file2.rs").into(), 45),
],
cx,
);
- let no_rust_search_results = store
+ let no_rust_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -186,24 +188,87 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
.await
.unwrap();
- cx.foreground().run_until_parked();
+ deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
let prev_embedding_count = embedding_provider.embedding_count();
- let (file_count, outstanding_file_count) = store
+ let (file_count, outstanding_file_count) = semantic_index
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 1);
- cx.foreground().run_until_parked();
+ deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*outstanding_file_count.borrow(), 0);
assert_eq!(
embedding_provider.embedding_count() - prev_embedding_count,
- 2
+ 1
);
}
+#[gpui::test(iterations = 10)]
+async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
+ let (outstanding_job_count, _) = postage::watch::channel_with(0);
+ let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
+
+ let files = (1..=3)
+ .map(|file_ix| FileToEmbed {
+ worktree_id: 5,
+ path: format!("path-{file_ix}").into(),
+ mtime: SystemTime::now(),
+ documents: (0..rng.gen_range(4..22))
+ .map(|document_ix| {
+ let content_len = rng.gen_range(10..100);
+ let content = RandomCharIter::new(&mut rng)
+ .with_simple_text()
+ .take(content_len)
+ .collect::<String>();
+ let digest = DocumentDigest::from(content.as_str());
+ Document {
+ range: 0..10,
+ embedding: None,
+ name: format!("document {document_ix}"),
+ content,
+ digest,
+ token_count: rng.gen_range(10..30),
+ }
+ })
+ .collect(),
+ job_handle: JobHandle::new(&outstanding_job_count),
+ })
+ .collect::<Vec<_>>();
+
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+
+ let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
+ for file in &files {
+ queue.push(file.clone());
+ }
+ queue.flush();
+
+ cx.foreground().run_until_parked();
+ let finished_files = queue.finished_files();
+ let mut embedded_files: Vec<_> = files
+ .iter()
+ .map(|_| finished_files.try_recv().expect("no finished file"))
+ .collect();
+
+ let expected_files: Vec<_> = files
+ .iter()
+ .map(|file| {
+ let mut file = file.clone();
+ for doc in &mut file.documents {
+ doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
+ }
+ file
+ })
+ .collect();
+
+ embedded_files.sort_by_key(|f| f.path.clone());
+
+ assert_eq!(embedded_files, expected_files);
+}
+
#[track_caller]
fn assert_search_results(
actual: &[SearchResult],
@@ -227,7 +292,8 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/// A doc comment
@@ -314,7 +380,8 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
{
@@ -397,7 +464,8 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/* globals importScripts, backend */
@@ -495,7 +563,8 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
-- Creates a new class
@@ -568,7 +637,8 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
defmodule File.Stream do
@@ -684,7 +754,8 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/**
@@ -836,7 +907,8 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
# This concern is inspired by "sudo mode" on GitHub. It
@@ -1026,7 +1098,8 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
- let mut retriever = CodeContextRetriever::new();
+ let embedding_provider = Arc::new(DummyEmbeddings {});
+ let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
<?php
@@ -1173,36 +1246,6 @@ async fn test_code_context_retrieval_php() {
);
}
-#[gpui::test]
-fn test_dot_product(mut rng: StdRng) {
- assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
- assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
-
- for _ in 0..100 {
- let size = 1536;
- let mut a = vec![0.; size];
- let mut b = vec![0.; size];
- for (a, b) in a.iter_mut().zip(b.iter_mut()) {
- *a = rng.gen();
- *b = rng.gen();
- }
-
- assert_eq!(
- round_to_decimals(dot(&a, &b), 1),
- round_to_decimals(reference_dot(&a, &b), 1)
- );
- }
-
- fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
- let factor = (10.0 as f32).powi(decimal_places);
- (n * factor).round() / factor
- }
-
- fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
- a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
- }
-}
-
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
@@ -1212,35 +1255,42 @@ impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
+
+ fn embed_sync(&self, span: &str) -> Embedding {
+ let mut result = vec![1.0; 26];
+ for letter in span.chars() {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result.into()
+ }
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
- self.embedding_count
- .fetch_add(spans.len(), atomic::Ordering::SeqCst);
- Ok(spans
- .iter()
- .map(|span| {
- let mut result = vec![1.0; 26];
- for letter in span.chars() {
- let letter = letter.to_ascii_lowercase();
- if letter as u32 >= 'a' as u32 {
- let ix = (letter as u32) - ('a' as u32);
- if ix < 26 {
- result[ix as usize] += 1.0;
- }
- }
- }
+ fn truncate(&self, span: &str) -> (String, usize) {
+ (span.to_string(), 1)
+ }
- let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
- for x in &mut result {
- *x /= norm;
- }
+ fn max_tokens_per_batch(&self) -> usize {
+ 200
+ }
- result
- })
- .collect())
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ self.embedding_count
+ .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+ Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
@@ -1684,3 +1734,11 @@ fn test_subtract_ranges() {
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
}
+
+fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ cx.set_global(SettingsStore::test(cx));
+ settings::register::<SemanticIndexSettings>(cx);
+ settings::register::<ProjectSettings>(cx);
+ });
+}
@@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
Defer(Some(f))
}
-pub struct RandomCharIter<T: Rng>(T);
+pub struct RandomCharIter<T: Rng> {
+ rng: T,
+ simple_text: bool,
+}
impl<T: Rng> RandomCharIter<T> {
pub fn new(rng: T) -> Self {
- Self(rng)
+ Self {
+ rng,
+ simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
+ }
+ }
+
+ pub fn with_simple_text(mut self) -> Self {
+ self.simple_text = true;
+ self
}
}
@@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
type Item = char;
fn next(&mut self) -> Option<Self::Item> {
- if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
- return if self.0.gen_range(0..100) < 5 {
+ if self.simple_text {
+ return if self.rng.gen_range(0..100) < 5 {
Some('\n')
} else {
- Some(self.0.gen_range(b'a'..b'z' + 1).into())
+ Some(self.rng.gen_range(b'a'..b'z' + 1).into())
};
}
- match self.0.gen_range(0..100) {
+ match self.rng.gen_range(0..100) {
// whitespace
- 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
+ 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
// two-byte greek letters
- 20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
+ 20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
// // three-byte characters
- 33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
+ 33..=45 => ['✋', '✅', '❌', '❎', '⭐']
+ .choose(&mut self.rng)
+ .copied(),
// // four-byte characters
- 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
+ 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
// ascii letters
- _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
+ _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
}
}
}