1mod chunking;
2mod embedding;
3mod project_index_debug_view;
4
5use anyhow::{anyhow, Context as _, Result};
6use chunking::{chunk_text, Chunk};
7use collections::{Bound, HashMap, HashSet};
8pub use embedding::*;
9use fs::Fs;
10use futures::stream::StreamExt;
11use futures_batch::ChunksTimeoutStreamExt;
12use gpui::{
13 AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
14 Model, ModelContext, Subscription, Task, WeakModel,
15};
16use heed::types::{SerdeBincode, Str};
17use language::LanguageRegistry;
18use parking_lot::Mutex;
19use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
20use serde::{Deserialize, Serialize};
21use smol::channel;
22use std::{
23 cmp::Ordering,
24 future::Future,
25 iter,
26 num::NonZeroUsize,
27 ops::Range,
28 path::{Path, PathBuf},
29 sync::{Arc, Weak},
30 time::{Duration, SystemTime},
31};
32use util::ResultExt;
33use worktree::LocalSnapshot;
34
35pub use project_index_debug_view::ProjectIndexDebugView;
36
37pub struct SemanticIndex {
38 embedding_provider: Arc<dyn EmbeddingProvider>,
39 db_connection: heed::Env,
40 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
41}
42
43impl Global for SemanticIndex {}
44
45impl SemanticIndex {
46 pub async fn new(
47 db_path: PathBuf,
48 embedding_provider: Arc<dyn EmbeddingProvider>,
49 cx: &mut AsyncAppContext,
50 ) -> Result<Self> {
51 let db_connection = cx
52 .background_executor()
53 .spawn(async move {
54 std::fs::create_dir_all(&db_path)?;
55 unsafe {
56 heed::EnvOpenOptions::new()
57 .map_size(1024 * 1024 * 1024)
58 .max_dbs(3000)
59 .open(db_path)
60 }
61 })
62 .await
63 .context("opening database connection")?;
64
65 Ok(SemanticIndex {
66 db_connection,
67 embedding_provider,
68 project_indices: HashMap::default(),
69 })
70 }
71
72 pub fn project_index(
73 &mut self,
74 project: Model<Project>,
75 cx: &mut AppContext,
76 ) -> Model<ProjectIndex> {
77 let project_weak = project.downgrade();
78 project.update(cx, move |_, cx| {
79 cx.on_release(move |_, cx| {
80 if cx.has_global::<SemanticIndex>() {
81 cx.update_global::<SemanticIndex, _>(|this, _| {
82 this.project_indices.remove(&project_weak);
83 })
84 }
85 })
86 .detach();
87 });
88
89 self.project_indices
90 .entry(project.downgrade())
91 .or_insert_with(|| {
92 cx.new_model(|cx| {
93 ProjectIndex::new(
94 project,
95 self.db_connection.clone(),
96 self.embedding_provider.clone(),
97 cx,
98 )
99 })
100 })
101 .clone()
102 }
103}
104
105pub struct ProjectIndex {
106 db_connection: heed::Env,
107 project: WeakModel<Project>,
108 worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
109 language_registry: Arc<LanguageRegistry>,
110 fs: Arc<dyn Fs>,
111 last_status: Status,
112 status_tx: channel::Sender<()>,
113 embedding_provider: Arc<dyn EmbeddingProvider>,
114 _maintain_status: Task<()>,
115 _subscription: Subscription,
116}
117
118enum WorktreeIndexHandle {
119 Loading { _task: Task<Result<()>> },
120 Loaded { index: Model<WorktreeIndex> },
121}
122
123impl ProjectIndex {
124 fn new(
125 project: Model<Project>,
126 db_connection: heed::Env,
127 embedding_provider: Arc<dyn EmbeddingProvider>,
128 cx: &mut ModelContext<Self>,
129 ) -> Self {
130 let language_registry = project.read(cx).languages().clone();
131 let fs = project.read(cx).fs().clone();
132 let (status_tx, mut status_rx) = channel::unbounded();
133 let mut this = ProjectIndex {
134 db_connection,
135 project: project.downgrade(),
136 worktree_indices: HashMap::default(),
137 language_registry,
138 fs,
139 status_tx,
140 last_status: Status::Idle,
141 embedding_provider,
142 _subscription: cx.subscribe(&project, Self::handle_project_event),
143 _maintain_status: cx.spawn(|this, mut cx| async move {
144 while status_rx.next().await.is_some() {
145 if this
146 .update(&mut cx, |this, cx| this.update_status(cx))
147 .is_err()
148 {
149 break;
150 }
151 }
152 }),
153 };
154 this.update_worktree_indices(cx);
155 this
156 }
157
158 pub fn status(&self) -> Status {
159 self.last_status
160 }
161
162 pub fn project(&self) -> WeakModel<Project> {
163 self.project.clone()
164 }
165
166 fn handle_project_event(
167 &mut self,
168 _: Model<Project>,
169 event: &project::Event,
170 cx: &mut ModelContext<Self>,
171 ) {
172 match event {
173 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
174 self.update_worktree_indices(cx);
175 }
176 _ => {}
177 }
178 }
179
180 fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
181 let Some(project) = self.project.upgrade() else {
182 return;
183 };
184
185 let worktrees = project
186 .read(cx)
187 .visible_worktrees(cx)
188 .filter_map(|worktree| {
189 if worktree.read(cx).is_local() {
190 Some((worktree.entity_id(), worktree))
191 } else {
192 None
193 }
194 })
195 .collect::<HashMap<_, _>>();
196
197 self.worktree_indices
198 .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
199 for (worktree_id, worktree) in worktrees {
200 self.worktree_indices.entry(worktree_id).or_insert_with(|| {
201 let worktree_index = WorktreeIndex::load(
202 worktree.clone(),
203 self.db_connection.clone(),
204 self.language_registry.clone(),
205 self.fs.clone(),
206 self.status_tx.clone(),
207 self.embedding_provider.clone(),
208 cx,
209 );
210
211 let load_worktree = cx.spawn(|this, mut cx| async move {
212 if let Some(worktree_index) = worktree_index.await.log_err() {
213 this.update(&mut cx, |this, _| {
214 this.worktree_indices.insert(
215 worktree_id,
216 WorktreeIndexHandle::Loaded {
217 index: worktree_index,
218 },
219 );
220 })?;
221 } else {
222 this.update(&mut cx, |this, _cx| {
223 this.worktree_indices.remove(&worktree_id)
224 })?;
225 }
226
227 this.update(&mut cx, |this, cx| this.update_status(cx))
228 });
229
230 WorktreeIndexHandle::Loading {
231 _task: load_worktree,
232 }
233 });
234 }
235
236 self.update_status(cx);
237 }
238
239 fn update_status(&mut self, cx: &mut ModelContext<Self>) {
240 let mut indexing_count = 0;
241 let mut any_loading = false;
242
243 for index in self.worktree_indices.values_mut() {
244 match index {
245 WorktreeIndexHandle::Loading { .. } => {
246 any_loading = true;
247 break;
248 }
249 WorktreeIndexHandle::Loaded { index, .. } => {
250 indexing_count += index.read(cx).entry_ids_being_indexed.len();
251 }
252 }
253 }
254
255 let status = if any_loading {
256 Status::Loading
257 } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
258 Status::Scanning { remaining_count }
259 } else {
260 Status::Idle
261 };
262
263 if status != self.last_status {
264 self.last_status = status;
265 cx.emit(status);
266 }
267 }
268
269 pub fn search(
270 &self,
271 query: String,
272 limit: usize,
273 cx: &AppContext,
274 ) -> Task<Result<Vec<SearchResult>>> {
275 let (chunks_tx, chunks_rx) = channel::bounded(1024);
276 let mut worktree_scan_tasks = Vec::new();
277 for worktree_index in self.worktree_indices.values() {
278 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
279 let chunks_tx = chunks_tx.clone();
280 index.read_with(cx, |index, cx| {
281 let worktree_id = index.worktree.read(cx).id();
282 let db_connection = index.db_connection.clone();
283 let db = index.db;
284 worktree_scan_tasks.push(cx.background_executor().spawn({
285 async move {
286 let txn = db_connection
287 .read_txn()
288 .context("failed to create read transaction")?;
289 let db_entries = db.iter(&txn).context("failed to iterate database")?;
290 for db_entry in db_entries {
291 let (_key, db_embedded_file) = db_entry?;
292 for chunk in db_embedded_file.chunks {
293 chunks_tx
294 .send((worktree_id, db_embedded_file.path.clone(), chunk))
295 .await?;
296 }
297 }
298 anyhow::Ok(())
299 }
300 }));
301 })
302 }
303 }
304 drop(chunks_tx);
305
306 let project = self.project.clone();
307 let embedding_provider = self.embedding_provider.clone();
308 cx.spawn(|cx| async move {
309 #[cfg(debug_assertions)]
310 let embedding_query_start = std::time::Instant::now();
311 log::info!("Searching for {query}");
312
313 let query_embeddings = embedding_provider
314 .embed(&[TextToEmbed::new(&query)])
315 .await?;
316 let query_embedding = query_embeddings
317 .into_iter()
318 .next()
319 .ok_or_else(|| anyhow!("no embedding for query"))?;
320
321 let mut results_by_worker = Vec::new();
322 for _ in 0..cx.background_executor().num_cpus() {
323 results_by_worker.push(Vec::<WorktreeSearchResult>::new());
324 }
325
326 #[cfg(debug_assertions)]
327 let search_start = std::time::Instant::now();
328
329 cx.background_executor()
330 .scoped(|cx| {
331 for results in results_by_worker.iter_mut() {
332 cx.spawn(async {
333 while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
334 let score = chunk.embedding.similarity(&query_embedding);
335 let ix = match results.binary_search_by(|probe| {
336 score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
337 }) {
338 Ok(ix) | Err(ix) => ix,
339 };
340 results.insert(
341 ix,
342 WorktreeSearchResult {
343 worktree_id,
344 path: path.clone(),
345 range: chunk.chunk.range.clone(),
346 score,
347 },
348 );
349 results.truncate(limit);
350 }
351 });
352 }
353 })
354 .await;
355
356 futures::future::try_join_all(worktree_scan_tasks).await?;
357
358 project.read_with(&cx, |project, cx| {
359 let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
360 for worker_results in results_by_worker {
361 search_results.extend(worker_results.into_iter().filter_map(|result| {
362 Some(SearchResult {
363 worktree: project.worktree_for_id(result.worktree_id, cx)?,
364 path: result.path,
365 range: result.range,
366 score: result.score,
367 })
368 }));
369 }
370 search_results.sort_unstable_by(|a, b| {
371 b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
372 });
373 search_results.truncate(limit);
374
375 #[cfg(debug_assertions)]
376 {
377 let search_elapsed = search_start.elapsed();
378 log::debug!(
379 "searched {} entries in {:?}",
380 search_results.len(),
381 search_elapsed
382 );
383 let embedding_query_elapsed = embedding_query_start.elapsed();
384 log::debug!("embedding query took {:?}", embedding_query_elapsed);
385 }
386
387 search_results
388 })
389 })
390 }
391
392 #[cfg(test)]
393 pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
394 let mut result = 0;
395 for worktree_index in self.worktree_indices.values() {
396 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
397 result += index.read(cx).path_count()?;
398 }
399 }
400 Ok(result)
401 }
402
403 pub(crate) fn worktree_index(
404 &self,
405 worktree_id: WorktreeId,
406 cx: &AppContext,
407 ) -> Option<Model<WorktreeIndex>> {
408 for index in self.worktree_indices.values() {
409 if let WorktreeIndexHandle::Loaded { index, .. } = index {
410 if index.read(cx).worktree.read(cx).id() == worktree_id {
411 return Some(index.clone());
412 }
413 }
414 }
415 None
416 }
417
418 pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
419 let mut result = self
420 .worktree_indices
421 .values()
422 .filter_map(|index| {
423 if let WorktreeIndexHandle::Loaded { index, .. } = index {
424 Some(index.clone())
425 } else {
426 None
427 }
428 })
429 .collect::<Vec<_>>();
430 result.sort_by_key(|index| index.read(cx).worktree.read(cx).id());
431 result
432 }
433}
434
435pub struct SearchResult {
436 pub worktree: Model<Worktree>,
437 pub path: Arc<Path>,
438 pub range: Range<usize>,
439 pub score: f32,
440}
441
442pub struct WorktreeSearchResult {
443 pub worktree_id: WorktreeId,
444 pub path: Arc<Path>,
445 pub range: Range<usize>,
446 pub score: f32,
447}
448
449#[derive(Copy, Clone, Debug, Eq, PartialEq)]
450pub enum Status {
451 Idle,
452 Loading,
453 Scanning { remaining_count: NonZeroUsize },
454}
455
456impl EventEmitter<Status> for ProjectIndex {}
457
458struct WorktreeIndex {
459 worktree: Model<Worktree>,
460 db_connection: heed::Env,
461 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
462 language_registry: Arc<LanguageRegistry>,
463 fs: Arc<dyn Fs>,
464 embedding_provider: Arc<dyn EmbeddingProvider>,
465 entry_ids_being_indexed: Arc<IndexingEntrySet>,
466 _index_entries: Task<Result<()>>,
467 _subscription: Subscription,
468}
469
470impl WorktreeIndex {
471 pub fn load(
472 worktree: Model<Worktree>,
473 db_connection: heed::Env,
474 language_registry: Arc<LanguageRegistry>,
475 fs: Arc<dyn Fs>,
476 status_tx: channel::Sender<()>,
477 embedding_provider: Arc<dyn EmbeddingProvider>,
478 cx: &mut AppContext,
479 ) -> Task<Result<Model<Self>>> {
480 let worktree_abs_path = worktree.read(cx).abs_path();
481 cx.spawn(|mut cx| async move {
482 let db = cx
483 .background_executor()
484 .spawn({
485 let db_connection = db_connection.clone();
486 async move {
487 let mut txn = db_connection.write_txn()?;
488 let db_name = worktree_abs_path.to_string_lossy();
489 let db = db_connection.create_database(&mut txn, Some(&db_name))?;
490 txn.commit()?;
491 anyhow::Ok(db)
492 }
493 })
494 .await?;
495 cx.new_model(|cx| {
496 Self::new(
497 worktree,
498 db_connection,
499 db,
500 status_tx,
501 language_registry,
502 fs,
503 embedding_provider,
504 cx,
505 )
506 })
507 })
508 }
509
510 #[allow(clippy::too_many_arguments)]
511 fn new(
512 worktree: Model<Worktree>,
513 db_connection: heed::Env,
514 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
515 status: channel::Sender<()>,
516 language_registry: Arc<LanguageRegistry>,
517 fs: Arc<dyn Fs>,
518 embedding_provider: Arc<dyn EmbeddingProvider>,
519 cx: &mut ModelContext<Self>,
520 ) -> Self {
521 let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
522 let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
523 if let worktree::Event::UpdatedEntries(update) = event {
524 _ = updated_entries_tx.try_send(update.clone());
525 }
526 });
527
528 Self {
529 db_connection,
530 db,
531 worktree,
532 language_registry,
533 fs,
534 embedding_provider,
535 entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
536 _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
537 _subscription,
538 }
539 }
540
541 async fn index_entries(
542 this: WeakModel<Self>,
543 updated_entries: channel::Receiver<UpdatedEntriesSet>,
544 mut cx: AsyncAppContext,
545 ) -> Result<()> {
546 let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
547 index.await.log_err();
548
549 while let Ok(updated_entries) = updated_entries.recv().await {
550 let index = this.update(&mut cx, |this, cx| {
551 this.index_updated_entries(updated_entries, cx)
552 })?;
553 index.await.log_err();
554 }
555
556 Ok(())
557 }
558
559 fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
560 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
561 let worktree_abs_path = worktree.abs_path().clone();
562 let scan = self.scan_entries(worktree.clone(), cx);
563 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
564 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
565 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
566 async move {
567 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
568 Ok(())
569 }
570 }
571
572 fn index_updated_entries(
573 &self,
574 updated_entries: UpdatedEntriesSet,
575 cx: &AppContext,
576 ) -> impl Future<Output = Result<()>> {
577 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
578 let worktree_abs_path = worktree.abs_path().clone();
579 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
580 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
581 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
582 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
583 async move {
584 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
585 Ok(())
586 }
587 }
588
589 fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
590 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
591 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
592 let db_connection = self.db_connection.clone();
593 let db = self.db;
594 let entries_being_indexed = self.entry_ids_being_indexed.clone();
595 let task = cx.background_executor().spawn(async move {
596 let txn = db_connection
597 .read_txn()
598 .context("failed to create read transaction")?;
599 let mut db_entries = db
600 .iter(&txn)
601 .context("failed to create iterator")?
602 .move_between_keys()
603 .peekable();
604
605 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
606 for entry in worktree.files(false, 0) {
607 let entry_db_key = db_key_for_path(&entry.path);
608
609 let mut saved_mtime = None;
610 while let Some(db_entry) = db_entries.peek() {
611 match db_entry {
612 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
613 Ordering::Less => {
614 if let Some(deletion_range) = deletion_range.as_mut() {
615 deletion_range.1 = Bound::Included(db_path);
616 } else {
617 deletion_range =
618 Some((Bound::Included(db_path), Bound::Included(db_path)));
619 }
620
621 db_entries.next();
622 }
623 Ordering::Equal => {
624 if let Some(deletion_range) = deletion_range.take() {
625 deleted_entry_ranges_tx
626 .send((
627 deletion_range.0.map(ToString::to_string),
628 deletion_range.1.map(ToString::to_string),
629 ))
630 .await?;
631 }
632 saved_mtime = db_embedded_file.mtime;
633 db_entries.next();
634 break;
635 }
636 Ordering::Greater => {
637 break;
638 }
639 },
640 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
641 }
642 }
643
644 if entry.mtime != saved_mtime {
645 let handle = entries_being_indexed.insert(entry.id);
646 updated_entries_tx.send((entry.clone(), handle)).await?;
647 }
648 }
649
650 if let Some(db_entry) = db_entries.next() {
651 let (db_path, _) = db_entry?;
652 deleted_entry_ranges_tx
653 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
654 .await?;
655 }
656
657 Ok(())
658 });
659
660 ScanEntries {
661 updated_entries: updated_entries_rx,
662 deleted_entry_ranges: deleted_entry_ranges_rx,
663 task,
664 }
665 }
666
667 fn scan_updated_entries(
668 &self,
669 worktree: LocalSnapshot,
670 updated_entries: UpdatedEntriesSet,
671 cx: &AppContext,
672 ) -> ScanEntries {
673 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
674 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
675 let entries_being_indexed = self.entry_ids_being_indexed.clone();
676 let task = cx.background_executor().spawn(async move {
677 for (path, entry_id, status) in updated_entries.iter() {
678 match status {
679 project::PathChange::Added
680 | project::PathChange::Updated
681 | project::PathChange::AddedOrUpdated => {
682 if let Some(entry) = worktree.entry_for_id(*entry_id) {
683 if entry.is_file() {
684 let handle = entries_being_indexed.insert(entry.id);
685 updated_entries_tx.send((entry.clone(), handle)).await?;
686 }
687 }
688 }
689 project::PathChange::Removed => {
690 let db_path = db_key_for_path(path);
691 deleted_entry_ranges_tx
692 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
693 .await?;
694 }
695 project::PathChange::Loaded => {
696 // Do nothing.
697 }
698 }
699 }
700
701 Ok(())
702 });
703
704 ScanEntries {
705 updated_entries: updated_entries_rx,
706 deleted_entry_ranges: deleted_entry_ranges_rx,
707 task,
708 }
709 }
710
711 fn chunk_files(
712 &self,
713 worktree_abs_path: Arc<Path>,
714 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
715 cx: &AppContext,
716 ) -> ChunkFiles {
717 let language_registry = self.language_registry.clone();
718 let fs = self.fs.clone();
719 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
720 let task = cx.spawn(|cx| async move {
721 cx.background_executor()
722 .scoped(|cx| {
723 for _ in 0..cx.num_cpus() {
724 cx.spawn(async {
725 while let Ok((entry, handle)) = entries.recv().await {
726 let entry_abs_path = worktree_abs_path.join(&entry.path);
727 let Some(text) = fs
728 .load(&entry_abs_path)
729 .await
730 .with_context(|| {
731 format!("failed to read path {entry_abs_path:?}")
732 })
733 .log_err()
734 else {
735 continue;
736 };
737 let language = language_registry
738 .language_for_file_path(&entry.path)
739 .await
740 .ok();
741 let chunked_file = ChunkedFile {
742 chunks: chunk_text(&text, language.as_ref(), &entry.path),
743 handle,
744 path: entry.path,
745 mtime: entry.mtime,
746 text,
747 };
748
749 if chunked_files_tx.send(chunked_file).await.is_err() {
750 return;
751 }
752 }
753 });
754 }
755 })
756 .await;
757 Ok(())
758 });
759
760 ChunkFiles {
761 files: chunked_files_rx,
762 task,
763 }
764 }
765
766 fn embed_files(
767 embedding_provider: Arc<dyn EmbeddingProvider>,
768 chunked_files: channel::Receiver<ChunkedFile>,
769 cx: &AppContext,
770 ) -> EmbedFiles {
771 let embedding_provider = embedding_provider.clone();
772 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
773 let task = cx.background_executor().spawn(async move {
774 let mut chunked_file_batches =
775 chunked_files.chunks_timeout(512, Duration::from_secs(2));
776 while let Some(chunked_files) = chunked_file_batches.next().await {
777 // View the batch of files as a vec of chunks
778 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
779 // Once those are done, reassemble them back into the files in which they belong
780 // If any embeddings fail for a file, the entire file is discarded
781
782 let chunks: Vec<TextToEmbed> = chunked_files
783 .iter()
784 .flat_map(|file| {
785 file.chunks.iter().map(|chunk| TextToEmbed {
786 text: &file.text[chunk.range.clone()],
787 digest: chunk.digest,
788 })
789 })
790 .collect::<Vec<_>>();
791
792 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
793 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
794 if let Some(batch_embeddings) =
795 embedding_provider.embed(embedding_batch).await.log_err()
796 {
797 if batch_embeddings.len() == embedding_batch.len() {
798 embeddings.extend(batch_embeddings.into_iter().map(Some));
799 continue;
800 }
801 log::error!(
802 "embedding provider returned unexpected embedding count {}, expected {}",
803 batch_embeddings.len(), embedding_batch.len()
804 );
805 }
806
807 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
808 }
809
810 let mut embeddings = embeddings.into_iter();
811 for chunked_file in chunked_files {
812 let mut embedded_file = EmbeddedFile {
813 path: chunked_file.path,
814 mtime: chunked_file.mtime,
815 chunks: Vec::new(),
816 };
817
818 let mut embedded_all_chunks = true;
819 for (chunk, embedding) in
820 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
821 {
822 if let Some(embedding) = embedding {
823 embedded_file
824 .chunks
825 .push(EmbeddedChunk { chunk, embedding });
826 } else {
827 embedded_all_chunks = false;
828 }
829 }
830
831 if embedded_all_chunks {
832 embedded_files_tx
833 .send((embedded_file, chunked_file.handle))
834 .await?;
835 }
836 }
837 }
838 Ok(())
839 });
840
841 EmbedFiles {
842 files: embedded_files_rx,
843 task,
844 }
845 }
846
847 fn persist_embeddings(
848 &self,
849 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
850 embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
851 cx: &AppContext,
852 ) -> Task<Result<()>> {
853 let db_connection = self.db_connection.clone();
854 let db = self.db;
855 cx.background_executor().spawn(async move {
856 while let Some(deletion_range) = deleted_entry_ranges.next().await {
857 let mut txn = db_connection.write_txn()?;
858 let start = deletion_range.0.as_ref().map(|start| start.as_str());
859 let end = deletion_range.1.as_ref().map(|end| end.as_str());
860 log::debug!("deleting embeddings in range {:?}", &(start, end));
861 db.delete_range(&mut txn, &(start, end))?;
862 txn.commit()?;
863 }
864
865 let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
866 while let Some(embedded_files) = embedded_files.next().await {
867 let mut txn = db_connection.write_txn()?;
868 for (file, _) in &embedded_files {
869 log::debug!("saving embedding for file {:?}", file.path);
870 let key = db_key_for_path(&file.path);
871 db.put(&mut txn, &key, file)?;
872 }
873 txn.commit()?;
874
875 drop(embedded_files);
876 log::debug!("committed");
877 }
878
879 Ok(())
880 })
881 }
882
883 fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
884 let connection = self.db_connection.clone();
885 let db = self.db;
886 cx.background_executor().spawn(async move {
887 let tx = connection
888 .read_txn()
889 .context("failed to create read transaction")?;
890 let result = db
891 .iter(&tx)?
892 .map(|entry| Ok(entry?.1.path.clone()))
893 .collect::<Result<Vec<Arc<Path>>>>();
894 drop(tx);
895 result
896 })
897 }
898
899 fn chunks_for_path(
900 &self,
901 path: Arc<Path>,
902 cx: &AppContext,
903 ) -> Task<Result<Vec<EmbeddedChunk>>> {
904 let connection = self.db_connection.clone();
905 let db = self.db;
906 cx.background_executor().spawn(async move {
907 let tx = connection
908 .read_txn()
909 .context("failed to create read transaction")?;
910 Ok(db
911 .get(&tx, &db_key_for_path(&path))?
912 .ok_or_else(|| anyhow!("no such path"))?
913 .chunks
914 .clone())
915 })
916 }
917
918 #[cfg(test)]
919 fn path_count(&self) -> Result<u64> {
920 let txn = self
921 .db_connection
922 .read_txn()
923 .context("failed to create read transaction")?;
924 Ok(self.db.len(&txn)?)
925 }
926}
927
928struct ScanEntries {
929 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
930 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
931 task: Task<Result<()>>,
932}
933
934struct ChunkFiles {
935 files: channel::Receiver<ChunkedFile>,
936 task: Task<Result<()>>,
937}
938
939struct ChunkedFile {
940 pub path: Arc<Path>,
941 pub mtime: Option<SystemTime>,
942 pub handle: IndexingEntryHandle,
943 pub text: String,
944 pub chunks: Vec<Chunk>,
945}
946
947struct EmbedFiles {
948 files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
949 task: Task<Result<()>>,
950}
951
952#[derive(Debug, Serialize, Deserialize)]
953struct EmbeddedFile {
954 path: Arc<Path>,
955 mtime: Option<SystemTime>,
956 chunks: Vec<EmbeddedChunk>,
957}
958
959#[derive(Clone, Debug, Serialize, Deserialize)]
960struct EmbeddedChunk {
961 chunk: Chunk,
962 embedding: Embedding,
963}
964
965/// The set of entries that are currently being indexed.
966struct IndexingEntrySet {
967 entry_ids: Mutex<HashSet<ProjectEntryId>>,
968 tx: channel::Sender<()>,
969}
970
971/// When dropped, removes the entry from the set of entries that are being indexed.
972#[derive(Clone)]
973struct IndexingEntryHandle {
974 entry_id: ProjectEntryId,
975 set: Weak<IndexingEntrySet>,
976}
977
978impl IndexingEntrySet {
979 fn new(tx: channel::Sender<()>) -> Self {
980 Self {
981 entry_ids: Default::default(),
982 tx,
983 }
984 }
985
986 fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
987 self.entry_ids.lock().insert(entry_id);
988 self.tx.send_blocking(()).ok();
989 IndexingEntryHandle {
990 entry_id,
991 set: Arc::downgrade(self),
992 }
993 }
994
995 pub fn len(&self) -> usize {
996 self.entry_ids.lock().len()
997 }
998}
999
1000impl Drop for IndexingEntryHandle {
1001 fn drop(&mut self) {
1002 if let Some(set) = self.set.upgrade() {
1003 set.tx.send_blocking(()).ok();
1004 set.entry_ids.lock().remove(&self.entry_id);
1005 }
1006 }
1007}
1008
1009fn db_key_for_path(path: &Arc<Path>) -> String {
1010 path.to_string_lossy().replace('/', "\0")
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015 use super::*;
1016 use futures::{future::BoxFuture, FutureExt};
1017 use gpui::TestAppContext;
1018 use language::language_settings::AllLanguageSettings;
1019 use project::Project;
1020 use settings::SettingsStore;
1021 use std::{future, path::Path, sync::Arc};
1022
1023 fn init_test(cx: &mut TestAppContext) {
1024 _ = cx.update(|cx| {
1025 let store = SettingsStore::test(cx);
1026 cx.set_global(store);
1027 language::init(cx);
1028 Project::init_settings(cx);
1029 SettingsStore::update(cx, |store, cx| {
1030 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1031 });
1032 });
1033 }
1034
1035 pub struct TestEmbeddingProvider {
1036 batch_size: usize,
1037 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1038 }
1039
1040 impl TestEmbeddingProvider {
1041 pub fn new(
1042 batch_size: usize,
1043 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1044 ) -> Self {
1045 return Self {
1046 batch_size,
1047 compute_embedding: Box::new(compute_embedding),
1048 };
1049 }
1050 }
1051
1052 impl EmbeddingProvider for TestEmbeddingProvider {
1053 fn embed<'a>(
1054 &'a self,
1055 texts: &'a [TextToEmbed<'a>],
1056 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1057 let embeddings = texts
1058 .iter()
1059 .map(|to_embed| (self.compute_embedding)(to_embed.text))
1060 .collect();
1061 future::ready(embeddings).boxed()
1062 }
1063
1064 fn batch_size(&self) -> usize {
1065 self.batch_size
1066 }
1067 }
1068
1069 #[gpui::test]
1070 async fn test_search(cx: &mut TestAppContext) {
1071 cx.executor().allow_parking();
1072
1073 init_test(cx);
1074
1075 let temp_dir = tempfile::tempdir().unwrap();
1076
1077 let mut semantic_index = SemanticIndex::new(
1078 temp_dir.path().into(),
1079 Arc::new(TestEmbeddingProvider::new(16, |text| {
1080 let mut embedding = vec![0f32; 2];
1081 // if the text contains garbage, give it a 1 in the first dimension
1082 if text.contains("garbage in") {
1083 embedding[0] = 0.9;
1084 } else {
1085 embedding[0] = -0.9;
1086 }
1087
1088 if text.contains("garbage out") {
1089 embedding[1] = 0.9;
1090 } else {
1091 embedding[1] = -0.9;
1092 }
1093
1094 Ok(Embedding::new(embedding))
1095 })),
1096 &mut cx.to_async(),
1097 )
1098 .await
1099 .unwrap();
1100
1101 let project_path = Path::new("./fixture");
1102
1103 let project = cx
1104 .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1105 .await;
1106
1107 cx.update(|cx| {
1108 let language_registry = project.read(cx).languages().clone();
1109 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1110 languages::init(language_registry, node_runtime, cx);
1111 });
1112
1113 let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1114
1115 while project_index
1116 .read_with(cx, |index, cx| index.path_count(cx))
1117 .unwrap()
1118 == 0
1119 {
1120 project_index.next_event(cx).await;
1121 }
1122
1123 let results = cx
1124 .update(|cx| {
1125 let project_index = project_index.read(cx);
1126 let query = "garbage in, garbage out";
1127 project_index.search(query.into(), 4, cx)
1128 })
1129 .await
1130 .unwrap();
1131
1132 assert!(results.len() > 1, "should have found some results");
1133
1134 for result in &results {
1135 println!("result: {:?}", result.path);
1136 println!("score: {:?}", result.score);
1137 }
1138
1139 // Find result that is greater than 0.5
1140 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1141
1142 assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1143
1144 let content = cx
1145 .update(|cx| {
1146 let worktree = search_result.worktree.read(cx);
1147 let entry_abs_path = worktree.abs_path().join(&search_result.path);
1148 let fs = project.read(cx).fs().clone();
1149 cx.background_executor()
1150 .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
1151 })
1152 .await;
1153
1154 let range = search_result.range.clone();
1155 let content = content[range.clone()].to_owned();
1156
1157 assert!(content.contains("garbage in, garbage out"));
1158 }
1159
1160 #[gpui::test]
1161 async fn test_embed_files(cx: &mut TestAppContext) {
1162 cx.executor().allow_parking();
1163
1164 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1165 if text.contains('g') {
1166 Err(anyhow!("cannot embed text containing a 'g' character"))
1167 } else {
1168 Ok(Embedding::new(
1169 ('a'..'z')
1170 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1171 .collect(),
1172 ))
1173 }
1174 }));
1175
1176 let (indexing_progress_tx, _) = channel::unbounded();
1177 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1178
1179 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1180 chunked_files_tx
1181 .send_blocking(ChunkedFile {
1182 path: Path::new("test1.md").into(),
1183 mtime: None,
1184 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1185 text: "abcdefghijklmnop".to_string(),
1186 chunks: [0..4, 4..8, 8..12, 12..16]
1187 .into_iter()
1188 .map(|range| Chunk {
1189 range,
1190 digest: Default::default(),
1191 })
1192 .collect(),
1193 })
1194 .unwrap();
1195 chunked_files_tx
1196 .send_blocking(ChunkedFile {
1197 path: Path::new("test2.md").into(),
1198 mtime: None,
1199 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1200 text: "qrstuvwxyz".to_string(),
1201 chunks: [0..4, 4..8, 8..10]
1202 .into_iter()
1203 .map(|range| Chunk {
1204 range,
1205 digest: Default::default(),
1206 })
1207 .collect(),
1208 })
1209 .unwrap();
1210 chunked_files_tx.close();
1211
1212 let embed_files_task =
1213 cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1214 embed_files_task.task.await.unwrap();
1215
1216 let mut embedded_files_rx = embed_files_task.files;
1217 let mut embedded_files = Vec::new();
1218 while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1219 embedded_files.push(embedded_file);
1220 }
1221
1222 assert_eq!(embedded_files.len(), 1);
1223 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1224 assert_eq!(
1225 embedded_files[0]
1226 .chunks
1227 .iter()
1228 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1229 .collect::<Vec<Embedding>>(),
1230 vec![
1231 (provider.compute_embedding)("qrst").unwrap(),
1232 (provider.compute_embedding)("uvwx").unwrap(),
1233 (provider.compute_embedding)("yz").unwrap(),
1234 ],
1235 );
1236 }
1237}