semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
  11use anyhow::{anyhow, Result};
  12use collections::{BTreeMap, HashMap, HashSet};
  13use db::VectorDatabase;
  14use embedding_queue::{EmbeddingQueue, FileToEmbed};
  15use futures::{future, FutureExt, StreamExt};
  16use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  17use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  18use lazy_static::lazy_static;
  19use ordered_float::OrderedFloat;
  20use parking_lot::Mutex;
  21use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  22use postage::watch;
  23use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  24use smol::channel;
  25use std::{
  26    cmp::Reverse,
  27    env,
  28    future::Future,
  29    mem,
  30    ops::Range,
  31    path::{Path, PathBuf},
  32    sync::{Arc, Weak},
  33    time::{Duration, Instant, SystemTime},
  34};
  35use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  36use workspace::WorkspaceCreated;
  37
  38const SEMANTIC_INDEX_VERSION: usize = 11;
  39const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  40const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  41
  42lazy_static! {
  43    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  44}
  45
  46pub fn init(
  47    fs: Arc<dyn Fs>,
  48    http_client: Arc<dyn HttpClient>,
  49    language_registry: Arc<LanguageRegistry>,
  50    cx: &mut AppContext,
  51) {
  52    settings::register::<SemanticIndexSettings>(cx);
  53
  54    let db_file_path = EMBEDDINGS_DIR
  55        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  56        .join("embeddings_db");
  57
  58    cx.subscribe_global::<WorkspaceCreated, _>({
  59        move |event, cx| {
  60            let Some(semantic_index) = SemanticIndex::global(cx) else {
  61                return;
  62            };
  63            let workspace = &event.0;
  64            if let Some(workspace) = workspace.upgrade(cx) {
  65                let project = workspace.read(cx).project().clone();
  66                if project.read(cx).is_local() {
  67                    cx.spawn(|mut cx| async move {
  68                        let previously_indexed = semantic_index
  69                            .update(&mut cx, |index, cx| {
  70                                index.project_previously_indexed(&project, cx)
  71                            })
  72                            .await?;
  73                        if previously_indexed {
  74                            semantic_index
  75                                .update(&mut cx, |index, cx| index.index_project(project, cx))
  76                                .await?;
  77                        }
  78                        anyhow::Ok(())
  79                    })
  80                    .detach_and_log_err(cx);
  81                }
  82            }
  83        }
  84    })
  85    .detach();
  86
  87    cx.spawn(move |mut cx| async move {
  88        let semantic_index = SemanticIndex::new(
  89            fs,
  90            db_file_path,
  91            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
  92            language_registry,
  93            cx.clone(),
  94        )
  95        .await?;
  96
  97        cx.update(|cx| {
  98            cx.set_global(semantic_index.clone());
  99        });
 100
 101        anyhow::Ok(())
 102    })
 103    .detach();
 104}
 105
 106#[derive(Copy, Clone, Debug)]
 107pub enum SemanticIndexStatus {
 108    NotAuthenticated,
 109    NotIndexed,
 110    Indexed,
 111    Indexing {
 112        remaining_files: usize,
 113        rate_limit_expiry: Option<Instant>,
 114    },
 115}
 116
 117pub struct SemanticIndex {
 118    fs: Arc<dyn Fs>,
 119    db: VectorDatabase,
 120    embedding_provider: Arc<dyn EmbeddingProvider>,
 121    language_registry: Arc<LanguageRegistry>,
 122    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 123    _embedding_task: Task<()>,
 124    _parsing_files_tasks: Vec<Task<()>>,
 125    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 126}
 127
 128struct ProjectState {
 129    worktrees: HashMap<WorktreeId, WorktreeState>,
 130    pending_file_count_rx: watch::Receiver<usize>,
 131    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 132    pending_index: usize,
 133    _subscription: gpui::Subscription,
 134    _observe_pending_file_count: Task<()>,
 135}
 136
 137enum WorktreeState {
 138    Registering(RegisteringWorktreeState),
 139    Registered(RegisteredWorktreeState),
 140}
 141
 142impl WorktreeState {
 143    fn is_registered(&self) -> bool {
 144        matches!(self, Self::Registered(_))
 145    }
 146
 147    fn paths_changed(
 148        &mut self,
 149        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 150        worktree: &Worktree,
 151    ) {
 152        let changed_paths = match self {
 153            Self::Registering(state) => &mut state.changed_paths,
 154            Self::Registered(state) => &mut state.changed_paths,
 155        };
 156
 157        for (path, entry_id, change) in changes.iter() {
 158            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 159                continue;
 160            };
 161            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 162                continue;
 163            }
 164            changed_paths.insert(
 165                path.clone(),
 166                ChangedPathInfo {
 167                    mtime: entry.mtime,
 168                    is_deleted: *change == PathChange::Removed,
 169                },
 170            );
 171        }
 172    }
 173}
 174
 175struct RegisteringWorktreeState {
 176    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 177    done_rx: watch::Receiver<Option<()>>,
 178    _registration: Task<()>,
 179}
 180
 181impl RegisteringWorktreeState {
 182    fn done(&self) -> impl Future<Output = ()> {
 183        let mut done_rx = self.done_rx.clone();
 184        async move {
 185            while let Some(result) = done_rx.next().await {
 186                if result.is_some() {
 187                    break;
 188                }
 189            }
 190        }
 191    }
 192}
 193
 194struct RegisteredWorktreeState {
 195    db_id: i64,
 196    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 197}
 198
 199struct ChangedPathInfo {
 200    mtime: SystemTime,
 201    is_deleted: bool,
 202}
 203
 204#[derive(Clone)]
 205pub struct JobHandle {
 206    /// The outer Arc is here to count the clones of a JobHandle instance;
 207    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 208    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 209}
 210
 211impl JobHandle {
 212    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 213        *tx.lock().borrow_mut() += 1;
 214        Self {
 215            tx: Arc::new(Arc::downgrade(&tx)),
 216        }
 217    }
 218}
 219
 220impl ProjectState {
 221    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 222        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 223        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 224        Self {
 225            worktrees: Default::default(),
 226            pending_file_count_rx: pending_file_count_rx.clone(),
 227            pending_file_count_tx,
 228            pending_index: 0,
 229            _subscription: subscription,
 230            _observe_pending_file_count: cx.spawn_weak({
 231                let mut pending_file_count_rx = pending_file_count_rx.clone();
 232                |this, mut cx| async move {
 233                    while let Some(_) = pending_file_count_rx.next().await {
 234                        if let Some(this) = this.upgrade(&cx) {
 235                            this.update(&mut cx, |_, cx| cx.notify());
 236                        }
 237                    }
 238                }
 239            }),
 240        }
 241    }
 242
 243    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 244        self.worktrees
 245            .iter()
 246            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 247                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 248                _ => None,
 249            })
 250    }
 251}
 252
 253#[derive(Clone)]
 254pub struct PendingFile {
 255    worktree_db_id: i64,
 256    relative_path: Arc<Path>,
 257    absolute_path: PathBuf,
 258    language: Option<Arc<Language>>,
 259    modified_time: SystemTime,
 260    job_handle: JobHandle,
 261}
 262
 263#[derive(Clone)]
 264pub struct SearchResult {
 265    pub buffer: ModelHandle<Buffer>,
 266    pub range: Range<Anchor>,
 267    pub similarity: OrderedFloat<f32>,
 268}
 269
 270impl SemanticIndex {
 271    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
 272        if cx.has_global::<ModelHandle<Self>>() {
 273            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 274        } else {
 275            None
 276        }
 277    }
 278
 279    pub fn enabled(cx: &AppContext) -> bool {
 280        settings::get::<SemanticIndexSettings>(cx).enabled
 281    }
 282
 283    pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
 284        if !self.embedding_provider.is_authenticated() {
 285            return SemanticIndexStatus::NotAuthenticated;
 286        }
 287
 288        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 289            if project_state
 290                .worktrees
 291                .values()
 292                .all(|worktree| worktree.is_registered())
 293                && project_state.pending_index == 0
 294            {
 295                SemanticIndexStatus::Indexed
 296            } else {
 297                SemanticIndexStatus::Indexing {
 298                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 299                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 300                }
 301            }
 302        } else {
 303            SemanticIndexStatus::NotIndexed
 304        }
 305    }
 306
 307    pub async fn new(
 308        fs: Arc<dyn Fs>,
 309        database_path: PathBuf,
 310        embedding_provider: Arc<dyn EmbeddingProvider>,
 311        language_registry: Arc<LanguageRegistry>,
 312        mut cx: AsyncAppContext,
 313    ) -> Result<ModelHandle<Self>> {
 314        let t0 = Instant::now();
 315        let database_path = Arc::from(database_path);
 316        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 317
 318        log::trace!(
 319            "db initialization took {:?} milliseconds",
 320            t0.elapsed().as_millis()
 321        );
 322
 323        Ok(cx.add_model(|cx| {
 324            let t0 = Instant::now();
 325            let embedding_queue =
 326                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
 327            let _embedding_task = cx.background().spawn({
 328                let embedded_files = embedding_queue.finished_files();
 329                let db = db.clone();
 330                async move {
 331                    while let Ok(file) = embedded_files.recv().await {
 332                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 333                            .await
 334                            .log_err();
 335                    }
 336                }
 337            });
 338
 339            // Parse files into embeddable spans.
 340            let (parsing_files_tx, parsing_files_rx) =
 341                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 342            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 343            let mut _parsing_files_tasks = Vec::new();
 344            for _ in 0..cx.background().num_cpus() {
 345                let fs = fs.clone();
 346                let mut parsing_files_rx = parsing_files_rx.clone();
 347                let embedding_provider = embedding_provider.clone();
 348                let embedding_queue = embedding_queue.clone();
 349                let background = cx.background().clone();
 350                _parsing_files_tasks.push(cx.background().spawn(async move {
 351                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 352                    loop {
 353                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 354                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 355                        futures::select_biased! {
 356                            next_file_to_parse = next_file_to_parse => {
 357                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 358                                    Self::parse_file(
 359                                        &fs,
 360                                        pending_file,
 361                                        &mut retriever,
 362                                        &embedding_queue,
 363                                        &embeddings_for_digest,
 364                                    )
 365                                    .await
 366                                } else {
 367                                    break;
 368                                }
 369                            },
 370                            _ = timer => {
 371                                embedding_queue.lock().flush();
 372                            }
 373                        }
 374                    }
 375                }));
 376            }
 377
 378            log::trace!(
 379                "semantic index task initialization took {:?} milliseconds",
 380                t0.elapsed().as_millis()
 381            );
 382            Self {
 383                fs,
 384                db,
 385                embedding_provider,
 386                language_registry,
 387                parsing_files_tx,
 388                _embedding_task,
 389                _parsing_files_tasks,
 390                projects: Default::default(),
 391            }
 392        }))
 393    }
 394
 395    async fn parse_file(
 396        fs: &Arc<dyn Fs>,
 397        pending_file: PendingFile,
 398        retriever: &mut CodeContextRetriever,
 399        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 400        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 401    ) {
 402        let Some(language) = pending_file.language else {
 403            return;
 404        };
 405
 406        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 407            if let Some(mut spans) = retriever
 408                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 409                .log_err()
 410            {
 411                log::trace!(
 412                    "parsed path {:?}: {} spans",
 413                    pending_file.relative_path,
 414                    spans.len()
 415                );
 416
 417                for span in &mut spans {
 418                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 419                        span.embedding = Some(embedding.to_owned());
 420                    }
 421                }
 422
 423                embedding_queue.lock().push(FileToEmbed {
 424                    worktree_id: pending_file.worktree_db_id,
 425                    path: pending_file.relative_path,
 426                    mtime: pending_file.modified_time,
 427                    job_handle: pending_file.job_handle,
 428                    spans,
 429                });
 430            }
 431        }
 432    }
 433
 434    pub fn project_previously_indexed(
 435        &mut self,
 436        project: &ModelHandle<Project>,
 437        cx: &mut ModelContext<Self>,
 438    ) -> Task<Result<bool>> {
 439        let worktrees_indexed_previously = project
 440            .read(cx)
 441            .worktrees(cx)
 442            .map(|worktree| {
 443                self.db
 444                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 445            })
 446            .collect::<Vec<_>>();
 447        cx.spawn(|_, _cx| async move {
 448            let worktree_indexed_previously =
 449                futures::future::join_all(worktrees_indexed_previously).await;
 450
 451            Ok(worktree_indexed_previously
 452                .iter()
 453                .filter(|worktree| worktree.is_ok())
 454                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 455        })
 456    }
 457
 458    fn project_entries_changed(
 459        &mut self,
 460        project: ModelHandle<Project>,
 461        worktree_id: WorktreeId,
 462        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 463        cx: &mut ModelContext<Self>,
 464    ) {
 465        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 466            return;
 467        };
 468        let project = project.downgrade();
 469        let Some(project_state) = self.projects.get_mut(&project) else {
 470            return;
 471        };
 472
 473        let worktree = worktree.read(cx);
 474        let worktree_state =
 475            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 476                worktree_state
 477            } else {
 478                return;
 479            };
 480        worktree_state.paths_changed(changes, worktree);
 481        if let WorktreeState::Registered(_) = worktree_state {
 482            cx.spawn_weak(|this, mut cx| async move {
 483                cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
 484                if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
 485                    this.update(&mut cx, |this, cx| {
 486                        this.index_project(project, cx).detach_and_log_err(cx)
 487                    });
 488                }
 489            })
 490            .detach();
 491        }
 492    }
 493
 494    fn register_worktree(
 495        &mut self,
 496        project: ModelHandle<Project>,
 497        worktree: ModelHandle<Worktree>,
 498        cx: &mut ModelContext<Self>,
 499    ) {
 500        let project = project.downgrade();
 501        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 502            project_state
 503        } else {
 504            return;
 505        };
 506        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 507            worktree
 508        } else {
 509            return;
 510        };
 511        let worktree_abs_path = worktree.abs_path().clone();
 512        let scan_complete = worktree.scan_complete();
 513        let worktree_id = worktree.id();
 514        let db = self.db.clone();
 515        let language_registry = self.language_registry.clone();
 516        let (mut done_tx, done_rx) = watch::channel();
 517        let registration = cx.spawn(|this, mut cx| {
 518            async move {
 519                let register = async {
 520                    scan_complete.await;
 521                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 522                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 523                    let worktree = if let Some(project) = project.upgrade(&cx) {
 524                        project
 525                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 526                            .ok_or_else(|| anyhow!("worktree not found"))?
 527                    } else {
 528                        return anyhow::Ok(());
 529                    };
 530                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
 531                    let mut changed_paths = cx
 532                        .background()
 533                        .spawn(async move {
 534                            let mut changed_paths = BTreeMap::new();
 535                            for file in worktree.files(false, 0) {
 536                                let absolute_path = worktree.absolutize(&file.path);
 537
 538                                if file.is_external || file.is_ignored || file.is_symlink {
 539                                    continue;
 540                                }
 541
 542                                if let Ok(language) = language_registry
 543                                    .language_for_file(&absolute_path, None)
 544                                    .await
 545                                {
 546                                    // Test if file is valid parseable file
 547                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 548                                        .contains(&language.name().as_ref())
 549                                        && &language.name().as_ref() != &"Markdown"
 550                                        && language
 551                                            .grammar()
 552                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 553                                            .is_none()
 554                                    {
 555                                        continue;
 556                                    }
 557
 558                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 559                                    let already_stored = stored_mtime
 560                                        .map_or(false, |existing_mtime| {
 561                                            existing_mtime == file.mtime
 562                                        });
 563
 564                                    if !already_stored {
 565                                        changed_paths.insert(
 566                                            file.path.clone(),
 567                                            ChangedPathInfo {
 568                                                mtime: file.mtime,
 569                                                is_deleted: false,
 570                                            },
 571                                        );
 572                                    }
 573                                }
 574                            }
 575
 576                            // Clean up entries from database that are no longer in the worktree.
 577                            for (path, mtime) in file_mtimes {
 578                                changed_paths.insert(
 579                                    path.into(),
 580                                    ChangedPathInfo {
 581                                        mtime,
 582                                        is_deleted: true,
 583                                    },
 584                                );
 585                            }
 586
 587                            anyhow::Ok(changed_paths)
 588                        })
 589                        .await?;
 590                    this.update(&mut cx, |this, cx| {
 591                        let project_state = this
 592                            .projects
 593                            .get_mut(&project)
 594                            .ok_or_else(|| anyhow!("project not registered"))?;
 595                        let project = project
 596                            .upgrade(cx)
 597                            .ok_or_else(|| anyhow!("project was dropped"))?;
 598
 599                        if let Some(WorktreeState::Registering(state)) =
 600                            project_state.worktrees.remove(&worktree_id)
 601                        {
 602                            changed_paths.extend(state.changed_paths);
 603                        }
 604                        project_state.worktrees.insert(
 605                            worktree_id,
 606                            WorktreeState::Registered(RegisteredWorktreeState {
 607                                db_id,
 608                                changed_paths,
 609                            }),
 610                        );
 611                        this.index_project(project, cx).detach_and_log_err(cx);
 612
 613                        anyhow::Ok(())
 614                    })?;
 615
 616                    anyhow::Ok(())
 617                };
 618
 619                if register.await.log_err().is_none() {
 620                    // Stop tracking this worktree if the registration failed.
 621                    this.update(&mut cx, |this, _| {
 622                        this.projects.get_mut(&project).map(|project_state| {
 623                            project_state.worktrees.remove(&worktree_id);
 624                        });
 625                    })
 626                }
 627
 628                *done_tx.borrow_mut() = Some(());
 629            }
 630        });
 631        project_state.worktrees.insert(
 632            worktree_id,
 633            WorktreeState::Registering(RegisteringWorktreeState {
 634                changed_paths: Default::default(),
 635                done_rx,
 636                _registration: registration,
 637            }),
 638        );
 639    }
 640
 641    fn project_worktrees_changed(
 642        &mut self,
 643        project: ModelHandle<Project>,
 644        cx: &mut ModelContext<Self>,
 645    ) {
 646        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 647        {
 648            project_state
 649        } else {
 650            return;
 651        };
 652
 653        let mut worktrees = project
 654            .read(cx)
 655            .worktrees(cx)
 656            .filter(|worktree| worktree.read(cx).is_local())
 657            .collect::<Vec<_>>();
 658        let worktree_ids = worktrees
 659            .iter()
 660            .map(|worktree| worktree.read(cx).id())
 661            .collect::<HashSet<_>>();
 662
 663        // Remove worktrees that are no longer present
 664        project_state
 665            .worktrees
 666            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 667
 668        // Register new worktrees
 669        worktrees.retain(|worktree| {
 670            let worktree_id = worktree.read(cx).id();
 671            !project_state.worktrees.contains_key(&worktree_id)
 672        });
 673        for worktree in worktrees {
 674            self.register_worktree(project.clone(), worktree, cx);
 675        }
 676    }
 677
 678    pub fn pending_file_count(
 679        &self,
 680        project: &ModelHandle<Project>,
 681    ) -> Option<watch::Receiver<usize>> {
 682        Some(
 683            self.projects
 684                .get(&project.downgrade())?
 685                .pending_file_count_rx
 686                .clone(),
 687        )
 688    }
 689
 690    pub fn search_project(
 691        &mut self,
 692        project: ModelHandle<Project>,
 693        query: String,
 694        limit: usize,
 695        includes: Vec<PathMatcher>,
 696        excludes: Vec<PathMatcher>,
 697        cx: &mut ModelContext<Self>,
 698    ) -> Task<Result<Vec<SearchResult>>> {
 699        if query.is_empty() {
 700            return Task::ready(Ok(Vec::new()));
 701        }
 702
 703        let index = self.index_project(project.clone(), cx);
 704        let embedding_provider = self.embedding_provider.clone();
 705
 706        cx.spawn(|this, mut cx| async move {
 707            index.await?;
 708            let t0 = Instant::now();
 709            let query = embedding_provider
 710                .embed_batch(vec![query])
 711                .await?
 712                .pop()
 713                .ok_or_else(|| anyhow!("could not embed query"))?;
 714            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 715
 716            let search_start = Instant::now();
 717            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 718                this.search_modified_buffers(
 719                    &project,
 720                    query.clone(),
 721                    limit,
 722                    &includes,
 723                    &excludes,
 724                    cx,
 725                )
 726            });
 727            let file_results = this.update(&mut cx, |this, cx| {
 728                this.search_files(project, query, limit, includes, excludes, cx)
 729            });
 730            let (modified_buffer_results, file_results) =
 731                futures::join!(modified_buffer_results, file_results);
 732
 733            // Weave together the results from modified buffers and files.
 734            let mut results = Vec::new();
 735            let mut modified_buffers = HashSet::default();
 736            for result in modified_buffer_results.log_err().unwrap_or_default() {
 737                modified_buffers.insert(result.buffer.clone());
 738                results.push(result);
 739            }
 740            for result in file_results.log_err().unwrap_or_default() {
 741                if !modified_buffers.contains(&result.buffer) {
 742                    results.push(result);
 743                }
 744            }
 745            results.sort_by_key(|result| Reverse(result.similarity));
 746            results.truncate(limit);
 747            log::trace!("Semantic search took {:?}", search_start.elapsed());
 748            Ok(results)
 749        })
 750    }
 751
 752    pub fn search_files(
 753        &mut self,
 754        project: ModelHandle<Project>,
 755        query: Embedding,
 756        limit: usize,
 757        includes: Vec<PathMatcher>,
 758        excludes: Vec<PathMatcher>,
 759        cx: &mut ModelContext<Self>,
 760    ) -> Task<Result<Vec<SearchResult>>> {
 761        let db_path = self.db.path().clone();
 762        let fs = self.fs.clone();
 763        cx.spawn(|this, mut cx| async move {
 764            let database =
 765                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 766
 767            let worktree_db_ids = this.read_with(&cx, |this, _| {
 768                let project_state = this
 769                    .projects
 770                    .get(&project.downgrade())
 771                    .ok_or_else(|| anyhow!("project was not indexed"))?;
 772                let worktree_db_ids = project_state
 773                    .worktrees
 774                    .values()
 775                    .filter_map(|worktree| {
 776                        if let WorktreeState::Registered(worktree) = worktree {
 777                            Some(worktree.db_id)
 778                        } else {
 779                            None
 780                        }
 781                    })
 782                    .collect::<Vec<i64>>();
 783                anyhow::Ok(worktree_db_ids)
 784            })?;
 785
 786            let file_ids = database
 787                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 788                .await?;
 789
 790            let batch_n = cx.background().num_cpus();
 791            let ids_len = file_ids.clone().len();
 792            let minimum_batch_size = 50;
 793
 794            let batch_size = {
 795                let size = ids_len / batch_n;
 796                if size < minimum_batch_size {
 797                    minimum_batch_size
 798                } else {
 799                    size
 800                }
 801            };
 802
 803            let mut batch_results = Vec::new();
 804            for batch in file_ids.chunks(batch_size) {
 805                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 806                let limit = limit.clone();
 807                let fs = fs.clone();
 808                let db_path = db_path.clone();
 809                let query = query.clone();
 810                if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
 811                    .await
 812                    .log_err()
 813                {
 814                    batch_results.push(async move {
 815                        db.top_k_search(&query, limit, batch.as_slice()).await
 816                    });
 817                }
 818            }
 819
 820            let batch_results = futures::future::join_all(batch_results).await;
 821
 822            let mut results = Vec::new();
 823            for batch_result in batch_results {
 824                if batch_result.is_ok() {
 825                    for (id, similarity) in batch_result.unwrap() {
 826                        let ix = match results
 827                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 828                        {
 829                            Ok(ix) => ix,
 830                            Err(ix) => ix,
 831                        };
 832
 833                        results.insert(ix, (id, similarity));
 834                        results.truncate(limit);
 835                    }
 836                }
 837            }
 838
 839            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 840            let scores = results
 841                .into_iter()
 842                .map(|(_, score)| score)
 843                .collect::<Vec<_>>();
 844            let spans = database.spans_for_ids(ids.as_slice()).await?;
 845
 846            let mut tasks = Vec::new();
 847            let mut ranges = Vec::new();
 848            let weak_project = project.downgrade();
 849            project.update(&mut cx, |project, cx| {
 850                for (worktree_db_id, file_path, byte_range) in spans {
 851                    let project_state =
 852                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 853                            state
 854                        } else {
 855                            return Err(anyhow!("project not added"));
 856                        };
 857                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 858                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 859                        ranges.push(byte_range);
 860                    }
 861                }
 862
 863                Ok(())
 864            })?;
 865
 866            let buffers = futures::future::join_all(tasks).await;
 867            Ok(buffers
 868                .into_iter()
 869                .zip(ranges)
 870                .zip(scores)
 871                .filter_map(|((buffer, range), similarity)| {
 872                    let buffer = buffer.log_err()?;
 873                    let range = buffer.read_with(&cx, |buffer, _| {
 874                        let start = buffer.clip_offset(range.start, Bias::Left);
 875                        let end = buffer.clip_offset(range.end, Bias::Right);
 876                        buffer.anchor_before(start)..buffer.anchor_after(end)
 877                    });
 878                    Some(SearchResult {
 879                        buffer,
 880                        range,
 881                        similarity,
 882                    })
 883                })
 884                .collect())
 885        })
 886    }
 887
 888    fn search_modified_buffers(
 889        &self,
 890        project: &ModelHandle<Project>,
 891        query: Embedding,
 892        limit: usize,
 893        includes: &[PathMatcher],
 894        excludes: &[PathMatcher],
 895        cx: &mut ModelContext<Self>,
 896    ) -> Task<Result<Vec<SearchResult>>> {
 897        let modified_buffers = project
 898            .read(cx)
 899            .opened_buffers(cx)
 900            .into_iter()
 901            .filter_map(|buffer_handle| {
 902                let buffer = buffer_handle.read(cx);
 903                let snapshot = buffer.snapshot();
 904                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 905                    excludes.iter().any(|matcher| matcher.is_match(&path))
 906                });
 907
 908                let included = if includes.len() == 0 {
 909                    true
 910                } else {
 911                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 912                        includes.iter().any(|matcher| matcher.is_match(&path))
 913                    })
 914                };
 915
 916                if buffer.is_dirty() && !excluded && included {
 917                    Some((buffer_handle, snapshot))
 918                } else {
 919                    None
 920                }
 921            })
 922            .collect::<HashMap<_, _>>();
 923
 924        let embedding_provider = self.embedding_provider.clone();
 925        let fs = self.fs.clone();
 926        let db_path = self.db.path().clone();
 927        let background = cx.background().clone();
 928        cx.background().spawn(async move {
 929            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 930            let mut results = Vec::<SearchResult>::new();
 931
 932            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 933            for (buffer, snapshot) in modified_buffers {
 934                let language = snapshot
 935                    .language_at(0)
 936                    .cloned()
 937                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 938                let mut spans = retriever
 939                    .parse_file_with_template(None, &snapshot.text(), language)
 940                    .log_err()
 941                    .unwrap_or_default();
 942                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 943                    .await
 944                    .log_err()
 945                    .is_some()
 946                {
 947                    for span in spans {
 948                        let similarity = span.embedding.unwrap().similarity(&query);
 949                        let ix = match results
 950                            .binary_search_by_key(&Reverse(similarity), |result| {
 951                                Reverse(result.similarity)
 952                            }) {
 953                            Ok(ix) => ix,
 954                            Err(ix) => ix,
 955                        };
 956
 957                        let range = {
 958                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 959                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 960                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 961                        };
 962
 963                        results.insert(
 964                            ix,
 965                            SearchResult {
 966                                buffer: buffer.clone(),
 967                                range,
 968                                similarity,
 969                            },
 970                        );
 971                        results.truncate(limit);
 972                    }
 973                }
 974            }
 975
 976            Ok(results)
 977        })
 978    }
 979
 980    pub fn index_project(
 981        &mut self,
 982        project: ModelHandle<Project>,
 983        cx: &mut ModelContext<Self>,
 984    ) -> Task<Result<()>> {
 985        if !self.embedding_provider.is_authenticated() {
 986            return Task::ready(Err(anyhow!("user is not authenticated")));
 987        }
 988
 989        if !self.projects.contains_key(&project.downgrade()) {
 990            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
 991                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
 992                    this.project_worktrees_changed(project.clone(), cx);
 993                }
 994                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
 995                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
 996                }
 997                _ => {}
 998            });
 999            let project_state = ProjectState::new(subscription, cx);
1000            self.projects.insert(project.downgrade(), project_state);
1001            self.project_worktrees_changed(project.clone(), cx);
1002        }
1003        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1004        project_state.pending_index += 1;
1005        cx.notify();
1006
1007        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1008        let db = self.db.clone();
1009        let language_registry = self.language_registry.clone();
1010        let parsing_files_tx = self.parsing_files_tx.clone();
1011        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1012
1013        cx.spawn(|this, mut cx| async move {
1014            worktree_registration.await?;
1015
1016            let mut pending_files = Vec::new();
1017            let mut files_to_delete = Vec::new();
1018            this.update(&mut cx, |this, cx| {
1019                let project_state = this
1020                    .projects
1021                    .get_mut(&project.downgrade())
1022                    .ok_or_else(|| anyhow!("project was dropped"))?;
1023                let pending_file_count_tx = &project_state.pending_file_count_tx;
1024
1025                project_state
1026                    .worktrees
1027                    .retain(|worktree_id, worktree_state| {
1028                        let worktree = if let Some(worktree) =
1029                            project.read(cx).worktree_for_id(*worktree_id, cx)
1030                        {
1031                            worktree
1032                        } else {
1033                            return false;
1034                        };
1035                        let worktree_state =
1036                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1037                                worktree_state
1038                            } else {
1039                                return true;
1040                            };
1041
1042                        worktree_state.changed_paths.retain(|path, info| {
1043                            if info.is_deleted {
1044                                files_to_delete.push((worktree_state.db_id, path.clone()));
1045                            } else {
1046                                let absolute_path = worktree.read(cx).absolutize(path);
1047                                let job_handle = JobHandle::new(pending_file_count_tx);
1048                                pending_files.push(PendingFile {
1049                                    absolute_path,
1050                                    relative_path: path.clone(),
1051                                    language: None,
1052                                    job_handle,
1053                                    modified_time: info.mtime,
1054                                    worktree_db_id: worktree_state.db_id,
1055                                });
1056                            }
1057
1058                            false
1059                        });
1060                        true
1061                    });
1062
1063                anyhow::Ok(())
1064            })?;
1065
1066            cx.background()
1067                .spawn(async move {
1068                    for (worktree_db_id, path) in files_to_delete {
1069                        db.delete_file(worktree_db_id, path).await.log_err();
1070                    }
1071
1072                    let embeddings_for_digest = {
1073                        let mut files = HashMap::default();
1074                        for pending_file in &pending_files {
1075                            files
1076                                .entry(pending_file.worktree_db_id)
1077                                .or_insert(Vec::new())
1078                                .push(pending_file.relative_path.clone());
1079                        }
1080                        Arc::new(
1081                            db.embeddings_for_files(files)
1082                                .await
1083                                .log_err()
1084                                .unwrap_or_default(),
1085                        )
1086                    };
1087
1088                    for mut pending_file in pending_files {
1089                        if let Ok(language) = language_registry
1090                            .language_for_file(&pending_file.relative_path, None)
1091                            .await
1092                        {
1093                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1094                                && &language.name().as_ref() != &"Markdown"
1095                                && language
1096                                    .grammar()
1097                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1098                                    .is_none()
1099                            {
1100                                continue;
1101                            }
1102                            pending_file.language = Some(language);
1103                        }
1104                        parsing_files_tx
1105                            .try_send((embeddings_for_digest.clone(), pending_file))
1106                            .ok();
1107                    }
1108
1109                    // Wait until we're done indexing.
1110                    while let Some(count) = pending_file_count_rx.next().await {
1111                        if count == 0 {
1112                            break;
1113                        }
1114                    }
1115                })
1116                .await;
1117
1118            this.update(&mut cx, |this, cx| {
1119                let project_state = this
1120                    .projects
1121                    .get_mut(&project.downgrade())
1122                    .ok_or_else(|| anyhow!("project was dropped"))?;
1123                project_state.pending_index -= 1;
1124                cx.notify();
1125                anyhow::Ok(())
1126            })?;
1127
1128            Ok(())
1129        })
1130    }
1131
1132    fn wait_for_worktree_registration(
1133        &self,
1134        project: &ModelHandle<Project>,
1135        cx: &mut ModelContext<Self>,
1136    ) -> Task<Result<()>> {
1137        let project = project.downgrade();
1138        cx.spawn_weak(|this, cx| async move {
1139            loop {
1140                let mut pending_worktrees = Vec::new();
1141                this.upgrade(&cx)
1142                    .ok_or_else(|| anyhow!("semantic index dropped"))?
1143                    .read_with(&cx, |this, _| {
1144                        if let Some(project) = this.projects.get(&project) {
1145                            for worktree in project.worktrees.values() {
1146                                if let WorktreeState::Registering(worktree) = worktree {
1147                                    pending_worktrees.push(worktree.done());
1148                                }
1149                            }
1150                        }
1151                    });
1152
1153                if pending_worktrees.is_empty() {
1154                    break;
1155                } else {
1156                    future::join_all(pending_worktrees).await;
1157                }
1158            }
1159            Ok(())
1160        })
1161    }
1162
1163    async fn embed_spans(
1164        spans: &mut [Span],
1165        embedding_provider: &dyn EmbeddingProvider,
1166        db: &VectorDatabase,
1167    ) -> Result<()> {
1168        let mut batch = Vec::new();
1169        let mut batch_tokens = 0;
1170        let mut embeddings = Vec::new();
1171
1172        let digests = spans
1173            .iter()
1174            .map(|span| span.digest.clone())
1175            .collect::<Vec<_>>();
1176        let embeddings_for_digests = db
1177            .embeddings_for_digests(digests)
1178            .await
1179            .log_err()
1180            .unwrap_or_default();
1181
1182        for span in &*spans {
1183            if embeddings_for_digests.contains_key(&span.digest) {
1184                continue;
1185            };
1186
1187            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1188                let batch_embeddings = embedding_provider
1189                    .embed_batch(mem::take(&mut batch))
1190                    .await?;
1191                embeddings.extend(batch_embeddings);
1192                batch_tokens = 0;
1193            }
1194
1195            batch_tokens += span.token_count;
1196            batch.push(span.content.clone());
1197        }
1198
1199        if !batch.is_empty() {
1200            let batch_embeddings = embedding_provider
1201                .embed_batch(mem::take(&mut batch))
1202                .await?;
1203
1204            embeddings.extend(batch_embeddings);
1205        }
1206
1207        let mut embeddings = embeddings.into_iter();
1208        for span in spans {
1209            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1210                Some(embedding.clone())
1211            } else {
1212                embeddings.next()
1213            };
1214            let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1215            span.embedding = Some(embedding);
1216        }
1217        Ok(())
1218    }
1219}
1220
1221impl Entity for SemanticIndex {
1222    type Event = ();
1223}
1224
1225impl Drop for JobHandle {
1226    fn drop(&mut self) {
1227        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1228            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1229            if let Some(tx) = inner.upgrade() {
1230                let mut tx = tx.lock();
1231                *tx.borrow_mut() -= 1;
1232            }
1233        }
1234    }
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239
1240    use super::*;
1241    #[test]
1242    fn test_job_handle() {
1243        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1244        let tx = Arc::new(Mutex::new(job_count_tx));
1245        let job_handle = JobHandle::new(&tx);
1246
1247        assert_eq!(1, *job_count_rx.borrow());
1248        let new_job_handle = job_handle.clone();
1249        assert_eq!(1, *job_count_rx.borrow());
1250        drop(job_handle);
1251        assert_eq!(1, *job_count_rx.borrow());
1252        drop(new_job_handle);
1253        assert_eq!(0, *job_count_rx.borrow());
1254    }
1255}