1mod chunking;
2mod embedding;
3mod project_index_debug_view;
4
5use anyhow::{anyhow, Context as _, Result};
6use chunking::{chunk_text, Chunk};
7use collections::{Bound, HashMap, HashSet};
8pub use embedding::*;
9use fs::Fs;
10use futures::stream::StreamExt;
11use futures_batch::ChunksTimeoutStreamExt;
12use gpui::{
13 AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
14 Model, ModelContext, Subscription, Task, WeakModel,
15};
16use heed::types::{SerdeBincode, Str};
17use language::LanguageRegistry;
18use parking_lot::Mutex;
19use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
20use serde::{Deserialize, Serialize};
21use smol::channel;
22use std::{
23 cmp::Ordering,
24 future::Future,
25 iter,
26 num::NonZeroUsize,
27 ops::Range,
28 path::{Path, PathBuf},
29 sync::{Arc, Weak},
30 time::{Duration, SystemTime},
31};
32use util::ResultExt;
33use worktree::LocalSnapshot;
34
35pub use project_index_debug_view::ProjectIndexDebugView;
36
37pub struct SemanticIndex {
38 embedding_provider: Arc<dyn EmbeddingProvider>,
39 db_connection: heed::Env,
40 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
41}
42
43impl Global for SemanticIndex {}
44
45impl SemanticIndex {
46 pub async fn new(
47 db_path: PathBuf,
48 embedding_provider: Arc<dyn EmbeddingProvider>,
49 cx: &mut AsyncAppContext,
50 ) -> Result<Self> {
51 let db_connection = cx
52 .background_executor()
53 .spawn(async move {
54 std::fs::create_dir_all(&db_path)?;
55 unsafe {
56 heed::EnvOpenOptions::new()
57 .map_size(1024 * 1024 * 1024)
58 .max_dbs(3000)
59 .open(db_path)
60 }
61 })
62 .await
63 .context("opening database connection")?;
64
65 Ok(SemanticIndex {
66 db_connection,
67 embedding_provider,
68 project_indices: HashMap::default(),
69 })
70 }
71
72 pub fn project_index(
73 &mut self,
74 project: Model<Project>,
75 cx: &mut AppContext,
76 ) -> Model<ProjectIndex> {
77 let project_weak = project.downgrade();
78 project.update(cx, move |_, cx| {
79 cx.on_release(move |_, cx| {
80 if cx.has_global::<SemanticIndex>() {
81 cx.update_global::<SemanticIndex, _>(|this, _| {
82 this.project_indices.remove(&project_weak);
83 })
84 }
85 })
86 .detach();
87 });
88
89 self.project_indices
90 .entry(project.downgrade())
91 .or_insert_with(|| {
92 cx.new_model(|cx| {
93 ProjectIndex::new(
94 project,
95 self.db_connection.clone(),
96 self.embedding_provider.clone(),
97 cx,
98 )
99 })
100 })
101 .clone()
102 }
103}
104
105pub struct ProjectIndex {
106 db_connection: heed::Env,
107 project: WeakModel<Project>,
108 worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
109 language_registry: Arc<LanguageRegistry>,
110 fs: Arc<dyn Fs>,
111 last_status: Status,
112 status_tx: channel::Sender<()>,
113 embedding_provider: Arc<dyn EmbeddingProvider>,
114 _maintain_status: Task<()>,
115 _subscription: Subscription,
116}
117
118enum WorktreeIndexHandle {
119 Loading { _task: Task<Result<()>> },
120 Loaded { index: Model<WorktreeIndex> },
121}
122
123impl ProjectIndex {
124 fn new(
125 project: Model<Project>,
126 db_connection: heed::Env,
127 embedding_provider: Arc<dyn EmbeddingProvider>,
128 cx: &mut ModelContext<Self>,
129 ) -> Self {
130 let language_registry = project.read(cx).languages().clone();
131 let fs = project.read(cx).fs().clone();
132 let (status_tx, mut status_rx) = channel::unbounded();
133 let mut this = ProjectIndex {
134 db_connection,
135 project: project.downgrade(),
136 worktree_indices: HashMap::default(),
137 language_registry,
138 fs,
139 status_tx,
140 last_status: Status::Idle,
141 embedding_provider,
142 _subscription: cx.subscribe(&project, Self::handle_project_event),
143 _maintain_status: cx.spawn(|this, mut cx| async move {
144 while status_rx.next().await.is_some() {
145 if this
146 .update(&mut cx, |this, cx| this.update_status(cx))
147 .is_err()
148 {
149 break;
150 }
151 }
152 }),
153 };
154 this.update_worktree_indices(cx);
155 this
156 }
157
158 pub fn status(&self) -> Status {
159 self.last_status
160 }
161
162 pub fn project(&self) -> WeakModel<Project> {
163 self.project.clone()
164 }
165
166 pub fn fs(&self) -> Arc<dyn Fs> {
167 self.fs.clone()
168 }
169
170 fn handle_project_event(
171 &mut self,
172 _: Model<Project>,
173 event: &project::Event,
174 cx: &mut ModelContext<Self>,
175 ) {
176 match event {
177 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
178 self.update_worktree_indices(cx);
179 }
180 _ => {}
181 }
182 }
183
184 fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
185 let Some(project) = self.project.upgrade() else {
186 return;
187 };
188
189 let worktrees = project
190 .read(cx)
191 .visible_worktrees(cx)
192 .filter_map(|worktree| {
193 if worktree.read(cx).is_local() {
194 Some((worktree.entity_id(), worktree))
195 } else {
196 None
197 }
198 })
199 .collect::<HashMap<_, _>>();
200
201 self.worktree_indices
202 .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
203 for (worktree_id, worktree) in worktrees {
204 self.worktree_indices.entry(worktree_id).or_insert_with(|| {
205 let worktree_index = WorktreeIndex::load(
206 worktree.clone(),
207 self.db_connection.clone(),
208 self.language_registry.clone(),
209 self.fs.clone(),
210 self.status_tx.clone(),
211 self.embedding_provider.clone(),
212 cx,
213 );
214
215 let load_worktree = cx.spawn(|this, mut cx| async move {
216 if let Some(worktree_index) = worktree_index.await.log_err() {
217 this.update(&mut cx, |this, _| {
218 this.worktree_indices.insert(
219 worktree_id,
220 WorktreeIndexHandle::Loaded {
221 index: worktree_index,
222 },
223 );
224 })?;
225 } else {
226 this.update(&mut cx, |this, _cx| {
227 this.worktree_indices.remove(&worktree_id)
228 })?;
229 }
230
231 this.update(&mut cx, |this, cx| this.update_status(cx))
232 });
233
234 WorktreeIndexHandle::Loading {
235 _task: load_worktree,
236 }
237 });
238 }
239
240 self.update_status(cx);
241 }
242
243 fn update_status(&mut self, cx: &mut ModelContext<Self>) {
244 let mut indexing_count = 0;
245 let mut any_loading = false;
246
247 for index in self.worktree_indices.values_mut() {
248 match index {
249 WorktreeIndexHandle::Loading { .. } => {
250 any_loading = true;
251 break;
252 }
253 WorktreeIndexHandle::Loaded { index, .. } => {
254 indexing_count += index.read(cx).entry_ids_being_indexed.len();
255 }
256 }
257 }
258
259 let status = if any_loading {
260 Status::Loading
261 } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
262 Status::Scanning { remaining_count }
263 } else {
264 Status::Idle
265 };
266
267 if status != self.last_status {
268 self.last_status = status;
269 cx.emit(status);
270 }
271 }
272
273 pub fn search(
274 &self,
275 query: String,
276 limit: usize,
277 cx: &AppContext,
278 ) -> Task<Result<Vec<SearchResult>>> {
279 let (chunks_tx, chunks_rx) = channel::bounded(1024);
280 let mut worktree_scan_tasks = Vec::new();
281 for worktree_index in self.worktree_indices.values() {
282 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
283 let chunks_tx = chunks_tx.clone();
284 index.read_with(cx, |index, cx| {
285 let worktree_id = index.worktree.read(cx).id();
286 let db_connection = index.db_connection.clone();
287 let db = index.db;
288 worktree_scan_tasks.push(cx.background_executor().spawn({
289 async move {
290 let txn = db_connection
291 .read_txn()
292 .context("failed to create read transaction")?;
293 let db_entries = db.iter(&txn).context("failed to iterate database")?;
294 for db_entry in db_entries {
295 let (_key, db_embedded_file) = db_entry?;
296 for chunk in db_embedded_file.chunks {
297 chunks_tx
298 .send((worktree_id, db_embedded_file.path.clone(), chunk))
299 .await?;
300 }
301 }
302 anyhow::Ok(())
303 }
304 }));
305 })
306 }
307 }
308 drop(chunks_tx);
309
310 let project = self.project.clone();
311 let embedding_provider = self.embedding_provider.clone();
312 cx.spawn(|cx| async move {
313 #[cfg(debug_assertions)]
314 let embedding_query_start = std::time::Instant::now();
315 log::info!("Searching for {query}");
316
317 let query_embeddings = embedding_provider
318 .embed(&[TextToEmbed::new(&query)])
319 .await?;
320 let query_embedding = query_embeddings
321 .into_iter()
322 .next()
323 .ok_or_else(|| anyhow!("no embedding for query"))?;
324
325 let mut results_by_worker = Vec::new();
326 for _ in 0..cx.background_executor().num_cpus() {
327 results_by_worker.push(Vec::<WorktreeSearchResult>::new());
328 }
329
330 #[cfg(debug_assertions)]
331 let search_start = std::time::Instant::now();
332
333 cx.background_executor()
334 .scoped(|cx| {
335 for results in results_by_worker.iter_mut() {
336 cx.spawn(async {
337 while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
338 let score = chunk.embedding.similarity(&query_embedding);
339 let ix = match results.binary_search_by(|probe| {
340 score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
341 }) {
342 Ok(ix) | Err(ix) => ix,
343 };
344 results.insert(
345 ix,
346 WorktreeSearchResult {
347 worktree_id,
348 path: path.clone(),
349 range: chunk.chunk.range.clone(),
350 score,
351 },
352 );
353 results.truncate(limit);
354 }
355 });
356 }
357 })
358 .await;
359
360 futures::future::try_join_all(worktree_scan_tasks).await?;
361
362 project.read_with(&cx, |project, cx| {
363 let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
364 for worker_results in results_by_worker {
365 search_results.extend(worker_results.into_iter().filter_map(|result| {
366 Some(SearchResult {
367 worktree: project.worktree_for_id(result.worktree_id, cx)?,
368 path: result.path,
369 range: result.range,
370 score: result.score,
371 })
372 }));
373 }
374 search_results.sort_unstable_by(|a, b| {
375 b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
376 });
377 search_results.truncate(limit);
378
379 #[cfg(debug_assertions)]
380 {
381 let search_elapsed = search_start.elapsed();
382 log::debug!(
383 "searched {} entries in {:?}",
384 search_results.len(),
385 search_elapsed
386 );
387 let embedding_query_elapsed = embedding_query_start.elapsed();
388 log::debug!("embedding query took {:?}", embedding_query_elapsed);
389 }
390
391 search_results
392 })
393 })
394 }
395
396 #[cfg(test)]
397 pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
398 let mut result = 0;
399 for worktree_index in self.worktree_indices.values() {
400 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
401 result += index.read(cx).path_count()?;
402 }
403 }
404 Ok(result)
405 }
406
407 pub(crate) fn worktree_index(
408 &self,
409 worktree_id: WorktreeId,
410 cx: &AppContext,
411 ) -> Option<Model<WorktreeIndex>> {
412 for index in self.worktree_indices.values() {
413 if let WorktreeIndexHandle::Loaded { index, .. } = index {
414 if index.read(cx).worktree.read(cx).id() == worktree_id {
415 return Some(index.clone());
416 }
417 }
418 }
419 None
420 }
421
422 pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
423 let mut result = self
424 .worktree_indices
425 .values()
426 .filter_map(|index| {
427 if let WorktreeIndexHandle::Loaded { index, .. } = index {
428 Some(index.clone())
429 } else {
430 None
431 }
432 })
433 .collect::<Vec<_>>();
434 result.sort_by_key(|index| index.read(cx).worktree.read(cx).id());
435 result
436 }
437}
438
439pub struct SearchResult {
440 pub worktree: Model<Worktree>,
441 pub path: Arc<Path>,
442 pub range: Range<usize>,
443 pub score: f32,
444}
445
446pub struct WorktreeSearchResult {
447 pub worktree_id: WorktreeId,
448 pub path: Arc<Path>,
449 pub range: Range<usize>,
450 pub score: f32,
451}
452
453#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
454pub enum Status {
455 Idle,
456 Loading,
457 Scanning { remaining_count: NonZeroUsize },
458}
459
460impl EventEmitter<Status> for ProjectIndex {}
461
462struct WorktreeIndex {
463 worktree: Model<Worktree>,
464 db_connection: heed::Env,
465 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
466 language_registry: Arc<LanguageRegistry>,
467 fs: Arc<dyn Fs>,
468 embedding_provider: Arc<dyn EmbeddingProvider>,
469 entry_ids_being_indexed: Arc<IndexingEntrySet>,
470 _index_entries: Task<Result<()>>,
471 _subscription: Subscription,
472}
473
474impl WorktreeIndex {
475 pub fn load(
476 worktree: Model<Worktree>,
477 db_connection: heed::Env,
478 language_registry: Arc<LanguageRegistry>,
479 fs: Arc<dyn Fs>,
480 status_tx: channel::Sender<()>,
481 embedding_provider: Arc<dyn EmbeddingProvider>,
482 cx: &mut AppContext,
483 ) -> Task<Result<Model<Self>>> {
484 let worktree_abs_path = worktree.read(cx).abs_path();
485 cx.spawn(|mut cx| async move {
486 let db = cx
487 .background_executor()
488 .spawn({
489 let db_connection = db_connection.clone();
490 async move {
491 let mut txn = db_connection.write_txn()?;
492 let db_name = worktree_abs_path.to_string_lossy();
493 let db = db_connection.create_database(&mut txn, Some(&db_name))?;
494 txn.commit()?;
495 anyhow::Ok(db)
496 }
497 })
498 .await?;
499 cx.new_model(|cx| {
500 Self::new(
501 worktree,
502 db_connection,
503 db,
504 status_tx,
505 language_registry,
506 fs,
507 embedding_provider,
508 cx,
509 )
510 })
511 })
512 }
513
514 #[allow(clippy::too_many_arguments)]
515 fn new(
516 worktree: Model<Worktree>,
517 db_connection: heed::Env,
518 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
519 status: channel::Sender<()>,
520 language_registry: Arc<LanguageRegistry>,
521 fs: Arc<dyn Fs>,
522 embedding_provider: Arc<dyn EmbeddingProvider>,
523 cx: &mut ModelContext<Self>,
524 ) -> Self {
525 let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
526 let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
527 if let worktree::Event::UpdatedEntries(update) = event {
528 _ = updated_entries_tx.try_send(update.clone());
529 }
530 });
531
532 Self {
533 db_connection,
534 db,
535 worktree,
536 language_registry,
537 fs,
538 embedding_provider,
539 entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
540 _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
541 _subscription,
542 }
543 }
544
545 async fn index_entries(
546 this: WeakModel<Self>,
547 updated_entries: channel::Receiver<UpdatedEntriesSet>,
548 mut cx: AsyncAppContext,
549 ) -> Result<()> {
550 let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
551 index.await.log_err();
552
553 while let Ok(updated_entries) = updated_entries.recv().await {
554 let index = this.update(&mut cx, |this, cx| {
555 this.index_updated_entries(updated_entries, cx)
556 })?;
557 index.await.log_err();
558 }
559
560 Ok(())
561 }
562
563 fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
564 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
565 let worktree_abs_path = worktree.abs_path().clone();
566 let scan = self.scan_entries(worktree.clone(), cx);
567 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
568 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
569 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
570 async move {
571 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
572 Ok(())
573 }
574 }
575
576 fn index_updated_entries(
577 &self,
578 updated_entries: UpdatedEntriesSet,
579 cx: &AppContext,
580 ) -> impl Future<Output = Result<()>> {
581 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
582 let worktree_abs_path = worktree.abs_path().clone();
583 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
584 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
585 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
586 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
587 async move {
588 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
589 Ok(())
590 }
591 }
592
593 fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
594 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
595 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
596 let db_connection = self.db_connection.clone();
597 let db = self.db;
598 let entries_being_indexed = self.entry_ids_being_indexed.clone();
599 let task = cx.background_executor().spawn(async move {
600 let txn = db_connection
601 .read_txn()
602 .context("failed to create read transaction")?;
603 let mut db_entries = db
604 .iter(&txn)
605 .context("failed to create iterator")?
606 .move_between_keys()
607 .peekable();
608
609 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
610 for entry in worktree.files(false, 0) {
611 let entry_db_key = db_key_for_path(&entry.path);
612
613 let mut saved_mtime = None;
614 while let Some(db_entry) = db_entries.peek() {
615 match db_entry {
616 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
617 Ordering::Less => {
618 if let Some(deletion_range) = deletion_range.as_mut() {
619 deletion_range.1 = Bound::Included(db_path);
620 } else {
621 deletion_range =
622 Some((Bound::Included(db_path), Bound::Included(db_path)));
623 }
624
625 db_entries.next();
626 }
627 Ordering::Equal => {
628 if let Some(deletion_range) = deletion_range.take() {
629 deleted_entry_ranges_tx
630 .send((
631 deletion_range.0.map(ToString::to_string),
632 deletion_range.1.map(ToString::to_string),
633 ))
634 .await?;
635 }
636 saved_mtime = db_embedded_file.mtime;
637 db_entries.next();
638 break;
639 }
640 Ordering::Greater => {
641 break;
642 }
643 },
644 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
645 }
646 }
647
648 if entry.mtime != saved_mtime {
649 let handle = entries_being_indexed.insert(entry.id);
650 updated_entries_tx.send((entry.clone(), handle)).await?;
651 }
652 }
653
654 if let Some(db_entry) = db_entries.next() {
655 let (db_path, _) = db_entry?;
656 deleted_entry_ranges_tx
657 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
658 .await?;
659 }
660
661 Ok(())
662 });
663
664 ScanEntries {
665 updated_entries: updated_entries_rx,
666 deleted_entry_ranges: deleted_entry_ranges_rx,
667 task,
668 }
669 }
670
671 fn scan_updated_entries(
672 &self,
673 worktree: LocalSnapshot,
674 updated_entries: UpdatedEntriesSet,
675 cx: &AppContext,
676 ) -> ScanEntries {
677 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
678 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
679 let entries_being_indexed = self.entry_ids_being_indexed.clone();
680 let task = cx.background_executor().spawn(async move {
681 for (path, entry_id, status) in updated_entries.iter() {
682 match status {
683 project::PathChange::Added
684 | project::PathChange::Updated
685 | project::PathChange::AddedOrUpdated => {
686 if let Some(entry) = worktree.entry_for_id(*entry_id) {
687 if entry.is_file() {
688 let handle = entries_being_indexed.insert(entry.id);
689 updated_entries_tx.send((entry.clone(), handle)).await?;
690 }
691 }
692 }
693 project::PathChange::Removed => {
694 let db_path = db_key_for_path(path);
695 deleted_entry_ranges_tx
696 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
697 .await?;
698 }
699 project::PathChange::Loaded => {
700 // Do nothing.
701 }
702 }
703 }
704
705 Ok(())
706 });
707
708 ScanEntries {
709 updated_entries: updated_entries_rx,
710 deleted_entry_ranges: deleted_entry_ranges_rx,
711 task,
712 }
713 }
714
715 fn chunk_files(
716 &self,
717 worktree_abs_path: Arc<Path>,
718 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
719 cx: &AppContext,
720 ) -> ChunkFiles {
721 let language_registry = self.language_registry.clone();
722 let fs = self.fs.clone();
723 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
724 let task = cx.spawn(|cx| async move {
725 cx.background_executor()
726 .scoped(|cx| {
727 for _ in 0..cx.num_cpus() {
728 cx.spawn(async {
729 while let Ok((entry, handle)) = entries.recv().await {
730 let entry_abs_path = worktree_abs_path.join(&entry.path);
731 let Some(text) = fs
732 .load(&entry_abs_path)
733 .await
734 .with_context(|| {
735 format!("failed to read path {entry_abs_path:?}")
736 })
737 .log_err()
738 else {
739 continue;
740 };
741 let language = language_registry
742 .language_for_file_path(&entry.path)
743 .await
744 .ok();
745 let chunked_file = ChunkedFile {
746 chunks: chunk_text(&text, language.as_ref(), &entry.path),
747 handle,
748 path: entry.path,
749 mtime: entry.mtime,
750 text,
751 };
752
753 if chunked_files_tx.send(chunked_file).await.is_err() {
754 return;
755 }
756 }
757 });
758 }
759 })
760 .await;
761 Ok(())
762 });
763
764 ChunkFiles {
765 files: chunked_files_rx,
766 task,
767 }
768 }
769
770 fn embed_files(
771 embedding_provider: Arc<dyn EmbeddingProvider>,
772 chunked_files: channel::Receiver<ChunkedFile>,
773 cx: &AppContext,
774 ) -> EmbedFiles {
775 let embedding_provider = embedding_provider.clone();
776 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
777 let task = cx.background_executor().spawn(async move {
778 let mut chunked_file_batches =
779 chunked_files.chunks_timeout(512, Duration::from_secs(2));
780 while let Some(chunked_files) = chunked_file_batches.next().await {
781 // View the batch of files as a vec of chunks
782 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
783 // Once those are done, reassemble them back into the files in which they belong
784 // If any embeddings fail for a file, the entire file is discarded
785
786 let chunks: Vec<TextToEmbed> = chunked_files
787 .iter()
788 .flat_map(|file| {
789 file.chunks.iter().map(|chunk| TextToEmbed {
790 text: &file.text[chunk.range.clone()],
791 digest: chunk.digest,
792 })
793 })
794 .collect::<Vec<_>>();
795
796 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
797 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
798 if let Some(batch_embeddings) =
799 embedding_provider.embed(embedding_batch).await.log_err()
800 {
801 if batch_embeddings.len() == embedding_batch.len() {
802 embeddings.extend(batch_embeddings.into_iter().map(Some));
803 continue;
804 }
805 log::error!(
806 "embedding provider returned unexpected embedding count {}, expected {}",
807 batch_embeddings.len(), embedding_batch.len()
808 );
809 }
810
811 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
812 }
813
814 let mut embeddings = embeddings.into_iter();
815 for chunked_file in chunked_files {
816 let mut embedded_file = EmbeddedFile {
817 path: chunked_file.path,
818 mtime: chunked_file.mtime,
819 chunks: Vec::new(),
820 };
821
822 let mut embedded_all_chunks = true;
823 for (chunk, embedding) in
824 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
825 {
826 if let Some(embedding) = embedding {
827 embedded_file
828 .chunks
829 .push(EmbeddedChunk { chunk, embedding });
830 } else {
831 embedded_all_chunks = false;
832 }
833 }
834
835 if embedded_all_chunks {
836 embedded_files_tx
837 .send((embedded_file, chunked_file.handle))
838 .await?;
839 }
840 }
841 }
842 Ok(())
843 });
844
845 EmbedFiles {
846 files: embedded_files_rx,
847 task,
848 }
849 }
850
851 fn persist_embeddings(
852 &self,
853 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
854 embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
855 cx: &AppContext,
856 ) -> Task<Result<()>> {
857 let db_connection = self.db_connection.clone();
858 let db = self.db;
859 cx.background_executor().spawn(async move {
860 while let Some(deletion_range) = deleted_entry_ranges.next().await {
861 let mut txn = db_connection.write_txn()?;
862 let start = deletion_range.0.as_ref().map(|start| start.as_str());
863 let end = deletion_range.1.as_ref().map(|end| end.as_str());
864 log::debug!("deleting embeddings in range {:?}", &(start, end));
865 db.delete_range(&mut txn, &(start, end))?;
866 txn.commit()?;
867 }
868
869 let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
870 while let Some(embedded_files) = embedded_files.next().await {
871 let mut txn = db_connection.write_txn()?;
872 for (file, _) in &embedded_files {
873 log::debug!("saving embedding for file {:?}", file.path);
874 let key = db_key_for_path(&file.path);
875 db.put(&mut txn, &key, file)?;
876 }
877 txn.commit()?;
878
879 drop(embedded_files);
880 log::debug!("committed");
881 }
882
883 Ok(())
884 })
885 }
886
887 fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
888 let connection = self.db_connection.clone();
889 let db = self.db;
890 cx.background_executor().spawn(async move {
891 let tx = connection
892 .read_txn()
893 .context("failed to create read transaction")?;
894 let result = db
895 .iter(&tx)?
896 .map(|entry| Ok(entry?.1.path.clone()))
897 .collect::<Result<Vec<Arc<Path>>>>();
898 drop(tx);
899 result
900 })
901 }
902
903 fn chunks_for_path(
904 &self,
905 path: Arc<Path>,
906 cx: &AppContext,
907 ) -> Task<Result<Vec<EmbeddedChunk>>> {
908 let connection = self.db_connection.clone();
909 let db = self.db;
910 cx.background_executor().spawn(async move {
911 let tx = connection
912 .read_txn()
913 .context("failed to create read transaction")?;
914 Ok(db
915 .get(&tx, &db_key_for_path(&path))?
916 .ok_or_else(|| anyhow!("no such path"))?
917 .chunks
918 .clone())
919 })
920 }
921
922 #[cfg(test)]
923 fn path_count(&self) -> Result<u64> {
924 let txn = self
925 .db_connection
926 .read_txn()
927 .context("failed to create read transaction")?;
928 Ok(self.db.len(&txn)?)
929 }
930}
931
932struct ScanEntries {
933 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
934 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
935 task: Task<Result<()>>,
936}
937
938struct ChunkFiles {
939 files: channel::Receiver<ChunkedFile>,
940 task: Task<Result<()>>,
941}
942
943struct ChunkedFile {
944 pub path: Arc<Path>,
945 pub mtime: Option<SystemTime>,
946 pub handle: IndexingEntryHandle,
947 pub text: String,
948 pub chunks: Vec<Chunk>,
949}
950
951struct EmbedFiles {
952 files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
953 task: Task<Result<()>>,
954}
955
956#[derive(Debug, Serialize, Deserialize)]
957struct EmbeddedFile {
958 path: Arc<Path>,
959 mtime: Option<SystemTime>,
960 chunks: Vec<EmbeddedChunk>,
961}
962
963#[derive(Clone, Debug, Serialize, Deserialize)]
964struct EmbeddedChunk {
965 chunk: Chunk,
966 embedding: Embedding,
967}
968
969/// The set of entries that are currently being indexed.
970struct IndexingEntrySet {
971 entry_ids: Mutex<HashSet<ProjectEntryId>>,
972 tx: channel::Sender<()>,
973}
974
975/// When dropped, removes the entry from the set of entries that are being indexed.
976#[derive(Clone)]
977struct IndexingEntryHandle {
978 entry_id: ProjectEntryId,
979 set: Weak<IndexingEntrySet>,
980}
981
982impl IndexingEntrySet {
983 fn new(tx: channel::Sender<()>) -> Self {
984 Self {
985 entry_ids: Default::default(),
986 tx,
987 }
988 }
989
990 fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
991 self.entry_ids.lock().insert(entry_id);
992 self.tx.send_blocking(()).ok();
993 IndexingEntryHandle {
994 entry_id,
995 set: Arc::downgrade(self),
996 }
997 }
998
999 pub fn len(&self) -> usize {
1000 self.entry_ids.lock().len()
1001 }
1002}
1003
1004impl Drop for IndexingEntryHandle {
1005 fn drop(&mut self) {
1006 if let Some(set) = self.set.upgrade() {
1007 set.tx.send_blocking(()).ok();
1008 set.entry_ids.lock().remove(&self.entry_id);
1009 }
1010 }
1011}
1012
1013fn db_key_for_path(path: &Arc<Path>) -> String {
1014 path.to_string_lossy().replace('/', "\0")
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use super::*;
1020 use futures::{future::BoxFuture, FutureExt};
1021 use gpui::TestAppContext;
1022 use language::language_settings::AllLanguageSettings;
1023 use project::Project;
1024 use settings::SettingsStore;
1025 use std::{future, path::Path, sync::Arc};
1026
1027 fn init_test(cx: &mut TestAppContext) {
1028 _ = cx.update(|cx| {
1029 let store = SettingsStore::test(cx);
1030 cx.set_global(store);
1031 language::init(cx);
1032 Project::init_settings(cx);
1033 SettingsStore::update(cx, |store, cx| {
1034 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1035 });
1036 });
1037 }
1038
1039 pub struct TestEmbeddingProvider {
1040 batch_size: usize,
1041 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1042 }
1043
1044 impl TestEmbeddingProvider {
1045 pub fn new(
1046 batch_size: usize,
1047 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1048 ) -> Self {
1049 return Self {
1050 batch_size,
1051 compute_embedding: Box::new(compute_embedding),
1052 };
1053 }
1054 }
1055
1056 impl EmbeddingProvider for TestEmbeddingProvider {
1057 fn embed<'a>(
1058 &'a self,
1059 texts: &'a [TextToEmbed<'a>],
1060 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1061 let embeddings = texts
1062 .iter()
1063 .map(|to_embed| (self.compute_embedding)(to_embed.text))
1064 .collect();
1065 future::ready(embeddings).boxed()
1066 }
1067
1068 fn batch_size(&self) -> usize {
1069 self.batch_size
1070 }
1071 }
1072
1073 #[gpui::test]
1074 async fn test_search(cx: &mut TestAppContext) {
1075 cx.executor().allow_parking();
1076
1077 init_test(cx);
1078
1079 let temp_dir = tempfile::tempdir().unwrap();
1080
1081 let mut semantic_index = SemanticIndex::new(
1082 temp_dir.path().into(),
1083 Arc::new(TestEmbeddingProvider::new(16, |text| {
1084 let mut embedding = vec![0f32; 2];
1085 // if the text contains garbage, give it a 1 in the first dimension
1086 if text.contains("garbage in") {
1087 embedding[0] = 0.9;
1088 } else {
1089 embedding[0] = -0.9;
1090 }
1091
1092 if text.contains("garbage out") {
1093 embedding[1] = 0.9;
1094 } else {
1095 embedding[1] = -0.9;
1096 }
1097
1098 Ok(Embedding::new(embedding))
1099 })),
1100 &mut cx.to_async(),
1101 )
1102 .await
1103 .unwrap();
1104
1105 let project_path = Path::new("./fixture");
1106
1107 let project = cx
1108 .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1109 .await;
1110
1111 cx.update(|cx| {
1112 let language_registry = project.read(cx).languages().clone();
1113 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1114 languages::init(language_registry, node_runtime, cx);
1115 });
1116
1117 let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1118
1119 while project_index
1120 .read_with(cx, |index, cx| index.path_count(cx))
1121 .unwrap()
1122 == 0
1123 {
1124 project_index.next_event(cx).await;
1125 }
1126
1127 let results = cx
1128 .update(|cx| {
1129 let project_index = project_index.read(cx);
1130 let query = "garbage in, garbage out";
1131 project_index.search(query.into(), 4, cx)
1132 })
1133 .await
1134 .unwrap();
1135
1136 assert!(results.len() > 1, "should have found some results");
1137
1138 for result in &results {
1139 println!("result: {:?}", result.path);
1140 println!("score: {:?}", result.score);
1141 }
1142
1143 // Find result that is greater than 0.5
1144 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1145
1146 assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1147
1148 let content = cx
1149 .update(|cx| {
1150 let worktree = search_result.worktree.read(cx);
1151 let entry_abs_path = worktree.abs_path().join(&search_result.path);
1152 let fs = project.read(cx).fs().clone();
1153 cx.background_executor()
1154 .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
1155 })
1156 .await;
1157
1158 let range = search_result.range.clone();
1159 let content = content[range.clone()].to_owned();
1160
1161 assert!(content.contains("garbage in, garbage out"));
1162 }
1163
1164 #[gpui::test]
1165 async fn test_embed_files(cx: &mut TestAppContext) {
1166 cx.executor().allow_parking();
1167
1168 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1169 if text.contains('g') {
1170 Err(anyhow!("cannot embed text containing a 'g' character"))
1171 } else {
1172 Ok(Embedding::new(
1173 ('a'..'z')
1174 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1175 .collect(),
1176 ))
1177 }
1178 }));
1179
1180 let (indexing_progress_tx, _) = channel::unbounded();
1181 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1182
1183 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1184 chunked_files_tx
1185 .send_blocking(ChunkedFile {
1186 path: Path::new("test1.md").into(),
1187 mtime: None,
1188 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1189 text: "abcdefghijklmnop".to_string(),
1190 chunks: [0..4, 4..8, 8..12, 12..16]
1191 .into_iter()
1192 .map(|range| Chunk {
1193 range,
1194 digest: Default::default(),
1195 })
1196 .collect(),
1197 })
1198 .unwrap();
1199 chunked_files_tx
1200 .send_blocking(ChunkedFile {
1201 path: Path::new("test2.md").into(),
1202 mtime: None,
1203 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1204 text: "qrstuvwxyz".to_string(),
1205 chunks: [0..4, 4..8, 8..10]
1206 .into_iter()
1207 .map(|range| Chunk {
1208 range,
1209 digest: Default::default(),
1210 })
1211 .collect(),
1212 })
1213 .unwrap();
1214 chunked_files_tx.close();
1215
1216 let embed_files_task =
1217 cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1218 embed_files_task.task.await.unwrap();
1219
1220 let mut embedded_files_rx = embed_files_task.files;
1221 let mut embedded_files = Vec::new();
1222 while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1223 embedded_files.push(embedded_file);
1224 }
1225
1226 assert_eq!(embedded_files.len(), 1);
1227 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1228 assert_eq!(
1229 embedded_files[0]
1230 .chunks
1231 .iter()
1232 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1233 .collect::<Vec<Embedding>>(),
1234 vec![
1235 (provider.compute_embedding)("qrst").unwrap(),
1236 (provider.compute_embedding)("uvwx").unwrap(),
1237 (provider.compute_embedding)("yz").unwrap(),
1238 ],
1239 );
1240 }
1241}