1mod db;
2mod embedding;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use anyhow::{anyhow, Result};
11use db::VectorDatabase;
12use embedding::{EmbeddingProvider, OpenAIEmbeddings};
13use futures::{channel::oneshot, Future};
14use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
15use language::{Anchor, Buffer, Language, LanguageRegistry};
16use parking_lot::Mutex;
17use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
18use postage::watch;
19use project::{search::PathMatcher, Fs, Project, WorktreeId};
20use smol::channel;
21use std::{
22 cmp::Ordering,
23 collections::HashMap,
24 mem,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::{Arc, Weak},
28 time::{Instant, SystemTime},
29};
30use util::{
31 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
32 http::HttpClient,
33 paths::EMBEDDINGS_DIR,
34 ResultExt,
35};
36
37const SEMANTIC_INDEX_VERSION: usize = 6;
38const EMBEDDINGS_BATCH_SIZE: usize = 80;
39
40pub fn init(
41 fs: Arc<dyn Fs>,
42 http_client: Arc<dyn HttpClient>,
43 language_registry: Arc<LanguageRegistry>,
44 cx: &mut AppContext,
45) {
46 settings::register::<SemanticIndexSettings>(cx);
47
48 let db_file_path = EMBEDDINGS_DIR
49 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
50 .join("embeddings_db");
51
52 // This needs to be removed at some point before stable.
53 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
54 return;
55 }
56
57 cx.spawn(move |mut cx| async move {
58 let semantic_index = SemanticIndex::new(
59 fs,
60 db_file_path,
61 Arc::new(OpenAIEmbeddings {
62 client: http_client,
63 executor: cx.background(),
64 }),
65 language_registry,
66 cx.clone(),
67 )
68 .await?;
69
70 cx.update(|cx| {
71 cx.set_global(semantic_index.clone());
72 });
73
74 anyhow::Ok(())
75 })
76 .detach();
77}
78
79pub struct SemanticIndex {
80 fs: Arc<dyn Fs>,
81 database_url: Arc<PathBuf>,
82 embedding_provider: Arc<dyn EmbeddingProvider>,
83 language_registry: Arc<LanguageRegistry>,
84 db_update_tx: channel::Sender<DbOperation>,
85 parsing_files_tx: channel::Sender<PendingFile>,
86 _db_update_task: Task<()>,
87 _embed_batch_tasks: Vec<Task<()>>,
88 _batch_files_task: Task<()>,
89 _parsing_files_tasks: Vec<Task<()>>,
90 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
91}
92
93struct ProjectState {
94 worktree_db_ids: Vec<(WorktreeId, i64)>,
95 outstanding_job_count_rx: watch::Receiver<usize>,
96 _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
97}
98
99#[derive(Clone)]
100struct JobHandle {
101 /// The outer Arc is here to count the clones of a JobHandle instance;
102 /// when the last handle to a given job is dropped, we decrement a counter (just once).
103 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
104}
105
106impl JobHandle {
107 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
108 *tx.lock().borrow_mut() += 1;
109 Self {
110 tx: Arc::new(Arc::downgrade(&tx)),
111 }
112 }
113}
114impl ProjectState {
115 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
116 self.worktree_db_ids
117 .iter()
118 .find_map(|(worktree_id, db_id)| {
119 if *worktree_id == id {
120 Some(*db_id)
121 } else {
122 None
123 }
124 })
125 }
126
127 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
128 self.worktree_db_ids
129 .iter()
130 .find_map(|(worktree_id, db_id)| {
131 if *db_id == id {
132 Some(*worktree_id)
133 } else {
134 None
135 }
136 })
137 }
138}
139
140pub struct PendingFile {
141 worktree_db_id: i64,
142 relative_path: PathBuf,
143 absolute_path: PathBuf,
144 language: Arc<Language>,
145 modified_time: SystemTime,
146 job_handle: JobHandle,
147}
148
149pub struct SearchResult {
150 pub buffer: ModelHandle<Buffer>,
151 pub range: Range<Anchor>,
152}
153
154enum DbOperation {
155 InsertFile {
156 worktree_id: i64,
157 documents: Vec<Document>,
158 path: PathBuf,
159 mtime: SystemTime,
160 job_handle: JobHandle,
161 },
162 Delete {
163 worktree_id: i64,
164 path: PathBuf,
165 },
166 FindOrCreateWorktree {
167 path: PathBuf,
168 sender: oneshot::Sender<Result<i64>>,
169 },
170 FileMTimes {
171 worktree_id: i64,
172 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
173 },
174 WorktreePreviouslyIndexed {
175 path: Arc<Path>,
176 sender: oneshot::Sender<Result<bool>>,
177 },
178}
179
180enum EmbeddingJob {
181 Enqueue {
182 worktree_id: i64,
183 path: PathBuf,
184 mtime: SystemTime,
185 documents: Vec<Document>,
186 job_handle: JobHandle,
187 },
188 Flush,
189}
190
191impl SemanticIndex {
192 pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
193 if cx.has_global::<ModelHandle<Self>>() {
194 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
195 } else {
196 None
197 }
198 }
199
200 pub fn enabled(cx: &AppContext) -> bool {
201 settings::get::<SemanticIndexSettings>(cx).enabled
202 && *RELEASE_CHANNEL != ReleaseChannel::Stable
203 }
204
205 async fn new(
206 fs: Arc<dyn Fs>,
207 database_url: PathBuf,
208 embedding_provider: Arc<dyn EmbeddingProvider>,
209 language_registry: Arc<LanguageRegistry>,
210 mut cx: AsyncAppContext,
211 ) -> Result<ModelHandle<Self>> {
212 let t0 = Instant::now();
213 let database_url = Arc::new(database_url);
214
215 let db = cx
216 .background()
217 .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
218 .await?;
219
220 log::trace!(
221 "db initialization took {:?} milliseconds",
222 t0.elapsed().as_millis()
223 );
224
225 Ok(cx.add_model(|cx| {
226 let t0 = Instant::now();
227 // Perform database operations
228 let (db_update_tx, db_update_rx) = channel::unbounded();
229 let _db_update_task = cx.background().spawn({
230 async move {
231 while let Ok(job) = db_update_rx.recv().await {
232 Self::run_db_operation(&db, job)
233 }
234 }
235 });
236
237 // Group documents into batches and send them to the embedding provider.
238 let (embed_batch_tx, embed_batch_rx) =
239 channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
240 let mut _embed_batch_tasks = Vec::new();
241 for _ in 0..cx.background().num_cpus() {
242 let embed_batch_rx = embed_batch_rx.clone();
243 _embed_batch_tasks.push(cx.background().spawn({
244 let db_update_tx = db_update_tx.clone();
245 let embedding_provider = embedding_provider.clone();
246 async move {
247 while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
248 Self::compute_embeddings_for_batch(
249 embeddings_queue,
250 &embedding_provider,
251 &db_update_tx,
252 )
253 .await;
254 }
255 }
256 }));
257 }
258
259 // Group documents into batches and send them to the embedding provider.
260 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
261 let _batch_files_task = cx.background().spawn(async move {
262 let mut queue_len = 0;
263 let mut embeddings_queue = vec![];
264 while let Ok(job) = batch_files_rx.recv().await {
265 Self::enqueue_documents_to_embed(
266 job,
267 &mut queue_len,
268 &mut embeddings_queue,
269 &embed_batch_tx,
270 );
271 }
272 });
273
274 // Parse files into embeddable documents.
275 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
276 let mut _parsing_files_tasks = Vec::new();
277 for _ in 0..cx.background().num_cpus() {
278 let fs = fs.clone();
279 let parsing_files_rx = parsing_files_rx.clone();
280 let batch_files_tx = batch_files_tx.clone();
281 let db_update_tx = db_update_tx.clone();
282 _parsing_files_tasks.push(cx.background().spawn(async move {
283 let mut retriever = CodeContextRetriever::new();
284 while let Ok(pending_file) = parsing_files_rx.recv().await {
285 Self::parse_file(
286 &fs,
287 pending_file,
288 &mut retriever,
289 &batch_files_tx,
290 &parsing_files_rx,
291 &db_update_tx,
292 )
293 .await;
294 }
295 }));
296 }
297
298 log::trace!(
299 "semantic index task initialization took {:?} milliseconds",
300 t0.elapsed().as_millis()
301 );
302 Self {
303 fs,
304 database_url,
305 embedding_provider,
306 language_registry,
307 db_update_tx,
308 parsing_files_tx,
309 _db_update_task,
310 _embed_batch_tasks,
311 _batch_files_task,
312 _parsing_files_tasks,
313 projects: HashMap::new(),
314 }
315 }))
316 }
317
318 fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
319 match job {
320 DbOperation::InsertFile {
321 worktree_id,
322 documents,
323 path,
324 mtime,
325 job_handle,
326 } => {
327 db.insert_file(worktree_id, path, mtime, documents)
328 .log_err();
329 drop(job_handle)
330 }
331 DbOperation::Delete { worktree_id, path } => {
332 db.delete_file(worktree_id, path).log_err();
333 }
334 DbOperation::FindOrCreateWorktree { path, sender } => {
335 let id = db.find_or_create_worktree(&path);
336 sender.send(id).ok();
337 }
338 DbOperation::FileMTimes {
339 worktree_id: worktree_db_id,
340 sender,
341 } => {
342 let file_mtimes = db.get_file_mtimes(worktree_db_id);
343 sender.send(file_mtimes).ok();
344 }
345 DbOperation::WorktreePreviouslyIndexed { path, sender } => {
346 let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
347 sender.send(worktree_indexed).ok();
348 }
349 }
350 }
351
352 async fn compute_embeddings_for_batch(
353 mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
354 embedding_provider: &Arc<dyn EmbeddingProvider>,
355 db_update_tx: &channel::Sender<DbOperation>,
356 ) {
357 let mut batch_documents = vec![];
358 for (_, documents, _, _, _) in embeddings_queue.iter() {
359 batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
360 }
361
362 if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
363 log::trace!(
364 "created {} embeddings for {} files",
365 embeddings.len(),
366 embeddings_queue.len(),
367 );
368
369 let mut i = 0;
370 let mut j = 0;
371
372 for embedding in embeddings.iter() {
373 while embeddings_queue[i].1.len() == j {
374 i += 1;
375 j = 0;
376 }
377
378 embeddings_queue[i].1[j].embedding = embedding.to_owned();
379 j += 1;
380 }
381
382 for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
383 db_update_tx
384 .send(DbOperation::InsertFile {
385 worktree_id,
386 documents,
387 path,
388 mtime,
389 job_handle,
390 })
391 .await
392 .unwrap();
393 }
394 } else {
395 // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
396 for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
397 db_update_tx
398 .send(DbOperation::InsertFile {
399 worktree_id,
400 documents: vec![],
401 path,
402 mtime,
403 job_handle,
404 })
405 .await
406 .unwrap();
407 }
408 }
409 }
410
411 fn enqueue_documents_to_embed(
412 job: EmbeddingJob,
413 queue_len: &mut usize,
414 embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
415 embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
416 ) {
417 // Handle edge case where individual file has more documents than max batch size
418 let should_flush = match job {
419 EmbeddingJob::Enqueue {
420 documents,
421 worktree_id,
422 path,
423 mtime,
424 job_handle,
425 } => {
426 // If documents is greater than embeddings batch size, recursively batch existing rows.
427 if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
428 let first_job = EmbeddingJob::Enqueue {
429 documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
430 worktree_id,
431 path: path.clone(),
432 mtime,
433 job_handle: job_handle.clone(),
434 };
435
436 Self::enqueue_documents_to_embed(
437 first_job,
438 queue_len,
439 embeddings_queue,
440 embed_batch_tx,
441 );
442
443 let second_job = EmbeddingJob::Enqueue {
444 documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
445 worktree_id,
446 path: path.clone(),
447 mtime,
448 job_handle: job_handle.clone(),
449 };
450
451 Self::enqueue_documents_to_embed(
452 second_job,
453 queue_len,
454 embeddings_queue,
455 embed_batch_tx,
456 );
457 return;
458 } else {
459 *queue_len += &documents.len();
460 embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
461 *queue_len >= EMBEDDINGS_BATCH_SIZE
462 }
463 }
464 EmbeddingJob::Flush => true,
465 };
466
467 if should_flush {
468 embed_batch_tx
469 .try_send(mem::take(embeddings_queue))
470 .unwrap();
471 *queue_len = 0;
472 }
473 }
474
475 async fn parse_file(
476 fs: &Arc<dyn Fs>,
477 pending_file: PendingFile,
478 retriever: &mut CodeContextRetriever,
479 batch_files_tx: &channel::Sender<EmbeddingJob>,
480 parsing_files_rx: &channel::Receiver<PendingFile>,
481 db_update_tx: &channel::Sender<DbOperation>,
482 ) {
483 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
484 if let Some(documents) = retriever
485 .parse_file_with_template(
486 &pending_file.relative_path,
487 &content,
488 pending_file.language,
489 )
490 .log_err()
491 {
492 log::trace!(
493 "parsed path {:?}: {} documents",
494 pending_file.relative_path,
495 documents.len()
496 );
497
498 if documents.len() == 0 {
499 db_update_tx
500 .send(DbOperation::InsertFile {
501 worktree_id: pending_file.worktree_db_id,
502 documents,
503 path: pending_file.relative_path,
504 mtime: pending_file.modified_time,
505 job_handle: pending_file.job_handle,
506 })
507 .await
508 .unwrap();
509 } else {
510 batch_files_tx
511 .try_send(EmbeddingJob::Enqueue {
512 worktree_id: pending_file.worktree_db_id,
513 path: pending_file.relative_path,
514 mtime: pending_file.modified_time,
515 job_handle: pending_file.job_handle,
516 documents,
517 })
518 .unwrap();
519 }
520 }
521 }
522
523 if parsing_files_rx.len() == 0 {
524 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
525 }
526 }
527
528 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
529 let (tx, rx) = oneshot::channel();
530 self.db_update_tx
531 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
532 .unwrap();
533 async move { rx.await? }
534 }
535
536 fn get_file_mtimes(
537 &self,
538 worktree_id: i64,
539 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
540 let (tx, rx) = oneshot::channel();
541 self.db_update_tx
542 .try_send(DbOperation::FileMTimes {
543 worktree_id,
544 sender: tx,
545 })
546 .unwrap();
547 async move { rx.await? }
548 }
549
550 fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
551 let (tx, rx) = oneshot::channel();
552 self.db_update_tx
553 .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
554 .unwrap();
555 async move { rx.await? }
556 }
557
558 pub fn project_previously_indexed(
559 &mut self,
560 project: ModelHandle<Project>,
561 cx: &mut ModelContext<Self>,
562 ) -> Task<Result<bool>> {
563 let worktrees_indexed_previously = project
564 .read(cx)
565 .worktrees(cx)
566 .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
567 .collect::<Vec<_>>();
568 cx.spawn(|_, _cx| async move {
569 let worktree_indexed_previously =
570 futures::future::join_all(worktrees_indexed_previously).await;
571
572 Ok(worktree_indexed_previously
573 .iter()
574 .filter(|worktree| worktree.is_ok())
575 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
576 })
577 }
578
579 pub fn index_project(
580 &mut self,
581 project: ModelHandle<Project>,
582 cx: &mut ModelContext<Self>,
583 ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
584 let t0 = Instant::now();
585 let worktree_scans_complete = project
586 .read(cx)
587 .worktrees(cx)
588 .map(|worktree| {
589 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
590 async move {
591 scan_complete.await;
592 }
593 })
594 .collect::<Vec<_>>();
595 let worktree_db_ids = project
596 .read(cx)
597 .worktrees(cx)
598 .map(|worktree| {
599 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
600 })
601 .collect::<Vec<_>>();
602
603 let language_registry = self.language_registry.clone();
604 let db_update_tx = self.db_update_tx.clone();
605 let parsing_files_tx = self.parsing_files_tx.clone();
606
607 cx.spawn(|this, mut cx| async move {
608 futures::future::join_all(worktree_scans_complete).await;
609
610 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
611
612 let worktrees = project.read_with(&cx, |project, cx| {
613 project
614 .worktrees(cx)
615 .map(|worktree| worktree.read(cx).snapshot())
616 .collect::<Vec<_>>()
617 });
618
619 let mut worktree_file_mtimes = HashMap::new();
620 let mut db_ids_by_worktree_id = HashMap::new();
621 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
622 let db_id = db_id?;
623 db_ids_by_worktree_id.insert(worktree.id(), db_id);
624 worktree_file_mtimes.insert(
625 worktree.id(),
626 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
627 .await?,
628 );
629 }
630
631 let (job_count_tx, job_count_rx) = watch::channel_with(0);
632 let job_count_tx = Arc::new(Mutex::new(job_count_tx));
633 this.update(&mut cx, |this, _| {
634 this.projects.insert(
635 project.downgrade(),
636 ProjectState {
637 worktree_db_ids: db_ids_by_worktree_id
638 .iter()
639 .map(|(a, b)| (*a, *b))
640 .collect(),
641 outstanding_job_count_rx: job_count_rx.clone(),
642 _outstanding_job_count_tx: job_count_tx.clone(),
643 },
644 );
645 });
646
647 cx.background()
648 .spawn(async move {
649 let mut count = 0;
650 for worktree in worktrees.into_iter() {
651 let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
652 for file in worktree.files(false, 0) {
653 let absolute_path = worktree.absolutize(&file.path);
654
655 if let Ok(language) = language_registry
656 .language_for_file(&absolute_path, None)
657 .await
658 {
659 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
660 && &language.name().as_ref() != &"Markdown"
661 && language
662 .grammar()
663 .and_then(|grammar| grammar.embedding_config.as_ref())
664 .is_none()
665 {
666 continue;
667 }
668
669 let path_buf = file.path.to_path_buf();
670 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
671 let already_stored = stored_mtime
672 .map_or(false, |existing_mtime| existing_mtime == file.mtime);
673
674 if !already_stored {
675 count += 1;
676
677 let job_handle = JobHandle::new(&job_count_tx);
678 parsing_files_tx
679 .try_send(PendingFile {
680 worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
681 relative_path: path_buf,
682 absolute_path,
683 language,
684 job_handle,
685 modified_time: file.mtime,
686 })
687 .unwrap();
688 }
689 }
690 }
691 for file in file_mtimes.keys() {
692 db_update_tx
693 .try_send(DbOperation::Delete {
694 worktree_id: db_ids_by_worktree_id[&worktree.id()],
695 path: file.to_owned(),
696 })
697 .unwrap();
698 }
699 }
700
701 log::trace!(
702 "walking worktree took {:?} milliseconds",
703 t0.elapsed().as_millis()
704 );
705 anyhow::Ok((count, job_count_rx))
706 })
707 .await
708 })
709 }
710
711 pub fn outstanding_job_count_rx(
712 &self,
713 project: &ModelHandle<Project>,
714 ) -> Option<watch::Receiver<usize>> {
715 Some(
716 self.projects
717 .get(&project.downgrade())?
718 .outstanding_job_count_rx
719 .clone(),
720 )
721 }
722
723 pub fn search_project(
724 &mut self,
725 project: ModelHandle<Project>,
726 phrase: String,
727 limit: usize,
728 includes: Vec<PathMatcher>,
729 excludes: Vec<PathMatcher>,
730 cx: &mut ModelContext<Self>,
731 ) -> Task<Result<Vec<SearchResult>>> {
732 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
733 state
734 } else {
735 return Task::ready(Err(anyhow!("project not added")));
736 };
737
738 let worktree_db_ids = project
739 .read(cx)
740 .worktrees(cx)
741 .filter_map(|worktree| {
742 let worktree_id = worktree.read(cx).id();
743 project_state.db_id_for_worktree_id(worktree_id)
744 })
745 .collect::<Vec<_>>();
746
747 let embedding_provider = self.embedding_provider.clone();
748 let database_url = self.database_url.clone();
749 let fs = self.fs.clone();
750 cx.spawn(|this, mut cx| async move {
751 let t0 = Instant::now();
752 let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
753
754 let phrase_embedding = embedding_provider
755 .embed_batch(vec![&phrase])
756 .await?
757 .into_iter()
758 .next()
759 .unwrap();
760
761 log::trace!(
762 "Embedding search phrase took: {:?} milliseconds",
763 t0.elapsed().as_millis()
764 );
765
766 let file_ids =
767 database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
768
769 let batch_n = cx.background().num_cpus();
770 let ids_len = file_ids.clone().len();
771 let batch_size = if ids_len <= batch_n {
772 ids_len
773 } else {
774 ids_len / batch_n
775 };
776
777 let mut result_tasks = Vec::new();
778 for batch in file_ids.chunks(batch_size) {
779 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
780 let limit = limit.clone();
781 let fs = fs.clone();
782 let database_url = database_url.clone();
783 let phrase_embedding = phrase_embedding.clone();
784 let task = cx.background().spawn(async move {
785 let database = VectorDatabase::new(fs, database_url).await.log_err();
786 if database.is_none() {
787 return Err(anyhow!("failed to acquire database connection"));
788 } else {
789 database
790 .unwrap()
791 .top_k_search(&phrase_embedding, limit, batch.as_slice())
792 }
793 });
794 result_tasks.push(task);
795 }
796
797 let batch_results = futures::future::join_all(result_tasks).await;
798
799 let mut results = Vec::new();
800 for batch_result in batch_results {
801 if batch_result.is_ok() {
802 for (id, similarity) in batch_result.unwrap() {
803 let ix = match results.binary_search_by(|(_, s)| {
804 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
805 }) {
806 Ok(ix) => ix,
807 Err(ix) => ix,
808 };
809 results.insert(ix, (id, similarity));
810 results.truncate(limit);
811 }
812 }
813 }
814
815 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
816 let documents = database.get_documents_by_ids(ids.as_slice())?;
817
818 let mut tasks = Vec::new();
819 let mut ranges = Vec::new();
820 let weak_project = project.downgrade();
821 project.update(&mut cx, |project, cx| {
822 for (worktree_db_id, file_path, byte_range) in documents {
823 let project_state =
824 if let Some(state) = this.read(cx).projects.get(&weak_project) {
825 state
826 } else {
827 return Err(anyhow!("project not added"));
828 };
829 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
830 tasks.push(project.open_buffer((worktree_id, file_path), cx));
831 ranges.push(byte_range);
832 }
833 }
834
835 Ok(())
836 })?;
837
838 let buffers = futures::future::join_all(tasks).await;
839
840 log::trace!(
841 "Semantic Searching took: {:?} milliseconds in total",
842 t0.elapsed().as_millis()
843 );
844
845 Ok(buffers
846 .into_iter()
847 .zip(ranges)
848 .filter_map(|(buffer, range)| {
849 let buffer = buffer.log_err()?;
850 let range = buffer.read_with(&cx, |buffer, _| {
851 buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
852 });
853 Some(SearchResult { buffer, range })
854 })
855 .collect::<Vec<_>>())
856 })
857 }
858}
859
860impl Entity for SemanticIndex {
861 type Event = ();
862}
863
864impl Drop for JobHandle {
865 fn drop(&mut self) {
866 if let Some(inner) = Arc::get_mut(&mut self.tx) {
867 // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
868 if let Some(tx) = inner.upgrade() {
869 let mut tx = tx.lock();
870 *tx.borrow_mut() -= 1;
871 }
872 }
873 }
874}
875
876#[cfg(test)]
877mod tests {
878
879 use super::*;
880 #[test]
881 fn test_job_handle() {
882 let (job_count_tx, job_count_rx) = watch::channel_with(0);
883 let tx = Arc::new(Mutex::new(job_count_tx));
884 let job_handle = JobHandle::new(&tx);
885
886 assert_eq!(1, *job_count_rx.borrow());
887 let new_job_handle = job_handle.clone();
888 assert_eq!(1, *job_count_rx.borrow());
889 drop(job_handle);
890 assert_eq!(1, *job_count_rx.borrow());
891 drop(new_job_handle);
892 assert_eq!(0, *job_count_rx.borrow());
893 }
894}