embedding_index.rs

  1use crate::{
  2    chunking::{self, Chunk},
  3    embedding::{Embedding, EmbeddingProvider, TextToEmbed},
  4    indexing::{IndexingEntryHandle, IndexingEntrySet},
  5};
  6use anyhow::{anyhow, Context as _, Result};
  7use collections::Bound;
  8use fs::Fs;
  9use futures::stream::StreamExt;
 10use futures_batch::ChunksTimeoutStreamExt;
 11use gpui::{AppContext, Model, Task};
 12use heed::types::{SerdeBincode, Str};
 13use language::LanguageRegistry;
 14use log;
 15use project::{Entry, UpdatedEntriesSet, Worktree};
 16use serde::{Deserialize, Serialize};
 17use smol::channel;
 18use std::{
 19    cmp::Ordering,
 20    future::Future,
 21    iter,
 22    path::Path,
 23    sync::Arc,
 24    time::{Duration, SystemTime},
 25};
 26use util::ResultExt;
 27use worktree::Snapshot;
 28
 29pub struct EmbeddingIndex {
 30    worktree: Model<Worktree>,
 31    db_connection: heed::Env,
 32    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 33    fs: Arc<dyn Fs>,
 34    language_registry: Arc<LanguageRegistry>,
 35    embedding_provider: Arc<dyn EmbeddingProvider>,
 36    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 37}
 38
 39impl EmbeddingIndex {
 40    pub fn new(
 41        worktree: Model<Worktree>,
 42        fs: Arc<dyn Fs>,
 43        db_connection: heed::Env,
 44        embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 45        language_registry: Arc<LanguageRegistry>,
 46        embedding_provider: Arc<dyn EmbeddingProvider>,
 47        entry_ids_being_indexed: Arc<IndexingEntrySet>,
 48    ) -> Self {
 49        Self {
 50            worktree,
 51            fs,
 52            db_connection,
 53            db: embedding_db,
 54            language_registry,
 55            embedding_provider,
 56            entry_ids_being_indexed,
 57        }
 58    }
 59
 60    pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
 61        &self.db
 62    }
 63
 64    pub fn index_entries_changed_on_disk(
 65        &self,
 66        cx: &AppContext,
 67    ) -> impl Future<Output = Result<()>> {
 68        let worktree = self.worktree.read(cx).snapshot();
 69        let worktree_abs_path = worktree.abs_path().clone();
 70        let scan = self.scan_entries(worktree, cx);
 71        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 72        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 73        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 74        async move {
 75            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 76            Ok(())
 77        }
 78    }
 79
 80    pub fn index_updated_entries(
 81        &self,
 82        updated_entries: UpdatedEntriesSet,
 83        cx: &AppContext,
 84    ) -> impl Future<Output = Result<()>> {
 85        let worktree = self.worktree.read(cx).snapshot();
 86        let worktree_abs_path = worktree.abs_path().clone();
 87        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 88        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 89        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 90        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 91        async move {
 92            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 93            Ok(())
 94        }
 95    }
 96
 97    fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
 98        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 99        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
100        let db_connection = self.db_connection.clone();
101        let db = self.db;
102        let entries_being_indexed = self.entry_ids_being_indexed.clone();
103        let task = cx.background_executor().spawn(async move {
104            let txn = db_connection
105                .read_txn()
106                .context("failed to create read transaction")?;
107            let mut db_entries = db
108                .iter(&txn)
109                .context("failed to create iterator")?
110                .move_between_keys()
111                .peekable();
112
113            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
114            for entry in worktree.files(false, 0) {
115                log::trace!("scanning for embedding index: {:?}", &entry.path);
116
117                let entry_db_key = db_key_for_path(&entry.path);
118
119                let mut saved_mtime = None;
120                while let Some(db_entry) = db_entries.peek() {
121                    match db_entry {
122                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
123                            Ordering::Less => {
124                                if let Some(deletion_range) = deletion_range.as_mut() {
125                                    deletion_range.1 = Bound::Included(db_path);
126                                } else {
127                                    deletion_range =
128                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
129                                }
130
131                                db_entries.next();
132                            }
133                            Ordering::Equal => {
134                                if let Some(deletion_range) = deletion_range.take() {
135                                    deleted_entry_ranges_tx
136                                        .send((
137                                            deletion_range.0.map(ToString::to_string),
138                                            deletion_range.1.map(ToString::to_string),
139                                        ))
140                                        .await?;
141                                }
142                                saved_mtime = db_embedded_file.mtime;
143                                db_entries.next();
144                                break;
145                            }
146                            Ordering::Greater => {
147                                break;
148                            }
149                        },
150                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
151                    }
152                }
153
154                if entry.mtime != saved_mtime {
155                    let handle = entries_being_indexed.insert(entry.id);
156                    updated_entries_tx.send((entry.clone(), handle)).await?;
157                }
158            }
159
160            if let Some(db_entry) = db_entries.next() {
161                let (db_path, _) = db_entry?;
162                deleted_entry_ranges_tx
163                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
164                    .await?;
165            }
166
167            Ok(())
168        });
169
170        ScanEntries {
171            updated_entries: updated_entries_rx,
172            deleted_entry_ranges: deleted_entry_ranges_rx,
173            task,
174        }
175    }
176
177    fn scan_updated_entries(
178        &self,
179        worktree: Snapshot,
180        updated_entries: UpdatedEntriesSet,
181        cx: &AppContext,
182    ) -> ScanEntries {
183        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
184        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
185        let entries_being_indexed = self.entry_ids_being_indexed.clone();
186        let task = cx.background_executor().spawn(async move {
187            for (path, entry_id, status) in updated_entries.iter() {
188                match status {
189                    project::PathChange::Added
190                    | project::PathChange::Updated
191                    | project::PathChange::AddedOrUpdated => {
192                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
193                            if entry.is_file() {
194                                let handle = entries_being_indexed.insert(entry.id);
195                                updated_entries_tx.send((entry.clone(), handle)).await?;
196                            }
197                        }
198                    }
199                    project::PathChange::Removed => {
200                        let db_path = db_key_for_path(path);
201                        deleted_entry_ranges_tx
202                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
203                            .await?;
204                    }
205                    project::PathChange::Loaded => {
206                        // Do nothing.
207                    }
208                }
209            }
210
211            Ok(())
212        });
213
214        ScanEntries {
215            updated_entries: updated_entries_rx,
216            deleted_entry_ranges: deleted_entry_ranges_rx,
217            task,
218        }
219    }
220
221    fn chunk_files(
222        &self,
223        worktree_abs_path: Arc<Path>,
224        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
225        cx: &AppContext,
226    ) -> ChunkFiles {
227        let language_registry = self.language_registry.clone();
228        let fs = self.fs.clone();
229        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
230        let task = cx.spawn(|cx| async move {
231            cx.background_executor()
232                .scoped(|cx| {
233                    for _ in 0..cx.num_cpus() {
234                        cx.spawn(async {
235                            while let Ok((entry, handle)) = entries.recv().await {
236                                let entry_abs_path = worktree_abs_path.join(&entry.path);
237                                if let Some(text) = fs.load(&entry_abs_path).await.ok() {
238                                    let language = language_registry
239                                        .language_for_file_path(&entry.path)
240                                        .await
241                                        .ok();
242                                    let chunked_file = ChunkedFile {
243                                        chunks: chunking::chunk_text(
244                                            &text,
245                                            language.as_ref(),
246                                            &entry.path,
247                                        ),
248                                        handle,
249                                        path: entry.path,
250                                        mtime: entry.mtime,
251                                        text,
252                                    };
253
254                                    if chunked_files_tx.send(chunked_file).await.is_err() {
255                                        return;
256                                    }
257                                }
258                            }
259                        });
260                    }
261                })
262                .await;
263            Ok(())
264        });
265
266        ChunkFiles {
267            files: chunked_files_rx,
268            task,
269        }
270    }
271
272    pub fn embed_files(
273        embedding_provider: Arc<dyn EmbeddingProvider>,
274        chunked_files: channel::Receiver<ChunkedFile>,
275        cx: &AppContext,
276    ) -> EmbedFiles {
277        let embedding_provider = embedding_provider.clone();
278        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
279        let task = cx.background_executor().spawn(async move {
280            let mut chunked_file_batches =
281                chunked_files.chunks_timeout(512, Duration::from_secs(2));
282            while let Some(chunked_files) = chunked_file_batches.next().await {
283                // View the batch of files as a vec of chunks
284                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
285                // Once those are done, reassemble them back into the files in which they belong
286                // If any embeddings fail for a file, the entire file is discarded
287
288                let chunks: Vec<TextToEmbed> = chunked_files
289                    .iter()
290                    .flat_map(|file| {
291                        file.chunks.iter().map(|chunk| TextToEmbed {
292                            text: &file.text[chunk.range.clone()],
293                            digest: chunk.digest,
294                        })
295                    })
296                    .collect::<Vec<_>>();
297
298                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
299                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
300                    if let Some(batch_embeddings) =
301                        embedding_provider.embed(embedding_batch).await.log_err()
302                    {
303                        if batch_embeddings.len() == embedding_batch.len() {
304                            embeddings.extend(batch_embeddings.into_iter().map(Some));
305                            continue;
306                        }
307                        log::error!(
308                            "embedding provider returned unexpected embedding count {}, expected {}",
309                            batch_embeddings.len(), embedding_batch.len()
310                        );
311                    }
312
313                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
314                }
315
316                let mut embeddings = embeddings.into_iter();
317                for chunked_file in chunked_files {
318                    let mut embedded_file = EmbeddedFile {
319                        path: chunked_file.path,
320                        mtime: chunked_file.mtime,
321                        chunks: Vec::new(),
322                    };
323
324                    let mut embedded_all_chunks = true;
325                    for (chunk, embedding) in
326                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
327                    {
328                        if let Some(embedding) = embedding {
329                            embedded_file
330                                .chunks
331                                .push(EmbeddedChunk { chunk, embedding });
332                        } else {
333                            embedded_all_chunks = false;
334                        }
335                    }
336
337                    if embedded_all_chunks {
338                        embedded_files_tx
339                            .send((embedded_file, chunked_file.handle))
340                            .await?;
341                    }
342                }
343            }
344            Ok(())
345        });
346
347        EmbedFiles {
348            files: embedded_files_rx,
349            task,
350        }
351    }
352
353    fn persist_embeddings(
354        &self,
355        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
356        mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
357        cx: &AppContext,
358    ) -> Task<Result<()>> {
359        let db_connection = self.db_connection.clone();
360        let db = self.db;
361
362        cx.background_executor().spawn(async move {
363            loop {
364                // Interleave deletions and persists of embedded files
365                futures::select_biased! {
366                    deletion_range = deleted_entry_ranges.next() => {
367                        if let Some(deletion_range) = deletion_range {
368                            let mut txn = db_connection.write_txn()?;
369                            let start = deletion_range.0.as_ref().map(|start| start.as_str());
370                            let end = deletion_range.1.as_ref().map(|end| end.as_str());
371                            log::debug!("deleting embeddings in range {:?}", &(start, end));
372                            db.delete_range(&mut txn, &(start, end))?;
373                            txn.commit()?;
374                        }
375                    },
376                    file = embedded_files.next() => {
377                        if let Some((file, _)) = file {
378                            let mut txn = db_connection.write_txn()?;
379                            log::debug!("saving embedding for file {:?}", file.path);
380                            let key = db_key_for_path(&file.path);
381                            db.put(&mut txn, &key, &file)?;
382                            txn.commit()?;
383                        }
384                    },
385                    complete => break,
386                }
387            }
388
389            Ok(())
390        })
391    }
392
393    pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
394        let connection = self.db_connection.clone();
395        let db = self.db;
396        cx.background_executor().spawn(async move {
397            let tx = connection
398                .read_txn()
399                .context("failed to create read transaction")?;
400            let result = db
401                .iter(&tx)?
402                .map(|entry| Ok(entry?.1.path.clone()))
403                .collect::<Result<Vec<Arc<Path>>>>();
404            drop(tx);
405            result
406        })
407    }
408
409    pub fn chunks_for_path(
410        &self,
411        path: Arc<Path>,
412        cx: &AppContext,
413    ) -> Task<Result<Vec<EmbeddedChunk>>> {
414        let connection = self.db_connection.clone();
415        let db = self.db;
416        cx.background_executor().spawn(async move {
417            let tx = connection
418                .read_txn()
419                .context("failed to create read transaction")?;
420            Ok(db
421                .get(&tx, &db_key_for_path(&path))?
422                .ok_or_else(|| anyhow!("no such path"))?
423                .chunks
424                .clone())
425        })
426    }
427}
428
429struct ScanEntries {
430    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
431    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
432    task: Task<Result<()>>,
433}
434
435struct ChunkFiles {
436    files: channel::Receiver<ChunkedFile>,
437    task: Task<Result<()>>,
438}
439
440pub struct ChunkedFile {
441    pub path: Arc<Path>,
442    pub mtime: Option<SystemTime>,
443    pub handle: IndexingEntryHandle,
444    pub text: String,
445    pub chunks: Vec<Chunk>,
446}
447
448pub struct EmbedFiles {
449    pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
450    pub task: Task<Result<()>>,
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub struct EmbeddedFile {
455    pub path: Arc<Path>,
456    pub mtime: Option<SystemTime>,
457    pub chunks: Vec<EmbeddedChunk>,
458}
459
460#[derive(Clone, Debug, Serialize, Deserialize)]
461pub struct EmbeddedChunk {
462    pub chunk: Chunk,
463    pub embedding: Embedding,
464}
465
466fn db_key_for_path(path: &Arc<Path>) -> String {
467    path.to_string_lossy().replace('/', "\0")
468}