1mod chunking;
2mod embedding;
3mod project_index_debug_view;
4
5use anyhow::{anyhow, Context as _, Result};
6use chunking::{chunk_text, Chunk};
7use collections::{Bound, HashMap, HashSet};
8pub use embedding::*;
9use fs::Fs;
10use futures::{future::Shared, stream::StreamExt, FutureExt};
11use futures_batch::ChunksTimeoutStreamExt;
12use gpui::{
13 AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
14 Model, ModelContext, Subscription, Task, WeakModel,
15};
16use heed::types::{SerdeBincode, Str};
17use language::LanguageRegistry;
18use parking_lot::Mutex;
19use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
20use serde::{Deserialize, Serialize};
21use smol::channel;
22use std::{
23 cmp::Ordering,
24 future::Future,
25 iter,
26 num::NonZeroUsize,
27 ops::Range,
28 path::{Path, PathBuf},
29 sync::{Arc, Weak},
30 time::{Duration, SystemTime},
31};
32use util::ResultExt;
33use worktree::Snapshot;
34
35pub use project_index_debug_view::ProjectIndexDebugView;
36
37pub struct SemanticIndex {
38 embedding_provider: Arc<dyn EmbeddingProvider>,
39 db_connection: heed::Env,
40 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
41}
42
43impl Global for SemanticIndex {}
44
45impl SemanticIndex {
46 pub async fn new(
47 db_path: PathBuf,
48 embedding_provider: Arc<dyn EmbeddingProvider>,
49 cx: &mut AsyncAppContext,
50 ) -> Result<Self> {
51 let db_connection = cx
52 .background_executor()
53 .spawn(async move {
54 std::fs::create_dir_all(&db_path)?;
55 unsafe {
56 heed::EnvOpenOptions::new()
57 .map_size(1024 * 1024 * 1024)
58 .max_dbs(3000)
59 .open(db_path)
60 }
61 })
62 .await
63 .context("opening database connection")?;
64
65 Ok(SemanticIndex {
66 db_connection,
67 embedding_provider,
68 project_indices: HashMap::default(),
69 })
70 }
71
72 pub fn project_index(
73 &mut self,
74 project: Model<Project>,
75 cx: &mut AppContext,
76 ) -> Model<ProjectIndex> {
77 let project_weak = project.downgrade();
78 project.update(cx, move |_, cx| {
79 cx.on_release(move |_, cx| {
80 if cx.has_global::<SemanticIndex>() {
81 cx.update_global::<SemanticIndex, _>(|this, _| {
82 this.project_indices.remove(&project_weak);
83 })
84 }
85 })
86 .detach();
87 });
88
89 self.project_indices
90 .entry(project.downgrade())
91 .or_insert_with(|| {
92 cx.new_model(|cx| {
93 ProjectIndex::new(
94 project,
95 self.db_connection.clone(),
96 self.embedding_provider.clone(),
97 cx,
98 )
99 })
100 })
101 .clone()
102 }
103}
104
105pub struct ProjectIndex {
106 db_connection: heed::Env,
107 project: WeakModel<Project>,
108 worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
109 language_registry: Arc<LanguageRegistry>,
110 fs: Arc<dyn Fs>,
111 last_status: Status,
112 status_tx: channel::Sender<()>,
113 embedding_provider: Arc<dyn EmbeddingProvider>,
114 _maintain_status: Task<()>,
115 _subscription: Subscription,
116}
117
118#[derive(Clone)]
119enum WorktreeIndexHandle {
120 Loading {
121 index: Shared<Task<Result<Model<WorktreeIndex>, Arc<anyhow::Error>>>>,
122 },
123 Loaded {
124 index: Model<WorktreeIndex>,
125 },
126}
127
128impl ProjectIndex {
129 fn new(
130 project: Model<Project>,
131 db_connection: heed::Env,
132 embedding_provider: Arc<dyn EmbeddingProvider>,
133 cx: &mut ModelContext<Self>,
134 ) -> Self {
135 let language_registry = project.read(cx).languages().clone();
136 let fs = project.read(cx).fs().clone();
137 let (status_tx, mut status_rx) = channel::unbounded();
138 let mut this = ProjectIndex {
139 db_connection,
140 project: project.downgrade(),
141 worktree_indices: HashMap::default(),
142 language_registry,
143 fs,
144 status_tx,
145 last_status: Status::Idle,
146 embedding_provider,
147 _subscription: cx.subscribe(&project, Self::handle_project_event),
148 _maintain_status: cx.spawn(|this, mut cx| async move {
149 while status_rx.next().await.is_some() {
150 if this
151 .update(&mut cx, |this, cx| this.update_status(cx))
152 .is_err()
153 {
154 break;
155 }
156 }
157 }),
158 };
159 this.update_worktree_indices(cx);
160 this
161 }
162
163 pub fn status(&self) -> Status {
164 self.last_status
165 }
166
167 pub fn project(&self) -> WeakModel<Project> {
168 self.project.clone()
169 }
170
171 pub fn fs(&self) -> Arc<dyn Fs> {
172 self.fs.clone()
173 }
174
175 fn handle_project_event(
176 &mut self,
177 _: Model<Project>,
178 event: &project::Event,
179 cx: &mut ModelContext<Self>,
180 ) {
181 match event {
182 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
183 self.update_worktree_indices(cx);
184 }
185 _ => {}
186 }
187 }
188
189 fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
190 let Some(project) = self.project.upgrade() else {
191 return;
192 };
193
194 let worktrees = project
195 .read(cx)
196 .visible_worktrees(cx)
197 .filter_map(|worktree| {
198 if worktree.read(cx).is_local() {
199 Some((worktree.entity_id(), worktree))
200 } else {
201 None
202 }
203 })
204 .collect::<HashMap<_, _>>();
205
206 self.worktree_indices
207 .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
208 for (worktree_id, worktree) in worktrees {
209 self.worktree_indices.entry(worktree_id).or_insert_with(|| {
210 let worktree_index = WorktreeIndex::load(
211 worktree.clone(),
212 self.db_connection.clone(),
213 self.language_registry.clone(),
214 self.fs.clone(),
215 self.status_tx.clone(),
216 self.embedding_provider.clone(),
217 cx,
218 );
219
220 let load_worktree = cx.spawn(|this, mut cx| async move {
221 let result = match worktree_index.await {
222 Ok(worktree_index) => {
223 this.update(&mut cx, |this, _| {
224 this.worktree_indices.insert(
225 worktree_id,
226 WorktreeIndexHandle::Loaded {
227 index: worktree_index.clone(),
228 },
229 );
230 })?;
231 Ok(worktree_index)
232 }
233 Err(error) => {
234 this.update(&mut cx, |this, _cx| {
235 this.worktree_indices.remove(&worktree_id)
236 })?;
237 Err(Arc::new(error))
238 }
239 };
240
241 this.update(&mut cx, |this, cx| this.update_status(cx))?;
242
243 result
244 });
245
246 WorktreeIndexHandle::Loading {
247 index: load_worktree.shared(),
248 }
249 });
250 }
251
252 self.update_status(cx);
253 }
254
255 fn update_status(&mut self, cx: &mut ModelContext<Self>) {
256 let mut indexing_count = 0;
257 let mut any_loading = false;
258
259 for index in self.worktree_indices.values_mut() {
260 match index {
261 WorktreeIndexHandle::Loading { .. } => {
262 any_loading = true;
263 break;
264 }
265 WorktreeIndexHandle::Loaded { index, .. } => {
266 indexing_count += index.read(cx).entry_ids_being_indexed.len();
267 }
268 }
269 }
270
271 let status = if any_loading {
272 Status::Loading
273 } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
274 Status::Scanning { remaining_count }
275 } else {
276 Status::Idle
277 };
278
279 if status != self.last_status {
280 self.last_status = status;
281 cx.emit(status);
282 }
283 }
284
285 pub fn search(
286 &self,
287 query: String,
288 limit: usize,
289 cx: &AppContext,
290 ) -> Task<Result<Vec<SearchResult>>> {
291 let (chunks_tx, chunks_rx) = channel::bounded(1024);
292 let mut worktree_scan_tasks = Vec::new();
293 for worktree_index in self.worktree_indices.values() {
294 let worktree_index = worktree_index.clone();
295 let chunks_tx = chunks_tx.clone();
296 worktree_scan_tasks.push(cx.spawn(|cx| async move {
297 let index = match worktree_index {
298 WorktreeIndexHandle::Loading { index } => {
299 index.clone().await.map_err(|error| anyhow!(error))?
300 }
301 WorktreeIndexHandle::Loaded { index } => index.clone(),
302 };
303
304 index
305 .read_with(&cx, |index, cx| {
306 let worktree_id = index.worktree.read(cx).id();
307 let db_connection = index.db_connection.clone();
308 let db = index.db;
309 cx.background_executor().spawn(async move {
310 let txn = db_connection
311 .read_txn()
312 .context("failed to create read transaction")?;
313 let db_entries = db.iter(&txn).context("failed to iterate database")?;
314 for db_entry in db_entries {
315 let (_key, db_embedded_file) = db_entry?;
316 for chunk in db_embedded_file.chunks {
317 chunks_tx
318 .send((worktree_id, db_embedded_file.path.clone(), chunk))
319 .await?;
320 }
321 }
322 anyhow::Ok(())
323 })
324 })?
325 .await
326 }));
327 }
328 drop(chunks_tx);
329
330 let project = self.project.clone();
331 let embedding_provider = self.embedding_provider.clone();
332 cx.spawn(|cx| async move {
333 #[cfg(debug_assertions)]
334 let embedding_query_start = std::time::Instant::now();
335 log::info!("Searching for {query}");
336
337 let query_embeddings = embedding_provider
338 .embed(&[TextToEmbed::new(&query)])
339 .await?;
340 let query_embedding = query_embeddings
341 .into_iter()
342 .next()
343 .ok_or_else(|| anyhow!("no embedding for query"))?;
344
345 let mut results_by_worker = Vec::new();
346 for _ in 0..cx.background_executor().num_cpus() {
347 results_by_worker.push(Vec::<WorktreeSearchResult>::new());
348 }
349
350 #[cfg(debug_assertions)]
351 let search_start = std::time::Instant::now();
352
353 cx.background_executor()
354 .scoped(|cx| {
355 for results in results_by_worker.iter_mut() {
356 cx.spawn(async {
357 while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
358 let score = chunk.embedding.similarity(&query_embedding);
359 let ix = match results.binary_search_by(|probe| {
360 score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
361 }) {
362 Ok(ix) | Err(ix) => ix,
363 };
364 results.insert(
365 ix,
366 WorktreeSearchResult {
367 worktree_id,
368 path: path.clone(),
369 range: chunk.chunk.range.clone(),
370 score,
371 },
372 );
373 results.truncate(limit);
374 }
375 });
376 }
377 })
378 .await;
379
380 for scan_task in futures::future::join_all(worktree_scan_tasks).await {
381 scan_task.log_err();
382 }
383
384 project.read_with(&cx, |project, cx| {
385 let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
386 for worker_results in results_by_worker {
387 search_results.extend(worker_results.into_iter().filter_map(|result| {
388 Some(SearchResult {
389 worktree: project.worktree_for_id(result.worktree_id, cx)?,
390 path: result.path,
391 range: result.range,
392 score: result.score,
393 })
394 }));
395 }
396 search_results.sort_unstable_by(|a, b| {
397 b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
398 });
399 search_results.truncate(limit);
400
401 #[cfg(debug_assertions)]
402 {
403 let search_elapsed = search_start.elapsed();
404 log::debug!(
405 "searched {} entries in {:?}",
406 search_results.len(),
407 search_elapsed
408 );
409 let embedding_query_elapsed = embedding_query_start.elapsed();
410 log::debug!("embedding query took {:?}", embedding_query_elapsed);
411 }
412
413 search_results
414 })
415 })
416 }
417
418 #[cfg(test)]
419 pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
420 let mut result = 0;
421 for worktree_index in self.worktree_indices.values() {
422 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
423 result += index.read(cx).path_count()?;
424 }
425 }
426 Ok(result)
427 }
428
429 pub(crate) fn worktree_index(
430 &self,
431 worktree_id: WorktreeId,
432 cx: &AppContext,
433 ) -> Option<Model<WorktreeIndex>> {
434 for index in self.worktree_indices.values() {
435 if let WorktreeIndexHandle::Loaded { index, .. } = index {
436 if index.read(cx).worktree.read(cx).id() == worktree_id {
437 return Some(index.clone());
438 }
439 }
440 }
441 None
442 }
443
444 pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
445 let mut result = self
446 .worktree_indices
447 .values()
448 .filter_map(|index| {
449 if let WorktreeIndexHandle::Loaded { index, .. } = index {
450 Some(index.clone())
451 } else {
452 None
453 }
454 })
455 .collect::<Vec<_>>();
456 result.sort_by_key(|index| index.read(cx).worktree.read(cx).id());
457 result
458 }
459}
460
461pub struct SearchResult {
462 pub worktree: Model<Worktree>,
463 pub path: Arc<Path>,
464 pub range: Range<usize>,
465 pub score: f32,
466}
467
468pub struct WorktreeSearchResult {
469 pub worktree_id: WorktreeId,
470 pub path: Arc<Path>,
471 pub range: Range<usize>,
472 pub score: f32,
473}
474
475#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
476pub enum Status {
477 Idle,
478 Loading,
479 Scanning { remaining_count: NonZeroUsize },
480}
481
482impl EventEmitter<Status> for ProjectIndex {}
483
484struct WorktreeIndex {
485 worktree: Model<Worktree>,
486 db_connection: heed::Env,
487 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
488 language_registry: Arc<LanguageRegistry>,
489 fs: Arc<dyn Fs>,
490 embedding_provider: Arc<dyn EmbeddingProvider>,
491 entry_ids_being_indexed: Arc<IndexingEntrySet>,
492 _index_entries: Task<Result<()>>,
493 _subscription: Subscription,
494}
495
496impl WorktreeIndex {
497 pub fn load(
498 worktree: Model<Worktree>,
499 db_connection: heed::Env,
500 language_registry: Arc<LanguageRegistry>,
501 fs: Arc<dyn Fs>,
502 status_tx: channel::Sender<()>,
503 embedding_provider: Arc<dyn EmbeddingProvider>,
504 cx: &mut AppContext,
505 ) -> Task<Result<Model<Self>>> {
506 let worktree_abs_path = worktree.read(cx).abs_path();
507 cx.spawn(|mut cx| async move {
508 let db = cx
509 .background_executor()
510 .spawn({
511 let db_connection = db_connection.clone();
512 async move {
513 let mut txn = db_connection.write_txn()?;
514 let db_name = worktree_abs_path.to_string_lossy();
515 let db = db_connection.create_database(&mut txn, Some(&db_name))?;
516 txn.commit()?;
517 anyhow::Ok(db)
518 }
519 })
520 .await?;
521 cx.new_model(|cx| {
522 Self::new(
523 worktree,
524 db_connection,
525 db,
526 status_tx,
527 language_registry,
528 fs,
529 embedding_provider,
530 cx,
531 )
532 })
533 })
534 }
535
536 #[allow(clippy::too_many_arguments)]
537 fn new(
538 worktree: Model<Worktree>,
539 db_connection: heed::Env,
540 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
541 status: channel::Sender<()>,
542 language_registry: Arc<LanguageRegistry>,
543 fs: Arc<dyn Fs>,
544 embedding_provider: Arc<dyn EmbeddingProvider>,
545 cx: &mut ModelContext<Self>,
546 ) -> Self {
547 let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
548 let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
549 if let worktree::Event::UpdatedEntries(update) = event {
550 _ = updated_entries_tx.try_send(update.clone());
551 }
552 });
553
554 Self {
555 db_connection,
556 db,
557 worktree,
558 language_registry,
559 fs,
560 embedding_provider,
561 entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
562 _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
563 _subscription,
564 }
565 }
566
567 async fn index_entries(
568 this: WeakModel<Self>,
569 updated_entries: channel::Receiver<UpdatedEntriesSet>,
570 mut cx: AsyncAppContext,
571 ) -> Result<()> {
572 let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
573 index.await.log_err();
574
575 while let Ok(updated_entries) = updated_entries.recv().await {
576 let index = this.update(&mut cx, |this, cx| {
577 this.index_updated_entries(updated_entries, cx)
578 })?;
579 index.await.log_err();
580 }
581
582 Ok(())
583 }
584
585 fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
586 let worktree = self.worktree.read(cx).snapshot();
587 let worktree_abs_path = worktree.abs_path().clone();
588 let scan = self.scan_entries(worktree, cx);
589 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
590 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
591 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
592 async move {
593 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
594 Ok(())
595 }
596 }
597
598 fn index_updated_entries(
599 &self,
600 updated_entries: UpdatedEntriesSet,
601 cx: &AppContext,
602 ) -> impl Future<Output = Result<()>> {
603 let worktree = self.worktree.read(cx).snapshot();
604 let worktree_abs_path = worktree.abs_path().clone();
605 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
606 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
607 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
608 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
609 async move {
610 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
611 Ok(())
612 }
613 }
614
615 fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
616 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
617 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
618 let db_connection = self.db_connection.clone();
619 let db = self.db;
620 let entries_being_indexed = self.entry_ids_being_indexed.clone();
621 let task = cx.background_executor().spawn(async move {
622 let txn = db_connection
623 .read_txn()
624 .context("failed to create read transaction")?;
625 let mut db_entries = db
626 .iter(&txn)
627 .context("failed to create iterator")?
628 .move_between_keys()
629 .peekable();
630
631 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
632 for entry in worktree.files(false, 0) {
633 let entry_db_key = db_key_for_path(&entry.path);
634
635 let mut saved_mtime = None;
636 while let Some(db_entry) = db_entries.peek() {
637 match db_entry {
638 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
639 Ordering::Less => {
640 if let Some(deletion_range) = deletion_range.as_mut() {
641 deletion_range.1 = Bound::Included(db_path);
642 } else {
643 deletion_range =
644 Some((Bound::Included(db_path), Bound::Included(db_path)));
645 }
646
647 db_entries.next();
648 }
649 Ordering::Equal => {
650 if let Some(deletion_range) = deletion_range.take() {
651 deleted_entry_ranges_tx
652 .send((
653 deletion_range.0.map(ToString::to_string),
654 deletion_range.1.map(ToString::to_string),
655 ))
656 .await?;
657 }
658 saved_mtime = db_embedded_file.mtime;
659 db_entries.next();
660 break;
661 }
662 Ordering::Greater => {
663 break;
664 }
665 },
666 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
667 }
668 }
669
670 if entry.mtime != saved_mtime {
671 let handle = entries_being_indexed.insert(entry.id);
672 updated_entries_tx.send((entry.clone(), handle)).await?;
673 }
674 }
675
676 if let Some(db_entry) = db_entries.next() {
677 let (db_path, _) = db_entry?;
678 deleted_entry_ranges_tx
679 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
680 .await?;
681 }
682
683 Ok(())
684 });
685
686 ScanEntries {
687 updated_entries: updated_entries_rx,
688 deleted_entry_ranges: deleted_entry_ranges_rx,
689 task,
690 }
691 }
692
693 fn scan_updated_entries(
694 &self,
695 worktree: Snapshot,
696 updated_entries: UpdatedEntriesSet,
697 cx: &AppContext,
698 ) -> ScanEntries {
699 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
700 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
701 let entries_being_indexed = self.entry_ids_being_indexed.clone();
702 let task = cx.background_executor().spawn(async move {
703 for (path, entry_id, status) in updated_entries.iter() {
704 match status {
705 project::PathChange::Added
706 | project::PathChange::Updated
707 | project::PathChange::AddedOrUpdated => {
708 if let Some(entry) = worktree.entry_for_id(*entry_id) {
709 if entry.is_file() {
710 let handle = entries_being_indexed.insert(entry.id);
711 updated_entries_tx.send((entry.clone(), handle)).await?;
712 }
713 }
714 }
715 project::PathChange::Removed => {
716 let db_path = db_key_for_path(path);
717 deleted_entry_ranges_tx
718 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
719 .await?;
720 }
721 project::PathChange::Loaded => {
722 // Do nothing.
723 }
724 }
725 }
726
727 Ok(())
728 });
729
730 ScanEntries {
731 updated_entries: updated_entries_rx,
732 deleted_entry_ranges: deleted_entry_ranges_rx,
733 task,
734 }
735 }
736
737 fn chunk_files(
738 &self,
739 worktree_abs_path: Arc<Path>,
740 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
741 cx: &AppContext,
742 ) -> ChunkFiles {
743 let language_registry = self.language_registry.clone();
744 let fs = self.fs.clone();
745 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
746 let task = cx.spawn(|cx| async move {
747 cx.background_executor()
748 .scoped(|cx| {
749 for _ in 0..cx.num_cpus() {
750 cx.spawn(async {
751 while let Ok((entry, handle)) = entries.recv().await {
752 let entry_abs_path = worktree_abs_path.join(&entry.path);
753 let Some(text) = fs
754 .load(&entry_abs_path)
755 .await
756 .with_context(|| {
757 format!("failed to read path {entry_abs_path:?}")
758 })
759 .log_err()
760 else {
761 continue;
762 };
763 let language = language_registry
764 .language_for_file_path(&entry.path)
765 .await
766 .ok();
767 let chunked_file = ChunkedFile {
768 chunks: chunk_text(&text, language.as_ref(), &entry.path),
769 handle,
770 path: entry.path,
771 mtime: entry.mtime,
772 text,
773 };
774
775 if chunked_files_tx.send(chunked_file).await.is_err() {
776 return;
777 }
778 }
779 });
780 }
781 })
782 .await;
783 Ok(())
784 });
785
786 ChunkFiles {
787 files: chunked_files_rx,
788 task,
789 }
790 }
791
792 fn embed_files(
793 embedding_provider: Arc<dyn EmbeddingProvider>,
794 chunked_files: channel::Receiver<ChunkedFile>,
795 cx: &AppContext,
796 ) -> EmbedFiles {
797 let embedding_provider = embedding_provider.clone();
798 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
799 let task = cx.background_executor().spawn(async move {
800 let mut chunked_file_batches =
801 chunked_files.chunks_timeout(512, Duration::from_secs(2));
802 while let Some(chunked_files) = chunked_file_batches.next().await {
803 // View the batch of files as a vec of chunks
804 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
805 // Once those are done, reassemble them back into the files in which they belong
806 // If any embeddings fail for a file, the entire file is discarded
807
808 let chunks: Vec<TextToEmbed> = chunked_files
809 .iter()
810 .flat_map(|file| {
811 file.chunks.iter().map(|chunk| TextToEmbed {
812 text: &file.text[chunk.range.clone()],
813 digest: chunk.digest,
814 })
815 })
816 .collect::<Vec<_>>();
817
818 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
819 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
820 if let Some(batch_embeddings) =
821 embedding_provider.embed(embedding_batch).await.log_err()
822 {
823 if batch_embeddings.len() == embedding_batch.len() {
824 embeddings.extend(batch_embeddings.into_iter().map(Some));
825 continue;
826 }
827 log::error!(
828 "embedding provider returned unexpected embedding count {}, expected {}",
829 batch_embeddings.len(), embedding_batch.len()
830 );
831 }
832
833 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
834 }
835
836 let mut embeddings = embeddings.into_iter();
837 for chunked_file in chunked_files {
838 let mut embedded_file = EmbeddedFile {
839 path: chunked_file.path,
840 mtime: chunked_file.mtime,
841 chunks: Vec::new(),
842 };
843
844 let mut embedded_all_chunks = true;
845 for (chunk, embedding) in
846 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
847 {
848 if let Some(embedding) = embedding {
849 embedded_file
850 .chunks
851 .push(EmbeddedChunk { chunk, embedding });
852 } else {
853 embedded_all_chunks = false;
854 }
855 }
856
857 if embedded_all_chunks {
858 embedded_files_tx
859 .send((embedded_file, chunked_file.handle))
860 .await?;
861 }
862 }
863 }
864 Ok(())
865 });
866
867 EmbedFiles {
868 files: embedded_files_rx,
869 task,
870 }
871 }
872
873 fn persist_embeddings(
874 &self,
875 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
876 embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
877 cx: &AppContext,
878 ) -> Task<Result<()>> {
879 let db_connection = self.db_connection.clone();
880 let db = self.db;
881 cx.background_executor().spawn(async move {
882 while let Some(deletion_range) = deleted_entry_ranges.next().await {
883 let mut txn = db_connection.write_txn()?;
884 let start = deletion_range.0.as_ref().map(|start| start.as_str());
885 let end = deletion_range.1.as_ref().map(|end| end.as_str());
886 log::debug!("deleting embeddings in range {:?}", &(start, end));
887 db.delete_range(&mut txn, &(start, end))?;
888 txn.commit()?;
889 }
890
891 let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
892 while let Some(embedded_files) = embedded_files.next().await {
893 let mut txn = db_connection.write_txn()?;
894 for (file, _) in &embedded_files {
895 log::debug!("saving embedding for file {:?}", file.path);
896 let key = db_key_for_path(&file.path);
897 db.put(&mut txn, &key, file)?;
898 }
899 txn.commit()?;
900
901 drop(embedded_files);
902 log::debug!("committed");
903 }
904
905 Ok(())
906 })
907 }
908
909 fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
910 let connection = self.db_connection.clone();
911 let db = self.db;
912 cx.background_executor().spawn(async move {
913 let tx = connection
914 .read_txn()
915 .context("failed to create read transaction")?;
916 let result = db
917 .iter(&tx)?
918 .map(|entry| Ok(entry?.1.path.clone()))
919 .collect::<Result<Vec<Arc<Path>>>>();
920 drop(tx);
921 result
922 })
923 }
924
925 fn chunks_for_path(
926 &self,
927 path: Arc<Path>,
928 cx: &AppContext,
929 ) -> Task<Result<Vec<EmbeddedChunk>>> {
930 let connection = self.db_connection.clone();
931 let db = self.db;
932 cx.background_executor().spawn(async move {
933 let tx = connection
934 .read_txn()
935 .context("failed to create read transaction")?;
936 Ok(db
937 .get(&tx, &db_key_for_path(&path))?
938 .ok_or_else(|| anyhow!("no such path"))?
939 .chunks
940 .clone())
941 })
942 }
943
944 #[cfg(test)]
945 fn path_count(&self) -> Result<u64> {
946 let txn = self
947 .db_connection
948 .read_txn()
949 .context("failed to create read transaction")?;
950 Ok(self.db.len(&txn)?)
951 }
952}
953
954struct ScanEntries {
955 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
956 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
957 task: Task<Result<()>>,
958}
959
960struct ChunkFiles {
961 files: channel::Receiver<ChunkedFile>,
962 task: Task<Result<()>>,
963}
964
965struct ChunkedFile {
966 pub path: Arc<Path>,
967 pub mtime: Option<SystemTime>,
968 pub handle: IndexingEntryHandle,
969 pub text: String,
970 pub chunks: Vec<Chunk>,
971}
972
973struct EmbedFiles {
974 files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
975 task: Task<Result<()>>,
976}
977
978#[derive(Debug, Serialize, Deserialize)]
979struct EmbeddedFile {
980 path: Arc<Path>,
981 mtime: Option<SystemTime>,
982 chunks: Vec<EmbeddedChunk>,
983}
984
985#[derive(Clone, Debug, Serialize, Deserialize)]
986struct EmbeddedChunk {
987 chunk: Chunk,
988 embedding: Embedding,
989}
990
991/// The set of entries that are currently being indexed.
992struct IndexingEntrySet {
993 entry_ids: Mutex<HashSet<ProjectEntryId>>,
994 tx: channel::Sender<()>,
995}
996
997/// When dropped, removes the entry from the set of entries that are being indexed.
998#[derive(Clone)]
999struct IndexingEntryHandle {
1000 entry_id: ProjectEntryId,
1001 set: Weak<IndexingEntrySet>,
1002}
1003
1004impl IndexingEntrySet {
1005 fn new(tx: channel::Sender<()>) -> Self {
1006 Self {
1007 entry_ids: Default::default(),
1008 tx,
1009 }
1010 }
1011
1012 fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
1013 self.entry_ids.lock().insert(entry_id);
1014 self.tx.send_blocking(()).ok();
1015 IndexingEntryHandle {
1016 entry_id,
1017 set: Arc::downgrade(self),
1018 }
1019 }
1020
1021 pub fn len(&self) -> usize {
1022 self.entry_ids.lock().len()
1023 }
1024}
1025
1026impl Drop for IndexingEntryHandle {
1027 fn drop(&mut self) {
1028 if let Some(set) = self.set.upgrade() {
1029 set.tx.send_blocking(()).ok();
1030 set.entry_ids.lock().remove(&self.entry_id);
1031 }
1032 }
1033}
1034
1035fn db_key_for_path(path: &Arc<Path>) -> String {
1036 path.to_string_lossy().replace('/', "\0")
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use super::*;
1042 use futures::{future::BoxFuture, FutureExt};
1043 use gpui::TestAppContext;
1044 use language::language_settings::AllLanguageSettings;
1045 use project::Project;
1046 use settings::SettingsStore;
1047 use std::{future, path::Path, sync::Arc};
1048
1049 fn init_test(cx: &mut TestAppContext) {
1050 _ = cx.update(|cx| {
1051 let store = SettingsStore::test(cx);
1052 cx.set_global(store);
1053 language::init(cx);
1054 Project::init_settings(cx);
1055 SettingsStore::update(cx, |store, cx| {
1056 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1057 });
1058 });
1059 }
1060
1061 pub struct TestEmbeddingProvider {
1062 batch_size: usize,
1063 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1064 }
1065
1066 impl TestEmbeddingProvider {
1067 pub fn new(
1068 batch_size: usize,
1069 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1070 ) -> Self {
1071 return Self {
1072 batch_size,
1073 compute_embedding: Box::new(compute_embedding),
1074 };
1075 }
1076 }
1077
1078 impl EmbeddingProvider for TestEmbeddingProvider {
1079 fn embed<'a>(
1080 &'a self,
1081 texts: &'a [TextToEmbed<'a>],
1082 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1083 let embeddings = texts
1084 .iter()
1085 .map(|to_embed| (self.compute_embedding)(to_embed.text))
1086 .collect();
1087 future::ready(embeddings).boxed()
1088 }
1089
1090 fn batch_size(&self) -> usize {
1091 self.batch_size
1092 }
1093 }
1094
1095 #[gpui::test]
1096 async fn test_search(cx: &mut TestAppContext) {
1097 cx.executor().allow_parking();
1098
1099 init_test(cx);
1100
1101 let temp_dir = tempfile::tempdir().unwrap();
1102
1103 let mut semantic_index = SemanticIndex::new(
1104 temp_dir.path().into(),
1105 Arc::new(TestEmbeddingProvider::new(16, |text| {
1106 let mut embedding = vec![0f32; 2];
1107 // if the text contains garbage, give it a 1 in the first dimension
1108 if text.contains("garbage in") {
1109 embedding[0] = 0.9;
1110 } else {
1111 embedding[0] = -0.9;
1112 }
1113
1114 if text.contains("garbage out") {
1115 embedding[1] = 0.9;
1116 } else {
1117 embedding[1] = -0.9;
1118 }
1119
1120 Ok(Embedding::new(embedding))
1121 })),
1122 &mut cx.to_async(),
1123 )
1124 .await
1125 .unwrap();
1126
1127 let project_path = Path::new("./fixture");
1128
1129 let project = cx
1130 .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1131 .await;
1132
1133 cx.update(|cx| {
1134 let language_registry = project.read(cx).languages().clone();
1135 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1136 languages::init(language_registry, node_runtime, cx);
1137 });
1138
1139 let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1140
1141 while project_index
1142 .read_with(cx, |index, cx| index.path_count(cx))
1143 .unwrap()
1144 == 0
1145 {
1146 project_index.next_event(cx).await;
1147 }
1148
1149 let results = cx
1150 .update(|cx| {
1151 let project_index = project_index.read(cx);
1152 let query = "garbage in, garbage out";
1153 project_index.search(query.into(), 4, cx)
1154 })
1155 .await
1156 .unwrap();
1157
1158 assert!(results.len() > 1, "should have found some results");
1159
1160 for result in &results {
1161 println!("result: {:?}", result.path);
1162 println!("score: {:?}", result.score);
1163 }
1164
1165 // Find result that is greater than 0.5
1166 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1167
1168 assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1169
1170 let content = cx
1171 .update(|cx| {
1172 let worktree = search_result.worktree.read(cx);
1173 let entry_abs_path = worktree.abs_path().join(&search_result.path);
1174 let fs = project.read(cx).fs().clone();
1175 cx.background_executor()
1176 .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
1177 })
1178 .await;
1179
1180 let range = search_result.range.clone();
1181 let content = content[range.clone()].to_owned();
1182
1183 assert!(content.contains("garbage in, garbage out"));
1184 }
1185
1186 #[gpui::test]
1187 async fn test_embed_files(cx: &mut TestAppContext) {
1188 cx.executor().allow_parking();
1189
1190 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1191 if text.contains('g') {
1192 Err(anyhow!("cannot embed text containing a 'g' character"))
1193 } else {
1194 Ok(Embedding::new(
1195 ('a'..='z')
1196 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1197 .collect(),
1198 ))
1199 }
1200 }));
1201
1202 let (indexing_progress_tx, _) = channel::unbounded();
1203 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1204
1205 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1206 chunked_files_tx
1207 .send_blocking(ChunkedFile {
1208 path: Path::new("test1.md").into(),
1209 mtime: None,
1210 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1211 text: "abcdefghijklmnop".to_string(),
1212 chunks: [0..4, 4..8, 8..12, 12..16]
1213 .into_iter()
1214 .map(|range| Chunk {
1215 range,
1216 digest: Default::default(),
1217 })
1218 .collect(),
1219 })
1220 .unwrap();
1221 chunked_files_tx
1222 .send_blocking(ChunkedFile {
1223 path: Path::new("test2.md").into(),
1224 mtime: None,
1225 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1226 text: "qrstuvwxyz".to_string(),
1227 chunks: [0..4, 4..8, 8..10]
1228 .into_iter()
1229 .map(|range| Chunk {
1230 range,
1231 digest: Default::default(),
1232 })
1233 .collect(),
1234 })
1235 .unwrap();
1236 chunked_files_tx.close();
1237
1238 let embed_files_task =
1239 cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1240 embed_files_task.task.await.unwrap();
1241
1242 let mut embedded_files_rx = embed_files_task.files;
1243 let mut embedded_files = Vec::new();
1244 while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1245 embedded_files.push(embedded_file);
1246 }
1247
1248 assert_eq!(embedded_files.len(), 1);
1249 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1250 assert_eq!(
1251 embedded_files[0]
1252 .chunks
1253 .iter()
1254 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1255 .collect::<Vec<Embedding>>(),
1256 vec![
1257 (provider.compute_embedding)("qrst").unwrap(),
1258 (provider.compute_embedding)("uvwx").unwrap(),
1259 (provider.compute_embedding)("yz").unwrap(),
1260 ],
1261 );
1262 }
1263}