1use crate::{
2 chunking::{self, Chunk},
3 embedding::{Embedding, EmbeddingProvider, TextToEmbed},
4 indexing::{IndexingEntryHandle, IndexingEntrySet},
5};
6use anyhow::{anyhow, Context as _, Result};
7use collections::Bound;
8use feature_flags::FeatureFlagAppExt;
9use fs::Fs;
10use fs::MTime;
11use futures::stream::StreamExt;
12use futures_batch::ChunksTimeoutStreamExt;
13use gpui::{AppContext, Model, Task};
14use heed::types::{SerdeBincode, Str};
15use language::LanguageRegistry;
16use log;
17use project::{Entry, UpdatedEntriesSet, Worktree};
18use serde::{Deserialize, Serialize};
19use smol::channel;
20use smol::future::FutureExt;
21use std::{cmp::Ordering, future::Future, iter, path::Path, sync::Arc, time::Duration};
22use util::ResultExt;
23use worktree::Snapshot;
24
25pub struct EmbeddingIndex {
26 worktree: Model<Worktree>,
27 db_connection: heed::Env,
28 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
29 fs: Arc<dyn Fs>,
30 language_registry: Arc<LanguageRegistry>,
31 embedding_provider: Arc<dyn EmbeddingProvider>,
32 entry_ids_being_indexed: Arc<IndexingEntrySet>,
33}
34
35impl EmbeddingIndex {
36 pub fn new(
37 worktree: Model<Worktree>,
38 fs: Arc<dyn Fs>,
39 db_connection: heed::Env,
40 embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
41 language_registry: Arc<LanguageRegistry>,
42 embedding_provider: Arc<dyn EmbeddingProvider>,
43 entry_ids_being_indexed: Arc<IndexingEntrySet>,
44 ) -> Self {
45 Self {
46 worktree,
47 fs,
48 db_connection,
49 db: embedding_db,
50 language_registry,
51 embedding_provider,
52 entry_ids_being_indexed,
53 }
54 }
55
56 pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
57 &self.db
58 }
59
60 pub fn index_entries_changed_on_disk(
61 &self,
62 cx: &AppContext,
63 ) -> impl Future<Output = Result<()>> {
64 if !cx.is_staff() {
65 return async move { Ok(()) }.boxed();
66 }
67
68 let worktree = self.worktree.read(cx).snapshot();
69 let worktree_abs_path = worktree.abs_path().clone();
70 let scan = self.scan_entries(worktree, cx);
71 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
72 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
73 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
74 async move {
75 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
76 Ok(())
77 }
78 .boxed()
79 }
80
81 pub fn index_updated_entries(
82 &self,
83 updated_entries: UpdatedEntriesSet,
84 cx: &AppContext,
85 ) -> impl Future<Output = Result<()>> {
86 if !cx.is_staff() {
87 return async move { Ok(()) }.boxed();
88 }
89
90 let worktree = self.worktree.read(cx).snapshot();
91 let worktree_abs_path = worktree.abs_path().clone();
92 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
93 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
94 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
95 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
96 async move {
97 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
98 Ok(())
99 }
100 .boxed()
101 }
102
103 fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
104 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
105 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
106 let db_connection = self.db_connection.clone();
107 let db = self.db;
108 let entries_being_indexed = self.entry_ids_being_indexed.clone();
109 let task = cx.background_executor().spawn(async move {
110 let txn = db_connection
111 .read_txn()
112 .context("failed to create read transaction")?;
113 let mut db_entries = db
114 .iter(&txn)
115 .context("failed to create iterator")?
116 .move_between_keys()
117 .peekable();
118
119 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
120 for entry in worktree.files(false, 0) {
121 log::trace!("scanning for embedding index: {:?}", &entry.path);
122
123 let entry_db_key = db_key_for_path(&entry.path);
124
125 let mut saved_mtime = None;
126 while let Some(db_entry) = db_entries.peek() {
127 match db_entry {
128 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
129 Ordering::Less => {
130 if let Some(deletion_range) = deletion_range.as_mut() {
131 deletion_range.1 = Bound::Included(db_path);
132 } else {
133 deletion_range =
134 Some((Bound::Included(db_path), Bound::Included(db_path)));
135 }
136
137 db_entries.next();
138 }
139 Ordering::Equal => {
140 if let Some(deletion_range) = deletion_range.take() {
141 deleted_entry_ranges_tx
142 .send((
143 deletion_range.0.map(ToString::to_string),
144 deletion_range.1.map(ToString::to_string),
145 ))
146 .await?;
147 }
148 saved_mtime = db_embedded_file.mtime;
149 db_entries.next();
150 break;
151 }
152 Ordering::Greater => {
153 break;
154 }
155 },
156 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
157 }
158 }
159
160 if entry.mtime != saved_mtime {
161 let handle = entries_being_indexed.insert(entry.id);
162 updated_entries_tx.send((entry.clone(), handle)).await?;
163 }
164 }
165
166 if let Some(db_entry) = db_entries.next() {
167 let (db_path, _) = db_entry?;
168 deleted_entry_ranges_tx
169 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
170 .await?;
171 }
172
173 Ok(())
174 });
175
176 ScanEntries {
177 updated_entries: updated_entries_rx,
178 deleted_entry_ranges: deleted_entry_ranges_rx,
179 task,
180 }
181 }
182
183 fn scan_updated_entries(
184 &self,
185 worktree: Snapshot,
186 updated_entries: UpdatedEntriesSet,
187 cx: &AppContext,
188 ) -> ScanEntries {
189 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
190 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
191 let entries_being_indexed = self.entry_ids_being_indexed.clone();
192 let task = cx.background_executor().spawn(async move {
193 for (path, entry_id, status) in updated_entries.iter() {
194 match status {
195 project::PathChange::Added
196 | project::PathChange::Updated
197 | project::PathChange::AddedOrUpdated => {
198 if let Some(entry) = worktree.entry_for_id(*entry_id) {
199 if entry.is_file() {
200 let handle = entries_being_indexed.insert(entry.id);
201 updated_entries_tx.send((entry.clone(), handle)).await?;
202 }
203 }
204 }
205 project::PathChange::Removed => {
206 let db_path = db_key_for_path(path);
207 deleted_entry_ranges_tx
208 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
209 .await?;
210 }
211 project::PathChange::Loaded => {
212 // Do nothing.
213 }
214 }
215 }
216
217 Ok(())
218 });
219
220 ScanEntries {
221 updated_entries: updated_entries_rx,
222 deleted_entry_ranges: deleted_entry_ranges_rx,
223 task,
224 }
225 }
226
227 fn chunk_files(
228 &self,
229 worktree_abs_path: Arc<Path>,
230 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
231 cx: &AppContext,
232 ) -> ChunkFiles {
233 let language_registry = self.language_registry.clone();
234 let fs = self.fs.clone();
235 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
236 let task = cx.spawn(|cx| async move {
237 cx.background_executor()
238 .scoped(|cx| {
239 for _ in 0..cx.num_cpus() {
240 cx.spawn(async {
241 while let Ok((entry, handle)) = entries.recv().await {
242 let entry_abs_path = worktree_abs_path.join(&entry.path);
243 if let Some(text) = fs.load(&entry_abs_path).await.ok() {
244 let language = language_registry
245 .language_for_file_path(&entry.path)
246 .await
247 .ok();
248 let chunked_file = ChunkedFile {
249 chunks: chunking::chunk_text(
250 &text,
251 language.as_ref(),
252 &entry.path,
253 ),
254 handle,
255 path: entry.path,
256 mtime: entry.mtime,
257 text,
258 };
259
260 if chunked_files_tx.send(chunked_file).await.is_err() {
261 return;
262 }
263 }
264 }
265 });
266 }
267 })
268 .await;
269 Ok(())
270 });
271
272 ChunkFiles {
273 files: chunked_files_rx,
274 task,
275 }
276 }
277
278 pub fn embed_files(
279 embedding_provider: Arc<dyn EmbeddingProvider>,
280 chunked_files: channel::Receiver<ChunkedFile>,
281 cx: &AppContext,
282 ) -> EmbedFiles {
283 let embedding_provider = embedding_provider.clone();
284 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
285 let task = cx.background_executor().spawn(async move {
286 let mut chunked_file_batches =
287 chunked_files.chunks_timeout(512, Duration::from_secs(2));
288 while let Some(chunked_files) = chunked_file_batches.next().await {
289 // View the batch of files as a vec of chunks
290 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
291 // Once those are done, reassemble them back into the files in which they belong
292 // If any embeddings fail for a file, the entire file is discarded
293
294 let chunks: Vec<TextToEmbed> = chunked_files
295 .iter()
296 .flat_map(|file| {
297 file.chunks.iter().map(|chunk| TextToEmbed {
298 text: &file.text[chunk.range.clone()],
299 digest: chunk.digest,
300 })
301 })
302 .collect::<Vec<_>>();
303
304 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
305 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
306 if let Some(batch_embeddings) =
307 embedding_provider.embed(embedding_batch).await.log_err()
308 {
309 if batch_embeddings.len() == embedding_batch.len() {
310 embeddings.extend(batch_embeddings.into_iter().map(Some));
311 continue;
312 }
313 log::error!(
314 "embedding provider returned unexpected embedding count {}, expected {}",
315 batch_embeddings.len(), embedding_batch.len()
316 );
317 }
318
319 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
320 }
321
322 let mut embeddings = embeddings.into_iter();
323 for chunked_file in chunked_files {
324 let mut embedded_file = EmbeddedFile {
325 path: chunked_file.path,
326 mtime: chunked_file.mtime,
327 chunks: Vec::new(),
328 };
329
330 let mut embedded_all_chunks = true;
331 for (chunk, embedding) in
332 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
333 {
334 if let Some(embedding) = embedding {
335 embedded_file
336 .chunks
337 .push(EmbeddedChunk { chunk, embedding });
338 } else {
339 embedded_all_chunks = false;
340 }
341 }
342
343 if embedded_all_chunks {
344 embedded_files_tx
345 .send((embedded_file, chunked_file.handle))
346 .await?;
347 }
348 }
349 }
350 Ok(())
351 });
352
353 EmbedFiles {
354 files: embedded_files_rx,
355 task,
356 }
357 }
358
359 fn persist_embeddings(
360 &self,
361 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
362 mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
363 cx: &AppContext,
364 ) -> Task<Result<()>> {
365 let db_connection = self.db_connection.clone();
366 let db = self.db;
367
368 cx.background_executor().spawn(async move {
369 loop {
370 // Interleave deletions and persists of embedded files
371 futures::select_biased! {
372 deletion_range = deleted_entry_ranges.next() => {
373 if let Some(deletion_range) = deletion_range {
374 let mut txn = db_connection.write_txn()?;
375 let start = deletion_range.0.as_ref().map(|start| start.as_str());
376 let end = deletion_range.1.as_ref().map(|end| end.as_str());
377 log::debug!("deleting embeddings in range {:?}", &(start, end));
378 db.delete_range(&mut txn, &(start, end))?;
379 txn.commit()?;
380 }
381 },
382 file = embedded_files.next() => {
383 if let Some((file, _)) = file {
384 let mut txn = db_connection.write_txn()?;
385 log::debug!("saving embedding for file {:?}", file.path);
386 let key = db_key_for_path(&file.path);
387 db.put(&mut txn, &key, &file)?;
388 txn.commit()?;
389 }
390 },
391 complete => break,
392 }
393 }
394
395 Ok(())
396 })
397 }
398
399 pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
400 let connection = self.db_connection.clone();
401 let db = self.db;
402 cx.background_executor().spawn(async move {
403 let tx = connection
404 .read_txn()
405 .context("failed to create read transaction")?;
406 let result = db
407 .iter(&tx)?
408 .map(|entry| Ok(entry?.1.path.clone()))
409 .collect::<Result<Vec<Arc<Path>>>>();
410 drop(tx);
411 result
412 })
413 }
414
415 pub fn chunks_for_path(
416 &self,
417 path: Arc<Path>,
418 cx: &AppContext,
419 ) -> Task<Result<Vec<EmbeddedChunk>>> {
420 let connection = self.db_connection.clone();
421 let db = self.db;
422 cx.background_executor().spawn(async move {
423 let tx = connection
424 .read_txn()
425 .context("failed to create read transaction")?;
426 Ok(db
427 .get(&tx, &db_key_for_path(&path))?
428 .ok_or_else(|| anyhow!("no such path"))?
429 .chunks
430 .clone())
431 })
432 }
433}
434
435struct ScanEntries {
436 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
437 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
438 task: Task<Result<()>>,
439}
440
441struct ChunkFiles {
442 files: channel::Receiver<ChunkedFile>,
443 task: Task<Result<()>>,
444}
445
446pub struct ChunkedFile {
447 pub path: Arc<Path>,
448 pub mtime: Option<MTime>,
449 pub handle: IndexingEntryHandle,
450 pub text: String,
451 pub chunks: Vec<Chunk>,
452}
453
454pub struct EmbedFiles {
455 pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
456 pub task: Task<Result<()>>,
457}
458
459#[derive(Debug, Serialize, Deserialize)]
460pub struct EmbeddedFile {
461 pub path: Arc<Path>,
462 pub mtime: Option<MTime>,
463 pub chunks: Vec<EmbeddedChunk>,
464}
465
466#[derive(Clone, Debug, Serialize, Deserialize)]
467pub struct EmbeddedChunk {
468 pub chunk: Chunk,
469 pub embedding: Embedding,
470}
471
472fn db_key_for_path(path: &Arc<Path>) -> String {
473 path.to_string_lossy().replace('/', "\0")
474}