@@ -4580,6 +4580,19 @@ dependencies = [
"tempfile",
]
+[[package]]
+name = "ndarray"
+version = "0.15.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
+dependencies = [
+ "matrixmultiply",
+ "num-complex 0.4.4",
+ "num-integer",
+ "num-traits",
+ "rawpointer",
+]
+
[[package]]
name = "ndk"
version = "0.7.0"
@@ -4706,7 +4719,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
dependencies = [
"num-bigint 0.2.6",
- "num-complex",
+ "num-complex 0.2.4",
"num-integer",
"num-iter",
"num-rational 0.2.4",
@@ -4762,6 +4775,15 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "num-complex"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
+dependencies = [
+ "num-traits",
+]
+
[[package]]
name = "num-derive"
version = "0.3.3"
@@ -6751,6 +6773,7 @@ dependencies = [
"language",
"lazy_static",
"log",
+ "ndarray",
"node_runtime",
"ordered-float",
"parking_lot 0.11.2",
@@ -7,13 +7,13 @@ use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::executor;
+use ndarray::{Array1, Array2};
use ordered_float::OrderedFloat;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
use rusqlite::params;
use rusqlite::types::Value;
use std::{
- cmp::Reverse,
future::Future,
ops::Range,
path::{Path, PathBuf},
@@ -23,6 +23,13 @@ use std::{
};
use util::TryFutureExt;
+pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
+ let mut indices = (0..data.len()).collect::<Vec<_>>();
+ indices.sort_by_key(|&i| &data[i]);
+ indices.reverse();
+ indices
+}
+
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
@@ -409,23 +416,91 @@ impl VectorDatabase {
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
- let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec();
+ let query = query_embedding.clone().0;
+ let query = Array1::from_vec(query);
self.transact(move |db| {
- let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
- Self::for_each_span(db, &file_ids, |id, embedding| {
- let similarity = embedding.similarity(&query_embedding);
- let ix = match results
- .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
- {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
- results.insert(ix, (id, similarity));
- results.truncate(limit);
- })?;
+ let mut query_statement = db.prepare(
+ "
+ SELECT
+ id, embedding
+ FROM
+ spans
+ WHERE
+ file_id IN rarray(?)
+ ",
+ )?;
+
+ let deserialized_rows = query_statement
+ .query_map(params![ids_to_sql(&file_ids)], |row| {
+ Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
+ })?
+ .filter_map(|row| row.ok())
+ .collect::<Vec<(usize, Embedding)>>();
+
+ if deserialized_rows.len() == 0 {
+ return Ok(Vec::new());
+ }
+
+ // Get Length of Embeddings Returned
+ let embedding_len = deserialized_rows[0].1 .0.len();
+
+ let batch_n = 1000;
+ let mut batches = Vec::new();
+ let mut batch_ids = Vec::new();
+ let mut batch_embeddings: Vec<f32> = Vec::new();
+ deserialized_rows.iter().for_each(|(id, embedding)| {
+ batch_ids.push(id);
+ batch_embeddings.extend(&embedding.0);
+
+ if batch_ids.len() == batch_n {
+ let embeddings = std::mem::take(&mut batch_embeddings);
+ let ids = std::mem::take(&mut batch_ids);
+ let array =
+ Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
+ match array {
+ Ok(array) => {
+ batches.push((ids, array));
+ }
+ Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+ }
+ }
+ });
- anyhow::Ok(results)
+ if batch_ids.len() > 0 {
+ let array = Array2::from_shape_vec(
+ (batch_ids.len(), embedding_len),
+ batch_embeddings.clone(),
+ );
+ match array {
+ Ok(array) => {
+ batches.push((batch_ids.clone(), array));
+ }
+ Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+ }
+ }
+
+ let mut ids: Vec<usize> = Vec::new();
+ let mut results = Vec::new();
+ for (batch_ids, array) in batches {
+ let scores = array
+ .dot(&query.t())
+ .to_vec()
+ .iter()
+ .map(|score| OrderedFloat(*score))
+ .collect::<Vec<OrderedFloat<f32>>>();
+ results.extend(scores);
+ ids.extend(batch_ids);
+ }
+
+ let sorted_idx = argsort(&results);
+ let mut sorted_results = Vec::new();
+ let last_idx = limit.min(sorted_idx.len());
+ for idx in &sorted_idx[0..last_idx] {
+ sorted_results.push((ids[*idx] as i64, results[*idx]))
+ }
+
+ Ok(sorted_results)
})
}
@@ -468,31 +543,6 @@ impl VectorDatabase {
})
}
- fn for_each_span(
- db: &rusqlite::Connection,
- file_ids: &[i64],
- mut f: impl FnMut(i64, Embedding),
- ) -> Result<()> {
- let mut query_statement = db.prepare(
- "
- SELECT
- id, embedding
- FROM
- spans
- WHERE
- file_id IN rarray(?)
- ",
- )?;
-
- query_statement
- .query_map(params![ids_to_sql(&file_ids)], |row| {
- Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
- })?
- .filter_map(|row| row.ok())
- .for_each(|(id, embedding)| f(id, embedding));
- Ok(())
- }
-
pub fn spans_for_ids(
&self,
ids: &[i64],
@@ -705,11 +705,13 @@ impl SemanticIndex {
cx.spawn(|this, mut cx| async move {
index.await?;
+ let t0 = Instant::now();
let query = embedding_provider
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
+ log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
let search_start = Instant::now();
let modified_buffer_results = this.update(&mut cx, |this, cx| {
@@ -787,10 +789,15 @@ impl SemanticIndex {
let batch_n = cx.background().num_cpus();
let ids_len = file_ids.clone().len();
- let batch_size = if ids_len <= batch_n {
- ids_len
- } else {
- ids_len / batch_n
+ let minimum_batch_size = 50;
+
+ let batch_size = {
+ let size = ids_len / batch_n;
+ if size < minimum_batch_size {
+ minimum_batch_size
+ } else {
+ size
+ }
};
let mut batch_results = Vec::new();
@@ -822,6 +829,7 @@ impl SemanticIndex {
Ok(ix) => ix,
Err(ix) => ix,
};
+
results.insert(ix, (id, similarity));
results.truncate(limit);
}
@@ -856,7 +864,6 @@ impl SemanticIndex {
})?;
let buffers = futures::future::join_all(tasks).await;
-
Ok(buffers
.into_iter()
.zip(ranges)