Detailed changes
@@ -53,36 +53,30 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
- fn count_tokens(&self, span: &str) -> usize;
- fn should_truncate(&self, span: &str) -> bool;
- fn truncate(&self, span: &str) -> String;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
+ fn max_tokens_per_batch(&self) -> usize;
+ fn truncate(&self, span: &str) -> (String, usize);
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = vec![0.32 as f32; 1536];
return Ok(vec![dummy_vec; spans.len()]);
}
- fn count_tokens(&self, span: &str) -> usize {
- // For Dummy Providers, we are going to use OpenAI tokenization for ease
- let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- tokens.len()
+ fn max_tokens_per_batch(&self) -> usize {
+ OPENAI_INPUT_LIMIT
}
- fn should_truncate(&self, span: &str) -> bool {
- self.count_tokens(span) > OPENAI_INPUT_LIMIT
- }
-
- fn truncate(&self, span: &str) -> String {
+ fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+ let token_count = tokens.len();
+ let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens)
@@ -92,7 +86,7 @@ impl EmbeddingProvider for DummyEmbeddings {
span.to_string()
};
- output
+ (output, token_count)
}
}
@@ -125,19 +119,14 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
- fn count_tokens(&self, span: &str) -> usize {
- // For Dummy Providers, we are going to use OpenAI tokenization for ease
- let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- tokens.len()
- }
-
- fn should_truncate(&self, span: &str) -> bool {
- self.count_tokens(span) > OPENAI_INPUT_LIMIT
+ fn max_tokens_per_batch(&self) -> usize {
+ OPENAI_INPUT_LIMIT
}
- fn truncate(&self, span: &str) -> String {
+ fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+ let token_count = tokens.len();
+ let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens)
@@ -147,10 +136,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
span.to_string()
};
- output
+ (output, token_count)
}
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@@ -160,9 +149,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
let mut request_number = 0;
let mut request_timeout: u64 = 10;
- let mut truncated = false;
let mut response: Response<AsyncBody>;
- let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
while request_number < MAX_RETRIES {
response = self
.send_request(
@@ -0,0 +1,140 @@
+use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
+
+use gpui::AppContext;
+use parking_lot::Mutex;
+use smol::channel;
+
+use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+
+#[derive(Clone)]
+pub struct FileToEmbed {
+ pub worktree_id: i64,
+ pub path: PathBuf,
+ pub mtime: SystemTime,
+ pub documents: Vec<Document>,
+ pub job_handle: JobHandle,
+}
+
+impl std::fmt::Debug for FileToEmbed {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("FileToEmbed")
+ .field("worktree_id", &self.worktree_id)
+ .field("path", &self.path)
+ .field("mtime", &self.mtime)
+ .field("document", &self.documents)
+ .finish_non_exhaustive()
+ }
+}
+
+impl PartialEq for FileToEmbed {
+ fn eq(&self, other: &Self) -> bool {
+ self.worktree_id == other.worktree_id
+ && self.path == other.path
+ && self.mtime == other.mtime
+ && self.documents == other.documents
+ }
+}
+
+pub struct EmbeddingQueue {
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ pending_batch: Vec<FileToEmbedFragment>,
+ pending_batch_token_count: usize,
+ finished_files_tx: channel::Sender<FileToEmbed>,
+ finished_files_rx: channel::Receiver<FileToEmbed>,
+}
+
+pub struct FileToEmbedFragment {
+ file: Arc<Mutex<FileToEmbed>>,
+ document_range: Range<usize>,
+}
+
+impl EmbeddingQueue {
+ pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
+ let (finished_files_tx, finished_files_rx) = channel::unbounded();
+ Self {
+ embedding_provider,
+ pending_batch: Vec::new(),
+ pending_batch_token_count: 0,
+ finished_files_tx,
+ finished_files_rx,
+ }
+ }
+
+ pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
+ let file = Arc::new(Mutex::new(file));
+
+ self.pending_batch.push(FileToEmbedFragment {
+ file: file.clone(),
+ document_range: 0..0,
+ });
+
+ let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+ for (ix, document) in file.lock().documents.iter().enumerate() {
+ let next_token_count = self.pending_batch_token_count + document.token_count;
+ if next_token_count > self.embedding_provider.max_tokens_per_batch() {
+ let range_end = fragment_range.end;
+ self.flush(cx);
+ self.pending_batch.push(FileToEmbedFragment {
+ file: file.clone(),
+ document_range: range_end..range_end,
+ });
+ fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+ }
+
+ fragment_range.end = ix + 1;
+ self.pending_batch_token_count += document.token_count;
+ }
+ }
+
+ pub fn flush(&mut self, cx: &mut AppContext) {
+ let batch = mem::take(&mut self.pending_batch);
+ self.pending_batch_token_count = 0;
+ if batch.is_empty() {
+ return;
+ }
+
+ let finished_files_tx = self.finished_files_tx.clone();
+ let embedding_provider = self.embedding_provider.clone();
+ cx.background().spawn(async move {
+ let mut spans = Vec::new();
+ for fragment in &batch {
+ let file = fragment.file.lock();
+ spans.extend(
+ file.documents[fragment.document_range.clone()]
+ .iter()
+ .map(|d| d.content.clone()),
+ );
+ }
+
+ match embedding_provider.embed_batch(spans).await {
+ Ok(embeddings) => {
+ let mut embeddings = embeddings.into_iter();
+ for fragment in batch {
+ for document in
+ &mut fragment.file.lock().documents[fragment.document_range.clone()]
+ {
+ if let Some(embedding) = embeddings.next() {
+ document.embedding = embedding;
+ } else {
+ //
+ log::error!("number of embeddings returned different from number of documents");
+ }
+ }
+
+ if let Some(file) = Arc::into_inner(fragment.file) {
+ finished_files_tx.try_send(file.into_inner()).unwrap();
+ }
+ }
+ }
+ Err(error) => {
+ log::error!("{:?}", error);
+ }
+ }
+ })
+ .detach();
+ }
+
+ pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
+ self.finished_files_rx.clone()
+ }
+}
@@ -72,8 +72,7 @@ impl CodeContextRetriever {
let mut sha1 = Sha1::new();
sha1.update(&document_span);
- let token_count = self.embedding_provider.count_tokens(&document_span);
- let document_span = self.embedding_provider.truncate(&document_span);
+ let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
@@ -93,8 +92,7 @@ impl CodeContextRetriever {
let mut sha1 = Sha1::new();
sha1.update(&document_span);
- let token_count = self.embedding_provider.count_tokens(&document_span);
- let document_span = self.embedding_provider.truncate(&document_span);
+ let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
@@ -183,8 +181,8 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("item", &document.content);
- let token_count = self.embedding_provider.count_tokens(&document_content);
- let document_content = self.embedding_provider.truncate(&document_content);
+ let (document_content, token_count) =
+ self.embedding_provider.truncate(&document_content);
document.content = document_content;
document.token_count = token_count;
@@ -1,14 +1,16 @@
use crate::{
db::dot,
embedding::{DummyEmbeddings, EmbeddingProvider},
+ embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Document},
semantic_index_settings::SemanticIndexSettings,
- SearchResult, SemanticIndex,
+ FileToEmbed, JobHandle, SearchResult, SemanticIndex,
};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
+use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
use rand::{rngs::StdRng, Rng};
@@ -20,8 +22,10 @@ use std::{
atomic::{self, AtomicUsize},
Arc,
},
+ time::SystemTime,
};
use unindent::Unindent;
+use util::RandomCharIter;
#[ctor::ctor]
fn init_logger() {
@@ -32,11 +36,7 @@ fn init_logger() {
#[gpui::test]
async fn test_semantic_index(cx: &mut TestAppContext) {
- cx.update(|cx| {
- cx.set_global(SettingsStore::test(cx));
- settings::register::<SemanticIndexSettings>(cx);
- settings::register::<ProjectSettings>(cx);
- });
+ init_test(cx);
let fs = FakeFs::new(cx.background());
fs.insert_tree(
@@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let db_path = db_dir.path().join("db.sqlite");
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let store = SemanticIndex::new(
+ let semantic_index = SemanticIndex::new(
fs.clone(),
db_path,
embedding_provider.clone(),
@@ -87,13 +87,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
- let _ = store
+ let _ = semantic_index
.update(cx, |store, cx| {
store.initialize_project(project.clone(), cx)
})
.await;
- let (file_count, outstanding_file_count) = store
+ let (file_count, outstanding_file_count) = semantic_index
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
@@ -101,7 +101,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
cx.foreground().run_until_parked();
assert_eq!(*outstanding_file_count.borrow(), 0);
- let search_results = store
+ let search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -129,7 +129,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
// Test Include Files Functonality
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
- let rust_only_search_results = store
+ let rust_only_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -153,7 +153,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
cx,
);
- let no_rust_search_results = store
+ let no_rust_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -189,7 +189,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
cx.foreground().run_until_parked();
let prev_embedding_count = embedding_provider.embedding_count();
- let (file_count, outstanding_file_count) = store
+ let (file_count, outstanding_file_count) = semantic_index
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
@@ -204,6 +204,69 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
);
}
+#[gpui::test(iterations = 10)]
+async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
+ let (outstanding_job_count, _) = postage::watch::channel_with(0);
+ let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
+
+ let files = (1..=3)
+ .map(|file_ix| FileToEmbed {
+ worktree_id: 5,
+ path: format!("path-{file_ix}").into(),
+ mtime: SystemTime::now(),
+ documents: (0..rng.gen_range(4..22))
+ .map(|document_ix| {
+ let content_len = rng.gen_range(10..100);
+ Document {
+ range: 0..10,
+ embedding: Vec::new(),
+ name: format!("document {document_ix}"),
+ content: RandomCharIter::new(&mut rng)
+ .with_simple_text()
+ .take(content_len)
+ .collect(),
+ sha1: rng.gen(),
+ token_count: rng.gen_range(10..30),
+ }
+ })
+ .collect(),
+ job_handle: JobHandle::new(&outstanding_job_count),
+ })
+ .collect::<Vec<_>>();
+
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+ let mut queue = EmbeddingQueue::new(embedding_provider.clone());
+
+ let finished_files = cx.update(|cx| {
+ for file in &files {
+ queue.push(file.clone(), cx);
+ }
+ queue.flush(cx);
+ queue.finished_files()
+ });
+
+ cx.foreground().run_until_parked();
+ let mut embedded_files: Vec<_> = files
+ .iter()
+ .map(|_| finished_files.try_recv().expect("no finished file"))
+ .collect();
+
+ let expected_files: Vec<_> = files
+ .iter()
+ .map(|file| {
+ let mut file = file.clone();
+ for doc in &mut file.documents {
+ doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
+ }
+ file
+ })
+ .collect();
+
+ embedded_files.sort_by_key(|f| f.path.clone());
+
+ assert_eq!(embedded_files, expected_files);
+}
+
#[track_caller]
fn assert_search_results(
actual: &[SearchResult],
@@ -1220,47 +1283,42 @@ impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
+
+ fn embed_sync(&self, span: &str) -> Vec<f32> {
+ let mut result = vec![1.0; 26];
+ for letter in span.chars() {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result
+ }
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
- fn count_tokens(&self, span: &str) -> usize {
- span.len()
- }
-
- fn should_truncate(&self, span: &str) -> bool {
- false
+ fn truncate(&self, span: &str) -> (String, usize) {
+ (span.to_string(), 1)
}
- fn truncate(&self, span: &str) -> String {
- span.to_string()
+ fn max_tokens_per_batch(&self) -> usize {
+ 200
}
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
- Ok(spans
- .iter()
- .map(|span| {
- let mut result = vec![1.0; 26];
- for letter in span.chars() {
- let letter = letter.to_ascii_lowercase();
- if letter as u32 >= 'a' as u32 {
- let ix = (letter as u32) - ('a' as u32);
- if ix < 26 {
- result[ix as usize] += 1.0;
- }
- }
- }
-
- let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
- for x in &mut result {
- *x /= norm;
- }
-
- result
- })
- .collect())
+ Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
@@ -1704,3 +1762,11 @@ fn test_subtract_ranges() {
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
}
+
+fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ cx.set_global(SettingsStore::test(cx));
+ settings::register::<SemanticIndexSettings>(cx);
+ settings::register::<ProjectSettings>(cx);
+ });
+}
@@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
Defer(Some(f))
}
-pub struct RandomCharIter<T: Rng>(T);
+pub struct RandomCharIter<T: Rng> {
+ rng: T,
+ simple_text: bool,
+}
impl<T: Rng> RandomCharIter<T> {
pub fn new(rng: T) -> Self {
- Self(rng)
+ Self {
+ rng,
+ simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
+ }
+ }
+
+ pub fn with_simple_text(mut self) -> Self {
+ self.simple_text = true;
+ self
}
}
@@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
type Item = char;
fn next(&mut self) -> Option<Self::Item> {
- if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
- return if self.0.gen_range(0..100) < 5 {
+ if self.simple_text {
+ return if self.rng.gen_range(0..100) < 5 {
Some('\n')
} else {
- Some(self.0.gen_range(b'a'..b'z' + 1).into())
+ Some(self.rng.gen_range(b'a'..b'z' + 1).into())
};
}
- match self.0.gen_range(0..100) {
+ match self.rng.gen_range(0..100) {
// whitespace
- 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
+ 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
// two-byte greek letters
- 20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
+ 20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
// // three-byte characters
- 33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
+ 33..=45 => ['✋', '✅', '❌', '❎', '⭐']
+ .choose(&mut self.rng)
+ .copied(),
// // four-byte characters
- 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
+ 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
// ascii letters
- _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
+ _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
}
}
}