1use crate::{
2 chunking::{self, Chunk},
3 embedding::{Embedding, EmbeddingProvider, TextToEmbed},
4 indexing::{IndexingEntryHandle, IndexingEntrySet},
5};
6use anyhow::{anyhow, Context as _, Result};
7use collections::Bound;
8use feature_flags::FeatureFlagAppExt;
9use fs::Fs;
10use futures::stream::StreamExt;
11use futures_batch::ChunksTimeoutStreamExt;
12use gpui::{AppContext, Model, Task};
13use heed::types::{SerdeBincode, Str};
14use language::LanguageRegistry;
15use log;
16use project::{Entry, UpdatedEntriesSet, Worktree};
17use serde::{Deserialize, Serialize};
18use smol::channel;
19use smol::future::FutureExt;
20use std::{
21 cmp::Ordering,
22 future::Future,
23 iter,
24 path::Path,
25 sync::Arc,
26 time::{Duration, SystemTime},
27};
28use util::ResultExt;
29use worktree::Snapshot;
30
31pub struct EmbeddingIndex {
32 worktree: Model<Worktree>,
33 db_connection: heed::Env,
34 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
35 fs: Arc<dyn Fs>,
36 language_registry: Arc<LanguageRegistry>,
37 embedding_provider: Arc<dyn EmbeddingProvider>,
38 entry_ids_being_indexed: Arc<IndexingEntrySet>,
39}
40
41impl EmbeddingIndex {
42 pub fn new(
43 worktree: Model<Worktree>,
44 fs: Arc<dyn Fs>,
45 db_connection: heed::Env,
46 embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
47 language_registry: Arc<LanguageRegistry>,
48 embedding_provider: Arc<dyn EmbeddingProvider>,
49 entry_ids_being_indexed: Arc<IndexingEntrySet>,
50 ) -> Self {
51 Self {
52 worktree,
53 fs,
54 db_connection,
55 db: embedding_db,
56 language_registry,
57 embedding_provider,
58 entry_ids_being_indexed,
59 }
60 }
61
62 pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
63 &self.db
64 }
65
66 pub fn index_entries_changed_on_disk(
67 &self,
68 cx: &AppContext,
69 ) -> impl Future<Output = Result<()>> {
70 if !cx.is_staff() {
71 return async move { Ok(()) }.boxed();
72 }
73
74 let worktree = self.worktree.read(cx).snapshot();
75 let worktree_abs_path = worktree.abs_path().clone();
76 let scan = self.scan_entries(worktree, cx);
77 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
78 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
79 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
80 async move {
81 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
82 Ok(())
83 }
84 .boxed()
85 }
86
87 pub fn index_updated_entries(
88 &self,
89 updated_entries: UpdatedEntriesSet,
90 cx: &AppContext,
91 ) -> impl Future<Output = Result<()>> {
92 if !cx.is_staff() {
93 return async move { Ok(()) }.boxed();
94 }
95
96 let worktree = self.worktree.read(cx).snapshot();
97 let worktree_abs_path = worktree.abs_path().clone();
98 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
99 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
100 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
101 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
102 async move {
103 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
104 Ok(())
105 }
106 .boxed()
107 }
108
109 fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
110 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
111 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
112 let db_connection = self.db_connection.clone();
113 let db = self.db;
114 let entries_being_indexed = self.entry_ids_being_indexed.clone();
115 let task = cx.background_executor().spawn(async move {
116 let txn = db_connection
117 .read_txn()
118 .context("failed to create read transaction")?;
119 let mut db_entries = db
120 .iter(&txn)
121 .context("failed to create iterator")?
122 .move_between_keys()
123 .peekable();
124
125 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
126 for entry in worktree.files(false, 0) {
127 log::trace!("scanning for embedding index: {:?}", &entry.path);
128
129 let entry_db_key = db_key_for_path(&entry.path);
130
131 let mut saved_mtime = None;
132 while let Some(db_entry) = db_entries.peek() {
133 match db_entry {
134 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
135 Ordering::Less => {
136 if let Some(deletion_range) = deletion_range.as_mut() {
137 deletion_range.1 = Bound::Included(db_path);
138 } else {
139 deletion_range =
140 Some((Bound::Included(db_path), Bound::Included(db_path)));
141 }
142
143 db_entries.next();
144 }
145 Ordering::Equal => {
146 if let Some(deletion_range) = deletion_range.take() {
147 deleted_entry_ranges_tx
148 .send((
149 deletion_range.0.map(ToString::to_string),
150 deletion_range.1.map(ToString::to_string),
151 ))
152 .await?;
153 }
154 saved_mtime = db_embedded_file.mtime;
155 db_entries.next();
156 break;
157 }
158 Ordering::Greater => {
159 break;
160 }
161 },
162 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
163 }
164 }
165
166 if entry.mtime != saved_mtime {
167 let handle = entries_being_indexed.insert(entry.id);
168 updated_entries_tx.send((entry.clone(), handle)).await?;
169 }
170 }
171
172 if let Some(db_entry) = db_entries.next() {
173 let (db_path, _) = db_entry?;
174 deleted_entry_ranges_tx
175 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
176 .await?;
177 }
178
179 Ok(())
180 });
181
182 ScanEntries {
183 updated_entries: updated_entries_rx,
184 deleted_entry_ranges: deleted_entry_ranges_rx,
185 task,
186 }
187 }
188
189 fn scan_updated_entries(
190 &self,
191 worktree: Snapshot,
192 updated_entries: UpdatedEntriesSet,
193 cx: &AppContext,
194 ) -> ScanEntries {
195 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
196 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
197 let entries_being_indexed = self.entry_ids_being_indexed.clone();
198 let task = cx.background_executor().spawn(async move {
199 for (path, entry_id, status) in updated_entries.iter() {
200 match status {
201 project::PathChange::Added
202 | project::PathChange::Updated
203 | project::PathChange::AddedOrUpdated => {
204 if let Some(entry) = worktree.entry_for_id(*entry_id) {
205 if entry.is_file() {
206 let handle = entries_being_indexed.insert(entry.id);
207 updated_entries_tx.send((entry.clone(), handle)).await?;
208 }
209 }
210 }
211 project::PathChange::Removed => {
212 let db_path = db_key_for_path(path);
213 deleted_entry_ranges_tx
214 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
215 .await?;
216 }
217 project::PathChange::Loaded => {
218 // Do nothing.
219 }
220 }
221 }
222
223 Ok(())
224 });
225
226 ScanEntries {
227 updated_entries: updated_entries_rx,
228 deleted_entry_ranges: deleted_entry_ranges_rx,
229 task,
230 }
231 }
232
233 fn chunk_files(
234 &self,
235 worktree_abs_path: Arc<Path>,
236 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
237 cx: &AppContext,
238 ) -> ChunkFiles {
239 let language_registry = self.language_registry.clone();
240 let fs = self.fs.clone();
241 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
242 let task = cx.spawn(|cx| async move {
243 cx.background_executor()
244 .scoped(|cx| {
245 for _ in 0..cx.num_cpus() {
246 cx.spawn(async {
247 while let Ok((entry, handle)) = entries.recv().await {
248 let entry_abs_path = worktree_abs_path.join(&entry.path);
249 if let Some(text) = fs.load(&entry_abs_path).await.ok() {
250 let language = language_registry
251 .language_for_file_path(&entry.path)
252 .await
253 .ok();
254 let chunked_file = ChunkedFile {
255 chunks: chunking::chunk_text(
256 &text,
257 language.as_ref(),
258 &entry.path,
259 ),
260 handle,
261 path: entry.path,
262 mtime: entry.mtime,
263 text,
264 };
265
266 if chunked_files_tx.send(chunked_file).await.is_err() {
267 return;
268 }
269 }
270 }
271 });
272 }
273 })
274 .await;
275 Ok(())
276 });
277
278 ChunkFiles {
279 files: chunked_files_rx,
280 task,
281 }
282 }
283
284 pub fn embed_files(
285 embedding_provider: Arc<dyn EmbeddingProvider>,
286 chunked_files: channel::Receiver<ChunkedFile>,
287 cx: &AppContext,
288 ) -> EmbedFiles {
289 let embedding_provider = embedding_provider.clone();
290 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
291 let task = cx.background_executor().spawn(async move {
292 let mut chunked_file_batches =
293 chunked_files.chunks_timeout(512, Duration::from_secs(2));
294 while let Some(chunked_files) = chunked_file_batches.next().await {
295 // View the batch of files as a vec of chunks
296 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
297 // Once those are done, reassemble them back into the files in which they belong
298 // If any embeddings fail for a file, the entire file is discarded
299
300 let chunks: Vec<TextToEmbed> = chunked_files
301 .iter()
302 .flat_map(|file| {
303 file.chunks.iter().map(|chunk| TextToEmbed {
304 text: &file.text[chunk.range.clone()],
305 digest: chunk.digest,
306 })
307 })
308 .collect::<Vec<_>>();
309
310 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
311 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
312 if let Some(batch_embeddings) =
313 embedding_provider.embed(embedding_batch).await.log_err()
314 {
315 if batch_embeddings.len() == embedding_batch.len() {
316 embeddings.extend(batch_embeddings.into_iter().map(Some));
317 continue;
318 }
319 log::error!(
320 "embedding provider returned unexpected embedding count {}, expected {}",
321 batch_embeddings.len(), embedding_batch.len()
322 );
323 }
324
325 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
326 }
327
328 let mut embeddings = embeddings.into_iter();
329 for chunked_file in chunked_files {
330 let mut embedded_file = EmbeddedFile {
331 path: chunked_file.path,
332 mtime: chunked_file.mtime,
333 chunks: Vec::new(),
334 };
335
336 let mut embedded_all_chunks = true;
337 for (chunk, embedding) in
338 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
339 {
340 if let Some(embedding) = embedding {
341 embedded_file
342 .chunks
343 .push(EmbeddedChunk { chunk, embedding });
344 } else {
345 embedded_all_chunks = false;
346 }
347 }
348
349 if embedded_all_chunks {
350 embedded_files_tx
351 .send((embedded_file, chunked_file.handle))
352 .await?;
353 }
354 }
355 }
356 Ok(())
357 });
358
359 EmbedFiles {
360 files: embedded_files_rx,
361 task,
362 }
363 }
364
365 fn persist_embeddings(
366 &self,
367 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
368 mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
369 cx: &AppContext,
370 ) -> Task<Result<()>> {
371 let db_connection = self.db_connection.clone();
372 let db = self.db;
373
374 cx.background_executor().spawn(async move {
375 loop {
376 // Interleave deletions and persists of embedded files
377 futures::select_biased! {
378 deletion_range = deleted_entry_ranges.next() => {
379 if let Some(deletion_range) = deletion_range {
380 let mut txn = db_connection.write_txn()?;
381 let start = deletion_range.0.as_ref().map(|start| start.as_str());
382 let end = deletion_range.1.as_ref().map(|end| end.as_str());
383 log::debug!("deleting embeddings in range {:?}", &(start, end));
384 db.delete_range(&mut txn, &(start, end))?;
385 txn.commit()?;
386 }
387 },
388 file = embedded_files.next() => {
389 if let Some((file, _)) = file {
390 let mut txn = db_connection.write_txn()?;
391 log::debug!("saving embedding for file {:?}", file.path);
392 let key = db_key_for_path(&file.path);
393 db.put(&mut txn, &key, &file)?;
394 txn.commit()?;
395 }
396 },
397 complete => break,
398 }
399 }
400
401 Ok(())
402 })
403 }
404
405 pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
406 let connection = self.db_connection.clone();
407 let db = self.db;
408 cx.background_executor().spawn(async move {
409 let tx = connection
410 .read_txn()
411 .context("failed to create read transaction")?;
412 let result = db
413 .iter(&tx)?
414 .map(|entry| Ok(entry?.1.path.clone()))
415 .collect::<Result<Vec<Arc<Path>>>>();
416 drop(tx);
417 result
418 })
419 }
420
421 pub fn chunks_for_path(
422 &self,
423 path: Arc<Path>,
424 cx: &AppContext,
425 ) -> Task<Result<Vec<EmbeddedChunk>>> {
426 let connection = self.db_connection.clone();
427 let db = self.db;
428 cx.background_executor().spawn(async move {
429 let tx = connection
430 .read_txn()
431 .context("failed to create read transaction")?;
432 Ok(db
433 .get(&tx, &db_key_for_path(&path))?
434 .ok_or_else(|| anyhow!("no such path"))?
435 .chunks
436 .clone())
437 })
438 }
439}
440
441struct ScanEntries {
442 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
443 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
444 task: Task<Result<()>>,
445}
446
447struct ChunkFiles {
448 files: channel::Receiver<ChunkedFile>,
449 task: Task<Result<()>>,
450}
451
452pub struct ChunkedFile {
453 pub path: Arc<Path>,
454 pub mtime: Option<SystemTime>,
455 pub handle: IndexingEntryHandle,
456 pub text: String,
457 pub chunks: Vec<Chunk>,
458}
459
460pub struct EmbedFiles {
461 pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
462 pub task: Task<Result<()>>,
463}
464
465#[derive(Debug, Serialize, Deserialize)]
466pub struct EmbeddedFile {
467 pub path: Arc<Path>,
468 pub mtime: Option<SystemTime>,
469 pub chunks: Vec<EmbeddedChunk>,
470}
471
472#[derive(Clone, Debug, Serialize, Deserialize)]
473pub struct EmbeddedChunk {
474 pub chunk: Chunk,
475 pub embedding: Embedding,
476}
477
478fn db_key_for_path(path: &Arc<Path>) -> String {
479 path.to_string_lossy().replace('/', "\0")
480}