1use crate::{
2 chunking::{self, Chunk},
3 embedding::{Embedding, EmbeddingProvider, TextToEmbed},
4 indexing::{IndexingEntryHandle, IndexingEntrySet},
5};
6use anyhow::{anyhow, Context as _, Result};
7use collections::Bound;
8use fs::Fs;
9use futures::stream::StreamExt;
10use futures_batch::ChunksTimeoutStreamExt;
11use gpui::{AppContext, Model, Task};
12use heed::types::{SerdeBincode, Str};
13use language::LanguageRegistry;
14use log;
15use project::{Entry, UpdatedEntriesSet, Worktree};
16use serde::{Deserialize, Serialize};
17use smol::channel;
18use std::{
19 cmp::Ordering,
20 future::Future,
21 iter,
22 path::Path,
23 sync::Arc,
24 time::{Duration, SystemTime},
25};
26use util::ResultExt;
27use worktree::Snapshot;
28
29pub struct EmbeddingIndex {
30 worktree: Model<Worktree>,
31 db_connection: heed::Env,
32 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
33 fs: Arc<dyn Fs>,
34 language_registry: Arc<LanguageRegistry>,
35 embedding_provider: Arc<dyn EmbeddingProvider>,
36 entry_ids_being_indexed: Arc<IndexingEntrySet>,
37}
38
39impl EmbeddingIndex {
40 pub fn new(
41 worktree: Model<Worktree>,
42 fs: Arc<dyn Fs>,
43 db_connection: heed::Env,
44 embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
45 language_registry: Arc<LanguageRegistry>,
46 embedding_provider: Arc<dyn EmbeddingProvider>,
47 entry_ids_being_indexed: Arc<IndexingEntrySet>,
48 ) -> Self {
49 Self {
50 worktree,
51 fs,
52 db_connection,
53 db: embedding_db,
54 language_registry,
55 embedding_provider,
56 entry_ids_being_indexed,
57 }
58 }
59
60 pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
61 &self.db
62 }
63
64 pub fn index_entries_changed_on_disk(
65 &self,
66 cx: &AppContext,
67 ) -> impl Future<Output = Result<()>> {
68 let worktree = self.worktree.read(cx).snapshot();
69 let worktree_abs_path = worktree.abs_path().clone();
70 let scan = self.scan_entries(worktree, cx);
71 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
72 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
73 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
74 async move {
75 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
76 Ok(())
77 }
78 }
79
80 pub fn index_updated_entries(
81 &self,
82 updated_entries: UpdatedEntriesSet,
83 cx: &AppContext,
84 ) -> impl Future<Output = Result<()>> {
85 let worktree = self.worktree.read(cx).snapshot();
86 let worktree_abs_path = worktree.abs_path().clone();
87 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
88 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
89 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
90 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
91 async move {
92 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
93 Ok(())
94 }
95 }
96
97 fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
98 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
99 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
100 let db_connection = self.db_connection.clone();
101 let db = self.db;
102 let entries_being_indexed = self.entry_ids_being_indexed.clone();
103 let task = cx.background_executor().spawn(async move {
104 let txn = db_connection
105 .read_txn()
106 .context("failed to create read transaction")?;
107 let mut db_entries = db
108 .iter(&txn)
109 .context("failed to create iterator")?
110 .move_between_keys()
111 .peekable();
112
113 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
114 for entry in worktree.files(false, 0) {
115 log::trace!("scanning for embedding index: {:?}", &entry.path);
116
117 let entry_db_key = db_key_for_path(&entry.path);
118
119 let mut saved_mtime = None;
120 while let Some(db_entry) = db_entries.peek() {
121 match db_entry {
122 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
123 Ordering::Less => {
124 if let Some(deletion_range) = deletion_range.as_mut() {
125 deletion_range.1 = Bound::Included(db_path);
126 } else {
127 deletion_range =
128 Some((Bound::Included(db_path), Bound::Included(db_path)));
129 }
130
131 db_entries.next();
132 }
133 Ordering::Equal => {
134 if let Some(deletion_range) = deletion_range.take() {
135 deleted_entry_ranges_tx
136 .send((
137 deletion_range.0.map(ToString::to_string),
138 deletion_range.1.map(ToString::to_string),
139 ))
140 .await?;
141 }
142 saved_mtime = db_embedded_file.mtime;
143 db_entries.next();
144 break;
145 }
146 Ordering::Greater => {
147 break;
148 }
149 },
150 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
151 }
152 }
153
154 if entry.mtime != saved_mtime {
155 let handle = entries_being_indexed.insert(entry.id);
156 updated_entries_tx.send((entry.clone(), handle)).await?;
157 }
158 }
159
160 if let Some(db_entry) = db_entries.next() {
161 let (db_path, _) = db_entry?;
162 deleted_entry_ranges_tx
163 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
164 .await?;
165 }
166
167 Ok(())
168 });
169
170 ScanEntries {
171 updated_entries: updated_entries_rx,
172 deleted_entry_ranges: deleted_entry_ranges_rx,
173 task,
174 }
175 }
176
177 fn scan_updated_entries(
178 &self,
179 worktree: Snapshot,
180 updated_entries: UpdatedEntriesSet,
181 cx: &AppContext,
182 ) -> ScanEntries {
183 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
184 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
185 let entries_being_indexed = self.entry_ids_being_indexed.clone();
186 let task = cx.background_executor().spawn(async move {
187 for (path, entry_id, status) in updated_entries.iter() {
188 match status {
189 project::PathChange::Added
190 | project::PathChange::Updated
191 | project::PathChange::AddedOrUpdated => {
192 if let Some(entry) = worktree.entry_for_id(*entry_id) {
193 if entry.is_file() {
194 let handle = entries_being_indexed.insert(entry.id);
195 updated_entries_tx.send((entry.clone(), handle)).await?;
196 }
197 }
198 }
199 project::PathChange::Removed => {
200 let db_path = db_key_for_path(path);
201 deleted_entry_ranges_tx
202 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
203 .await?;
204 }
205 project::PathChange::Loaded => {
206 // Do nothing.
207 }
208 }
209 }
210
211 Ok(())
212 });
213
214 ScanEntries {
215 updated_entries: updated_entries_rx,
216 deleted_entry_ranges: deleted_entry_ranges_rx,
217 task,
218 }
219 }
220
221 fn chunk_files(
222 &self,
223 worktree_abs_path: Arc<Path>,
224 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
225 cx: &AppContext,
226 ) -> ChunkFiles {
227 let language_registry = self.language_registry.clone();
228 let fs = self.fs.clone();
229 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
230 let task = cx.spawn(|cx| async move {
231 cx.background_executor()
232 .scoped(|cx| {
233 for _ in 0..cx.num_cpus() {
234 cx.spawn(async {
235 while let Ok((entry, handle)) = entries.recv().await {
236 let entry_abs_path = worktree_abs_path.join(&entry.path);
237 match fs.load(&entry_abs_path).await {
238 Ok(text) => {
239 let language = language_registry
240 .language_for_file_path(&entry.path)
241 .await
242 .ok();
243 let chunked_file = ChunkedFile {
244 chunks: chunking::chunk_text(
245 &text,
246 language.as_ref(),
247 &entry.path,
248 ),
249 handle,
250 path: entry.path,
251 mtime: entry.mtime,
252 text,
253 };
254
255 if chunked_files_tx.send(chunked_file).await.is_err() {
256 return;
257 }
258 }
259 Err(_)=> {
260 log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
261 }
262 }
263 }
264 });
265 }
266 })
267 .await;
268 Ok(())
269 });
270
271 ChunkFiles {
272 files: chunked_files_rx,
273 task,
274 }
275 }
276
277 pub fn embed_files(
278 embedding_provider: Arc<dyn EmbeddingProvider>,
279 chunked_files: channel::Receiver<ChunkedFile>,
280 cx: &AppContext,
281 ) -> EmbedFiles {
282 let embedding_provider = embedding_provider.clone();
283 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
284 let task = cx.background_executor().spawn(async move {
285 let mut chunked_file_batches =
286 chunked_files.chunks_timeout(512, Duration::from_secs(2));
287 while let Some(chunked_files) = chunked_file_batches.next().await {
288 // View the batch of files as a vec of chunks
289 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
290 // Once those are done, reassemble them back into the files in which they belong
291 // If any embeddings fail for a file, the entire file is discarded
292
293 let chunks: Vec<TextToEmbed> = chunked_files
294 .iter()
295 .flat_map(|file| {
296 file.chunks.iter().map(|chunk| TextToEmbed {
297 text: &file.text[chunk.range.clone()],
298 digest: chunk.digest,
299 })
300 })
301 .collect::<Vec<_>>();
302
303 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
304 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
305 if let Some(batch_embeddings) =
306 embedding_provider.embed(embedding_batch).await.log_err()
307 {
308 if batch_embeddings.len() == embedding_batch.len() {
309 embeddings.extend(batch_embeddings.into_iter().map(Some));
310 continue;
311 }
312 log::error!(
313 "embedding provider returned unexpected embedding count {}, expected {}",
314 batch_embeddings.len(), embedding_batch.len()
315 );
316 }
317
318 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
319 }
320
321 let mut embeddings = embeddings.into_iter();
322 for chunked_file in chunked_files {
323 let mut embedded_file = EmbeddedFile {
324 path: chunked_file.path,
325 mtime: chunked_file.mtime,
326 chunks: Vec::new(),
327 };
328
329 let mut embedded_all_chunks = true;
330 for (chunk, embedding) in
331 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
332 {
333 if let Some(embedding) = embedding {
334 embedded_file
335 .chunks
336 .push(EmbeddedChunk { chunk, embedding });
337 } else {
338 embedded_all_chunks = false;
339 }
340 }
341
342 if embedded_all_chunks {
343 embedded_files_tx
344 .send((embedded_file, chunked_file.handle))
345 .await?;
346 }
347 }
348 }
349 Ok(())
350 });
351
352 EmbedFiles {
353 files: embedded_files_rx,
354 task,
355 }
356 }
357
358 fn persist_embeddings(
359 &self,
360 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
361 embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
362 cx: &AppContext,
363 ) -> Task<Result<()>> {
364 let db_connection = self.db_connection.clone();
365 let db = self.db;
366 cx.background_executor().spawn(async move {
367 while let Some(deletion_range) = deleted_entry_ranges.next().await {
368 let mut txn = db_connection.write_txn()?;
369 let start = deletion_range.0.as_ref().map(|start| start.as_str());
370 let end = deletion_range.1.as_ref().map(|end| end.as_str());
371 log::debug!("deleting embeddings in range {:?}", &(start, end));
372 db.delete_range(&mut txn, &(start, end))?;
373 txn.commit()?;
374 }
375
376 let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
377 while let Some(embedded_files) = embedded_files.next().await {
378 let mut txn = db_connection.write_txn()?;
379 for (file, _) in &embedded_files {
380 log::debug!("saving embedding for file {:?}", file.path);
381 let key = db_key_for_path(&file.path);
382 db.put(&mut txn, &key, file)?;
383 }
384 txn.commit()?;
385
386 drop(embedded_files);
387 log::debug!("committed");
388 }
389
390 Ok(())
391 })
392 }
393
394 pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
395 let connection = self.db_connection.clone();
396 let db = self.db;
397 cx.background_executor().spawn(async move {
398 let tx = connection
399 .read_txn()
400 .context("failed to create read transaction")?;
401 let result = db
402 .iter(&tx)?
403 .map(|entry| Ok(entry?.1.path.clone()))
404 .collect::<Result<Vec<Arc<Path>>>>();
405 drop(tx);
406 result
407 })
408 }
409
410 pub fn chunks_for_path(
411 &self,
412 path: Arc<Path>,
413 cx: &AppContext,
414 ) -> Task<Result<Vec<EmbeddedChunk>>> {
415 let connection = self.db_connection.clone();
416 let db = self.db;
417 cx.background_executor().spawn(async move {
418 let tx = connection
419 .read_txn()
420 .context("failed to create read transaction")?;
421 Ok(db
422 .get(&tx, &db_key_for_path(&path))?
423 .ok_or_else(|| anyhow!("no such path"))?
424 .chunks
425 .clone())
426 })
427 }
428}
429
430struct ScanEntries {
431 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
432 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
433 task: Task<Result<()>>,
434}
435
436struct ChunkFiles {
437 files: channel::Receiver<ChunkedFile>,
438 task: Task<Result<()>>,
439}
440
441pub struct ChunkedFile {
442 pub path: Arc<Path>,
443 pub mtime: Option<SystemTime>,
444 pub handle: IndexingEntryHandle,
445 pub text: String,
446 pub chunks: Vec<Chunk>,
447}
448
449pub struct EmbedFiles {
450 pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
451 pub task: Task<Result<()>>,
452}
453
454#[derive(Debug, Serialize, Deserialize)]
455pub struct EmbeddedFile {
456 pub path: Arc<Path>,
457 pub mtime: Option<SystemTime>,
458 pub chunks: Vec<EmbeddedChunk>,
459}
460
461#[derive(Clone, Debug, Serialize, Deserialize)]
462pub struct EmbeddedChunk {
463 pub chunk: Chunk,
464 pub embedding: Embedding,
465}
466
467fn db_key_for_path(path: &Arc<Path>) -> String {
468 path.to_string_lossy().replace('/', "\0")
469}