embedding_index.rs

  1use crate::{
  2    chunking::{self, Chunk},
  3    embedding::{Embedding, EmbeddingProvider, TextToEmbed},
  4    indexing::{IndexingEntryHandle, IndexingEntrySet},
  5};
  6use anyhow::{anyhow, Context as _, Result};
  7use collections::Bound;
  8use feature_flags::FeatureFlagAppExt;
  9use fs::Fs;
 10use futures::stream::StreamExt;
 11use futures_batch::ChunksTimeoutStreamExt;
 12use gpui::{AppContext, Model, Task};
 13use heed::types::{SerdeBincode, Str};
 14use language::LanguageRegistry;
 15use log;
 16use project::{Entry, UpdatedEntriesSet, Worktree};
 17use serde::{Deserialize, Serialize};
 18use smol::channel;
 19use smol::future::FutureExt;
 20use std::{
 21    cmp::Ordering,
 22    future::Future,
 23    iter,
 24    path::Path,
 25    sync::Arc,
 26    time::{Duration, SystemTime},
 27};
 28use util::ResultExt;
 29use worktree::Snapshot;
 30
 31pub struct EmbeddingIndex {
 32    worktree: Model<Worktree>,
 33    db_connection: heed::Env,
 34    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 35    fs: Arc<dyn Fs>,
 36    language_registry: Arc<LanguageRegistry>,
 37    embedding_provider: Arc<dyn EmbeddingProvider>,
 38    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 39}
 40
 41impl EmbeddingIndex {
 42    pub fn new(
 43        worktree: Model<Worktree>,
 44        fs: Arc<dyn Fs>,
 45        db_connection: heed::Env,
 46        embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 47        language_registry: Arc<LanguageRegistry>,
 48        embedding_provider: Arc<dyn EmbeddingProvider>,
 49        entry_ids_being_indexed: Arc<IndexingEntrySet>,
 50    ) -> Self {
 51        Self {
 52            worktree,
 53            fs,
 54            db_connection,
 55            db: embedding_db,
 56            language_registry,
 57            embedding_provider,
 58            entry_ids_being_indexed,
 59        }
 60    }
 61
 62    pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
 63        &self.db
 64    }
 65
 66    pub fn index_entries_changed_on_disk(
 67        &self,
 68        cx: &AppContext,
 69    ) -> impl Future<Output = Result<()>> {
 70        if !cx.is_staff() {
 71            return async move { Ok(()) }.boxed();
 72        }
 73
 74        let worktree = self.worktree.read(cx).snapshot();
 75        let worktree_abs_path = worktree.abs_path().clone();
 76        let scan = self.scan_entries(worktree, cx);
 77        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 78        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 79        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 80        async move {
 81            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 82            Ok(())
 83        }
 84        .boxed()
 85    }
 86
 87    pub fn index_updated_entries(
 88        &self,
 89        updated_entries: UpdatedEntriesSet,
 90        cx: &AppContext,
 91    ) -> impl Future<Output = Result<()>> {
 92        if !cx.is_staff() {
 93            return async move { Ok(()) }.boxed();
 94        }
 95
 96        let worktree = self.worktree.read(cx).snapshot();
 97        let worktree_abs_path = worktree.abs_path().clone();
 98        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 99        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
100        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
101        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
102        async move {
103            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
104            Ok(())
105        }
106        .boxed()
107    }
108
109    fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
110        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
111        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
112        let db_connection = self.db_connection.clone();
113        let db = self.db;
114        let entries_being_indexed = self.entry_ids_being_indexed.clone();
115        let task = cx.background_executor().spawn(async move {
116            let txn = db_connection
117                .read_txn()
118                .context("failed to create read transaction")?;
119            let mut db_entries = db
120                .iter(&txn)
121                .context("failed to create iterator")?
122                .move_between_keys()
123                .peekable();
124
125            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
126            for entry in worktree.files(false, 0) {
127                log::trace!("scanning for embedding index: {:?}", &entry.path);
128
129                let entry_db_key = db_key_for_path(&entry.path);
130
131                let mut saved_mtime = None;
132                while let Some(db_entry) = db_entries.peek() {
133                    match db_entry {
134                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
135                            Ordering::Less => {
136                                if let Some(deletion_range) = deletion_range.as_mut() {
137                                    deletion_range.1 = Bound::Included(db_path);
138                                } else {
139                                    deletion_range =
140                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
141                                }
142
143                                db_entries.next();
144                            }
145                            Ordering::Equal => {
146                                if let Some(deletion_range) = deletion_range.take() {
147                                    deleted_entry_ranges_tx
148                                        .send((
149                                            deletion_range.0.map(ToString::to_string),
150                                            deletion_range.1.map(ToString::to_string),
151                                        ))
152                                        .await?;
153                                }
154                                saved_mtime = db_embedded_file.mtime;
155                                db_entries.next();
156                                break;
157                            }
158                            Ordering::Greater => {
159                                break;
160                            }
161                        },
162                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
163                    }
164                }
165
166                if entry.mtime != saved_mtime {
167                    let handle = entries_being_indexed.insert(entry.id);
168                    updated_entries_tx.send((entry.clone(), handle)).await?;
169                }
170            }
171
172            if let Some(db_entry) = db_entries.next() {
173                let (db_path, _) = db_entry?;
174                deleted_entry_ranges_tx
175                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
176                    .await?;
177            }
178
179            Ok(())
180        });
181
182        ScanEntries {
183            updated_entries: updated_entries_rx,
184            deleted_entry_ranges: deleted_entry_ranges_rx,
185            task,
186        }
187    }
188
189    fn scan_updated_entries(
190        &self,
191        worktree: Snapshot,
192        updated_entries: UpdatedEntriesSet,
193        cx: &AppContext,
194    ) -> ScanEntries {
195        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
196        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
197        let entries_being_indexed = self.entry_ids_being_indexed.clone();
198        let task = cx.background_executor().spawn(async move {
199            for (path, entry_id, status) in updated_entries.iter() {
200                match status {
201                    project::PathChange::Added
202                    | project::PathChange::Updated
203                    | project::PathChange::AddedOrUpdated => {
204                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
205                            if entry.is_file() {
206                                let handle = entries_being_indexed.insert(entry.id);
207                                updated_entries_tx.send((entry.clone(), handle)).await?;
208                            }
209                        }
210                    }
211                    project::PathChange::Removed => {
212                        let db_path = db_key_for_path(path);
213                        deleted_entry_ranges_tx
214                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
215                            .await?;
216                    }
217                    project::PathChange::Loaded => {
218                        // Do nothing.
219                    }
220                }
221            }
222
223            Ok(())
224        });
225
226        ScanEntries {
227            updated_entries: updated_entries_rx,
228            deleted_entry_ranges: deleted_entry_ranges_rx,
229            task,
230        }
231    }
232
233    fn chunk_files(
234        &self,
235        worktree_abs_path: Arc<Path>,
236        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
237        cx: &AppContext,
238    ) -> ChunkFiles {
239        let language_registry = self.language_registry.clone();
240        let fs = self.fs.clone();
241        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
242        let task = cx.spawn(|cx| async move {
243            cx.background_executor()
244                .scoped(|cx| {
245                    for _ in 0..cx.num_cpus() {
246                        cx.spawn(async {
247                            while let Ok((entry, handle)) = entries.recv().await {
248                                let entry_abs_path = worktree_abs_path.join(&entry.path);
249                                if let Some(text) = fs.load(&entry_abs_path).await.ok() {
250                                    let language = language_registry
251                                        .language_for_file_path(&entry.path)
252                                        .await
253                                        .ok();
254                                    let chunked_file = ChunkedFile {
255                                        chunks: chunking::chunk_text(
256                                            &text,
257                                            language.as_ref(),
258                                            &entry.path,
259                                        ),
260                                        handle,
261                                        path: entry.path,
262                                        mtime: entry.mtime,
263                                        text,
264                                    };
265
266                                    if chunked_files_tx.send(chunked_file).await.is_err() {
267                                        return;
268                                    }
269                                }
270                            }
271                        });
272                    }
273                })
274                .await;
275            Ok(())
276        });
277
278        ChunkFiles {
279            files: chunked_files_rx,
280            task,
281        }
282    }
283
284    pub fn embed_files(
285        embedding_provider: Arc<dyn EmbeddingProvider>,
286        chunked_files: channel::Receiver<ChunkedFile>,
287        cx: &AppContext,
288    ) -> EmbedFiles {
289        let embedding_provider = embedding_provider.clone();
290        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
291        let task = cx.background_executor().spawn(async move {
292            let mut chunked_file_batches =
293                chunked_files.chunks_timeout(512, Duration::from_secs(2));
294            while let Some(chunked_files) = chunked_file_batches.next().await {
295                // View the batch of files as a vec of chunks
296                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
297                // Once those are done, reassemble them back into the files in which they belong
298                // If any embeddings fail for a file, the entire file is discarded
299
300                let chunks: Vec<TextToEmbed> = chunked_files
301                    .iter()
302                    .flat_map(|file| {
303                        file.chunks.iter().map(|chunk| TextToEmbed {
304                            text: &file.text[chunk.range.clone()],
305                            digest: chunk.digest,
306                        })
307                    })
308                    .collect::<Vec<_>>();
309
310                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
311                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
312                    if let Some(batch_embeddings) =
313                        embedding_provider.embed(embedding_batch).await.log_err()
314                    {
315                        if batch_embeddings.len() == embedding_batch.len() {
316                            embeddings.extend(batch_embeddings.into_iter().map(Some));
317                            continue;
318                        }
319                        log::error!(
320                            "embedding provider returned unexpected embedding count {}, expected {}",
321                            batch_embeddings.len(), embedding_batch.len()
322                        );
323                    }
324
325                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
326                }
327
328                let mut embeddings = embeddings.into_iter();
329                for chunked_file in chunked_files {
330                    let mut embedded_file = EmbeddedFile {
331                        path: chunked_file.path,
332                        mtime: chunked_file.mtime,
333                        chunks: Vec::new(),
334                    };
335
336                    let mut embedded_all_chunks = true;
337                    for (chunk, embedding) in
338                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
339                    {
340                        if let Some(embedding) = embedding {
341                            embedded_file
342                                .chunks
343                                .push(EmbeddedChunk { chunk, embedding });
344                        } else {
345                            embedded_all_chunks = false;
346                        }
347                    }
348
349                    if embedded_all_chunks {
350                        embedded_files_tx
351                            .send((embedded_file, chunked_file.handle))
352                            .await?;
353                    }
354                }
355            }
356            Ok(())
357        });
358
359        EmbedFiles {
360            files: embedded_files_rx,
361            task,
362        }
363    }
364
365    fn persist_embeddings(
366        &self,
367        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
368        mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
369        cx: &AppContext,
370    ) -> Task<Result<()>> {
371        let db_connection = self.db_connection.clone();
372        let db = self.db;
373
374        cx.background_executor().spawn(async move {
375            loop {
376                // Interleave deletions and persists of embedded files
377                futures::select_biased! {
378                    deletion_range = deleted_entry_ranges.next() => {
379                        if let Some(deletion_range) = deletion_range {
380                            let mut txn = db_connection.write_txn()?;
381                            let start = deletion_range.0.as_ref().map(|start| start.as_str());
382                            let end = deletion_range.1.as_ref().map(|end| end.as_str());
383                            log::debug!("deleting embeddings in range {:?}", &(start, end));
384                            db.delete_range(&mut txn, &(start, end))?;
385                            txn.commit()?;
386                        }
387                    },
388                    file = embedded_files.next() => {
389                        if let Some((file, _)) = file {
390                            let mut txn = db_connection.write_txn()?;
391                            log::debug!("saving embedding for file {:?}", file.path);
392                            let key = db_key_for_path(&file.path);
393                            db.put(&mut txn, &key, &file)?;
394                            txn.commit()?;
395                        }
396                    },
397                    complete => break,
398                }
399            }
400
401            Ok(())
402        })
403    }
404
405    pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
406        let connection = self.db_connection.clone();
407        let db = self.db;
408        cx.background_executor().spawn(async move {
409            let tx = connection
410                .read_txn()
411                .context("failed to create read transaction")?;
412            let result = db
413                .iter(&tx)?
414                .map(|entry| Ok(entry?.1.path.clone()))
415                .collect::<Result<Vec<Arc<Path>>>>();
416            drop(tx);
417            result
418        })
419    }
420
421    pub fn chunks_for_path(
422        &self,
423        path: Arc<Path>,
424        cx: &AppContext,
425    ) -> Task<Result<Vec<EmbeddedChunk>>> {
426        let connection = self.db_connection.clone();
427        let db = self.db;
428        cx.background_executor().spawn(async move {
429            let tx = connection
430                .read_txn()
431                .context("failed to create read transaction")?;
432            Ok(db
433                .get(&tx, &db_key_for_path(&path))?
434                .ok_or_else(|| anyhow!("no such path"))?
435                .chunks
436                .clone())
437        })
438    }
439}
440
441struct ScanEntries {
442    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
443    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
444    task: Task<Result<()>>,
445}
446
447struct ChunkFiles {
448    files: channel::Receiver<ChunkedFile>,
449    task: Task<Result<()>>,
450}
451
452pub struct ChunkedFile {
453    pub path: Arc<Path>,
454    pub mtime: Option<SystemTime>,
455    pub handle: IndexingEntryHandle,
456    pub text: String,
457    pub chunks: Vec<Chunk>,
458}
459
460pub struct EmbedFiles {
461    pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
462    pub task: Task<Result<()>>,
463}
464
465#[derive(Debug, Serialize, Deserialize)]
466pub struct EmbeddedFile {
467    pub path: Arc<Path>,
468    pub mtime: Option<SystemTime>,
469    pub chunks: Vec<EmbeddedChunk>,
470}
471
472#[derive(Clone, Debug, Serialize, Deserialize)]
473pub struct EmbeddedChunk {
474    pub chunk: Chunk,
475    pub embedding: Embedding,
476}
477
478fn db_key_for_path(path: &Arc<Path>) -> String {
479    path.to_string_lossy().replace('/', "\0")
480}