embedding_index.rs

  1use crate::{
  2    chunking::{self, Chunk},
  3    embedding::{Embedding, EmbeddingProvider, TextToEmbed},
  4    indexing::{IndexingEntryHandle, IndexingEntrySet},
  5};
  6use anyhow::{anyhow, Context as _, Result};
  7use collections::Bound;
  8use feature_flags::FeatureFlagAppExt;
  9use fs::Fs;
 10use fs::MTime;
 11use futures::stream::StreamExt;
 12use futures_batch::ChunksTimeoutStreamExt;
 13use gpui::{AppContext, Model, Task};
 14use heed::types::{SerdeBincode, Str};
 15use language::LanguageRegistry;
 16use log;
 17use project::{Entry, UpdatedEntriesSet, Worktree};
 18use serde::{Deserialize, Serialize};
 19use smol::channel;
 20use smol::future::FutureExt;
 21use std::{cmp::Ordering, future::Future, iter, path::Path, sync::Arc, time::Duration};
 22use util::ResultExt;
 23use worktree::Snapshot;
 24
 25pub struct EmbeddingIndex {
 26    worktree: Model<Worktree>,
 27    db_connection: heed::Env,
 28    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 29    fs: Arc<dyn Fs>,
 30    language_registry: Arc<LanguageRegistry>,
 31    embedding_provider: Arc<dyn EmbeddingProvider>,
 32    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 33}
 34
 35impl EmbeddingIndex {
 36    pub fn new(
 37        worktree: Model<Worktree>,
 38        fs: Arc<dyn Fs>,
 39        db_connection: heed::Env,
 40        embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 41        language_registry: Arc<LanguageRegistry>,
 42        embedding_provider: Arc<dyn EmbeddingProvider>,
 43        entry_ids_being_indexed: Arc<IndexingEntrySet>,
 44    ) -> Self {
 45        Self {
 46            worktree,
 47            fs,
 48            db_connection,
 49            db: embedding_db,
 50            language_registry,
 51            embedding_provider,
 52            entry_ids_being_indexed,
 53        }
 54    }
 55
 56    pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
 57        &self.db
 58    }
 59
 60    pub fn index_entries_changed_on_disk(
 61        &self,
 62        cx: &AppContext,
 63    ) -> impl Future<Output = Result<()>> {
 64        if !cx.is_staff() {
 65            return async move { Ok(()) }.boxed();
 66        }
 67
 68        let worktree = self.worktree.read(cx).snapshot();
 69        let worktree_abs_path = worktree.abs_path().clone();
 70        let scan = self.scan_entries(worktree, cx);
 71        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 72        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 73        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 74        async move {
 75            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 76            Ok(())
 77        }
 78        .boxed()
 79    }
 80
 81    pub fn index_updated_entries(
 82        &self,
 83        updated_entries: UpdatedEntriesSet,
 84        cx: &AppContext,
 85    ) -> impl Future<Output = Result<()>> {
 86        if !cx.is_staff() {
 87            return async move { Ok(()) }.boxed();
 88        }
 89
 90        let worktree = self.worktree.read(cx).snapshot();
 91        let worktree_abs_path = worktree.abs_path().clone();
 92        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 93        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 94        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 95        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 96        async move {
 97            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 98            Ok(())
 99        }
100        .boxed()
101    }
102
103    fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
104        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
105        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
106        let db_connection = self.db_connection.clone();
107        let db = self.db;
108        let entries_being_indexed = self.entry_ids_being_indexed.clone();
109        let task = cx.background_executor().spawn(async move {
110            let txn = db_connection
111                .read_txn()
112                .context("failed to create read transaction")?;
113            let mut db_entries = db
114                .iter(&txn)
115                .context("failed to create iterator")?
116                .move_between_keys()
117                .peekable();
118
119            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
120            for entry in worktree.files(false, 0) {
121                log::trace!("scanning for embedding index: {:?}", &entry.path);
122
123                let entry_db_key = db_key_for_path(&entry.path);
124
125                let mut saved_mtime = None;
126                while let Some(db_entry) = db_entries.peek() {
127                    match db_entry {
128                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
129                            Ordering::Less => {
130                                if let Some(deletion_range) = deletion_range.as_mut() {
131                                    deletion_range.1 = Bound::Included(db_path);
132                                } else {
133                                    deletion_range =
134                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
135                                }
136
137                                db_entries.next();
138                            }
139                            Ordering::Equal => {
140                                if let Some(deletion_range) = deletion_range.take() {
141                                    deleted_entry_ranges_tx
142                                        .send((
143                                            deletion_range.0.map(ToString::to_string),
144                                            deletion_range.1.map(ToString::to_string),
145                                        ))
146                                        .await?;
147                                }
148                                saved_mtime = db_embedded_file.mtime;
149                                db_entries.next();
150                                break;
151                            }
152                            Ordering::Greater => {
153                                break;
154                            }
155                        },
156                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
157                    }
158                }
159
160                if entry.mtime != saved_mtime {
161                    let handle = entries_being_indexed.insert(entry.id);
162                    updated_entries_tx.send((entry.clone(), handle)).await?;
163                }
164            }
165
166            if let Some(db_entry) = db_entries.next() {
167                let (db_path, _) = db_entry?;
168                deleted_entry_ranges_tx
169                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
170                    .await?;
171            }
172
173            Ok(())
174        });
175
176        ScanEntries {
177            updated_entries: updated_entries_rx,
178            deleted_entry_ranges: deleted_entry_ranges_rx,
179            task,
180        }
181    }
182
183    fn scan_updated_entries(
184        &self,
185        worktree: Snapshot,
186        updated_entries: UpdatedEntriesSet,
187        cx: &AppContext,
188    ) -> ScanEntries {
189        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
190        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
191        let entries_being_indexed = self.entry_ids_being_indexed.clone();
192        let task = cx.background_executor().spawn(async move {
193            for (path, entry_id, status) in updated_entries.iter() {
194                match status {
195                    project::PathChange::Added
196                    | project::PathChange::Updated
197                    | project::PathChange::AddedOrUpdated => {
198                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
199                            if entry.is_file() {
200                                let handle = entries_being_indexed.insert(entry.id);
201                                updated_entries_tx.send((entry.clone(), handle)).await?;
202                            }
203                        }
204                    }
205                    project::PathChange::Removed => {
206                        let db_path = db_key_for_path(path);
207                        deleted_entry_ranges_tx
208                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
209                            .await?;
210                    }
211                    project::PathChange::Loaded => {
212                        // Do nothing.
213                    }
214                }
215            }
216
217            Ok(())
218        });
219
220        ScanEntries {
221            updated_entries: updated_entries_rx,
222            deleted_entry_ranges: deleted_entry_ranges_rx,
223            task,
224        }
225    }
226
227    fn chunk_files(
228        &self,
229        worktree_abs_path: Arc<Path>,
230        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
231        cx: &AppContext,
232    ) -> ChunkFiles {
233        let language_registry = self.language_registry.clone();
234        let fs = self.fs.clone();
235        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
236        let task = cx.spawn(|cx| async move {
237            cx.background_executor()
238                .scoped(|cx| {
239                    for _ in 0..cx.num_cpus() {
240                        cx.spawn(async {
241                            while let Ok((entry, handle)) = entries.recv().await {
242                                let entry_abs_path = worktree_abs_path.join(&entry.path);
243                                if let Some(text) = fs.load(&entry_abs_path).await.ok() {
244                                    let language = language_registry
245                                        .language_for_file_path(&entry.path)
246                                        .await
247                                        .ok();
248                                    let chunked_file = ChunkedFile {
249                                        chunks: chunking::chunk_text(
250                                            &text,
251                                            language.as_ref(),
252                                            &entry.path,
253                                        ),
254                                        handle,
255                                        path: entry.path,
256                                        mtime: entry.mtime,
257                                        text,
258                                    };
259
260                                    if chunked_files_tx.send(chunked_file).await.is_err() {
261                                        return;
262                                    }
263                                }
264                            }
265                        });
266                    }
267                })
268                .await;
269            Ok(())
270        });
271
272        ChunkFiles {
273            files: chunked_files_rx,
274            task,
275        }
276    }
277
278    pub fn embed_files(
279        embedding_provider: Arc<dyn EmbeddingProvider>,
280        chunked_files: channel::Receiver<ChunkedFile>,
281        cx: &AppContext,
282    ) -> EmbedFiles {
283        let embedding_provider = embedding_provider.clone();
284        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
285        let task = cx.background_executor().spawn(async move {
286            let mut chunked_file_batches =
287                chunked_files.chunks_timeout(512, Duration::from_secs(2));
288            while let Some(chunked_files) = chunked_file_batches.next().await {
289                // View the batch of files as a vec of chunks
290                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
291                // Once those are done, reassemble them back into the files in which they belong
292                // If any embeddings fail for a file, the entire file is discarded
293
294                let chunks: Vec<TextToEmbed> = chunked_files
295                    .iter()
296                    .flat_map(|file| {
297                        file.chunks.iter().map(|chunk| TextToEmbed {
298                            text: &file.text[chunk.range.clone()],
299                            digest: chunk.digest,
300                        })
301                    })
302                    .collect::<Vec<_>>();
303
304                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
305                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
306                    if let Some(batch_embeddings) =
307                        embedding_provider.embed(embedding_batch).await.log_err()
308                    {
309                        if batch_embeddings.len() == embedding_batch.len() {
310                            embeddings.extend(batch_embeddings.into_iter().map(Some));
311                            continue;
312                        }
313                        log::error!(
314                            "embedding provider returned unexpected embedding count {}, expected {}",
315                            batch_embeddings.len(), embedding_batch.len()
316                        );
317                    }
318
319                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
320                }
321
322                let mut embeddings = embeddings.into_iter();
323                for chunked_file in chunked_files {
324                    let mut embedded_file = EmbeddedFile {
325                        path: chunked_file.path,
326                        mtime: chunked_file.mtime,
327                        chunks: Vec::new(),
328                    };
329
330                    let mut embedded_all_chunks = true;
331                    for (chunk, embedding) in
332                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
333                    {
334                        if let Some(embedding) = embedding {
335                            embedded_file
336                                .chunks
337                                .push(EmbeddedChunk { chunk, embedding });
338                        } else {
339                            embedded_all_chunks = false;
340                        }
341                    }
342
343                    if embedded_all_chunks {
344                        embedded_files_tx
345                            .send((embedded_file, chunked_file.handle))
346                            .await?;
347                    }
348                }
349            }
350            Ok(())
351        });
352
353        EmbedFiles {
354            files: embedded_files_rx,
355            task,
356        }
357    }
358
359    fn persist_embeddings(
360        &self,
361        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
362        mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
363        cx: &AppContext,
364    ) -> Task<Result<()>> {
365        let db_connection = self.db_connection.clone();
366        let db = self.db;
367
368        cx.background_executor().spawn(async move {
369            loop {
370                // Interleave deletions and persists of embedded files
371                futures::select_biased! {
372                    deletion_range = deleted_entry_ranges.next() => {
373                        if let Some(deletion_range) = deletion_range {
374                            let mut txn = db_connection.write_txn()?;
375                            let start = deletion_range.0.as_ref().map(|start| start.as_str());
376                            let end = deletion_range.1.as_ref().map(|end| end.as_str());
377                            log::debug!("deleting embeddings in range {:?}", &(start, end));
378                            db.delete_range(&mut txn, &(start, end))?;
379                            txn.commit()?;
380                        }
381                    },
382                    file = embedded_files.next() => {
383                        if let Some((file, _)) = file {
384                            let mut txn = db_connection.write_txn()?;
385                            log::debug!("saving embedding for file {:?}", file.path);
386                            let key = db_key_for_path(&file.path);
387                            db.put(&mut txn, &key, &file)?;
388                            txn.commit()?;
389                        }
390                    },
391                    complete => break,
392                }
393            }
394
395            Ok(())
396        })
397    }
398
399    pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
400        let connection = self.db_connection.clone();
401        let db = self.db;
402        cx.background_executor().spawn(async move {
403            let tx = connection
404                .read_txn()
405                .context("failed to create read transaction")?;
406            let result = db
407                .iter(&tx)?
408                .map(|entry| Ok(entry?.1.path.clone()))
409                .collect::<Result<Vec<Arc<Path>>>>();
410            drop(tx);
411            result
412        })
413    }
414
415    pub fn chunks_for_path(
416        &self,
417        path: Arc<Path>,
418        cx: &AppContext,
419    ) -> Task<Result<Vec<EmbeddedChunk>>> {
420        let connection = self.db_connection.clone();
421        let db = self.db;
422        cx.background_executor().spawn(async move {
423            let tx = connection
424                .read_txn()
425                .context("failed to create read transaction")?;
426            Ok(db
427                .get(&tx, &db_key_for_path(&path))?
428                .ok_or_else(|| anyhow!("no such path"))?
429                .chunks
430                .clone())
431        })
432    }
433}
434
435struct ScanEntries {
436    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
437    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
438    task: Task<Result<()>>,
439}
440
441struct ChunkFiles {
442    files: channel::Receiver<ChunkedFile>,
443    task: Task<Result<()>>,
444}
445
446pub struct ChunkedFile {
447    pub path: Arc<Path>,
448    pub mtime: Option<MTime>,
449    pub handle: IndexingEntryHandle,
450    pub text: String,
451    pub chunks: Vec<Chunk>,
452}
453
454pub struct EmbedFiles {
455    pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
456    pub task: Task<Result<()>>,
457}
458
459#[derive(Debug, Serialize, Deserialize)]
460pub struct EmbeddedFile {
461    pub path: Arc<Path>,
462    pub mtime: Option<MTime>,
463    pub chunks: Vec<EmbeddedChunk>,
464}
465
466#[derive(Clone, Debug, Serialize, Deserialize)]
467pub struct EmbeddedChunk {
468    pub chunk: Chunk,
469    pub embedding: Embedding,
470}
471
472fn db_key_for_path(path: &Arc<Path>) -> String {
473    path.to_string_lossy().replace('/', "\0")
474}