1mod chunking;
2mod embedding;
3
4use anyhow::{anyhow, Context as _, Result};
5use chunking::{chunk_text, Chunk};
6use collections::{Bound, HashMap};
7pub use embedding::*;
8use fs::Fs;
9use futures::stream::StreamExt;
10use futures_batch::ChunksTimeoutStreamExt;
11use gpui::{
12 AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext,
13 Subscription, Task, WeakModel,
14};
15use heed::types::{SerdeBincode, Str};
16use language::LanguageRegistry;
17use project::{Entry, Project, UpdatedEntriesSet, Worktree};
18use serde::{Deserialize, Serialize};
19use smol::channel;
20use std::{
21 cmp::Ordering,
22 future::Future,
23 ops::Range,
24 path::Path,
25 sync::Arc,
26 time::{Duration, SystemTime},
27};
28use util::ResultExt;
29use worktree::LocalSnapshot;
30
31pub struct SemanticIndex {
32 embedding_provider: Arc<dyn EmbeddingProvider>,
33 db_connection: heed::Env,
34 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
35}
36
37impl Global for SemanticIndex {}
38
39impl SemanticIndex {
40 pub fn new(
41 db_path: &Path,
42 embedding_provider: Arc<dyn EmbeddingProvider>,
43 cx: &mut AppContext,
44 ) -> Task<Result<Self>> {
45 let db_path = db_path.to_path_buf();
46 cx.spawn(|cx| async move {
47 let db_connection = cx
48 .background_executor()
49 .spawn(async move {
50 unsafe {
51 heed::EnvOpenOptions::new()
52 .map_size(1024 * 1024 * 1024)
53 .max_dbs(3000)
54 .open(db_path)
55 }
56 })
57 .await?;
58
59 Ok(SemanticIndex {
60 db_connection,
61 embedding_provider,
62 project_indices: HashMap::default(),
63 })
64 })
65 }
66
67 pub fn project_index(
68 &mut self,
69 project: Model<Project>,
70 cx: &mut AppContext,
71 ) -> Model<ProjectIndex> {
72 self.project_indices
73 .entry(project.downgrade())
74 .or_insert_with(|| {
75 cx.new_model(|cx| {
76 ProjectIndex::new(
77 project,
78 self.db_connection.clone(),
79 self.embedding_provider.clone(),
80 cx,
81 )
82 })
83 })
84 .clone()
85 }
86}
87
88pub struct ProjectIndex {
89 db_connection: heed::Env,
90 project: Model<Project>,
91 worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
92 language_registry: Arc<LanguageRegistry>,
93 fs: Arc<dyn Fs>,
94 last_status: Status,
95 embedding_provider: Arc<dyn EmbeddingProvider>,
96 _subscription: Subscription,
97}
98
99enum WorktreeIndexHandle {
100 Loading {
101 _task: Task<Result<()>>,
102 },
103 Loaded {
104 index: Model<WorktreeIndex>,
105 _subscription: Subscription,
106 },
107}
108
109impl ProjectIndex {
110 fn new(
111 project: Model<Project>,
112 db_connection: heed::Env,
113 embedding_provider: Arc<dyn EmbeddingProvider>,
114 cx: &mut ModelContext<Self>,
115 ) -> Self {
116 let language_registry = project.read(cx).languages().clone();
117 let fs = project.read(cx).fs().clone();
118 let mut this = ProjectIndex {
119 db_connection,
120 project: project.clone(),
121 worktree_indices: HashMap::default(),
122 language_registry,
123 fs,
124 last_status: Status::Idle,
125 embedding_provider,
126 _subscription: cx.subscribe(&project, Self::handle_project_event),
127 };
128 this.update_worktree_indices(cx);
129 this
130 }
131
132 fn handle_project_event(
133 &mut self,
134 _: Model<Project>,
135 event: &project::Event,
136 cx: &mut ModelContext<Self>,
137 ) {
138 match event {
139 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
140 self.update_worktree_indices(cx);
141 }
142 _ => {}
143 }
144 }
145
146 fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
147 let worktrees = self
148 .project
149 .read(cx)
150 .visible_worktrees(cx)
151 .filter_map(|worktree| {
152 if worktree.read(cx).is_local() {
153 Some((worktree.entity_id(), worktree))
154 } else {
155 None
156 }
157 })
158 .collect::<HashMap<_, _>>();
159
160 self.worktree_indices
161 .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
162 for (worktree_id, worktree) in worktrees {
163 self.worktree_indices.entry(worktree_id).or_insert_with(|| {
164 let worktree_index = WorktreeIndex::load(
165 worktree.clone(),
166 self.db_connection.clone(),
167 self.language_registry.clone(),
168 self.fs.clone(),
169 self.embedding_provider.clone(),
170 cx,
171 );
172
173 let load_worktree = cx.spawn(|this, mut cx| async move {
174 if let Some(index) = worktree_index.await.log_err() {
175 this.update(&mut cx, |this, cx| {
176 this.worktree_indices.insert(
177 worktree_id,
178 WorktreeIndexHandle::Loaded {
179 _subscription: cx
180 .observe(&index, |this, _, cx| this.update_status(cx)),
181 index,
182 },
183 );
184 })?;
185 } else {
186 this.update(&mut cx, |this, _cx| {
187 this.worktree_indices.remove(&worktree_id)
188 })?;
189 }
190
191 this.update(&mut cx, |this, cx| this.update_status(cx))
192 });
193
194 WorktreeIndexHandle::Loading {
195 _task: load_worktree,
196 }
197 });
198 }
199
200 self.update_status(cx);
201 }
202
203 fn update_status(&mut self, cx: &mut ModelContext<Self>) {
204 let mut status = Status::Idle;
205 for index in self.worktree_indices.values() {
206 match index {
207 WorktreeIndexHandle::Loading { .. } => {
208 status = Status::Scanning;
209 break;
210 }
211 WorktreeIndexHandle::Loaded { index, .. } => {
212 if index.read(cx).status == Status::Scanning {
213 status = Status::Scanning;
214 break;
215 }
216 }
217 }
218 }
219
220 if status != self.last_status {
221 self.last_status = status;
222 cx.emit(status);
223 }
224 }
225
226 pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
227 let mut worktree_searches = Vec::new();
228 for worktree_index in self.worktree_indices.values() {
229 if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
230 worktree_searches
231 .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
232 }
233 }
234
235 cx.spawn(|_| async move {
236 let mut results = Vec::new();
237 let worktree_searches = futures::future::join_all(worktree_searches).await;
238
239 for worktree_search_results in worktree_searches {
240 if let Some(worktree_search_results) = worktree_search_results.log_err() {
241 results.extend(worktree_search_results);
242 }
243 }
244
245 results
246 .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
247 results.truncate(limit);
248
249 results
250 })
251 }
252}
253
254pub struct SearchResult {
255 pub worktree: Model<Worktree>,
256 pub path: Arc<Path>,
257 pub range: Range<usize>,
258 pub score: f32,
259}
260
261#[derive(Copy, Clone, Debug, Eq, PartialEq)]
262pub enum Status {
263 Idle,
264 Scanning,
265}
266
267impl EventEmitter<Status> for ProjectIndex {}
268
269struct WorktreeIndex {
270 worktree: Model<Worktree>,
271 db_connection: heed::Env,
272 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
273 language_registry: Arc<LanguageRegistry>,
274 fs: Arc<dyn Fs>,
275 embedding_provider: Arc<dyn EmbeddingProvider>,
276 status: Status,
277 _index_entries: Task<Result<()>>,
278 _subscription: Subscription,
279}
280
281impl WorktreeIndex {
282 pub fn load(
283 worktree: Model<Worktree>,
284 db_connection: heed::Env,
285 language_registry: Arc<LanguageRegistry>,
286 fs: Arc<dyn Fs>,
287 embedding_provider: Arc<dyn EmbeddingProvider>,
288 cx: &mut AppContext,
289 ) -> Task<Result<Model<Self>>> {
290 let worktree_abs_path = worktree.read(cx).abs_path();
291 cx.spawn(|mut cx| async move {
292 let db = cx
293 .background_executor()
294 .spawn({
295 let db_connection = db_connection.clone();
296 async move {
297 let mut txn = db_connection.write_txn()?;
298 let db_name = worktree_abs_path.to_string_lossy();
299 let db = db_connection.create_database(&mut txn, Some(&db_name))?;
300 txn.commit()?;
301 anyhow::Ok(db)
302 }
303 })
304 .await?;
305 cx.new_model(|cx| {
306 Self::new(
307 worktree,
308 db_connection,
309 db,
310 language_registry,
311 fs,
312 embedding_provider,
313 cx,
314 )
315 })
316 })
317 }
318
319 fn new(
320 worktree: Model<Worktree>,
321 db_connection: heed::Env,
322 db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
323 language_registry: Arc<LanguageRegistry>,
324 fs: Arc<dyn Fs>,
325 embedding_provider: Arc<dyn EmbeddingProvider>,
326 cx: &mut ModelContext<Self>,
327 ) -> Self {
328 let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
329 let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
330 if let worktree::Event::UpdatedEntries(update) = event {
331 _ = updated_entries_tx.try_send(update.clone());
332 }
333 });
334
335 Self {
336 db_connection,
337 db,
338 worktree,
339 language_registry,
340 fs,
341 embedding_provider,
342 status: Status::Idle,
343 _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
344 _subscription,
345 }
346 }
347
348 async fn index_entries(
349 this: WeakModel<Self>,
350 updated_entries: channel::Receiver<UpdatedEntriesSet>,
351 mut cx: AsyncAppContext,
352 ) -> Result<()> {
353 let index = this.update(&mut cx, |this, cx| {
354 cx.notify();
355 this.status = Status::Scanning;
356 this.index_entries_changed_on_disk(cx)
357 })?;
358 index.await.log_err();
359 this.update(&mut cx, |this, cx| {
360 this.status = Status::Idle;
361 cx.notify();
362 })?;
363
364 while let Ok(updated_entries) = updated_entries.recv().await {
365 let index = this.update(&mut cx, |this, cx| {
366 cx.notify();
367 this.status = Status::Scanning;
368 this.index_updated_entries(updated_entries, cx)
369 })?;
370 index.await.log_err();
371 this.update(&mut cx, |this, cx| {
372 this.status = Status::Idle;
373 cx.notify();
374 })?;
375 }
376
377 Ok(())
378 }
379
380 fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
381 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
382 let worktree_abs_path = worktree.abs_path().clone();
383 let scan = self.scan_entries(worktree.clone(), cx);
384 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
385 let embed = self.embed_files(chunk.files, cx);
386 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
387 async move {
388 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
389 Ok(())
390 }
391 }
392
393 fn index_updated_entries(
394 &self,
395 updated_entries: UpdatedEntriesSet,
396 cx: &AppContext,
397 ) -> impl Future<Output = Result<()>> {
398 let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
399 let worktree_abs_path = worktree.abs_path().clone();
400 let scan = self.scan_updated_entries(worktree, updated_entries, cx);
401 let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
402 let embed = self.embed_files(chunk.files, cx);
403 let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
404 async move {
405 futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
406 Ok(())
407 }
408 }
409
410 fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
411 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
412 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
413 let db_connection = self.db_connection.clone();
414 let db = self.db;
415 let task = cx.background_executor().spawn(async move {
416 let txn = db_connection
417 .read_txn()
418 .context("failed to create read transaction")?;
419 let mut db_entries = db
420 .iter(&txn)
421 .context("failed to create iterator")?
422 .move_between_keys()
423 .peekable();
424
425 let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
426 for entry in worktree.files(false, 0) {
427 let entry_db_key = db_key_for_path(&entry.path);
428
429 let mut saved_mtime = None;
430 while let Some(db_entry) = db_entries.peek() {
431 match db_entry {
432 Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
433 Ordering::Less => {
434 if let Some(deletion_range) = deletion_range.as_mut() {
435 deletion_range.1 = Bound::Included(db_path);
436 } else {
437 deletion_range =
438 Some((Bound::Included(db_path), Bound::Included(db_path)));
439 }
440
441 db_entries.next();
442 }
443 Ordering::Equal => {
444 if let Some(deletion_range) = deletion_range.take() {
445 deleted_entry_ranges_tx
446 .send((
447 deletion_range.0.map(ToString::to_string),
448 deletion_range.1.map(ToString::to_string),
449 ))
450 .await?;
451 }
452 saved_mtime = db_embedded_file.mtime;
453 db_entries.next();
454 break;
455 }
456 Ordering::Greater => {
457 break;
458 }
459 },
460 Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
461 }
462 }
463
464 if entry.mtime != saved_mtime {
465 updated_entries_tx.send(entry.clone()).await?;
466 }
467 }
468
469 if let Some(db_entry) = db_entries.next() {
470 let (db_path, _) = db_entry?;
471 deleted_entry_ranges_tx
472 .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
473 .await?;
474 }
475
476 Ok(())
477 });
478
479 ScanEntries {
480 updated_entries: updated_entries_rx,
481 deleted_entry_ranges: deleted_entry_ranges_rx,
482 task,
483 }
484 }
485
486 fn scan_updated_entries(
487 &self,
488 worktree: LocalSnapshot,
489 updated_entries: UpdatedEntriesSet,
490 cx: &AppContext,
491 ) -> ScanEntries {
492 let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
493 let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
494 let task = cx.background_executor().spawn(async move {
495 for (path, entry_id, status) in updated_entries.iter() {
496 match status {
497 project::PathChange::Added
498 | project::PathChange::Updated
499 | project::PathChange::AddedOrUpdated => {
500 if let Some(entry) = worktree.entry_for_id(*entry_id) {
501 updated_entries_tx.send(entry.clone()).await?;
502 }
503 }
504 project::PathChange::Removed => {
505 let db_path = db_key_for_path(path);
506 deleted_entry_ranges_tx
507 .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
508 .await?;
509 }
510 project::PathChange::Loaded => {
511 // Do nothing.
512 }
513 }
514 }
515
516 Ok(())
517 });
518
519 ScanEntries {
520 updated_entries: updated_entries_rx,
521 deleted_entry_ranges: deleted_entry_ranges_rx,
522 task,
523 }
524 }
525
526 fn chunk_files(
527 &self,
528 worktree_abs_path: Arc<Path>,
529 entries: channel::Receiver<Entry>,
530 cx: &AppContext,
531 ) -> ChunkFiles {
532 let language_registry = self.language_registry.clone();
533 let fs = self.fs.clone();
534 let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
535 let task = cx.spawn(|cx| async move {
536 cx.background_executor()
537 .scoped(|cx| {
538 for _ in 0..cx.num_cpus() {
539 cx.spawn(async {
540 while let Ok(entry) = entries.recv().await {
541 let entry_abs_path = worktree_abs_path.join(&entry.path);
542 let Some(text) = fs.load(&entry_abs_path).await.log_err() else {
543 continue;
544 };
545 let language = language_registry
546 .language_for_file_path(&entry.path)
547 .await
548 .ok();
549 let grammar =
550 language.as_ref().and_then(|language| language.grammar());
551 let chunked_file = ChunkedFile {
552 worktree_root: worktree_abs_path.clone(),
553 chunks: chunk_text(&text, grammar),
554 entry,
555 text,
556 };
557
558 if chunked_files_tx.send(chunked_file).await.is_err() {
559 return;
560 }
561 }
562 });
563 }
564 })
565 .await;
566 Ok(())
567 });
568
569 ChunkFiles {
570 files: chunked_files_rx,
571 task,
572 }
573 }
574
575 fn embed_files(
576 &self,
577 chunked_files: channel::Receiver<ChunkedFile>,
578 cx: &AppContext,
579 ) -> EmbedFiles {
580 let embedding_provider = self.embedding_provider.clone();
581 let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
582 let task = cx.background_executor().spawn(async move {
583 let mut chunked_file_batches =
584 chunked_files.chunks_timeout(512, Duration::from_secs(2));
585 while let Some(chunked_files) = chunked_file_batches.next().await {
586 // View the batch of files as a vec of chunks
587 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
588 // Once those are done, reassemble it back into which files they belong to
589
590 let chunks = chunked_files
591 .iter()
592 .flat_map(|file| {
593 file.chunks.iter().map(|chunk| TextToEmbed {
594 text: &file.text[chunk.range.clone()],
595 digest: chunk.digest,
596 })
597 })
598 .collect::<Vec<_>>();
599
600 let mut embeddings = Vec::new();
601 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
602 // todo!("add a retry facility")
603 embeddings.extend(embedding_provider.embed(embedding_batch).await?);
604 }
605
606 let mut embeddings = embeddings.into_iter();
607 for chunked_file in chunked_files {
608 let chunk_embeddings = embeddings
609 .by_ref()
610 .take(chunked_file.chunks.len())
611 .collect::<Vec<_>>();
612 let embedded_chunks = chunked_file
613 .chunks
614 .into_iter()
615 .zip(chunk_embeddings)
616 .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
617 .collect();
618 let embedded_file = EmbeddedFile {
619 path: chunked_file.entry.path.clone(),
620 mtime: chunked_file.entry.mtime,
621 chunks: embedded_chunks,
622 };
623
624 embedded_files_tx.send(embedded_file).await?;
625 }
626 }
627 Ok(())
628 });
629
630 EmbedFiles {
631 files: embedded_files_rx,
632 task,
633 }
634 }
635
636 fn persist_embeddings(
637 &self,
638 mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
639 embedded_files: channel::Receiver<EmbeddedFile>,
640 cx: &AppContext,
641 ) -> Task<Result<()>> {
642 let db_connection = self.db_connection.clone();
643 let db = self.db;
644 cx.background_executor().spawn(async move {
645 while let Some(deletion_range) = deleted_entry_ranges.next().await {
646 let mut txn = db_connection.write_txn()?;
647 let start = deletion_range.0.as_ref().map(|start| start.as_str());
648 let end = deletion_range.1.as_ref().map(|end| end.as_str());
649 log::debug!("deleting embeddings in range {:?}", &(start, end));
650 db.delete_range(&mut txn, &(start, end))?;
651 txn.commit()?;
652 }
653
654 let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
655 while let Some(embedded_files) = embedded_files.next().await {
656 let mut txn = db_connection.write_txn()?;
657 for file in embedded_files {
658 log::debug!("saving embedding for file {:?}", file.path);
659 let key = db_key_for_path(&file.path);
660 db.put(&mut txn, &key, &file)?;
661 }
662 txn.commit()?;
663 log::debug!("committed");
664 }
665
666 Ok(())
667 })
668 }
669
670 fn search(
671 &self,
672 query: &str,
673 limit: usize,
674 cx: &AppContext,
675 ) -> Task<Result<Vec<SearchResult>>> {
676 let (chunks_tx, chunks_rx) = channel::bounded(1024);
677
678 let db_connection = self.db_connection.clone();
679 let db = self.db;
680 let scan_chunks = cx.background_executor().spawn({
681 async move {
682 let txn = db_connection
683 .read_txn()
684 .context("failed to create read transaction")?;
685 let db_entries = db.iter(&txn).context("failed to iterate database")?;
686 for db_entry in db_entries {
687 let (_, db_embedded_file) = db_entry?;
688 for chunk in db_embedded_file.chunks {
689 chunks_tx
690 .send((db_embedded_file.path.clone(), chunk))
691 .await?;
692 }
693 }
694 anyhow::Ok(())
695 }
696 });
697
698 let query = query.to_string();
699 let embedding_provider = self.embedding_provider.clone();
700 let worktree = self.worktree.clone();
701 cx.spawn(|cx| async move {
702 #[cfg(debug_assertions)]
703 let embedding_query_start = std::time::Instant::now();
704
705 let mut query_embeddings = embedding_provider
706 .embed(&[TextToEmbed::new(&query)])
707 .await?;
708 let query_embedding = query_embeddings
709 .pop()
710 .ok_or_else(|| anyhow!("no embedding for query"))?;
711 let mut workers = Vec::new();
712 for _ in 0..cx.background_executor().num_cpus() {
713 workers.push(Vec::<SearchResult>::new());
714 }
715
716 #[cfg(debug_assertions)]
717 let search_start = std::time::Instant::now();
718
719 cx.background_executor()
720 .scoped(|cx| {
721 for worker_results in workers.iter_mut() {
722 cx.spawn(async {
723 while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
724 let score = embedded_chunk.embedding.similarity(&query_embedding);
725 let ix = match worker_results.binary_search_by(|probe| {
726 score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
727 }) {
728 Ok(ix) | Err(ix) => ix,
729 };
730 worker_results.insert(
731 ix,
732 SearchResult {
733 worktree: worktree.clone(),
734 path: path.clone(),
735 range: embedded_chunk.chunk.range.clone(),
736 score,
737 },
738 );
739 worker_results.truncate(limit);
740 }
741 });
742 }
743 })
744 .await;
745 scan_chunks.await?;
746
747 let mut search_results = Vec::with_capacity(workers.len() * limit);
748 for worker_results in workers {
749 search_results.extend(worker_results);
750 }
751 search_results
752 .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
753 search_results.truncate(limit);
754 #[cfg(debug_assertions)]
755 {
756 let search_elapsed = search_start.elapsed();
757 log::debug!(
758 "searched {} entries in {:?}",
759 search_results.len(),
760 search_elapsed
761 );
762 let embedding_query_elapsed = embedding_query_start.elapsed();
763 log::debug!("embedding query took {:?}", embedding_query_elapsed);
764 }
765
766 Ok(search_results)
767 })
768 }
769}
770
771struct ScanEntries {
772 updated_entries: channel::Receiver<Entry>,
773 deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
774 task: Task<Result<()>>,
775}
776
777struct ChunkFiles {
778 files: channel::Receiver<ChunkedFile>,
779 task: Task<Result<()>>,
780}
781
782struct ChunkedFile {
783 #[allow(dead_code)]
784 pub worktree_root: Arc<Path>,
785 pub entry: Entry,
786 pub text: String,
787 pub chunks: Vec<Chunk>,
788}
789
790struct EmbedFiles {
791 files: channel::Receiver<EmbeddedFile>,
792 task: Task<Result<()>>,
793}
794
795#[derive(Debug, Serialize, Deserialize)]
796struct EmbeddedFile {
797 path: Arc<Path>,
798 mtime: Option<SystemTime>,
799 chunks: Vec<EmbeddedChunk>,
800}
801
802#[derive(Debug, Serialize, Deserialize)]
803struct EmbeddedChunk {
804 chunk: Chunk,
805 embedding: Embedding,
806}
807
808fn db_key_for_path(path: &Arc<Path>) -> String {
809 path.to_string_lossy().replace('/', "\0")
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815
816 use futures::channel::oneshot;
817 use futures::{future::BoxFuture, FutureExt};
818
819 use gpui::{Global, TestAppContext};
820 use language::language_settings::AllLanguageSettings;
821 use project::Project;
822 use settings::SettingsStore;
823 use std::{future, path::Path, sync::Arc};
824
825 fn init_test(cx: &mut TestAppContext) {
826 _ = cx.update(|cx| {
827 let store = SettingsStore::test(cx);
828 cx.set_global(store);
829 language::init(cx);
830 Project::init_settings(cx);
831 SettingsStore::update(cx, |store, cx| {
832 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
833 });
834 });
835 }
836
837 pub struct TestEmbeddingProvider;
838
839 impl EmbeddingProvider for TestEmbeddingProvider {
840 fn embed<'a>(
841 &'a self,
842 texts: &'a [TextToEmbed<'a>],
843 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
844 let embeddings = texts
845 .iter()
846 .map(|text| {
847 let mut embedding = vec![0f32; 2];
848 // if the text contains garbage, give it a 1 in the first dimension
849 if text.text.contains("garbage in") {
850 embedding[0] = 0.9;
851 } else {
852 embedding[0] = -0.9;
853 }
854
855 if text.text.contains("garbage out") {
856 embedding[1] = 0.9;
857 } else {
858 embedding[1] = -0.9;
859 }
860
861 Embedding::new(embedding)
862 })
863 .collect();
864 future::ready(Ok(embeddings)).boxed()
865 }
866
867 fn batch_size(&self) -> usize {
868 16
869 }
870 }
871
872 #[gpui::test]
873 async fn test_search(cx: &mut TestAppContext) {
874 cx.executor().allow_parking();
875
876 init_test(cx);
877
878 let temp_dir = tempfile::tempdir().unwrap();
879
880 let mut semantic_index = cx
881 .update(|cx| {
882 let semantic_index = SemanticIndex::new(
883 Path::new(temp_dir.path()),
884 Arc::new(TestEmbeddingProvider),
885 cx,
886 );
887 semantic_index
888 })
889 .await
890 .unwrap();
891
892 // todo!(): use a fixture
893 let project_path = Path::new("./fixture");
894
895 let project = cx
896 .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
897 .await;
898
899 cx.update(|cx| {
900 let language_registry = project.read(cx).languages().clone();
901 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
902 languages::init(language_registry, node_runtime, cx);
903 });
904
905 let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
906
907 let (tx, rx) = oneshot::channel();
908 let mut tx = Some(tx);
909 let subscription = cx.update(|cx| {
910 cx.subscribe(&project_index, move |_, event, _| {
911 if let Some(tx) = tx.take() {
912 _ = tx.send(*event);
913 }
914 })
915 });
916
917 rx.await.expect("no event emitted");
918 drop(subscription);
919
920 let results = cx
921 .update(|cx| {
922 let project_index = project_index.read(cx);
923 let query = "garbage in, garbage out";
924 project_index.search(query, 4, cx)
925 })
926 .await;
927
928 assert!(results.len() > 1, "should have found some results");
929
930 for result in &results {
931 println!("result: {:?}", result.path);
932 println!("score: {:?}", result.score);
933 }
934
935 // Find result that is greater than 0.5
936 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
937
938 assert_eq!(search_result.path.to_string_lossy(), "needle.md");
939
940 let content = cx
941 .update(|cx| {
942 let worktree = search_result.worktree.read(cx);
943 let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
944 let fs = project.read(cx).fs().clone();
945 cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
946 })
947 .await;
948
949 let range = search_result.range.clone();
950 let content = content[range.clone()].to_owned();
951
952 assert!(content.contains("garbage in, garbage out"));
953 }
954}