Detailed changes
@@ -1,13 +1,10 @@
-use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
+use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION};
use anyhow::{anyhow, Context, Result};
use futures::channel::oneshot;
use gpui::executor;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
-use rusqlite::{
- params,
- types::{FromSql, FromSqlResult, ValueRef},
-};
+use rusqlite::params;
use std::{
cmp::Ordering,
collections::HashMap,
@@ -27,34 +24,6 @@ pub struct FileRecord {
pub mtime: Timestamp,
}
-#[derive(Debug)]
-struct Embedding(pub Vec<f32>);
-
-#[derive(Debug)]
-struct Sha1(pub Vec<u8>);
-
-impl FromSql for Embedding {
- fn column_result(value: ValueRef) -> FromSqlResult<Self> {
- let bytes = value.as_blob()?;
- let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
- if embedding.is_err() {
- return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
- }
- return Ok(Embedding(embedding.unwrap()));
- }
-}
-
-impl FromSql for Sha1 {
- fn column_result(value: ValueRef) -> FromSqlResult<Self> {
- let bytes = value.as_blob()?;
- let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
- if sha1.is_err() {
- return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
- }
- return Ok(Sha1(sha1.unwrap()));
- }
-}
-
#[derive(Clone)]
pub struct VectorDatabase {
path: Arc<Path>,
@@ -255,9 +224,6 @@ impl VectorDatabase {
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in documents {
- let embedding_blob = bincode::serialize(&document.embedding)?;
- let sha_blob = bincode::serialize(&document.sha1)?;
-
db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
@@ -265,8 +231,8 @@ impl VectorDatabase {
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
- embedding_blob,
- sha_blob
+ document.embedding,
+ document.sha1
],
)?;
}
@@ -351,7 +317,7 @@ impl VectorDatabase {
pub fn top_k_search(
&self,
- query_embedding: &Vec<f32>,
+ query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
@@ -360,7 +326,7 @@ impl VectorDatabase {
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
Self::for_each_document(db, &file_ids, |id, embedding| {
- let similarity = dot(&embedding, &query_embedding);
+ let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
@@ -417,7 +383,7 @@ impl VectorDatabase {
fn for_each_document(
db: &rusqlite::Connection,
file_ids: &[i64],
- mut f: impl FnMut(i64, Vec<f32>),
+ mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = db.prepare(
"
@@ -435,7 +401,7 @@ impl VectorDatabase {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
- .for_each(|(id, embedding)| f(id, embedding.0));
+ .for_each(|(id, embedding)| f(id, embedding));
Ok(())
}
@@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
.collect::<Vec<_>>(),
)
}
-
-pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
- let len = vec_a.len();
- assert_eq!(len, vec_b.len());
-
- let mut result = 0.0;
- unsafe {
- matrixmultiply::sgemm(
- 1,
- len,
- 1,
- 1.0,
- vec_a.as_ptr(),
- len as isize,
- 1,
- vec_b.as_ptr(),
- 1,
- len as isize,
- 0.0,
- &mut result as *mut f32,
- 1,
- 1,
- );
- }
- result
-}
@@ -8,6 +8,8 @@ use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parse_duration::parse;
+use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
+use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
@@ -20,6 +22,62 @@ lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
+#[derive(Debug, PartialEq, Clone)]
+pub struct Embedding(Vec<f32>);
+
+impl From<Vec<f32>> for Embedding {
+ fn from(value: Vec<f32>) -> Self {
+ Embedding(value)
+ }
+}
+
+impl Embedding {
+ pub fn similarity(&self, other: &Self) -> f32 {
+ let len = self.0.len();
+ assert_eq!(len, other.0.len());
+
+ let mut result = 0.0;
+ unsafe {
+ matrixmultiply::sgemm(
+ 1,
+ len,
+ 1,
+ 1.0,
+ self.0.as_ptr(),
+ len as isize,
+ 1,
+ other.0.as_ptr(),
+ 1,
+ len as isize,
+ 0.0,
+ &mut result as *mut f32,
+ 1,
+ 1,
+ );
+ }
+ result
+ }
+}
+
+impl FromSql for Embedding {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let bytes = value.as_blob()?;
+ let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+ if embedding.is_err() {
+ return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+ }
+ Ok(Embedding(embedding.unwrap()))
+ }
+}
+
+impl ToSql for Embedding {
+ fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+ let bytes = bincode::serialize(&self.0)
+ .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
+ Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
+ }
+}
+
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
@@ -53,7 +111,7 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
}
@@ -62,10 +120,10 @@ pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
- let dummy_vec = vec![0.32 as f32; 1536];
+ let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
return Ok(vec![dummy_vec; spans.len()]);
}
@@ -137,7 +195,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
(output, token_count)
}
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@@ -175,7 +233,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
return Ok(response
.data
.into_iter()
- .map(|embedding| embedding.embedding)
+ .map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
@@ -218,3 +276,49 @@ impl EmbeddingProvider for OpenAIEmbeddings {
Err(anyhow!("openai max retries"))
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::prelude::*;
+
+ #[gpui::test]
+ fn test_similarity(mut rng: StdRng) {
+ assert_eq!(
+ Embedding::from(vec![1., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
+ 0.
+ );
+ assert_eq!(
+ Embedding::from(vec![2., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
+ 6.
+ );
+
+ for _ in 0..100 {
+ let size = 1536;
+ let mut a = vec![0.; size];
+ let mut b = vec![0.; size];
+ for (a, b) in a.iter_mut().zip(b.iter_mut()) {
+ *a = rng.gen();
+ *b = rng.gen();
+ }
+ let a = Embedding::from(a);
+ let b = Embedding::from(b);
+
+ assert_eq!(
+ round_to_decimals(a.similarity(&b), 1),
+ round_to_decimals(reference_dot(&a.0, &b.0), 1)
+ );
+ }
+
+ fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+ let factor = (10.0 as f32).powi(decimal_places);
+ (n * factor).round() / factor
+ }
+
+ fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
+ a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+ }
+ }
+}
@@ -121,7 +121,7 @@ impl EmbeddingQueue {
&mut fragment.file.lock().documents[fragment.document_range.clone()]
{
if let Some(embedding) = embeddings.next() {
- document.embedding = embedding;
+ document.embedding = Some(embedding);
} else {
//
log::error!("number of embeddings returned different from number of documents");
@@ -1,7 +1,11 @@
-use crate::embedding::EmbeddingProvider;
-use anyhow::{anyhow, Ok, Result};
+use crate::embedding::{EmbeddingProvider, Embedding};
+use anyhow::{anyhow, Result};
use language::{Grammar, Language};
-use sha1::{Digest, Sha1};
+use rusqlite::{
+ types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
+ ToSql,
+};
+use sha1::Digest;
use std::{
cmp::{self, Reverse},
collections::HashSet,
@@ -11,13 +15,43 @@ use std::{
};
use tree_sitter::{Parser, QueryCursor};
+#[derive(Debug, PartialEq, Clone)]
+pub struct Sha1([u8; 20]);
+
+impl FromSql for Sha1 {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let blob = value.as_blob()?;
+ let bytes =
+ blob.try_into()
+ .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
+ expected_size: 20,
+ blob_size: blob.len(),
+ })?;
+ return Ok(Sha1(bytes));
+ }
+}
+
+impl ToSql for Sha1 {
+ fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+ self.0.to_sql()
+ }
+}
+
+impl From<&'_ str> for Sha1 {
+ fn from(value: &'_ str) -> Self {
+ let mut sha1 = sha1::Sha1::new();
+ sha1.update(value);
+ Self(sha1.finalize().into())
+ }
+}
+
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub name: String,
pub range: Range<usize>,
pub content: String,
- pub embedding: Vec<f32>,
- pub sha1: [u8; 20],
+ pub embedding: Option<Embedding>,
+ pub sha1: Sha1,
pub token_count: usize,
}
@@ -69,17 +103,16 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
- let mut sha1 = Sha1::new();
- sha1.update(&document_span);
+ let sha1 = Sha1::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
- embedding: Vec::new(),
+ embedding: Default::default(),
name: language_name.to_string(),
- sha1: sha1.finalize().into(),
+ sha1,
token_count,
}])
}
@@ -88,18 +121,14 @@ impl CodeContextRetriever {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content);
-
- let mut sha1 = Sha1::new();
- sha1.update(&document_span);
-
+ let sha1 = Sha1::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
-
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
- embedding: Vec::new(),
+ embedding: None,
name: "Markdown".to_string(),
- sha1: sha1.finalize().into(),
+ sha1,
token_count,
}])
}
@@ -279,15 +308,13 @@ impl CodeContextRetriever {
);
}
- let mut sha1 = Sha1::new();
- sha1.update(&document_content);
-
+ let sha1 = Sha1::from(document_content.as_str());
documents.push(Document {
name,
content: document_content,
range: item_range.clone(),
- embedding: vec![],
- sha1: sha1.finalize().into(),
+ embedding: None,
+ sha1,
token_count: 0,
})
}
@@ -1,8 +1,7 @@
use crate::{
- db::dot,
- embedding::{DummyEmbeddings, EmbeddingProvider},
+ embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue,
- parsing::{subtract_ranges, CodeContextRetriever, Document},
+ parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1},
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex,
};
@@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
documents: (0..rng.gen_range(4..22))
.map(|document_ix| {
let content_len = rng.gen_range(10..100);
+ let content = RandomCharIter::new(&mut rng)
+ .with_simple_text()
+ .take(content_len)
+ .collect::<String>();
+ let sha1 = Sha1::from(content.as_str());
Document {
range: 0..10,
- embedding: Vec::new(),
+ embedding: None,
name: format!("document {document_ix}"),
- content: RandomCharIter::new(&mut rng)
- .with_simple_text()
- .take(content_len)
- .collect(),
- sha1: rng.gen(),
+ content,
+ sha1,
token_count: rng.gen_range(10..30),
}
})
@@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
.map(|file| {
let mut file = file.clone();
for doc in &mut file.documents {
- doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
+ doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
}
file
})
@@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() {
);
}
-#[gpui::test]
-fn test_dot_product(mut rng: StdRng) {
- assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
- assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
-
- for _ in 0..100 {
- let size = 1536;
- let mut a = vec![0.; size];
- let mut b = vec![0.; size];
- for (a, b) in a.iter_mut().zip(b.iter_mut()) {
- *a = rng.gen();
- *b = rng.gen();
- }
-
- assert_eq!(
- round_to_decimals(dot(&a, &b), 1),
- round_to_decimals(reference_dot(&a, &b), 1)
- );
- }
-
- fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
- let factor = (10.0 as f32).powi(decimal_places);
- (n * factor).round() / factor
- }
-
- fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
- a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
- }
-}
-
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
@@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
- fn embed_sync(&self, span: &str) -> Vec<f32> {
+ fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
@@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider {
*x /= norm;
}
- result
+ result.into()
}
}
@@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
200
}
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())