semantic_index.rs

   1mod db;
   2mod embedding;
   3mod embedding_queue;
   4mod parsing;
   5pub mod semantic_index_settings;
   6
   7#[cfg(test)]
   8mod semantic_index_tests;
   9
  10use crate::semantic_index_settings::SemanticIndexSettings;
  11use anyhow::{anyhow, Result};
  12use collections::{BTreeMap, HashMap, HashSet};
  13use db::VectorDatabase;
  14use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  18use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  19use ordered_float::OrderedFloat;
  20use parking_lot::Mutex;
  21use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  22use postage::watch;
  23use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  24use smol::channel;
  25use std::{
  26    cmp::Reverse,
  27    future::Future,
  28    mem,
  29    ops::Range,
  30    path::{Path, PathBuf},
  31    sync::{Arc, Weak},
  32    time::{Duration, Instant, SystemTime},
  33};
  34use util::{
  35    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
  36    http::HttpClient,
  37    paths::EMBEDDINGS_DIR,
  38    ResultExt,
  39};
  40use workspace::WorkspaceCreated;
  41
  42const SEMANTIC_INDEX_VERSION: usize = 11;
  43const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  44const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  45
  46pub fn init(
  47    fs: Arc<dyn Fs>,
  48    http_client: Arc<dyn HttpClient>,
  49    language_registry: Arc<LanguageRegistry>,
  50    cx: &mut AppContext,
  51) {
  52    settings::register::<SemanticIndexSettings>(cx);
  53
  54    let db_file_path = EMBEDDINGS_DIR
  55        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  56        .join("embeddings_db");
  57
  58    // This needs to be removed at some point before stable.
  59    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
  60        return;
  61    }
  62
  63    cx.subscribe_global::<WorkspaceCreated, _>({
  64        move |event, cx| {
  65            let Some(semantic_index) = SemanticIndex::global(cx) else {
  66                return;
  67            };
  68            let workspace = &event.0;
  69            if let Some(workspace) = workspace.upgrade(cx) {
  70                let project = workspace.read(cx).project().clone();
  71                if project.read(cx).is_local() {
  72                    cx.spawn(|mut cx| async move {
  73                        let previously_indexed = semantic_index
  74                            .update(&mut cx, |index, cx| {
  75                                index.project_previously_indexed(&project, cx)
  76                            })
  77                            .await?;
  78                        if previously_indexed {
  79                            semantic_index
  80                                .update(&mut cx, |index, cx| index.index_project(project, cx))
  81                                .await?;
  82                        }
  83                        anyhow::Ok(())
  84                    })
  85                    .detach_and_log_err(cx);
  86                }
  87            }
  88        }
  89    })
  90    .detach();
  91
  92    cx.spawn(move |mut cx| async move {
  93        let semantic_index = SemanticIndex::new(
  94            fs,
  95            db_file_path,
  96            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
  97            language_registry,
  98            cx.clone(),
  99        )
 100        .await?;
 101
 102        cx.update(|cx| {
 103            cx.set_global(semantic_index.clone());
 104        });
 105
 106        anyhow::Ok(())
 107    })
 108    .detach();
 109}
 110
 111#[derive(Copy, Clone, Debug)]
 112pub enum SemanticIndexStatus {
 113    NotIndexed,
 114    Indexed,
 115    Indexing {
 116        remaining_files: usize,
 117        rate_limit_expiry: Option<Instant>,
 118    },
 119}
 120
 121pub struct SemanticIndex {
 122    fs: Arc<dyn Fs>,
 123    db: VectorDatabase,
 124    embedding_provider: Arc<dyn EmbeddingProvider>,
 125    language_registry: Arc<LanguageRegistry>,
 126    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 127    _embedding_task: Task<()>,
 128    _parsing_files_tasks: Vec<Task<()>>,
 129    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 130}
 131
 132struct ProjectState {
 133    worktrees: HashMap<WorktreeId, WorktreeState>,
 134    pending_file_count_rx: watch::Receiver<usize>,
 135    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 136    pending_index: usize,
 137    _subscription: gpui::Subscription,
 138    _observe_pending_file_count: Task<()>,
 139}
 140
 141enum WorktreeState {
 142    Registering(RegisteringWorktreeState),
 143    Registered(RegisteredWorktreeState),
 144}
 145
 146impl WorktreeState {
 147    fn is_registered(&self) -> bool {
 148        matches!(self, Self::Registered(_))
 149    }
 150
 151    fn paths_changed(
 152        &mut self,
 153        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 154        worktree: &Worktree,
 155    ) {
 156        let changed_paths = match self {
 157            Self::Registering(state) => &mut state.changed_paths,
 158            Self::Registered(state) => &mut state.changed_paths,
 159        };
 160
 161        for (path, entry_id, change) in changes.iter() {
 162            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 163                continue;
 164            };
 165            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 166                continue;
 167            }
 168            changed_paths.insert(
 169                path.clone(),
 170                ChangedPathInfo {
 171                    mtime: entry.mtime,
 172                    is_deleted: *change == PathChange::Removed,
 173                },
 174            );
 175        }
 176    }
 177}
 178
 179struct RegisteringWorktreeState {
 180    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 181    done_rx: watch::Receiver<Option<()>>,
 182    _registration: Task<()>,
 183}
 184
 185impl RegisteringWorktreeState {
 186    fn done(&self) -> impl Future<Output = ()> {
 187        let mut done_rx = self.done_rx.clone();
 188        async move {
 189            while let Some(result) = done_rx.next().await {
 190                if result.is_some() {
 191                    break;
 192                }
 193            }
 194        }
 195    }
 196}
 197
 198struct RegisteredWorktreeState {
 199    db_id: i64,
 200    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 201}
 202
 203struct ChangedPathInfo {
 204    mtime: SystemTime,
 205    is_deleted: bool,
 206}
 207
 208#[derive(Clone)]
 209pub struct JobHandle {
 210    /// The outer Arc is here to count the clones of a JobHandle instance;
 211    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 212    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 213}
 214
 215impl JobHandle {
 216    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 217        *tx.lock().borrow_mut() += 1;
 218        Self {
 219            tx: Arc::new(Arc::downgrade(&tx)),
 220        }
 221    }
 222}
 223
 224impl ProjectState {
 225    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 226        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 227        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 228        Self {
 229            worktrees: Default::default(),
 230            pending_file_count_rx: pending_file_count_rx.clone(),
 231            pending_file_count_tx,
 232            pending_index: 0,
 233            _subscription: subscription,
 234            _observe_pending_file_count: cx.spawn_weak({
 235                let mut pending_file_count_rx = pending_file_count_rx.clone();
 236                |this, mut cx| async move {
 237                    while let Some(_) = pending_file_count_rx.next().await {
 238                        if let Some(this) = this.upgrade(&cx) {
 239                            this.update(&mut cx, |_, cx| cx.notify());
 240                        }
 241                    }
 242                }
 243            }),
 244        }
 245    }
 246
 247    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 248        self.worktrees
 249            .iter()
 250            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 251                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 252                _ => None,
 253            })
 254    }
 255}
 256
 257#[derive(Clone)]
 258pub struct PendingFile {
 259    worktree_db_id: i64,
 260    relative_path: Arc<Path>,
 261    absolute_path: PathBuf,
 262    language: Option<Arc<Language>>,
 263    modified_time: SystemTime,
 264    job_handle: JobHandle,
 265}
 266
 267#[derive(Clone)]
 268pub struct SearchResult {
 269    pub buffer: ModelHandle<Buffer>,
 270    pub range: Range<Anchor>,
 271    pub similarity: OrderedFloat<f32>,
 272}
 273
 274impl SemanticIndex {
 275    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
 276        if cx.has_global::<ModelHandle<Self>>() {
 277            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 278        } else {
 279            None
 280        }
 281    }
 282
 283    pub fn enabled(cx: &AppContext) -> bool {
 284        settings::get::<SemanticIndexSettings>(cx).enabled
 285            && *RELEASE_CHANNEL != ReleaseChannel::Stable
 286    }
 287
 288    pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
 289        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 290            if project_state
 291                .worktrees
 292                .values()
 293                .all(|worktree| worktree.is_registered())
 294                && project_state.pending_index == 0
 295            {
 296                SemanticIndexStatus::Indexed
 297            } else {
 298                SemanticIndexStatus::Indexing {
 299                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 300                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 301                }
 302            }
 303        } else {
 304            SemanticIndexStatus::NotIndexed
 305        }
 306    }
 307
 308    async fn new(
 309        fs: Arc<dyn Fs>,
 310        database_path: PathBuf,
 311        embedding_provider: Arc<dyn EmbeddingProvider>,
 312        language_registry: Arc<LanguageRegistry>,
 313        mut cx: AsyncAppContext,
 314    ) -> Result<ModelHandle<Self>> {
 315        let t0 = Instant::now();
 316        let database_path = Arc::from(database_path);
 317        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 318
 319        log::trace!(
 320            "db initialization took {:?} milliseconds",
 321            t0.elapsed().as_millis()
 322        );
 323
 324        Ok(cx.add_model(|cx| {
 325            let t0 = Instant::now();
 326            let embedding_queue =
 327                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
 328            let _embedding_task = cx.background().spawn({
 329                let embedded_files = embedding_queue.finished_files();
 330                let db = db.clone();
 331                async move {
 332                    while let Ok(file) = embedded_files.recv().await {
 333                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 334                            .await
 335                            .log_err();
 336                    }
 337                }
 338            });
 339
 340            // Parse files into embeddable spans.
 341            let (parsing_files_tx, parsing_files_rx) =
 342                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 343            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 344            let mut _parsing_files_tasks = Vec::new();
 345            for _ in 0..cx.background().num_cpus() {
 346                let fs = fs.clone();
 347                let mut parsing_files_rx = parsing_files_rx.clone();
 348                let embedding_provider = embedding_provider.clone();
 349                let embedding_queue = embedding_queue.clone();
 350                let background = cx.background().clone();
 351                _parsing_files_tasks.push(cx.background().spawn(async move {
 352                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 353                    loop {
 354                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 355                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 356                        futures::select_biased! {
 357                            next_file_to_parse = next_file_to_parse => {
 358                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 359                                    Self::parse_file(
 360                                        &fs,
 361                                        pending_file,
 362                                        &mut retriever,
 363                                        &embedding_queue,
 364                                        &embeddings_for_digest,
 365                                    )
 366                                    .await
 367                                } else {
 368                                    break;
 369                                }
 370                            },
 371                            _ = timer => {
 372                                embedding_queue.lock().flush();
 373                            }
 374                        }
 375                    }
 376                }));
 377            }
 378
 379            log::trace!(
 380                "semantic index task initialization took {:?} milliseconds",
 381                t0.elapsed().as_millis()
 382            );
 383            Self {
 384                fs,
 385                db,
 386                embedding_provider,
 387                language_registry,
 388                parsing_files_tx,
 389                _embedding_task,
 390                _parsing_files_tasks,
 391                projects: Default::default(),
 392            }
 393        }))
 394    }
 395
 396    async fn parse_file(
 397        fs: &Arc<dyn Fs>,
 398        pending_file: PendingFile,
 399        retriever: &mut CodeContextRetriever,
 400        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 401        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 402    ) {
 403        let Some(language) = pending_file.language else {
 404            return;
 405        };
 406
 407        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 408            if let Some(mut spans) = retriever
 409                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 410                .log_err()
 411            {
 412                log::trace!(
 413                    "parsed path {:?}: {} spans",
 414                    pending_file.relative_path,
 415                    spans.len()
 416                );
 417
 418                for span in &mut spans {
 419                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 420                        span.embedding = Some(embedding.to_owned());
 421                    }
 422                }
 423
 424                embedding_queue.lock().push(FileToEmbed {
 425                    worktree_id: pending_file.worktree_db_id,
 426                    path: pending_file.relative_path,
 427                    mtime: pending_file.modified_time,
 428                    job_handle: pending_file.job_handle,
 429                    spans,
 430                });
 431            }
 432        }
 433    }
 434
 435    pub fn project_previously_indexed(
 436        &mut self,
 437        project: &ModelHandle<Project>,
 438        cx: &mut ModelContext<Self>,
 439    ) -> Task<Result<bool>> {
 440        let worktrees_indexed_previously = project
 441            .read(cx)
 442            .worktrees(cx)
 443            .map(|worktree| {
 444                self.db
 445                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 446            })
 447            .collect::<Vec<_>>();
 448        cx.spawn(|_, _cx| async move {
 449            let worktree_indexed_previously =
 450                futures::future::join_all(worktrees_indexed_previously).await;
 451
 452            Ok(worktree_indexed_previously
 453                .iter()
 454                .filter(|worktree| worktree.is_ok())
 455                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 456        })
 457    }
 458
 459    fn project_entries_changed(
 460        &mut self,
 461        project: ModelHandle<Project>,
 462        worktree_id: WorktreeId,
 463        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 464        cx: &mut ModelContext<Self>,
 465    ) {
 466        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 467            return;
 468        };
 469        let project = project.downgrade();
 470        let Some(project_state) = self.projects.get_mut(&project) else {
 471            return;
 472        };
 473
 474        let worktree = worktree.read(cx);
 475        let worktree_state =
 476            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 477                worktree_state
 478            } else {
 479                return;
 480            };
 481        worktree_state.paths_changed(changes, worktree);
 482        if let WorktreeState::Registered(_) = worktree_state {
 483            cx.spawn_weak(|this, mut cx| async move {
 484                cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
 485                if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
 486                    this.update(&mut cx, |this, cx| {
 487                        this.index_project(project, cx).detach_and_log_err(cx)
 488                    });
 489                }
 490            })
 491            .detach();
 492        }
 493    }
 494
 495    fn register_worktree(
 496        &mut self,
 497        project: ModelHandle<Project>,
 498        worktree: ModelHandle<Worktree>,
 499        cx: &mut ModelContext<Self>,
 500    ) {
 501        let project = project.downgrade();
 502        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 503            project_state
 504        } else {
 505            return;
 506        };
 507        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 508            worktree
 509        } else {
 510            return;
 511        };
 512        let worktree_abs_path = worktree.abs_path().clone();
 513        let scan_complete = worktree.scan_complete();
 514        let worktree_id = worktree.id();
 515        let db = self.db.clone();
 516        let language_registry = self.language_registry.clone();
 517        let (mut done_tx, done_rx) = watch::channel();
 518        let registration = cx.spawn(|this, mut cx| {
 519            async move {
 520                let register = async {
 521                    scan_complete.await;
 522                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 523                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 524                    let worktree = if let Some(project) = project.upgrade(&cx) {
 525                        project
 526                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 527                            .ok_or_else(|| anyhow!("worktree not found"))?
 528                    } else {
 529                        return anyhow::Ok(());
 530                    };
 531                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
 532                    let mut changed_paths = cx
 533                        .background()
 534                        .spawn(async move {
 535                            let mut changed_paths = BTreeMap::new();
 536                            for file in worktree.files(false, 0) {
 537                                let absolute_path = worktree.absolutize(&file.path);
 538
 539                                if file.is_external || file.is_ignored || file.is_symlink {
 540                                    continue;
 541                                }
 542
 543                                if let Ok(language) = language_registry
 544                                    .language_for_file(&absolute_path, None)
 545                                    .await
 546                                {
 547                                    // Test if file is valid parseable file
 548                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 549                                        .contains(&language.name().as_ref())
 550                                        && &language.name().as_ref() != &"Markdown"
 551                                        && language
 552                                            .grammar()
 553                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 554                                            .is_none()
 555                                    {
 556                                        continue;
 557                                    }
 558
 559                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 560                                    let already_stored = stored_mtime
 561                                        .map_or(false, |existing_mtime| {
 562                                            existing_mtime == file.mtime
 563                                        });
 564
 565                                    if !already_stored {
 566                                        changed_paths.insert(
 567                                            file.path.clone(),
 568                                            ChangedPathInfo {
 569                                                mtime: file.mtime,
 570                                                is_deleted: false,
 571                                            },
 572                                        );
 573                                    }
 574                                }
 575                            }
 576
 577                            // Clean up entries from database that are no longer in the worktree.
 578                            for (path, mtime) in file_mtimes {
 579                                changed_paths.insert(
 580                                    path.into(),
 581                                    ChangedPathInfo {
 582                                        mtime,
 583                                        is_deleted: true,
 584                                    },
 585                                );
 586                            }
 587
 588                            anyhow::Ok(changed_paths)
 589                        })
 590                        .await?;
 591                    this.update(&mut cx, |this, cx| {
 592                        let project_state = this
 593                            .projects
 594                            .get_mut(&project)
 595                            .ok_or_else(|| anyhow!("project not registered"))?;
 596                        let project = project
 597                            .upgrade(cx)
 598                            .ok_or_else(|| anyhow!("project was dropped"))?;
 599
 600                        if let Some(WorktreeState::Registering(state)) =
 601                            project_state.worktrees.remove(&worktree_id)
 602                        {
 603                            changed_paths.extend(state.changed_paths);
 604                        }
 605                        project_state.worktrees.insert(
 606                            worktree_id,
 607                            WorktreeState::Registered(RegisteredWorktreeState {
 608                                db_id,
 609                                changed_paths,
 610                            }),
 611                        );
 612                        this.index_project(project, cx).detach_and_log_err(cx);
 613
 614                        anyhow::Ok(())
 615                    })?;
 616
 617                    anyhow::Ok(())
 618                };
 619
 620                if register.await.log_err().is_none() {
 621                    // Stop tracking this worktree if the registration failed.
 622                    this.update(&mut cx, |this, _| {
 623                        this.projects.get_mut(&project).map(|project_state| {
 624                            project_state.worktrees.remove(&worktree_id);
 625                        });
 626                    })
 627                }
 628
 629                *done_tx.borrow_mut() = Some(());
 630            }
 631        });
 632        project_state.worktrees.insert(
 633            worktree_id,
 634            WorktreeState::Registering(RegisteringWorktreeState {
 635                changed_paths: Default::default(),
 636                done_rx,
 637                _registration: registration,
 638            }),
 639        );
 640    }
 641
 642    fn project_worktrees_changed(
 643        &mut self,
 644        project: ModelHandle<Project>,
 645        cx: &mut ModelContext<Self>,
 646    ) {
 647        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 648        {
 649            project_state
 650        } else {
 651            return;
 652        };
 653
 654        let mut worktrees = project
 655            .read(cx)
 656            .worktrees(cx)
 657            .filter(|worktree| worktree.read(cx).is_local())
 658            .collect::<Vec<_>>();
 659        let worktree_ids = worktrees
 660            .iter()
 661            .map(|worktree| worktree.read(cx).id())
 662            .collect::<HashSet<_>>();
 663
 664        // Remove worktrees that are no longer present
 665        project_state
 666            .worktrees
 667            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 668
 669        // Register new worktrees
 670        worktrees.retain(|worktree| {
 671            let worktree_id = worktree.read(cx).id();
 672            !project_state.worktrees.contains_key(&worktree_id)
 673        });
 674        for worktree in worktrees {
 675            self.register_worktree(project.clone(), worktree, cx);
 676        }
 677    }
 678
 679    pub fn pending_file_count(
 680        &self,
 681        project: &ModelHandle<Project>,
 682    ) -> Option<watch::Receiver<usize>> {
 683        Some(
 684            self.projects
 685                .get(&project.downgrade())?
 686                .pending_file_count_rx
 687                .clone(),
 688        )
 689    }
 690
 691    pub fn search_project(
 692        &mut self,
 693        project: ModelHandle<Project>,
 694        query: String,
 695        limit: usize,
 696        includes: Vec<PathMatcher>,
 697        excludes: Vec<PathMatcher>,
 698        cx: &mut ModelContext<Self>,
 699    ) -> Task<Result<Vec<SearchResult>>> {
 700        if query.is_empty() {
 701            return Task::ready(Ok(Vec::new()));
 702        }
 703
 704        let index = self.index_project(project.clone(), cx);
 705        let embedding_provider = self.embedding_provider.clone();
 706
 707        cx.spawn(|this, mut cx| async move {
 708            let query = embedding_provider
 709                .embed_batch(vec![query])
 710                .await?
 711                .pop()
 712                .ok_or_else(|| anyhow!("could not embed query"))?;
 713            index.await?;
 714
 715            let search_start = Instant::now();
 716            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 717                this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
 718            });
 719            let file_results = this.update(&mut cx, |this, cx| {
 720                this.search_files(project, query, limit, includes, excludes, cx)
 721            });
 722            let (modified_buffer_results, file_results) =
 723                futures::join!(modified_buffer_results, file_results);
 724
 725            // Weave together the results from modified buffers and files.
 726            let mut results = Vec::new();
 727            let mut modified_buffers = HashSet::default();
 728            for result in modified_buffer_results.log_err().unwrap_or_default() {
 729                modified_buffers.insert(result.buffer.clone());
 730                results.push(result);
 731            }
 732            for result in file_results.log_err().unwrap_or_default() {
 733                if !modified_buffers.contains(&result.buffer) {
 734                    results.push(result);
 735                }
 736            }
 737            results.sort_by_key(|result| Reverse(result.similarity));
 738            results.truncate(limit);
 739            log::trace!("Semantic search took {:?}", search_start.elapsed());
 740            Ok(results)
 741        })
 742    }
 743
 744    pub fn search_files(
 745        &mut self,
 746        project: ModelHandle<Project>,
 747        query: Embedding,
 748        limit: usize,
 749        includes: Vec<PathMatcher>,
 750        excludes: Vec<PathMatcher>,
 751        cx: &mut ModelContext<Self>,
 752    ) -> Task<Result<Vec<SearchResult>>> {
 753        let db_path = self.db.path().clone();
 754        let fs = self.fs.clone();
 755        cx.spawn(|this, mut cx| async move {
 756            let database =
 757                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 758
 759            let worktree_db_ids = this.read_with(&cx, |this, _| {
 760                let project_state = this
 761                    .projects
 762                    .get(&project.downgrade())
 763                    .ok_or_else(|| anyhow!("project was not indexed"))?;
 764                let worktree_db_ids = project_state
 765                    .worktrees
 766                    .values()
 767                    .filter_map(|worktree| {
 768                        if let WorktreeState::Registered(worktree) = worktree {
 769                            Some(worktree.db_id)
 770                        } else {
 771                            None
 772                        }
 773                    })
 774                    .collect::<Vec<i64>>();
 775                anyhow::Ok(worktree_db_ids)
 776            })?;
 777
 778            let file_ids = database
 779                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 780                .await?;
 781
 782            let batch_n = cx.background().num_cpus();
 783            let ids_len = file_ids.clone().len();
 784            let batch_size = if ids_len <= batch_n {
 785                ids_len
 786            } else {
 787                ids_len / batch_n
 788            };
 789
 790            let mut batch_results = Vec::new();
 791            for batch in file_ids.chunks(batch_size) {
 792                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 793                let limit = limit.clone();
 794                let fs = fs.clone();
 795                let db_path = db_path.clone();
 796                let query = query.clone();
 797                if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
 798                    .await
 799                    .log_err()
 800                {
 801                    batch_results.push(async move {
 802                        db.top_k_search(&query, limit, batch.as_slice()).await
 803                    });
 804                }
 805            }
 806
 807            let batch_results = futures::future::join_all(batch_results).await;
 808
 809            let mut results = Vec::new();
 810            for batch_result in batch_results {
 811                if batch_result.is_ok() {
 812                    for (id, similarity) in batch_result.unwrap() {
 813                        let ix = match results
 814                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 815                        {
 816                            Ok(ix) => ix,
 817                            Err(ix) => ix,
 818                        };
 819                        results.insert(ix, (id, similarity));
 820                        results.truncate(limit);
 821                    }
 822                }
 823            }
 824
 825            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 826            let scores = results
 827                .into_iter()
 828                .map(|(_, score)| score)
 829                .collect::<Vec<_>>();
 830            let spans = database.spans_for_ids(ids.as_slice()).await?;
 831
 832            let mut tasks = Vec::new();
 833            let mut ranges = Vec::new();
 834            let weak_project = project.downgrade();
 835            project.update(&mut cx, |project, cx| {
 836                for (worktree_db_id, file_path, byte_range) in spans {
 837                    let project_state =
 838                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 839                            state
 840                        } else {
 841                            return Err(anyhow!("project not added"));
 842                        };
 843                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 844                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 845                        ranges.push(byte_range);
 846                    }
 847                }
 848
 849                Ok(())
 850            })?;
 851
 852            let buffers = futures::future::join_all(tasks).await;
 853
 854            Ok(buffers
 855                .into_iter()
 856                .zip(ranges)
 857                .zip(scores)
 858                .filter_map(|((buffer, range), similarity)| {
 859                    let buffer = buffer.log_err()?;
 860                    let range = buffer.read_with(&cx, |buffer, _| {
 861                        let start = buffer.clip_offset(range.start, Bias::Left);
 862                        let end = buffer.clip_offset(range.end, Bias::Right);
 863                        buffer.anchor_before(start)..buffer.anchor_after(end)
 864                    });
 865                    Some(SearchResult {
 866                        buffer,
 867                        range,
 868                        similarity,
 869                    })
 870                })
 871                .collect())
 872        })
 873    }
 874
 875    fn search_modified_buffers(
 876        &self,
 877        project: &ModelHandle<Project>,
 878        query: Embedding,
 879        limit: usize,
 880        excludes: &[PathMatcher],
 881        cx: &mut ModelContext<Self>,
 882    ) -> Task<Result<Vec<SearchResult>>> {
 883        let modified_buffers = project
 884            .read(cx)
 885            .opened_buffers(cx)
 886            .into_iter()
 887            .filter_map(|buffer_handle| {
 888                let buffer = buffer_handle.read(cx);
 889                let snapshot = buffer.snapshot();
 890                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 891                    excludes.iter().any(|matcher| matcher.is_match(&path))
 892                });
 893                if buffer.is_dirty() && !excluded {
 894                    Some((buffer_handle, snapshot))
 895                } else {
 896                    None
 897                }
 898            })
 899            .collect::<HashMap<_, _>>();
 900
 901        let embedding_provider = self.embedding_provider.clone();
 902        let fs = self.fs.clone();
 903        let db_path = self.db.path().clone();
 904        let background = cx.background().clone();
 905        cx.background().spawn(async move {
 906            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 907            let mut results = Vec::<SearchResult>::new();
 908
 909            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 910            for (buffer, snapshot) in modified_buffers {
 911                let language = snapshot
 912                    .language_at(0)
 913                    .cloned()
 914                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 915                let mut spans = retriever
 916                    .parse_file_with_template(None, &snapshot.text(), language)
 917                    .log_err()
 918                    .unwrap_or_default();
 919                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 920                    .await
 921                    .log_err()
 922                    .is_some()
 923                {
 924                    for span in spans {
 925                        let similarity = span.embedding.unwrap().similarity(&query);
 926                        let ix = match results
 927                            .binary_search_by_key(&Reverse(similarity), |result| {
 928                                Reverse(result.similarity)
 929                            }) {
 930                            Ok(ix) => ix,
 931                            Err(ix) => ix,
 932                        };
 933
 934                        let range = {
 935                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 936                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 937                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 938                        };
 939
 940                        results.insert(
 941                            ix,
 942                            SearchResult {
 943                                buffer: buffer.clone(),
 944                                range,
 945                                similarity,
 946                            },
 947                        );
 948                        results.truncate(limit);
 949                    }
 950                }
 951            }
 952
 953            Ok(results)
 954        })
 955    }
 956
 957    pub fn index_project(
 958        &mut self,
 959        project: ModelHandle<Project>,
 960        cx: &mut ModelContext<Self>,
 961    ) -> Task<Result<()>> {
 962        if !self.projects.contains_key(&project.downgrade()) {
 963            log::trace!("Registering Project for Semantic Index");
 964
 965            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
 966                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
 967                    this.project_worktrees_changed(project.clone(), cx);
 968                }
 969                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
 970                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
 971                }
 972                _ => {}
 973            });
 974            let project_state = ProjectState::new(subscription, cx);
 975            self.projects.insert(project.downgrade(), project_state);
 976            self.project_worktrees_changed(project.clone(), cx);
 977        }
 978        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
 979        project_state.pending_index += 1;
 980        cx.notify();
 981
 982        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
 983        let db = self.db.clone();
 984        let language_registry = self.language_registry.clone();
 985        let parsing_files_tx = self.parsing_files_tx.clone();
 986        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
 987
 988        cx.spawn(|this, mut cx| async move {
 989            worktree_registration.await?;
 990
 991            let mut pending_files = Vec::new();
 992            let mut files_to_delete = Vec::new();
 993            this.update(&mut cx, |this, cx| {
 994                let project_state = this
 995                    .projects
 996                    .get_mut(&project.downgrade())
 997                    .ok_or_else(|| anyhow!("project was dropped"))?;
 998                let pending_file_count_tx = &project_state.pending_file_count_tx;
 999
1000                project_state
1001                    .worktrees
1002                    .retain(|worktree_id, worktree_state| {
1003                        let worktree = if let Some(worktree) =
1004                            project.read(cx).worktree_for_id(*worktree_id, cx)
1005                        {
1006                            worktree
1007                        } else {
1008                            return false;
1009                        };
1010                        let worktree_state =
1011                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1012                                worktree_state
1013                            } else {
1014                                return true;
1015                            };
1016
1017                        worktree_state.changed_paths.retain(|path, info| {
1018                            if info.is_deleted {
1019                                files_to_delete.push((worktree_state.db_id, path.clone()));
1020                            } else {
1021                                let absolute_path = worktree.read(cx).absolutize(path);
1022                                let job_handle = JobHandle::new(pending_file_count_tx);
1023                                pending_files.push(PendingFile {
1024                                    absolute_path,
1025                                    relative_path: path.clone(),
1026                                    language: None,
1027                                    job_handle,
1028                                    modified_time: info.mtime,
1029                                    worktree_db_id: worktree_state.db_id,
1030                                });
1031                            }
1032
1033                            false
1034                        });
1035                        true
1036                    });
1037
1038                anyhow::Ok(())
1039            })?;
1040
1041            cx.background()
1042                .spawn(async move {
1043                    for (worktree_db_id, path) in files_to_delete {
1044                        db.delete_file(worktree_db_id, path).await.log_err();
1045                    }
1046
1047                    let embeddings_for_digest = {
1048                        let mut files = HashMap::default();
1049                        for pending_file in &pending_files {
1050                            files
1051                                .entry(pending_file.worktree_db_id)
1052                                .or_insert(Vec::new())
1053                                .push(pending_file.relative_path.clone());
1054                        }
1055                        Arc::new(
1056                            db.embeddings_for_files(files)
1057                                .await
1058                                .log_err()
1059                                .unwrap_or_default(),
1060                        )
1061                    };
1062
1063                    for mut pending_file in pending_files {
1064                        if let Ok(language) = language_registry
1065                            .language_for_file(&pending_file.relative_path, None)
1066                            .await
1067                        {
1068                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1069                                && &language.name().as_ref() != &"Markdown"
1070                                && language
1071                                    .grammar()
1072                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1073                                    .is_none()
1074                            {
1075                                continue;
1076                            }
1077                            pending_file.language = Some(language);
1078                        }
1079                        parsing_files_tx
1080                            .try_send((embeddings_for_digest.clone(), pending_file))
1081                            .ok();
1082                    }
1083
1084                    // Wait until we're done indexing.
1085                    while let Some(count) = pending_file_count_rx.next().await {
1086                        if count == 0 {
1087                            break;
1088                        }
1089                    }
1090                })
1091                .await;
1092
1093            this.update(&mut cx, |this, cx| {
1094                let project_state = this
1095                    .projects
1096                    .get_mut(&project.downgrade())
1097                    .ok_or_else(|| anyhow!("project was dropped"))?;
1098                project_state.pending_index -= 1;
1099                cx.notify();
1100                anyhow::Ok(())
1101            })?;
1102
1103            Ok(())
1104        })
1105    }
1106
1107    fn wait_for_worktree_registration(
1108        &self,
1109        project: &ModelHandle<Project>,
1110        cx: &mut ModelContext<Self>,
1111    ) -> Task<Result<()>> {
1112        let project = project.downgrade();
1113        cx.spawn_weak(|this, cx| async move {
1114            loop {
1115                let mut pending_worktrees = Vec::new();
1116                this.upgrade(&cx)
1117                    .ok_or_else(|| anyhow!("semantic index dropped"))?
1118                    .read_with(&cx, |this, _| {
1119                        if let Some(project) = this.projects.get(&project) {
1120                            for worktree in project.worktrees.values() {
1121                                if let WorktreeState::Registering(worktree) = worktree {
1122                                    pending_worktrees.push(worktree.done());
1123                                }
1124                            }
1125                        }
1126                    });
1127
1128                if pending_worktrees.is_empty() {
1129                    break;
1130                } else {
1131                    future::join_all(pending_worktrees).await;
1132                }
1133            }
1134            Ok(())
1135        })
1136    }
1137
1138    async fn embed_spans(
1139        spans: &mut [Span],
1140        embedding_provider: &dyn EmbeddingProvider,
1141        db: &VectorDatabase,
1142    ) -> Result<()> {
1143        let mut batch = Vec::new();
1144        let mut batch_tokens = 0;
1145        let mut embeddings = Vec::new();
1146
1147        let digests = spans
1148            .iter()
1149            .map(|span| span.digest.clone())
1150            .collect::<Vec<_>>();
1151        let embeddings_for_digests = db
1152            .embeddings_for_digests(digests)
1153            .await
1154            .log_err()
1155            .unwrap_or_default();
1156
1157        for span in &*spans {
1158            if embeddings_for_digests.contains_key(&span.digest) {
1159                continue;
1160            };
1161
1162            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1163                let batch_embeddings = embedding_provider
1164                    .embed_batch(mem::take(&mut batch))
1165                    .await?;
1166                embeddings.extend(batch_embeddings);
1167                batch_tokens = 0;
1168            }
1169
1170            batch_tokens += span.token_count;
1171            batch.push(span.content.clone());
1172        }
1173
1174        if !batch.is_empty() {
1175            let batch_embeddings = embedding_provider
1176                .embed_batch(mem::take(&mut batch))
1177                .await?;
1178
1179            embeddings.extend(batch_embeddings);
1180        }
1181
1182        let mut embeddings = embeddings.into_iter();
1183        for span in spans {
1184            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1185                Some(embedding.clone())
1186            } else {
1187                embeddings.next()
1188            };
1189            let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1190            span.embedding = Some(embedding);
1191        }
1192        Ok(())
1193    }
1194}
1195
1196impl Entity for SemanticIndex {
1197    type Event = ();
1198}
1199
1200impl Drop for JobHandle {
1201    fn drop(&mut self) {
1202        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1203            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1204            if let Some(tx) = inner.upgrade() {
1205                let mut tx = tx.lock();
1206                *tx.borrow_mut() -= 1;
1207            }
1208        }
1209    }
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214
1215    use super::*;
1216    #[test]
1217    fn test_job_handle() {
1218        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1219        let tx = Arc::new(Mutex::new(job_count_tx));
1220        let job_handle = JobHandle::new(&tx);
1221
1222        assert_eq!(1, *job_count_rx.borrow());
1223        let new_job_handle = job_handle.clone();
1224        assert_eq!(1, *job_count_rx.borrow());
1225        drop(job_handle);
1226        assert_eq!(1, *job_count_rx.borrow());
1227        drop(new_job_handle);
1228        assert_eq!(0, *job_count_rx.borrow());
1229    }
1230}