semantic_index.rs

   1mod db;
   2pub mod embedding;
   3mod embedding_queue;
   4mod parsing;
   5pub mod semantic_index_settings;
   6
   7#[cfg(test)]
   8mod semantic_index_tests;
   9
  10use crate::semantic_index_settings::SemanticIndexSettings;
  11use anyhow::{anyhow, Result};
  12use collections::{BTreeMap, HashMap, HashSet};
  13use db::VectorDatabase;
  14use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  18use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  19use ordered_float::OrderedFloat;
  20use parking_lot::Mutex;
  21use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  22use postage::watch;
  23use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  24use smol::channel;
  25use std::{
  26    cmp::Reverse,
  27    future::Future,
  28    mem,
  29    ops::Range,
  30    path::{Path, PathBuf},
  31    sync::{Arc, Weak},
  32    time::{Duration, Instant, SystemTime},
  33};
  34use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  35use workspace::WorkspaceCreated;
  36
  37const SEMANTIC_INDEX_VERSION: usize = 11;
  38const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  39const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  40
  41pub fn init(
  42    fs: Arc<dyn Fs>,
  43    http_client: Arc<dyn HttpClient>,
  44    language_registry: Arc<LanguageRegistry>,
  45    cx: &mut AppContext,
  46) {
  47    settings::register::<SemanticIndexSettings>(cx);
  48
  49    let db_file_path = EMBEDDINGS_DIR
  50        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  51        .join("embeddings_db");
  52
  53    cx.subscribe_global::<WorkspaceCreated, _>({
  54        move |event, cx| {
  55            let Some(semantic_index) = SemanticIndex::global(cx) else {
  56                return;
  57            };
  58            let workspace = &event.0;
  59            if let Some(workspace) = workspace.upgrade(cx) {
  60                let project = workspace.read(cx).project().clone();
  61                if project.read(cx).is_local() {
  62                    cx.spawn(|mut cx| async move {
  63                        let previously_indexed = semantic_index
  64                            .update(&mut cx, |index, cx| {
  65                                index.project_previously_indexed(&project, cx)
  66                            })
  67                            .await?;
  68                        if previously_indexed {
  69                            semantic_index
  70                                .update(&mut cx, |index, cx| index.index_project(project, cx))
  71                                .await?;
  72                        }
  73                        anyhow::Ok(())
  74                    })
  75                    .detach_and_log_err(cx);
  76                }
  77            }
  78        }
  79    })
  80    .detach();
  81
  82    cx.spawn(move |mut cx| async move {
  83        let semantic_index = SemanticIndex::new(
  84            fs,
  85            db_file_path,
  86            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
  87            language_registry,
  88            cx.clone(),
  89        )
  90        .await?;
  91
  92        cx.update(|cx| {
  93            cx.set_global(semantic_index.clone());
  94        });
  95
  96        anyhow::Ok(())
  97    })
  98    .detach();
  99}
 100
 101#[derive(Copy, Clone, Debug)]
 102pub enum SemanticIndexStatus {
 103    NotIndexed,
 104    Indexed,
 105    Indexing {
 106        remaining_files: usize,
 107        rate_limit_expiry: Option<Instant>,
 108    },
 109}
 110
 111pub struct SemanticIndex {
 112    fs: Arc<dyn Fs>,
 113    db: VectorDatabase,
 114    embedding_provider: Arc<dyn EmbeddingProvider>,
 115    language_registry: Arc<LanguageRegistry>,
 116    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 117    _embedding_task: Task<()>,
 118    _parsing_files_tasks: Vec<Task<()>>,
 119    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 120}
 121
 122struct ProjectState {
 123    worktrees: HashMap<WorktreeId, WorktreeState>,
 124    pending_file_count_rx: watch::Receiver<usize>,
 125    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 126    pending_index: usize,
 127    _subscription: gpui::Subscription,
 128    _observe_pending_file_count: Task<()>,
 129}
 130
 131enum WorktreeState {
 132    Registering(RegisteringWorktreeState),
 133    Registered(RegisteredWorktreeState),
 134}
 135
 136impl WorktreeState {
 137    fn is_registered(&self) -> bool {
 138        matches!(self, Self::Registered(_))
 139    }
 140
 141    fn paths_changed(
 142        &mut self,
 143        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 144        worktree: &Worktree,
 145    ) {
 146        let changed_paths = match self {
 147            Self::Registering(state) => &mut state.changed_paths,
 148            Self::Registered(state) => &mut state.changed_paths,
 149        };
 150
 151        for (path, entry_id, change) in changes.iter() {
 152            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 153                continue;
 154            };
 155            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 156                continue;
 157            }
 158            changed_paths.insert(
 159                path.clone(),
 160                ChangedPathInfo {
 161                    mtime: entry.mtime,
 162                    is_deleted: *change == PathChange::Removed,
 163                },
 164            );
 165        }
 166    }
 167}
 168
 169struct RegisteringWorktreeState {
 170    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 171    done_rx: watch::Receiver<Option<()>>,
 172    _registration: Task<()>,
 173}
 174
 175impl RegisteringWorktreeState {
 176    fn done(&self) -> impl Future<Output = ()> {
 177        let mut done_rx = self.done_rx.clone();
 178        async move {
 179            while let Some(result) = done_rx.next().await {
 180                if result.is_some() {
 181                    break;
 182                }
 183            }
 184        }
 185    }
 186}
 187
 188struct RegisteredWorktreeState {
 189    db_id: i64,
 190    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 191}
 192
 193struct ChangedPathInfo {
 194    mtime: SystemTime,
 195    is_deleted: bool,
 196}
 197
 198#[derive(Clone)]
 199pub struct JobHandle {
 200    /// The outer Arc is here to count the clones of a JobHandle instance;
 201    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 202    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 203}
 204
 205impl JobHandle {
 206    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 207        *tx.lock().borrow_mut() += 1;
 208        Self {
 209            tx: Arc::new(Arc::downgrade(&tx)),
 210        }
 211    }
 212}
 213
 214impl ProjectState {
 215    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 216        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 217        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 218        Self {
 219            worktrees: Default::default(),
 220            pending_file_count_rx: pending_file_count_rx.clone(),
 221            pending_file_count_tx,
 222            pending_index: 0,
 223            _subscription: subscription,
 224            _observe_pending_file_count: cx.spawn_weak({
 225                let mut pending_file_count_rx = pending_file_count_rx.clone();
 226                |this, mut cx| async move {
 227                    while let Some(_) = pending_file_count_rx.next().await {
 228                        if let Some(this) = this.upgrade(&cx) {
 229                            this.update(&mut cx, |_, cx| cx.notify());
 230                        }
 231                    }
 232                }
 233            }),
 234        }
 235    }
 236
 237    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 238        self.worktrees
 239            .iter()
 240            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 241                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 242                _ => None,
 243            })
 244    }
 245}
 246
 247#[derive(Clone)]
 248pub struct PendingFile {
 249    worktree_db_id: i64,
 250    relative_path: Arc<Path>,
 251    absolute_path: PathBuf,
 252    language: Option<Arc<Language>>,
 253    modified_time: SystemTime,
 254    job_handle: JobHandle,
 255}
 256
 257#[derive(Clone)]
 258pub struct SearchResult {
 259    pub buffer: ModelHandle<Buffer>,
 260    pub range: Range<Anchor>,
 261    pub similarity: OrderedFloat<f32>,
 262}
 263
 264impl SemanticIndex {
 265    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
 266        if cx.has_global::<ModelHandle<Self>>() {
 267            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 268        } else {
 269            None
 270        }
 271    }
 272
 273    pub fn enabled(cx: &AppContext) -> bool {
 274        settings::get::<SemanticIndexSettings>(cx).enabled
 275    }
 276
 277    pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
 278        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 279            if project_state
 280                .worktrees
 281                .values()
 282                .all(|worktree| worktree.is_registered())
 283                && project_state.pending_index == 0
 284            {
 285                SemanticIndexStatus::Indexed
 286            } else {
 287                SemanticIndexStatus::Indexing {
 288                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 289                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 290                }
 291            }
 292        } else {
 293            SemanticIndexStatus::NotIndexed
 294        }
 295    }
 296
 297    pub async fn new(
 298        fs: Arc<dyn Fs>,
 299        database_path: PathBuf,
 300        embedding_provider: Arc<dyn EmbeddingProvider>,
 301        language_registry: Arc<LanguageRegistry>,
 302        mut cx: AsyncAppContext,
 303    ) -> Result<ModelHandle<Self>> {
 304        let t0 = Instant::now();
 305        let database_path = Arc::from(database_path);
 306        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 307
 308        log::trace!(
 309            "db initialization took {:?} milliseconds",
 310            t0.elapsed().as_millis()
 311        );
 312
 313        Ok(cx.add_model(|cx| {
 314            let t0 = Instant::now();
 315            let embedding_queue =
 316                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
 317            let _embedding_task = cx.background().spawn({
 318                let embedded_files = embedding_queue.finished_files();
 319                let db = db.clone();
 320                async move {
 321                    while let Ok(file) = embedded_files.recv().await {
 322                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 323                            .await
 324                            .log_err();
 325                    }
 326                }
 327            });
 328
 329            // Parse files into embeddable spans.
 330            let (parsing_files_tx, parsing_files_rx) =
 331                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 332            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 333            let mut _parsing_files_tasks = Vec::new();
 334            for _ in 0..cx.background().num_cpus() {
 335                let fs = fs.clone();
 336                let mut parsing_files_rx = parsing_files_rx.clone();
 337                let embedding_provider = embedding_provider.clone();
 338                let embedding_queue = embedding_queue.clone();
 339                let background = cx.background().clone();
 340                _parsing_files_tasks.push(cx.background().spawn(async move {
 341                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 342                    loop {
 343                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 344                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 345                        futures::select_biased! {
 346                            next_file_to_parse = next_file_to_parse => {
 347                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 348                                    Self::parse_file(
 349                                        &fs,
 350                                        pending_file,
 351                                        &mut retriever,
 352                                        &embedding_queue,
 353                                        &embeddings_for_digest,
 354                                    )
 355                                    .await
 356                                } else {
 357                                    break;
 358                                }
 359                            },
 360                            _ = timer => {
 361                                embedding_queue.lock().flush();
 362                            }
 363                        }
 364                    }
 365                }));
 366            }
 367
 368            log::trace!(
 369                "semantic index task initialization took {:?} milliseconds",
 370                t0.elapsed().as_millis()
 371            );
 372            Self {
 373                fs,
 374                db,
 375                embedding_provider,
 376                language_registry,
 377                parsing_files_tx,
 378                _embedding_task,
 379                _parsing_files_tasks,
 380                projects: Default::default(),
 381            }
 382        }))
 383    }
 384
 385    async fn parse_file(
 386        fs: &Arc<dyn Fs>,
 387        pending_file: PendingFile,
 388        retriever: &mut CodeContextRetriever,
 389        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 390        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 391    ) {
 392        let Some(language) = pending_file.language else {
 393            return;
 394        };
 395
 396        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 397            if let Some(mut spans) = retriever
 398                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 399                .log_err()
 400            {
 401                log::trace!(
 402                    "parsed path {:?}: {} spans",
 403                    pending_file.relative_path,
 404                    spans.len()
 405                );
 406
 407                for span in &mut spans {
 408                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 409                        span.embedding = Some(embedding.to_owned());
 410                    }
 411                }
 412
 413                embedding_queue.lock().push(FileToEmbed {
 414                    worktree_id: pending_file.worktree_db_id,
 415                    path: pending_file.relative_path,
 416                    mtime: pending_file.modified_time,
 417                    job_handle: pending_file.job_handle,
 418                    spans,
 419                });
 420            }
 421        }
 422    }
 423
 424    pub fn project_previously_indexed(
 425        &mut self,
 426        project: &ModelHandle<Project>,
 427        cx: &mut ModelContext<Self>,
 428    ) -> Task<Result<bool>> {
 429        let worktrees_indexed_previously = project
 430            .read(cx)
 431            .worktrees(cx)
 432            .map(|worktree| {
 433                self.db
 434                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 435            })
 436            .collect::<Vec<_>>();
 437        cx.spawn(|_, _cx| async move {
 438            let worktree_indexed_previously =
 439                futures::future::join_all(worktrees_indexed_previously).await;
 440
 441            Ok(worktree_indexed_previously
 442                .iter()
 443                .filter(|worktree| worktree.is_ok())
 444                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 445        })
 446    }
 447
 448    fn project_entries_changed(
 449        &mut self,
 450        project: ModelHandle<Project>,
 451        worktree_id: WorktreeId,
 452        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 453        cx: &mut ModelContext<Self>,
 454    ) {
 455        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 456            return;
 457        };
 458        let project = project.downgrade();
 459        let Some(project_state) = self.projects.get_mut(&project) else {
 460            return;
 461        };
 462
 463        let worktree = worktree.read(cx);
 464        let worktree_state =
 465            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 466                worktree_state
 467            } else {
 468                return;
 469            };
 470        worktree_state.paths_changed(changes, worktree);
 471        if let WorktreeState::Registered(_) = worktree_state {
 472            cx.spawn_weak(|this, mut cx| async move {
 473                cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
 474                if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
 475                    this.update(&mut cx, |this, cx| {
 476                        this.index_project(project, cx).detach_and_log_err(cx)
 477                    });
 478                }
 479            })
 480            .detach();
 481        }
 482    }
 483
 484    fn register_worktree(
 485        &mut self,
 486        project: ModelHandle<Project>,
 487        worktree: ModelHandle<Worktree>,
 488        cx: &mut ModelContext<Self>,
 489    ) {
 490        let project = project.downgrade();
 491        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 492            project_state
 493        } else {
 494            return;
 495        };
 496        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 497            worktree
 498        } else {
 499            return;
 500        };
 501        let worktree_abs_path = worktree.abs_path().clone();
 502        let scan_complete = worktree.scan_complete();
 503        let worktree_id = worktree.id();
 504        let db = self.db.clone();
 505        let language_registry = self.language_registry.clone();
 506        let (mut done_tx, done_rx) = watch::channel();
 507        let registration = cx.spawn(|this, mut cx| {
 508            async move {
 509                let register = async {
 510                    scan_complete.await;
 511                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 512                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 513                    let worktree = if let Some(project) = project.upgrade(&cx) {
 514                        project
 515                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 516                            .ok_or_else(|| anyhow!("worktree not found"))?
 517                    } else {
 518                        return anyhow::Ok(());
 519                    };
 520                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
 521                    let mut changed_paths = cx
 522                        .background()
 523                        .spawn(async move {
 524                            let mut changed_paths = BTreeMap::new();
 525                            for file in worktree.files(false, 0) {
 526                                let absolute_path = worktree.absolutize(&file.path);
 527
 528                                if file.is_external || file.is_ignored || file.is_symlink {
 529                                    continue;
 530                                }
 531
 532                                if let Ok(language) = language_registry
 533                                    .language_for_file(&absolute_path, None)
 534                                    .await
 535                                {
 536                                    // Test if file is valid parseable file
 537                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 538                                        .contains(&language.name().as_ref())
 539                                        && &language.name().as_ref() != &"Markdown"
 540                                        && language
 541                                            .grammar()
 542                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 543                                            .is_none()
 544                                    {
 545                                        continue;
 546                                    }
 547
 548                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 549                                    let already_stored = stored_mtime
 550                                        .map_or(false, |existing_mtime| {
 551                                            existing_mtime == file.mtime
 552                                        });
 553
 554                                    if !already_stored {
 555                                        changed_paths.insert(
 556                                            file.path.clone(),
 557                                            ChangedPathInfo {
 558                                                mtime: file.mtime,
 559                                                is_deleted: false,
 560                                            },
 561                                        );
 562                                    }
 563                                }
 564                            }
 565
 566                            // Clean up entries from database that are no longer in the worktree.
 567                            for (path, mtime) in file_mtimes {
 568                                changed_paths.insert(
 569                                    path.into(),
 570                                    ChangedPathInfo {
 571                                        mtime,
 572                                        is_deleted: true,
 573                                    },
 574                                );
 575                            }
 576
 577                            anyhow::Ok(changed_paths)
 578                        })
 579                        .await?;
 580                    this.update(&mut cx, |this, cx| {
 581                        let project_state = this
 582                            .projects
 583                            .get_mut(&project)
 584                            .ok_or_else(|| anyhow!("project not registered"))?;
 585                        let project = project
 586                            .upgrade(cx)
 587                            .ok_or_else(|| anyhow!("project was dropped"))?;
 588
 589                        if let Some(WorktreeState::Registering(state)) =
 590                            project_state.worktrees.remove(&worktree_id)
 591                        {
 592                            changed_paths.extend(state.changed_paths);
 593                        }
 594                        project_state.worktrees.insert(
 595                            worktree_id,
 596                            WorktreeState::Registered(RegisteredWorktreeState {
 597                                db_id,
 598                                changed_paths,
 599                            }),
 600                        );
 601                        this.index_project(project, cx).detach_and_log_err(cx);
 602
 603                        anyhow::Ok(())
 604                    })?;
 605
 606                    anyhow::Ok(())
 607                };
 608
 609                if register.await.log_err().is_none() {
 610                    // Stop tracking this worktree if the registration failed.
 611                    this.update(&mut cx, |this, _| {
 612                        this.projects.get_mut(&project).map(|project_state| {
 613                            project_state.worktrees.remove(&worktree_id);
 614                        });
 615                    })
 616                }
 617
 618                *done_tx.borrow_mut() = Some(());
 619            }
 620        });
 621        project_state.worktrees.insert(
 622            worktree_id,
 623            WorktreeState::Registering(RegisteringWorktreeState {
 624                changed_paths: Default::default(),
 625                done_rx,
 626                _registration: registration,
 627            }),
 628        );
 629    }
 630
 631    fn project_worktrees_changed(
 632        &mut self,
 633        project: ModelHandle<Project>,
 634        cx: &mut ModelContext<Self>,
 635    ) {
 636        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 637        {
 638            project_state
 639        } else {
 640            return;
 641        };
 642
 643        let mut worktrees = project
 644            .read(cx)
 645            .worktrees(cx)
 646            .filter(|worktree| worktree.read(cx).is_local())
 647            .collect::<Vec<_>>();
 648        let worktree_ids = worktrees
 649            .iter()
 650            .map(|worktree| worktree.read(cx).id())
 651            .collect::<HashSet<_>>();
 652
 653        // Remove worktrees that are no longer present
 654        project_state
 655            .worktrees
 656            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 657
 658        // Register new worktrees
 659        worktrees.retain(|worktree| {
 660            let worktree_id = worktree.read(cx).id();
 661            !project_state.worktrees.contains_key(&worktree_id)
 662        });
 663        for worktree in worktrees {
 664            self.register_worktree(project.clone(), worktree, cx);
 665        }
 666    }
 667
 668    pub fn pending_file_count(
 669        &self,
 670        project: &ModelHandle<Project>,
 671    ) -> Option<watch::Receiver<usize>> {
 672        Some(
 673            self.projects
 674                .get(&project.downgrade())?
 675                .pending_file_count_rx
 676                .clone(),
 677        )
 678    }
 679
 680    pub fn search_project(
 681        &mut self,
 682        project: ModelHandle<Project>,
 683        query: String,
 684        limit: usize,
 685        includes: Vec<PathMatcher>,
 686        excludes: Vec<PathMatcher>,
 687        cx: &mut ModelContext<Self>,
 688    ) -> Task<Result<Vec<SearchResult>>> {
 689        if query.is_empty() {
 690            return Task::ready(Ok(Vec::new()));
 691        }
 692
 693        let index = self.index_project(project.clone(), cx);
 694        let embedding_provider = self.embedding_provider.clone();
 695
 696        cx.spawn(|this, mut cx| async move {
 697            let query = embedding_provider
 698                .embed_batch(vec![query])
 699                .await?
 700                .pop()
 701                .ok_or_else(|| anyhow!("could not embed query"))?;
 702            index.await?;
 703
 704            let search_start = Instant::now();
 705            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 706                this.search_modified_buffers(
 707                    &project,
 708                    query.clone(),
 709                    limit,
 710                    &includes,
 711                    &excludes,
 712                    cx,
 713                )
 714            });
 715            let file_results = this.update(&mut cx, |this, cx| {
 716                this.search_files(project, query, limit, includes, excludes, cx)
 717            });
 718            let (modified_buffer_results, file_results) =
 719                futures::join!(modified_buffer_results, file_results);
 720
 721            // Weave together the results from modified buffers and files.
 722            let mut results = Vec::new();
 723            let mut modified_buffers = HashSet::default();
 724            for result in modified_buffer_results.log_err().unwrap_or_default() {
 725                modified_buffers.insert(result.buffer.clone());
 726                results.push(result);
 727            }
 728            for result in file_results.log_err().unwrap_or_default() {
 729                if !modified_buffers.contains(&result.buffer) {
 730                    results.push(result);
 731                }
 732            }
 733            results.sort_by_key(|result| Reverse(result.similarity));
 734            results.truncate(limit);
 735            log::trace!("Semantic search took {:?}", search_start.elapsed());
 736            Ok(results)
 737        })
 738    }
 739
 740    pub fn search_files(
 741        &mut self,
 742        project: ModelHandle<Project>,
 743        query: Embedding,
 744        limit: usize,
 745        includes: Vec<PathMatcher>,
 746        excludes: Vec<PathMatcher>,
 747        cx: &mut ModelContext<Self>,
 748    ) -> Task<Result<Vec<SearchResult>>> {
 749        let db_path = self.db.path().clone();
 750        let fs = self.fs.clone();
 751        cx.spawn(|this, mut cx| async move {
 752            let database =
 753                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 754
 755            let worktree_db_ids = this.read_with(&cx, |this, _| {
 756                let project_state = this
 757                    .projects
 758                    .get(&project.downgrade())
 759                    .ok_or_else(|| anyhow!("project was not indexed"))?;
 760                let worktree_db_ids = project_state
 761                    .worktrees
 762                    .values()
 763                    .filter_map(|worktree| {
 764                        if let WorktreeState::Registered(worktree) = worktree {
 765                            Some(worktree.db_id)
 766                        } else {
 767                            None
 768                        }
 769                    })
 770                    .collect::<Vec<i64>>();
 771                anyhow::Ok(worktree_db_ids)
 772            })?;
 773
 774            let file_ids = database
 775                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 776                .await?;
 777
 778            let batch_n = cx.background().num_cpus();
 779            let ids_len = file_ids.clone().len();
 780            let batch_size = if ids_len <= batch_n {
 781                ids_len
 782            } else {
 783                ids_len / batch_n
 784            };
 785
 786            let mut batch_results = Vec::new();
 787            for batch in file_ids.chunks(batch_size) {
 788                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 789                let limit = limit.clone();
 790                let fs = fs.clone();
 791                let db_path = db_path.clone();
 792                let query = query.clone();
 793                if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
 794                    .await
 795                    .log_err()
 796                {
 797                    batch_results.push(async move {
 798                        db.top_k_search(&query, limit, batch.as_slice()).await
 799                    });
 800                }
 801            }
 802
 803            let batch_results = futures::future::join_all(batch_results).await;
 804
 805            let mut results = Vec::new();
 806            for batch_result in batch_results {
 807                if batch_result.is_ok() {
 808                    for (id, similarity) in batch_result.unwrap() {
 809                        let ix = match results
 810                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 811                        {
 812                            Ok(ix) => ix,
 813                            Err(ix) => ix,
 814                        };
 815                        results.insert(ix, (id, similarity));
 816                        results.truncate(limit);
 817                    }
 818                }
 819            }
 820
 821            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 822            let scores = results
 823                .into_iter()
 824                .map(|(_, score)| score)
 825                .collect::<Vec<_>>();
 826            let spans = database.spans_for_ids(ids.as_slice()).await?;
 827
 828            let mut tasks = Vec::new();
 829            let mut ranges = Vec::new();
 830            let weak_project = project.downgrade();
 831            project.update(&mut cx, |project, cx| {
 832                for (worktree_db_id, file_path, byte_range) in spans {
 833                    let project_state =
 834                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 835                            state
 836                        } else {
 837                            return Err(anyhow!("project not added"));
 838                        };
 839                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 840                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 841                        ranges.push(byte_range);
 842                    }
 843                }
 844
 845                Ok(())
 846            })?;
 847
 848            let buffers = futures::future::join_all(tasks).await;
 849
 850            Ok(buffers
 851                .into_iter()
 852                .zip(ranges)
 853                .zip(scores)
 854                .filter_map(|((buffer, range), similarity)| {
 855                    let buffer = buffer.log_err()?;
 856                    let range = buffer.read_with(&cx, |buffer, _| {
 857                        let start = buffer.clip_offset(range.start, Bias::Left);
 858                        let end = buffer.clip_offset(range.end, Bias::Right);
 859                        buffer.anchor_before(start)..buffer.anchor_after(end)
 860                    });
 861                    Some(SearchResult {
 862                        buffer,
 863                        range,
 864                        similarity,
 865                    })
 866                })
 867                .collect())
 868        })
 869    }
 870
 871    fn search_modified_buffers(
 872        &self,
 873        project: &ModelHandle<Project>,
 874        query: Embedding,
 875        limit: usize,
 876        includes: &[PathMatcher],
 877        excludes: &[PathMatcher],
 878        cx: &mut ModelContext<Self>,
 879    ) -> Task<Result<Vec<SearchResult>>> {
 880        let modified_buffers = project
 881            .read(cx)
 882            .opened_buffers(cx)
 883            .into_iter()
 884            .filter_map(|buffer_handle| {
 885                let buffer = buffer_handle.read(cx);
 886                let snapshot = buffer.snapshot();
 887                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 888                    excludes.iter().any(|matcher| matcher.is_match(&path))
 889                });
 890
 891                let included = if includes.len() == 0 {
 892                    true
 893                } else {
 894                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 895                        includes.iter().any(|matcher| matcher.is_match(&path))
 896                    })
 897                };
 898
 899                if buffer.is_dirty() && !excluded && included {
 900                    Some((buffer_handle, snapshot))
 901                } else {
 902                    None
 903                }
 904            })
 905            .collect::<HashMap<_, _>>();
 906
 907        let embedding_provider = self.embedding_provider.clone();
 908        let fs = self.fs.clone();
 909        let db_path = self.db.path().clone();
 910        let background = cx.background().clone();
 911        cx.background().spawn(async move {
 912            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 913            let mut results = Vec::<SearchResult>::new();
 914
 915            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 916            for (buffer, snapshot) in modified_buffers {
 917                let language = snapshot
 918                    .language_at(0)
 919                    .cloned()
 920                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 921                let mut spans = retriever
 922                    .parse_file_with_template(None, &snapshot.text(), language)
 923                    .log_err()
 924                    .unwrap_or_default();
 925                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 926                    .await
 927                    .log_err()
 928                    .is_some()
 929                {
 930                    for span in spans {
 931                        let similarity = span.embedding.unwrap().similarity(&query);
 932                        let ix = match results
 933                            .binary_search_by_key(&Reverse(similarity), |result| {
 934                                Reverse(result.similarity)
 935                            }) {
 936                            Ok(ix) => ix,
 937                            Err(ix) => ix,
 938                        };
 939
 940                        let range = {
 941                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 942                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 943                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 944                        };
 945
 946                        results.insert(
 947                            ix,
 948                            SearchResult {
 949                                buffer: buffer.clone(),
 950                                range,
 951                                similarity,
 952                            },
 953                        );
 954                        results.truncate(limit);
 955                    }
 956                }
 957            }
 958
 959            Ok(results)
 960        })
 961    }
 962
 963    pub fn index_project(
 964        &mut self,
 965        project: ModelHandle<Project>,
 966        cx: &mut ModelContext<Self>,
 967    ) -> Task<Result<()>> {
 968        if !self.projects.contains_key(&project.downgrade()) {
 969            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
 970                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
 971                    this.project_worktrees_changed(project.clone(), cx);
 972                }
 973                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
 974                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
 975                }
 976                _ => {}
 977            });
 978            let project_state = ProjectState::new(subscription, cx);
 979            self.projects.insert(project.downgrade(), project_state);
 980            self.project_worktrees_changed(project.clone(), cx);
 981        }
 982        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
 983        project_state.pending_index += 1;
 984        cx.notify();
 985
 986        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
 987        let db = self.db.clone();
 988        let language_registry = self.language_registry.clone();
 989        let parsing_files_tx = self.parsing_files_tx.clone();
 990        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
 991
 992        cx.spawn(|this, mut cx| async move {
 993            worktree_registration.await?;
 994
 995            let mut pending_files = Vec::new();
 996            let mut files_to_delete = Vec::new();
 997            this.update(&mut cx, |this, cx| {
 998                let project_state = this
 999                    .projects
1000                    .get_mut(&project.downgrade())
1001                    .ok_or_else(|| anyhow!("project was dropped"))?;
1002                let pending_file_count_tx = &project_state.pending_file_count_tx;
1003
1004                project_state
1005                    .worktrees
1006                    .retain(|worktree_id, worktree_state| {
1007                        let worktree = if let Some(worktree) =
1008                            project.read(cx).worktree_for_id(*worktree_id, cx)
1009                        {
1010                            worktree
1011                        } else {
1012                            return false;
1013                        };
1014                        let worktree_state =
1015                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1016                                worktree_state
1017                            } else {
1018                                return true;
1019                            };
1020
1021                        worktree_state.changed_paths.retain(|path, info| {
1022                            if info.is_deleted {
1023                                files_to_delete.push((worktree_state.db_id, path.clone()));
1024                            } else {
1025                                let absolute_path = worktree.read(cx).absolutize(path);
1026                                let job_handle = JobHandle::new(pending_file_count_tx);
1027                                pending_files.push(PendingFile {
1028                                    absolute_path,
1029                                    relative_path: path.clone(),
1030                                    language: None,
1031                                    job_handle,
1032                                    modified_time: info.mtime,
1033                                    worktree_db_id: worktree_state.db_id,
1034                                });
1035                            }
1036
1037                            false
1038                        });
1039                        true
1040                    });
1041
1042                anyhow::Ok(())
1043            })?;
1044
1045            cx.background()
1046                .spawn(async move {
1047                    for (worktree_db_id, path) in files_to_delete {
1048                        db.delete_file(worktree_db_id, path).await.log_err();
1049                    }
1050
1051                    let embeddings_for_digest = {
1052                        let mut files = HashMap::default();
1053                        for pending_file in &pending_files {
1054                            files
1055                                .entry(pending_file.worktree_db_id)
1056                                .or_insert(Vec::new())
1057                                .push(pending_file.relative_path.clone());
1058                        }
1059                        Arc::new(
1060                            db.embeddings_for_files(files)
1061                                .await
1062                                .log_err()
1063                                .unwrap_or_default(),
1064                        )
1065                    };
1066
1067                    for mut pending_file in pending_files {
1068                        if let Ok(language) = language_registry
1069                            .language_for_file(&pending_file.relative_path, None)
1070                            .await
1071                        {
1072                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1073                                && &language.name().as_ref() != &"Markdown"
1074                                && language
1075                                    .grammar()
1076                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1077                                    .is_none()
1078                            {
1079                                continue;
1080                            }
1081                            pending_file.language = Some(language);
1082                        }
1083                        parsing_files_tx
1084                            .try_send((embeddings_for_digest.clone(), pending_file))
1085                            .ok();
1086                    }
1087
1088                    // Wait until we're done indexing.
1089                    while let Some(count) = pending_file_count_rx.next().await {
1090                        if count == 0 {
1091                            break;
1092                        }
1093                    }
1094                })
1095                .await;
1096
1097            this.update(&mut cx, |this, cx| {
1098                let project_state = this
1099                    .projects
1100                    .get_mut(&project.downgrade())
1101                    .ok_or_else(|| anyhow!("project was dropped"))?;
1102                project_state.pending_index -= 1;
1103                cx.notify();
1104                anyhow::Ok(())
1105            })?;
1106
1107            Ok(())
1108        })
1109    }
1110
1111    fn wait_for_worktree_registration(
1112        &self,
1113        project: &ModelHandle<Project>,
1114        cx: &mut ModelContext<Self>,
1115    ) -> Task<Result<()>> {
1116        let project = project.downgrade();
1117        cx.spawn_weak(|this, cx| async move {
1118            loop {
1119                let mut pending_worktrees = Vec::new();
1120                this.upgrade(&cx)
1121                    .ok_or_else(|| anyhow!("semantic index dropped"))?
1122                    .read_with(&cx, |this, _| {
1123                        if let Some(project) = this.projects.get(&project) {
1124                            for worktree in project.worktrees.values() {
1125                                if let WorktreeState::Registering(worktree) = worktree {
1126                                    pending_worktrees.push(worktree.done());
1127                                }
1128                            }
1129                        }
1130                    });
1131
1132                if pending_worktrees.is_empty() {
1133                    break;
1134                } else {
1135                    future::join_all(pending_worktrees).await;
1136                }
1137            }
1138            Ok(())
1139        })
1140    }
1141
1142    async fn embed_spans(
1143        spans: &mut [Span],
1144        embedding_provider: &dyn EmbeddingProvider,
1145        db: &VectorDatabase,
1146    ) -> Result<()> {
1147        let mut batch = Vec::new();
1148        let mut batch_tokens = 0;
1149        let mut embeddings = Vec::new();
1150
1151        let digests = spans
1152            .iter()
1153            .map(|span| span.digest.clone())
1154            .collect::<Vec<_>>();
1155        let embeddings_for_digests = db
1156            .embeddings_for_digests(digests)
1157            .await
1158            .log_err()
1159            .unwrap_or_default();
1160
1161        for span in &*spans {
1162            if embeddings_for_digests.contains_key(&span.digest) {
1163                continue;
1164            };
1165
1166            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1167                let batch_embeddings = embedding_provider
1168                    .embed_batch(mem::take(&mut batch))
1169                    .await?;
1170                embeddings.extend(batch_embeddings);
1171                batch_tokens = 0;
1172            }
1173
1174            batch_tokens += span.token_count;
1175            batch.push(span.content.clone());
1176        }
1177
1178        if !batch.is_empty() {
1179            let batch_embeddings = embedding_provider
1180                .embed_batch(mem::take(&mut batch))
1181                .await?;
1182
1183            embeddings.extend(batch_embeddings);
1184        }
1185
1186        let mut embeddings = embeddings.into_iter();
1187        for span in spans {
1188            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1189                Some(embedding.clone())
1190            } else {
1191                embeddings.next()
1192            };
1193            let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1194            span.embedding = Some(embedding);
1195        }
1196        Ok(())
1197    }
1198}
1199
1200impl Entity for SemanticIndex {
1201    type Event = ();
1202}
1203
1204impl Drop for JobHandle {
1205    fn drop(&mut self) {
1206        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1207            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1208            if let Some(tx) = inner.upgrade() {
1209                let mut tx = tx.lock();
1210                *tx.borrow_mut() -= 1;
1211            }
1212        }
1213    }
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218
1219    use super::*;
1220    #[test]
1221    fn test_job_handle() {
1222        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1223        let tx = Arc::new(Mutex::new(job_count_tx));
1224        let job_handle = JobHandle::new(&tx);
1225
1226        assert_eq!(1, *job_count_rx.borrow());
1227        let new_job_handle = job_handle.clone();
1228        assert_eq!(1, *job_count_rx.borrow());
1229        drop(job_handle);
1230        assert_eq!(1, *job_count_rx.borrow());
1231        drop(new_job_handle);
1232        assert_eq!(0, *job_count_rx.borrow());
1233    }
1234}