1use crate::{
2 chunking::{self, Chunk},
3 embedding::{Embedding, EmbeddingProvider, TextToEmbed},
4 indexing::{IndexingEntryHandle, IndexingEntrySet},
5};
6use anyhow::{Context as _, Result};
7use collections::Bound;
8use feature_flags::FeatureFlagAppExt;
9use fs::Fs;
10use fs::MTime;
11use futures::{FutureExt as _, stream::StreamExt};
12use futures_batch::ChunksTimeoutStreamExt;
13use gpui::{App, AppContext as _, Entity, Task};
14use heed::types::{SerdeBincode, Str};
15use language::LanguageRegistry;
16use log;
17use project::{Entry, UpdatedEntriesSet, Worktree};
18use serde::{Deserialize, Serialize};
19use smol::channel;
20use std::{cmp::Ordering, future::Future, iter, path::Path, pin::pin, sync::Arc, time::Duration};
21use util::ResultExt;
22use worktree::Snapshot;
23
24pub struct EmbeddingIndex {
25 worktree: Entity<Worktree>,
26 db_connection: heed::Env,
27 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
28 fs: Arc<dyn Fs>,
29 language_registry: Arc<LanguageRegistry>,
30 embedding_provider: Arc<dyn EmbeddingProvider>,
31 entry_ids_being_indexed: Arc<IndexingEntrySet>,
32}
33
34impl EmbeddingIndex {
35 pub fn new(
36 worktree: Entity<Worktree>,
37 fs: Arc<dyn Fs>,
38 db_connection: heed::Env,
39 embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
40 language_registry: Arc<LanguageRegistry>,
41 embedding_provider: Arc<dyn EmbeddingProvider>,
42 entry_ids_being_indexed: Arc<IndexingEntrySet>,
43 ) -> Self {
44 Self {
45 worktree,
46 fs,
47 db_connection,
48 db: embedding_db,
49 language_registry,
50 embedding_provider,
51 entry_ids_being_indexed,
52 }
53 }
54
55 pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
56 &self.db
57 }
58
59 pub fn index_entries_changed_on_disk(
60 &self,
61 cx: &App,
62 ) -> impl Future<Output = Result<()>> + use<> {
63 if !cx.is_staff() {
64 return async move { Ok(()) }.boxed();
65 }
66
67 let worktree = self.worktree.read(cx).snapshot();
68 let worktree_abs_path = worktree.abs_path().clone();
69 let scan = self.scan_entries(worktree, cx);
70 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
71 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
72 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
73 async move {
74 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
75 Ok(())
76 }
77 .boxed()
78 }
79
80 pub fn index_updated_entries(
81 &self,
82 updated_entries: UpdatedEntriesSet,
83 cx: &App,
84 ) -> impl Future<Output = Result<()>> + use<> {
85 if !cx.is_staff() {
86 return async move { Ok(()) }.boxed();
87 }
88
89 let worktree = self.worktree.read(cx).snapshot();
90 let worktree_abs_path = worktree.abs_path().clone();
91 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
92 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
93 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
94 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
95 async move {
96 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
97 Ok(())
98 }
99 .boxed()
100 }
101
102 fn scan_entries(&self, worktree: Snapshot, cx: &App) -> ScanEntries {
103 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
104 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
105 let db_connection = self.db_connection.clone();
106 let db = self.db;
107 let entries_being_indexed = self.entry_ids_being_indexed.clone();
108 let task = cx.background_spawn(async move {
109 let txn = db_connection
110 .read_txn()
111 .context("failed to create read transaction")?;
112 let mut db_entries = db
113 .iter(&txn)
114 .context("failed to create iterator")?
115 .move_between_keys()
116 .peekable();
117
118 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
119 for entry in worktree.files(false, 0) {
120 log::trace!("scanning for embedding index: {:?}", &entry.path);
121
122 let entry_db_key = db_key_for_path(&entry.path);
123
124 let mut saved_mtime = None;
125 while let Some(db_entry) = db_entries.peek() {
126 match db_entry {
127 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
128 Ordering::Less => {
129 if let Some(deletion_range) = deletion_range.as_mut() {
130 deletion_range.1 = Bound::Included(db_path);
131 } else {
132 deletion_range =
133 Some((Bound::Included(db_path), Bound::Included(db_path)));
134 }
135
136 db_entries.next();
137 }
138 Ordering::Equal => {
139 if let Some(deletion_range) = deletion_range.take() {
140 deleted_entry_ranges_tx
141 .send((
142 deletion_range.0.map(ToString::to_string),
143 deletion_range.1.map(ToString::to_string),
144 ))
145 .await?;
146 }
147 saved_mtime = db_embedded_file.mtime;
148 db_entries.next();
149 break;
150 }
151 Ordering::Greater => {
152 break;
153 }
154 },
155 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
156 }
157 }
158
159 if entry.mtime != saved_mtime {
160 let handle = entries_being_indexed.insert(entry.id);
161 updated_entries_tx.send((entry.clone(), handle)).await?;
162 }
163 }
164
165 if let Some(db_entry) = db_entries.next() {
166 let (db_path, _) = db_entry?;
167 deleted_entry_ranges_tx
168 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
169 .await?;
170 }
171
172 Ok(())
173 });
174
175 ScanEntries {
176 updated_entries: updated_entries_rx,
177 deleted_entry_ranges: deleted_entry_ranges_rx,
178 task,
179 }
180 }
181
182 fn scan_updated_entries(
183 &self,
184 worktree: Snapshot,
185 updated_entries: UpdatedEntriesSet,
186 cx: &App,
187 ) -> ScanEntries {
188 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
189 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
190 let entries_being_indexed = self.entry_ids_being_indexed.clone();
191 let task = cx.background_spawn(async move {
192 for (path, entry_id, status) in updated_entries.iter() {
193 match status {
194 project::PathChange::Added
195 | project::PathChange::Updated
196 | project::PathChange::AddedOrUpdated => {
197 if let Some(entry) = worktree.entry_for_id(*entry_id)
198 && entry.is_file() {
199 let handle = entries_being_indexed.insert(entry.id);
200 updated_entries_tx.send((entry.clone(), handle)).await?;
201 }
202 }
203 project::PathChange::Removed => {
204 let db_path = db_key_for_path(path);
205 deleted_entry_ranges_tx
206 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
207 .await?;
208 }
209 project::PathChange::Loaded => {
210 // Do nothing.
211 }
212 }
213 }
214
215 Ok(())
216 });
217
218 ScanEntries {
219 updated_entries: updated_entries_rx,
220 deleted_entry_ranges: deleted_entry_ranges_rx,
221 task,
222 }
223 }
224
225 fn chunk_files(
226 &self,
227 worktree_abs_path: Arc<Path>,
228 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
229 cx: &App,
230 ) -> ChunkFiles {
231 let language_registry = self.language_registry.clone();
232 let fs = self.fs.clone();
233 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
234 let task = cx.spawn(async move |cx| {
235 cx.background_executor()
236 .scoped(|cx| {
237 for _ in 0..cx.num_cpus() {
238 cx.spawn(async {
239 while let Ok((entry, handle)) = entries.recv().await {
240 let entry_abs_path = worktree_abs_path.join(&entry.path);
241 if let Some(text) = fs.load(&entry_abs_path).await.ok() {
242 let language = language_registry
243 .language_for_file_path(&entry.path)
244 .await
245 .ok();
246 let chunked_file = ChunkedFile {
247 chunks: chunking::chunk_text(
248 &text,
249 language.as_ref(),
250 &entry.path,
251 ),
252 handle,
253 path: entry.path,
254 mtime: entry.mtime,
255 text,
256 };
257
258 if chunked_files_tx.send(chunked_file).await.is_err() {
259 return;
260 }
261 }
262 }
263 });
264 }
265 })
266 .await;
267 Ok(())
268 });
269
270 ChunkFiles {
271 files: chunked_files_rx,
272 task,
273 }
274 }
275
276 pub fn embed_files(
277 embedding_provider: Arc<dyn EmbeddingProvider>,
278 chunked_files: channel::Receiver<ChunkedFile>,
279 cx: &App,
280 ) -> EmbedFiles {
281 let embedding_provider = embedding_provider.clone();
282 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
283 let task = cx.background_spawn(async move {
284 let mut chunked_file_batches =
285 pin!(chunked_files.chunks_timeout(512, Duration::from_secs(2)));
286 while let Some(chunked_files) = chunked_file_batches.next().await {
287 // View the batch of files as a vec of chunks
288 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
289 // Once those are done, reassemble them back into the files in which they belong
290 // If any embeddings fail for a file, the entire file is discarded
291
292 let chunks: Vec<TextToEmbed> = chunked_files
293 .iter()
294 .flat_map(|file| {
295 file.chunks.iter().map(|chunk| TextToEmbed {
296 text: &file.text[chunk.range.clone()],
297 digest: chunk.digest,
298 })
299 })
300 .collect::<Vec<_>>();
301
302 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
303 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
304 if let Some(batch_embeddings) =
305 embedding_provider.embed(embedding_batch).await.log_err()
306 {
307 if batch_embeddings.len() == embedding_batch.len() {
308 embeddings.extend(batch_embeddings.into_iter().map(Some));
309 continue;
310 }
311 log::error!(
312 "embedding provider returned unexpected embedding count {}, expected {}",
313 batch_embeddings.len(), embedding_batch.len()
314 );
315 }
316
317 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
318 }
319
320 let mut embeddings = embeddings.into_iter();
321 for chunked_file in chunked_files {
322 let mut embedded_file = EmbeddedFile {
323 path: chunked_file.path,
324 mtime: chunked_file.mtime,
325 chunks: Vec::new(),
326 };
327
328 let mut embedded_all_chunks = true;
329 for (chunk, embedding) in
330 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
331 {
332 if let Some(embedding) = embedding {
333 embedded_file
334 .chunks
335 .push(EmbeddedChunk { chunk, embedding });
336 } else {
337 embedded_all_chunks = false;
338 }
339 }
340
341 if embedded_all_chunks {
342 embedded_files_tx
343 .send((embedded_file, chunked_file.handle))
344 .await?;
345 }
346 }
347 }
348 Ok(())
349 });
350
351 EmbedFiles {
352 files: embedded_files_rx,
353 task,
354 }
355 }
356
357 fn persist_embeddings(
358 &self,
359 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
360 embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
361 cx: &App,
362 ) -> Task<Result<()>> {
363 let db_connection = self.db_connection.clone();
364 let db = self.db;
365
366 cx.background_spawn(async move {
367 let mut deleted_entry_ranges = pin!(deleted_entry_ranges);
368 let mut embedded_files = pin!(embedded_files);
369 loop {
370 // Interleave deletions and persists of embedded files
371 futures::select_biased! {
372 deletion_range = deleted_entry_ranges.next() => {
373 if let Some(deletion_range) = deletion_range {
374 let mut txn = db_connection.write_txn()?;
375 let start = deletion_range.0.as_ref().map(|start| start.as_str());
376 let end = deletion_range.1.as_ref().map(|end| end.as_str());
377 log::debug!("deleting embeddings in range {:?}", &(start, end));
378 db.delete_range(&mut txn, &(start, end))?;
379 txn.commit()?;
380 }
381 },
382 file = embedded_files.next() => {
383 if let Some((file, _)) = file {
384 let mut txn = db_connection.write_txn()?;
385 log::debug!("saving embedding for file {:?}", file.path);
386 let key = db_key_for_path(&file.path);
387 db.put(&mut txn, &key, &file)?;
388 txn.commit()?;
389 }
390 },
391 complete => break,
392 }
393 }
394
395 Ok(())
396 })
397 }
398
399 pub fn paths(&self, cx: &App) -> Task<Result<Vec<Arc<Path>>>> {
400 let connection = self.db_connection.clone();
401 let db = self.db;
402 cx.background_spawn(async move {
403 let tx = connection
404 .read_txn()
405 .context("failed to create read transaction")?;
406 let result = db
407 .iter(&tx)?
408 .map(|entry| Ok(entry?.1.path.clone()))
409 .collect::<Result<Vec<Arc<Path>>>>();
410 drop(tx);
411 result
412 })
413 }
414
415 pub fn chunks_for_path(&self, path: Arc<Path>, cx: &App) -> Task<Result<Vec<EmbeddedChunk>>> {
416 let connection = self.db_connection.clone();
417 let db = self.db;
418 cx.background_spawn(async move {
419 let tx = connection
420 .read_txn()
421 .context("failed to create read transaction")?;
422 Ok(db
423 .get(&tx, &db_key_for_path(&path))?
424 .context("no such path")?
425 .chunks
426 .clone())
427 })
428 }
429}
430
431struct ScanEntries {
432 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
433 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
434 task: Task<Result<()>>,
435}
436
437struct ChunkFiles {
438 files: channel::Receiver<ChunkedFile>,
439 task: Task<Result<()>>,
440}
441
442pub struct ChunkedFile {
443 pub path: Arc<Path>,
444 pub mtime: Option<MTime>,
445 pub handle: IndexingEntryHandle,
446 pub text: String,
447 pub chunks: Vec<Chunk>,
448}
449
450pub struct EmbedFiles {
451 pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
452 pub task: Task<Result<()>>,
453}
454
455#[derive(Debug, Serialize, Deserialize)]
456pub struct EmbeddedFile {
457 pub path: Arc<Path>,
458 pub mtime: Option<MTime>,
459 pub chunks: Vec<EmbeddedChunk>,
460}
461
462#[derive(Clone, Debug, Serialize, Deserialize)]
463pub struct EmbeddedChunk {
464 pub chunk: Chunk,
465 pub embedding: Embedding,
466}
467
468fn db_key_for_path(path: &Arc<Path>) -> String {
469 path.to_string_lossy().replace('/', "\0")
470}