@@ -19,7 +19,7 @@ use gpui::{
use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
use rich_text::RichText;
-use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
+use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
use serde::Deserialize;
use settings::Settings;
use std::sync::Arc;
@@ -51,7 +51,7 @@ pub enum SubmitMode {
Codebase,
}
-gpui::actions!(assistant2, [Cancel, ToggleFocus]);
+gpui::actions!(assistant2, [Cancel, ToggleFocus, DebugProjectIndex]);
gpui::impl_actions!(assistant2, [Submit]);
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@@ -131,7 +131,13 @@ impl AssistantPanel {
let tool_registry = Arc::new(tool_registry);
- Self::new(app_state.languages.clone(), tool_registry, user_store, cx)
+ Self::new(
+ app_state.languages.clone(),
+ tool_registry,
+ user_store,
+ Some(project_index),
+ cx,
+ )
})
})
}
@@ -140,6 +146,7 @@ impl AssistantPanel {
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
user_store: Model<UserStore>,
+ project_index: Option<Model<ProjectIndex>>,
cx: &mut ViewContext<Self>,
) -> Self {
let chat = cx.new_view(|cx| {
@@ -147,6 +154,7 @@ impl AssistantPanel {
language_registry.clone(),
tool_registry.clone(),
user_store,
+ project_index,
cx,
)
});
@@ -225,6 +233,7 @@ struct AssistantChat {
collapsed_messages: HashMap<MessageId, bool>,
pending_completion: Option<Task<()>>,
tool_registry: Arc<ToolRegistry>,
+ project_index: Option<Model<ProjectIndex>>,
}
impl AssistantChat {
@@ -232,6 +241,7 @@ impl AssistantChat {
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
user_store: Model<UserStore>,
+ project_index: Option<Model<ProjectIndex>>,
cx: &mut ViewContext<Self>,
) -> Self {
let model = CompletionProvider::get(cx).default_model();
@@ -258,6 +268,7 @@ impl AssistantChat {
list_state,
user_store,
language_registry,
+ project_index,
next_message_id: MessageId(0),
collapsed_messages: HashMap::default(),
pending_completion: None,
@@ -342,6 +353,14 @@ impl AssistantChat {
self.pending_completion.is_none()
}
+ fn debug_project_index(&mut self, _: &DebugProjectIndex, cx: &mut ViewContext<Self>) {
+ if let Some(index) = &self.project_index {
+ index.update(cx, |project_index, cx| {
+ project_index.debug(cx).detach_and_log_err(cx)
+ });
+ }
+ }
+
async fn request_completion(
this: WeakView<Self>,
mode: SubmitMode,
@@ -686,6 +705,7 @@ impl Render for AssistantChat {
.key_context("AssistantChat")
.on_action(cx.listener(Self::submit))
.on_action(cx.listener(Self::cancel))
+ .on_action(cx.listener(Self::debug_project_index))
.text_color(Color::Default.color(cx))
.child(list(self.list_state.clone()).flex_1())
.child(Composer::new(
@@ -21,6 +21,7 @@ use smol::channel;
use std::{
cmp::Ordering,
future::Future,
+ iter,
num::NonZeroUsize,
ops::Range,
path::{Path, PathBuf},
@@ -295,6 +296,28 @@ impl ProjectIndex {
}
Ok(result)
}
+
+ pub fn debug(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+ let indices = self
+ .worktree_indices
+ .values()
+ .filter_map(|worktree_index| {
+ if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
+ Some(index.clone())
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+
+ cx.spawn(|_, mut cx| async move {
+ eprintln!("semantic index contents:");
+ for index in indices {
+ index.update(&mut cx, |index, cx| index.debug(cx))?.await?
+ }
+ Ok(())
+ })
+ }
}
pub struct SearchResult {
@@ -419,7 +442,7 @@ impl WorktreeIndex {
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_entries(worktree.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
- let embed = self.embed_files(chunk.files, cx);
+ let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
@@ -436,7 +459,7 @@ impl WorktreeIndex {
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
- let embed = self.embed_files(chunk.files, cx);
+ let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
@@ -500,7 +523,7 @@ impl WorktreeIndex {
}
if entry.mtime != saved_mtime {
- let handle = entries_being_indexed.insert(&entry);
+ let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
@@ -539,7 +562,7 @@ impl WorktreeIndex {
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() {
- let handle = entries_being_indexed.insert(&entry);
+ let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
@@ -601,7 +624,8 @@ impl WorktreeIndex {
let chunked_file = ChunkedFile {
chunks: chunk_text(&text, grammar),
handle,
- entry,
+ path: entry.path,
+ mtime: entry.mtime,
text,
};
@@ -623,11 +647,11 @@ impl WorktreeIndex {
}
fn embed_files(
- &self,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
chunked_files: channel::Receiver<ChunkedFile>,
cx: &AppContext,
) -> EmbedFiles {
- let embedding_provider = self.embedding_provider.clone();
+ let embedding_provider = embedding_provider.clone();
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
let task = cx.background_executor().spawn(async move {
let mut chunked_file_batches =
@@ -635,9 +659,10 @@ impl WorktreeIndex {
while let Some(chunked_files) = chunked_file_batches.next().await {
// View the batch of files as a vec of chunks
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
- // Once those are done, reassemble it back into which files they belong to
+ // Once those are done, reassemble them back into the files in which they belong
+ // If any embeddings fail for a file, the entire file is discarded
- let chunks = chunked_files
+ let chunks: Vec<TextToEmbed> = chunked_files
.iter()
.flat_map(|file| {
file.chunks.iter().map(|chunk| TextToEmbed {
@@ -647,36 +672,50 @@ impl WorktreeIndex {
})
.collect::<Vec<_>>();
- let mut embeddings = Vec::new();
+ let mut embeddings: Vec<Option<Embedding>> = Vec::new();
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
if let Some(batch_embeddings) =
embedding_provider.embed(embedding_batch).await.log_err()
{
- embeddings.extend_from_slice(&batch_embeddings);
+ if batch_embeddings.len() == embedding_batch.len() {
+ embeddings.extend(batch_embeddings.into_iter().map(Some));
+ continue;
+ }
+ log::error!(
+ "embedding provider returned unexpected embedding count {}, expected {}",
+ batch_embeddings.len(), embedding_batch.len()
+ );
}
+
+ embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
}
let mut embeddings = embeddings.into_iter();
for chunked_file in chunked_files {
- let chunk_embeddings = embeddings
- .by_ref()
- .take(chunked_file.chunks.len())
- .collect::<Vec<_>>();
- let embedded_chunks = chunked_file
- .chunks
- .into_iter()
- .zip(chunk_embeddings)
- .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
- .collect();
- let embedded_file = EmbeddedFile {
- path: chunked_file.entry.path.clone(),
- mtime: chunked_file.entry.mtime,
- chunks: embedded_chunks,
+ let mut embedded_file = EmbeddedFile {
+ path: chunked_file.path,
+ mtime: chunked_file.mtime,
+ chunks: Vec::new(),
};
- embedded_files_tx
- .send((embedded_file, chunked_file.handle))
- .await?;
+ let mut embedded_all_chunks = true;
+ for (chunk, embedding) in
+ chunked_file.chunks.into_iter().zip(embeddings.by_ref())
+ {
+ if let Some(embedding) = embedding {
+ embedded_file
+ .chunks
+ .push(EmbeddedChunk { chunk, embedding });
+ } else {
+ embedded_all_chunks = false;
+ }
+ }
+
+ if embedded_all_chunks {
+ embedded_files_tx
+ .send((embedded_file, chunked_file.handle))
+ .await?;
+ }
}
}
Ok(())
@@ -826,6 +865,21 @@ impl WorktreeIndex {
})
}
+ fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+ let connection = self.db_connection.clone();
+ let db = self.db;
+ cx.background_executor().spawn(async move {
+ let tx = connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ for record in db.iter(&tx)? {
+ let (key, _) = record?;
+ eprintln!("{}", path_for_db_key(key));
+ }
+ Ok(())
+ })
+ }
+
#[cfg(test)]
fn path_count(&self) -> Result<u64> {
let txn = self
@@ -848,7 +902,8 @@ struct ChunkFiles {
}
struct ChunkedFile {
- pub entry: Entry,
+ pub path: Arc<Path>,
+ pub mtime: Option<SystemTime>,
pub handle: IndexingEntryHandle,
pub text: String,
pub chunks: Vec<Chunk>,
@@ -872,11 +927,14 @@ struct EmbeddedChunk {
embedding: Embedding,
}
+/// The set of entries that are currently being indexed.
struct IndexingEntrySet {
entry_ids: Mutex<HashSet<ProjectEntryId>>,
tx: channel::Sender<()>,
}
+/// When dropped, removes the entry from the set of entries that are being indexed.
+#[derive(Clone)]
struct IndexingEntryHandle {
entry_id: ProjectEntryId,
set: Weak<IndexingEntrySet>,
@@ -890,11 +948,11 @@ impl IndexingEntrySet {
}
}
- fn insert(self: &Arc<Self>, entry: &project::Entry) -> IndexingEntryHandle {
- self.entry_ids.lock().insert(entry.id);
+ fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
+ self.entry_ids.lock().insert(entry_id);
self.tx.send_blocking(()).ok();
IndexingEntryHandle {
- entry_id: entry.id,
+ entry_id,
set: Arc::downgrade(self),
}
}
@@ -917,6 +975,10 @@ fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}
+fn path_for_db_key(key: &str) -> String {
+ key.replace('\0', "/")
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -939,7 +1001,22 @@ mod tests {
});
}
- pub struct TestEmbeddingProvider;
+ pub struct TestEmbeddingProvider {
+ batch_size: usize,
+ compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
+ }
+
+ impl TestEmbeddingProvider {
+ pub fn new(
+ batch_size: usize,
+ compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
+ ) -> Self {
+ return Self {
+ batch_size,
+ compute_embedding: Box::new(compute_embedding),
+ };
+ }
+ }
impl EmbeddingProvider for TestEmbeddingProvider {
fn embed<'a>(
@@ -948,29 +1025,13 @@ mod tests {
) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embeddings = texts
.iter()
- .map(|text| {
- let mut embedding = vec![0f32; 2];
- // if the text contains garbage, give it a 1 in the first dimension
- if text.text.contains("garbage in") {
- embedding[0] = 0.9;
- } else {
- embedding[0] = -0.9;
- }
-
- if text.text.contains("garbage out") {
- embedding[1] = 0.9;
- } else {
- embedding[1] = -0.9;
- }
-
- Embedding::new(embedding)
- })
+ .map(|to_embed| (self.compute_embedding)(to_embed.text))
.collect();
- future::ready(Ok(embeddings)).boxed()
+ future::ready(embeddings).boxed()
}
fn batch_size(&self) -> usize {
- 16
+ self.batch_size
}
}
@@ -984,7 +1045,23 @@ mod tests {
let mut semantic_index = SemanticIndex::new(
temp_dir.path().into(),
- Arc::new(TestEmbeddingProvider),
+ Arc::new(TestEmbeddingProvider::new(16, |text| {
+ let mut embedding = vec![0f32; 2];
+ // if the text contains garbage, give it a 1 in the first dimension
+ if text.contains("garbage in") {
+ embedding[0] = 0.9;
+ } else {
+ embedding[0] = -0.9;
+ }
+
+ if text.contains("garbage out") {
+ embedding[1] = 0.9;
+ } else {
+ embedding[1] = -0.9;
+ }
+
+ Ok(Embedding::new(embedding))
+ })),
&mut cx.to_async(),
)
.await
@@ -1046,4 +1123,82 @@ mod tests {
assert!(content.contains("garbage in, garbage out"));
}
+
+ #[gpui::test]
+ async fn test_embed_files(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
+ if text.contains('g') {
+ Err(anyhow!("cannot embed text containing a 'g' character"))
+ } else {
+ Ok(Embedding::new(
+ ('a'..'z')
+ .map(|char| text.chars().filter(|c| *c == char).count() as f32)
+ .collect(),
+ ))
+ }
+ }));
+
+ let (indexing_progress_tx, _) = channel::unbounded();
+ let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
+
+ let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
+ chunked_files_tx
+ .send_blocking(ChunkedFile {
+ path: Path::new("test1.md").into(),
+ mtime: None,
+ handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
+ text: "abcdefghijklmnop".to_string(),
+ chunks: [0..4, 4..8, 8..12, 12..16]
+ .into_iter()
+ .map(|range| Chunk {
+ range,
+ digest: Default::default(),
+ })
+ .collect(),
+ })
+ .unwrap();
+ chunked_files_tx
+ .send_blocking(ChunkedFile {
+ path: Path::new("test2.md").into(),
+ mtime: None,
+ handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
+ text: "qrstuvwxyz".to_string(),
+ chunks: [0..4, 4..8, 8..10]
+ .into_iter()
+ .map(|range| Chunk {
+ range,
+ digest: Default::default(),
+ })
+ .collect(),
+ })
+ .unwrap();
+ chunked_files_tx.close();
+
+ let embed_files_task =
+ cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
+ embed_files_task.task.await.unwrap();
+
+ let mut embedded_files_rx = embed_files_task.files;
+ let mut embedded_files = Vec::new();
+ while let Some((embedded_file, _)) = embedded_files_rx.next().await {
+ embedded_files.push(embedded_file);
+ }
+
+ assert_eq!(embedded_files.len(), 1);
+ assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
+ assert_eq!(
+ embedded_files[0]
+ .chunks
+ .iter()
+ .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
+ .collect::<Vec<Embedding>>(),
+ vec![
+ (provider.compute_embedding)("qrst").unwrap(),
+ (provider.compute_embedding)("uvwx").unwrap(),
+ (provider.compute_embedding)("yz").unwrap(),
+ ],
+ );
+ }
}