embedding_index.rs

  1use crate::{
  2    chunking::{self, Chunk},
  3    embedding::{Embedding, EmbeddingProvider, TextToEmbed},
  4    indexing::{IndexingEntryHandle, IndexingEntrySet},
  5};
  6use anyhow::{anyhow, Context as _, Result};
  7use collections::Bound;
  8use feature_flags::FeatureFlagAppExt;
  9use fs::Fs;
 10use fs::MTime;
 11use futures::{stream::StreamExt, FutureExt as _};
 12use futures_batch::ChunksTimeoutStreamExt;
 13use gpui::{App, AppContext as _, Entity, Task};
 14use heed::types::{SerdeBincode, Str};
 15use language::LanguageRegistry;
 16use log;
 17use project::{Entry, UpdatedEntriesSet, Worktree};
 18use serde::{Deserialize, Serialize};
 19use smol::channel;
 20use std::{cmp::Ordering, future::Future, iter, path::Path, pin::pin, sync::Arc, time::Duration};
 21use util::ResultExt;
 22use worktree::Snapshot;
 23
 24pub struct EmbeddingIndex {
 25    worktree: Entity<Worktree>,
 26    db_connection: heed::Env,
 27    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 28    fs: Arc<dyn Fs>,
 29    language_registry: Arc<LanguageRegistry>,
 30    embedding_provider: Arc<dyn EmbeddingProvider>,
 31    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 32}
 33
 34impl EmbeddingIndex {
 35    pub fn new(
 36        worktree: Entity<Worktree>,
 37        fs: Arc<dyn Fs>,
 38        db_connection: heed::Env,
 39        embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 40        language_registry: Arc<LanguageRegistry>,
 41        embedding_provider: Arc<dyn EmbeddingProvider>,
 42        entry_ids_being_indexed: Arc<IndexingEntrySet>,
 43    ) -> Self {
 44        Self {
 45            worktree,
 46            fs,
 47            db_connection,
 48            db: embedding_db,
 49            language_registry,
 50            embedding_provider,
 51            entry_ids_being_indexed,
 52        }
 53    }
 54
 55    pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
 56        &self.db
 57    }
 58
 59    pub fn index_entries_changed_on_disk(&self, cx: &App) -> impl Future<Output = Result<()>> {
 60        if !cx.is_staff() {
 61            return async move { Ok(()) }.boxed();
 62        }
 63
 64        let worktree = self.worktree.read(cx).snapshot();
 65        let worktree_abs_path = worktree.abs_path().clone();
 66        let scan = self.scan_entries(worktree, cx);
 67        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 68        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 69        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 70        async move {
 71            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 72            Ok(())
 73        }
 74        .boxed()
 75    }
 76
 77    pub fn index_updated_entries(
 78        &self,
 79        updated_entries: UpdatedEntriesSet,
 80        cx: &App,
 81    ) -> impl Future<Output = Result<()>> {
 82        if !cx.is_staff() {
 83            return async move { Ok(()) }.boxed();
 84        }
 85
 86        let worktree = self.worktree.read(cx).snapshot();
 87        let worktree_abs_path = worktree.abs_path().clone();
 88        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 89        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 90        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 91        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 92        async move {
 93            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 94            Ok(())
 95        }
 96        .boxed()
 97    }
 98
 99    fn scan_entries(&self, worktree: Snapshot, cx: &App) -> ScanEntries {
100        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
101        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
102        let db_connection = self.db_connection.clone();
103        let db = self.db;
104        let entries_being_indexed = self.entry_ids_being_indexed.clone();
105        let task = cx.background_spawn(async move {
106            let txn = db_connection
107                .read_txn()
108                .context("failed to create read transaction")?;
109            let mut db_entries = db
110                .iter(&txn)
111                .context("failed to create iterator")?
112                .move_between_keys()
113                .peekable();
114
115            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
116            for entry in worktree.files(false, 0) {
117                log::trace!("scanning for embedding index: {:?}", &entry.path);
118
119                let entry_db_key = db_key_for_path(&entry.path);
120
121                let mut saved_mtime = None;
122                while let Some(db_entry) = db_entries.peek() {
123                    match db_entry {
124                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
125                            Ordering::Less => {
126                                if let Some(deletion_range) = deletion_range.as_mut() {
127                                    deletion_range.1 = Bound::Included(db_path);
128                                } else {
129                                    deletion_range =
130                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
131                                }
132
133                                db_entries.next();
134                            }
135                            Ordering::Equal => {
136                                if let Some(deletion_range) = deletion_range.take() {
137                                    deleted_entry_ranges_tx
138                                        .send((
139                                            deletion_range.0.map(ToString::to_string),
140                                            deletion_range.1.map(ToString::to_string),
141                                        ))
142                                        .await?;
143                                }
144                                saved_mtime = db_embedded_file.mtime;
145                                db_entries.next();
146                                break;
147                            }
148                            Ordering::Greater => {
149                                break;
150                            }
151                        },
152                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
153                    }
154                }
155
156                if entry.mtime != saved_mtime {
157                    let handle = entries_being_indexed.insert(entry.id);
158                    updated_entries_tx.send((entry.clone(), handle)).await?;
159                }
160            }
161
162            if let Some(db_entry) = db_entries.next() {
163                let (db_path, _) = db_entry?;
164                deleted_entry_ranges_tx
165                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
166                    .await?;
167            }
168
169            Ok(())
170        });
171
172        ScanEntries {
173            updated_entries: updated_entries_rx,
174            deleted_entry_ranges: deleted_entry_ranges_rx,
175            task,
176        }
177    }
178
179    fn scan_updated_entries(
180        &self,
181        worktree: Snapshot,
182        updated_entries: UpdatedEntriesSet,
183        cx: &App,
184    ) -> ScanEntries {
185        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
186        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
187        let entries_being_indexed = self.entry_ids_being_indexed.clone();
188        let task = cx.background_spawn(async move {
189            for (path, entry_id, status) in updated_entries.iter() {
190                match status {
191                    project::PathChange::Added
192                    | project::PathChange::Updated
193                    | project::PathChange::AddedOrUpdated => {
194                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
195                            if entry.is_file() {
196                                let handle = entries_being_indexed.insert(entry.id);
197                                updated_entries_tx.send((entry.clone(), handle)).await?;
198                            }
199                        }
200                    }
201                    project::PathChange::Removed => {
202                        let db_path = db_key_for_path(path);
203                        deleted_entry_ranges_tx
204                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
205                            .await?;
206                    }
207                    project::PathChange::Loaded => {
208                        // Do nothing.
209                    }
210                }
211            }
212
213            Ok(())
214        });
215
216        ScanEntries {
217            updated_entries: updated_entries_rx,
218            deleted_entry_ranges: deleted_entry_ranges_rx,
219            task,
220        }
221    }
222
223    fn chunk_files(
224        &self,
225        worktree_abs_path: Arc<Path>,
226        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
227        cx: &App,
228    ) -> ChunkFiles {
229        let language_registry = self.language_registry.clone();
230        let fs = self.fs.clone();
231        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
232        let task = cx.spawn(async move |cx| {
233            cx.background_executor()
234                .scoped(|cx| {
235                    for _ in 0..cx.num_cpus() {
236                        cx.spawn(async {
237                            while let Ok((entry, handle)) = entries.recv().await {
238                                let entry_abs_path = worktree_abs_path.join(&entry.path);
239                                if let Some(text) = fs.load(&entry_abs_path).await.ok() {
240                                    let language = language_registry
241                                        .language_for_file_path(&entry.path)
242                                        .await
243                                        .ok();
244                                    let chunked_file = ChunkedFile {
245                                        chunks: chunking::chunk_text(
246                                            &text,
247                                            language.as_ref(),
248                                            &entry.path,
249                                        ),
250                                        handle,
251                                        path: entry.path,
252                                        mtime: entry.mtime,
253                                        text,
254                                    };
255
256                                    if chunked_files_tx.send(chunked_file).await.is_err() {
257                                        return;
258                                    }
259                                }
260                            }
261                        });
262                    }
263                })
264                .await;
265            Ok(())
266        });
267
268        ChunkFiles {
269            files: chunked_files_rx,
270            task,
271        }
272    }
273
274    pub fn embed_files(
275        embedding_provider: Arc<dyn EmbeddingProvider>,
276        chunked_files: channel::Receiver<ChunkedFile>,
277        cx: &App,
278    ) -> EmbedFiles {
279        let embedding_provider = embedding_provider.clone();
280        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
281        let task = cx.background_spawn(async move {
282            let mut chunked_file_batches =
283                pin!(chunked_files.chunks_timeout(512, Duration::from_secs(2)));
284            while let Some(chunked_files) = chunked_file_batches.next().await {
285                // View the batch of files as a vec of chunks
286                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
287                // Once those are done, reassemble them back into the files in which they belong
288                // If any embeddings fail for a file, the entire file is discarded
289
290                let chunks: Vec<TextToEmbed> = chunked_files
291                    .iter()
292                    .flat_map(|file| {
293                        file.chunks.iter().map(|chunk| TextToEmbed {
294                            text: &file.text[chunk.range.clone()],
295                            digest: chunk.digest,
296                        })
297                    })
298                    .collect::<Vec<_>>();
299
300                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
301                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
302                    if let Some(batch_embeddings) =
303                        embedding_provider.embed(embedding_batch).await.log_err()
304                    {
305                        if batch_embeddings.len() == embedding_batch.len() {
306                            embeddings.extend(batch_embeddings.into_iter().map(Some));
307                            continue;
308                        }
309                        log::error!(
310                            "embedding provider returned unexpected embedding count {}, expected {}",
311                            batch_embeddings.len(), embedding_batch.len()
312                        );
313                    }
314
315                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
316                }
317
318                let mut embeddings = embeddings.into_iter();
319                for chunked_file in chunked_files {
320                    let mut embedded_file = EmbeddedFile {
321                        path: chunked_file.path,
322                        mtime: chunked_file.mtime,
323                        chunks: Vec::new(),
324                    };
325
326                    let mut embedded_all_chunks = true;
327                    for (chunk, embedding) in
328                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
329                    {
330                        if let Some(embedding) = embedding {
331                            embedded_file
332                                .chunks
333                                .push(EmbeddedChunk { chunk, embedding });
334                        } else {
335                            embedded_all_chunks = false;
336                        }
337                    }
338
339                    if embedded_all_chunks {
340                        embedded_files_tx
341                            .send((embedded_file, chunked_file.handle))
342                            .await?;
343                    }
344                }
345            }
346            Ok(())
347        });
348
349        EmbedFiles {
350            files: embedded_files_rx,
351            task,
352        }
353    }
354
355    fn persist_embeddings(
356        &self,
357        deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
358        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
359        cx: &App,
360    ) -> Task<Result<()>> {
361        let db_connection = self.db_connection.clone();
362        let db = self.db;
363
364        cx.background_spawn(async move {
365            let mut deleted_entry_ranges = pin!(deleted_entry_ranges);
366            let mut embedded_files = pin!(embedded_files);
367            loop {
368                // Interleave deletions and persists of embedded files
369                futures::select_biased! {
370                    deletion_range = deleted_entry_ranges.next() => {
371                        if let Some(deletion_range) = deletion_range {
372                            let mut txn = db_connection.write_txn()?;
373                            let start = deletion_range.0.as_ref().map(|start| start.as_str());
374                            let end = deletion_range.1.as_ref().map(|end| end.as_str());
375                            log::debug!("deleting embeddings in range {:?}", &(start, end));
376                            db.delete_range(&mut txn, &(start, end))?;
377                            txn.commit()?;
378                        }
379                    },
380                    file = embedded_files.next() => {
381                        if let Some((file, _)) = file {
382                            let mut txn = db_connection.write_txn()?;
383                            log::debug!("saving embedding for file {:?}", file.path);
384                            let key = db_key_for_path(&file.path);
385                            db.put(&mut txn, &key, &file)?;
386                            txn.commit()?;
387                        }
388                    },
389                    complete => break,
390                }
391            }
392
393            Ok(())
394        })
395    }
396
397    pub fn paths(&self, cx: &App) -> Task<Result<Vec<Arc<Path>>>> {
398        let connection = self.db_connection.clone();
399        let db = self.db;
400        cx.background_spawn(async move {
401            let tx = connection
402                .read_txn()
403                .context("failed to create read transaction")?;
404            let result = db
405                .iter(&tx)?
406                .map(|entry| Ok(entry?.1.path.clone()))
407                .collect::<Result<Vec<Arc<Path>>>>();
408            drop(tx);
409            result
410        })
411    }
412
413    pub fn chunks_for_path(&self, path: Arc<Path>, cx: &App) -> Task<Result<Vec<EmbeddedChunk>>> {
414        let connection = self.db_connection.clone();
415        let db = self.db;
416        cx.background_spawn(async move {
417            let tx = connection
418                .read_txn()
419                .context("failed to create read transaction")?;
420            Ok(db
421                .get(&tx, &db_key_for_path(&path))?
422                .ok_or_else(|| anyhow!("no such path"))?
423                .chunks
424                .clone())
425        })
426    }
427}
428
429struct ScanEntries {
430    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
431    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
432    task: Task<Result<()>>,
433}
434
435struct ChunkFiles {
436    files: channel::Receiver<ChunkedFile>,
437    task: Task<Result<()>>,
438}
439
440pub struct ChunkedFile {
441    pub path: Arc<Path>,
442    pub mtime: Option<MTime>,
443    pub handle: IndexingEntryHandle,
444    pub text: String,
445    pub chunks: Vec<Chunk>,
446}
447
448pub struct EmbedFiles {
449    pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
450    pub task: Task<Result<()>>,
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub struct EmbeddedFile {
455    pub path: Arc<Path>,
456    pub mtime: Option<MTime>,
457    pub chunks: Vec<EmbeddedChunk>,
458}
459
460#[derive(Clone, Debug, Serialize, Deserialize)]
461pub struct EmbeddedChunk {
462    pub chunk: Chunk,
463    pub embedding: Embedding,
464}
465
466fn db_key_for_path(path: &Arc<Path>) -> String {
467    path.to_string_lossy().replace('/', "\0")
468}