semantic_index.rs

  1mod db;
  2mod embedding;
  3mod parsing;
  4pub mod semantic_index_settings;
  5
  6#[cfg(test)]
  7mod semantic_index_tests;
  8
  9use crate::semantic_index_settings::SemanticIndexSettings;
 10use anyhow::{anyhow, Result};
 11use db::VectorDatabase;
 12use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 13use futures::{channel::oneshot, Future};
 14use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 15use language::{Anchor, Buffer, Language, LanguageRegistry};
 16use parking_lot::Mutex;
 17use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
 18use postage::watch;
 19use project::{search::PathMatcher, Fs, Project, WorktreeId};
 20use smol::channel;
 21use std::{
 22    cmp::Ordering,
 23    collections::HashMap,
 24    mem,
 25    ops::Range,
 26    path::{Path, PathBuf},
 27    sync::{Arc, Weak},
 28    time::{Instant, SystemTime},
 29};
 30use util::{
 31    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 32    http::HttpClient,
 33    paths::EMBEDDINGS_DIR,
 34    ResultExt,
 35};
 36
 37const SEMANTIC_INDEX_VERSION: usize = 6;
 38const EMBEDDINGS_BATCH_SIZE: usize = 80;
 39
 40pub fn init(
 41    fs: Arc<dyn Fs>,
 42    http_client: Arc<dyn HttpClient>,
 43    language_registry: Arc<LanguageRegistry>,
 44    cx: &mut AppContext,
 45) {
 46    settings::register::<SemanticIndexSettings>(cx);
 47
 48    let db_file_path = EMBEDDINGS_DIR
 49        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 50        .join("embeddings_db");
 51
 52    // This needs to be removed at some point before stable.
 53    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 54        return;
 55    }
 56
 57    cx.spawn(move |mut cx| async move {
 58        let semantic_index = SemanticIndex::new(
 59            fs,
 60            db_file_path,
 61            Arc::new(OpenAIEmbeddings {
 62                client: http_client,
 63                executor: cx.background(),
 64            }),
 65            language_registry,
 66            cx.clone(),
 67        )
 68        .await?;
 69
 70        cx.update(|cx| {
 71            cx.set_global(semantic_index.clone());
 72        });
 73
 74        anyhow::Ok(())
 75    })
 76    .detach();
 77}
 78
 79pub struct SemanticIndex {
 80    fs: Arc<dyn Fs>,
 81    database_url: Arc<PathBuf>,
 82    embedding_provider: Arc<dyn EmbeddingProvider>,
 83    language_registry: Arc<LanguageRegistry>,
 84    db_update_tx: channel::Sender<DbOperation>,
 85    parsing_files_tx: channel::Sender<PendingFile>,
 86    _db_update_task: Task<()>,
 87    _embed_batch_tasks: Vec<Task<()>>,
 88    _batch_files_task: Task<()>,
 89    _parsing_files_tasks: Vec<Task<()>>,
 90    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 91}
 92
 93struct ProjectState {
 94    worktree_db_ids: Vec<(WorktreeId, i64)>,
 95    outstanding_job_count_rx: watch::Receiver<usize>,
 96    _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 97}
 98
 99#[derive(Clone)]
100struct JobHandle {
101    /// The outer Arc is here to count the clones of a JobHandle instance;
102    /// when the last handle to a given job is dropped, we decrement a counter (just once).
103    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
104}
105
106impl JobHandle {
107    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
108        *tx.lock().borrow_mut() += 1;
109        Self {
110            tx: Arc::new(Arc::downgrade(&tx)),
111        }
112    }
113}
114impl ProjectState {
115    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
116        self.worktree_db_ids
117            .iter()
118            .find_map(|(worktree_id, db_id)| {
119                if *worktree_id == id {
120                    Some(*db_id)
121                } else {
122                    None
123                }
124            })
125    }
126
127    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
128        self.worktree_db_ids
129            .iter()
130            .find_map(|(worktree_id, db_id)| {
131                if *db_id == id {
132                    Some(*worktree_id)
133                } else {
134                    None
135                }
136            })
137    }
138}
139
140pub struct PendingFile {
141    worktree_db_id: i64,
142    relative_path: PathBuf,
143    absolute_path: PathBuf,
144    language: Arc<Language>,
145    modified_time: SystemTime,
146    job_handle: JobHandle,
147}
148
149pub struct SearchResult {
150    pub buffer: ModelHandle<Buffer>,
151    pub range: Range<Anchor>,
152}
153
154enum DbOperation {
155    InsertFile {
156        worktree_id: i64,
157        documents: Vec<Document>,
158        path: PathBuf,
159        mtime: SystemTime,
160        job_handle: JobHandle,
161    },
162    Delete {
163        worktree_id: i64,
164        path: PathBuf,
165    },
166    FindOrCreateWorktree {
167        path: PathBuf,
168        sender: oneshot::Sender<Result<i64>>,
169    },
170    FileMTimes {
171        worktree_id: i64,
172        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
173    },
174    WorktreePreviouslyIndexed {
175        path: Arc<Path>,
176        sender: oneshot::Sender<Result<bool>>,
177    },
178}
179
180enum EmbeddingJob {
181    Enqueue {
182        worktree_id: i64,
183        path: PathBuf,
184        mtime: SystemTime,
185        documents: Vec<Document>,
186        job_handle: JobHandle,
187    },
188    Flush,
189}
190
191impl SemanticIndex {
192    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
193        if cx.has_global::<ModelHandle<Self>>() {
194            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
195        } else {
196            None
197        }
198    }
199
200    pub fn enabled(cx: &AppContext) -> bool {
201        settings::get::<SemanticIndexSettings>(cx).enabled
202            && *RELEASE_CHANNEL != ReleaseChannel::Stable
203    }
204
205    async fn new(
206        fs: Arc<dyn Fs>,
207        database_url: PathBuf,
208        embedding_provider: Arc<dyn EmbeddingProvider>,
209        language_registry: Arc<LanguageRegistry>,
210        mut cx: AsyncAppContext,
211    ) -> Result<ModelHandle<Self>> {
212        let t0 = Instant::now();
213        let database_url = Arc::new(database_url);
214
215        let db = cx
216            .background()
217            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
218            .await?;
219
220        log::trace!(
221            "db initialization took {:?} milliseconds",
222            t0.elapsed().as_millis()
223        );
224
225        Ok(cx.add_model(|cx| {
226            let t0 = Instant::now();
227            // Perform database operations
228            let (db_update_tx, db_update_rx) = channel::unbounded();
229            let _db_update_task = cx.background().spawn({
230                async move {
231                    while let Ok(job) = db_update_rx.recv().await {
232                        Self::run_db_operation(&db, job)
233                    }
234                }
235            });
236
237            // Group documents into batches and send them to the embedding provider.
238            let (embed_batch_tx, embed_batch_rx) =
239                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
240            let mut _embed_batch_tasks = Vec::new();
241            for _ in 0..cx.background().num_cpus() {
242                let embed_batch_rx = embed_batch_rx.clone();
243                _embed_batch_tasks.push(cx.background().spawn({
244                    let db_update_tx = db_update_tx.clone();
245                    let embedding_provider = embedding_provider.clone();
246                    async move {
247                        while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
248                            Self::compute_embeddings_for_batch(
249                                embeddings_queue,
250                                &embedding_provider,
251                                &db_update_tx,
252                            )
253                            .await;
254                        }
255                    }
256                }));
257            }
258
259            // Group documents into batches and send them to the embedding provider.
260            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
261            let _batch_files_task = cx.background().spawn(async move {
262                let mut queue_len = 0;
263                let mut embeddings_queue = vec![];
264                while let Ok(job) = batch_files_rx.recv().await {
265                    Self::enqueue_documents_to_embed(
266                        job,
267                        &mut queue_len,
268                        &mut embeddings_queue,
269                        &embed_batch_tx,
270                    );
271                }
272            });
273
274            // Parse files into embeddable documents.
275            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
276            let mut _parsing_files_tasks = Vec::new();
277            for _ in 0..cx.background().num_cpus() {
278                let fs = fs.clone();
279                let parsing_files_rx = parsing_files_rx.clone();
280                let batch_files_tx = batch_files_tx.clone();
281                let db_update_tx = db_update_tx.clone();
282                _parsing_files_tasks.push(cx.background().spawn(async move {
283                    let mut retriever = CodeContextRetriever::new();
284                    while let Ok(pending_file) = parsing_files_rx.recv().await {
285                        Self::parse_file(
286                            &fs,
287                            pending_file,
288                            &mut retriever,
289                            &batch_files_tx,
290                            &parsing_files_rx,
291                            &db_update_tx,
292                        )
293                        .await;
294                    }
295                }));
296            }
297
298            log::trace!(
299                "semantic index task initialization took {:?} milliseconds",
300                t0.elapsed().as_millis()
301            );
302            Self {
303                fs,
304                database_url,
305                embedding_provider,
306                language_registry,
307                db_update_tx,
308                parsing_files_tx,
309                _db_update_task,
310                _embed_batch_tasks,
311                _batch_files_task,
312                _parsing_files_tasks,
313                projects: HashMap::new(),
314            }
315        }))
316    }
317
318    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
319        match job {
320            DbOperation::InsertFile {
321                worktree_id,
322                documents,
323                path,
324                mtime,
325                job_handle,
326            } => {
327                db.insert_file(worktree_id, path, mtime, documents)
328                    .log_err();
329                drop(job_handle)
330            }
331            DbOperation::Delete { worktree_id, path } => {
332                db.delete_file(worktree_id, path).log_err();
333            }
334            DbOperation::FindOrCreateWorktree { path, sender } => {
335                let id = db.find_or_create_worktree(&path);
336                sender.send(id).ok();
337            }
338            DbOperation::FileMTimes {
339                worktree_id: worktree_db_id,
340                sender,
341            } => {
342                let file_mtimes = db.get_file_mtimes(worktree_db_id);
343                sender.send(file_mtimes).ok();
344            }
345            DbOperation::WorktreePreviouslyIndexed { path, sender } => {
346                let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
347                sender.send(worktree_indexed).ok();
348            }
349        }
350    }
351
352    async fn compute_embeddings_for_batch(
353        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
354        embedding_provider: &Arc<dyn EmbeddingProvider>,
355        db_update_tx: &channel::Sender<DbOperation>,
356    ) {
357        let mut batch_documents = vec![];
358        for (_, documents, _, _, _) in embeddings_queue.iter() {
359            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
360        }
361
362        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
363            log::trace!(
364                "created {} embeddings for {} files",
365                embeddings.len(),
366                embeddings_queue.len(),
367            );
368
369            let mut i = 0;
370            let mut j = 0;
371
372            for embedding in embeddings.iter() {
373                while embeddings_queue[i].1.len() == j {
374                    i += 1;
375                    j = 0;
376                }
377
378                embeddings_queue[i].1[j].embedding = embedding.to_owned();
379                j += 1;
380            }
381
382            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
383                db_update_tx
384                    .send(DbOperation::InsertFile {
385                        worktree_id,
386                        documents,
387                        path,
388                        mtime,
389                        job_handle,
390                    })
391                    .await
392                    .unwrap();
393            }
394        } else {
395            // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
396            for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
397                db_update_tx
398                    .send(DbOperation::InsertFile {
399                        worktree_id,
400                        documents: vec![],
401                        path,
402                        mtime,
403                        job_handle,
404                    })
405                    .await
406                    .unwrap();
407            }
408        }
409    }
410
411    fn enqueue_documents_to_embed(
412        job: EmbeddingJob,
413        queue_len: &mut usize,
414        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
415        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
416    ) {
417        // Handle edge case where individual file has more documents than max batch size
418        let should_flush = match job {
419            EmbeddingJob::Enqueue {
420                documents,
421                worktree_id,
422                path,
423                mtime,
424                job_handle,
425            } => {
426                // If documents is greater than embeddings batch size, recursively batch existing rows.
427                if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
428                    let first_job = EmbeddingJob::Enqueue {
429                        documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
430                        worktree_id,
431                        path: path.clone(),
432                        mtime,
433                        job_handle: job_handle.clone(),
434                    };
435
436                    Self::enqueue_documents_to_embed(
437                        first_job,
438                        queue_len,
439                        embeddings_queue,
440                        embed_batch_tx,
441                    );
442
443                    let second_job = EmbeddingJob::Enqueue {
444                        documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
445                        worktree_id,
446                        path: path.clone(),
447                        mtime,
448                        job_handle: job_handle.clone(),
449                    };
450
451                    Self::enqueue_documents_to_embed(
452                        second_job,
453                        queue_len,
454                        embeddings_queue,
455                        embed_batch_tx,
456                    );
457                    return;
458                } else {
459                    *queue_len += &documents.len();
460                    embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
461                    *queue_len >= EMBEDDINGS_BATCH_SIZE
462                }
463            }
464            EmbeddingJob::Flush => true,
465        };
466
467        if should_flush {
468            embed_batch_tx
469                .try_send(mem::take(embeddings_queue))
470                .unwrap();
471            *queue_len = 0;
472        }
473    }
474
475    async fn parse_file(
476        fs: &Arc<dyn Fs>,
477        pending_file: PendingFile,
478        retriever: &mut CodeContextRetriever,
479        batch_files_tx: &channel::Sender<EmbeddingJob>,
480        parsing_files_rx: &channel::Receiver<PendingFile>,
481        db_update_tx: &channel::Sender<DbOperation>,
482    ) {
483        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
484            if let Some(documents) = retriever
485                .parse_file_with_template(
486                    &pending_file.relative_path,
487                    &content,
488                    pending_file.language,
489                )
490                .log_err()
491            {
492                log::trace!(
493                    "parsed path {:?}: {} documents",
494                    pending_file.relative_path,
495                    documents.len()
496                );
497
498                if documents.len() == 0 {
499                    db_update_tx
500                        .send(DbOperation::InsertFile {
501                            worktree_id: pending_file.worktree_db_id,
502                            documents,
503                            path: pending_file.relative_path,
504                            mtime: pending_file.modified_time,
505                            job_handle: pending_file.job_handle,
506                        })
507                        .await
508                        .unwrap();
509                } else {
510                    batch_files_tx
511                        .try_send(EmbeddingJob::Enqueue {
512                            worktree_id: pending_file.worktree_db_id,
513                            path: pending_file.relative_path,
514                            mtime: pending_file.modified_time,
515                            job_handle: pending_file.job_handle,
516                            documents,
517                        })
518                        .unwrap();
519                }
520            }
521        }
522
523        if parsing_files_rx.len() == 0 {
524            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
525        }
526    }
527
528    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
529        let (tx, rx) = oneshot::channel();
530        self.db_update_tx
531            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
532            .unwrap();
533        async move { rx.await? }
534    }
535
536    fn get_file_mtimes(
537        &self,
538        worktree_id: i64,
539    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
540        let (tx, rx) = oneshot::channel();
541        self.db_update_tx
542            .try_send(DbOperation::FileMTimes {
543                worktree_id,
544                sender: tx,
545            })
546            .unwrap();
547        async move { rx.await? }
548    }
549
550    fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
551        let (tx, rx) = oneshot::channel();
552        self.db_update_tx
553            .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
554            .unwrap();
555        async move { rx.await? }
556    }
557
558    pub fn project_previously_indexed(
559        &mut self,
560        project: ModelHandle<Project>,
561        cx: &mut ModelContext<Self>,
562    ) -> Task<Result<bool>> {
563        let worktrees_indexed_previously = project
564            .read(cx)
565            .worktrees(cx)
566            .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
567            .collect::<Vec<_>>();
568        cx.spawn(|_, _cx| async move {
569            let worktree_indexed_previously =
570                futures::future::join_all(worktrees_indexed_previously).await;
571
572            Ok(worktree_indexed_previously
573                .iter()
574                .filter(|worktree| worktree.is_ok())
575                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
576        })
577    }
578
579    pub fn index_project(
580        &mut self,
581        project: ModelHandle<Project>,
582        cx: &mut ModelContext<Self>,
583    ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
584        let t0 = Instant::now();
585        let worktree_scans_complete = project
586            .read(cx)
587            .worktrees(cx)
588            .map(|worktree| {
589                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
590                async move {
591                    scan_complete.await;
592                }
593            })
594            .collect::<Vec<_>>();
595        let worktree_db_ids = project
596            .read(cx)
597            .worktrees(cx)
598            .map(|worktree| {
599                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
600            })
601            .collect::<Vec<_>>();
602
603        let language_registry = self.language_registry.clone();
604        let db_update_tx = self.db_update_tx.clone();
605        let parsing_files_tx = self.parsing_files_tx.clone();
606
607        cx.spawn(|this, mut cx| async move {
608            futures::future::join_all(worktree_scans_complete).await;
609
610            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
611
612            let worktrees = project.read_with(&cx, |project, cx| {
613                project
614                    .worktrees(cx)
615                    .map(|worktree| worktree.read(cx).snapshot())
616                    .collect::<Vec<_>>()
617            });
618
619            let mut worktree_file_mtimes = HashMap::new();
620            let mut db_ids_by_worktree_id = HashMap::new();
621            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
622                let db_id = db_id?;
623                db_ids_by_worktree_id.insert(worktree.id(), db_id);
624                worktree_file_mtimes.insert(
625                    worktree.id(),
626                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
627                        .await?,
628                );
629            }
630
631            let (job_count_tx, job_count_rx) = watch::channel_with(0);
632            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
633            this.update(&mut cx, |this, _| {
634                this.projects.insert(
635                    project.downgrade(),
636                    ProjectState {
637                        worktree_db_ids: db_ids_by_worktree_id
638                            .iter()
639                            .map(|(a, b)| (*a, *b))
640                            .collect(),
641                        outstanding_job_count_rx: job_count_rx.clone(),
642                        _outstanding_job_count_tx: job_count_tx.clone(),
643                    },
644                );
645            });
646
647            cx.background()
648                .spawn(async move {
649                    let mut count = 0;
650                    for worktree in worktrees.into_iter() {
651                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
652                        for file in worktree.files(false, 0) {
653                            let absolute_path = worktree.absolutize(&file.path);
654
655                            if let Ok(language) = language_registry
656                                .language_for_file(&absolute_path, None)
657                                .await
658                            {
659                                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
660                                    && &language.name().as_ref() != &"Markdown"
661                                    && language
662                                        .grammar()
663                                        .and_then(|grammar| grammar.embedding_config.as_ref())
664                                        .is_none()
665                                {
666                                    continue;
667                                }
668
669                                let path_buf = file.path.to_path_buf();
670                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
671                                let already_stored = stored_mtime
672                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
673
674                                if !already_stored {
675                                    count += 1;
676
677                                    let job_handle = JobHandle::new(&job_count_tx);
678                                    parsing_files_tx
679                                        .try_send(PendingFile {
680                                            worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
681                                            relative_path: path_buf,
682                                            absolute_path,
683                                            language,
684                                            job_handle,
685                                            modified_time: file.mtime,
686                                        })
687                                        .unwrap();
688                                }
689                            }
690                        }
691                        for file in file_mtimes.keys() {
692                            db_update_tx
693                                .try_send(DbOperation::Delete {
694                                    worktree_id: db_ids_by_worktree_id[&worktree.id()],
695                                    path: file.to_owned(),
696                                })
697                                .unwrap();
698                        }
699                    }
700
701                    log::trace!(
702                        "walking worktree took {:?} milliseconds",
703                        t0.elapsed().as_millis()
704                    );
705                    anyhow::Ok((count, job_count_rx))
706                })
707                .await
708        })
709    }
710
711    pub fn outstanding_job_count_rx(
712        &self,
713        project: &ModelHandle<Project>,
714    ) -> Option<watch::Receiver<usize>> {
715        Some(
716            self.projects
717                .get(&project.downgrade())?
718                .outstanding_job_count_rx
719                .clone(),
720        )
721    }
722
723    pub fn search_project(
724        &mut self,
725        project: ModelHandle<Project>,
726        phrase: String,
727        limit: usize,
728        includes: Vec<PathMatcher>,
729        excludes: Vec<PathMatcher>,
730        cx: &mut ModelContext<Self>,
731    ) -> Task<Result<Vec<SearchResult>>> {
732        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
733            state
734        } else {
735            return Task::ready(Err(anyhow!("project not added")));
736        };
737
738        let worktree_db_ids = project
739            .read(cx)
740            .worktrees(cx)
741            .filter_map(|worktree| {
742                let worktree_id = worktree.read(cx).id();
743                project_state.db_id_for_worktree_id(worktree_id)
744            })
745            .collect::<Vec<_>>();
746
747        let embedding_provider = self.embedding_provider.clone();
748        let database_url = self.database_url.clone();
749        let fs = self.fs.clone();
750        cx.spawn(|this, mut cx| async move {
751            let t0 = Instant::now();
752            let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
753
754            let phrase_embedding = embedding_provider
755                .embed_batch(vec![&phrase])
756                .await?
757                .into_iter()
758                .next()
759                .unwrap();
760
761            log::trace!(
762                "Embedding search phrase took: {:?} milliseconds",
763                t0.elapsed().as_millis()
764            );
765
766            let file_ids =
767                database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
768
769            let batch_n = cx.background().num_cpus();
770            let ids_len = file_ids.clone().len();
771            let batch_size = if ids_len <= batch_n {
772                ids_len
773            } else {
774                ids_len / batch_n
775            };
776
777            let mut result_tasks = Vec::new();
778            for batch in file_ids.chunks(batch_size) {
779                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
780                let limit = limit.clone();
781                let fs = fs.clone();
782                let database_url = database_url.clone();
783                let phrase_embedding = phrase_embedding.clone();
784                let task = cx.background().spawn(async move {
785                    let database = VectorDatabase::new(fs, database_url).await.log_err();
786                    if database.is_none() {
787                        return Err(anyhow!("failed to acquire database connection"));
788                    } else {
789                        database
790                            .unwrap()
791                            .top_k_search(&phrase_embedding, limit, batch.as_slice())
792                    }
793                });
794                result_tasks.push(task);
795            }
796
797            let batch_results = futures::future::join_all(result_tasks).await;
798
799            let mut results = Vec::new();
800            for batch_result in batch_results {
801                if batch_result.is_ok() {
802                    for (id, similarity) in batch_result.unwrap() {
803                        let ix = match results.binary_search_by(|(_, s)| {
804                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
805                        }) {
806                            Ok(ix) => ix,
807                            Err(ix) => ix,
808                        };
809                        results.insert(ix, (id, similarity));
810                        results.truncate(limit);
811                    }
812                }
813            }
814
815            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
816            let documents = database.get_documents_by_ids(ids.as_slice())?;
817
818            let mut tasks = Vec::new();
819            let mut ranges = Vec::new();
820            let weak_project = project.downgrade();
821            project.update(&mut cx, |project, cx| {
822                for (worktree_db_id, file_path, byte_range) in documents {
823                    let project_state =
824                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
825                            state
826                        } else {
827                            return Err(anyhow!("project not added"));
828                        };
829                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
830                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
831                        ranges.push(byte_range);
832                    }
833                }
834
835                Ok(())
836            })?;
837
838            let buffers = futures::future::join_all(tasks).await;
839
840            log::trace!(
841                "Semantic Searching took: {:?} milliseconds in total",
842                t0.elapsed().as_millis()
843            );
844
845            Ok(buffers
846                .into_iter()
847                .zip(ranges)
848                .filter_map(|(buffer, range)| {
849                    let buffer = buffer.log_err()?;
850                    let range = buffer.read_with(&cx, |buffer, _| {
851                        buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
852                    });
853                    Some(SearchResult { buffer, range })
854                })
855                .collect::<Vec<_>>())
856        })
857    }
858}
859
860impl Entity for SemanticIndex {
861    type Event = ();
862}
863
864impl Drop for JobHandle {
865    fn drop(&mut self) {
866        if let Some(inner) = Arc::get_mut(&mut self.tx) {
867            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
868            if let Some(tx) = inner.upgrade() {
869                let mut tx = tx.lock();
870                *tx.borrow_mut() -= 1;
871            }
872        }
873    }
874}
875
876#[cfg(test)]
877mod tests {
878
879    use super::*;
880    #[test]
881    fn test_job_handle() {
882        let (job_count_tx, job_count_rx) = watch::channel_with(0);
883        let tx = Arc::new(Mutex::new(job_count_tx));
884        let job_handle = JobHandle::new(&tx);
885
886        assert_eq!(1, *job_count_rx.borrow());
887        let new_job_handle = job_handle.clone();
888        assert_eq!(1, *job_count_rx.borrow());
889        drop(job_handle);
890        assert_eq!(1, *job_count_rx.borrow());
891        drop(new_job_handle);
892        assert_eq!(0, *job_count_rx.borrow());
893    }
894}