1mod chunking;
2mod embedding;
3
4use anyhow::{anyhow, Context as _, Result};
5use chunking::{chunk_text, Chunk};
6use collections::{Bound, HashMap, HashSet};
7pub use embedding::*;
8use fs::Fs;
9use futures::stream::StreamExt;
10use futures_batch::ChunksTimeoutStreamExt;
11use gpui::{
12 AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
13 Model, ModelContext, Subscription, Task, WeakModel,
14};
15use heed::types::{SerdeBincode, Str};
16use language::LanguageRegistry;
17use parking_lot::Mutex;
18use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
19use serde::{Deserialize, Serialize};
20use smol::channel;
21use std::{
22 cmp::Ordering,
23 future::Future,
24 iter,
25 num::NonZeroUsize,
26 ops::Range,
27 path::{Path, PathBuf},
28 sync::{Arc, Weak},
29 time::{Duration, SystemTime},
30};
31use util::ResultExt;
32use worktree::LocalSnapshot;
33
34pub struct SemanticIndex {
35 embedding_provider: Arc<dyn EmbeddingProvider>,
36 db_connection: heed::Env,
37 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
38}
39
40impl Global for SemanticIndex {}
41
42impl SemanticIndex {
43 pub async fn new(
44 db_path: PathBuf,
45 embedding_provider: Arc<dyn EmbeddingProvider>,
46 cx: &mut AsyncAppContext,
47 ) -> Result<Self> {
48 let db_connection = cx
49 .background_executor()
50 .spawn(async move {
51 std::fs::create_dir_all(&db_path)?;
52 unsafe {
53 heed::EnvOpenOptions::new()
54 .map_size(1024 * 1024 * 1024)
55 .max_dbs(3000)
56 .open(db_path)
57 }
58 })
59 .await
60 .context("opening database connection")?;
61
62 Ok(SemanticIndex {
63 db_connection,
64 embedding_provider,
65 project_indices: HashMap::default(),
66 })
67 }
68
69 pub fn project_index(
70 &mut self,
71 project: Model<Project>,
72 cx: &mut AppContext,
73 ) -> Model<ProjectIndex> {
74 let project_weak = project.downgrade();
75 project.update(cx, move |_, cx| {
76 cx.on_release(move |_, cx| {
77 if cx.has_global::<SemanticIndex>() {
78 cx.update_global::<SemanticIndex, _>(|this, _| {
79 this.project_indices.remove(&project_weak);
80 })
81 }
82 })
83 .detach();
84 });
85
86 self.project_indices
87 .entry(project.downgrade())
88 .or_insert_with(|| {
89 cx.new_model(|cx| {
90 ProjectIndex::new(
91 project,
92 self.db_connection.clone(),
93 self.embedding_provider.clone(),
94 cx,
95 )
96 })
97 })
98 .clone()
99 }
100}
101
102pub struct ProjectIndex {
103 db_connection: heed::Env,
104 project: WeakModel<Project>,
105 worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
106 language_registry: Arc<LanguageRegistry>,
107 fs: Arc<dyn Fs>,
108 last_status: Status,
109 status_tx: channel::Sender<()>,
110 embedding_provider: Arc<dyn EmbeddingProvider>,
111 _maintain_status: Task<()>,
112 _subscription: Subscription,
113}
114
115enum WorktreeIndexHandle {
116 Loading { _task: Task<Result<()>> },
117 Loaded { index: Model<WorktreeIndex> },
118}
119
120impl ProjectIndex {
121 fn new(
122 project: Model<Project>,
123 db_connection: heed::Env,
124 embedding_provider: Arc<dyn EmbeddingProvider>,
125 cx: &mut ModelContext<Self>,
126 ) -> Self {
127 let language_registry = project.read(cx).languages().clone();
128 let fs = project.read(cx).fs().clone();
129 let (status_tx, mut status_rx) = channel::unbounded();
130 let mut this = ProjectIndex {
131 db_connection,
132 project: project.downgrade(),
133 worktree_indices: HashMap::default(),
134 language_registry,
135 fs,
136 status_tx,
137 last_status: Status::Idle,
138 embedding_provider,
139 _subscription: cx.subscribe(&project, Self::handle_project_event),
140 _maintain_status: cx.spawn(|this, mut cx| async move {
141 while status_rx.next().await.is_some() {
142 if this
143 .update(&mut cx, |this, cx| this.update_status(cx))
144 .is_err()
145 {
146 break;
147 }
148 }
149 }),
150 };
151 this.update_worktree_indices(cx);
152 this
153 }
154
155 pub fn status(&self) -> Status {
156 self.last_status
157 }
158
159 fn handle_project_event(
160 &mut self,
161 _: Model<Project>,
162 event: &project::Event,
163 cx: &mut ModelContext<Self>,
164 ) {
165 match event {
166 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
167 self.update_worktree_indices(cx);
168 }
169 _ => {}
170 }
171 }
172
173 fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
174 let Some(project) = self.project.upgrade() else {
175 return;
176 };
177
178 let worktrees = project
179 .read(cx)
180 .visible_worktrees(cx)
181 .filter_map(|worktree| {
182 if worktree.read(cx).is_local() {
183 Some((worktree.entity_id(), worktree))
184 } else {
185 None
186 }
187 })
188 .collect::<HashMap<_, _>>();
189
190 self.worktree_indices
191 .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
192 for (worktree_id, worktree) in worktrees {
193 self.worktree_indices.entry(worktree_id).or_insert_with(|| {
194 let worktree_index = WorktreeIndex::load(
195 worktree.clone(),
196 self.db_connection.clone(),
197 self.language_registry.clone(),
198 self.fs.clone(),
199 self.status_tx.clone(),
200 self.embedding_provider.clone(),
201 cx,
202 );
203
204 let load_worktree = cx.spawn(|this, mut cx| async move {
205 if let Some(worktree_index) = worktree_index.await.log_err() {
206 this.update(&mut cx, |this, _| {
207 this.worktree_indices.insert(
208 worktree_id,
209 WorktreeIndexHandle::Loaded {
210 index: worktree_index,
211 },
212 );
213 })?;
214 } else {
215 this.update(&mut cx, |this, _cx| {
216 this.worktree_indices.remove(&worktree_id)
217 })?;
218 }
219
220 this.update(&mut cx, |this, cx| this.update_status(cx))
221 });
222
223 WorktreeIndexHandle::Loading {
224 _task: load_worktree,
225 }
226 });
227 }
228
229 self.update_status(cx);
230 }
231
232 fn update_status(&mut self, cx: &mut ModelContext<Self>) {
233 let mut indexing_count = 0;
234 let mut any_loading = false;
235
236 for index in self.worktree_indices.values_mut() {
237 match index {
238 WorktreeIndexHandle::Loading { .. } => {
239 any_loading = true;
240 break;
241 }
242 WorktreeIndexHandle::Loaded { index, .. } => {
243 indexing_count += index.read(cx).entry_ids_being_indexed.len();
244 }
245 }
246 }
247
248 let status = if any_loading {
249 Status::Loading
250 } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
251 Status::Scanning { remaining_count }
252 } else {
253 Status::Idle
254 };
255
256 if status != self.last_status {
257 self.last_status = status;
258 cx.emit(status);
259 }
260 }
261
262 pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
263 let mut worktree_searches = Vec::new();
264 for worktree_index in self.worktree_indices.values() {
265 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
266 worktree_searches
267 .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
268 }
269 }
270
271 cx.spawn(|_| async move {
272 let mut results = Vec::new();
273 let worktree_searches = futures::future::join_all(worktree_searches).await;
274
275 for worktree_search_results in worktree_searches {
276 if let Some(worktree_search_results) = worktree_search_results.log_err() {
277 results.extend(worktree_search_results);
278 }
279 }
280
281 results
282 .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
283 results.truncate(limit);
284
285 results
286 })
287 }
288
289 #[cfg(test)]
290 pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
291 let mut result = 0;
292 for worktree_index in self.worktree_indices.values() {
293 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
294 result += index.read(cx).path_count()?;
295 }
296 }
297 Ok(result)
298 }
299
300 pub fn debug(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
301 let indices = self
302 .worktree_indices
303 .values()
304 .filter_map(|worktree_index| {
305 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
306 Some(index.clone())
307 } else {
308 None
309 }
310 })
311 .collect::<Vec<_>>();
312
313 cx.spawn(|_, mut cx| async move {
314 eprintln!("semantic index contents:");
315 for index in indices {
316 index.update(&mut cx, |index, cx| index.debug(cx))?.await?
317 }
318 Ok(())
319 })
320 }
321}
322
323pub struct SearchResult {
324 pub worktree: Model<Worktree>,
325 pub path: Arc<Path>,
326 pub range: Range<usize>,
327 pub score: f32,
328}
329
330#[derive(Copy, Clone, Debug, Eq, PartialEq)]
331pub enum Status {
332 Idle,
333 Loading,
334 Scanning { remaining_count: NonZeroUsize },
335}
336
337impl EventEmitter<Status> for ProjectIndex {}
338
339struct WorktreeIndex {
340 worktree: Model<Worktree>,
341 db_connection: heed::Env,
342 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
343 language_registry: Arc<LanguageRegistry>,
344 fs: Arc<dyn Fs>,
345 embedding_provider: Arc<dyn EmbeddingProvider>,
346 entry_ids_being_indexed: Arc<IndexingEntrySet>,
347 _index_entries: Task<Result<()>>,
348 _subscription: Subscription,
349}
350
351impl WorktreeIndex {
352 pub fn load(
353 worktree: Model<Worktree>,
354 db_connection: heed::Env,
355 language_registry: Arc<LanguageRegistry>,
356 fs: Arc<dyn Fs>,
357 status_tx: channel::Sender<()>,
358 embedding_provider: Arc<dyn EmbeddingProvider>,
359 cx: &mut AppContext,
360 ) -> Task<Result<Model<Self>>> {
361 let worktree_abs_path = worktree.read(cx).abs_path();
362 cx.spawn(|mut cx| async move {
363 let db = cx
364 .background_executor()
365 .spawn({
366 let db_connection = db_connection.clone();
367 async move {
368 let mut txn = db_connection.write_txn()?;
369 let db_name = worktree_abs_path.to_string_lossy();
370 let db = db_connection.create_database(&mut txn, Some(&db_name))?;
371 txn.commit()?;
372 anyhow::Ok(db)
373 }
374 })
375 .await?;
376 cx.new_model(|cx| {
377 Self::new(
378 worktree,
379 db_connection,
380 db,
381 status_tx,
382 language_registry,
383 fs,
384 embedding_provider,
385 cx,
386 )
387 })
388 })
389 }
390
391 #[allow(clippy::too_many_arguments)]
392 fn new(
393 worktree: Model<Worktree>,
394 db_connection: heed::Env,
395 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
396 status: channel::Sender<()>,
397 language_registry: Arc<LanguageRegistry>,
398 fs: Arc<dyn Fs>,
399 embedding_provider: Arc<dyn EmbeddingProvider>,
400 cx: &mut ModelContext<Self>,
401 ) -> Self {
402 let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
403 let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
404 if let worktree::Event::UpdatedEntries(update) = event {
405 _ = updated_entries_tx.try_send(update.clone());
406 }
407 });
408
409 Self {
410 db_connection,
411 db,
412 worktree,
413 language_registry,
414 fs,
415 embedding_provider,
416 entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
417 _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
418 _subscription,
419 }
420 }
421
422 async fn index_entries(
423 this: WeakModel<Self>,
424 updated_entries: channel::Receiver<UpdatedEntriesSet>,
425 mut cx: AsyncAppContext,
426 ) -> Result<()> {
427 let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
428 index.await.log_err();
429
430 while let Ok(updated_entries) = updated_entries.recv().await {
431 let index = this.update(&mut cx, |this, cx| {
432 this.index_updated_entries(updated_entries, cx)
433 })?;
434 index.await.log_err();
435 }
436
437 Ok(())
438 }
439
440 fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
441 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
442 let worktree_abs_path = worktree.abs_path().clone();
443 let scan = self.scan_entries(worktree.clone(), cx);
444 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
445 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
446 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
447 async move {
448 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
449 Ok(())
450 }
451 }
452
453 fn index_updated_entries(
454 &self,
455 updated_entries: UpdatedEntriesSet,
456 cx: &AppContext,
457 ) -> impl Future<Output = Result<()>> {
458 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
459 let worktree_abs_path = worktree.abs_path().clone();
460 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
461 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
462 let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
463 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
464 async move {
465 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
466 Ok(())
467 }
468 }
469
470 fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
471 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
472 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
473 let db_connection = self.db_connection.clone();
474 let db = self.db;
475 let entries_being_indexed = self.entry_ids_being_indexed.clone();
476 let task = cx.background_executor().spawn(async move {
477 let txn = db_connection
478 .read_txn()
479 .context("failed to create read transaction")?;
480 let mut db_entries = db
481 .iter(&txn)
482 .context("failed to create iterator")?
483 .move_between_keys()
484 .peekable();
485
486 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
487 for entry in worktree.files(false, 0) {
488 let entry_db_key = db_key_for_path(&entry.path);
489
490 let mut saved_mtime = None;
491 while let Some(db_entry) = db_entries.peek() {
492 match db_entry {
493 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
494 Ordering::Less => {
495 if let Some(deletion_range) = deletion_range.as_mut() {
496 deletion_range.1 = Bound::Included(db_path);
497 } else {
498 deletion_range =
499 Some((Bound::Included(db_path), Bound::Included(db_path)));
500 }
501
502 db_entries.next();
503 }
504 Ordering::Equal => {
505 if let Some(deletion_range) = deletion_range.take() {
506 deleted_entry_ranges_tx
507 .send((
508 deletion_range.0.map(ToString::to_string),
509 deletion_range.1.map(ToString::to_string),
510 ))
511 .await?;
512 }
513 saved_mtime = db_embedded_file.mtime;
514 db_entries.next();
515 break;
516 }
517 Ordering::Greater => {
518 break;
519 }
520 },
521 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
522 }
523 }
524
525 if entry.mtime != saved_mtime {
526 let handle = entries_being_indexed.insert(entry.id);
527 updated_entries_tx.send((entry.clone(), handle)).await?;
528 }
529 }
530
531 if let Some(db_entry) = db_entries.next() {
532 let (db_path, _) = db_entry?;
533 deleted_entry_ranges_tx
534 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
535 .await?;
536 }
537
538 Ok(())
539 });
540
541 ScanEntries {
542 updated_entries: updated_entries_rx,
543 deleted_entry_ranges: deleted_entry_ranges_rx,
544 task,
545 }
546 }
547
548 fn scan_updated_entries(
549 &self,
550 worktree: LocalSnapshot,
551 updated_entries: UpdatedEntriesSet,
552 cx: &AppContext,
553 ) -> ScanEntries {
554 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
555 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
556 let entries_being_indexed = self.entry_ids_being_indexed.clone();
557 let task = cx.background_executor().spawn(async move {
558 for (path, entry_id, status) in updated_entries.iter() {
559 match status {
560 project::PathChange::Added
561 | project::PathChange::Updated
562 | project::PathChange::AddedOrUpdated => {
563 if let Some(entry) = worktree.entry_for_id(*entry_id) {
564 if entry.is_file() {
565 let handle = entries_being_indexed.insert(entry.id);
566 updated_entries_tx.send((entry.clone(), handle)).await?;
567 }
568 }
569 }
570 project::PathChange::Removed => {
571 let db_path = db_key_for_path(path);
572 deleted_entry_ranges_tx
573 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
574 .await?;
575 }
576 project::PathChange::Loaded => {
577 // Do nothing.
578 }
579 }
580 }
581
582 Ok(())
583 });
584
585 ScanEntries {
586 updated_entries: updated_entries_rx,
587 deleted_entry_ranges: deleted_entry_ranges_rx,
588 task,
589 }
590 }
591
592 fn chunk_files(
593 &self,
594 worktree_abs_path: Arc<Path>,
595 entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
596 cx: &AppContext,
597 ) -> ChunkFiles {
598 let language_registry = self.language_registry.clone();
599 let fs = self.fs.clone();
600 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
601 let task = cx.spawn(|cx| async move {
602 cx.background_executor()
603 .scoped(|cx| {
604 for _ in 0..cx.num_cpus() {
605 cx.spawn(async {
606 while let Ok((entry, handle)) = entries.recv().await {
607 let entry_abs_path = worktree_abs_path.join(&entry.path);
608 let Some(text) = fs
609 .load(&entry_abs_path)
610 .await
611 .with_context(|| {
612 format!("failed to read path {entry_abs_path:?}")
613 })
614 .log_err()
615 else {
616 continue;
617 };
618 let language = language_registry
619 .language_for_file_path(&entry.path)
620 .await
621 .ok();
622 let grammar =
623 language.as_ref().and_then(|language| language.grammar());
624 let chunked_file = ChunkedFile {
625 chunks: chunk_text(&text, grammar),
626 handle,
627 path: entry.path,
628 mtime: entry.mtime,
629 text,
630 };
631
632 if chunked_files_tx.send(chunked_file).await.is_err() {
633 return;
634 }
635 }
636 });
637 }
638 })
639 .await;
640 Ok(())
641 });
642
643 ChunkFiles {
644 files: chunked_files_rx,
645 task,
646 }
647 }
648
649 fn embed_files(
650 embedding_provider: Arc<dyn EmbeddingProvider>,
651 chunked_files: channel::Receiver<ChunkedFile>,
652 cx: &AppContext,
653 ) -> EmbedFiles {
654 let embedding_provider = embedding_provider.clone();
655 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
656 let task = cx.background_executor().spawn(async move {
657 let mut chunked_file_batches =
658 chunked_files.chunks_timeout(512, Duration::from_secs(2));
659 while let Some(chunked_files) = chunked_file_batches.next().await {
660 // View the batch of files as a vec of chunks
661 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
662 // Once those are done, reassemble them back into the files in which they belong
663 // If any embeddings fail for a file, the entire file is discarded
664
665 let chunks: Vec<TextToEmbed> = chunked_files
666 .iter()
667 .flat_map(|file| {
668 file.chunks.iter().map(|chunk| TextToEmbed {
669 text: &file.text[chunk.range.clone()],
670 digest: chunk.digest,
671 })
672 })
673 .collect::<Vec<_>>();
674
675 let mut embeddings: Vec<Option<Embedding>> = Vec::new();
676 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
677 if let Some(batch_embeddings) =
678 embedding_provider.embed(embedding_batch).await.log_err()
679 {
680 if batch_embeddings.len() == embedding_batch.len() {
681 embeddings.extend(batch_embeddings.into_iter().map(Some));
682 continue;
683 }
684 log::error!(
685 "embedding provider returned unexpected embedding count {}, expected {}",
686 batch_embeddings.len(), embedding_batch.len()
687 );
688 }
689
690 embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
691 }
692
693 let mut embeddings = embeddings.into_iter();
694 for chunked_file in chunked_files {
695 let mut embedded_file = EmbeddedFile {
696 path: chunked_file.path,
697 mtime: chunked_file.mtime,
698 chunks: Vec::new(),
699 };
700
701 let mut embedded_all_chunks = true;
702 for (chunk, embedding) in
703 chunked_file.chunks.into_iter().zip(embeddings.by_ref())
704 {
705 if let Some(embedding) = embedding {
706 embedded_file
707 .chunks
708 .push(EmbeddedChunk { chunk, embedding });
709 } else {
710 embedded_all_chunks = false;
711 }
712 }
713
714 if embedded_all_chunks {
715 embedded_files_tx
716 .send((embedded_file, chunked_file.handle))
717 .await?;
718 }
719 }
720 }
721 Ok(())
722 });
723
724 EmbedFiles {
725 files: embedded_files_rx,
726 task,
727 }
728 }
729
730 fn persist_embeddings(
731 &self,
732 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
733 embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
734 cx: &AppContext,
735 ) -> Task<Result<()>> {
736 let db_connection = self.db_connection.clone();
737 let db = self.db;
738 cx.background_executor().spawn(async move {
739 while let Some(deletion_range) = deleted_entry_ranges.next().await {
740 let mut txn = db_connection.write_txn()?;
741 let start = deletion_range.0.as_ref().map(|start| start.as_str());
742 let end = deletion_range.1.as_ref().map(|end| end.as_str());
743 log::debug!("deleting embeddings in range {:?}", &(start, end));
744 db.delete_range(&mut txn, &(start, end))?;
745 txn.commit()?;
746 }
747
748 let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
749 while let Some(embedded_files) = embedded_files.next().await {
750 let mut txn = db_connection.write_txn()?;
751 for (file, _) in &embedded_files {
752 log::debug!("saving embedding for file {:?}", file.path);
753 let key = db_key_for_path(&file.path);
754 db.put(&mut txn, &key, file)?;
755 }
756 txn.commit()?;
757 eprintln!("committed {:?}", embedded_files.len());
758
759 drop(embedded_files);
760 log::debug!("committed");
761 }
762
763 Ok(())
764 })
765 }
766
767 fn search(
768 &self,
769 query: &str,
770 limit: usize,
771 cx: &AppContext,
772 ) -> Task<Result<Vec<SearchResult>>> {
773 let (chunks_tx, chunks_rx) = channel::bounded(1024);
774
775 let db_connection = self.db_connection.clone();
776 let db = self.db;
777 let scan_chunks = cx.background_executor().spawn({
778 async move {
779 let txn = db_connection
780 .read_txn()
781 .context("failed to create read transaction")?;
782 let db_entries = db.iter(&txn).context("failed to iterate database")?;
783 for db_entry in db_entries {
784 let (_key, db_embedded_file) = db_entry?;
785 for chunk in db_embedded_file.chunks {
786 chunks_tx
787 .send((db_embedded_file.path.clone(), chunk))
788 .await?;
789 }
790 }
791 anyhow::Ok(())
792 }
793 });
794
795 let query = query.to_string();
796 let embedding_provider = self.embedding_provider.clone();
797 let worktree = self.worktree.clone();
798 cx.spawn(|cx| async move {
799 #[cfg(debug_assertions)]
800 let embedding_query_start = std::time::Instant::now();
801 log::info!("Searching for {query}");
802
803 let mut query_embeddings = embedding_provider
804 .embed(&[TextToEmbed::new(&query)])
805 .await?;
806 let query_embedding = query_embeddings
807 .pop()
808 .ok_or_else(|| anyhow!("no embedding for query"))?;
809 let mut workers = Vec::new();
810 for _ in 0..cx.background_executor().num_cpus() {
811 workers.push(Vec::<SearchResult>::new());
812 }
813
814 #[cfg(debug_assertions)]
815 let search_start = std::time::Instant::now();
816
817 cx.background_executor()
818 .scoped(|cx| {
819 for worker_results in workers.iter_mut() {
820 cx.spawn(async {
821 while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
822 let score = embedded_chunk.embedding.similarity(&query_embedding);
823 let ix = match worker_results.binary_search_by(|probe| {
824 score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
825 }) {
826 Ok(ix) | Err(ix) => ix,
827 };
828 worker_results.insert(
829 ix,
830 SearchResult {
831 worktree: worktree.clone(),
832 path: path.clone(),
833 range: embedded_chunk.chunk.range.clone(),
834 score,
835 },
836 );
837 worker_results.truncate(limit);
838 }
839 });
840 }
841 })
842 .await;
843 scan_chunks.await?;
844
845 let mut search_results = Vec::with_capacity(workers.len() * limit);
846 for worker_results in workers {
847 search_results.extend(worker_results);
848 }
849 search_results
850 .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
851 search_results.truncate(limit);
852 #[cfg(debug_assertions)]
853 {
854 let search_elapsed = search_start.elapsed();
855 log::debug!(
856 "searched {} entries in {:?}",
857 search_results.len(),
858 search_elapsed
859 );
860 let embedding_query_elapsed = embedding_query_start.elapsed();
861 log::debug!("embedding query took {:?}", embedding_query_elapsed);
862 }
863
864 Ok(search_results)
865 })
866 }
867
868 fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
869 let connection = self.db_connection.clone();
870 let db = self.db;
871 cx.background_executor().spawn(async move {
872 let tx = connection
873 .read_txn()
874 .context("failed to create read transaction")?;
875 for record in db.iter(&tx)? {
876 let (key, _) = record?;
877 eprintln!("{}", path_for_db_key(key));
878 }
879 Ok(())
880 })
881 }
882
883 #[cfg(test)]
884 fn path_count(&self) -> Result<u64> {
885 let txn = self
886 .db_connection
887 .read_txn()
888 .context("failed to create read transaction")?;
889 Ok(self.db.len(&txn)?)
890 }
891}
892
893struct ScanEntries {
894 updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
895 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
896 task: Task<Result<()>>,
897}
898
899struct ChunkFiles {
900 files: channel::Receiver<ChunkedFile>,
901 task: Task<Result<()>>,
902}
903
904struct ChunkedFile {
905 pub path: Arc<Path>,
906 pub mtime: Option<SystemTime>,
907 pub handle: IndexingEntryHandle,
908 pub text: String,
909 pub chunks: Vec<Chunk>,
910}
911
912struct EmbedFiles {
913 files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
914 task: Task<Result<()>>,
915}
916
917#[derive(Debug, Serialize, Deserialize)]
918struct EmbeddedFile {
919 path: Arc<Path>,
920 mtime: Option<SystemTime>,
921 chunks: Vec<EmbeddedChunk>,
922}
923
924#[derive(Debug, Serialize, Deserialize)]
925struct EmbeddedChunk {
926 chunk: Chunk,
927 embedding: Embedding,
928}
929
930/// The set of entries that are currently being indexed.
931struct IndexingEntrySet {
932 entry_ids: Mutex<HashSet<ProjectEntryId>>,
933 tx: channel::Sender<()>,
934}
935
936/// When dropped, removes the entry from the set of entries that are being indexed.
937#[derive(Clone)]
938struct IndexingEntryHandle {
939 entry_id: ProjectEntryId,
940 set: Weak<IndexingEntrySet>,
941}
942
943impl IndexingEntrySet {
944 fn new(tx: channel::Sender<()>) -> Self {
945 Self {
946 entry_ids: Default::default(),
947 tx,
948 }
949 }
950
951 fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
952 self.entry_ids.lock().insert(entry_id);
953 self.tx.send_blocking(()).ok();
954 IndexingEntryHandle {
955 entry_id,
956 set: Arc::downgrade(self),
957 }
958 }
959
960 pub fn len(&self) -> usize {
961 self.entry_ids.lock().len()
962 }
963}
964
965impl Drop for IndexingEntryHandle {
966 fn drop(&mut self) {
967 if let Some(set) = self.set.upgrade() {
968 set.tx.send_blocking(()).ok();
969 set.entry_ids.lock().remove(&self.entry_id);
970 }
971 }
972}
973
974fn db_key_for_path(path: &Arc<Path>) -> String {
975 path.to_string_lossy().replace('/', "\0")
976}
977
978fn path_for_db_key(key: &str) -> String {
979 key.replace('\0', "/")
980}
981
982#[cfg(test)]
983mod tests {
984 use super::*;
985 use futures::{future::BoxFuture, FutureExt};
986 use gpui::TestAppContext;
987 use language::language_settings::AllLanguageSettings;
988 use project::Project;
989 use settings::SettingsStore;
990 use std::{future, path::Path, sync::Arc};
991
992 fn init_test(cx: &mut TestAppContext) {
993 _ = cx.update(|cx| {
994 let store = SettingsStore::test(cx);
995 cx.set_global(store);
996 language::init(cx);
997 Project::init_settings(cx);
998 SettingsStore::update(cx, |store, cx| {
999 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1000 });
1001 });
1002 }
1003
1004 pub struct TestEmbeddingProvider {
1005 batch_size: usize,
1006 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1007 }
1008
1009 impl TestEmbeddingProvider {
1010 pub fn new(
1011 batch_size: usize,
1012 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1013 ) -> Self {
1014 return Self {
1015 batch_size,
1016 compute_embedding: Box::new(compute_embedding),
1017 };
1018 }
1019 }
1020
1021 impl EmbeddingProvider for TestEmbeddingProvider {
1022 fn embed<'a>(
1023 &'a self,
1024 texts: &'a [TextToEmbed<'a>],
1025 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1026 let embeddings = texts
1027 .iter()
1028 .map(|to_embed| (self.compute_embedding)(to_embed.text))
1029 .collect();
1030 future::ready(embeddings).boxed()
1031 }
1032
1033 fn batch_size(&self) -> usize {
1034 self.batch_size
1035 }
1036 }
1037
1038 #[gpui::test]
1039 async fn test_search(cx: &mut TestAppContext) {
1040 cx.executor().allow_parking();
1041
1042 init_test(cx);
1043
1044 let temp_dir = tempfile::tempdir().unwrap();
1045
1046 let mut semantic_index = SemanticIndex::new(
1047 temp_dir.path().into(),
1048 Arc::new(TestEmbeddingProvider::new(16, |text| {
1049 let mut embedding = vec![0f32; 2];
1050 // if the text contains garbage, give it a 1 in the first dimension
1051 if text.contains("garbage in") {
1052 embedding[0] = 0.9;
1053 } else {
1054 embedding[0] = -0.9;
1055 }
1056
1057 if text.contains("garbage out") {
1058 embedding[1] = 0.9;
1059 } else {
1060 embedding[1] = -0.9;
1061 }
1062
1063 Ok(Embedding::new(embedding))
1064 })),
1065 &mut cx.to_async(),
1066 )
1067 .await
1068 .unwrap();
1069
1070 let project_path = Path::new("./fixture");
1071
1072 let project = cx
1073 .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1074 .await;
1075
1076 cx.update(|cx| {
1077 let language_registry = project.read(cx).languages().clone();
1078 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1079 languages::init(language_registry, node_runtime, cx);
1080 });
1081
1082 let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1083
1084 while project_index
1085 .read_with(cx, |index, cx| index.path_count(cx))
1086 .unwrap()
1087 == 0
1088 {
1089 project_index.next_event(cx).await;
1090 }
1091
1092 let results = cx
1093 .update(|cx| {
1094 let project_index = project_index.read(cx);
1095 let query = "garbage in, garbage out";
1096 project_index.search(query, 4, cx)
1097 })
1098 .await;
1099
1100 assert!(results.len() > 1, "should have found some results");
1101
1102 for result in &results {
1103 println!("result: {:?}", result.path);
1104 println!("score: {:?}", result.score);
1105 }
1106
1107 // Find result that is greater than 0.5
1108 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1109
1110 assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1111
1112 let content = cx
1113 .update(|cx| {
1114 let worktree = search_result.worktree.read(cx);
1115 let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
1116 let fs = project.read(cx).fs().clone();
1117 cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
1118 })
1119 .await;
1120
1121 let range = search_result.range.clone();
1122 let content = content[range.clone()].to_owned();
1123
1124 assert!(content.contains("garbage in, garbage out"));
1125 }
1126
1127 #[gpui::test]
1128 async fn test_embed_files(cx: &mut TestAppContext) {
1129 cx.executor().allow_parking();
1130
1131 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1132 if text.contains('g') {
1133 Err(anyhow!("cannot embed text containing a 'g' character"))
1134 } else {
1135 Ok(Embedding::new(
1136 ('a'..'z')
1137 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1138 .collect(),
1139 ))
1140 }
1141 }));
1142
1143 let (indexing_progress_tx, _) = channel::unbounded();
1144 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1145
1146 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1147 chunked_files_tx
1148 .send_blocking(ChunkedFile {
1149 path: Path::new("test1.md").into(),
1150 mtime: None,
1151 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1152 text: "abcdefghijklmnop".to_string(),
1153 chunks: [0..4, 4..8, 8..12, 12..16]
1154 .into_iter()
1155 .map(|range| Chunk {
1156 range,
1157 digest: Default::default(),
1158 })
1159 .collect(),
1160 })
1161 .unwrap();
1162 chunked_files_tx
1163 .send_blocking(ChunkedFile {
1164 path: Path::new("test2.md").into(),
1165 mtime: None,
1166 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1167 text: "qrstuvwxyz".to_string(),
1168 chunks: [0..4, 4..8, 8..10]
1169 .into_iter()
1170 .map(|range| Chunk {
1171 range,
1172 digest: Default::default(),
1173 })
1174 .collect(),
1175 })
1176 .unwrap();
1177 chunked_files_tx.close();
1178
1179 let embed_files_task =
1180 cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1181 embed_files_task.task.await.unwrap();
1182
1183 let mut embedded_files_rx = embed_files_task.files;
1184 let mut embedded_files = Vec::new();
1185 while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1186 embedded_files.push(embedded_file);
1187 }
1188
1189 assert_eq!(embedded_files.len(), 1);
1190 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1191 assert_eq!(
1192 embedded_files[0]
1193 .chunks
1194 .iter()
1195 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1196 .collect::<Vec<Embedding>>(),
1197 vec![
1198 (provider.compute_embedding)("qrst").unwrap(),
1199 (provider.compute_embedding)("uvwx").unwrap(),
1200 (provider.compute_embedding)("yz").unwrap(),
1201 ],
1202 );
1203 }
1204}