Cargo.lock 🔗
@@ -6739,6 +6739,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
+ "ordered-float",
"parking_lot 0.11.2",
"parse_duration",
"picker",
Antonio Scandurra created
Cargo.lock | 1
crates/semantic_index/Cargo.toml | 1
crates/semantic_index/src/db.rs | 13
crates/semantic_index/src/embedding.rs | 11
crates/semantic_index/src/semantic_index.rs | 415 +++++++++++-----------
5 files changed, 214 insertions(+), 227 deletions(-)
@@ -6739,6 +6739,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
+ "ordered-float",
"parking_lot 0.11.2",
"parse_duration",
"picker",
@@ -23,6 +23,7 @@ settings = { path = "../settings" }
anyhow.workspace = true
postage.workspace = true
futures.workspace = true
+ordered-float.workspace = true
smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true
@@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::executor;
+use ordered_float::OrderedFloat;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
use rusqlite::params;
use rusqlite::types::Value;
use std::{
- cmp::Ordering,
+ cmp::Reverse,
future::Future,
ops::Range,
path::{Path, PathBuf},
@@ -407,16 +408,16 @@ impl VectorDatabase {
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
- ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
+ ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec();
self.transact(move |db| {
- let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+ let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
Self::for_each_span(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding);
- let ix = match results.binary_search_by(|(_, s)| {
- similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
- }) {
+ let ix = match results
+ .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
+ {
Ok(ix) => ix,
Err(ix) => ix,
};
@@ -7,6 +7,7 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
+use ordered_float::OrderedFloat;
use parking_lot::Mutex;
use parse_duration::parse;
use postage::watch;
@@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
}
impl Embedding {
- pub fn similarity(&self, other: &Self) -> f32 {
+ pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
@@ -58,7 +59,7 @@ impl Embedding {
1,
);
}
- result
+ OrderedFloat(result)
}
}
@@ -379,13 +380,13 @@ mod tests {
);
}
- fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+ fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
- fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
- a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+ fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
+ OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}
@@ -16,13 +16,14 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
+use ordered_float::OrderedFloat;
use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
+use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch;
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
use smol::channel;
use std::{
- cmp::Ordering,
+ cmp::Reverse,
future::Future,
mem,
ops::Range,
@@ -267,7 +268,7 @@ pub struct PendingFile {
pub struct SearchResult {
pub buffer: ModelHandle<Buffer>,
pub range: Range<Anchor>,
- pub similarity: f32,
+ pub similarity: OrderedFloat<f32>,
}
impl SemanticIndex {
@@ -690,38 +691,70 @@ impl SemanticIndex {
pub fn search_project(
&mut self,
project: ModelHandle<Project>,
- phrase: String,
+ query: String,
limit: usize,
includes: Vec<PathMatcher>,
- mut excludes: Vec<PathMatcher>,
+ excludes: Vec<PathMatcher>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
+ if query.is_empty() {
+ return Task::ready(Ok(Vec::new()));
+ }
+
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
- let db_path = self.db.path().clone();
- let fs = self.fs.clone();
+
cx.spawn(|this, mut cx| async move {
+ let query = embedding_provider
+ .embed_batch(vec![query])
+ .await?
+ .pop()
+ .ok_or_else(|| anyhow!("could not embed query"))?;
index.await?;
- let t0 = Instant::now();
- let database =
- VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
+ let search_start = Instant::now();
+ let modified_buffer_results = this.update(&mut cx, |this, cx| {
+ this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
+ });
+ let file_results = this.update(&mut cx, |this, cx| {
+ this.search_files(project, query, limit, includes, excludes, cx)
+ });
+ let (modified_buffer_results, file_results) =
+ futures::join!(modified_buffer_results, file_results);
- if phrase.len() == 0 {
- return Ok(Vec::new());
+ // Weave together the results from modified buffers and files.
+ let mut results = Vec::new();
+ let mut modified_buffers = HashSet::default();
+ for result in modified_buffer_results.log_err().unwrap_or_default() {
+ modified_buffers.insert(result.buffer.clone());
+ results.push(result);
}
+ for result in file_results.log_err().unwrap_or_default() {
+ if !modified_buffers.contains(&result.buffer) {
+ results.push(result);
+ }
+ }
+ results.sort_by_key(|result| Reverse(result.similarity));
+ results.truncate(limit);
+ log::trace!("Semantic search took {:?}", search_start.elapsed());
+ Ok(results)
+ })
+ }
- let phrase_embedding = embedding_provider
- .embed_batch(vec![phrase])
- .await?
- .into_iter()
- .next()
- .unwrap();
-
- log::trace!(
- "Embedding search phrase took: {:?} milliseconds",
- t0.elapsed().as_millis()
- );
+ pub fn search_files(
+ &mut self,
+ project: ModelHandle<Project>,
+ query: Embedding,
+ limit: usize,
+ includes: Vec<PathMatcher>,
+ excludes: Vec<PathMatcher>,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<SearchResult>>> {
+ let db_path = self.db.path().clone();
+ let fs = self.fs.clone();
+ cx.spawn(|this, mut cx| async move {
+ let database =
+ VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
let worktree_db_ids = this.read_with(&cx, |this, _| {
let project_state = this
@@ -742,42 +775,6 @@ impl SemanticIndex {
anyhow::Ok(worktree_db_ids)
})?;
- let (dirty_buffers, dirty_paths) = project.read_with(&cx, |project, cx| {
- let mut dirty_paths = Vec::new();
- let dirty_buffers = project
- .opened_buffers(cx)
- .into_iter()
- .filter_map(|buffer_handle| {
- let buffer = buffer_handle.read(cx);
- if buffer.is_dirty() {
- let snapshot = buffer.snapshot();
- if let Some(file_pathbuf) = snapshot.resolve_file_path(cx, false) {
- let file_path = file_pathbuf.as_path();
-
- if excludes.iter().any(|glob| glob.is_match(file_path)) {
- return None;
- }
-
- file_pathbuf
- .to_str()
- .and_then(|path| PathMatcher::new(path).log_err())
- .and_then(|path_matcher| {
- dirty_paths.push(path_matcher);
- Some(())
- });
- }
- // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
- Some((buffer_handle, buffer.snapshot()))
- } else {
- None
- }
- })
- .collect::<HashMap<_, _>>();
-
- (dirty_buffers, dirty_paths)
- });
-
- excludes.extend(dirty_paths);
let file_ids = database
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
.await?;
@@ -796,155 +793,26 @@ impl SemanticIndex {
let limit = limit.clone();
let fs = fs.clone();
let db_path = db_path.clone();
- let phrase_embedding = phrase_embedding.clone();
+ let query = query.clone();
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
.await
.log_err()
{
batch_results.push(async move {
- db.top_k_search(&phrase_embedding, limit, batch.as_slice())
- .await
+ db.top_k_search(&query, limit, batch.as_slice()).await
});
}
}
- let buffer_results = if let Some(db) =
- VectorDatabase::new(fs, db_path.clone(), cx.background())
- .await
- .log_err()
- {
- cx.background()
- .spawn({
- let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
- let embedding_provider = embedding_provider.clone();
- let phrase_embedding = phrase_embedding.clone();
- async move {
- let mut results = Vec::<SearchResult>::new();
- 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
- let language = buffer_snapshot
- .language_at(0)
- .cloned()
- .unwrap_or_else(|| language::PLAIN_TEXT.clone());
- if let Some(spans) = retriever
- .parse_file_with_template(
- None,
- &buffer_snapshot.text(),
- language,
- )
- .log_err()
- {
- let mut batch = Vec::new();
- let mut batch_tokens = 0;
- let mut embeddings = Vec::new();
-
- let digests = spans
- .iter()
- .map(|span| span.digest.clone())
- .collect::<Vec<_>>();
- let embeddings_for_digests = db
- .embeddings_for_digests(digests)
- .await
- .map_or(Default::default(), |m| m);
-
- for span in &spans {
- if embeddings_for_digests.contains_key(&span.digest) {
- continue;
- };
-
- if batch_tokens + span.token_count
- > embedding_provider.max_tokens_per_batch()
- {
- if let Some(batch_embeddings) = embedding_provider
- .embed_batch(mem::take(&mut batch))
- .await
- .log_err()
- {
- embeddings.extend(batch_embeddings);
- batch_tokens = 0;
- } else {
- continue 'buffers;
- }
- }
-
- batch_tokens += span.token_count;
- batch.push(span.content.clone());
- }
-
- if let Some(batch_embeddings) = embedding_provider
- .embed_batch(mem::take(&mut batch))
- .await
- .log_err()
- {
- embeddings.extend(batch_embeddings);
- } else {
- continue 'buffers;
- }
-
- let mut embeddings = embeddings.into_iter();
- for span in spans {
- let embedding = if let Some(embedding) =
- embeddings_for_digests.get(&span.digest)
- {
- Some(embedding.clone())
- } else {
- embeddings.next()
- };
-
- if let Some(embedding) = embedding {
- let similarity =
- embedding.similarity(&phrase_embedding);
-
- let ix = match results.binary_search_by(|s| {
- similarity
- .partial_cmp(&s.similarity)
- .unwrap_or(Ordering::Equal)
- }) {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
-
- let range = {
- let start = buffer_snapshot
- .clip_offset(span.range.start, Bias::Left);
- let end = buffer_snapshot
- .clip_offset(span.range.end, Bias::Right);
- buffer_snapshot.anchor_before(start)
- ..buffer_snapshot.anchor_after(end)
- };
-
- results.insert(
- ix,
- SearchResult {
- buffer: buffer_handle.clone(),
- range,
- similarity,
- },
- );
- results.truncate(limit);
- } else {
- log::error!("failed to embed span");
- continue 'buffers;
- }
- }
- }
- }
- anyhow::Ok(results)
- }
- })
- .await
- } else {
- Ok(Vec::new())
- };
-
let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new();
for batch_result in batch_results {
if batch_result.is_ok() {
for (id, similarity) in batch_result.unwrap() {
- let ix = match results.binary_search_by(|(_, s)| {
- similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
- }) {
+ let ix = match results
+ .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
+ {
Ok(ix) => ix,
Err(ix) => ix,
};
@@ -958,7 +826,7 @@ impl SemanticIndex {
let scores = results
.into_iter()
.map(|(_, score)| score)
- .collect::<Vec<f32>>();
+ .collect::<Vec<_>>();
let spans = database.spans_for_ids(ids.as_slice()).await?;
let mut tasks = Vec::new();
@@ -983,12 +851,7 @@ impl SemanticIndex {
let buffers = futures::future::join_all(tasks).await;
- log::trace!(
- "Semantic Searching took: {:?} milliseconds in total",
- t0.elapsed().as_millis()
- );
-
- let mut database_results = buffers
+ Ok(buffers
.into_iter()
.zip(ranges)
.zip(scores)
@@ -1005,26 +868,89 @@ impl SemanticIndex {
similarity,
})
})
- .collect::<Vec<_>>();
+ .collect())
+ })
+ }
- // Stitch Together Database Results & Buffer Results
- if let Ok(buffer_results) = buffer_results {
- for buffer_result in buffer_results {
- let ix = match database_results.binary_search_by(|search_result| {
- buffer_result
- .similarity
- .partial_cmp(&search_result.similarity)
- .unwrap_or(Ordering::Equal)
- }) {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
- database_results.insert(ix, buffer_result);
- database_results.truncate(limit);
+ fn search_modified_buffers(
+ &self,
+ project: &ModelHandle<Project>,
+ query: Embedding,
+ limit: usize,
+ excludes: &[PathMatcher],
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<SearchResult>>> {
+ let modified_buffers = project
+ .read(cx)
+ .opened_buffers(cx)
+ .into_iter()
+ .filter_map(|buffer_handle| {
+ let buffer = buffer_handle.read(cx);
+ let snapshot = buffer.snapshot();
+ let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
+ excludes.iter().any(|matcher| matcher.is_match(&path))
+ });
+ if buffer.is_dirty() && !excluded {
+ Some((buffer_handle, snapshot))
+ } else {
+ None
+ }
+ })
+ .collect::<HashMap<_, _>>();
+
+ let embedding_provider = self.embedding_provider.clone();
+ let fs = self.fs.clone();
+ let db_path = self.db.path().clone();
+ let background = cx.background().clone();
+ cx.background().spawn(async move {
+ let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
+ let mut results = Vec::<SearchResult>::new();
+
+ let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+ for (buffer, snapshot) in modified_buffers {
+ let language = snapshot
+ .language_at(0)
+ .cloned()
+ .unwrap_or_else(|| language::PLAIN_TEXT.clone());
+ let mut spans = retriever
+ .parse_file_with_template(None, &snapshot.text(), language)
+ .log_err()
+ .unwrap_or_default();
+ if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+ .await
+ .log_err()
+ .is_some()
+ {
+ for span in spans {
+ let similarity = span.embedding.unwrap().similarity(&query);
+ let ix = match results
+ .binary_search_by_key(&Reverse(similarity), |result| {
+ Reverse(result.similarity)
+ }) {
+ Ok(ix) => ix,
+ Err(ix) => ix,
+ };
+
+ let range = {
+ let start = snapshot.clip_offset(span.range.start, Bias::Left);
+ let end = snapshot.clip_offset(span.range.end, Bias::Right);
+ snapshot.anchor_before(start)..snapshot.anchor_after(end)
+ };
+
+ results.insert(
+ ix,
+ SearchResult {
+ buffer: buffer.clone(),
+ range,
+ similarity,
+ },
+ );
+ results.truncate(limit);
+ }
}
}
- Ok(database_results)
+ Ok(results)
})
}
@@ -1208,6 +1134,63 @@ impl SemanticIndex {
Ok(())
})
}
+
+ async fn embed_spans(
+ spans: &mut [Span],
+ embedding_provider: &dyn EmbeddingProvider,
+ db: &VectorDatabase,
+ ) -> Result<()> {
+ let mut batch = Vec::new();
+ let mut batch_tokens = 0;
+ let mut embeddings = Vec::new();
+
+ let digests = spans
+ .iter()
+ .map(|span| span.digest.clone())
+ .collect::<Vec<_>>();
+ let embeddings_for_digests = db
+ .embeddings_for_digests(digests)
+ .await
+ .log_err()
+ .unwrap_or_default();
+
+ for span in &*spans {
+ if embeddings_for_digests.contains_key(&span.digest) {
+ continue;
+ };
+
+ if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
+ let batch_embeddings = embedding_provider
+ .embed_batch(mem::take(&mut batch))
+ .await?;
+ embeddings.extend(batch_embeddings);
+ batch_tokens = 0;
+ }
+
+ batch_tokens += span.token_count;
+ batch.push(span.content.clone());
+ }
+
+ if !batch.is_empty() {
+ let batch_embeddings = embedding_provider
+ .embed_batch(mem::take(&mut batch))
+ .await?;
+
+ embeddings.extend(batch_embeddings);
+ }
+
+ let mut embeddings = embeddings.into_iter();
+ for span in spans {
+ let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
+ Some(embedding.clone())
+ } else {
+ embeddings.next()
+ };
+ let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
+ span.embedding = Some(embedding);
+ }
+ Ok(())
+ }
}
impl Entity for SemanticIndex {