semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider};
  11use ai::providers::open_ai::OpenAiEmbeddingProvider;
  12use anyhow::{anyhow, Context as _, Result};
  13use collections::{BTreeMap, HashMap, HashSet};
  14use db::VectorDatabase;
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{
  18    AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
  19    ViewContext, WeakModel,
  20};
  21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  22use lazy_static::lazy_static;
  23use ordered_float::OrderedFloat;
  24use parking_lot::Mutex;
  25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  26use postage::watch;
  27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  28use release_channel::ReleaseChannel;
  29use settings::Settings;
  30use smol::channel;
  31use std::{
  32    cmp::Reverse,
  33    env,
  34    future::Future,
  35    mem,
  36    ops::Range,
  37    path::{Path, PathBuf},
  38    sync::{Arc, Weak},
  39    time::{Duration, Instant, SystemTime},
  40};
  41use util::paths::PathMatcher;
  42use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  43use workspace::Workspace;
  44
  45const SEMANTIC_INDEX_VERSION: usize = 11;
  46const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  47const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  48
  49lazy_static! {
  50    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  51}
  52
  53pub fn init(
  54    fs: Arc<dyn Fs>,
  55    http_client: Arc<dyn HttpClient>,
  56    language_registry: Arc<LanguageRegistry>,
  57    cx: &mut AppContext,
  58) {
  59    SemanticIndexSettings::register(cx);
  60
  61    let db_file_path = EMBEDDINGS_DIR
  62        .join(Path::new(ReleaseChannel::global(cx).dev_name()))
  63        .join("embeddings_db");
  64
  65    cx.observe_new_views(
  66        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
  67            let Some(semantic_index) = SemanticIndex::global(cx) else {
  68                return;
  69            };
  70            let project = workspace.project().clone();
  71
  72            if project.read(cx).is_local() {
  73                cx.app_mut()
  74                    .spawn(|mut cx| async move {
  75                        let previously_indexed = semantic_index
  76                            .update(&mut cx, |index, cx| {
  77                                index.project_previously_indexed(&project, cx)
  78                            })?
  79                            .await?;
  80                        if previously_indexed {
  81                            semantic_index
  82                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
  83                                .await?;
  84                        }
  85                        anyhow::Ok(())
  86                    })
  87                    .detach_and_log_err(cx);
  88            }
  89        },
  90    )
  91    .detach();
  92
  93    cx.spawn(move |cx| async move {
  94        let embedding_provider =
  95            OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
  96        let semantic_index = SemanticIndex::new(
  97            fs,
  98            db_file_path,
  99            Arc::new(embedding_provider),
 100            language_registry,
 101            cx.clone(),
 102        )
 103        .await?;
 104
 105        cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
 106
 107        anyhow::Ok(())
 108    })
 109    .detach();
 110}
 111
 112#[derive(Copy, Clone, Debug)]
 113pub enum SemanticIndexStatus {
 114    NotAuthenticated,
 115    NotIndexed,
 116    Indexed,
 117    Indexing {
 118        remaining_files: usize,
 119        rate_limit_expiry: Option<Instant>,
 120    },
 121}
 122
 123pub struct SemanticIndex {
 124    fs: Arc<dyn Fs>,
 125    db: VectorDatabase,
 126    embedding_provider: Arc<dyn EmbeddingProvider>,
 127    language_registry: Arc<LanguageRegistry>,
 128    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 129    _embedding_task: Task<()>,
 130    _parsing_files_tasks: Vec<Task<()>>,
 131    projects: HashMap<WeakModel<Project>, ProjectState>,
 132}
 133
 134struct GlobalSemanticIndex(Model<SemanticIndex>);
 135
 136impl Global for GlobalSemanticIndex {}
 137
 138struct ProjectState {
 139    worktrees: HashMap<WorktreeId, WorktreeState>,
 140    pending_file_count_rx: watch::Receiver<usize>,
 141    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 142    pending_index: usize,
 143    _subscription: gpui::Subscription,
 144    _observe_pending_file_count: Task<()>,
 145}
 146
 147enum WorktreeState {
 148    Registering(RegisteringWorktreeState),
 149    Registered(RegisteredWorktreeState),
 150}
 151
 152impl WorktreeState {
 153    fn is_registered(&self) -> bool {
 154        matches!(self, Self::Registered(_))
 155    }
 156
 157    fn paths_changed(
 158        &mut self,
 159        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 160        worktree: &Worktree,
 161    ) {
 162        let changed_paths = match self {
 163            Self::Registering(state) => &mut state.changed_paths,
 164            Self::Registered(state) => &mut state.changed_paths,
 165        };
 166
 167        for (path, entry_id, change) in changes.iter() {
 168            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 169                continue;
 170            };
 171            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 172                continue;
 173            }
 174            changed_paths.insert(
 175                path.clone(),
 176                ChangedPathInfo {
 177                    mtime: entry.mtime,
 178                    is_deleted: *change == PathChange::Removed,
 179                },
 180            );
 181        }
 182    }
 183}
 184
 185struct RegisteringWorktreeState {
 186    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 187    done_rx: watch::Receiver<Option<()>>,
 188    _registration: Task<()>,
 189}
 190
 191impl RegisteringWorktreeState {
 192    fn done(&self) -> impl Future<Output = ()> {
 193        let mut done_rx = self.done_rx.clone();
 194        async move {
 195            while let Some(result) = done_rx.next().await {
 196                if result.is_some() {
 197                    break;
 198                }
 199            }
 200        }
 201    }
 202}
 203
 204struct RegisteredWorktreeState {
 205    db_id: i64,
 206    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 207}
 208
 209struct ChangedPathInfo {
 210    mtime: SystemTime,
 211    is_deleted: bool,
 212}
 213
 214#[derive(Clone)]
 215pub struct JobHandle {
 216    /// The outer Arc is here to count the clones of a JobHandle instance;
 217    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 218    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 219}
 220
 221impl JobHandle {
 222    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 223        *tx.lock().borrow_mut() += 1;
 224        Self {
 225            tx: Arc::new(Arc::downgrade(&tx)),
 226        }
 227    }
 228}
 229
 230impl ProjectState {
 231    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 232        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 233        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 234        Self {
 235            worktrees: Default::default(),
 236            pending_file_count_rx: pending_file_count_rx.clone(),
 237            pending_file_count_tx,
 238            pending_index: 0,
 239            _subscription: subscription,
 240            _observe_pending_file_count: cx.spawn({
 241                let mut pending_file_count_rx = pending_file_count_rx.clone();
 242                |this, mut cx| async move {
 243                    while let Some(_) = pending_file_count_rx.next().await {
 244                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
 245                            break;
 246                        }
 247                    }
 248                }
 249            }),
 250        }
 251    }
 252
 253    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 254        self.worktrees
 255            .iter()
 256            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 257                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 258                _ => None,
 259            })
 260    }
 261}
 262
 263#[derive(Clone)]
 264pub struct PendingFile {
 265    worktree_db_id: i64,
 266    relative_path: Arc<Path>,
 267    absolute_path: PathBuf,
 268    language: Option<Arc<Language>>,
 269    modified_time: SystemTime,
 270    job_handle: JobHandle,
 271}
 272
 273#[derive(Clone)]
 274pub struct SearchResult {
 275    pub buffer: Model<Buffer>,
 276    pub range: Range<Anchor>,
 277    pub similarity: OrderedFloat<f32>,
 278}
 279
 280impl SemanticIndex {
 281    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
 282        cx.try_global::<GlobalSemanticIndex>()
 283            .map(|semantic_index| semantic_index.0.clone())
 284    }
 285
 286    pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
 287        if !self.embedding_provider.has_credentials() {
 288            let embedding_provider = self.embedding_provider.clone();
 289            cx.spawn(|cx| async move {
 290                if let Some(retrieve_credentials) = cx
 291                    .update(|cx| embedding_provider.retrieve_credentials(cx))
 292                    .log_err()
 293                {
 294                    retrieve_credentials.await;
 295                }
 296
 297                embedding_provider.has_credentials()
 298            })
 299        } else {
 300            Task::ready(true)
 301        }
 302    }
 303
 304    pub fn is_authenticated(&self) -> bool {
 305        self.embedding_provider.has_credentials()
 306    }
 307
 308    pub fn enabled(cx: &AppContext) -> bool {
 309        SemanticIndexSettings::get_global(cx).enabled
 310    }
 311
 312    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
 313        if !self.is_authenticated() {
 314            return SemanticIndexStatus::NotAuthenticated;
 315        }
 316
 317        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 318            if project_state
 319                .worktrees
 320                .values()
 321                .all(|worktree| worktree.is_registered())
 322                && project_state.pending_index == 0
 323            {
 324                SemanticIndexStatus::Indexed
 325            } else {
 326                SemanticIndexStatus::Indexing {
 327                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 328                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 329                }
 330            }
 331        } else {
 332            SemanticIndexStatus::NotIndexed
 333        }
 334    }
 335
 336    pub async fn new(
 337        fs: Arc<dyn Fs>,
 338        database_path: PathBuf,
 339        embedding_provider: Arc<dyn EmbeddingProvider>,
 340        language_registry: Arc<LanguageRegistry>,
 341        mut cx: AsyncAppContext,
 342    ) -> Result<Model<Self>> {
 343        let t0 = Instant::now();
 344        let database_path = Arc::from(database_path);
 345        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
 346            .await?;
 347
 348        log::trace!(
 349            "db initialization took {:?} milliseconds",
 350            t0.elapsed().as_millis()
 351        );
 352
 353        cx.new_model(|cx| {
 354            let t0 = Instant::now();
 355            let embedding_queue =
 356                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
 357            let _embedding_task = cx.background_executor().spawn({
 358                let embedded_files = embedding_queue.finished_files();
 359                let db = db.clone();
 360                async move {
 361                    while let Ok(file) = embedded_files.recv().await {
 362                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 363                            .await
 364                            .log_err();
 365                    }
 366                }
 367            });
 368
 369            // Parse files into embeddable spans.
 370            let (parsing_files_tx, parsing_files_rx) =
 371                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 372            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 373            let mut _parsing_files_tasks = Vec::new();
 374            for _ in 0..cx.background_executor().num_cpus() {
 375                let fs = fs.clone();
 376                let mut parsing_files_rx = parsing_files_rx.clone();
 377                let embedding_provider = embedding_provider.clone();
 378                let embedding_queue = embedding_queue.clone();
 379                let background = cx.background_executor().clone();
 380                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
 381                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 382                    loop {
 383                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 384                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 385                        futures::select_biased! {
 386                            next_file_to_parse = next_file_to_parse => {
 387                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 388                                    Self::parse_file(
 389                                        &fs,
 390                                        pending_file,
 391                                        &mut retriever,
 392                                        &embedding_queue,
 393                                        &embeddings_for_digest,
 394                                    )
 395                                    .await
 396                                } else {
 397                                    break;
 398                                }
 399                            },
 400                            _ = timer => {
 401                                embedding_queue.lock().flush();
 402                            }
 403                        }
 404                    }
 405                }));
 406            }
 407
 408            log::trace!(
 409                "semantic index task initialization took {:?} milliseconds",
 410                t0.elapsed().as_millis()
 411            );
 412            Self {
 413                fs,
 414                db,
 415                embedding_provider,
 416                language_registry,
 417                parsing_files_tx,
 418                _embedding_task,
 419                _parsing_files_tasks,
 420                projects: Default::default(),
 421            }
 422        })
 423    }
 424
 425    async fn parse_file(
 426        fs: &Arc<dyn Fs>,
 427        pending_file: PendingFile,
 428        retriever: &mut CodeContextRetriever,
 429        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 430        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 431    ) {
 432        let Some(language) = pending_file.language else {
 433            return;
 434        };
 435
 436        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 437            if let Some(mut spans) = retriever
 438                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 439                .log_err()
 440            {
 441                log::trace!(
 442                    "parsed path {:?}: {} spans",
 443                    pending_file.relative_path,
 444                    spans.len()
 445                );
 446
 447                for span in &mut spans {
 448                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 449                        span.embedding = Some(embedding.to_owned());
 450                    }
 451                }
 452
 453                embedding_queue.lock().push(FileToEmbed {
 454                    worktree_id: pending_file.worktree_db_id,
 455                    path: pending_file.relative_path,
 456                    mtime: pending_file.modified_time,
 457                    job_handle: pending_file.job_handle,
 458                    spans,
 459                });
 460            }
 461        }
 462    }
 463
 464    pub fn project_previously_indexed(
 465        &mut self,
 466        project: &Model<Project>,
 467        cx: &mut ModelContext<Self>,
 468    ) -> Task<Result<bool>> {
 469        let worktrees_indexed_previously = project
 470            .read(cx)
 471            .worktrees()
 472            .map(|worktree| {
 473                self.db
 474                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 475            })
 476            .collect::<Vec<_>>();
 477        cx.spawn(|_, _cx| async move {
 478            let worktree_indexed_previously =
 479                futures::future::join_all(worktrees_indexed_previously).await;
 480
 481            Ok(worktree_indexed_previously
 482                .iter()
 483                .filter(|worktree| worktree.is_ok())
 484                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 485        })
 486    }
 487
 488    fn project_entries_changed(
 489        &mut self,
 490        project: Model<Project>,
 491        worktree_id: WorktreeId,
 492        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 493        cx: &mut ModelContext<Self>,
 494    ) {
 495        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 496            return;
 497        };
 498        let project = project.downgrade();
 499        let Some(project_state) = self.projects.get_mut(&project) else {
 500            return;
 501        };
 502
 503        let worktree = worktree.read(cx);
 504        let worktree_state =
 505            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 506                worktree_state
 507            } else {
 508                return;
 509            };
 510        worktree_state.paths_changed(changes, worktree);
 511        if let WorktreeState::Registered(_) = worktree_state {
 512            cx.spawn(|this, mut cx| async move {
 513                cx.background_executor()
 514                    .timer(BACKGROUND_INDEXING_DELAY)
 515                    .await;
 516                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
 517                    this.update(&mut cx, |this, cx| {
 518                        this.index_project(project, cx).detach_and_log_err(cx)
 519                    })?;
 520                }
 521                anyhow::Ok(())
 522            })
 523            .detach_and_log_err(cx);
 524        }
 525    }
 526
 527    fn register_worktree(
 528        &mut self,
 529        project: Model<Project>,
 530        worktree: Model<Worktree>,
 531        cx: &mut ModelContext<Self>,
 532    ) {
 533        let project = project.downgrade();
 534        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 535            project_state
 536        } else {
 537            return;
 538        };
 539        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 540            worktree
 541        } else {
 542            return;
 543        };
 544        let worktree_abs_path = worktree.abs_path().clone();
 545        let scan_complete = worktree.scan_complete();
 546        let worktree_id = worktree.id();
 547        let db = self.db.clone();
 548        let language_registry = self.language_registry.clone();
 549        let (mut done_tx, done_rx) = watch::channel();
 550        let registration = cx.spawn(|this, mut cx| {
 551            async move {
 552                let register = async {
 553                    scan_complete.await;
 554                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 555                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 556                    let worktree = if let Some(project) = project.upgrade() {
 557                        project
 558                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 559                            .ok()
 560                            .flatten()
 561                            .context("worktree not found")?
 562                    } else {
 563                        return anyhow::Ok(());
 564                    };
 565                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
 566                    let mut changed_paths = cx
 567                        .background_executor()
 568                        .spawn(async move {
 569                            let mut changed_paths = BTreeMap::new();
 570                            for file in worktree.files(false, 0) {
 571                                let absolute_path = worktree.absolutize(&file.path)?;
 572
 573                                if file.is_external || file.is_ignored || file.is_symlink {
 574                                    continue;
 575                                }
 576
 577                                if let Ok(language) = language_registry
 578                                    .language_for_file(&absolute_path, None)
 579                                    .await
 580                                {
 581                                    // Test if file is valid parseable file
 582                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 583                                        .contains(&language.name().as_ref())
 584                                        && &language.name().as_ref() != &"Markdown"
 585                                        && language
 586                                            .grammar()
 587                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 588                                            .is_none()
 589                                    {
 590                                        continue;
 591                                    }
 592
 593                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 594                                    let already_stored = stored_mtime
 595                                        .map_or(false, |existing_mtime| {
 596                                            existing_mtime == file.mtime
 597                                        });
 598
 599                                    if !already_stored {
 600                                        changed_paths.insert(
 601                                            file.path.clone(),
 602                                            ChangedPathInfo {
 603                                                mtime: file.mtime,
 604                                                is_deleted: false,
 605                                            },
 606                                        );
 607                                    }
 608                                }
 609                            }
 610
 611                            // Clean up entries from database that are no longer in the worktree.
 612                            for (path, mtime) in file_mtimes {
 613                                changed_paths.insert(
 614                                    path.into(),
 615                                    ChangedPathInfo {
 616                                        mtime,
 617                                        is_deleted: true,
 618                                    },
 619                                );
 620                            }
 621
 622                            anyhow::Ok(changed_paths)
 623                        })
 624                        .await?;
 625                    this.update(&mut cx, |this, cx| {
 626                        let project_state = this
 627                            .projects
 628                            .get_mut(&project)
 629                            .context("project not registered")?;
 630                        let project = project.upgrade().context("project was dropped")?;
 631
 632                        if let Some(WorktreeState::Registering(state)) =
 633                            project_state.worktrees.remove(&worktree_id)
 634                        {
 635                            changed_paths.extend(state.changed_paths);
 636                        }
 637                        project_state.worktrees.insert(
 638                            worktree_id,
 639                            WorktreeState::Registered(RegisteredWorktreeState {
 640                                db_id,
 641                                changed_paths,
 642                            }),
 643                        );
 644                        this.index_project(project, cx).detach_and_log_err(cx);
 645
 646                        anyhow::Ok(())
 647                    })??;
 648
 649                    anyhow::Ok(())
 650                };
 651
 652                if register.await.log_err().is_none() {
 653                    // Stop tracking this worktree if the registration failed.
 654                    this.update(&mut cx, |this, _| {
 655                        this.projects.get_mut(&project).map(|project_state| {
 656                            project_state.worktrees.remove(&worktree_id);
 657                        });
 658                    })
 659                    .ok();
 660                }
 661
 662                *done_tx.borrow_mut() = Some(());
 663            }
 664        });
 665        project_state.worktrees.insert(
 666            worktree_id,
 667            WorktreeState::Registering(RegisteringWorktreeState {
 668                changed_paths: Default::default(),
 669                done_rx,
 670                _registration: registration,
 671            }),
 672        );
 673    }
 674
 675    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
 676        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 677        {
 678            project_state
 679        } else {
 680            return;
 681        };
 682
 683        let mut worktrees = project
 684            .read(cx)
 685            .worktrees()
 686            .filter(|worktree| worktree.read(cx).is_local())
 687            .collect::<Vec<_>>();
 688        let worktree_ids = worktrees
 689            .iter()
 690            .map(|worktree| worktree.read(cx).id())
 691            .collect::<HashSet<_>>();
 692
 693        // Remove worktrees that are no longer present
 694        project_state
 695            .worktrees
 696            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 697
 698        // Register new worktrees
 699        worktrees.retain(|worktree| {
 700            let worktree_id = worktree.read(cx).id();
 701            !project_state.worktrees.contains_key(&worktree_id)
 702        });
 703        for worktree in worktrees {
 704            self.register_worktree(project.clone(), worktree, cx);
 705        }
 706    }
 707
 708    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
 709        Some(
 710            self.projects
 711                .get(&project.downgrade())?
 712                .pending_file_count_rx
 713                .clone(),
 714        )
 715    }
 716
 717    pub fn search_project(
 718        &mut self,
 719        project: Model<Project>,
 720        query: String,
 721        limit: usize,
 722        includes: Vec<PathMatcher>,
 723        excludes: Vec<PathMatcher>,
 724        cx: &mut ModelContext<Self>,
 725    ) -> Task<Result<Vec<SearchResult>>> {
 726        if query.is_empty() {
 727            return Task::ready(Ok(Vec::new()));
 728        }
 729
 730        let index = self.index_project(project.clone(), cx);
 731        let embedding_provider = self.embedding_provider.clone();
 732
 733        cx.spawn(|this, mut cx| async move {
 734            index.await?;
 735            let t0 = Instant::now();
 736
 737            let query = embedding_provider
 738                .embed_batch(vec![query])
 739                .await?
 740                .pop()
 741                .context("could not embed query")?;
 742            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 743
 744            let search_start = Instant::now();
 745            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 746                this.search_modified_buffers(
 747                    &project,
 748                    query.clone(),
 749                    limit,
 750                    &includes,
 751                    &excludes,
 752                    cx,
 753                )
 754            })?;
 755            let file_results = this.update(&mut cx, |this, cx| {
 756                this.search_files(project, query, limit, includes, excludes, cx)
 757            })?;
 758            let (modified_buffer_results, file_results) =
 759                futures::join!(modified_buffer_results, file_results);
 760
 761            // Weave together the results from modified buffers and files.
 762            let mut results = Vec::new();
 763            let mut modified_buffers = HashSet::default();
 764            for result in modified_buffer_results.log_err().unwrap_or_default() {
 765                modified_buffers.insert(result.buffer.clone());
 766                results.push(result);
 767            }
 768            for result in file_results.log_err().unwrap_or_default() {
 769                if !modified_buffers.contains(&result.buffer) {
 770                    results.push(result);
 771                }
 772            }
 773            results.sort_by_key(|result| Reverse(result.similarity));
 774            results.truncate(limit);
 775            log::trace!("Semantic search took {:?}", search_start.elapsed());
 776            Ok(results)
 777        })
 778    }
 779
 780    pub fn search_files(
 781        &mut self,
 782        project: Model<Project>,
 783        query: Embedding,
 784        limit: usize,
 785        includes: Vec<PathMatcher>,
 786        excludes: Vec<PathMatcher>,
 787        cx: &mut ModelContext<Self>,
 788    ) -> Task<Result<Vec<SearchResult>>> {
 789        let db_path = self.db.path().clone();
 790        let fs = self.fs.clone();
 791        cx.spawn(|this, mut cx| async move {
 792            let database = VectorDatabase::new(
 793                fs.clone(),
 794                db_path.clone(),
 795                cx.background_executor().clone(),
 796            )
 797            .await?;
 798
 799            let worktree_db_ids = this.read_with(&cx, |this, _| {
 800                let project_state = this
 801                    .projects
 802                    .get(&project.downgrade())
 803                    .context("project was not indexed")?;
 804                let worktree_db_ids = project_state
 805                    .worktrees
 806                    .values()
 807                    .filter_map(|worktree| {
 808                        if let WorktreeState::Registered(worktree) = worktree {
 809                            Some(worktree.db_id)
 810                        } else {
 811                            None
 812                        }
 813                    })
 814                    .collect::<Vec<i64>>();
 815                anyhow::Ok(worktree_db_ids)
 816            })??;
 817
 818            let file_ids = database
 819                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 820                .await?;
 821
 822            let batch_n = cx.background_executor().num_cpus();
 823            let ids_len = file_ids.clone().len();
 824            let minimum_batch_size = 50;
 825
 826            let batch_size = {
 827                let size = ids_len / batch_n;
 828                if size < minimum_batch_size {
 829                    minimum_batch_size
 830                } else {
 831                    size
 832                }
 833            };
 834
 835            let mut batch_results = Vec::new();
 836            for batch in file_ids.chunks(batch_size) {
 837                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 838                let limit = limit.clone();
 839                let fs = fs.clone();
 840                let db_path = db_path.clone();
 841                let query = query.clone();
 842                if let Some(db) =
 843                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
 844                        .await
 845                        .log_err()
 846                {
 847                    batch_results.push(async move {
 848                        db.top_k_search(&query, limit, batch.as_slice()).await
 849                    });
 850                }
 851            }
 852
 853            let batch_results = futures::future::join_all(batch_results).await;
 854
 855            let mut results = Vec::new();
 856            for batch_result in batch_results {
 857                if batch_result.is_ok() {
 858                    for (id, similarity) in batch_result.unwrap() {
 859                        let ix = match results
 860                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 861                        {
 862                            Ok(ix) => ix,
 863                            Err(ix) => ix,
 864                        };
 865
 866                        results.insert(ix, (id, similarity));
 867                        results.truncate(limit);
 868                    }
 869                }
 870            }
 871
 872            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 873            let scores = results
 874                .into_iter()
 875                .map(|(_, score)| score)
 876                .collect::<Vec<_>>();
 877            let spans = database.spans_for_ids(ids.as_slice()).await?;
 878
 879            let mut tasks = Vec::new();
 880            let mut ranges = Vec::new();
 881            let weak_project = project.downgrade();
 882            project.update(&mut cx, |project, cx| {
 883                let this = this.upgrade().context("index was dropped")?;
 884                for (worktree_db_id, file_path, byte_range) in spans {
 885                    let project_state =
 886                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 887                            state
 888                        } else {
 889                            return Err(anyhow!("project not added"));
 890                        };
 891                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 892                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 893                        ranges.push(byte_range);
 894                    }
 895                }
 896
 897                Ok(())
 898            })??;
 899
 900            let buffers = futures::future::join_all(tasks).await;
 901            Ok(buffers
 902                .into_iter()
 903                .zip(ranges)
 904                .zip(scores)
 905                .filter_map(|((buffer, range), similarity)| {
 906                    let buffer = buffer.log_err()?;
 907                    let range = buffer
 908                        .read_with(&cx, |buffer, _| {
 909                            let start = buffer.clip_offset(range.start, Bias::Left);
 910                            let end = buffer.clip_offset(range.end, Bias::Right);
 911                            buffer.anchor_before(start)..buffer.anchor_after(end)
 912                        })
 913                        .log_err()?;
 914                    Some(SearchResult {
 915                        buffer,
 916                        range,
 917                        similarity,
 918                    })
 919                })
 920                .collect())
 921        })
 922    }
 923
 924    fn search_modified_buffers(
 925        &self,
 926        project: &Model<Project>,
 927        query: Embedding,
 928        limit: usize,
 929        includes: &[PathMatcher],
 930        excludes: &[PathMatcher],
 931        cx: &mut ModelContext<Self>,
 932    ) -> Task<Result<Vec<SearchResult>>> {
 933        let modified_buffers = project
 934            .read(cx)
 935            .opened_buffers()
 936            .into_iter()
 937            .filter_map(|buffer_handle| {
 938                let buffer = buffer_handle.read(cx);
 939                let snapshot = buffer.snapshot();
 940                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 941                    excludes.iter().any(|matcher| matcher.is_match(&path))
 942                });
 943
 944                let included = if includes.len() == 0 {
 945                    true
 946                } else {
 947                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 948                        includes.iter().any(|matcher| matcher.is_match(&path))
 949                    })
 950                };
 951
 952                if buffer.is_dirty() && !excluded && included {
 953                    Some((buffer_handle, snapshot))
 954                } else {
 955                    None
 956                }
 957            })
 958            .collect::<HashMap<_, _>>();
 959
 960        let embedding_provider = self.embedding_provider.clone();
 961        let fs = self.fs.clone();
 962        let db_path = self.db.path().clone();
 963        let background = cx.background_executor().clone();
 964        cx.background_executor().spawn(async move {
 965            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 966            let mut results = Vec::<SearchResult>::new();
 967
 968            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 969            for (buffer, snapshot) in modified_buffers {
 970                let language = snapshot
 971                    .language_at(0)
 972                    .cloned()
 973                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 974                let mut spans = retriever
 975                    .parse_file_with_template(None, &snapshot.text(), language)
 976                    .log_err()
 977                    .unwrap_or_default();
 978                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 979                    .await
 980                    .log_err()
 981                    .is_some()
 982                {
 983                    for span in spans {
 984                        let similarity = span.embedding.unwrap().similarity(&query);
 985                        let ix = match results
 986                            .binary_search_by_key(&Reverse(similarity), |result| {
 987                                Reverse(result.similarity)
 988                            }) {
 989                            Ok(ix) => ix,
 990                            Err(ix) => ix,
 991                        };
 992
 993                        let range = {
 994                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 995                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 996                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 997                        };
 998
 999                        results.insert(
1000                            ix,
1001                            SearchResult {
1002                                buffer: buffer.clone(),
1003                                range,
1004                                similarity,
1005                            },
1006                        );
1007                        results.truncate(limit);
1008                    }
1009                }
1010            }
1011
1012            Ok(results)
1013        })
1014    }
1015
1016    pub fn index_project(
1017        &mut self,
1018        project: Model<Project>,
1019        cx: &mut ModelContext<Self>,
1020    ) -> Task<Result<()>> {
1021        if self.is_authenticated() {
1022            self.index_project_internal(project, cx)
1023        } else {
1024            let authenticate = self.authenticate(cx);
1025            cx.spawn(|this, mut cx| async move {
1026                if authenticate.await {
1027                    this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1028                        .await
1029                } else {
1030                    Err(anyhow!("user is not authenticated"))
1031                }
1032            })
1033        }
1034    }
1035
1036    fn index_project_internal(
1037        &mut self,
1038        project: Model<Project>,
1039        cx: &mut ModelContext<Self>,
1040    ) -> Task<Result<()>> {
1041        if !self.projects.contains_key(&project.downgrade()) {
1042            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1043                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1044                    this.project_worktrees_changed(project.clone(), cx);
1045                }
1046                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1047                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1048                }
1049                _ => {}
1050            });
1051            let project_state = ProjectState::new(subscription, cx);
1052            self.projects.insert(project.downgrade(), project_state);
1053            self.project_worktrees_changed(project.clone(), cx);
1054        }
1055        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1056        project_state.pending_index += 1;
1057        cx.notify();
1058
1059        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1060        let db = self.db.clone();
1061        let language_registry = self.language_registry.clone();
1062        let parsing_files_tx = self.parsing_files_tx.clone();
1063        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1064
1065        cx.spawn(|this, mut cx| async move {
1066            worktree_registration.await?;
1067
1068            let mut pending_files = Vec::new();
1069            let mut files_to_delete = Vec::new();
1070            this.update(&mut cx, |this, cx| {
1071                let project_state = this
1072                    .projects
1073                    .get_mut(&project.downgrade())
1074                    .context("project was dropped")?;
1075                let pending_file_count_tx = &project_state.pending_file_count_tx;
1076
1077                project_state
1078                    .worktrees
1079                    .retain(|worktree_id, worktree_state| {
1080                        let worktree = if let Some(worktree) =
1081                            project.read(cx).worktree_for_id(*worktree_id, cx)
1082                        {
1083                            worktree
1084                        } else {
1085                            return false;
1086                        };
1087                        let worktree_state =
1088                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1089                                worktree_state
1090                            } else {
1091                                return true;
1092                            };
1093
1094                        for (path, info) in &worktree_state.changed_paths {
1095                            if info.is_deleted {
1096                                files_to_delete.push((worktree_state.db_id, path.clone()));
1097                            } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1098                                let job_handle = JobHandle::new(pending_file_count_tx);
1099                                pending_files.push(PendingFile {
1100                                    absolute_path,
1101                                    relative_path: path.clone(),
1102                                    language: None,
1103                                    job_handle,
1104                                    modified_time: info.mtime,
1105                                    worktree_db_id: worktree_state.db_id,
1106                                });
1107                            }
1108                        }
1109                        worktree_state.changed_paths.clear();
1110                        true
1111                    });
1112
1113                anyhow::Ok(())
1114            })??;
1115
1116            cx.background_executor()
1117                .spawn(async move {
1118                    for (worktree_db_id, path) in files_to_delete {
1119                        db.delete_file(worktree_db_id, path).await.log_err();
1120                    }
1121
1122                    let embeddings_for_digest = {
1123                        let mut files = HashMap::default();
1124                        for pending_file in &pending_files {
1125                            files
1126                                .entry(pending_file.worktree_db_id)
1127                                .or_insert(Vec::new())
1128                                .push(pending_file.relative_path.clone());
1129                        }
1130                        Arc::new(
1131                            db.embeddings_for_files(files)
1132                                .await
1133                                .log_err()
1134                                .unwrap_or_default(),
1135                        )
1136                    };
1137
1138                    for mut pending_file in pending_files {
1139                        if let Ok(language) = language_registry
1140                            .language_for_file(&pending_file.relative_path, None)
1141                            .await
1142                        {
1143                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1144                                && &language.name().as_ref() != &"Markdown"
1145                                && language
1146                                    .grammar()
1147                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1148                                    .is_none()
1149                            {
1150                                continue;
1151                            }
1152                            pending_file.language = Some(language);
1153                        }
1154                        parsing_files_tx
1155                            .try_send((embeddings_for_digest.clone(), pending_file))
1156                            .ok();
1157                    }
1158
1159                    // Wait until we're done indexing.
1160                    while let Some(count) = pending_file_count_rx.next().await {
1161                        if count == 0 {
1162                            break;
1163                        }
1164                    }
1165                })
1166                .await;
1167
1168            this.update(&mut cx, |this, cx| {
1169                let project_state = this
1170                    .projects
1171                    .get_mut(&project.downgrade())
1172                    .context("project was dropped")?;
1173                project_state.pending_index -= 1;
1174                cx.notify();
1175                anyhow::Ok(())
1176            })??;
1177
1178            Ok(())
1179        })
1180    }
1181
1182    fn wait_for_worktree_registration(
1183        &self,
1184        project: &Model<Project>,
1185        cx: &mut ModelContext<Self>,
1186    ) -> Task<Result<()>> {
1187        let project = project.downgrade();
1188        cx.spawn(|this, cx| async move {
1189            loop {
1190                let mut pending_worktrees = Vec::new();
1191                this.upgrade()
1192                    .context("semantic index dropped")?
1193                    .read_with(&cx, |this, _| {
1194                        if let Some(project) = this.projects.get(&project) {
1195                            for worktree in project.worktrees.values() {
1196                                if let WorktreeState::Registering(worktree) = worktree {
1197                                    pending_worktrees.push(worktree.done());
1198                                }
1199                            }
1200                        }
1201                    })?;
1202
1203                if pending_worktrees.is_empty() {
1204                    break;
1205                } else {
1206                    future::join_all(pending_worktrees).await;
1207                }
1208            }
1209            Ok(())
1210        })
1211    }
1212
1213    async fn embed_spans(
1214        spans: &mut [Span],
1215        embedding_provider: &dyn EmbeddingProvider,
1216        db: &VectorDatabase,
1217    ) -> Result<()> {
1218        let mut batch = Vec::new();
1219        let mut batch_tokens = 0;
1220        let mut embeddings = Vec::new();
1221
1222        let digests = spans
1223            .iter()
1224            .map(|span| span.digest.clone())
1225            .collect::<Vec<_>>();
1226        let embeddings_for_digests = db
1227            .embeddings_for_digests(digests)
1228            .await
1229            .log_err()
1230            .unwrap_or_default();
1231
1232        for span in &*spans {
1233            if embeddings_for_digests.contains_key(&span.digest) {
1234                continue;
1235            };
1236
1237            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1238                let batch_embeddings = embedding_provider
1239                    .embed_batch(mem::take(&mut batch))
1240                    .await?;
1241                embeddings.extend(batch_embeddings);
1242                batch_tokens = 0;
1243            }
1244
1245            batch_tokens += span.token_count;
1246            batch.push(span.content.clone());
1247        }
1248
1249        if !batch.is_empty() {
1250            let batch_embeddings = embedding_provider
1251                .embed_batch(mem::take(&mut batch))
1252                .await?;
1253
1254            embeddings.extend(batch_embeddings);
1255        }
1256
1257        let mut embeddings = embeddings.into_iter();
1258        for span in spans {
1259            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1260                Some(embedding.clone())
1261            } else {
1262                embeddings.next()
1263            };
1264            let embedding = embedding.context("failed to embed spans")?;
1265            span.embedding = Some(embedding);
1266        }
1267        Ok(())
1268    }
1269}
1270
1271impl Drop for JobHandle {
1272    fn drop(&mut self) {
1273        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1274            // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1275            if let Some(tx) = inner.upgrade() {
1276                let mut tx = tx.lock();
1277                *tx.borrow_mut() -= 1;
1278            }
1279        }
1280    }
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285
1286    use super::*;
1287    #[test]
1288    fn test_job_handle() {
1289        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1290        let tx = Arc::new(Mutex::new(job_count_tx));
1291        let job_handle = JobHandle::new(&tx);
1292
1293        assert_eq!(1, *job_count_rx.borrow());
1294        let new_job_handle = job_handle.clone();
1295        assert_eq!(1, *job_count_rx.borrow());
1296        drop(job_handle);
1297        assert_eq!(1, *job_count_rx.borrow());
1298        drop(new_job_handle);
1299        assert_eq!(0, *job_count_rx.borrow());
1300    }
1301}