semantic_index.rs

   1mod db;
   2mod embedding;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use anyhow::{anyhow, Result};
  11use db::VectorDatabase;
  12use embedding::{EmbeddingProvider, OpenAIEmbeddings};
  13use futures::{channel::oneshot, Future};
  14use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  15use isahc::http::header::OccupiedEntry;
  16use language::{Anchor, Buffer, Language, LanguageRegistry};
  17use parking_lot::Mutex;
  18use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
  19use postage::stream::Stream;
  20use postage::watch;
  21use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId};
  22use smol::channel;
  23use std::{
  24    cmp::Ordering,
  25    collections::HashMap,
  26    mem,
  27    ops::Range,
  28    path::{Path, PathBuf},
  29    sync::{Arc, Weak},
  30    time::{Instant, SystemTime},
  31};
  32use util::{
  33    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
  34    http::HttpClient,
  35    paths::EMBEDDINGS_DIR,
  36    ResultExt,
  37};
  38use workspace::WorkspaceCreated;
  39
  40const SEMANTIC_INDEX_VERSION: usize = 7;
  41const EMBEDDINGS_BATCH_SIZE: usize = 80;
  42
  43pub fn init(
  44    fs: Arc<dyn Fs>,
  45    http_client: Arc<dyn HttpClient>,
  46    language_registry: Arc<LanguageRegistry>,
  47    cx: &mut AppContext,
  48) {
  49    settings::register::<SemanticIndexSettings>(cx);
  50
  51    let db_file_path = EMBEDDINGS_DIR
  52        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  53        .join("embeddings_db");
  54
  55    // This needs to be removed at some point before stable.
  56    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
  57        return;
  58    }
  59
  60    cx.subscribe_global::<WorkspaceCreated, _>({
  61        move |event, mut cx| {
  62            let Some(semantic_index) = SemanticIndex::global(cx) else { return; };
  63            let workspace = &event.0;
  64            if let Some(workspace) = workspace.upgrade(cx) {
  65                let project = workspace.read(cx).project().clone();
  66                if project.read(cx).is_local() {
  67                    semantic_index.update(cx, |index, cx| {
  68                        index.initialize_project(project, cx).detach_and_log_err(cx)
  69                    });
  70                }
  71            }
  72        }
  73    })
  74    .detach();
  75
  76    cx.spawn(move |mut cx| async move {
  77        let semantic_index = SemanticIndex::new(
  78            fs,
  79            db_file_path,
  80            Arc::new(OpenAIEmbeddings {
  81                client: http_client,
  82                executor: cx.background(),
  83            }),
  84            language_registry,
  85            cx.clone(),
  86        )
  87        .await?;
  88
  89        cx.update(|cx| {
  90            cx.set_global(semantic_index.clone());
  91        });
  92
  93        anyhow::Ok(())
  94    })
  95    .detach();
  96}
  97
  98pub struct SemanticIndex {
  99    fs: Arc<dyn Fs>,
 100    database_url: Arc<PathBuf>,
 101    embedding_provider: Arc<dyn EmbeddingProvider>,
 102    language_registry: Arc<LanguageRegistry>,
 103    db_update_tx: channel::Sender<DbOperation>,
 104    parsing_files_tx: channel::Sender<PendingFile>,
 105    _db_update_task: Task<()>,
 106    _embed_batch_tasks: Vec<Task<()>>,
 107    _batch_files_task: Task<()>,
 108    _parsing_files_tasks: Vec<Task<()>>,
 109    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 110}
 111
 112struct ProjectState {
 113    worktree_db_ids: Vec<(WorktreeId, i64)>,
 114    subscription: gpui::Subscription,
 115    outstanding_job_count_rx: watch::Receiver<usize>,
 116    _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 117    job_queue_tx: channel::Sender<IndexOperation>,
 118    _queue_update_task: Task<()>,
 119}
 120
 121#[derive(Clone)]
 122struct JobHandle {
 123    /// The outer Arc is here to count the clones of a JobHandle instance;
 124    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 125    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 126}
 127
 128impl JobHandle {
 129    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 130        *tx.lock().borrow_mut() += 1;
 131        Self {
 132            tx: Arc::new(Arc::downgrade(&tx)),
 133        }
 134    }
 135}
 136impl ProjectState {
 137    fn new(
 138        cx: &mut AppContext,
 139        subscription: gpui::Subscription,
 140        worktree_db_ids: Vec<(WorktreeId, i64)>,
 141        outstanding_job_count_rx: watch::Receiver<usize>,
 142        _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 143    ) -> Self {
 144        let (job_count_tx, job_count_rx) = watch::channel_with(0);
 145        let job_count_tx = Arc::new(Mutex::new(job_count_tx));
 146
 147        let (job_queue_tx, job_queue_rx) = channel::unbounded();
 148        let _queue_update_task = cx.background().spawn({
 149            let mut worktree_queue = HashMap::new();
 150            async move {
 151                while let Ok(operation) = job_queue_rx.recv().await {
 152                    Self::update_queue(&mut worktree_queue, operation);
 153                }
 154            }
 155        });
 156
 157        Self {
 158            worktree_db_ids,
 159            outstanding_job_count_rx,
 160            _outstanding_job_count_tx,
 161            subscription,
 162            _queue_update_task,
 163            job_queue_tx,
 164        }
 165    }
 166
 167    pub fn get_outstanding_count(&self) -> usize {
 168        self.outstanding_job_count_rx.borrow().clone()
 169    }
 170
 171    fn update_queue(queue: &mut HashMap<PathBuf, IndexOperation>, operation: IndexOperation) {
 172        match operation {
 173            IndexOperation::FlushQueue => {
 174                let queue = std::mem::take(queue);
 175                for (_, op) in queue {
 176                    match op {
 177                        IndexOperation::IndexFile {
 178                            absolute_path,
 179                            payload,
 180                            tx,
 181                        } => {
 182                            tx.try_send(payload);
 183                        }
 184                        IndexOperation::DeleteFile {
 185                            absolute_path,
 186                            payload,
 187                            tx,
 188                        } => {
 189                            tx.try_send(payload);
 190                        }
 191                        _ => {}
 192                    }
 193                }
 194            }
 195            IndexOperation::IndexFile {
 196                ref absolute_path, ..
 197            }
 198            | IndexOperation::DeleteFile {
 199                ref absolute_path, ..
 200            } => {
 201                queue.insert(absolute_path.clone(), operation);
 202            }
 203        }
 204    }
 205
 206    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
 207        self.worktree_db_ids
 208            .iter()
 209            .find_map(|(worktree_id, db_id)| {
 210                if *worktree_id == id {
 211                    Some(*db_id)
 212                } else {
 213                    None
 214                }
 215            })
 216    }
 217
 218    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 219        self.worktree_db_ids
 220            .iter()
 221            .find_map(|(worktree_id, db_id)| {
 222                if *db_id == id {
 223                    Some(*worktree_id)
 224                } else {
 225                    None
 226                }
 227            })
 228    }
 229}
 230
 231#[derive(Clone)]
 232pub struct PendingFile {
 233    worktree_db_id: i64,
 234    relative_path: PathBuf,
 235    absolute_path: PathBuf,
 236    language: Arc<Language>,
 237    modified_time: SystemTime,
 238    job_handle: JobHandle,
 239}
 240enum IndexOperation {
 241    IndexFile {
 242        absolute_path: PathBuf,
 243        payload: PendingFile,
 244        tx: channel::Sender<PendingFile>,
 245    },
 246    DeleteFile {
 247        absolute_path: PathBuf,
 248        payload: DbOperation,
 249        tx: channel::Sender<DbOperation>,
 250    },
 251    FlushQueue,
 252}
 253
 254pub struct SearchResult {
 255    pub buffer: ModelHandle<Buffer>,
 256    pub range: Range<Anchor>,
 257}
 258
 259enum DbOperation {
 260    InsertFile {
 261        worktree_id: i64,
 262        documents: Vec<Document>,
 263        path: PathBuf,
 264        mtime: SystemTime,
 265        job_handle: JobHandle,
 266    },
 267    Delete {
 268        worktree_id: i64,
 269        path: PathBuf,
 270    },
 271    FindOrCreateWorktree {
 272        path: PathBuf,
 273        sender: oneshot::Sender<Result<i64>>,
 274    },
 275    FileMTimes {
 276        worktree_id: i64,
 277        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
 278    },
 279    WorktreePreviouslyIndexed {
 280        path: Arc<Path>,
 281        sender: oneshot::Sender<Result<bool>>,
 282    },
 283}
 284
 285enum EmbeddingJob {
 286    Enqueue {
 287        worktree_id: i64,
 288        path: PathBuf,
 289        mtime: SystemTime,
 290        documents: Vec<Document>,
 291        job_handle: JobHandle,
 292    },
 293    Flush,
 294}
 295
 296impl SemanticIndex {
 297    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
 298        if cx.has_global::<ModelHandle<Self>>() {
 299            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 300        } else {
 301            None
 302        }
 303    }
 304
 305    pub fn enabled(cx: &AppContext) -> bool {
 306        settings::get::<SemanticIndexSettings>(cx).enabled
 307            && *RELEASE_CHANNEL != ReleaseChannel::Stable
 308    }
 309
 310    async fn new(
 311        fs: Arc<dyn Fs>,
 312        database_url: PathBuf,
 313        embedding_provider: Arc<dyn EmbeddingProvider>,
 314        language_registry: Arc<LanguageRegistry>,
 315        mut cx: AsyncAppContext,
 316    ) -> Result<ModelHandle<Self>> {
 317        let t0 = Instant::now();
 318        let database_url = Arc::new(database_url);
 319
 320        let db = cx
 321            .background()
 322            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
 323            .await?;
 324
 325        log::trace!(
 326            "db initialization took {:?} milliseconds",
 327            t0.elapsed().as_millis()
 328        );
 329
 330        Ok(cx.add_model(|cx| {
 331            let t0 = Instant::now();
 332            // Perform database operations
 333            let (db_update_tx, db_update_rx) = channel::unbounded();
 334            let _db_update_task = cx.background().spawn({
 335                async move {
 336                    while let Ok(job) = db_update_rx.recv().await {
 337                        Self::run_db_operation(&db, job)
 338                    }
 339                }
 340            });
 341
 342            // Group documents into batches and send them to the embedding provider.
 343            let (embed_batch_tx, embed_batch_rx) =
 344                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
 345            let mut _embed_batch_tasks = Vec::new();
 346            for _ in 0..cx.background().num_cpus() {
 347                let embed_batch_rx = embed_batch_rx.clone();
 348                _embed_batch_tasks.push(cx.background().spawn({
 349                    let db_update_tx = db_update_tx.clone();
 350                    let embedding_provider = embedding_provider.clone();
 351                    async move {
 352                        while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
 353                            Self::compute_embeddings_for_batch(
 354                                embeddings_queue,
 355                                &embedding_provider,
 356                                &db_update_tx,
 357                            )
 358                            .await;
 359                        }
 360                    }
 361                }));
 362            }
 363
 364            // Group documents into batches and send them to the embedding provider.
 365            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
 366            let _batch_files_task = cx.background().spawn(async move {
 367                let mut queue_len = 0;
 368                let mut embeddings_queue = vec![];
 369                while let Ok(job) = batch_files_rx.recv().await {
 370                    Self::enqueue_documents_to_embed(
 371                        job,
 372                        &mut queue_len,
 373                        &mut embeddings_queue,
 374                        &embed_batch_tx,
 375                    );
 376                }
 377            });
 378
 379            // Parse files into embeddable documents.
 380            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
 381            let mut _parsing_files_tasks = Vec::new();
 382            for _ in 0..cx.background().num_cpus() {
 383                let fs = fs.clone();
 384                let parsing_files_rx = parsing_files_rx.clone();
 385                let batch_files_tx = batch_files_tx.clone();
 386                let db_update_tx = db_update_tx.clone();
 387                _parsing_files_tasks.push(cx.background().spawn(async move {
 388                    let mut retriever = CodeContextRetriever::new();
 389                    while let Ok(pending_file) = parsing_files_rx.recv().await {
 390                        Self::parse_file(
 391                            &fs,
 392                            pending_file,
 393                            &mut retriever,
 394                            &batch_files_tx,
 395                            &parsing_files_rx,
 396                            &db_update_tx,
 397                        )
 398                        .await;
 399                    }
 400                }));
 401            }
 402
 403            log::trace!(
 404                "semantic index task initialization took {:?} milliseconds",
 405                t0.elapsed().as_millis()
 406            );
 407            Self {
 408                fs,
 409                database_url,
 410                embedding_provider,
 411                language_registry,
 412                db_update_tx,
 413                parsing_files_tx,
 414                _db_update_task,
 415                _embed_batch_tasks,
 416                _batch_files_task,
 417                _parsing_files_tasks,
 418                projects: HashMap::new(),
 419            }
 420        }))
 421    }
 422
 423    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
 424        match job {
 425            DbOperation::InsertFile {
 426                worktree_id,
 427                documents,
 428                path,
 429                mtime,
 430                job_handle,
 431            } => {
 432                db.insert_file(worktree_id, path, mtime, documents)
 433                    .log_err();
 434                drop(job_handle)
 435            }
 436            DbOperation::Delete { worktree_id, path } => {
 437                db.delete_file(worktree_id, path).log_err();
 438            }
 439            DbOperation::FindOrCreateWorktree { path, sender } => {
 440                let id = db.find_or_create_worktree(&path);
 441                sender.send(id).ok();
 442            }
 443            DbOperation::FileMTimes {
 444                worktree_id: worktree_db_id,
 445                sender,
 446            } => {
 447                let file_mtimes = db.get_file_mtimes(worktree_db_id);
 448                sender.send(file_mtimes).ok();
 449            }
 450            DbOperation::WorktreePreviouslyIndexed { path, sender } => {
 451                let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
 452                sender.send(worktree_indexed).ok();
 453            }
 454        }
 455    }
 456
 457    async fn compute_embeddings_for_batch(
 458        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
 459        embedding_provider: &Arc<dyn EmbeddingProvider>,
 460        db_update_tx: &channel::Sender<DbOperation>,
 461    ) {
 462        let mut batch_documents = vec![];
 463        for (_, documents, _, _, _) in embeddings_queue.iter() {
 464            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
 465        }
 466
 467        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
 468            log::trace!(
 469                "created {} embeddings for {} files",
 470                embeddings.len(),
 471                embeddings_queue.len(),
 472            );
 473
 474            let mut i = 0;
 475            let mut j = 0;
 476
 477            for embedding in embeddings.iter() {
 478                while embeddings_queue[i].1.len() == j {
 479                    i += 1;
 480                    j = 0;
 481                }
 482
 483                embeddings_queue[i].1[j].embedding = embedding.to_owned();
 484                j += 1;
 485            }
 486
 487            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
 488                db_update_tx
 489                    .send(DbOperation::InsertFile {
 490                        worktree_id,
 491                        documents,
 492                        path,
 493                        mtime,
 494                        job_handle,
 495                    })
 496                    .await
 497                    .unwrap();
 498            }
 499        } else {
 500            // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
 501            for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
 502                db_update_tx
 503                    .send(DbOperation::InsertFile {
 504                        worktree_id,
 505                        documents: vec![],
 506                        path,
 507                        mtime,
 508                        job_handle,
 509                    })
 510                    .await
 511                    .unwrap();
 512            }
 513        }
 514    }
 515
 516    fn enqueue_documents_to_embed(
 517        job: EmbeddingJob,
 518        queue_len: &mut usize,
 519        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
 520        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
 521    ) {
 522        // Handle edge case where individual file has more documents than max batch size
 523        let should_flush = match job {
 524            EmbeddingJob::Enqueue {
 525                documents,
 526                worktree_id,
 527                path,
 528                mtime,
 529                job_handle,
 530            } => {
 531                // If documents is greater than embeddings batch size, recursively batch existing rows.
 532                if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
 533                    let first_job = EmbeddingJob::Enqueue {
 534                        documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
 535                        worktree_id,
 536                        path: path.clone(),
 537                        mtime,
 538                        job_handle: job_handle.clone(),
 539                    };
 540
 541                    Self::enqueue_documents_to_embed(
 542                        first_job,
 543                        queue_len,
 544                        embeddings_queue,
 545                        embed_batch_tx,
 546                    );
 547
 548                    let second_job = EmbeddingJob::Enqueue {
 549                        documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
 550                        worktree_id,
 551                        path: path.clone(),
 552                        mtime,
 553                        job_handle: job_handle.clone(),
 554                    };
 555
 556                    Self::enqueue_documents_to_embed(
 557                        second_job,
 558                        queue_len,
 559                        embeddings_queue,
 560                        embed_batch_tx,
 561                    );
 562                    return;
 563                } else {
 564                    *queue_len += &documents.len();
 565                    embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
 566                    *queue_len >= EMBEDDINGS_BATCH_SIZE
 567                }
 568            }
 569            EmbeddingJob::Flush => true,
 570        };
 571
 572        if should_flush {
 573            embed_batch_tx
 574                .try_send(mem::take(embeddings_queue))
 575                .unwrap();
 576            *queue_len = 0;
 577        }
 578    }
 579
 580    async fn parse_file(
 581        fs: &Arc<dyn Fs>,
 582        pending_file: PendingFile,
 583        retriever: &mut CodeContextRetriever,
 584        batch_files_tx: &channel::Sender<EmbeddingJob>,
 585        parsing_files_rx: &channel::Receiver<PendingFile>,
 586        db_update_tx: &channel::Sender<DbOperation>,
 587    ) {
 588        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 589            if let Some(documents) = retriever
 590                .parse_file_with_template(
 591                    &pending_file.relative_path,
 592                    &content,
 593                    pending_file.language,
 594                )
 595                .log_err()
 596            {
 597                log::trace!(
 598                    "parsed path {:?}: {} documents",
 599                    pending_file.relative_path,
 600                    documents.len()
 601                );
 602
 603                if documents.len() == 0 {
 604                    db_update_tx
 605                        .send(DbOperation::InsertFile {
 606                            worktree_id: pending_file.worktree_db_id,
 607                            documents,
 608                            path: pending_file.relative_path,
 609                            mtime: pending_file.modified_time,
 610                            job_handle: pending_file.job_handle,
 611                        })
 612                        .await
 613                        .unwrap();
 614                } else {
 615                    batch_files_tx
 616                        .try_send(EmbeddingJob::Enqueue {
 617                            worktree_id: pending_file.worktree_db_id,
 618                            path: pending_file.relative_path,
 619                            mtime: pending_file.modified_time,
 620                            job_handle: pending_file.job_handle,
 621                            documents,
 622                        })
 623                        .unwrap();
 624                }
 625            }
 626        }
 627
 628        if parsing_files_rx.len() == 0 {
 629            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
 630        }
 631    }
 632
 633    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
 634        let (tx, rx) = oneshot::channel();
 635        self.db_update_tx
 636            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
 637            .unwrap();
 638        async move { rx.await? }
 639    }
 640
 641    fn get_file_mtimes(
 642        &self,
 643        worktree_id: i64,
 644    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
 645        let (tx, rx) = oneshot::channel();
 646        self.db_update_tx
 647            .try_send(DbOperation::FileMTimes {
 648                worktree_id,
 649                sender: tx,
 650            })
 651            .unwrap();
 652        async move { rx.await? }
 653    }
 654
 655    fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
 656        let (tx, rx) = oneshot::channel();
 657        self.db_update_tx
 658            .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
 659            .unwrap();
 660        async move { rx.await? }
 661    }
 662
 663    pub fn project_previously_indexed(
 664        &mut self,
 665        project: ModelHandle<Project>,
 666        cx: &mut ModelContext<Self>,
 667    ) -> Task<Result<bool>> {
 668        let worktrees_indexed_previously = project
 669            .read(cx)
 670            .worktrees(cx)
 671            .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
 672            .collect::<Vec<_>>();
 673        cx.spawn(|_, _cx| async move {
 674            let worktree_indexed_previously =
 675                futures::future::join_all(worktrees_indexed_previously).await;
 676
 677            Ok(worktree_indexed_previously
 678                .iter()
 679                .filter(|worktree| worktree.is_ok())
 680                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 681        })
 682    }
 683
 684    fn project_entries_changed(
 685        &self,
 686        project: ModelHandle<Project>,
 687        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 688        cx: &mut ModelContext<'_, SemanticIndex>,
 689        worktree_id: &WorktreeId,
 690    ) -> Result<()> {
 691        let parsing_files_tx = self.parsing_files_tx.clone();
 692        let db_update_tx = self.db_update_tx.clone();
 693        let (job_queue_tx, outstanding_job_tx, worktree_db_id) = {
 694            let state = self
 695                .projects
 696                .get(&project.downgrade())
 697                .ok_or(anyhow!("Project not yet initialized"))?;
 698            let worktree_db_id = state
 699                .db_id_for_worktree_id(*worktree_id)
 700                .ok_or(anyhow!("Worktree ID in Database Not Available"))?;
 701            (
 702                state.job_queue_tx.clone(),
 703                state._outstanding_job_count_tx.clone(),
 704                worktree_db_id,
 705            )
 706        };
 707
 708        let language_registry = self.language_registry.clone();
 709        let parsing_files_tx = parsing_files_tx.clone();
 710        let db_update_tx = db_update_tx.clone();
 711
 712        let worktree = project
 713            .read(cx)
 714            .worktree_for_id(worktree_id.clone(), cx)
 715            .ok_or(anyhow!("Worktree not available"))?
 716            .read(cx)
 717            .snapshot();
 718        cx.spawn(|this, mut cx| async move {
 719            let worktree = worktree.clone();
 720            for (path, entry_id, path_change) in changes.iter() {
 721                let relative_path = path.to_path_buf();
 722                let absolute_path = worktree.absolutize(path);
 723
 724                let Some(entry) = worktree.entry_for_id(*entry_id) else {
 725                    continue;
 726                };
 727                if entry.is_ignored || entry.is_symlink || entry.is_external {
 728                    continue;
 729                }
 730
 731                match path_change {
 732                    PathChange::AddedOrUpdated | PathChange::Updated => {
 733                        log::trace!("File Updated: {:?}", path);
 734                        if let Ok(language) = language_registry
 735                            .language_for_file(&relative_path, None)
 736                            .await
 737                        {
 738                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
 739                                && &language.name().as_ref() != &"Markdown"
 740                                && language
 741                                    .grammar()
 742                                    .and_then(|grammar| grammar.embedding_config.as_ref())
 743                                    .is_none()
 744                            {
 745                                continue;
 746                            }
 747
 748                            let job_handle = JobHandle::new(&outstanding_job_tx);
 749                            let new_operation = IndexOperation::IndexFile {
 750                                absolute_path: absolute_path.clone(),
 751                                payload: PendingFile {
 752                                    worktree_db_id,
 753                                    relative_path,
 754                                    absolute_path,
 755                                    language,
 756                                    modified_time: entry.mtime,
 757                                    job_handle,
 758                                },
 759                                tx: parsing_files_tx.clone(),
 760                            };
 761                            job_queue_tx.try_send(new_operation);
 762                        }
 763                    }
 764                    PathChange::Removed => {
 765                        let new_operation = IndexOperation::DeleteFile {
 766                            absolute_path,
 767                            payload: DbOperation::Delete {
 768                                worktree_id: worktree_db_id,
 769                                path: relative_path,
 770                            },
 771                            tx: db_update_tx.clone(),
 772                        };
 773                        job_queue_tx.try_send(new_operation);
 774                    }
 775                    _ => {}
 776                }
 777            }
 778        })
 779        .detach();
 780
 781        Ok(())
 782    }
 783
 784    pub fn initialize_project(
 785        &mut self,
 786        project: ModelHandle<Project>,
 787        cx: &mut ModelContext<Self>,
 788    ) -> Task<Result<()>> {
 789        let worktree_scans_complete = project
 790            .read(cx)
 791            .worktrees(cx)
 792            .map(|worktree| {
 793                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
 794                async move {
 795                    scan_complete.await;
 796                }
 797            })
 798            .collect::<Vec<_>>();
 799
 800        let worktree_db_ids = project
 801            .read(cx)
 802            .worktrees(cx)
 803            .map(|worktree| {
 804                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
 805            })
 806            .collect::<Vec<_>>();
 807
 808        let _subscription = cx.subscribe(&project, |this, project, event, cx| {
 809            if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
 810                this.project_entries_changed(project, changes.clone(), cx, worktree_id);
 811            };
 812        });
 813
 814        let language_registry = self.language_registry.clone();
 815        let parsing_files_tx = self.parsing_files_tx.clone();
 816        let db_update_tx = self.db_update_tx.clone();
 817
 818        cx.spawn(|this, mut cx| async move {
 819            futures::future::join_all(worktree_scans_complete).await;
 820
 821            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
 822            let worktrees = project.read_with(&cx, |project, cx| {
 823                project
 824                    .worktrees(cx)
 825                    .map(|worktree| worktree.read(cx).snapshot())
 826                    .collect::<Vec<_>>()
 827            });
 828
 829            let mut worktree_file_mtimes = HashMap::new();
 830            let mut db_ids_by_worktree_id = HashMap::new();
 831
 832            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
 833                let db_id = db_id?;
 834                db_ids_by_worktree_id.insert(worktree.id(), db_id);
 835                worktree_file_mtimes.insert(
 836                    worktree.id(),
 837                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
 838                        .await?,
 839                );
 840            }
 841
 842            let worktree_db_ids = db_ids_by_worktree_id
 843                .iter()
 844                .map(|(a, b)| (*a, *b))
 845                .collect();
 846
 847            let (job_count_tx, job_count_rx) = watch::channel_with(0);
 848            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
 849            let job_count_tx_longlived = job_count_tx.clone();
 850
 851            let worktree_files = cx
 852                .background()
 853                .spawn(async move {
 854                    let mut worktree_files = Vec::new();
 855                    for worktree in worktrees.into_iter() {
 856                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
 857                        let worktree_db_id = db_ids_by_worktree_id[&worktree.id()];
 858                        for file in worktree.files(false, 0) {
 859                            let absolute_path = worktree.absolutize(&file.path);
 860
 861                            if file.is_external || file.is_ignored || file.is_symlink {
 862                                continue;
 863                            }
 864
 865                            if let Ok(language) = language_registry
 866                                .language_for_file(&absolute_path, None)
 867                                .await
 868                            {
 869                                // Test if file is valid parseable file
 870                                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
 871                                    && &language.name().as_ref() != &"Markdown"
 872                                    && language
 873                                        .grammar()
 874                                        .and_then(|grammar| grammar.embedding_config.as_ref())
 875                                        .is_none()
 876                                {
 877                                    continue;
 878                                }
 879
 880                                let path_buf = file.path.to_path_buf();
 881                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 882                                let already_stored = stored_mtime
 883                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
 884
 885                                if !already_stored {
 886                                    let job_handle = JobHandle::new(&job_count_tx);
 887                                    worktree_files.push(IndexOperation::IndexFile {
 888                                        absolute_path: absolute_path.clone(),
 889                                        payload: PendingFile {
 890                                            worktree_db_id,
 891                                            relative_path: path_buf,
 892                                            absolute_path,
 893                                            language,
 894                                            job_handle,
 895                                            modified_time: file.mtime,
 896                                        },
 897                                        tx: parsing_files_tx.clone(),
 898                                    });
 899                                }
 900                            }
 901                        }
 902                        // Clean up entries from database that are no longer in the worktree.
 903                        for (path, mtime) in file_mtimes {
 904                            worktree_files.push(IndexOperation::DeleteFile {
 905                                absolute_path: worktree.absolutize(path.as_path()),
 906                                payload: DbOperation::Delete {
 907                                    worktree_id: worktree_db_id,
 908                                    path,
 909                                },
 910                                tx: db_update_tx.clone(),
 911                            });
 912                        }
 913                    }
 914
 915                    anyhow::Ok(worktree_files)
 916                })
 917                .await?;
 918
 919            this.update(&mut cx, |this, cx| {
 920                let project_state = ProjectState::new(
 921                    cx,
 922                    _subscription,
 923                    worktree_db_ids,
 924                    job_count_rx,
 925                    job_count_tx_longlived,
 926                );
 927
 928                for op in worktree_files {
 929                    project_state.job_queue_tx.try_send(op);
 930                }
 931
 932                this.projects.insert(project.downgrade(), project_state);
 933            });
 934            Result::<(), _>::Ok(())
 935        })
 936    }
 937
 938    pub fn index_project(
 939        &mut self,
 940        project: ModelHandle<Project>,
 941        cx: &mut ModelContext<Self>,
 942    ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
 943        let state = self.projects.get_mut(&project.downgrade());
 944        let state = if state.is_none() {
 945            return Task::Ready(Some(Err(anyhow!("Project not yet initialized"))));
 946        } else {
 947            state.unwrap()
 948        };
 949
 950        let parsing_files_tx = self.parsing_files_tx.clone();
 951        let db_update_tx = self.db_update_tx.clone();
 952        let job_count_rx = state.outstanding_job_count_rx.clone();
 953        let count = state.get_outstanding_count();
 954
 955        cx.spawn(|this, mut cx| async move {
 956            this.update(&mut cx, |this, cx| {
 957                let Some(state) = this.projects.get_mut(&project.downgrade()) else {
 958                    return;
 959                };
 960                state.job_queue_tx.try_send(IndexOperation::FlushQueue);
 961            });
 962
 963            Ok((count, job_count_rx))
 964        })
 965    }
 966
 967    pub fn outstanding_job_count_rx(
 968        &self,
 969        project: &ModelHandle<Project>,
 970    ) -> Option<watch::Receiver<usize>> {
 971        Some(
 972            self.projects
 973                .get(&project.downgrade())?
 974                .outstanding_job_count_rx
 975                .clone(),
 976        )
 977    }
 978
 979    pub fn search_project(
 980        &mut self,
 981        project: ModelHandle<Project>,
 982        phrase: String,
 983        limit: usize,
 984        includes: Vec<PathMatcher>,
 985        excludes: Vec<PathMatcher>,
 986        cx: &mut ModelContext<Self>,
 987    ) -> Task<Result<Vec<SearchResult>>> {
 988        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
 989            state
 990        } else {
 991            return Task::ready(Err(anyhow!("project not added")));
 992        };
 993
 994        let worktree_db_ids = project
 995            .read(cx)
 996            .worktrees(cx)
 997            .filter_map(|worktree| {
 998                let worktree_id = worktree.read(cx).id();
 999                project_state.db_id_for_worktree_id(worktree_id)
1000            })
1001            .collect::<Vec<_>>();
1002
1003        let embedding_provider = self.embedding_provider.clone();
1004        let database_url = self.database_url.clone();
1005        let fs = self.fs.clone();
1006        cx.spawn(|this, mut cx| async move {
1007            let t0 = Instant::now();
1008            let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
1009
1010            let phrase_embedding = embedding_provider
1011                .embed_batch(vec![&phrase])
1012                .await?
1013                .into_iter()
1014                .next()
1015                .unwrap();
1016
1017            log::trace!(
1018                "Embedding search phrase took: {:?} milliseconds",
1019                t0.elapsed().as_millis()
1020            );
1021
1022            let file_ids =
1023                database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
1024
1025            let batch_n = cx.background().num_cpus();
1026            let ids_len = file_ids.clone().len();
1027            let batch_size = if ids_len <= batch_n {
1028                ids_len
1029            } else {
1030                ids_len / batch_n
1031            };
1032
1033            let mut result_tasks = Vec::new();
1034            for batch in file_ids.chunks(batch_size) {
1035                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
1036                let limit = limit.clone();
1037                let fs = fs.clone();
1038                let database_url = database_url.clone();
1039                let phrase_embedding = phrase_embedding.clone();
1040                let task = cx.background().spawn(async move {
1041                    let database = VectorDatabase::new(fs, database_url).await.log_err();
1042                    if database.is_none() {
1043                        return Err(anyhow!("failed to acquire database connection"));
1044                    } else {
1045                        database
1046                            .unwrap()
1047                            .top_k_search(&phrase_embedding, limit, batch.as_slice())
1048                    }
1049                });
1050                result_tasks.push(task);
1051            }
1052
1053            let batch_results = futures::future::join_all(result_tasks).await;
1054
1055            let mut results = Vec::new();
1056            for batch_result in batch_results {
1057                if batch_result.is_ok() {
1058                    for (id, similarity) in batch_result.unwrap() {
1059                        let ix = match results.binary_search_by(|(_, s)| {
1060                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
1061                        }) {
1062                            Ok(ix) => ix,
1063                            Err(ix) => ix,
1064                        };
1065                        results.insert(ix, (id, similarity));
1066                        results.truncate(limit);
1067                    }
1068                }
1069            }
1070
1071            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
1072            let documents = database.get_documents_by_ids(ids.as_slice())?;
1073
1074            let mut tasks = Vec::new();
1075            let mut ranges = Vec::new();
1076            let weak_project = project.downgrade();
1077            project.update(&mut cx, |project, cx| {
1078                for (worktree_db_id, file_path, byte_range) in documents {
1079                    let project_state =
1080                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
1081                            state
1082                        } else {
1083                            return Err(anyhow!("project not added"));
1084                        };
1085                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
1086                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
1087                        ranges.push(byte_range);
1088                    }
1089                }
1090
1091                Ok(())
1092            })?;
1093
1094            let buffers = futures::future::join_all(tasks).await;
1095
1096            log::trace!(
1097                "Semantic Searching took: {:?} milliseconds in total",
1098                t0.elapsed().as_millis()
1099            );
1100
1101            Ok(buffers
1102                .into_iter()
1103                .zip(ranges)
1104                .filter_map(|(buffer, range)| {
1105                    let buffer = buffer.log_err()?;
1106                    let range = buffer.read_with(&cx, |buffer, _| {
1107                        buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
1108                    });
1109                    Some(SearchResult { buffer, range })
1110                })
1111                .collect::<Vec<_>>())
1112        })
1113    }
1114}
1115
1116impl Entity for SemanticIndex {
1117    type Event = ();
1118}
1119
1120impl Drop for JobHandle {
1121    fn drop(&mut self) {
1122        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1123            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1124            if let Some(tx) = inner.upgrade() {
1125                let mut tx = tx.lock();
1126                *tx.borrow_mut() -= 1;
1127            }
1128        }
1129    }
1130}
1131
1132#[cfg(test)]
1133mod tests {
1134
1135    use super::*;
1136    #[test]
1137    fn test_job_handle() {
1138        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1139        let tx = Arc::new(Mutex::new(job_count_tx));
1140        let job_handle = JobHandle::new(&tx);
1141
1142        assert_eq!(1, *job_count_rx.borrow());
1143        let new_job_handle = job_handle.clone();
1144        assert_eq!(1, *job_count_rx.borrow());
1145        drop(job_handle);
1146        assert_eq!(1, *job_count_rx.borrow());
1147        drop(new_job_handle);
1148        assert_eq!(0, *job_count_rx.borrow());
1149    }
1150}