1use crate::{
2 chunking::{self, Chunk},
3 embedding::{Embedding, EmbeddingProvider, TextToEmbed},
4 indexing::{IndexingEntryHandle, IndexingEntrySet},
5};
6use anyhow::{anyhow, Context as _, Result};
7use collections::Bound;
8use feature_flags::FeatureFlagAppExt;
9use fs::Fs;
10use fs::MTime;
11use futures::{stream::StreamExt, FutureExt as _};
12use futures_batch::ChunksTimeoutStreamExt;
13use gpui::{App, AppContext as _, Entity, Task};
14use heed::types::{SerdeBincode, Str};
15use language::LanguageRegistry;
16use log;
17use project::{Entry, UpdatedEntriesSet, Worktree};
18use serde::{Deserialize, Serialize};
19use smol::channel;
20use std::{cmp::Ordering, future::Future, iter, path::Path, pin::pin, sync::Arc, time::Duration};
21use util::ResultExt;
22use worktree::Snapshot;
23
24pub struct EmbeddingIndex {
25 worktree: Entity<Worktree>,
26 db_connection: heed::Env,
27 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
28 fs: Arc<dyn Fs>,
29 language_registry: Arc<LanguageRegistry>,
30 embedding_provider: Arc<dyn EmbeddingProvider>,
31 entry_ids_being_indexed: Arc<IndexingEntrySet>,
32}
33
34impl EmbeddingIndex {
35 pub fn new(
36 worktree: Entity<Worktree>,
37 fs: Arc<dyn Fs>,
38 db_connection: heed::Env,
39 embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
40 language_registry: Arc<LanguageRegistry>,
41 embedding_provider: Arc<dyn EmbeddingProvider>,
42 entry_ids_being_indexed: Arc<IndexingEntrySet>,
43 ) -> Self {
44 Self {
45 worktree,
46 fs,
47 db_connection,
48 db: embedding_db,
49 language_registry,
50 embedding_provider,
51 entry_ids_being_indexed,
52 }
53 }
54
55 pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
56 &self.db
57 }
58
59 pub fn index_entries_changed_on_disk(&self, cx: &App) -> impl Future<Output = Result<()>> {
60 if !cx.is_staff() {
61 return async move { Ok(()) }.boxed();
62 }
63
64 let worktree = self.worktree.read(cx).snapshot();
65 let worktree_abs_path = worktree.abs_path().clone();
66 let scan = self.scan_entries(worktree, cx);
67 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
68 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
69 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
70 async move {
71 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
72 Ok(())
73 }
74 .boxed()
75 }
76
77 pub fn index_updated_entries(
78 &self,
79 updated_entries: UpdatedEntriesSet,
80 cx: &App,
81 ) -> impl Future<Output = Result<()>> {
82 if !cx.is_staff() {
83 return async move { Ok(()) }.boxed();
84 }
85
86 let worktree = self.worktree.read(cx).snapshot();
87 let worktree_abs_path = worktree.abs_path().clone();
88 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
89 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
90 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
91 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
92 async move {
93 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
94 Ok(())
95 }
96 .boxed()
97 }
98
99 fn scan_entries(&self, worktree: Snapshot, cx: &App) -> ScanEntries {
100 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
101 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
102 let db_connection = self.db_connection.clone();
103 let db = self.db;
104 let entries_being_indexed = self.entry_ids_being_indexed.clone();
105 let task = cx.background_spawn(async move {
106 let txn = db_connection
107 .read_txn()
108 .context("failed to create read transaction")?;
109 let mut db_entries = db
110 .iter(&txn)
111 .context("failed to create iterator")?
112 .move_between_keys()
113 .peekable();
114
115 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
116 for entry in worktree.files(false, 0) {
117 log::trace!("scanning for embedding index: {:?}", &entry.path);
118
119 let entry_db_key = db_key_for_path(&entry.path);
120
121 let mut saved_mtime = None;
122 while let Some(db_entry) = db_entries.peek() {
123 match db_entry {
124 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
125 Ordering::Less => {
126 if let Some(deletion_range) = deletion_range.as_mut() {
127 deletion_range.1 = Bound::Included(db_path);
128 } else {
129 deletion_range =
130 Some((Bound::Included(db_path), Bound::Included(db_path)));
131 }
132
133 db_entries.next();
134 }
135 Ordering::Equal => {
136 if let Some(deletion_range) = deletion_range.take() {
137 deleted_entry_ranges_tx
138 .send((
139 deletion_range.0.map(ToString::to_string),
140 deletion_range.1.map(ToString::to_string),
141 ))
142 .await?;
143 }
144 saved_mtime = db_embedded_file.mtime;
145 db_entries.next();
146 break;
147 }
148 Ordering::Greater => {
149 break;
150 }
151 },
152 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
153 }
154 }
155
156 if entry.mtime != saved_mtime {
157 let handle = entries_being_indexed.insert(entry.id);
158 updated_entries_tx.send((entry.clone(), handle)).await?;
159 }
160 }
161
162 if let Some(db_entry) = db_entries.next() {
163 let (db_path, _) = db_entry?;
164 deleted_entry_ranges_tx
165 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
166 .await?;
167 }
168
169 Ok(())
170 });
171
172 ScanEntries {
173 updated_entries: updated_entries_rx,
174 deleted_entry_ranges: deleted_entry_ranges_rx,
175 task,
176 }
177 }
178
179 fn scan_updated_entries(
180 &self,
181 worktree: Snapshot,
182 updated_entries: UpdatedEntriesSet,
183 cx: &App,
184 ) -> ScanEntries {
185 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
186 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
187 let entries_being_indexed = self.entry_ids_being_indexed.clone();
188 let task = cx.background_spawn(async move {
189 for (path, entry_id, status) in updated_entries.iter() {
190 match status {
191 project::PathChange::Added
192 | project::PathChange::Updated
193 | project::PathChange::AddedOrUpdated => {
194 if let Some(entry) = worktree.entry_for_id(*entry_id) {
195 if entry.is_file() {
196 let handle = entries_being_indexed.insert(entry.id);
197 updated_entries_tx.send((entry.clone(), handle)).await?;
198 }
199 }
200 }
201 project::PathChange::Removed => {
202 let db_path = db_key_for_path(path);
203 deleted_entry_ranges_tx
204 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
205 .await?;
206 }
207 project::PathChange::Loaded => {
208 // Do nothing.
209 }
210 }
211 }
212
213 Ok(())
214 });
215
216 ScanEntries {
217 updated_entries: updated_entries_rx,
218 deleted_entry_ranges: deleted_entry_ranges_rx,
219 task,
220 }
221 }
222
223 fn chunk_files(
224 &self,
225 worktree_abs_path: Arc<Path>,
226 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
227 cx: &App,
228 ) -> ChunkFiles {
229 let language_registry = self.language_registry.clone();
230 let fs = self.fs.clone();
231 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
232 let task = cx.spawn(async move |cx| {
233 cx.background_executor()
234 .scoped(|cx| {
235 for _ in 0..cx.num_cpus() {
236 cx.spawn(async {
237 while let Ok((entry, handle)) = entries.recv().await {
238 let entry_abs_path = worktree_abs_path.join(&entry.path);
239 if let Some(text) = fs.load(&entry_abs_path).await.ok() {
240 let language = language_registry
241 .language_for_file_path(&entry.path)
242 .await
243 .ok();
244 let chunked_file = ChunkedFile {
245 chunks: chunking::chunk_text(
246 &text,
247 language.as_ref(),
248 &entry.path,
249 ),
250 handle,
251 path: entry.path,
252 mtime: entry.mtime,
253 text,
254 };
255
256 if chunked_files_tx.send(chunked_file).await.is_err() {
257 return;
258 }
259 }
260 }
261 });
262 }
263 })
264 .await;
265 Ok(())
266 });
267
268 ChunkFiles {
269 files: chunked_files_rx,
270 task,
271 }
272 }
273
274 pub fn embed_files(
275 embedding_provider: Arc<dyn EmbeddingProvider>,
276 chunked_files: channel::Receiver<ChunkedFile>,
277 cx: &App,
278 ) -> EmbedFiles {
279 let embedding_provider = embedding_provider.clone();
280 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
281 let task = cx.background_spawn(async move {
282 let mut chunked_file_batches =
283 pin!(chunked_files.chunks_timeout(512, Duration::from_secs(2)));
284 while let Some(chunked_files) = chunked_file_batches.next().await {
285 // View the batch of files as a vec of chunks
286 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
287 // Once those are done, reassemble them back into the files in which they belong
288 // If any embeddings fail for a file, the entire file is discarded
289
290 let chunks: Vec<TextToEmbed> = chunked_files
291 .iter()
292 .flat_map(|file| {
293 file.chunks.iter().map(|chunk| TextToEmbed {
294 text: &file.text[chunk.range.clone()],
295 digest: chunk.digest,
296 })
297 })
298 .collect::<Vec<_>>();
299
300 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
301 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
302 if let Some(batch_embeddings) =
303 embedding_provider.embed(embedding_batch).await.log_err()
304 {
305 if batch_embeddings.len() == embedding_batch.len() {
306 embeddings.extend(batch_embeddings.into_iter().map(Some));
307 continue;
308 }
309 log::error!(
310 "embedding provider returned unexpected embedding count {}, expected {}",
311 batch_embeddings.len(), embedding_batch.len()
312 );
313 }
314
315 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
316 }
317
318 let mut embeddings = embeddings.into_iter();
319 for chunked_file in chunked_files {
320 let mut embedded_file = EmbeddedFile {
321 path: chunked_file.path,
322 mtime: chunked_file.mtime,
323 chunks: Vec::new(),
324 };
325
326 let mut embedded_all_chunks = true;
327 for (chunk, embedding) in
328 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
329 {
330 if let Some(embedding) = embedding {
331 embedded_file
332 .chunks
333 .push(EmbeddedChunk { chunk, embedding });
334 } else {
335 embedded_all_chunks = false;
336 }
337 }
338
339 if embedded_all_chunks {
340 embedded_files_tx
341 .send((embedded_file, chunked_file.handle))
342 .await?;
343 }
344 }
345 }
346 Ok(())
347 });
348
349 EmbedFiles {
350 files: embedded_files_rx,
351 task,
352 }
353 }
354
355 fn persist_embeddings(
356 &self,
357 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
358 embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
359 cx: &App,
360 ) -> Task<Result<()>> {
361 let db_connection = self.db_connection.clone();
362 let db = self.db;
363
364 cx.background_spawn(async move {
365 let mut deleted_entry_ranges = pin!(deleted_entry_ranges);
366 let mut embedded_files = pin!(embedded_files);
367 loop {
368 // Interleave deletions and persists of embedded files
369 futures::select_biased! {
370 deletion_range = deleted_entry_ranges.next() => {
371 if let Some(deletion_range) = deletion_range {
372 let mut txn = db_connection.write_txn()?;
373 let start = deletion_range.0.as_ref().map(|start| start.as_str());
374 let end = deletion_range.1.as_ref().map(|end| end.as_str());
375 log::debug!("deleting embeddings in range {:?}", &(start, end));
376 db.delete_range(&mut txn, &(start, end))?;
377 txn.commit()?;
378 }
379 },
380 file = embedded_files.next() => {
381 if let Some((file, _)) = file {
382 let mut txn = db_connection.write_txn()?;
383 log::debug!("saving embedding for file {:?}", file.path);
384 let key = db_key_for_path(&file.path);
385 db.put(&mut txn, &key, &file)?;
386 txn.commit()?;
387 }
388 },
389 complete => break,
390 }
391 }
392
393 Ok(())
394 })
395 }
396
397 pub fn paths(&self, cx: &App) -> Task<Result<Vec<Arc<Path>>>> {
398 let connection = self.db_connection.clone();
399 let db = self.db;
400 cx.background_spawn(async move {
401 let tx = connection
402 .read_txn()
403 .context("failed to create read transaction")?;
404 let result = db
405 .iter(&tx)?
406 .map(|entry| Ok(entry?.1.path.clone()))
407 .collect::<Result<Vec<Arc<Path>>>>();
408 drop(tx);
409 result
410 })
411 }
412
413 pub fn chunks_for_path(&self, path: Arc<Path>, cx: &App) -> Task<Result<Vec<EmbeddedChunk>>> {
414 let connection = self.db_connection.clone();
415 let db = self.db;
416 cx.background_spawn(async move {
417 let tx = connection
418 .read_txn()
419 .context("failed to create read transaction")?;
420 Ok(db
421 .get(&tx, &db_key_for_path(&path))?
422 .ok_or_else(|| anyhow!("no such path"))?
423 .chunks
424 .clone())
425 })
426 }
427}
428
429struct ScanEntries {
430 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
431 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
432 task: Task<Result<()>>,
433}
434
435struct ChunkFiles {
436 files: channel::Receiver<ChunkedFile>,
437 task: Task<Result<()>>,
438}
439
440pub struct ChunkedFile {
441 pub path: Arc<Path>,
442 pub mtime: Option<MTime>,
443 pub handle: IndexingEntryHandle,
444 pub text: String,
445 pub chunks: Vec<Chunk>,
446}
447
448pub struct EmbedFiles {
449 pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
450 pub task: Task<Result<()>>,
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub struct EmbeddedFile {
455 pub path: Arc<Path>,
456 pub mtime: Option<MTime>,
457 pub chunks: Vec<EmbeddedChunk>,
458}
459
460#[derive(Clone, Debug, Serialize, Deserialize)]
461pub struct EmbeddedChunk {
462 pub chunk: Chunk,
463 pub embedding: Embedding,
464}
465
466fn db_key_for_path(path: &Arc<Path>) -> String {
467 path.to_string_lossy().replace('/', "\0")
468}