semantic_index.rs

   1mod db;
   2mod embedding;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use anyhow::{anyhow, Result};
  11use db::VectorDatabase;
  12use embedding::{EmbeddingProvider, OpenAIEmbeddings};
  13use futures::{channel::oneshot, Future};
  14use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  15use language::{Anchor, Buffer, Language, LanguageRegistry};
  16use parking_lot::Mutex;
  17use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
  18use postage::watch;
  19use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId};
  20use smol::channel;
  21use std::{
  22    cmp::Ordering,
  23    collections::HashMap,
  24    mem,
  25    ops::Range,
  26    path::{Path, PathBuf},
  27    sync::{Arc, Weak},
  28    time::{Instant, SystemTime},
  29};
  30use util::{
  31    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
  32    http::HttpClient,
  33    paths::EMBEDDINGS_DIR,
  34    ResultExt,
  35};
  36use workspace::WorkspaceCreated;
  37
  38const SEMANTIC_INDEX_VERSION: usize = 7;
  39const EMBEDDINGS_BATCH_SIZE: usize = 80;
  40
  41pub fn init(
  42    fs: Arc<dyn Fs>,
  43    http_client: Arc<dyn HttpClient>,
  44    language_registry: Arc<LanguageRegistry>,
  45    cx: &mut AppContext,
  46) {
  47    settings::register::<SemanticIndexSettings>(cx);
  48
  49    let db_file_path = EMBEDDINGS_DIR
  50        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  51        .join("embeddings_db");
  52
  53    // This needs to be removed at some point before stable.
  54    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
  55        return;
  56    }
  57
  58    cx.subscribe_global::<WorkspaceCreated, _>({
  59        move |event, cx| {
  60            let Some(semantic_index) = SemanticIndex::global(cx) else { return; };
  61            let workspace = &event.0;
  62            if let Some(workspace) = workspace.upgrade(cx) {
  63                let project = workspace.read(cx).project().clone();
  64                if project.read(cx).is_local() {
  65                    semantic_index.update(cx, |index, cx| {
  66                        index.initialize_project(project, cx).detach_and_log_err(cx)
  67                    });
  68                }
  69            }
  70        }
  71    })
  72    .detach();
  73
  74    cx.spawn(move |mut cx| async move {
  75        let semantic_index = SemanticIndex::new(
  76            fs,
  77            db_file_path,
  78            Arc::new(OpenAIEmbeddings {
  79                client: http_client,
  80                executor: cx.background(),
  81            }),
  82            language_registry,
  83            cx.clone(),
  84        )
  85        .await?;
  86
  87        cx.update(|cx| {
  88            cx.set_global(semantic_index.clone());
  89        });
  90
  91        anyhow::Ok(())
  92    })
  93    .detach();
  94}
  95
  96pub struct SemanticIndex {
  97    fs: Arc<dyn Fs>,
  98    database_url: Arc<PathBuf>,
  99    embedding_provider: Arc<dyn EmbeddingProvider>,
 100    language_registry: Arc<LanguageRegistry>,
 101    db_update_tx: channel::Sender<DbOperation>,
 102    parsing_files_tx: channel::Sender<PendingFile>,
 103    _db_update_task: Task<()>,
 104    _embed_batch_tasks: Vec<Task<()>>,
 105    _batch_files_task: Task<()>,
 106    _parsing_files_tasks: Vec<Task<()>>,
 107    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 108}
 109
 110struct ProjectState {
 111    worktree_db_ids: Vec<(WorktreeId, i64)>,
 112    _subscription: gpui::Subscription,
 113    outstanding_job_count_rx: watch::Receiver<usize>,
 114    _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 115    job_queue_tx: channel::Sender<IndexOperation>,
 116    _queue_update_task: Task<()>,
 117}
 118
 119#[derive(Clone)]
 120struct JobHandle {
 121    /// The outer Arc is here to count the clones of a JobHandle instance;
 122    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 123    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 124}
 125
 126impl JobHandle {
 127    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 128        *tx.lock().borrow_mut() += 1;
 129        Self {
 130            tx: Arc::new(Arc::downgrade(&tx)),
 131        }
 132    }
 133}
 134impl ProjectState {
 135    fn new(
 136        cx: &mut AppContext,
 137        subscription: gpui::Subscription,
 138        worktree_db_ids: Vec<(WorktreeId, i64)>,
 139        outstanding_job_count_rx: watch::Receiver<usize>,
 140        _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 141    ) -> Self {
 142        let (job_queue_tx, job_queue_rx) = channel::unbounded();
 143        let _queue_update_task = cx.background().spawn({
 144            let mut worktree_queue = HashMap::new();
 145            async move {
 146                while let Ok(operation) = job_queue_rx.recv().await {
 147                    Self::update_queue(&mut worktree_queue, operation);
 148                }
 149            }
 150        });
 151
 152        Self {
 153            worktree_db_ids,
 154            outstanding_job_count_rx,
 155            _outstanding_job_count_tx,
 156            _subscription: subscription,
 157            _queue_update_task,
 158            job_queue_tx,
 159        }
 160    }
 161
 162    pub fn get_outstanding_count(&self) -> usize {
 163        self.outstanding_job_count_rx.borrow().clone()
 164    }
 165
 166    fn update_queue(queue: &mut HashMap<PathBuf, IndexOperation>, operation: IndexOperation) {
 167        match operation {
 168            IndexOperation::FlushQueue => {
 169                let queue = std::mem::take(queue);
 170                for (_, op) in queue {
 171                    match op {
 172                        IndexOperation::IndexFile {
 173                            absolute_path: _,
 174                            payload,
 175                            tx,
 176                        } => {
 177                            let _ = tx.try_send(payload);
 178                        }
 179                        IndexOperation::DeleteFile {
 180                            absolute_path: _,
 181                            payload,
 182                            tx,
 183                        } => {
 184                            let _ = tx.try_send(payload);
 185                        }
 186                        _ => {}
 187                    }
 188                }
 189            }
 190            IndexOperation::IndexFile {
 191                ref absolute_path, ..
 192            }
 193            | IndexOperation::DeleteFile {
 194                ref absolute_path, ..
 195            } => {
 196                queue.insert(absolute_path.clone(), operation);
 197            }
 198        }
 199    }
 200
 201    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
 202        self.worktree_db_ids
 203            .iter()
 204            .find_map(|(worktree_id, db_id)| {
 205                if *worktree_id == id {
 206                    Some(*db_id)
 207                } else {
 208                    None
 209                }
 210            })
 211    }
 212
 213    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 214        self.worktree_db_ids
 215            .iter()
 216            .find_map(|(worktree_id, db_id)| {
 217                if *db_id == id {
 218                    Some(*worktree_id)
 219                } else {
 220                    None
 221                }
 222            })
 223    }
 224}
 225
 226#[derive(Clone)]
 227pub struct PendingFile {
 228    worktree_db_id: i64,
 229    relative_path: PathBuf,
 230    absolute_path: PathBuf,
 231    language: Arc<Language>,
 232    modified_time: SystemTime,
 233    job_handle: JobHandle,
 234}
 235enum IndexOperation {
 236    IndexFile {
 237        absolute_path: PathBuf,
 238        payload: PendingFile,
 239        tx: channel::Sender<PendingFile>,
 240    },
 241    DeleteFile {
 242        absolute_path: PathBuf,
 243        payload: DbOperation,
 244        tx: channel::Sender<DbOperation>,
 245    },
 246    FlushQueue,
 247}
 248
 249pub struct SearchResult {
 250    pub buffer: ModelHandle<Buffer>,
 251    pub range: Range<Anchor>,
 252}
 253
 254enum DbOperation {
 255    InsertFile {
 256        worktree_id: i64,
 257        documents: Vec<Document>,
 258        path: PathBuf,
 259        mtime: SystemTime,
 260        job_handle: JobHandle,
 261    },
 262    Delete {
 263        worktree_id: i64,
 264        path: PathBuf,
 265    },
 266    FindOrCreateWorktree {
 267        path: PathBuf,
 268        sender: oneshot::Sender<Result<i64>>,
 269    },
 270    FileMTimes {
 271        worktree_id: i64,
 272        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
 273    },
 274    WorktreePreviouslyIndexed {
 275        path: Arc<Path>,
 276        sender: oneshot::Sender<Result<bool>>,
 277    },
 278}
 279
 280enum EmbeddingJob {
 281    Enqueue {
 282        worktree_id: i64,
 283        path: PathBuf,
 284        mtime: SystemTime,
 285        documents: Vec<Document>,
 286        job_handle: JobHandle,
 287    },
 288    Flush,
 289}
 290
 291impl SemanticIndex {
 292    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
 293        if cx.has_global::<ModelHandle<Self>>() {
 294            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 295        } else {
 296            None
 297        }
 298    }
 299
 300    pub fn enabled(cx: &AppContext) -> bool {
 301        settings::get::<SemanticIndexSettings>(cx).enabled
 302            && *RELEASE_CHANNEL != ReleaseChannel::Stable
 303    }
 304
 305    async fn new(
 306        fs: Arc<dyn Fs>,
 307        database_url: PathBuf,
 308        embedding_provider: Arc<dyn EmbeddingProvider>,
 309        language_registry: Arc<LanguageRegistry>,
 310        mut cx: AsyncAppContext,
 311    ) -> Result<ModelHandle<Self>> {
 312        let t0 = Instant::now();
 313        let database_url = Arc::new(database_url);
 314
 315        let db = cx
 316            .background()
 317            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
 318            .await?;
 319
 320        log::trace!(
 321            "db initialization took {:?} milliseconds",
 322            t0.elapsed().as_millis()
 323        );
 324
 325        Ok(cx.add_model(|cx| {
 326            let t0 = Instant::now();
 327            // Perform database operations
 328            let (db_update_tx, db_update_rx) = channel::unbounded();
 329            let _db_update_task = cx.background().spawn({
 330                async move {
 331                    while let Ok(job) = db_update_rx.recv().await {
 332                        Self::run_db_operation(&db, job)
 333                    }
 334                }
 335            });
 336
 337            // Group documents into batches and send them to the embedding provider.
 338            let (embed_batch_tx, embed_batch_rx) =
 339                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
 340            let mut _embed_batch_tasks = Vec::new();
 341            for _ in 0..cx.background().num_cpus() {
 342                let embed_batch_rx = embed_batch_rx.clone();
 343                _embed_batch_tasks.push(cx.background().spawn({
 344                    let db_update_tx = db_update_tx.clone();
 345                    let embedding_provider = embedding_provider.clone();
 346                    async move {
 347                        while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
 348                            Self::compute_embeddings_for_batch(
 349                                embeddings_queue,
 350                                &embedding_provider,
 351                                &db_update_tx,
 352                            )
 353                            .await;
 354                        }
 355                    }
 356                }));
 357            }
 358
 359            // Group documents into batches and send them to the embedding provider.
 360            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
 361            let _batch_files_task = cx.background().spawn(async move {
 362                let mut queue_len = 0;
 363                let mut embeddings_queue = vec![];
 364                while let Ok(job) = batch_files_rx.recv().await {
 365                    Self::enqueue_documents_to_embed(
 366                        job,
 367                        &mut queue_len,
 368                        &mut embeddings_queue,
 369                        &embed_batch_tx,
 370                    );
 371                }
 372            });
 373
 374            // Parse files into embeddable documents.
 375            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
 376            let mut _parsing_files_tasks = Vec::new();
 377            for _ in 0..cx.background().num_cpus() {
 378                let fs = fs.clone();
 379                let parsing_files_rx = parsing_files_rx.clone();
 380                let batch_files_tx = batch_files_tx.clone();
 381                let db_update_tx = db_update_tx.clone();
 382                _parsing_files_tasks.push(cx.background().spawn(async move {
 383                    let mut retriever = CodeContextRetriever::new();
 384                    while let Ok(pending_file) = parsing_files_rx.recv().await {
 385                        Self::parse_file(
 386                            &fs,
 387                            pending_file,
 388                            &mut retriever,
 389                            &batch_files_tx,
 390                            &parsing_files_rx,
 391                            &db_update_tx,
 392                        )
 393                        .await;
 394                    }
 395                }));
 396            }
 397
 398            log::trace!(
 399                "semantic index task initialization took {:?} milliseconds",
 400                t0.elapsed().as_millis()
 401            );
 402            Self {
 403                fs,
 404                database_url,
 405                embedding_provider,
 406                language_registry,
 407                db_update_tx,
 408                parsing_files_tx,
 409                _db_update_task,
 410                _embed_batch_tasks,
 411                _batch_files_task,
 412                _parsing_files_tasks,
 413                projects: HashMap::new(),
 414            }
 415        }))
 416    }
 417
 418    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
 419        match job {
 420            DbOperation::InsertFile {
 421                worktree_id,
 422                documents,
 423                path,
 424                mtime,
 425                job_handle,
 426            } => {
 427                db.insert_file(worktree_id, path, mtime, documents)
 428                    .log_err();
 429                drop(job_handle)
 430            }
 431            DbOperation::Delete { worktree_id, path } => {
 432                db.delete_file(worktree_id, path).log_err();
 433            }
 434            DbOperation::FindOrCreateWorktree { path, sender } => {
 435                let id = db.find_or_create_worktree(&path);
 436                sender.send(id).ok();
 437            }
 438            DbOperation::FileMTimes {
 439                worktree_id: worktree_db_id,
 440                sender,
 441            } => {
 442                let file_mtimes = db.get_file_mtimes(worktree_db_id);
 443                sender.send(file_mtimes).ok();
 444            }
 445            DbOperation::WorktreePreviouslyIndexed { path, sender } => {
 446                let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
 447                sender.send(worktree_indexed).ok();
 448            }
 449        }
 450    }
 451
 452    async fn compute_embeddings_for_batch(
 453        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
 454        embedding_provider: &Arc<dyn EmbeddingProvider>,
 455        db_update_tx: &channel::Sender<DbOperation>,
 456    ) {
 457        let mut batch_documents = vec![];
 458        for (_, documents, _, _, _) in embeddings_queue.iter() {
 459            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
 460        }
 461
 462        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
 463            log::trace!(
 464                "created {} embeddings for {} files",
 465                embeddings.len(),
 466                embeddings_queue.len(),
 467            );
 468
 469            let mut i = 0;
 470            let mut j = 0;
 471
 472            for embedding in embeddings.iter() {
 473                while embeddings_queue[i].1.len() == j {
 474                    i += 1;
 475                    j = 0;
 476                }
 477
 478                embeddings_queue[i].1[j].embedding = embedding.to_owned();
 479                j += 1;
 480            }
 481
 482            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
 483                db_update_tx
 484                    .send(DbOperation::InsertFile {
 485                        worktree_id,
 486                        documents,
 487                        path,
 488                        mtime,
 489                        job_handle,
 490                    })
 491                    .await
 492                    .unwrap();
 493            }
 494        } else {
 495            // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
 496            for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
 497                db_update_tx
 498                    .send(DbOperation::InsertFile {
 499                        worktree_id,
 500                        documents: vec![],
 501                        path,
 502                        mtime,
 503                        job_handle,
 504                    })
 505                    .await
 506                    .unwrap();
 507            }
 508        }
 509    }
 510
 511    fn enqueue_documents_to_embed(
 512        job: EmbeddingJob,
 513        queue_len: &mut usize,
 514        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
 515        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
 516    ) {
 517        // Handle edge case where individual file has more documents than max batch size
 518        let should_flush = match job {
 519            EmbeddingJob::Enqueue {
 520                documents,
 521                worktree_id,
 522                path,
 523                mtime,
 524                job_handle,
 525            } => {
 526                // If documents is greater than embeddings batch size, recursively batch existing rows.
 527                if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
 528                    let first_job = EmbeddingJob::Enqueue {
 529                        documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
 530                        worktree_id,
 531                        path: path.clone(),
 532                        mtime,
 533                        job_handle: job_handle.clone(),
 534                    };
 535
 536                    Self::enqueue_documents_to_embed(
 537                        first_job,
 538                        queue_len,
 539                        embeddings_queue,
 540                        embed_batch_tx,
 541                    );
 542
 543                    let second_job = EmbeddingJob::Enqueue {
 544                        documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
 545                        worktree_id,
 546                        path: path.clone(),
 547                        mtime,
 548                        job_handle: job_handle.clone(),
 549                    };
 550
 551                    Self::enqueue_documents_to_embed(
 552                        second_job,
 553                        queue_len,
 554                        embeddings_queue,
 555                        embed_batch_tx,
 556                    );
 557                    return;
 558                } else {
 559                    *queue_len += &documents.len();
 560                    embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
 561                    *queue_len >= EMBEDDINGS_BATCH_SIZE
 562                }
 563            }
 564            EmbeddingJob::Flush => true,
 565        };
 566
 567        if should_flush {
 568            embed_batch_tx
 569                .try_send(mem::take(embeddings_queue))
 570                .unwrap();
 571            *queue_len = 0;
 572        }
 573    }
 574
 575    async fn parse_file(
 576        fs: &Arc<dyn Fs>,
 577        pending_file: PendingFile,
 578        retriever: &mut CodeContextRetriever,
 579        batch_files_tx: &channel::Sender<EmbeddingJob>,
 580        parsing_files_rx: &channel::Receiver<PendingFile>,
 581        db_update_tx: &channel::Sender<DbOperation>,
 582    ) {
 583        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 584            if let Some(documents) = retriever
 585                .parse_file_with_template(
 586                    &pending_file.relative_path,
 587                    &content,
 588                    pending_file.language,
 589                )
 590                .log_err()
 591            {
 592                log::trace!(
 593                    "parsed path {:?}: {} documents",
 594                    pending_file.relative_path,
 595                    documents.len()
 596                );
 597
 598                if documents.len() == 0 {
 599                    db_update_tx
 600                        .send(DbOperation::InsertFile {
 601                            worktree_id: pending_file.worktree_db_id,
 602                            documents,
 603                            path: pending_file.relative_path,
 604                            mtime: pending_file.modified_time,
 605                            job_handle: pending_file.job_handle,
 606                        })
 607                        .await
 608                        .unwrap();
 609                } else {
 610                    batch_files_tx
 611                        .try_send(EmbeddingJob::Enqueue {
 612                            worktree_id: pending_file.worktree_db_id,
 613                            path: pending_file.relative_path,
 614                            mtime: pending_file.modified_time,
 615                            job_handle: pending_file.job_handle,
 616                            documents,
 617                        })
 618                        .unwrap();
 619                }
 620            }
 621        }
 622
 623        if parsing_files_rx.len() == 0 {
 624            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
 625        }
 626    }
 627
 628    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
 629        let (tx, rx) = oneshot::channel();
 630        self.db_update_tx
 631            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
 632            .unwrap();
 633        async move { rx.await? }
 634    }
 635
 636    fn get_file_mtimes(
 637        &self,
 638        worktree_id: i64,
 639    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
 640        let (tx, rx) = oneshot::channel();
 641        self.db_update_tx
 642            .try_send(DbOperation::FileMTimes {
 643                worktree_id,
 644                sender: tx,
 645            })
 646            .unwrap();
 647        async move { rx.await? }
 648    }
 649
 650    fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
 651        let (tx, rx) = oneshot::channel();
 652        self.db_update_tx
 653            .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
 654            .unwrap();
 655        async move { rx.await? }
 656    }
 657
 658    pub fn project_previously_indexed(
 659        &mut self,
 660        project: ModelHandle<Project>,
 661        cx: &mut ModelContext<Self>,
 662    ) -> Task<Result<bool>> {
 663        let worktrees_indexed_previously = project
 664            .read(cx)
 665            .worktrees(cx)
 666            .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
 667            .collect::<Vec<_>>();
 668        cx.spawn(|_, _cx| async move {
 669            let worktree_indexed_previously =
 670                futures::future::join_all(worktrees_indexed_previously).await;
 671
 672            Ok(worktree_indexed_previously
 673                .iter()
 674                .filter(|worktree| worktree.is_ok())
 675                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 676        })
 677    }
 678
 679    fn project_entries_changed(
 680        &self,
 681        project: ModelHandle<Project>,
 682        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 683        cx: &mut ModelContext<'_, SemanticIndex>,
 684        worktree_id: &WorktreeId,
 685    ) -> Result<()> {
 686        let parsing_files_tx = self.parsing_files_tx.clone();
 687        let db_update_tx = self.db_update_tx.clone();
 688        let (job_queue_tx, outstanding_job_tx, worktree_db_id) = {
 689            let state = self
 690                .projects
 691                .get(&project.downgrade())
 692                .ok_or(anyhow!("Project not yet initialized"))?;
 693            let worktree_db_id = state
 694                .db_id_for_worktree_id(*worktree_id)
 695                .ok_or(anyhow!("Worktree ID in Database Not Available"))?;
 696            (
 697                state.job_queue_tx.clone(),
 698                state._outstanding_job_count_tx.clone(),
 699                worktree_db_id,
 700            )
 701        };
 702
 703        let language_registry = self.language_registry.clone();
 704        let parsing_files_tx = parsing_files_tx.clone();
 705        let db_update_tx = db_update_tx.clone();
 706
 707        let worktree = project
 708            .read(cx)
 709            .worktree_for_id(worktree_id.clone(), cx)
 710            .ok_or(anyhow!("Worktree not available"))?
 711            .read(cx)
 712            .snapshot();
 713        cx.spawn(|_, _| async move {
 714            let worktree = worktree.clone();
 715            for (path, entry_id, path_change) in changes.iter() {
 716                let relative_path = path.to_path_buf();
 717                let absolute_path = worktree.absolutize(path);
 718
 719                let Some(entry) = worktree.entry_for_id(*entry_id) else {
 720                    continue;
 721                };
 722                if entry.is_ignored || entry.is_symlink || entry.is_external {
 723                    continue;
 724                }
 725
 726                log::trace!("File Event: {:?}, Path: {:?}", &path_change, &path);
 727                match path_change {
 728                    PathChange::AddedOrUpdated | PathChange::Updated | PathChange::Added => {
 729                        if let Ok(language) = language_registry
 730                            .language_for_file(&relative_path, None)
 731                            .await
 732                        {
 733                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
 734                                && &language.name().as_ref() != &"Markdown"
 735                                && language
 736                                    .grammar()
 737                                    .and_then(|grammar| grammar.embedding_config.as_ref())
 738                                    .is_none()
 739                            {
 740                                continue;
 741                            }
 742
 743                            let job_handle = JobHandle::new(&outstanding_job_tx);
 744                            let new_operation = IndexOperation::IndexFile {
 745                                absolute_path: absolute_path.clone(),
 746                                payload: PendingFile {
 747                                    worktree_db_id,
 748                                    relative_path,
 749                                    absolute_path,
 750                                    language,
 751                                    modified_time: entry.mtime,
 752                                    job_handle,
 753                                },
 754                                tx: parsing_files_tx.clone(),
 755                            };
 756                            let _ = job_queue_tx.try_send(new_operation);
 757                        }
 758                    }
 759                    PathChange::Removed => {
 760                        let new_operation = IndexOperation::DeleteFile {
 761                            absolute_path,
 762                            payload: DbOperation::Delete {
 763                                worktree_id: worktree_db_id,
 764                                path: relative_path,
 765                            },
 766                            tx: db_update_tx.clone(),
 767                        };
 768                        let _ = job_queue_tx.try_send(new_operation);
 769                    }
 770                    _ => {}
 771                }
 772            }
 773        })
 774        .detach();
 775
 776        Ok(())
 777    }
 778
 779    pub fn initialize_project(
 780        &mut self,
 781        project: ModelHandle<Project>,
 782        cx: &mut ModelContext<Self>,
 783    ) -> Task<Result<()>> {
 784        log::trace!("Initializing Project for Semantic Index");
 785        let worktree_scans_complete = project
 786            .read(cx)
 787            .worktrees(cx)
 788            .map(|worktree| {
 789                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
 790                async move {
 791                    scan_complete.await;
 792                }
 793            })
 794            .collect::<Vec<_>>();
 795
 796        let worktree_db_ids = project
 797            .read(cx)
 798            .worktrees(cx)
 799            .map(|worktree| {
 800                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
 801            })
 802            .collect::<Vec<_>>();
 803
 804        let _subscription = cx.subscribe(&project, |this, project, event, cx| {
 805            if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
 806                let _ =
 807                    this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id);
 808            };
 809        });
 810
 811        let language_registry = self.language_registry.clone();
 812        let parsing_files_tx = self.parsing_files_tx.clone();
 813        let db_update_tx = self.db_update_tx.clone();
 814
 815        cx.spawn(|this, mut cx| async move {
 816            futures::future::join_all(worktree_scans_complete).await;
 817
 818            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
 819            let worktrees = project.read_with(&cx, |project, cx| {
 820                project
 821                    .worktrees(cx)
 822                    .map(|worktree| worktree.read(cx).snapshot())
 823                    .collect::<Vec<_>>()
 824            });
 825
 826            let mut worktree_file_mtimes = HashMap::new();
 827            let mut db_ids_by_worktree_id = HashMap::new();
 828
 829            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
 830                let db_id = db_id?;
 831                db_ids_by_worktree_id.insert(worktree.id(), db_id);
 832                worktree_file_mtimes.insert(
 833                    worktree.id(),
 834                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
 835                        .await?,
 836                );
 837            }
 838
 839            let worktree_db_ids = db_ids_by_worktree_id
 840                .iter()
 841                .map(|(a, b)| (*a, *b))
 842                .collect();
 843
 844            let (job_count_tx, job_count_rx) = watch::channel_with(0);
 845            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
 846            let job_count_tx_longlived = job_count_tx.clone();
 847
 848            let worktree_files = cx
 849                .background()
 850                .spawn(async move {
 851                    let mut worktree_files = Vec::new();
 852                    for worktree in worktrees.into_iter() {
 853                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
 854                        let worktree_db_id = db_ids_by_worktree_id[&worktree.id()];
 855                        for file in worktree.files(false, 0) {
 856                            let absolute_path = worktree.absolutize(&file.path);
 857
 858                            if file.is_external || file.is_ignored || file.is_symlink {
 859                                continue;
 860                            }
 861
 862                            if let Ok(language) = language_registry
 863                                .language_for_file(&absolute_path, None)
 864                                .await
 865                            {
 866                                // Test if file is valid parseable file
 867                                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
 868                                    && &language.name().as_ref() != &"Markdown"
 869                                    && language
 870                                        .grammar()
 871                                        .and_then(|grammar| grammar.embedding_config.as_ref())
 872                                        .is_none()
 873                                {
 874                                    continue;
 875                                }
 876
 877                                let path_buf = file.path.to_path_buf();
 878                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 879                                let already_stored = stored_mtime
 880                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
 881
 882                                if !already_stored {
 883                                    let job_handle = JobHandle::new(&job_count_tx);
 884                                    worktree_files.push(IndexOperation::IndexFile {
 885                                        absolute_path: absolute_path.clone(),
 886                                        payload: PendingFile {
 887                                            worktree_db_id,
 888                                            relative_path: path_buf,
 889                                            absolute_path,
 890                                            language,
 891                                            job_handle,
 892                                            modified_time: file.mtime,
 893                                        },
 894                                        tx: parsing_files_tx.clone(),
 895                                    });
 896                                }
 897                            }
 898                        }
 899                        // Clean up entries from database that are no longer in the worktree.
 900                        for (path, _) in file_mtimes {
 901                            worktree_files.push(IndexOperation::DeleteFile {
 902                                absolute_path: worktree.absolutize(path.as_path()),
 903                                payload: DbOperation::Delete {
 904                                    worktree_id: worktree_db_id,
 905                                    path,
 906                                },
 907                                tx: db_update_tx.clone(),
 908                            });
 909                        }
 910                    }
 911
 912                    anyhow::Ok(worktree_files)
 913                })
 914                .await?;
 915
 916            this.update(&mut cx, |this, cx| {
 917                let project_state = ProjectState::new(
 918                    cx,
 919                    _subscription,
 920                    worktree_db_ids,
 921                    job_count_rx,
 922                    job_count_tx_longlived,
 923                );
 924
 925                for op in worktree_files {
 926                    let _ = project_state.job_queue_tx.try_send(op);
 927                }
 928
 929                this.projects.insert(project.downgrade(), project_state);
 930            });
 931            Result::<(), _>::Ok(())
 932        })
 933    }
 934
 935    pub fn index_project(
 936        &mut self,
 937        project: ModelHandle<Project>,
 938        cx: &mut ModelContext<Self>,
 939    ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
 940        let state = self.projects.get_mut(&project.downgrade());
 941        let state = if state.is_none() {
 942            return Task::Ready(Some(Err(anyhow!("Project not yet initialized"))));
 943        } else {
 944            state.unwrap()
 945        };
 946
 947        // let parsing_files_tx = self.parsing_files_tx.clone();
 948        // let db_update_tx = self.db_update_tx.clone();
 949        let job_count_rx = state.outstanding_job_count_rx.clone();
 950        let count = state.get_outstanding_count();
 951
 952        cx.spawn(|this, mut cx| async move {
 953            this.update(&mut cx, |this, _| {
 954                let Some(state) = this.projects.get_mut(&project.downgrade()) else {
 955                    return;
 956                };
 957                let _ = state.job_queue_tx.try_send(IndexOperation::FlushQueue);
 958            });
 959
 960            Ok((count, job_count_rx))
 961        })
 962    }
 963
 964    pub fn outstanding_job_count_rx(
 965        &self,
 966        project: &ModelHandle<Project>,
 967    ) -> Option<watch::Receiver<usize>> {
 968        Some(
 969            self.projects
 970                .get(&project.downgrade())?
 971                .outstanding_job_count_rx
 972                .clone(),
 973        )
 974    }
 975
 976    pub fn search_project(
 977        &mut self,
 978        project: ModelHandle<Project>,
 979        phrase: String,
 980        limit: usize,
 981        includes: Vec<PathMatcher>,
 982        excludes: Vec<PathMatcher>,
 983        cx: &mut ModelContext<Self>,
 984    ) -> Task<Result<Vec<SearchResult>>> {
 985        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
 986            state
 987        } else {
 988            return Task::ready(Err(anyhow!("project not added")));
 989        };
 990
 991        let worktree_db_ids = project
 992            .read(cx)
 993            .worktrees(cx)
 994            .filter_map(|worktree| {
 995                let worktree_id = worktree.read(cx).id();
 996                project_state.db_id_for_worktree_id(worktree_id)
 997            })
 998            .collect::<Vec<_>>();
 999
1000        let embedding_provider = self.embedding_provider.clone();
1001        let database_url = self.database_url.clone();
1002        let fs = self.fs.clone();
1003        cx.spawn(|this, mut cx| async move {
1004            let t0 = Instant::now();
1005            let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
1006
1007            let phrase_embedding = embedding_provider
1008                .embed_batch(vec![&phrase])
1009                .await?
1010                .into_iter()
1011                .next()
1012                .unwrap();
1013
1014            log::trace!(
1015                "Embedding search phrase took: {:?} milliseconds",
1016                t0.elapsed().as_millis()
1017            );
1018
1019            let file_ids =
1020                database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
1021
1022            let batch_n = cx.background().num_cpus();
1023            let ids_len = file_ids.clone().len();
1024            let batch_size = if ids_len <= batch_n {
1025                ids_len
1026            } else {
1027                ids_len / batch_n
1028            };
1029
1030            let mut result_tasks = Vec::new();
1031            for batch in file_ids.chunks(batch_size) {
1032                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
1033                let limit = limit.clone();
1034                let fs = fs.clone();
1035                let database_url = database_url.clone();
1036                let phrase_embedding = phrase_embedding.clone();
1037                let task = cx.background().spawn(async move {
1038                    let database = VectorDatabase::new(fs, database_url).await.log_err();
1039                    if database.is_none() {
1040                        return Err(anyhow!("failed to acquire database connection"));
1041                    } else {
1042                        database
1043                            .unwrap()
1044                            .top_k_search(&phrase_embedding, limit, batch.as_slice())
1045                    }
1046                });
1047                result_tasks.push(task);
1048            }
1049
1050            let batch_results = futures::future::join_all(result_tasks).await;
1051
1052            let mut results = Vec::new();
1053            for batch_result in batch_results {
1054                if batch_result.is_ok() {
1055                    for (id, similarity) in batch_result.unwrap() {
1056                        let ix = match results.binary_search_by(|(_, s)| {
1057                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
1058                        }) {
1059                            Ok(ix) => ix,
1060                            Err(ix) => ix,
1061                        };
1062                        results.insert(ix, (id, similarity));
1063                        results.truncate(limit);
1064                    }
1065                }
1066            }
1067
1068            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
1069            let documents = database.get_documents_by_ids(ids.as_slice())?;
1070
1071            let mut tasks = Vec::new();
1072            let mut ranges = Vec::new();
1073            let weak_project = project.downgrade();
1074            project.update(&mut cx, |project, cx| {
1075                for (worktree_db_id, file_path, byte_range) in documents {
1076                    let project_state =
1077                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
1078                            state
1079                        } else {
1080                            return Err(anyhow!("project not added"));
1081                        };
1082                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
1083                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
1084                        ranges.push(byte_range);
1085                    }
1086                }
1087
1088                Ok(())
1089            })?;
1090
1091            let buffers = futures::future::join_all(tasks).await;
1092
1093            log::trace!(
1094                "Semantic Searching took: {:?} milliseconds in total",
1095                t0.elapsed().as_millis()
1096            );
1097
1098            Ok(buffers
1099                .into_iter()
1100                .zip(ranges)
1101                .filter_map(|(buffer, range)| {
1102                    let buffer = buffer.log_err()?;
1103                    let range = buffer.read_with(&cx, |buffer, _| {
1104                        buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
1105                    });
1106                    Some(SearchResult { buffer, range })
1107                })
1108                .collect::<Vec<_>>())
1109        })
1110    }
1111}
1112
1113impl Entity for SemanticIndex {
1114    type Event = ();
1115}
1116
1117impl Drop for JobHandle {
1118    fn drop(&mut self) {
1119        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1120            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1121            if let Some(tx) = inner.upgrade() {
1122                let mut tx = tx.lock();
1123                *tx.borrow_mut() -= 1;
1124            }
1125        }
1126    }
1127}
1128
1129#[cfg(test)]
1130mod tests {
1131
1132    use super::*;
1133    #[test]
1134    fn test_job_handle() {
1135        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1136        let tx = Arc::new(Mutex::new(job_count_tx));
1137        let job_handle = JobHandle::new(&tx);
1138
1139        assert_eq!(1, *job_count_rx.borrow());
1140        let new_job_handle = job_handle.clone();
1141        assert_eq!(1, *job_count_rx.borrow());
1142        drop(job_handle);
1143        assert_eq!(1, *job_count_rx.borrow());
1144        drop(new_job_handle);
1145        assert_eq!(0, *job_count_rx.borrow());
1146    }
1147}