1use crate::{
2 chunking::{self, Chunk},
3 embedding::{Embedding, EmbeddingProvider, TextToEmbed},
4 indexing::{IndexingEntryHandle, IndexingEntrySet},
5};
6use anyhow::{anyhow, Context as _, Result};
7use collections::Bound;
8use fs::Fs;
9use futures::stream::StreamExt;
10use futures_batch::ChunksTimeoutStreamExt;
11use gpui::{AppContext, Model, Task};
12use heed::types::{SerdeBincode, Str};
13use language::LanguageRegistry;
14use log;
15use project::{Entry, UpdatedEntriesSet, Worktree};
16use serde::{Deserialize, Serialize};
17use smol::channel;
18use std::{
19 cmp::Ordering,
20 future::Future,
21 iter,
22 path::Path,
23 sync::Arc,
24 time::{Duration, SystemTime},
25};
26use util::ResultExt;
27use worktree::Snapshot;
28
29pub struct EmbeddingIndex {
30 worktree: Model<Worktree>,
31 db_connection: heed::Env,
32 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
33 fs: Arc<dyn Fs>,
34 language_registry: Arc<LanguageRegistry>,
35 embedding_provider: Arc<dyn EmbeddingProvider>,
36 entry_ids_being_indexed: Arc<IndexingEntrySet>,
37}
38
39impl EmbeddingIndex {
40 pub fn new(
41 worktree: Model<Worktree>,
42 fs: Arc<dyn Fs>,
43 db_connection: heed::Env,
44 embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
45 language_registry: Arc<LanguageRegistry>,
46 embedding_provider: Arc<dyn EmbeddingProvider>,
47 entry_ids_being_indexed: Arc<IndexingEntrySet>,
48 ) -> Self {
49 Self {
50 worktree,
51 fs,
52 db_connection,
53 db: embedding_db,
54 language_registry,
55 embedding_provider,
56 entry_ids_being_indexed,
57 }
58 }
59
60 pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
61 &self.db
62 }
63
64 pub fn index_entries_changed_on_disk(
65 &self,
66 cx: &AppContext,
67 ) -> impl Future<Output = Result<()>> {
68 let worktree = self.worktree.read(cx).snapshot();
69 let worktree_abs_path = worktree.abs_path().clone();
70 let scan = self.scan_entries(worktree, cx);
71 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
72 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
73 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
74 async move {
75 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
76 Ok(())
77 }
78 }
79
80 pub fn index_updated_entries(
81 &self,
82 updated_entries: UpdatedEntriesSet,
83 cx: &AppContext,
84 ) -> impl Future<Output = Result<()>> {
85 let worktree = self.worktree.read(cx).snapshot();
86 let worktree_abs_path = worktree.abs_path().clone();
87 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
88 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
89 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
90 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
91 async move {
92 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
93 Ok(())
94 }
95 }
96
97 fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
98 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
99 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
100 let db_connection = self.db_connection.clone();
101 let db = self.db;
102 let entries_being_indexed = self.entry_ids_being_indexed.clone();
103 let task = cx.background_executor().spawn(async move {
104 let txn = db_connection
105 .read_txn()
106 .context("failed to create read transaction")?;
107 let mut db_entries = db
108 .iter(&txn)
109 .context("failed to create iterator")?
110 .move_between_keys()
111 .peekable();
112
113 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
114 for entry in worktree.files(false, 0) {
115 log::trace!("scanning for embedding index: {:?}", &entry.path);
116
117 let entry_db_key = db_key_for_path(&entry.path);
118
119 let mut saved_mtime = None;
120 while let Some(db_entry) = db_entries.peek() {
121 match db_entry {
122 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
123 Ordering::Less => {
124 if let Some(deletion_range) = deletion_range.as_mut() {
125 deletion_range.1 = Bound::Included(db_path);
126 } else {
127 deletion_range =
128 Some((Bound::Included(db_path), Bound::Included(db_path)));
129 }
130
131 db_entries.next();
132 }
133 Ordering::Equal => {
134 if let Some(deletion_range) = deletion_range.take() {
135 deleted_entry_ranges_tx
136 .send((
137 deletion_range.0.map(ToString::to_string),
138 deletion_range.1.map(ToString::to_string),
139 ))
140 .await?;
141 }
142 saved_mtime = db_embedded_file.mtime;
143 db_entries.next();
144 break;
145 }
146 Ordering::Greater => {
147 break;
148 }
149 },
150 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
151 }
152 }
153
154 if entry.mtime != saved_mtime {
155 let handle = entries_being_indexed.insert(entry.id);
156 updated_entries_tx.send((entry.clone(), handle)).await?;
157 }
158 }
159
160 if let Some(db_entry) = db_entries.next() {
161 let (db_path, _) = db_entry?;
162 deleted_entry_ranges_tx
163 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
164 .await?;
165 }
166
167 Ok(())
168 });
169
170 ScanEntries {
171 updated_entries: updated_entries_rx,
172 deleted_entry_ranges: deleted_entry_ranges_rx,
173 task,
174 }
175 }
176
177 fn scan_updated_entries(
178 &self,
179 worktree: Snapshot,
180 updated_entries: UpdatedEntriesSet,
181 cx: &AppContext,
182 ) -> ScanEntries {
183 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
184 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
185 let entries_being_indexed = self.entry_ids_being_indexed.clone();
186 let task = cx.background_executor().spawn(async move {
187 for (path, entry_id, status) in updated_entries.iter() {
188 match status {
189 project::PathChange::Added
190 | project::PathChange::Updated
191 | project::PathChange::AddedOrUpdated => {
192 if let Some(entry) = worktree.entry_for_id(*entry_id) {
193 if entry.is_file() {
194 let handle = entries_being_indexed.insert(entry.id);
195 updated_entries_tx.send((entry.clone(), handle)).await?;
196 }
197 }
198 }
199 project::PathChange::Removed => {
200 let db_path = db_key_for_path(path);
201 deleted_entry_ranges_tx
202 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
203 .await?;
204 }
205 project::PathChange::Loaded => {
206 // Do nothing.
207 }
208 }
209 }
210
211 Ok(())
212 });
213
214 ScanEntries {
215 updated_entries: updated_entries_rx,
216 deleted_entry_ranges: deleted_entry_ranges_rx,
217 task,
218 }
219 }
220
221 fn chunk_files(
222 &self,
223 worktree_abs_path: Arc<Path>,
224 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
225 cx: &AppContext,
226 ) -> ChunkFiles {
227 let language_registry = self.language_registry.clone();
228 let fs = self.fs.clone();
229 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
230 let task = cx.spawn(|cx| async move {
231 cx.background_executor()
232 .scoped(|cx| {
233 for _ in 0..cx.num_cpus() {
234 cx.spawn(async {
235 while let Ok((entry, handle)) = entries.recv().await {
236 let entry_abs_path = worktree_abs_path.join(&entry.path);
237 if let Some(text) = fs.load(&entry_abs_path).await.ok() {
238 let language = language_registry
239 .language_for_file_path(&entry.path)
240 .await
241 .ok();
242 let chunked_file = ChunkedFile {
243 chunks: chunking::chunk_text(
244 &text,
245 language.as_ref(),
246 &entry.path,
247 ),
248 handle,
249 path: entry.path,
250 mtime: entry.mtime,
251 text,
252 };
253
254 if chunked_files_tx.send(chunked_file).await.is_err() {
255 return;
256 }
257 }
258 }
259 });
260 }
261 })
262 .await;
263 Ok(())
264 });
265
266 ChunkFiles {
267 files: chunked_files_rx,
268 task,
269 }
270 }
271
272 pub fn embed_files(
273 embedding_provider: Arc<dyn EmbeddingProvider>,
274 chunked_files: channel::Receiver<ChunkedFile>,
275 cx: &AppContext,
276 ) -> EmbedFiles {
277 let embedding_provider = embedding_provider.clone();
278 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
279 let task = cx.background_executor().spawn(async move {
280 let mut chunked_file_batches =
281 chunked_files.chunks_timeout(512, Duration::from_secs(2));
282 while let Some(chunked_files) = chunked_file_batches.next().await {
283 // View the batch of files as a vec of chunks
284 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
285 // Once those are done, reassemble them back into the files in which they belong
286 // If any embeddings fail for a file, the entire file is discarded
287
288 let chunks: Vec<TextToEmbed> = chunked_files
289 .iter()
290 .flat_map(|file| {
291 file.chunks.iter().map(|chunk| TextToEmbed {
292 text: &file.text[chunk.range.clone()],
293 digest: chunk.digest,
294 })
295 })
296 .collect::<Vec<_>>();
297
298 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
299 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
300 if let Some(batch_embeddings) =
301 embedding_provider.embed(embedding_batch).await.log_err()
302 {
303 if batch_embeddings.len() == embedding_batch.len() {
304 embeddings.extend(batch_embeddings.into_iter().map(Some));
305 continue;
306 }
307 log::error!(
308 "embedding provider returned unexpected embedding count {}, expected {}",
309 batch_embeddings.len(), embedding_batch.len()
310 );
311 }
312
313 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
314 }
315
316 let mut embeddings = embeddings.into_iter();
317 for chunked_file in chunked_files {
318 let mut embedded_file = EmbeddedFile {
319 path: chunked_file.path,
320 mtime: chunked_file.mtime,
321 chunks: Vec::new(),
322 };
323
324 let mut embedded_all_chunks = true;
325 for (chunk, embedding) in
326 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
327 {
328 if let Some(embedding) = embedding {
329 embedded_file
330 .chunks
331 .push(EmbeddedChunk { chunk, embedding });
332 } else {
333 embedded_all_chunks = false;
334 }
335 }
336
337 if embedded_all_chunks {
338 embedded_files_tx
339 .send((embedded_file, chunked_file.handle))
340 .await?;
341 }
342 }
343 }
344 Ok(())
345 });
346
347 EmbedFiles {
348 files: embedded_files_rx,
349 task,
350 }
351 }
352
353 fn persist_embeddings(
354 &self,
355 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
356 mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
357 cx: &AppContext,
358 ) -> Task<Result<()>> {
359 let db_connection = self.db_connection.clone();
360 let db = self.db;
361
362 cx.background_executor().spawn(async move {
363 loop {
364 // Interleave deletions and persists of embedded files
365 futures::select_biased! {
366 deletion_range = deleted_entry_ranges.next() => {
367 if let Some(deletion_range) = deletion_range {
368 let mut txn = db_connection.write_txn()?;
369 let start = deletion_range.0.as_ref().map(|start| start.as_str());
370 let end = deletion_range.1.as_ref().map(|end| end.as_str());
371 log::debug!("deleting embeddings in range {:?}", &(start, end));
372 db.delete_range(&mut txn, &(start, end))?;
373 txn.commit()?;
374 }
375 },
376 file = embedded_files.next() => {
377 if let Some((file, _)) = file {
378 let mut txn = db_connection.write_txn()?;
379 log::debug!("saving embedding for file {:?}", file.path);
380 let key = db_key_for_path(&file.path);
381 db.put(&mut txn, &key, &file)?;
382 txn.commit()?;
383 }
384 },
385 complete => break,
386 }
387 }
388
389 Ok(())
390 })
391 }
392
393 pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
394 let connection = self.db_connection.clone();
395 let db = self.db;
396 cx.background_executor().spawn(async move {
397 let tx = connection
398 .read_txn()
399 .context("failed to create read transaction")?;
400 let result = db
401 .iter(&tx)?
402 .map(|entry| Ok(entry?.1.path.clone()))
403 .collect::<Result<Vec<Arc<Path>>>>();
404 drop(tx);
405 result
406 })
407 }
408
409 pub fn chunks_for_path(
410 &self,
411 path: Arc<Path>,
412 cx: &AppContext,
413 ) -> Task<Result<Vec<EmbeddedChunk>>> {
414 let connection = self.db_connection.clone();
415 let db = self.db;
416 cx.background_executor().spawn(async move {
417 let tx = connection
418 .read_txn()
419 .context("failed to create read transaction")?;
420 Ok(db
421 .get(&tx, &db_key_for_path(&path))?
422 .ok_or_else(|| anyhow!("no such path"))?
423 .chunks
424 .clone())
425 })
426 }
427}
428
429struct ScanEntries {
430 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
431 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
432 task: Task<Result<()>>,
433}
434
435struct ChunkFiles {
436 files: channel::Receiver<ChunkedFile>,
437 task: Task<Result<()>>,
438}
439
440pub struct ChunkedFile {
441 pub path: Arc<Path>,
442 pub mtime: Option<SystemTime>,
443 pub handle: IndexingEntryHandle,
444 pub text: String,
445 pub chunks: Vec<Chunk>,
446}
447
448pub struct EmbedFiles {
449 pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
450 pub task: Task<Result<()>>,
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub struct EmbeddedFile {
455 pub path: Arc<Path>,
456 pub mtime: Option<SystemTime>,
457 pub chunks: Vec<EmbeddedChunk>,
458}
459
460#[derive(Clone, Debug, Serialize, Deserialize)]
461pub struct EmbeddedChunk {
462 pub chunk: Chunk,
463 pub embedding: Embedding,
464}
465
466fn db_key_for_path(path: &Arc<Path>) -> String {
467 path.to_string_lossy().replace('/', "\0")
468}