semantic_index.rs

   1mod db;
   2mod embedding;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use anyhow::{anyhow, Result};
  11use db::VectorDatabase;
  12use embedding::{EmbeddingProvider, OpenAIEmbeddings};
  13use futures::{channel::oneshot, Future};
  14use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  15use language::{Anchor, Buffer, Language, LanguageRegistry};
  16use parking_lot::Mutex;
  17use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
  18use postage::watch;
  19use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId};
  20use smol::channel;
  21use std::{
  22    cmp::Ordering,
  23    collections::HashMap,
  24    mem,
  25    ops::Range,
  26    path::{Path, PathBuf},
  27    sync::{Arc, Weak},
  28    time::{Instant, SystemTime},
  29};
  30use util::{
  31    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
  32    http::HttpClient,
  33    paths::EMBEDDINGS_DIR,
  34    ResultExt,
  35};
  36use workspace::WorkspaceCreated;
  37
  38const SEMANTIC_INDEX_VERSION: usize = 7;
  39const EMBEDDINGS_BATCH_SIZE: usize = 80;
  40
  41pub fn init(
  42    fs: Arc<dyn Fs>,
  43    http_client: Arc<dyn HttpClient>,
  44    language_registry: Arc<LanguageRegistry>,
  45    cx: &mut AppContext,
  46) {
  47    settings::register::<SemanticIndexSettings>(cx);
  48
  49    let db_file_path = EMBEDDINGS_DIR
  50        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  51        .join("embeddings_db");
  52
  53    // This needs to be removed at some point before stable.
  54    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
  55        return;
  56    }
  57
  58    cx.subscribe_global::<WorkspaceCreated, _>({
  59        move |event, cx| {
  60            let Some(semantic_index) = SemanticIndex::global(cx) else {
  61                return;
  62            };
  63            let workspace = &event.0;
  64            if let Some(workspace) = workspace.upgrade(cx) {
  65                let project = workspace.read(cx).project().clone();
  66                if project.read(cx).is_local() {
  67                    semantic_index.update(cx, |index, cx| {
  68                        index.initialize_project(project, cx).detach_and_log_err(cx)
  69                    });
  70                }
  71            }
  72        }
  73    })
  74    .detach();
  75
  76    cx.spawn(move |mut cx| async move {
  77        let semantic_index = SemanticIndex::new(
  78            fs,
  79            db_file_path,
  80            Arc::new(OpenAIEmbeddings {
  81                client: http_client,
  82                executor: cx.background(),
  83            }),
  84            language_registry,
  85            cx.clone(),
  86        )
  87        .await?;
  88
  89        cx.update(|cx| {
  90            cx.set_global(semantic_index.clone());
  91        });
  92
  93        anyhow::Ok(())
  94    })
  95    .detach();
  96}
  97
  98pub struct SemanticIndex {
  99    fs: Arc<dyn Fs>,
 100    database_url: Arc<PathBuf>,
 101    embedding_provider: Arc<dyn EmbeddingProvider>,
 102    language_registry: Arc<LanguageRegistry>,
 103    db_update_tx: channel::Sender<DbOperation>,
 104    parsing_files_tx: channel::Sender<PendingFile>,
 105    _db_update_task: Task<()>,
 106    _embed_batch_tasks: Vec<Task<()>>,
 107    _batch_files_task: Task<()>,
 108    _parsing_files_tasks: Vec<Task<()>>,
 109    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 110}
 111
 112struct ProjectState {
 113    worktree_db_ids: Vec<(WorktreeId, i64)>,
 114    _subscription: gpui::Subscription,
 115    outstanding_job_count_rx: watch::Receiver<usize>,
 116    _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 117    job_queue_tx: channel::Sender<IndexOperation>,
 118    _queue_update_task: Task<()>,
 119}
 120
 121#[derive(Clone)]
 122struct JobHandle {
 123    /// The outer Arc is here to count the clones of a JobHandle instance;
 124    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 125    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 126}
 127
 128impl JobHandle {
 129    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 130        *tx.lock().borrow_mut() += 1;
 131        Self {
 132            tx: Arc::new(Arc::downgrade(&tx)),
 133        }
 134    }
 135}
 136impl ProjectState {
 137    fn new(
 138        cx: &mut AppContext,
 139        subscription: gpui::Subscription,
 140        worktree_db_ids: Vec<(WorktreeId, i64)>,
 141        outstanding_job_count_rx: watch::Receiver<usize>,
 142        _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 143    ) -> Self {
 144        let (job_queue_tx, job_queue_rx) = channel::unbounded();
 145        let _queue_update_task = cx.background().spawn({
 146            let mut worktree_queue = HashMap::new();
 147            async move {
 148                while let Ok(operation) = job_queue_rx.recv().await {
 149                    Self::update_queue(&mut worktree_queue, operation);
 150                }
 151            }
 152        });
 153
 154        Self {
 155            worktree_db_ids,
 156            outstanding_job_count_rx,
 157            _outstanding_job_count_tx,
 158            _subscription: subscription,
 159            _queue_update_task,
 160            job_queue_tx,
 161        }
 162    }
 163
 164    pub fn get_outstanding_count(&self) -> usize {
 165        self.outstanding_job_count_rx.borrow().clone()
 166    }
 167
 168    fn update_queue(queue: &mut HashMap<PathBuf, IndexOperation>, operation: IndexOperation) {
 169        match operation {
 170            IndexOperation::FlushQueue => {
 171                let queue = std::mem::take(queue);
 172                for (_, op) in queue {
 173                    match op {
 174                        IndexOperation::IndexFile {
 175                            absolute_path: _,
 176                            payload,
 177                            tx,
 178                        } => {
 179                            let _ = tx.try_send(payload);
 180                        }
 181                        IndexOperation::DeleteFile {
 182                            absolute_path: _,
 183                            payload,
 184                            tx,
 185                        } => {
 186                            let _ = tx.try_send(payload);
 187                        }
 188                        _ => {}
 189                    }
 190                }
 191            }
 192            IndexOperation::IndexFile {
 193                ref absolute_path, ..
 194            }
 195            | IndexOperation::DeleteFile {
 196                ref absolute_path, ..
 197            } => {
 198                queue.insert(absolute_path.clone(), operation);
 199            }
 200        }
 201    }
 202
 203    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
 204        self.worktree_db_ids
 205            .iter()
 206            .find_map(|(worktree_id, db_id)| {
 207                if *worktree_id == id {
 208                    Some(*db_id)
 209                } else {
 210                    None
 211                }
 212            })
 213    }
 214
 215    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 216        self.worktree_db_ids
 217            .iter()
 218            .find_map(|(worktree_id, db_id)| {
 219                if *db_id == id {
 220                    Some(*worktree_id)
 221                } else {
 222                    None
 223                }
 224            })
 225    }
 226}
 227
 228#[derive(Clone)]
 229pub struct PendingFile {
 230    worktree_db_id: i64,
 231    relative_path: PathBuf,
 232    absolute_path: PathBuf,
 233    language: Arc<Language>,
 234    modified_time: SystemTime,
 235    job_handle: JobHandle,
 236}
 237enum IndexOperation {
 238    IndexFile {
 239        absolute_path: PathBuf,
 240        payload: PendingFile,
 241        tx: channel::Sender<PendingFile>,
 242    },
 243    DeleteFile {
 244        absolute_path: PathBuf,
 245        payload: DbOperation,
 246        tx: channel::Sender<DbOperation>,
 247    },
 248    FlushQueue,
 249}
 250
 251pub struct SearchResult {
 252    pub buffer: ModelHandle<Buffer>,
 253    pub range: Range<Anchor>,
 254}
 255
 256enum DbOperation {
 257    InsertFile {
 258        worktree_id: i64,
 259        documents: Vec<Document>,
 260        path: PathBuf,
 261        mtime: SystemTime,
 262        job_handle: JobHandle,
 263    },
 264    Delete {
 265        worktree_id: i64,
 266        path: PathBuf,
 267    },
 268    FindOrCreateWorktree {
 269        path: PathBuf,
 270        sender: oneshot::Sender<Result<i64>>,
 271    },
 272    FileMTimes {
 273        worktree_id: i64,
 274        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
 275    },
 276    WorktreePreviouslyIndexed {
 277        path: Arc<Path>,
 278        sender: oneshot::Sender<Result<bool>>,
 279    },
 280}
 281
 282enum EmbeddingJob {
 283    Enqueue {
 284        worktree_id: i64,
 285        path: PathBuf,
 286        mtime: SystemTime,
 287        documents: Vec<Document>,
 288        job_handle: JobHandle,
 289    },
 290    Flush,
 291}
 292
 293impl SemanticIndex {
 294    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
 295        if cx.has_global::<ModelHandle<Self>>() {
 296            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 297        } else {
 298            None
 299        }
 300    }
 301
 302    pub fn enabled(cx: &AppContext) -> bool {
 303        settings::get::<SemanticIndexSettings>(cx).enabled
 304            && *RELEASE_CHANNEL != ReleaseChannel::Stable
 305    }
 306
 307    async fn new(
 308        fs: Arc<dyn Fs>,
 309        database_url: PathBuf,
 310        embedding_provider: Arc<dyn EmbeddingProvider>,
 311        language_registry: Arc<LanguageRegistry>,
 312        mut cx: AsyncAppContext,
 313    ) -> Result<ModelHandle<Self>> {
 314        let t0 = Instant::now();
 315        let database_url = Arc::new(database_url);
 316
 317        let db = cx
 318            .background()
 319            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
 320            .await?;
 321
 322        log::trace!(
 323            "db initialization took {:?} milliseconds",
 324            t0.elapsed().as_millis()
 325        );
 326
 327        Ok(cx.add_model(|cx| {
 328            let t0 = Instant::now();
 329            // Perform database operations
 330            let (db_update_tx, db_update_rx) = channel::unbounded();
 331            let _db_update_task = cx.background().spawn({
 332                async move {
 333                    while let Ok(job) = db_update_rx.recv().await {
 334                        Self::run_db_operation(&db, job)
 335                    }
 336                }
 337            });
 338
 339            // Group documents into batches and send them to the embedding provider.
 340            let (embed_batch_tx, embed_batch_rx) =
 341                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
 342            let mut _embed_batch_tasks = Vec::new();
 343            for _ in 0..cx.background().num_cpus() {
 344                let embed_batch_rx = embed_batch_rx.clone();
 345                _embed_batch_tasks.push(cx.background().spawn({
 346                    let db_update_tx = db_update_tx.clone();
 347                    let embedding_provider = embedding_provider.clone();
 348                    async move {
 349                        while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
 350                            Self::compute_embeddings_for_batch(
 351                                embeddings_queue,
 352                                &embedding_provider,
 353                                &db_update_tx,
 354                            )
 355                            .await;
 356                        }
 357                    }
 358                }));
 359            }
 360
 361            // Group documents into batches and send them to the embedding provider.
 362            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
 363            let _batch_files_task = cx.background().spawn(async move {
 364                let mut queue_len = 0;
 365                let mut embeddings_queue = vec![];
 366                while let Ok(job) = batch_files_rx.recv().await {
 367                    Self::enqueue_documents_to_embed(
 368                        job,
 369                        &mut queue_len,
 370                        &mut embeddings_queue,
 371                        &embed_batch_tx,
 372                    );
 373                }
 374            });
 375
 376            // Parse files into embeddable documents.
 377            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
 378            let mut _parsing_files_tasks = Vec::new();
 379            for _ in 0..cx.background().num_cpus() {
 380                let fs = fs.clone();
 381                let parsing_files_rx = parsing_files_rx.clone();
 382                let batch_files_tx = batch_files_tx.clone();
 383                let db_update_tx = db_update_tx.clone();
 384                _parsing_files_tasks.push(cx.background().spawn(async move {
 385                    let mut retriever = CodeContextRetriever::new();
 386                    while let Ok(pending_file) = parsing_files_rx.recv().await {
 387                        Self::parse_file(
 388                            &fs,
 389                            pending_file,
 390                            &mut retriever,
 391                            &batch_files_tx,
 392                            &parsing_files_rx,
 393                            &db_update_tx,
 394                        )
 395                        .await;
 396                    }
 397                }));
 398            }
 399
 400            log::trace!(
 401                "semantic index task initialization took {:?} milliseconds",
 402                t0.elapsed().as_millis()
 403            );
 404            Self {
 405                fs,
 406                database_url,
 407                embedding_provider,
 408                language_registry,
 409                db_update_tx,
 410                parsing_files_tx,
 411                _db_update_task,
 412                _embed_batch_tasks,
 413                _batch_files_task,
 414                _parsing_files_tasks,
 415                projects: HashMap::new(),
 416            }
 417        }))
 418    }
 419
 420    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
 421        match job {
 422            DbOperation::InsertFile {
 423                worktree_id,
 424                documents,
 425                path,
 426                mtime,
 427                job_handle,
 428            } => {
 429                db.insert_file(worktree_id, path, mtime, documents)
 430                    .log_err();
 431                drop(job_handle)
 432            }
 433            DbOperation::Delete { worktree_id, path } => {
 434                db.delete_file(worktree_id, path).log_err();
 435            }
 436            DbOperation::FindOrCreateWorktree { path, sender } => {
 437                let id = db.find_or_create_worktree(&path);
 438                sender.send(id).ok();
 439            }
 440            DbOperation::FileMTimes {
 441                worktree_id: worktree_db_id,
 442                sender,
 443            } => {
 444                let file_mtimes = db.get_file_mtimes(worktree_db_id);
 445                sender.send(file_mtimes).ok();
 446            }
 447            DbOperation::WorktreePreviouslyIndexed { path, sender } => {
 448                let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
 449                sender.send(worktree_indexed).ok();
 450            }
 451        }
 452    }
 453
 454    async fn compute_embeddings_for_batch(
 455        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
 456        embedding_provider: &Arc<dyn EmbeddingProvider>,
 457        db_update_tx: &channel::Sender<DbOperation>,
 458    ) {
 459        let mut batch_documents = vec![];
 460        for (_, documents, _, _, _) in embeddings_queue.iter() {
 461            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
 462        }
 463
 464        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
 465            log::trace!(
 466                "created {} embeddings for {} files",
 467                embeddings.len(),
 468                embeddings_queue.len(),
 469            );
 470
 471            let mut i = 0;
 472            let mut j = 0;
 473
 474            for embedding in embeddings.iter() {
 475                while embeddings_queue[i].1.len() == j {
 476                    i += 1;
 477                    j = 0;
 478                }
 479
 480                embeddings_queue[i].1[j].embedding = embedding.to_owned();
 481                j += 1;
 482            }
 483
 484            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
 485                db_update_tx
 486                    .send(DbOperation::InsertFile {
 487                        worktree_id,
 488                        documents,
 489                        path,
 490                        mtime,
 491                        job_handle,
 492                    })
 493                    .await
 494                    .unwrap();
 495            }
 496        } else {
 497            // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
 498            for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
 499                db_update_tx
 500                    .send(DbOperation::InsertFile {
 501                        worktree_id,
 502                        documents: vec![],
 503                        path,
 504                        mtime,
 505                        job_handle,
 506                    })
 507                    .await
 508                    .unwrap();
 509            }
 510        }
 511    }
 512
 513    fn enqueue_documents_to_embed(
 514        job: EmbeddingJob,
 515        queue_len: &mut usize,
 516        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
 517        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
 518    ) {
 519        // Handle edge case where individual file has more documents than max batch size
 520        let should_flush = match job {
 521            EmbeddingJob::Enqueue {
 522                documents,
 523                worktree_id,
 524                path,
 525                mtime,
 526                job_handle,
 527            } => {
 528                // If documents is greater than embeddings batch size, recursively batch existing rows.
 529                if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
 530                    let first_job = EmbeddingJob::Enqueue {
 531                        documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
 532                        worktree_id,
 533                        path: path.clone(),
 534                        mtime,
 535                        job_handle: job_handle.clone(),
 536                    };
 537
 538                    Self::enqueue_documents_to_embed(
 539                        first_job,
 540                        queue_len,
 541                        embeddings_queue,
 542                        embed_batch_tx,
 543                    );
 544
 545                    let second_job = EmbeddingJob::Enqueue {
 546                        documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
 547                        worktree_id,
 548                        path: path.clone(),
 549                        mtime,
 550                        job_handle: job_handle.clone(),
 551                    };
 552
 553                    Self::enqueue_documents_to_embed(
 554                        second_job,
 555                        queue_len,
 556                        embeddings_queue,
 557                        embed_batch_tx,
 558                    );
 559                    return;
 560                } else {
 561                    *queue_len += &documents.len();
 562                    embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
 563                    *queue_len >= EMBEDDINGS_BATCH_SIZE
 564                }
 565            }
 566            EmbeddingJob::Flush => true,
 567        };
 568
 569        if should_flush {
 570            embed_batch_tx
 571                .try_send(mem::take(embeddings_queue))
 572                .unwrap();
 573            *queue_len = 0;
 574        }
 575    }
 576
 577    async fn parse_file(
 578        fs: &Arc<dyn Fs>,
 579        pending_file: PendingFile,
 580        retriever: &mut CodeContextRetriever,
 581        batch_files_tx: &channel::Sender<EmbeddingJob>,
 582        parsing_files_rx: &channel::Receiver<PendingFile>,
 583        db_update_tx: &channel::Sender<DbOperation>,
 584    ) {
 585        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 586            if let Some(documents) = retriever
 587                .parse_file_with_template(
 588                    &pending_file.relative_path,
 589                    &content,
 590                    pending_file.language,
 591                )
 592                .log_err()
 593            {
 594                log::trace!(
 595                    "parsed path {:?}: {} documents",
 596                    pending_file.relative_path,
 597                    documents.len()
 598                );
 599
 600                if documents.len() == 0 {
 601                    db_update_tx
 602                        .send(DbOperation::InsertFile {
 603                            worktree_id: pending_file.worktree_db_id,
 604                            documents,
 605                            path: pending_file.relative_path,
 606                            mtime: pending_file.modified_time,
 607                            job_handle: pending_file.job_handle,
 608                        })
 609                        .await
 610                        .unwrap();
 611                } else {
 612                    batch_files_tx
 613                        .try_send(EmbeddingJob::Enqueue {
 614                            worktree_id: pending_file.worktree_db_id,
 615                            path: pending_file.relative_path,
 616                            mtime: pending_file.modified_time,
 617                            job_handle: pending_file.job_handle,
 618                            documents,
 619                        })
 620                        .unwrap();
 621                }
 622            }
 623        }
 624
 625        if parsing_files_rx.len() == 0 {
 626            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
 627        }
 628    }
 629
 630    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
 631        let (tx, rx) = oneshot::channel();
 632        self.db_update_tx
 633            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
 634            .unwrap();
 635        async move { rx.await? }
 636    }
 637
 638    fn get_file_mtimes(
 639        &self,
 640        worktree_id: i64,
 641    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
 642        let (tx, rx) = oneshot::channel();
 643        self.db_update_tx
 644            .try_send(DbOperation::FileMTimes {
 645                worktree_id,
 646                sender: tx,
 647            })
 648            .unwrap();
 649        async move { rx.await? }
 650    }
 651
 652    fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
 653        let (tx, rx) = oneshot::channel();
 654        self.db_update_tx
 655            .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
 656            .unwrap();
 657        async move { rx.await? }
 658    }
 659
 660    pub fn project_previously_indexed(
 661        &mut self,
 662        project: ModelHandle<Project>,
 663        cx: &mut ModelContext<Self>,
 664    ) -> Task<Result<bool>> {
 665        let worktrees_indexed_previously = project
 666            .read(cx)
 667            .worktrees(cx)
 668            .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
 669            .collect::<Vec<_>>();
 670        cx.spawn(|_, _cx| async move {
 671            let worktree_indexed_previously =
 672                futures::future::join_all(worktrees_indexed_previously).await;
 673
 674            Ok(worktree_indexed_previously
 675                .iter()
 676                .filter(|worktree| worktree.is_ok())
 677                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 678        })
 679    }
 680
 681    fn project_entries_changed(
 682        &self,
 683        project: ModelHandle<Project>,
 684        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 685        cx: &mut ModelContext<'_, SemanticIndex>,
 686        worktree_id: &WorktreeId,
 687    ) -> Result<()> {
 688        let parsing_files_tx = self.parsing_files_tx.clone();
 689        let db_update_tx = self.db_update_tx.clone();
 690        let (job_queue_tx, outstanding_job_tx, worktree_db_id) = {
 691            let state = self
 692                .projects
 693                .get(&project.downgrade())
 694                .ok_or(anyhow!("Project not yet initialized"))?;
 695            let worktree_db_id = state
 696                .db_id_for_worktree_id(*worktree_id)
 697                .ok_or(anyhow!("Worktree ID in Database Not Available"))?;
 698            (
 699                state.job_queue_tx.clone(),
 700                state._outstanding_job_count_tx.clone(),
 701                worktree_db_id,
 702            )
 703        };
 704
 705        let language_registry = self.language_registry.clone();
 706        let parsing_files_tx = parsing_files_tx.clone();
 707        let db_update_tx = db_update_tx.clone();
 708
 709        let worktree = project
 710            .read(cx)
 711            .worktree_for_id(worktree_id.clone(), cx)
 712            .ok_or(anyhow!("Worktree not available"))?
 713            .read(cx)
 714            .snapshot();
 715        cx.spawn(|_, _| async move {
 716            let worktree = worktree.clone();
 717            for (path, entry_id, path_change) in changes.iter() {
 718                let relative_path = path.to_path_buf();
 719                let absolute_path = worktree.absolutize(path);
 720
 721                let Some(entry) = worktree.entry_for_id(*entry_id) else {
 722                    continue;
 723                };
 724                if entry.is_ignored || entry.is_symlink || entry.is_external {
 725                    continue;
 726                }
 727
 728                log::trace!("File Event: {:?}, Path: {:?}", &path_change, &path);
 729                match path_change {
 730                    PathChange::AddedOrUpdated | PathChange::Updated | PathChange::Added => {
 731                        if let Ok(language) = language_registry
 732                            .language_for_file(&relative_path, None)
 733                            .await
 734                        {
 735                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
 736                                && &language.name().as_ref() != &"Markdown"
 737                                && language
 738                                    .grammar()
 739                                    .and_then(|grammar| grammar.embedding_config.as_ref())
 740                                    .is_none()
 741                            {
 742                                continue;
 743                            }
 744
 745                            let job_handle = JobHandle::new(&outstanding_job_tx);
 746                            let new_operation = IndexOperation::IndexFile {
 747                                absolute_path: absolute_path.clone(),
 748                                payload: PendingFile {
 749                                    worktree_db_id,
 750                                    relative_path,
 751                                    absolute_path,
 752                                    language,
 753                                    modified_time: entry.mtime,
 754                                    job_handle,
 755                                },
 756                                tx: parsing_files_tx.clone(),
 757                            };
 758                            let _ = job_queue_tx.try_send(new_operation);
 759                        }
 760                    }
 761                    PathChange::Removed => {
 762                        let new_operation = IndexOperation::DeleteFile {
 763                            absolute_path,
 764                            payload: DbOperation::Delete {
 765                                worktree_id: worktree_db_id,
 766                                path: relative_path,
 767                            },
 768                            tx: db_update_tx.clone(),
 769                        };
 770                        let _ = job_queue_tx.try_send(new_operation);
 771                    }
 772                    _ => {}
 773                }
 774            }
 775        })
 776        .detach();
 777
 778        Ok(())
 779    }
 780
 781    pub fn initialize_project(
 782        &mut self,
 783        project: ModelHandle<Project>,
 784        cx: &mut ModelContext<Self>,
 785    ) -> Task<Result<()>> {
 786        log::trace!("Initializing Project for Semantic Index");
 787        let worktree_scans_complete = project
 788            .read(cx)
 789            .worktrees(cx)
 790            .map(|worktree| {
 791                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
 792                async move {
 793                    scan_complete.await;
 794                }
 795            })
 796            .collect::<Vec<_>>();
 797
 798        let worktree_db_ids = project
 799            .read(cx)
 800            .worktrees(cx)
 801            .map(|worktree| {
 802                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
 803            })
 804            .collect::<Vec<_>>();
 805
 806        let _subscription = cx.subscribe(&project, |this, project, event, cx| {
 807            if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
 808                let _ =
 809                    this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id);
 810            };
 811        });
 812
 813        let language_registry = self.language_registry.clone();
 814        let parsing_files_tx = self.parsing_files_tx.clone();
 815        let db_update_tx = self.db_update_tx.clone();
 816
 817        cx.spawn(|this, mut cx| async move {
 818            futures::future::join_all(worktree_scans_complete).await;
 819
 820            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
 821            let worktrees = project.read_with(&cx, |project, cx| {
 822                project
 823                    .worktrees(cx)
 824                    .map(|worktree| worktree.read(cx).snapshot())
 825                    .collect::<Vec<_>>()
 826            });
 827
 828            let mut worktree_file_mtimes = HashMap::new();
 829            let mut db_ids_by_worktree_id = HashMap::new();
 830
 831            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
 832                let db_id = db_id?;
 833                db_ids_by_worktree_id.insert(worktree.id(), db_id);
 834                worktree_file_mtimes.insert(
 835                    worktree.id(),
 836                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
 837                        .await?,
 838                );
 839            }
 840
 841            let worktree_db_ids = db_ids_by_worktree_id
 842                .iter()
 843                .map(|(a, b)| (*a, *b))
 844                .collect();
 845
 846            let (job_count_tx, job_count_rx) = watch::channel_with(0);
 847            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
 848            let job_count_tx_longlived = job_count_tx.clone();
 849
 850            let worktree_files = cx
 851                .background()
 852                .spawn(async move {
 853                    let mut worktree_files = Vec::new();
 854                    for worktree in worktrees.into_iter() {
 855                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
 856                        let worktree_db_id = db_ids_by_worktree_id[&worktree.id()];
 857                        for file in worktree.files(false, 0) {
 858                            let absolute_path = worktree.absolutize(&file.path);
 859
 860                            if file.is_external || file.is_ignored || file.is_symlink {
 861                                continue;
 862                            }
 863
 864                            if let Ok(language) = language_registry
 865                                .language_for_file(&absolute_path, None)
 866                                .await
 867                            {
 868                                // Test if file is valid parseable file
 869                                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
 870                                    && &language.name().as_ref() != &"Markdown"
 871                                    && language
 872                                        .grammar()
 873                                        .and_then(|grammar| grammar.embedding_config.as_ref())
 874                                        .is_none()
 875                                {
 876                                    continue;
 877                                }
 878
 879                                let path_buf = file.path.to_path_buf();
 880                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 881                                let already_stored = stored_mtime
 882                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
 883
 884                                if !already_stored {
 885                                    let job_handle = JobHandle::new(&job_count_tx);
 886                                    worktree_files.push(IndexOperation::IndexFile {
 887                                        absolute_path: absolute_path.clone(),
 888                                        payload: PendingFile {
 889                                            worktree_db_id,
 890                                            relative_path: path_buf,
 891                                            absolute_path,
 892                                            language,
 893                                            job_handle,
 894                                            modified_time: file.mtime,
 895                                        },
 896                                        tx: parsing_files_tx.clone(),
 897                                    });
 898                                }
 899                            }
 900                        }
 901                        // Clean up entries from database that are no longer in the worktree.
 902                        for (path, _) in file_mtimes {
 903                            worktree_files.push(IndexOperation::DeleteFile {
 904                                absolute_path: worktree.absolutize(path.as_path()),
 905                                payload: DbOperation::Delete {
 906                                    worktree_id: worktree_db_id,
 907                                    path,
 908                                },
 909                                tx: db_update_tx.clone(),
 910                            });
 911                        }
 912                    }
 913
 914                    anyhow::Ok(worktree_files)
 915                })
 916                .await?;
 917
 918            this.update(&mut cx, |this, cx| {
 919                let project_state = ProjectState::new(
 920                    cx,
 921                    _subscription,
 922                    worktree_db_ids,
 923                    job_count_rx,
 924                    job_count_tx_longlived,
 925                );
 926
 927                for op in worktree_files {
 928                    let _ = project_state.job_queue_tx.try_send(op);
 929                }
 930
 931                this.projects.insert(project.downgrade(), project_state);
 932            });
 933            Result::<(), _>::Ok(())
 934        })
 935    }
 936
 937    pub fn index_project(
 938        &mut self,
 939        project: ModelHandle<Project>,
 940        cx: &mut ModelContext<Self>,
 941    ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
 942        let state = self.projects.get_mut(&project.downgrade());
 943        let state = if state.is_none() {
 944            return Task::Ready(Some(Err(anyhow!("Project not yet initialized"))));
 945        } else {
 946            state.unwrap()
 947        };
 948
 949        // let parsing_files_tx = self.parsing_files_tx.clone();
 950        // let db_update_tx = self.db_update_tx.clone();
 951        let job_count_rx = state.outstanding_job_count_rx.clone();
 952        let count = state.get_outstanding_count();
 953
 954        cx.spawn(|this, mut cx| async move {
 955            this.update(&mut cx, |this, _| {
 956                let Some(state) = this.projects.get_mut(&project.downgrade()) else {
 957                    return;
 958                };
 959                let _ = state.job_queue_tx.try_send(IndexOperation::FlushQueue);
 960            });
 961
 962            Ok((count, job_count_rx))
 963        })
 964    }
 965
 966    pub fn outstanding_job_count_rx(
 967        &self,
 968        project: &ModelHandle<Project>,
 969    ) -> Option<watch::Receiver<usize>> {
 970        Some(
 971            self.projects
 972                .get(&project.downgrade())?
 973                .outstanding_job_count_rx
 974                .clone(),
 975        )
 976    }
 977
 978    pub fn search_project(
 979        &mut self,
 980        project: ModelHandle<Project>,
 981        phrase: String,
 982        limit: usize,
 983        includes: Vec<PathMatcher>,
 984        excludes: Vec<PathMatcher>,
 985        cx: &mut ModelContext<Self>,
 986    ) -> Task<Result<Vec<SearchResult>>> {
 987        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
 988            state
 989        } else {
 990            return Task::ready(Err(anyhow!("project not added")));
 991        };
 992
 993        let worktree_db_ids = project
 994            .read(cx)
 995            .worktrees(cx)
 996            .filter_map(|worktree| {
 997                let worktree_id = worktree.read(cx).id();
 998                project_state.db_id_for_worktree_id(worktree_id)
 999            })
1000            .collect::<Vec<_>>();
1001
1002        let embedding_provider = self.embedding_provider.clone();
1003        let database_url = self.database_url.clone();
1004        let fs = self.fs.clone();
1005        cx.spawn(|this, mut cx| async move {
1006            let t0 = Instant::now();
1007            let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
1008
1009            let phrase_embedding = embedding_provider
1010                .embed_batch(vec![&phrase])
1011                .await?
1012                .into_iter()
1013                .next()
1014                .unwrap();
1015
1016            log::trace!(
1017                "Embedding search phrase took: {:?} milliseconds",
1018                t0.elapsed().as_millis()
1019            );
1020
1021            let file_ids =
1022                database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
1023
1024            let batch_n = cx.background().num_cpus();
1025            let ids_len = file_ids.clone().len();
1026            let batch_size = if ids_len <= batch_n {
1027                ids_len
1028            } else {
1029                ids_len / batch_n
1030            };
1031
1032            let mut result_tasks = Vec::new();
1033            for batch in file_ids.chunks(batch_size) {
1034                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
1035                let limit = limit.clone();
1036                let fs = fs.clone();
1037                let database_url = database_url.clone();
1038                let phrase_embedding = phrase_embedding.clone();
1039                let task = cx.background().spawn(async move {
1040                    let database = VectorDatabase::new(fs, database_url).await.log_err();
1041                    if database.is_none() {
1042                        return Err(anyhow!("failed to acquire database connection"));
1043                    } else {
1044                        database
1045                            .unwrap()
1046                            .top_k_search(&phrase_embedding, limit, batch.as_slice())
1047                    }
1048                });
1049                result_tasks.push(task);
1050            }
1051
1052            let batch_results = futures::future::join_all(result_tasks).await;
1053
1054            let mut results = Vec::new();
1055            for batch_result in batch_results {
1056                if batch_result.is_ok() {
1057                    for (id, similarity) in batch_result.unwrap() {
1058                        let ix = match results.binary_search_by(|(_, s)| {
1059                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
1060                        }) {
1061                            Ok(ix) => ix,
1062                            Err(ix) => ix,
1063                        };
1064                        results.insert(ix, (id, similarity));
1065                        results.truncate(limit);
1066                    }
1067                }
1068            }
1069
1070            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
1071            let documents = database.get_documents_by_ids(ids.as_slice())?;
1072
1073            let mut tasks = Vec::new();
1074            let mut ranges = Vec::new();
1075            let weak_project = project.downgrade();
1076            project.update(&mut cx, |project, cx| {
1077                for (worktree_db_id, file_path, byte_range) in documents {
1078                    let project_state =
1079                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
1080                            state
1081                        } else {
1082                            return Err(anyhow!("project not added"));
1083                        };
1084                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
1085                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
1086                        ranges.push(byte_range);
1087                    }
1088                }
1089
1090                Ok(())
1091            })?;
1092
1093            let buffers = futures::future::join_all(tasks).await;
1094
1095            log::trace!(
1096                "Semantic Searching took: {:?} milliseconds in total",
1097                t0.elapsed().as_millis()
1098            );
1099
1100            Ok(buffers
1101                .into_iter()
1102                .zip(ranges)
1103                .filter_map(|(buffer, range)| {
1104                    let buffer = buffer.log_err()?;
1105                    let range = buffer.read_with(&cx, |buffer, _| {
1106                        buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
1107                    });
1108                    Some(SearchResult { buffer, range })
1109                })
1110                .collect::<Vec<_>>())
1111        })
1112    }
1113}
1114
1115impl Entity for SemanticIndex {
1116    type Event = ();
1117}
1118
1119impl Drop for JobHandle {
1120    fn drop(&mut self) {
1121        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1122            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1123            if let Some(tx) = inner.upgrade() {
1124                let mut tx = tx.lock();
1125                *tx.borrow_mut() -= 1;
1126            }
1127        }
1128    }
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133
1134    use super::*;
1135    #[test]
1136    fn test_job_handle() {
1137        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1138        let tx = Arc::new(Mutex::new(job_count_tx));
1139        let job_handle = JobHandle::new(&tx);
1140
1141        assert_eq!(1, *job_count_rx.borrow());
1142        let new_job_handle = job_handle.clone();
1143        assert_eq!(1, *job_count_rx.borrow());
1144        drop(job_handle);
1145        assert_eq!(1, *job_count_rx.borrow());
1146        drop(new_job_handle);
1147        assert_eq!(0, *job_count_rx.borrow());
1148    }
1149}