embedding_index.rs

  1use crate::{
  2    chunking::{self, Chunk},
  3    embedding::{Embedding, EmbeddingProvider, TextToEmbed},
  4    indexing::{IndexingEntryHandle, IndexingEntrySet},
  5};
  6use anyhow::{anyhow, Context as _, Result};
  7use collections::Bound;
  8use fs::Fs;
  9use futures::stream::StreamExt;
 10use futures_batch::ChunksTimeoutStreamExt;
 11use gpui::{AppContext, Model, Task};
 12use heed::types::{SerdeBincode, Str};
 13use language::LanguageRegistry;
 14use log;
 15use project::{Entry, UpdatedEntriesSet, Worktree};
 16use serde::{Deserialize, Serialize};
 17use smol::channel;
 18use std::{
 19    cmp::Ordering,
 20    future::Future,
 21    iter,
 22    path::Path,
 23    sync::Arc,
 24    time::{Duration, SystemTime},
 25};
 26use util::ResultExt;
 27use worktree::Snapshot;
 28
 29pub struct EmbeddingIndex {
 30    worktree: Model<Worktree>,
 31    db_connection: heed::Env,
 32    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 33    fs: Arc<dyn Fs>,
 34    language_registry: Arc<LanguageRegistry>,
 35    embedding_provider: Arc<dyn EmbeddingProvider>,
 36    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 37}
 38
 39impl EmbeddingIndex {
 40    pub fn new(
 41        worktree: Model<Worktree>,
 42        fs: Arc<dyn Fs>,
 43        db_connection: heed::Env,
 44        embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 45        language_registry: Arc<LanguageRegistry>,
 46        embedding_provider: Arc<dyn EmbeddingProvider>,
 47        entry_ids_being_indexed: Arc<IndexingEntrySet>,
 48    ) -> Self {
 49        Self {
 50            worktree,
 51            fs,
 52            db_connection,
 53            db: embedding_db,
 54            language_registry,
 55            embedding_provider,
 56            entry_ids_being_indexed,
 57        }
 58    }
 59
 60    pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
 61        &self.db
 62    }
 63
 64    pub fn index_entries_changed_on_disk(
 65        &self,
 66        cx: &AppContext,
 67    ) -> impl Future<Output = Result<()>> {
 68        let worktree = self.worktree.read(cx).snapshot();
 69        let worktree_abs_path = worktree.abs_path().clone();
 70        let scan = self.scan_entries(worktree, cx);
 71        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 72        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 73        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 74        async move {
 75            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 76            Ok(())
 77        }
 78    }
 79
 80    pub fn index_updated_entries(
 81        &self,
 82        updated_entries: UpdatedEntriesSet,
 83        cx: &AppContext,
 84    ) -> impl Future<Output = Result<()>> {
 85        let worktree = self.worktree.read(cx).snapshot();
 86        let worktree_abs_path = worktree.abs_path().clone();
 87        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 88        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 89        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 90        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 91        async move {
 92            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 93            Ok(())
 94        }
 95    }
 96
 97    fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
 98        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 99        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
100        let db_connection = self.db_connection.clone();
101        let db = self.db;
102        let entries_being_indexed = self.entry_ids_being_indexed.clone();
103        let task = cx.background_executor().spawn(async move {
104            let txn = db_connection
105                .read_txn()
106                .context("failed to create read transaction")?;
107            let mut db_entries = db
108                .iter(&txn)
109                .context("failed to create iterator")?
110                .move_between_keys()
111                .peekable();
112
113            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
114            for entry in worktree.files(false, 0) {
115                log::trace!("scanning for embedding index: {:?}", &entry.path);
116
117                let entry_db_key = db_key_for_path(&entry.path);
118
119                let mut saved_mtime = None;
120                while let Some(db_entry) = db_entries.peek() {
121                    match db_entry {
122                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
123                            Ordering::Less => {
124                                if let Some(deletion_range) = deletion_range.as_mut() {
125                                    deletion_range.1 = Bound::Included(db_path);
126                                } else {
127                                    deletion_range =
128                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
129                                }
130
131                                db_entries.next();
132                            }
133                            Ordering::Equal => {
134                                if let Some(deletion_range) = deletion_range.take() {
135                                    deleted_entry_ranges_tx
136                                        .send((
137                                            deletion_range.0.map(ToString::to_string),
138                                            deletion_range.1.map(ToString::to_string),
139                                        ))
140                                        .await?;
141                                }
142                                saved_mtime = db_embedded_file.mtime;
143                                db_entries.next();
144                                break;
145                            }
146                            Ordering::Greater => {
147                                break;
148                            }
149                        },
150                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
151                    }
152                }
153
154                if entry.mtime != saved_mtime {
155                    let handle = entries_being_indexed.insert(entry.id);
156                    updated_entries_tx.send((entry.clone(), handle)).await?;
157                }
158            }
159
160            if let Some(db_entry) = db_entries.next() {
161                let (db_path, _) = db_entry?;
162                deleted_entry_ranges_tx
163                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
164                    .await?;
165            }
166
167            Ok(())
168        });
169
170        ScanEntries {
171            updated_entries: updated_entries_rx,
172            deleted_entry_ranges: deleted_entry_ranges_rx,
173            task,
174        }
175    }
176
177    fn scan_updated_entries(
178        &self,
179        worktree: Snapshot,
180        updated_entries: UpdatedEntriesSet,
181        cx: &AppContext,
182    ) -> ScanEntries {
183        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
184        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
185        let entries_being_indexed = self.entry_ids_being_indexed.clone();
186        let task = cx.background_executor().spawn(async move {
187            for (path, entry_id, status) in updated_entries.iter() {
188                match status {
189                    project::PathChange::Added
190                    | project::PathChange::Updated
191                    | project::PathChange::AddedOrUpdated => {
192                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
193                            if entry.is_file() {
194                                let handle = entries_being_indexed.insert(entry.id);
195                                updated_entries_tx.send((entry.clone(), handle)).await?;
196                            }
197                        }
198                    }
199                    project::PathChange::Removed => {
200                        let db_path = db_key_for_path(path);
201                        deleted_entry_ranges_tx
202                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
203                            .await?;
204                    }
205                    project::PathChange::Loaded => {
206                        // Do nothing.
207                    }
208                }
209            }
210
211            Ok(())
212        });
213
214        ScanEntries {
215            updated_entries: updated_entries_rx,
216            deleted_entry_ranges: deleted_entry_ranges_rx,
217            task,
218        }
219    }
220
221    fn chunk_files(
222        &self,
223        worktree_abs_path: Arc<Path>,
224        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
225        cx: &AppContext,
226    ) -> ChunkFiles {
227        let language_registry = self.language_registry.clone();
228        let fs = self.fs.clone();
229        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
230        let task = cx.spawn(|cx| async move {
231            cx.background_executor()
232                .scoped(|cx| {
233                    for _ in 0..cx.num_cpus() {
234                        cx.spawn(async {
235                            while let Ok((entry, handle)) = entries.recv().await {
236                                let entry_abs_path = worktree_abs_path.join(&entry.path);
237                                match fs.load(&entry_abs_path).await {
238                                    Ok(text) => {
239                                        let language = language_registry
240                                            .language_for_file_path(&entry.path)
241                                            .await
242                                            .ok();
243                                        let chunked_file = ChunkedFile {
244                                            chunks: chunking::chunk_text(
245                                                &text,
246                                                language.as_ref(),
247                                                &entry.path,
248                                            ),
249                                            handle,
250                                            path: entry.path,
251                                            mtime: entry.mtime,
252                                            text,
253                                        };
254
255                                        if chunked_files_tx.send(chunked_file).await.is_err() {
256                                            return;
257                                        }
258                                    }
259                                    Err(_)=> {
260                                        log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
261                                    }
262                                }
263                            }
264                        });
265                    }
266                })
267                .await;
268            Ok(())
269        });
270
271        ChunkFiles {
272            files: chunked_files_rx,
273            task,
274        }
275    }
276
277    pub fn embed_files(
278        embedding_provider: Arc<dyn EmbeddingProvider>,
279        chunked_files: channel::Receiver<ChunkedFile>,
280        cx: &AppContext,
281    ) -> EmbedFiles {
282        let embedding_provider = embedding_provider.clone();
283        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
284        let task = cx.background_executor().spawn(async move {
285            let mut chunked_file_batches =
286                chunked_files.chunks_timeout(512, Duration::from_secs(2));
287            while let Some(chunked_files) = chunked_file_batches.next().await {
288                // View the batch of files as a vec of chunks
289                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
290                // Once those are done, reassemble them back into the files in which they belong
291                // If any embeddings fail for a file, the entire file is discarded
292
293                let chunks: Vec<TextToEmbed> = chunked_files
294                    .iter()
295                    .flat_map(|file| {
296                        file.chunks.iter().map(|chunk| TextToEmbed {
297                            text: &file.text[chunk.range.clone()],
298                            digest: chunk.digest,
299                        })
300                    })
301                    .collect::<Vec<_>>();
302
303                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
304                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
305                    if let Some(batch_embeddings) =
306                        embedding_provider.embed(embedding_batch).await.log_err()
307                    {
308                        if batch_embeddings.len() == embedding_batch.len() {
309                            embeddings.extend(batch_embeddings.into_iter().map(Some));
310                            continue;
311                        }
312                        log::error!(
313                            "embedding provider returned unexpected embedding count {}, expected {}",
314                            batch_embeddings.len(), embedding_batch.len()
315                        );
316                    }
317
318                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
319                }
320
321                let mut embeddings = embeddings.into_iter();
322                for chunked_file in chunked_files {
323                    let mut embedded_file = EmbeddedFile {
324                        path: chunked_file.path,
325                        mtime: chunked_file.mtime,
326                        chunks: Vec::new(),
327                    };
328
329                    let mut embedded_all_chunks = true;
330                    for (chunk, embedding) in
331                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
332                    {
333                        if let Some(embedding) = embedding {
334                            embedded_file
335                                .chunks
336                                .push(EmbeddedChunk { chunk, embedding });
337                        } else {
338                            embedded_all_chunks = false;
339                        }
340                    }
341
342                    if embedded_all_chunks {
343                        embedded_files_tx
344                            .send((embedded_file, chunked_file.handle))
345                            .await?;
346                    }
347                }
348            }
349            Ok(())
350        });
351
352        EmbedFiles {
353            files: embedded_files_rx,
354            task,
355        }
356    }
357
358    fn persist_embeddings(
359        &self,
360        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
361        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
362        cx: &AppContext,
363    ) -> Task<Result<()>> {
364        let db_connection = self.db_connection.clone();
365        let db = self.db;
366        cx.background_executor().spawn(async move {
367            while let Some(deletion_range) = deleted_entry_ranges.next().await {
368                let mut txn = db_connection.write_txn()?;
369                let start = deletion_range.0.as_ref().map(|start| start.as_str());
370                let end = deletion_range.1.as_ref().map(|end| end.as_str());
371                log::debug!("deleting embeddings in range {:?}", &(start, end));
372                db.delete_range(&mut txn, &(start, end))?;
373                txn.commit()?;
374            }
375
376            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
377            while let Some(embedded_files) = embedded_files.next().await {
378                let mut txn = db_connection.write_txn()?;
379                for (file, _) in &embedded_files {
380                    log::debug!("saving embedding for file {:?}", file.path);
381                    let key = db_key_for_path(&file.path);
382                    db.put(&mut txn, &key, file)?;
383                }
384                txn.commit()?;
385
386                drop(embedded_files);
387                log::debug!("committed");
388            }
389
390            Ok(())
391        })
392    }
393
394    pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
395        let connection = self.db_connection.clone();
396        let db = self.db;
397        cx.background_executor().spawn(async move {
398            let tx = connection
399                .read_txn()
400                .context("failed to create read transaction")?;
401            let result = db
402                .iter(&tx)?
403                .map(|entry| Ok(entry?.1.path.clone()))
404                .collect::<Result<Vec<Arc<Path>>>>();
405            drop(tx);
406            result
407        })
408    }
409
410    pub fn chunks_for_path(
411        &self,
412        path: Arc<Path>,
413        cx: &AppContext,
414    ) -> Task<Result<Vec<EmbeddedChunk>>> {
415        let connection = self.db_connection.clone();
416        let db = self.db;
417        cx.background_executor().spawn(async move {
418            let tx = connection
419                .read_txn()
420                .context("failed to create read transaction")?;
421            Ok(db
422                .get(&tx, &db_key_for_path(&path))?
423                .ok_or_else(|| anyhow!("no such path"))?
424                .chunks
425                .clone())
426        })
427    }
428}
429
430struct ScanEntries {
431    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
432    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
433    task: Task<Result<()>>,
434}
435
436struct ChunkFiles {
437    files: channel::Receiver<ChunkedFile>,
438    task: Task<Result<()>>,
439}
440
441pub struct ChunkedFile {
442    pub path: Arc<Path>,
443    pub mtime: Option<SystemTime>,
444    pub handle: IndexingEntryHandle,
445    pub text: String,
446    pub chunks: Vec<Chunk>,
447}
448
449pub struct EmbedFiles {
450    pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
451    pub task: Task<Result<()>>,
452}
453
454#[derive(Debug, Serialize, Deserialize)]
455pub struct EmbeddedFile {
456    pub path: Arc<Path>,
457    pub mtime: Option<SystemTime>,
458    pub chunks: Vec<EmbeddedChunk>,
459}
460
461#[derive(Clone, Debug, Serialize, Deserialize)]
462pub struct EmbeddedChunk {
463    pub chunk: Chunk,
464    pub embedding: Embedding,
465}
466
467fn db_key_for_path(path: &Arc<Path>) -> String {
468    path.to_string_lossy().replace('/', "\0")
469}