@@ -278,6 +278,39 @@ impl VectorDatabase {
})
}
+ pub fn embeddings_for_digests(
+ &self,
+ digests: Vec<SpanDigest>,
+ ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
+ self.transact(move |db| {
+ let mut query = db.prepare(
+ "
+ SELECT digest, embedding
+ FROM spans
+ WHERE digest IN rarray(?)
+ ",
+ )?;
+ let mut embeddings_by_digest = HashMap::default();
+ let digests = Rc::new(
+ digests
+ .into_iter()
+ .map(|p| Value::Blob(p.0.to_vec()))
+ .collect::<Vec<_>>(),
+ );
+ let rows = query.query_map(params![digests], |row| {
+ Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
+ })?;
+
+ for row in rows {
+ if let Ok(row) = row {
+ embeddings_by_digest.insert(row.0, row.1);
+ }
+ }
+
+ Ok(embeddings_by_digest)
+ })
+ }
+
pub fn embeddings_for_files(
&self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
@@ -17,7 +17,7 @@ use std::{
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
-pub struct SpanDigest([u8; 20]);
+pub struct SpanDigest(pub [u8; 20]);
impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
@@ -263,9 +263,11 @@ pub struct PendingFile {
job_handle: JobHandle,
}
+#[derive(Clone)]
pub struct SearchResult {
pub buffer: ModelHandle<Buffer>,
pub range: Range<Anchor>,
+ pub similarity: f32,
}
impl SemanticIndex {
@@ -775,7 +777,8 @@ impl SemanticIndex {
.filter_map(|buffer_handle| {
let buffer = buffer_handle.read(cx);
if buffer.is_dirty() {
- Some((buffer_handle.downgrade(), buffer.snapshot()))
+ // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
+ Some((buffer_handle, buffer.snapshot()))
} else {
None
}
@@ -783,77 +786,133 @@ impl SemanticIndex {
.collect::<HashMap<_, _>>()
});
- cx.background()
- .spawn({
- let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
- let embedding_provider = embedding_provider.clone();
- let phrase_embedding = phrase_embedding.clone();
- async move {
- let mut results = Vec::new();
- 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
- let language = buffer_snapshot
- .language_at(0)
- .cloned()
- .unwrap_or_else(|| language::PLAIN_TEXT.clone());
- if let Some(spans) = retriever
- .parse_file_with_template(None, &buffer_snapshot.text(), language)
- .log_err()
- {
- let mut batch = Vec::new();
- let mut batch_tokens = 0;
- let mut embeddings = Vec::new();
-
- // TODO: query span digests in the database to avoid embedding them again.
+ let buffer_results = if let Some(db) =
+ VectorDatabase::new(fs, db_path.clone(), cx.background())
+ .await
+ .log_err()
+ {
+ cx.background()
+ .spawn({
+ let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+ let embedding_provider = embedding_provider.clone();
+ let phrase_embedding = phrase_embedding.clone();
+ async move {
+ let mut results = Vec::<SearchResult>::new();
+ 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
+ let language = buffer_snapshot
+ .language_at(0)
+ .cloned()
+ .unwrap_or_else(|| language::PLAIN_TEXT.clone());
+ if let Some(spans) = retriever
+ .parse_file_with_template(
+ None,
+ &buffer_snapshot.text(),
+ language,
+ )
+ .log_err()
+ {
+ let mut batch = Vec::new();
+ let mut batch_tokens = 0;
+ let mut embeddings = Vec::new();
+
+ let digests = spans
+ .iter()
+ .map(|span| span.digest.clone())
+ .collect::<Vec<_>>();
+ let embeddings_for_digests = db
+ .embeddings_for_digests(digests)
+ .await
+ .map_or(Default::default(), |m| m);
+
+ for span in &spans {
+ if embeddings_for_digests.contains_key(&span.digest) {
+ continue;
+ };
+
+ if batch_tokens + span.token_count
+ > embedding_provider.max_tokens_per_batch()
+ {
+ if let Some(batch_embeddings) = embedding_provider
+ .embed_batch(mem::take(&mut batch))
+ .await
+ .log_err()
+ {
+ embeddings.extend(batch_embeddings);
+ batch_tokens = 0;
+ } else {
+ continue 'buffers;
+ }
+ }
- for span in &spans {
- if span.embedding.is_some() {
- continue;
+ batch_tokens += span.token_count;
+ batch.push(span.content.clone());
}
- if batch_tokens + span.token_count
- > embedding_provider.max_tokens_per_batch()
+ if let Some(batch_embeddings) = embedding_provider
+ .embed_batch(mem::take(&mut batch))
+ .await
+ .log_err()
{
- if let Some(batch_embeddings) = embedding_provider
- .embed_batch(mem::take(&mut batch))
- .await
- .log_err()
+ embeddings.extend(batch_embeddings);
+ } else {
+ continue 'buffers;
+ }
+
+ let mut embeddings = embeddings.into_iter();
+ for span in spans {
+ let embedding = if let Some(embedding) =
+ embeddings_for_digests.get(&span.digest)
{
- embeddings.extend(batch_embeddings);
- batch_tokens = 0;
+ Some(embedding.clone())
} else {
+ embeddings.next()
+ };
+
+ if let Some(embedding) = embedding {
+ let similarity =
+ embedding.similarity(&phrase_embedding);
+
+ let ix = match results.binary_search_by(|s| {
+ similarity
+ .partial_cmp(&s.similarity)
+ .unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) => ix,
+ Err(ix) => ix,
+ };
+
+ let range = {
+ let start = buffer_snapshot
+ .clip_offset(span.range.start, Bias::Left);
+ let end = buffer_snapshot
+ .clip_offset(span.range.end, Bias::Right);
+ buffer_snapshot.anchor_before(start)
+ ..buffer_snapshot.anchor_after(end)
+ };
+
+ results.insert(
+ ix,
+ SearchResult {
+ buffer: buffer_handle.clone(),
+ range,
+ similarity,
+ },
+ );
+ results.truncate(limit);
+ } else {
+ log::error!("failed to embed span");
continue 'buffers;
}
}
-
- batch_tokens += span.token_count;
- batch.push(span.content.clone());
- }
-
- if let Some(batch_embeddings) = embedding_provider
- .embed_batch(mem::take(&mut batch))
- .await
- .log_err()
- {
- embeddings.extend(batch_embeddings);
- } else {
- continue 'buffers;
- }
-
- let mut embeddings = embeddings.into_iter();
- for span in spans {
- let embedding = span.embedding.or_else(|| embeddings.next());
- if let Some(embedding) = embedding {
- todo!()
- } else {
- log::error!("failed to embed span");
- continue 'buffers;
- }
}
}
+ anyhow::Ok(results)
}
- }
- })
- .await;
+ })
+ .await
+ } else {
+ Ok(Vec::new())
+ };
let batch_results = futures::future::join_all(batch_results).await;
@@ -873,7 +932,11 @@ impl SemanticIndex {
}
}
- let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
+ let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
+ let scores = results
+ .into_iter()
+ .map(|(_, score)| score)
+ .collect::<Vec<f32>>();
let spans = database.spans_for_ids(ids.as_slice()).await?;
let mut tasks = Vec::new();
@@ -903,19 +966,74 @@ impl SemanticIndex {
t0.elapsed().as_millis()
);
- Ok(buffers
+ let database_results = buffers
.into_iter()
.zip(ranges)
- .filter_map(|(buffer, range)| {
+ .zip(scores)
+ .filter_map(|((buffer, range), similarity)| {
let buffer = buffer.log_err()?;
let range = buffer.read_with(&cx, |buffer, _| {
let start = buffer.clip_offset(range.start, Bias::Left);
let end = buffer.clip_offset(range.end, Bias::Right);
buffer.anchor_before(start)..buffer.anchor_after(end)
});
- Some(SearchResult { buffer, range })
+ Some(SearchResult {
+ buffer,
+ range,
+ similarity,
+ })
})
- .collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+
+ // Stitch Together Database Results & Buffer Results
+ if let Ok(buffer_results) = buffer_results {
+ let mut buffer_map = HashMap::default();
+ for buffer_result in buffer_results {
+ buffer_map
+ .entry(buffer_result.clone().buffer)
+ .or_insert(Vec::new())
+ .push(buffer_result);
+ }
+
+ for db_result in database_results {
+ if !buffer_map.contains_key(&db_result.buffer) {
+ buffer_map
+ .entry(db_result.clone().buffer)
+ .or_insert(Vec::new())
+ .push(db_result);
+ }
+ }
+
+ let mut full_results = Vec::<SearchResult>::new();
+
+ for (_, results) in buffer_map {
+ for res in results.into_iter() {
+ let ix = match full_results.binary_search_by(|search_result| {
+ res.similarity
+ .partial_cmp(&search_result.similarity)
+ .unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) => ix,
+ Err(ix) => ix,
+ };
+ full_results.insert(ix, res);
+ full_results.truncate(limit);
+ }
+ }
+
+ return Ok(full_results);
+ } else {
+ return Ok(database_results);
+ }
+
+ // let ix = match results.binary_search_by(|(_, s)| {
+ // similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+ // }) {
+ // Ok(ix) => ix,
+ // Err(ix) => ix,
+ // };
+ // results.insert(ix, (id, similarity));
+ // results.truncate(limit);
})
}