From ce62173534cff576776a92e154149153b183936e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 6 Sep 2023 16:48:53 +0200 Subject: [PATCH] Rename `Document` to `Span` --- crates/semantic_index/src/db.rs | 57 ++++++++-------- crates/semantic_index/src/embedding_queue.rs | 54 +++++++-------- crates/semantic_index/src/parsing.rs | 66 +++++++++---------- crates/semantic_index/src/semantic_index.rs | 32 ++++----- .../src/semantic_index_tests.rs | 12 ++-- 5 files changed, 109 insertions(+), 112 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 5664210388ab68ccc52e55221aedf67afbd1b635..28bbd56156b820168b13998d41c1ea061706dcf9 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,6 +1,6 @@ use crate::{ embedding::Embedding, - parsing::{Document, DocumentDigest}, + parsing::{Span, SpanDigest}, SEMANTIC_INDEX_VERSION, }; use anyhow::{anyhow, Context, Result}; @@ -124,8 +124,8 @@ impl VectorDatabase { } log::trace!("vector database schema out of date. updating..."); - db.execute("DROP TABLE IF EXISTS documents", []) - .context("failed to drop 'documents' table")?; + db.execute("DROP TABLE IF EXISTS spans", []) + .context("failed to drop 'spans' table")?; db.execute("DROP TABLE IF EXISTS files", []) .context("failed to drop 'files' table")?; db.execute("DROP TABLE IF EXISTS worktrees", []) @@ -174,7 +174,7 @@ impl VectorDatabase { )?; db.execute( - "CREATE TABLE documents ( + "CREATE TABLE spans ( id INTEGER PRIMARY KEY AUTOINCREMENT, file_id INTEGER NOT NULL, start_byte INTEGER NOT NULL, @@ -211,7 +211,7 @@ impl VectorDatabase { worktree_id: i64, path: Arc, mtime: SystemTime, - documents: Vec, + spans: Vec, ) -> impl Future> { self.transact(move |db| { // Return the existing ID, if both the file and mtime match @@ -231,7 +231,7 @@ impl VectorDatabase { let t0 = Instant::now(); let mut query = db.prepare( " - INSERT INTO documents + INSERT INTO spans (file_id, start_byte, end_byte, name, embedding, digest) VALUES (?1, ?2, ?3, ?4, ?5, ?6) ", @@ -241,14 +241,14 @@ impl VectorDatabase { t0.elapsed().as_millis() ); - for document in documents { + for span in spans { query.execute(params![ file_id, - document.range.start.to_string(), - document.range.end.to_string(), - document.name, - document.embedding, - document.digest + span.range.start.to_string(), + span.range.end.to_string(), + span.name, + span.embedding, + span.digest ])?; } @@ -278,13 +278,13 @@ impl VectorDatabase { pub fn embeddings_for_files( &self, worktree_id_file_paths: HashMap>>, - ) -> impl Future>> { + ) -> impl Future>> { self.transact(move |db| { let mut query = db.prepare( " SELECT digest, embedding - FROM documents - LEFT JOIN files ON files.id = documents.file_id + FROM spans + LEFT JOIN files ON files.id = spans.file_id WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) ", )?; @@ -297,10 +297,7 @@ impl VectorDatabase { .collect::>(), ); let rows = query.query_map(params![worktree_id, file_paths], |row| { - Ok(( - row.get::<_, DocumentDigest>(0)?, - row.get::<_, Embedding>(1)?, - )) + Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) })?; for row in rows { @@ -379,7 +376,7 @@ impl VectorDatabase { let file_ids = file_ids.to_vec(); self.transact(move |db| { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - Self::for_each_document(db, &file_ids, |id, embedding| { + Self::for_each_span(db, &file_ids, |id, embedding| { let similarity = embedding.similarity(&query_embedding); let ix = match results.binary_search_by(|(_, s)| { similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) @@ -434,7 +431,7 @@ impl VectorDatabase { }) } - fn for_each_document( + fn for_each_span( db: &rusqlite::Connection, file_ids: &[i64], mut f: impl FnMut(i64, Embedding), @@ -444,7 +441,7 @@ impl VectorDatabase { SELECT id, embedding FROM - documents + spans WHERE file_id IN rarray(?) ", @@ -459,7 +456,7 @@ impl VectorDatabase { Ok(()) } - pub fn get_documents_by_ids( + pub fn spans_for_ids( &self, ids: &[i64], ) -> impl Future)>>> { @@ -468,16 +465,16 @@ impl VectorDatabase { let mut statement = db.prepare( " SELECT - documents.id, + spans.id, files.worktree_id, files.relative_path, - documents.start_byte, - documents.end_byte + spans.start_byte, + spans.end_byte FROM - documents, files + spans, files WHERE - documents.file_id = files.id AND - documents.id in rarray(?) + spans.file_id = files.id AND + spans.id in rarray(?) ", )?; @@ -500,7 +497,7 @@ impl VectorDatabase { for id in &ids { let value = values_by_id .remove(id) - .ok_or(anyhow!("missing document id {}", id))?; + .ok_or(anyhow!("missing span id {}", id))?; results.push(value); } diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index f1abbde3a4e2d528bc0b0b292511462c515b4a74..024881f0b808734a4e1d0fff2f48f07ddc88470b 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,4 +1,4 @@ -use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle}; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -9,7 +9,7 @@ pub struct FileToEmbed { pub worktree_id: i64, pub path: Arc, pub mtime: SystemTime, - pub documents: Vec, + pub spans: Vec, pub job_handle: JobHandle, } @@ -19,7 +19,7 @@ impl std::fmt::Debug for FileToEmbed { .field("worktree_id", &self.worktree_id) .field("path", &self.path) .field("mtime", &self.mtime) - .field("document", &self.documents) + .field("spans", &self.spans) .finish_non_exhaustive() } } @@ -29,13 +29,13 @@ impl PartialEq for FileToEmbed { self.worktree_id == other.worktree_id && self.path == other.path && self.mtime == other.mtime - && self.documents == other.documents + && self.spans == other.spans } } pub struct EmbeddingQueue { embedding_provider: Arc, - pending_batch: Vec, + pending_batch: Vec, executor: Arc, pending_batch_token_count: usize, finished_files_tx: channel::Sender, @@ -43,9 +43,9 @@ pub struct EmbeddingQueue { } #[derive(Clone)] -pub struct FileToEmbedFragment { +pub struct FileFragmentToEmbed { file: Arc>, - document_range: Range, + span_range: Range, } impl EmbeddingQueue { @@ -62,41 +62,41 @@ impl EmbeddingQueue { } pub fn push(&mut self, file: FileToEmbed) { - if file.documents.is_empty() { + if file.spans.is_empty() { self.finished_files_tx.try_send(file).unwrap(); return; } let file = Arc::new(Mutex::new(file)); - self.pending_batch.push(FileToEmbedFragment { + self.pending_batch.push(FileFragmentToEmbed { file: file.clone(), - document_range: 0..0, + span_range: 0..0, }); - let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; let mut saved_tokens = 0; - for (ix, document) in file.lock().documents.iter().enumerate() { - let document_token_count = if document.embedding.is_none() { - document.token_count + for (ix, span) in file.lock().spans.iter().enumerate() { + let span_token_count = if span.embedding.is_none() { + span.token_count } else { - saved_tokens += document.token_count; + saved_tokens += span.token_count; 0 }; - let next_token_count = self.pending_batch_token_count + document_token_count; + let next_token_count = self.pending_batch_token_count + span_token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; self.flush(); - self.pending_batch.push(FileToEmbedFragment { + self.pending_batch.push(FileFragmentToEmbed { file: file.clone(), - document_range: range_end..range_end, + span_range: range_end..range_end, }); - fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; } fragment_range.end = ix + 1; - self.pending_batch_token_count += document_token_count; + self.pending_batch_token_count += span_token_count; } log::trace!("Saved Tokens: {:?}", saved_tokens); } @@ -113,20 +113,20 @@ impl EmbeddingQueue { self.executor.spawn(async move { let mut spans = Vec::new(); - let mut document_count = 0; + let mut span_count = 0; for fragment in &batch { let file = fragment.file.lock(); - document_count += file.documents[fragment.document_range.clone()].len(); + span_count += file.spans[fragment.span_range.clone()].len(); spans.extend( { - file.documents[fragment.document_range.clone()] + file.spans[fragment.span_range.clone()] .iter().filter(|d| d.embedding.is_none()) .map(|d| d.content.clone()) } ); } - log::trace!("Documents Length: {:?}", document_count); + log::trace!("Documents Length: {:?}", span_count); log::trace!("Span Length: {:?}", spans.clone().len()); // If spans is 0, just send the fragment to the finished files if its the last one. @@ -143,11 +143,11 @@ impl EmbeddingQueue { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { - for document in - &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) + for span in + &mut fragment.file.lock().spans[fragment.span_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) { if let Some(embedding) = embeddings.next() { - document.embedding = Some(embedding); + span.embedding = Some(embedding); } else { // log::error!("number of embeddings returned different from number of documents"); diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index c0a94c6b7355cb51bc5193a2bfd01148c28f4162..b6fc000e1dc7530483b97418e15347f1b9985832 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -16,9 +16,9 @@ use std::{ use tree_sitter::{Parser, QueryCursor}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub struct DocumentDigest([u8; 20]); +pub struct SpanDigest([u8; 20]); -impl FromSql for DocumentDigest { +impl FromSql for SpanDigest { fn column_result(value: ValueRef) -> FromSqlResult { let blob = value.as_blob()?; let bytes = @@ -27,17 +27,17 @@ impl FromSql for DocumentDigest { expected_size: 20, blob_size: blob.len(), })?; - return Ok(DocumentDigest(bytes)); + return Ok(SpanDigest(bytes)); } } -impl ToSql for DocumentDigest { +impl ToSql for SpanDigest { fn to_sql(&self) -> rusqlite::Result { self.0.to_sql() } } -impl From<&'_ str> for DocumentDigest { +impl From<&'_ str> for SpanDigest { fn from(value: &'_ str) -> Self { let mut sha1 = Sha1::new(); sha1.update(value); @@ -46,12 +46,12 @@ impl From<&'_ str> for DocumentDigest { } #[derive(Debug, PartialEq, Clone)] -pub struct Document { +pub struct Span { pub name: String, pub range: Range, pub content: String, pub embedding: Option, - pub digest: DocumentDigest, + pub digest: SpanDigest, pub token_count: usize, } @@ -97,14 +97,14 @@ impl CodeContextRetriever { relative_path: &Path, language_name: Arc, content: &str, - ) -> Result> { + ) -> Result> { let document_span = ENTIRE_FILE_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("", &content); - let digest = DocumentDigest::from(document_span.as_str()); + let digest = SpanDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { + Ok(vec![Span { range: 0..content.len(), content: document_span, embedding: Default::default(), @@ -114,13 +114,13 @@ impl CodeContextRetriever { }]) } - fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result> { + fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result> { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - let digest = DocumentDigest::from(document_span.as_str()); + let digest = SpanDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { + Ok(vec![Span { range: 0..content.len(), content: document_span, embedding: None, @@ -191,32 +191,32 @@ impl CodeContextRetriever { relative_path: &Path, content: &str, language: Arc, - ) -> Result> { + ) -> Result> { let language_name = language.name(); if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { return self.parse_entire_file(relative_path, language_name, &content); - } else if &language_name.to_string() == &"Markdown".to_string() { + } else if language_name.as_ref() == "Markdown" { return self.parse_markdown_file(relative_path, &content); } - let mut documents = self.parse_file(content, language)?; - for document in &mut documents { + let mut spans = self.parse_file(content, language)?; + for span in &mut spans { let document_content = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) - .replace("item", &document.content); + .replace("item", &span.content); let (document_content, token_count) = self.embedding_provider.truncate(&document_content); - document.content = document_content; - document.token_count = token_count; + span.content = document_content; + span.token_count = token_count; } - Ok(documents) + Ok(spans) } - pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { + pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { let grammar = language .grammar() .ok_or_else(|| anyhow!("no grammar for language"))?; @@ -227,7 +227,7 @@ impl CodeContextRetriever { let language_scope = language.default_scope(); let placeholder = language_scope.collapsed_placeholder(); - let mut documents = Vec::new(); + let mut spans = Vec::new(); let mut collapsed_ranges_within = Vec::new(); let mut parsed_name_ranges = HashSet::new(); for (i, context_match) in matches.iter().enumerate() { @@ -267,22 +267,22 @@ impl CodeContextRetriever { collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end))); - let mut document_content = String::new(); + let mut span_content = String::new(); for context_range in &context_match.context_ranges { add_content_from_range( - &mut document_content, + &mut span_content, content, context_range.clone(), context_match.start_col, ); - document_content.push_str("\n"); + span_content.push_str("\n"); } let mut offset = item_range.start; for collapsed_range in &collapsed_ranges_within { if collapsed_range.start > offset { add_content_from_range( - &mut document_content, + &mut span_content, content, offset..collapsed_range.start, context_match.start_col, @@ -291,24 +291,24 @@ impl CodeContextRetriever { } if collapsed_range.end > offset { - document_content.push_str(placeholder); + span_content.push_str(placeholder); offset = collapsed_range.end; } } if offset < item_range.end { add_content_from_range( - &mut document_content, + &mut span_content, content, offset..item_range.end, context_match.start_col, ); } - let sha1 = DocumentDigest::from(document_content.as_str()); - documents.push(Document { + let sha1 = SpanDigest::from(span_content.as_str()); + spans.push(Span { name, - content: document_content, + content: span_content, range: item_range.clone(), embedding: None, digest: sha1, @@ -316,7 +316,7 @@ impl CodeContextRetriever { }) } - return Ok(documents); + return Ok(spans); } } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index a098152784a5822f4ec31ded0230bd5e6d808315..1c1c40fa27de3fed4641498ed20eba5877410965 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -17,7 +17,7 @@ use futures::{future, FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; use smol::channel; @@ -36,7 +36,7 @@ use util::{ ResultExt, }; -const SEMANTIC_INDEX_VERSION: usize = 9; +const SEMANTIC_INDEX_VERSION: usize = 10; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); @@ -84,7 +84,7 @@ pub struct SemanticIndex { db: VectorDatabase, embedding_provider: Arc, language_registry: Arc, - parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, + parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, @@ -252,16 +252,16 @@ impl SemanticIndex { let db = db.clone(); async move { while let Ok(file) = embedded_files.recv().await { - db.insert_file(file.worktree_id, file.path, file.mtime, file.documents) + db.insert_file(file.worktree_id, file.path, file.mtime, file.spans) .await .log_err(); } } }); - // Parse files into embeddable documents. + // Parse files into embeddable spans. let (parsing_files_tx, parsing_files_rx) = - channel::unbounded::<(Arc>, PendingFile)>(); + channel::unbounded::<(Arc>, PendingFile)>(); let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { @@ -320,26 +320,26 @@ impl SemanticIndex { pending_file: PendingFile, retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, - embeddings_for_digest: &HashMap, + embeddings_for_digest: &HashMap, ) { let Some(language) = pending_file.language else { return; }; if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - if let Some(mut documents) = retriever + if let Some(mut spans) = retriever .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { log::trace!( - "parsed path {:?}: {} documents", + "parsed path {:?}: {} spans", pending_file.relative_path, - documents.len() + spans.len() ); - for document in documents.iter_mut() { - if let Some(embedding) = embeddings_for_digest.get(&document.digest) { - document.embedding = Some(embedding.to_owned()); + for span in &mut spans { + if let Some(embedding) = embeddings_for_digest.get(&span.digest) { + span.embedding = Some(embedding.to_owned()); } } @@ -348,7 +348,7 @@ impl SemanticIndex { path: pending_file.relative_path, mtime: pending_file.modified_time, job_handle: pending_file.job_handle, - documents, + spans: spans, }); } } @@ -708,13 +708,13 @@ impl SemanticIndex { } let ids = results.into_iter().map(|(id, _)| id).collect::>(); - let documents = database.get_documents_by_ids(ids.as_slice()).await?; + let spans = database.spans_for_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); let mut ranges = Vec::new(); let weak_project = project.downgrade(); project.update(&mut cx, |project, cx| { - for (worktree_db_id, file_path, byte_range) in documents { + for (worktree_db_id, file_path, byte_range) in spans { let project_state = if let Some(state) = this.read(cx).projects.get(&weak_project) { state diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index fe1b6b9cf9bb7958a4aec99a34e411885ecbdb2a..ffd8db87814cacb09143381891b993ca86e173e7 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,7 +1,7 @@ use crate::{ embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest}, + parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest}, semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; @@ -204,15 +204,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { worktree_id: 5, path: Path::new(&format!("path-{file_ix}")).into(), mtime: SystemTime::now(), - documents: (0..rng.gen_range(4..22)) + spans: (0..rng.gen_range(4..22)) .map(|document_ix| { let content_len = rng.gen_range(10..100); let content = RandomCharIter::new(&mut rng) .with_simple_text() .take(content_len) .collect::(); - let digest = DocumentDigest::from(content.as_str()); - Document { + let digest = SpanDigest::from(content.as_str()); + Span { range: 0..10, embedding: None, name: format!("document {document_ix}"), @@ -245,7 +245,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .iter() .map(|file| { let mut file = file.clone(); - for doc in &mut file.documents { + for doc in &mut file.spans { doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); } file @@ -437,7 +437,7 @@ async fn test_code_context_retrieval_json() { } fn assert_documents_eq( - documents: &[Document], + documents: &[Span], expected_contents_and_start_offsets: &[(String, usize)], ) { assert_eq!(