embedding_index.rs

  1use crate::{
  2    chunking::{self, Chunk},
  3    embedding::{Embedding, EmbeddingProvider, TextToEmbed},
  4    indexing::{IndexingEntryHandle, IndexingEntrySet},
  5};
  6use anyhow::{Context as _, Result};
  7use collections::Bound;
  8use feature_flags::FeatureFlagAppExt;
  9use fs::Fs;
 10use fs::MTime;
 11use futures::{FutureExt as _, stream::StreamExt};
 12use futures_batch::ChunksTimeoutStreamExt;
 13use gpui::{App, AppContext as _, Entity, Task};
 14use heed::types::{SerdeBincode, Str};
 15use language::LanguageRegistry;
 16use log;
 17use project::{Entry, UpdatedEntriesSet, Worktree};
 18use serde::{Deserialize, Serialize};
 19use smol::channel;
 20use std::{cmp::Ordering, future::Future, iter, path::Path, pin::pin, sync::Arc, time::Duration};
 21use util::ResultExt;
 22use worktree::Snapshot;
 23
 24pub struct EmbeddingIndex {
 25    worktree: Entity<Worktree>,
 26    db_connection: heed::Env,
 27    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 28    fs: Arc<dyn Fs>,
 29    language_registry: Arc<LanguageRegistry>,
 30    embedding_provider: Arc<dyn EmbeddingProvider>,
 31    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 32}
 33
 34impl EmbeddingIndex {
 35    pub fn new(
 36        worktree: Entity<Worktree>,
 37        fs: Arc<dyn Fs>,
 38        db_connection: heed::Env,
 39        embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 40        language_registry: Arc<LanguageRegistry>,
 41        embedding_provider: Arc<dyn EmbeddingProvider>,
 42        entry_ids_being_indexed: Arc<IndexingEntrySet>,
 43    ) -> Self {
 44        Self {
 45            worktree,
 46            fs,
 47            db_connection,
 48            db: embedding_db,
 49            language_registry,
 50            embedding_provider,
 51            entry_ids_being_indexed,
 52        }
 53    }
 54
 55    pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
 56        &self.db
 57    }
 58
 59    pub fn index_entries_changed_on_disk(
 60        &self,
 61        cx: &App,
 62    ) -> impl Future<Output = Result<()>> + use<> {
 63        if !cx.is_staff() {
 64            return async move { Ok(()) }.boxed();
 65        }
 66
 67        let worktree = self.worktree.read(cx).snapshot();
 68        let worktree_abs_path = worktree.abs_path().clone();
 69        let scan = self.scan_entries(worktree, cx);
 70        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 71        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 72        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 73        async move {
 74            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 75            Ok(())
 76        }
 77        .boxed()
 78    }
 79
 80    pub fn index_updated_entries(
 81        &self,
 82        updated_entries: UpdatedEntriesSet,
 83        cx: &App,
 84    ) -> impl Future<Output = Result<()>> + use<> {
 85        if !cx.is_staff() {
 86            return async move { Ok(()) }.boxed();
 87        }
 88
 89        let worktree = self.worktree.read(cx).snapshot();
 90        let worktree_abs_path = worktree.abs_path().clone();
 91        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 92        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 93        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 94        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 95        async move {
 96            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 97            Ok(())
 98        }
 99        .boxed()
100    }
101
102    fn scan_entries(&self, worktree: Snapshot, cx: &App) -> ScanEntries {
103        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
104        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
105        let db_connection = self.db_connection.clone();
106        let db = self.db;
107        let entries_being_indexed = self.entry_ids_being_indexed.clone();
108        let task = cx.background_spawn(async move {
109            let txn = db_connection
110                .read_txn()
111                .context("failed to create read transaction")?;
112            let mut db_entries = db
113                .iter(&txn)
114                .context("failed to create iterator")?
115                .move_between_keys()
116                .peekable();
117
118            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
119            for entry in worktree.files(false, 0) {
120                log::trace!("scanning for embedding index: {:?}", &entry.path);
121
122                let entry_db_key = db_key_for_path(&entry.path);
123
124                let mut saved_mtime = None;
125                while let Some(db_entry) = db_entries.peek() {
126                    match db_entry {
127                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
128                            Ordering::Less => {
129                                if let Some(deletion_range) = deletion_range.as_mut() {
130                                    deletion_range.1 = Bound::Included(db_path);
131                                } else {
132                                    deletion_range =
133                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
134                                }
135
136                                db_entries.next();
137                            }
138                            Ordering::Equal => {
139                                if let Some(deletion_range) = deletion_range.take() {
140                                    deleted_entry_ranges_tx
141                                        .send((
142                                            deletion_range.0.map(ToString::to_string),
143                                            deletion_range.1.map(ToString::to_string),
144                                        ))
145                                        .await?;
146                                }
147                                saved_mtime = db_embedded_file.mtime;
148                                db_entries.next();
149                                break;
150                            }
151                            Ordering::Greater => {
152                                break;
153                            }
154                        },
155                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
156                    }
157                }
158
159                if entry.mtime != saved_mtime {
160                    let handle = entries_being_indexed.insert(entry.id);
161                    updated_entries_tx.send((entry.clone(), handle)).await?;
162                }
163            }
164
165            if let Some(db_entry) = db_entries.next() {
166                let (db_path, _) = db_entry?;
167                deleted_entry_ranges_tx
168                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
169                    .await?;
170            }
171
172            Ok(())
173        });
174
175        ScanEntries {
176            updated_entries: updated_entries_rx,
177            deleted_entry_ranges: deleted_entry_ranges_rx,
178            task,
179        }
180    }
181
182    fn scan_updated_entries(
183        &self,
184        worktree: Snapshot,
185        updated_entries: UpdatedEntriesSet,
186        cx: &App,
187    ) -> ScanEntries {
188        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
189        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
190        let entries_being_indexed = self.entry_ids_being_indexed.clone();
191        let task = cx.background_spawn(async move {
192            for (path, entry_id, status) in updated_entries.iter() {
193                match status {
194                    project::PathChange::Added
195                    | project::PathChange::Updated
196                    | project::PathChange::AddedOrUpdated => {
197                        if let Some(entry) = worktree.entry_for_id(*entry_id)
198                            && entry.is_file() {
199                                let handle = entries_being_indexed.insert(entry.id);
200                                updated_entries_tx.send((entry.clone(), handle)).await?;
201                            }
202                    }
203                    project::PathChange::Removed => {
204                        let db_path = db_key_for_path(path);
205                        deleted_entry_ranges_tx
206                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
207                            .await?;
208                    }
209                    project::PathChange::Loaded => {
210                        // Do nothing.
211                    }
212                }
213            }
214
215            Ok(())
216        });
217
218        ScanEntries {
219            updated_entries: updated_entries_rx,
220            deleted_entry_ranges: deleted_entry_ranges_rx,
221            task,
222        }
223    }
224
225    fn chunk_files(
226        &self,
227        worktree_abs_path: Arc<Path>,
228        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
229        cx: &App,
230    ) -> ChunkFiles {
231        let language_registry = self.language_registry.clone();
232        let fs = self.fs.clone();
233        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
234        let task = cx.spawn(async move |cx| {
235            cx.background_executor()
236                .scoped(|cx| {
237                    for _ in 0..cx.num_cpus() {
238                        cx.spawn(async {
239                            while let Ok((entry, handle)) = entries.recv().await {
240                                let entry_abs_path = worktree_abs_path.join(&entry.path);
241                                if let Some(text) = fs.load(&entry_abs_path).await.ok() {
242                                    let language = language_registry
243                                        .language_for_file_path(&entry.path)
244                                        .await
245                                        .ok();
246                                    let chunked_file = ChunkedFile {
247                                        chunks: chunking::chunk_text(
248                                            &text,
249                                            language.as_ref(),
250                                            &entry.path,
251                                        ),
252                                        handle,
253                                        path: entry.path,
254                                        mtime: entry.mtime,
255                                        text,
256                                    };
257
258                                    if chunked_files_tx.send(chunked_file).await.is_err() {
259                                        return;
260                                    }
261                                }
262                            }
263                        });
264                    }
265                })
266                .await;
267            Ok(())
268        });
269
270        ChunkFiles {
271            files: chunked_files_rx,
272            task,
273        }
274    }
275
276    pub fn embed_files(
277        embedding_provider: Arc<dyn EmbeddingProvider>,
278        chunked_files: channel::Receiver<ChunkedFile>,
279        cx: &App,
280    ) -> EmbedFiles {
281        let embedding_provider = embedding_provider.clone();
282        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
283        let task = cx.background_spawn(async move {
284            let mut chunked_file_batches =
285                pin!(chunked_files.chunks_timeout(512, Duration::from_secs(2)));
286            while let Some(chunked_files) = chunked_file_batches.next().await {
287                // View the batch of files as a vec of chunks
288                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
289                // Once those are done, reassemble them back into the files in which they belong
290                // If any embeddings fail for a file, the entire file is discarded
291
292                let chunks: Vec<TextToEmbed> = chunked_files
293                    .iter()
294                    .flat_map(|file| {
295                        file.chunks.iter().map(|chunk| TextToEmbed {
296                            text: &file.text[chunk.range.clone()],
297                            digest: chunk.digest,
298                        })
299                    })
300                    .collect::<Vec<_>>();
301
302                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
303                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
304                    if let Some(batch_embeddings) =
305                        embedding_provider.embed(embedding_batch).await.log_err()
306                    {
307                        if batch_embeddings.len() == embedding_batch.len() {
308                            embeddings.extend(batch_embeddings.into_iter().map(Some));
309                            continue;
310                        }
311                        log::error!(
312                            "embedding provider returned unexpected embedding count {}, expected {}",
313                            batch_embeddings.len(), embedding_batch.len()
314                        );
315                    }
316
317                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
318                }
319
320                let mut embeddings = embeddings.into_iter();
321                for chunked_file in chunked_files {
322                    let mut embedded_file = EmbeddedFile {
323                        path: chunked_file.path,
324                        mtime: chunked_file.mtime,
325                        chunks: Vec::new(),
326                    };
327
328                    let mut embedded_all_chunks = true;
329                    for (chunk, embedding) in
330                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
331                    {
332                        if let Some(embedding) = embedding {
333                            embedded_file
334                                .chunks
335                                .push(EmbeddedChunk { chunk, embedding });
336                        } else {
337                            embedded_all_chunks = false;
338                        }
339                    }
340
341                    if embedded_all_chunks {
342                        embedded_files_tx
343                            .send((embedded_file, chunked_file.handle))
344                            .await?;
345                    }
346                }
347            }
348            Ok(())
349        });
350
351        EmbedFiles {
352            files: embedded_files_rx,
353            task,
354        }
355    }
356
357    fn persist_embeddings(
358        &self,
359        deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
360        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
361        cx: &App,
362    ) -> Task<Result<()>> {
363        let db_connection = self.db_connection.clone();
364        let db = self.db;
365
366        cx.background_spawn(async move {
367            let mut deleted_entry_ranges = pin!(deleted_entry_ranges);
368            let mut embedded_files = pin!(embedded_files);
369            loop {
370                // Interleave deletions and persists of embedded files
371                futures::select_biased! {
372                    deletion_range = deleted_entry_ranges.next() => {
373                        if let Some(deletion_range) = deletion_range {
374                            let mut txn = db_connection.write_txn()?;
375                            let start = deletion_range.0.as_ref().map(|start| start.as_str());
376                            let end = deletion_range.1.as_ref().map(|end| end.as_str());
377                            log::debug!("deleting embeddings in range {:?}", &(start, end));
378                            db.delete_range(&mut txn, &(start, end))?;
379                            txn.commit()?;
380                        }
381                    },
382                    file = embedded_files.next() => {
383                        if let Some((file, _)) = file {
384                            let mut txn = db_connection.write_txn()?;
385                            log::debug!("saving embedding for file {:?}", file.path);
386                            let key = db_key_for_path(&file.path);
387                            db.put(&mut txn, &key, &file)?;
388                            txn.commit()?;
389                        }
390                    },
391                    complete => break,
392                }
393            }
394
395            Ok(())
396        })
397    }
398
399    pub fn paths(&self, cx: &App) -> Task<Result<Vec<Arc<Path>>>> {
400        let connection = self.db_connection.clone();
401        let db = self.db;
402        cx.background_spawn(async move {
403            let tx = connection
404                .read_txn()
405                .context("failed to create read transaction")?;
406            let result = db
407                .iter(&tx)?
408                .map(|entry| Ok(entry?.1.path.clone()))
409                .collect::<Result<Vec<Arc<Path>>>>();
410            drop(tx);
411            result
412        })
413    }
414
415    pub fn chunks_for_path(&self, path: Arc<Path>, cx: &App) -> Task<Result<Vec<EmbeddedChunk>>> {
416        let connection = self.db_connection.clone();
417        let db = self.db;
418        cx.background_spawn(async move {
419            let tx = connection
420                .read_txn()
421                .context("failed to create read transaction")?;
422            Ok(db
423                .get(&tx, &db_key_for_path(&path))?
424                .context("no such path")?
425                .chunks
426                .clone())
427        })
428    }
429}
430
431struct ScanEntries {
432    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
433    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
434    task: Task<Result<()>>,
435}
436
437struct ChunkFiles {
438    files: channel::Receiver<ChunkedFile>,
439    task: Task<Result<()>>,
440}
441
442pub struct ChunkedFile {
443    pub path: Arc<Path>,
444    pub mtime: Option<MTime>,
445    pub handle: IndexingEntryHandle,
446    pub text: String,
447    pub chunks: Vec<Chunk>,
448}
449
450pub struct EmbedFiles {
451    pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
452    pub task: Task<Result<()>>,
453}
454
455#[derive(Debug, Serialize, Deserialize)]
456pub struct EmbeddedFile {
457    pub path: Arc<Path>,
458    pub mtime: Option<MTime>,
459    pub chunks: Vec<EmbeddedChunk>,
460}
461
462#[derive(Clone, Debug, Serialize, Deserialize)]
463pub struct EmbeddedChunk {
464    pub chunk: Chunk,
465    pub embedding: Embedding,
466}
467
468fn db_key_for_path(path: &Arc<Path>) -> String {
469    path.to_string_lossy().replace('/', "\0")
470}