1mod chunking;
2mod embedding;
3
4use anyhow::{anyhow, Context as _, Result};
5use chunking::{chunk_text, Chunk};
6use collections::{Bound, HashMap};
7pub use embedding::*;
8use fs::Fs;
9use futures::stream::StreamExt;
10use futures_batch::ChunksTimeoutStreamExt;
11use gpui::{
12 AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext,
13 Subscription, Task, WeakModel,
14};
15use heed::types::{SerdeBincode, Str};
16use language::LanguageRegistry;
17use project::{Entry, Project, UpdatedEntriesSet, Worktree};
18use serde::{Deserialize, Serialize};
19use smol::channel;
20use std::{
21 cmp::Ordering,
22 future::Future,
23 ops::Range,
24 path::{Path, PathBuf},
25 sync::Arc,
26 time::{Duration, SystemTime},
27};
28use util::ResultExt;
29use worktree::LocalSnapshot;
30
31pub struct SemanticIndex {
32 embedding_provider: Arc<dyn EmbeddingProvider>,
33 db_connection: heed::Env,
34 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
35}
36
37impl Global for SemanticIndex {}
38
39impl SemanticIndex {
40 pub async fn new(
41 db_path: PathBuf,
42 embedding_provider: Arc<dyn EmbeddingProvider>,
43 cx: &mut AsyncAppContext,
44 ) -> Result<Self> {
45 let db_connection = cx
46 .background_executor()
47 .spawn(async move {
48 std::fs::create_dir_all(&db_path)?;
49 unsafe {
50 heed::EnvOpenOptions::new()
51 .map_size(1024 * 1024 * 1024)
52 .max_dbs(3000)
53 .open(db_path)
54 }
55 })
56 .await
57 .context("opening database connection")?;
58
59 Ok(SemanticIndex {
60 db_connection,
61 embedding_provider,
62 project_indices: HashMap::default(),
63 })
64 }
65
66 pub fn project_index(
67 &mut self,
68 project: Model<Project>,
69 cx: &mut AppContext,
70 ) -> Model<ProjectIndex> {
71 self.project_indices
72 .entry(project.downgrade())
73 .or_insert_with(|| {
74 cx.new_model(|cx| {
75 ProjectIndex::new(
76 project,
77 self.db_connection.clone(),
78 self.embedding_provider.clone(),
79 cx,
80 )
81 })
82 })
83 .clone()
84 }
85}
86
87pub struct ProjectIndex {
88 db_connection: heed::Env,
89 project: Model<Project>,
90 worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
91 language_registry: Arc<LanguageRegistry>,
92 fs: Arc<dyn Fs>,
93 pub last_status: Status,
94 embedding_provider: Arc<dyn EmbeddingProvider>,
95 _subscription: Subscription,
96}
97
98enum WorktreeIndexHandle {
99 Loading {
100 _task: Task<Result<()>>,
101 },
102 Loaded {
103 index: Model<WorktreeIndex>,
104 _subscription: Subscription,
105 },
106}
107
108impl ProjectIndex {
109 fn new(
110 project: Model<Project>,
111 db_connection: heed::Env,
112 embedding_provider: Arc<dyn EmbeddingProvider>,
113 cx: &mut ModelContext<Self>,
114 ) -> Self {
115 let language_registry = project.read(cx).languages().clone();
116 let fs = project.read(cx).fs().clone();
117 let mut this = ProjectIndex {
118 db_connection,
119 project: project.clone(),
120 worktree_indices: HashMap::default(),
121 language_registry,
122 fs,
123 last_status: Status::Idle,
124 embedding_provider,
125 _subscription: cx.subscribe(&project, Self::handle_project_event),
126 };
127 this.update_worktree_indices(cx);
128 this
129 }
130
131 fn handle_project_event(
132 &mut self,
133 _: Model<Project>,
134 event: &project::Event,
135 cx: &mut ModelContext<Self>,
136 ) {
137 match event {
138 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
139 self.update_worktree_indices(cx);
140 }
141 _ => {}
142 }
143 }
144
145 fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
146 let worktrees = self
147 .project
148 .read(cx)
149 .visible_worktrees(cx)
150 .filter_map(|worktree| {
151 if worktree.read(cx).is_local() {
152 Some((worktree.entity_id(), worktree))
153 } else {
154 None
155 }
156 })
157 .collect::<HashMap<_, _>>();
158
159 self.worktree_indices
160 .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
161 for (worktree_id, worktree) in worktrees {
162 self.worktree_indices.entry(worktree_id).or_insert_with(|| {
163 let worktree_index = WorktreeIndex::load(
164 worktree.clone(),
165 self.db_connection.clone(),
166 self.language_registry.clone(),
167 self.fs.clone(),
168 self.embedding_provider.clone(),
169 cx,
170 );
171
172 let load_worktree = cx.spawn(|this, mut cx| async move {
173 if let Some(index) = worktree_index.await.log_err() {
174 this.update(&mut cx, |this, cx| {
175 this.worktree_indices.insert(
176 worktree_id,
177 WorktreeIndexHandle::Loaded {
178 _subscription: cx
179 .observe(&index, |this, _, cx| this.update_status(cx)),
180 index,
181 },
182 );
183 })?;
184 } else {
185 this.update(&mut cx, |this, _cx| {
186 this.worktree_indices.remove(&worktree_id)
187 })?;
188 }
189
190 this.update(&mut cx, |this, cx| this.update_status(cx))
191 });
192
193 WorktreeIndexHandle::Loading {
194 _task: load_worktree,
195 }
196 });
197 }
198
199 self.update_status(cx);
200 }
201
202 fn update_status(&mut self, cx: &mut ModelContext<Self>) {
203 let mut status = Status::Idle;
204 for index in self.worktree_indices.values() {
205 match index {
206 WorktreeIndexHandle::Loading { .. } => {
207 status = Status::Scanning;
208 break;
209 }
210 WorktreeIndexHandle::Loaded { index, .. } => {
211 if index.read(cx).status == Status::Scanning {
212 status = Status::Scanning;
213 break;
214 }
215 }
216 }
217 }
218
219 if status != self.last_status {
220 self.last_status = status;
221 cx.emit(status);
222 }
223 }
224
225 pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
226 let mut worktree_searches = Vec::new();
227 for worktree_index in self.worktree_indices.values() {
228 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
229 worktree_searches
230 .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
231 }
232 }
233
234 cx.spawn(|_| async move {
235 let mut results = Vec::new();
236 let worktree_searches = futures::future::join_all(worktree_searches).await;
237
238 for worktree_search_results in worktree_searches {
239 if let Some(worktree_search_results) = worktree_search_results.log_err() {
240 results.extend(worktree_search_results);
241 }
242 }
243
244 results
245 .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
246 results.truncate(limit);
247
248 results
249 })
250 }
251}
252
253pub struct SearchResult {
254 pub worktree: Model<Worktree>,
255 pub path: Arc<Path>,
256 pub range: Range<usize>,
257 pub score: f32,
258}
259
260#[derive(Copy, Clone, Debug, Eq, PartialEq)]
261pub enum Status {
262 Idle,
263 Scanning,
264}
265
266impl EventEmitter<Status> for ProjectIndex {}
267
268struct WorktreeIndex {
269 worktree: Model<Worktree>,
270 db_connection: heed::Env,
271 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
272 language_registry: Arc<LanguageRegistry>,
273 fs: Arc<dyn Fs>,
274 embedding_provider: Arc<dyn EmbeddingProvider>,
275 status: Status,
276 _index_entries: Task<Result<()>>,
277 _subscription: Subscription,
278}
279
280impl WorktreeIndex {
281 pub fn load(
282 worktree: Model<Worktree>,
283 db_connection: heed::Env,
284 language_registry: Arc<LanguageRegistry>,
285 fs: Arc<dyn Fs>,
286 embedding_provider: Arc<dyn EmbeddingProvider>,
287 cx: &mut AppContext,
288 ) -> Task<Result<Model<Self>>> {
289 let worktree_abs_path = worktree.read(cx).abs_path();
290 cx.spawn(|mut cx| async move {
291 let db = cx
292 .background_executor()
293 .spawn({
294 let db_connection = db_connection.clone();
295 async move {
296 let mut txn = db_connection.write_txn()?;
297 let db_name = worktree_abs_path.to_string_lossy();
298 let db = db_connection.create_database(&mut txn, Some(&db_name))?;
299 txn.commit()?;
300 anyhow::Ok(db)
301 }
302 })
303 .await?;
304 cx.new_model(|cx| {
305 Self::new(
306 worktree,
307 db_connection,
308 db,
309 language_registry,
310 fs,
311 embedding_provider,
312 cx,
313 )
314 })
315 })
316 }
317
318 fn new(
319 worktree: Model<Worktree>,
320 db_connection: heed::Env,
321 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
322 language_registry: Arc<LanguageRegistry>,
323 fs: Arc<dyn Fs>,
324 embedding_provider: Arc<dyn EmbeddingProvider>,
325 cx: &mut ModelContext<Self>,
326 ) -> Self {
327 let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
328 let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
329 if let worktree::Event::UpdatedEntries(update) = event {
330 _ = updated_entries_tx.try_send(update.clone());
331 }
332 });
333
334 Self {
335 db_connection,
336 db,
337 worktree,
338 language_registry,
339 fs,
340 embedding_provider,
341 status: Status::Idle,
342 _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
343 _subscription,
344 }
345 }
346
347 async fn index_entries(
348 this: WeakModel<Self>,
349 updated_entries: channel::Receiver<UpdatedEntriesSet>,
350 mut cx: AsyncAppContext,
351 ) -> Result<()> {
352 let index = this.update(&mut cx, |this, cx| {
353 cx.notify();
354 this.status = Status::Scanning;
355 this.index_entries_changed_on_disk(cx)
356 })?;
357 index.await.log_err();
358 this.update(&mut cx, |this, cx| {
359 this.status = Status::Idle;
360 cx.notify();
361 })?;
362
363 while let Ok(updated_entries) = updated_entries.recv().await {
364 let index = this.update(&mut cx, |this, cx| {
365 cx.notify();
366 this.status = Status::Scanning;
367 this.index_updated_entries(updated_entries, cx)
368 })?;
369 index.await.log_err();
370 this.update(&mut cx, |this, cx| {
371 this.status = Status::Idle;
372 cx.notify();
373 })?;
374 }
375
376 Ok(())
377 }
378
379 fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
380 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
381 let worktree_abs_path = worktree.abs_path().clone();
382 let scan = self.scan_entries(worktree.clone(), cx);
383 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
384 let embed = self.embed_files(chunk.files, cx);
385 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
386 async move {
387 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
388 Ok(())
389 }
390 }
391
392 fn index_updated_entries(
393 &self,
394 updated_entries: UpdatedEntriesSet,
395 cx: &AppContext,
396 ) -> impl Future<Output = Result<()>> {
397 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
398 let worktree_abs_path = worktree.abs_path().clone();
399 let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
400 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
401 let embed = self.embed_files(chunk.files, cx);
402 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
403 async move {
404 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
405 Ok(())
406 }
407 }
408
409 fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
410 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
411 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
412 let db_connection = self.db_connection.clone();
413 let db = self.db;
414 let task = cx.background_executor().spawn(async move {
415 let txn = db_connection
416 .read_txn()
417 .context("failed to create read transaction")?;
418 let mut db_entries = db
419 .iter(&txn)
420 .context("failed to create iterator")?
421 .move_between_keys()
422 .peekable();
423
424 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
425 for entry in worktree.files(false, 0) {
426 let entry_db_key = db_key_for_path(&entry.path);
427
428 let mut saved_mtime = None;
429 while let Some(db_entry) = db_entries.peek() {
430 match db_entry {
431 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
432 Ordering::Less => {
433 if let Some(deletion_range) = deletion_range.as_mut() {
434 deletion_range.1 = Bound::Included(db_path);
435 } else {
436 deletion_range =
437 Some((Bound::Included(db_path), Bound::Included(db_path)));
438 }
439
440 db_entries.next();
441 }
442 Ordering::Equal => {
443 if let Some(deletion_range) = deletion_range.take() {
444 deleted_entry_ranges_tx
445 .send((
446 deletion_range.0.map(ToString::to_string),
447 deletion_range.1.map(ToString::to_string),
448 ))
449 .await?;
450 }
451 saved_mtime = db_embedded_file.mtime;
452 db_entries.next();
453 break;
454 }
455 Ordering::Greater => {
456 break;
457 }
458 },
459 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
460 }
461 }
462
463 if entry.mtime != saved_mtime {
464 updated_entries_tx.send(entry.clone()).await?;
465 }
466 }
467
468 if let Some(db_entry) = db_entries.next() {
469 let (db_path, _) = db_entry?;
470 deleted_entry_ranges_tx
471 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
472 .await?;
473 }
474
475 Ok(())
476 });
477
478 ScanEntries {
479 updated_entries: updated_entries_rx,
480 deleted_entry_ranges: deleted_entry_ranges_rx,
481 task,
482 }
483 }
484
485 fn scan_updated_entries(
486 &self,
487 worktree: LocalSnapshot,
488 updated_entries: UpdatedEntriesSet,
489 cx: &AppContext,
490 ) -> ScanEntries {
491 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
492 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
493 let task = cx.background_executor().spawn(async move {
494 for (path, entry_id, status) in updated_entries.iter() {
495 match status {
496 project::PathChange::Added
497 | project::PathChange::Updated
498 | project::PathChange::AddedOrUpdated => {
499 if let Some(entry) = worktree.entry_for_id(*entry_id) {
500 if entry.is_file() {
501 updated_entries_tx.send(entry.clone()).await?;
502 }
503 }
504 }
505 project::PathChange::Removed => {
506 let db_path = db_key_for_path(path);
507 deleted_entry_ranges_tx
508 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
509 .await?;
510 }
511 project::PathChange::Loaded => {
512 // Do nothing.
513 }
514 }
515 }
516
517 Ok(())
518 });
519
520 ScanEntries {
521 updated_entries: updated_entries_rx,
522 deleted_entry_ranges: deleted_entry_ranges_rx,
523 task,
524 }
525 }
526
527 fn chunk_files(
528 &self,
529 worktree_abs_path: Arc<Path>,
530 entries: channel::Receiver<Entry>,
531 cx: &AppContext,
532 ) -> ChunkFiles {
533 let language_registry = self.language_registry.clone();
534 let fs = self.fs.clone();
535 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
536 let task = cx.spawn(|cx| async move {
537 cx.background_executor()
538 .scoped(|cx| {
539 for _ in 0..cx.num_cpus() {
540 cx.spawn(async {
541 while let Ok(entry) = entries.recv().await {
542 let entry_abs_path = worktree_abs_path.join(&entry.path);
543 let Some(text) = fs
544 .load(&entry_abs_path)
545 .await
546 .with_context(|| {
547 format!("failed to read path {entry_abs_path:?}")
548 })
549 .log_err()
550 else {
551 continue;
552 };
553 let language = language_registry
554 .language_for_file_path(&entry.path)
555 .await
556 .ok();
557 let grammar =
558 language.as_ref().and_then(|language| language.grammar());
559 let chunked_file = ChunkedFile {
560 worktree_root: worktree_abs_path.clone(),
561 chunks: chunk_text(&text, grammar),
562 entry,
563 text,
564 };
565
566 if chunked_files_tx.send(chunked_file).await.is_err() {
567 return;
568 }
569 }
570 });
571 }
572 })
573 .await;
574 Ok(())
575 });
576
577 ChunkFiles {
578 files: chunked_files_rx,
579 task,
580 }
581 }
582
583 fn embed_files(
584 &self,
585 chunked_files: channel::Receiver<ChunkedFile>,
586 cx: &AppContext,
587 ) -> EmbedFiles {
588 let embedding_provider = self.embedding_provider.clone();
589 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
590 let task = cx.background_executor().spawn(async move {
591 let mut chunked_file_batches =
592 chunked_files.chunks_timeout(512, Duration::from_secs(2));
593 while let Some(chunked_files) = chunked_file_batches.next().await {
594 // View the batch of files as a vec of chunks
595 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
596 // Once those are done, reassemble it back into which files they belong to
597
598 let chunks = chunked_files
599 .iter()
600 .flat_map(|file| {
601 file.chunks.iter().map(|chunk| TextToEmbed {
602 text: &file.text[chunk.range.clone()],
603 digest: chunk.digest,
604 })
605 })
606 .collect::<Vec<_>>();
607
608 let mut embeddings = Vec::new();
609 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
610 embeddings.extend(embedding_provider.embed(embedding_batch).await?);
611 }
612
613 let mut embeddings = embeddings.into_iter();
614 for chunked_file in chunked_files {
615 let chunk_embeddings = embeddings
616 .by_ref()
617 .take(chunked_file.chunks.len())
618 .collect::<Vec<_>>();
619 let embedded_chunks = chunked_file
620 .chunks
621 .into_iter()
622 .zip(chunk_embeddings)
623 .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
624 .collect();
625 let embedded_file = EmbeddedFile {
626 path: chunked_file.entry.path.clone(),
627 mtime: chunked_file.entry.mtime,
628 chunks: embedded_chunks,
629 };
630
631 embedded_files_tx.send(embedded_file).await?;
632 }
633 }
634 Ok(())
635 });
636
637 EmbedFiles {
638 files: embedded_files_rx,
639 task,
640 }
641 }
642
643 fn persist_embeddings(
644 &self,
645 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
646 embedded_files: channel::Receiver<EmbeddedFile>,
647 cx: &AppContext,
648 ) -> Task<Result<()>> {
649 let db_connection = self.db_connection.clone();
650 let db = self.db;
651 cx.background_executor().spawn(async move {
652 while let Some(deletion_range) = deleted_entry_ranges.next().await {
653 let mut txn = db_connection.write_txn()?;
654 let start = deletion_range.0.as_ref().map(|start| start.as_str());
655 let end = deletion_range.1.as_ref().map(|end| end.as_str());
656 log::debug!("deleting embeddings in range {:?}", &(start, end));
657 db.delete_range(&mut txn, &(start, end))?;
658 txn.commit()?;
659 }
660
661 let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
662 while let Some(embedded_files) = embedded_files.next().await {
663 let mut txn = db_connection.write_txn()?;
664 for file in embedded_files {
665 log::debug!("saving embedding for file {:?}", file.path);
666 let key = db_key_for_path(&file.path);
667 db.put(&mut txn, &key, &file)?;
668 }
669 txn.commit()?;
670 log::debug!("committed");
671 }
672
673 Ok(())
674 })
675 }
676
677 fn search(
678 &self,
679 query: &str,
680 limit: usize,
681 cx: &AppContext,
682 ) -> Task<Result<Vec<SearchResult>>> {
683 let (chunks_tx, chunks_rx) = channel::bounded(1024);
684
685 let db_connection = self.db_connection.clone();
686 let db = self.db;
687 let scan_chunks = cx.background_executor().spawn({
688 async move {
689 let txn = db_connection
690 .read_txn()
691 .context("failed to create read transaction")?;
692 let db_entries = db.iter(&txn).context("failed to iterate database")?;
693 for db_entry in db_entries {
694 let (_key, db_embedded_file) = db_entry?;
695 for chunk in db_embedded_file.chunks {
696 chunks_tx
697 .send((db_embedded_file.path.clone(), chunk))
698 .await?;
699 }
700 }
701 anyhow::Ok(())
702 }
703 });
704
705 let query = query.to_string();
706 let embedding_provider = self.embedding_provider.clone();
707 let worktree = self.worktree.clone();
708 cx.spawn(|cx| async move {
709 #[cfg(debug_assertions)]
710 let embedding_query_start = std::time::Instant::now();
711 log::info!("Searching for {query}");
712
713 let mut query_embeddings = embedding_provider
714 .embed(&[TextToEmbed::new(&query)])
715 .await?;
716 let query_embedding = query_embeddings
717 .pop()
718 .ok_or_else(|| anyhow!("no embedding for query"))?;
719 let mut workers = Vec::new();
720 for _ in 0..cx.background_executor().num_cpus() {
721 workers.push(Vec::<SearchResult>::new());
722 }
723
724 #[cfg(debug_assertions)]
725 let search_start = std::time::Instant::now();
726
727 cx.background_executor()
728 .scoped(|cx| {
729 for worker_results in workers.iter_mut() {
730 cx.spawn(async {
731 while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
732 let score = embedded_chunk.embedding.similarity(&query_embedding);
733 let ix = match worker_results.binary_search_by(|probe| {
734 score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
735 }) {
736 Ok(ix) | Err(ix) => ix,
737 };
738 worker_results.insert(
739 ix,
740 SearchResult {
741 worktree: worktree.clone(),
742 path: path.clone(),
743 range: embedded_chunk.chunk.range.clone(),
744 score,
745 },
746 );
747 worker_results.truncate(limit);
748 }
749 });
750 }
751 })
752 .await;
753 scan_chunks.await?;
754
755 let mut search_results = Vec::with_capacity(workers.len() * limit);
756 for worker_results in workers {
757 search_results.extend(worker_results);
758 }
759 search_results
760 .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
761 search_results.truncate(limit);
762 #[cfg(debug_assertions)]
763 {
764 let search_elapsed = search_start.elapsed();
765 log::debug!(
766 "searched {} entries in {:?}",
767 search_results.len(),
768 search_elapsed
769 );
770 let embedding_query_elapsed = embedding_query_start.elapsed();
771 log::debug!("embedding query took {:?}", embedding_query_elapsed);
772 }
773
774 Ok(search_results)
775 })
776 }
777}
778
779struct ScanEntries {
780 updated_entries: channel::Receiver<Entry>,
781 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
782 task: Task<Result<()>>,
783}
784
785struct ChunkFiles {
786 files: channel::Receiver<ChunkedFile>,
787 task: Task<Result<()>>,
788}
789
790struct ChunkedFile {
791 #[allow(dead_code)]
792 pub worktree_root: Arc<Path>,
793 pub entry: Entry,
794 pub text: String,
795 pub chunks: Vec<Chunk>,
796}
797
798struct EmbedFiles {
799 files: channel::Receiver<EmbeddedFile>,
800 task: Task<Result<()>>,
801}
802
803#[derive(Debug, Serialize, Deserialize)]
804struct EmbeddedFile {
805 path: Arc<Path>,
806 mtime: Option<SystemTime>,
807 chunks: Vec<EmbeddedChunk>,
808}
809
810#[derive(Debug, Serialize, Deserialize)]
811struct EmbeddedChunk {
812 chunk: Chunk,
813 embedding: Embedding,
814}
815
816fn db_key_for_path(path: &Arc<Path>) -> String {
817 path.to_string_lossy().replace('/', "\0")
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 use futures::channel::oneshot;
825 use futures::{future::BoxFuture, FutureExt};
826
827 use gpui::{Global, TestAppContext};
828 use language::language_settings::AllLanguageSettings;
829 use project::Project;
830 use settings::SettingsStore;
831 use std::{future, path::Path, sync::Arc};
832
833 fn init_test(cx: &mut TestAppContext) {
834 _ = cx.update(|cx| {
835 let store = SettingsStore::test(cx);
836 cx.set_global(store);
837 language::init(cx);
838 Project::init_settings(cx);
839 SettingsStore::update(cx, |store, cx| {
840 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
841 });
842 });
843 }
844
845 pub struct TestEmbeddingProvider;
846
847 impl EmbeddingProvider for TestEmbeddingProvider {
848 fn embed<'a>(
849 &'a self,
850 texts: &'a [TextToEmbed<'a>],
851 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
852 let embeddings = texts
853 .iter()
854 .map(|text| {
855 let mut embedding = vec![0f32; 2];
856 // if the text contains garbage, give it a 1 in the first dimension
857 if text.text.contains("garbage in") {
858 embedding[0] = 0.9;
859 } else {
860 embedding[0] = -0.9;
861 }
862
863 if text.text.contains("garbage out") {
864 embedding[1] = 0.9;
865 } else {
866 embedding[1] = -0.9;
867 }
868
869 Embedding::new(embedding)
870 })
871 .collect();
872 future::ready(Ok(embeddings)).boxed()
873 }
874
875 fn batch_size(&self) -> usize {
876 16
877 }
878 }
879
880 #[gpui::test]
881 async fn test_search(cx: &mut TestAppContext) {
882 cx.executor().allow_parking();
883
884 init_test(cx);
885
886 let temp_dir = tempfile::tempdir().unwrap();
887
888 let mut semantic_index = SemanticIndex::new(
889 temp_dir.path().into(),
890 Arc::new(TestEmbeddingProvider),
891 &mut cx.to_async(),
892 )
893 .await
894 .unwrap();
895
896 let project_path = Path::new("./fixture");
897
898 let project = cx
899 .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
900 .await;
901
902 cx.update(|cx| {
903 let language_registry = project.read(cx).languages().clone();
904 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
905 languages::init(language_registry, node_runtime, cx);
906 });
907
908 let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
909
910 let (tx, rx) = oneshot::channel();
911 let mut tx = Some(tx);
912 let subscription = cx.update(|cx| {
913 cx.subscribe(&project_index, move |_, event, _| {
914 if let Some(tx) = tx.take() {
915 _ = tx.send(*event);
916 }
917 })
918 });
919
920 rx.await.expect("no event emitted");
921 drop(subscription);
922
923 let results = cx
924 .update(|cx| {
925 let project_index = project_index.read(cx);
926 let query = "garbage in, garbage out";
927 project_index.search(query, 4, cx)
928 })
929 .await;
930
931 assert!(results.len() > 1, "should have found some results");
932
933 for result in &results {
934 println!("result: {:?}", result.path);
935 println!("score: {:?}", result.score);
936 }
937
938 // Find result that is greater than 0.5
939 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
940
941 assert_eq!(search_result.path.to_string_lossy(), "needle.md");
942
943 let content = cx
944 .update(|cx| {
945 let worktree = search_result.worktree.read(cx);
946 let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
947 let fs = project.read(cx).fs().clone();
948 cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
949 })
950 .await;
951
952 let range = search_result.range.clone();
953 let content = content[range.clone()].to_owned();
954
955 assert!(content.contains("garbage in, garbage out"));
956 }
957}