semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider};
  11use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
  12use anyhow::{anyhow, Context as _, Result};
  13use collections::{BTreeMap, HashMap, HashSet};
  14use db::VectorDatabase;
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{
  18    AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
  19    ViewContext, WeakModel,
  20};
  21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  22use lazy_static::lazy_static;
  23use ordered_float::OrderedFloat;
  24use parking_lot::Mutex;
  25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  26use postage::watch;
  27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  28use release_channel::ReleaseChannel;
  29use settings::Settings;
  30use smol::channel;
  31use std::{
  32    cmp::Reverse,
  33    env,
  34    future::Future,
  35    mem,
  36    ops::Range,
  37    path::{Path, PathBuf},
  38    sync::{Arc, Weak},
  39    time::{Duration, Instant, SystemTime},
  40};
  41use util::paths::PathMatcher;
  42use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  43use workspace::Workspace;
  44
  45const SEMANTIC_INDEX_VERSION: usize = 11;
  46const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  47const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  48
  49lazy_static! {
  50    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  51}
  52
  53pub fn init(
  54    fs: Arc<dyn Fs>,
  55    http_client: Arc<dyn HttpClient>,
  56    language_registry: Arc<LanguageRegistry>,
  57    cx: &mut AppContext,
  58) {
  59    SemanticIndexSettings::register(cx);
  60
  61    let db_file_path = EMBEDDINGS_DIR
  62        .join(Path::new(ReleaseChannel::global(cx).dev_name()))
  63        .join("embeddings_db");
  64
  65    cx.observe_new_views(
  66        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
  67            let Some(semantic_index) = SemanticIndex::global(cx) else {
  68                return;
  69            };
  70            let project = workspace.project().clone();
  71
  72            if project.read(cx).is_local() {
  73                cx.app_mut()
  74                    .spawn(|mut cx| async move {
  75                        let previously_indexed = semantic_index
  76                            .update(&mut cx, |index, cx| {
  77                                index.project_previously_indexed(&project, cx)
  78                            })?
  79                            .await?;
  80                        if previously_indexed {
  81                            semantic_index
  82                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
  83                                .await?;
  84                        }
  85                        anyhow::Ok(())
  86                    })
  87                    .detach_and_log_err(cx);
  88            }
  89        },
  90    )
  91    .detach();
  92
  93    cx.spawn(move |cx| async move {
  94        let embedding_provider = OpenAiEmbeddingProvider::new(
  95            // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
  96            OPEN_AI_API_URL.to_string(),
  97            http_client,
  98            cx.background_executor().clone(),
  99        )
 100        .await;
 101        let semantic_index = SemanticIndex::new(
 102            fs,
 103            db_file_path,
 104            Arc::new(embedding_provider),
 105            language_registry,
 106            cx.clone(),
 107        )
 108        .await?;
 109
 110        cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
 111
 112        anyhow::Ok(())
 113    })
 114    .detach();
 115}
 116
 117#[derive(Copy, Clone, Debug)]
 118pub enum SemanticIndexStatus {
 119    NotAuthenticated,
 120    NotIndexed,
 121    Indexed,
 122    Indexing {
 123        remaining_files: usize,
 124        rate_limit_expiry: Option<Instant>,
 125    },
 126}
 127
 128pub struct SemanticIndex {
 129    fs: Arc<dyn Fs>,
 130    db: VectorDatabase,
 131    embedding_provider: Arc<dyn EmbeddingProvider>,
 132    language_registry: Arc<LanguageRegistry>,
 133    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 134    _embedding_task: Task<()>,
 135    _parsing_files_tasks: Vec<Task<()>>,
 136    projects: HashMap<WeakModel<Project>, ProjectState>,
 137}
 138
 139struct GlobalSemanticIndex(Model<SemanticIndex>);
 140
 141impl Global for GlobalSemanticIndex {}
 142
 143struct ProjectState {
 144    worktrees: HashMap<WorktreeId, WorktreeState>,
 145    pending_file_count_rx: watch::Receiver<usize>,
 146    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 147    pending_index: usize,
 148    _subscription: gpui::Subscription,
 149    _observe_pending_file_count: Task<()>,
 150}
 151
 152enum WorktreeState {
 153    Registering(RegisteringWorktreeState),
 154    Registered(RegisteredWorktreeState),
 155}
 156
 157impl WorktreeState {
 158    fn is_registered(&self) -> bool {
 159        matches!(self, Self::Registered(_))
 160    }
 161
 162    fn paths_changed(
 163        &mut self,
 164        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 165        worktree: &Worktree,
 166    ) {
 167        let changed_paths = match self {
 168            Self::Registering(state) => &mut state.changed_paths,
 169            Self::Registered(state) => &mut state.changed_paths,
 170        };
 171
 172        for (path, entry_id, change) in changes.iter() {
 173            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 174                continue;
 175            };
 176            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 177                continue;
 178            }
 179            changed_paths.insert(
 180                path.clone(),
 181                ChangedPathInfo {
 182                    mtime: entry.mtime,
 183                    is_deleted: *change == PathChange::Removed,
 184                },
 185            );
 186        }
 187    }
 188}
 189
 190struct RegisteringWorktreeState {
 191    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 192    done_rx: watch::Receiver<Option<()>>,
 193    _registration: Task<()>,
 194}
 195
 196impl RegisteringWorktreeState {
 197    fn done(&self) -> impl Future<Output = ()> {
 198        let mut done_rx = self.done_rx.clone();
 199        async move {
 200            while let Some(result) = done_rx.next().await {
 201                if result.is_some() {
 202                    break;
 203                }
 204            }
 205        }
 206    }
 207}
 208
 209struct RegisteredWorktreeState {
 210    db_id: i64,
 211    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 212}
 213
 214struct ChangedPathInfo {
 215    mtime: SystemTime,
 216    is_deleted: bool,
 217}
 218
 219#[derive(Clone)]
 220pub struct JobHandle {
 221    /// The outer Arc is here to count the clones of a JobHandle instance;
 222    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 223    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 224}
 225
 226impl JobHandle {
 227    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 228        *tx.lock().borrow_mut() += 1;
 229        Self {
 230            tx: Arc::new(Arc::downgrade(&tx)),
 231        }
 232    }
 233}
 234
 235impl ProjectState {
 236    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 237        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 238        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 239        Self {
 240            worktrees: Default::default(),
 241            pending_file_count_rx: pending_file_count_rx.clone(),
 242            pending_file_count_tx,
 243            pending_index: 0,
 244            _subscription: subscription,
 245            _observe_pending_file_count: cx.spawn({
 246                let mut pending_file_count_rx = pending_file_count_rx.clone();
 247                |this, mut cx| async move {
 248                    while let Some(_) = pending_file_count_rx.next().await {
 249                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
 250                            break;
 251                        }
 252                    }
 253                }
 254            }),
 255        }
 256    }
 257
 258    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 259        self.worktrees
 260            .iter()
 261            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 262                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 263                _ => None,
 264            })
 265    }
 266}
 267
 268#[derive(Clone)]
 269pub struct PendingFile {
 270    worktree_db_id: i64,
 271    relative_path: Arc<Path>,
 272    absolute_path: PathBuf,
 273    language: Option<Arc<Language>>,
 274    modified_time: SystemTime,
 275    job_handle: JobHandle,
 276}
 277
 278#[derive(Clone)]
 279pub struct SearchResult {
 280    pub buffer: Model<Buffer>,
 281    pub range: Range<Anchor>,
 282    pub similarity: OrderedFloat<f32>,
 283}
 284
 285impl SemanticIndex {
 286    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
 287        cx.try_global::<GlobalSemanticIndex>()
 288            .map(|semantic_index| semantic_index.0.clone())
 289    }
 290
 291    pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
 292        if !self.embedding_provider.has_credentials() {
 293            let embedding_provider = self.embedding_provider.clone();
 294            cx.spawn(|cx| async move {
 295                if let Some(retrieve_credentials) = cx
 296                    .update(|cx| embedding_provider.retrieve_credentials(cx))
 297                    .log_err()
 298                {
 299                    retrieve_credentials.await;
 300                }
 301
 302                embedding_provider.has_credentials()
 303            })
 304        } else {
 305            Task::ready(true)
 306        }
 307    }
 308
 309    pub fn is_authenticated(&self) -> bool {
 310        self.embedding_provider.has_credentials()
 311    }
 312
 313    pub fn enabled(cx: &AppContext) -> bool {
 314        SemanticIndexSettings::get_global(cx).enabled
 315    }
 316
 317    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
 318        if !self.is_authenticated() {
 319            return SemanticIndexStatus::NotAuthenticated;
 320        }
 321
 322        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 323            if project_state
 324                .worktrees
 325                .values()
 326                .all(|worktree| worktree.is_registered())
 327                && project_state.pending_index == 0
 328            {
 329                SemanticIndexStatus::Indexed
 330            } else {
 331                SemanticIndexStatus::Indexing {
 332                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 333                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 334                }
 335            }
 336        } else {
 337            SemanticIndexStatus::NotIndexed
 338        }
 339    }
 340
 341    pub async fn new(
 342        fs: Arc<dyn Fs>,
 343        database_path: PathBuf,
 344        embedding_provider: Arc<dyn EmbeddingProvider>,
 345        language_registry: Arc<LanguageRegistry>,
 346        mut cx: AsyncAppContext,
 347    ) -> Result<Model<Self>> {
 348        let t0 = Instant::now();
 349        let database_path = Arc::from(database_path);
 350        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
 351            .await?;
 352
 353        log::trace!(
 354            "db initialization took {:?} milliseconds",
 355            t0.elapsed().as_millis()
 356        );
 357
 358        cx.new_model(|cx| {
 359            let t0 = Instant::now();
 360            let embedding_queue =
 361                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
 362            let _embedding_task = cx.background_executor().spawn({
 363                let embedded_files = embedding_queue.finished_files();
 364                let db = db.clone();
 365                async move {
 366                    while let Ok(file) = embedded_files.recv().await {
 367                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 368                            .await
 369                            .log_err();
 370                    }
 371                }
 372            });
 373
 374            // Parse files into embeddable spans.
 375            let (parsing_files_tx, parsing_files_rx) =
 376                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 377            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 378            let mut _parsing_files_tasks = Vec::new();
 379            for _ in 0..cx.background_executor().num_cpus() {
 380                let fs = fs.clone();
 381                let mut parsing_files_rx = parsing_files_rx.clone();
 382                let embedding_provider = embedding_provider.clone();
 383                let embedding_queue = embedding_queue.clone();
 384                let background = cx.background_executor().clone();
 385                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
 386                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 387                    loop {
 388                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 389                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 390                        futures::select_biased! {
 391                            next_file_to_parse = next_file_to_parse => {
 392                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 393                                    Self::parse_file(
 394                                        &fs,
 395                                        pending_file,
 396                                        &mut retriever,
 397                                        &embedding_queue,
 398                                        &embeddings_for_digest,
 399                                    )
 400                                    .await
 401                                } else {
 402                                    break;
 403                                }
 404                            },
 405                            _ = timer => {
 406                                embedding_queue.lock().flush();
 407                            }
 408                        }
 409                    }
 410                }));
 411            }
 412
 413            log::trace!(
 414                "semantic index task initialization took {:?} milliseconds",
 415                t0.elapsed().as_millis()
 416            );
 417            Self {
 418                fs,
 419                db,
 420                embedding_provider,
 421                language_registry,
 422                parsing_files_tx,
 423                _embedding_task,
 424                _parsing_files_tasks,
 425                projects: Default::default(),
 426            }
 427        })
 428    }
 429
 430    async fn parse_file(
 431        fs: &Arc<dyn Fs>,
 432        pending_file: PendingFile,
 433        retriever: &mut CodeContextRetriever,
 434        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 435        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 436    ) {
 437        let Some(language) = pending_file.language else {
 438            return;
 439        };
 440
 441        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 442            if let Some(mut spans) = retriever
 443                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 444                .log_err()
 445            {
 446                log::trace!(
 447                    "parsed path {:?}: {} spans",
 448                    pending_file.relative_path,
 449                    spans.len()
 450                );
 451
 452                for span in &mut spans {
 453                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 454                        span.embedding = Some(embedding.to_owned());
 455                    }
 456                }
 457
 458                embedding_queue.lock().push(FileToEmbed {
 459                    worktree_id: pending_file.worktree_db_id,
 460                    path: pending_file.relative_path,
 461                    mtime: pending_file.modified_time,
 462                    job_handle: pending_file.job_handle,
 463                    spans,
 464                });
 465            }
 466        }
 467    }
 468
 469    pub fn project_previously_indexed(
 470        &mut self,
 471        project: &Model<Project>,
 472        cx: &mut ModelContext<Self>,
 473    ) -> Task<Result<bool>> {
 474        let worktrees_indexed_previously = project
 475            .read(cx)
 476            .worktrees()
 477            .map(|worktree| {
 478                self.db
 479                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 480            })
 481            .collect::<Vec<_>>();
 482        cx.spawn(|_, _cx| async move {
 483            let worktree_indexed_previously =
 484                futures::future::join_all(worktrees_indexed_previously).await;
 485
 486            Ok(worktree_indexed_previously
 487                .iter()
 488                .filter(|worktree| worktree.is_ok())
 489                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 490        })
 491    }
 492
 493    fn project_entries_changed(
 494        &mut self,
 495        project: Model<Project>,
 496        worktree_id: WorktreeId,
 497        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 498        cx: &mut ModelContext<Self>,
 499    ) {
 500        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 501            return;
 502        };
 503        let project = project.downgrade();
 504        let Some(project_state) = self.projects.get_mut(&project) else {
 505            return;
 506        };
 507
 508        let worktree = worktree.read(cx);
 509        let worktree_state =
 510            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 511                worktree_state
 512            } else {
 513                return;
 514            };
 515        worktree_state.paths_changed(changes, worktree);
 516        if let WorktreeState::Registered(_) = worktree_state {
 517            cx.spawn(|this, mut cx| async move {
 518                cx.background_executor()
 519                    .timer(BACKGROUND_INDEXING_DELAY)
 520                    .await;
 521                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
 522                    this.update(&mut cx, |this, cx| {
 523                        this.index_project(project, cx).detach_and_log_err(cx)
 524                    })?;
 525                }
 526                anyhow::Ok(())
 527            })
 528            .detach_and_log_err(cx);
 529        }
 530    }
 531
 532    fn register_worktree(
 533        &mut self,
 534        project: Model<Project>,
 535        worktree: Model<Worktree>,
 536        cx: &mut ModelContext<Self>,
 537    ) {
 538        let project = project.downgrade();
 539        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 540            project_state
 541        } else {
 542            return;
 543        };
 544        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 545            worktree
 546        } else {
 547            return;
 548        };
 549        let worktree_abs_path = worktree.abs_path().clone();
 550        let scan_complete = worktree.scan_complete();
 551        let worktree_id = worktree.id();
 552        let db = self.db.clone();
 553        let language_registry = self.language_registry.clone();
 554        let (mut done_tx, done_rx) = watch::channel();
 555        let registration = cx.spawn(|this, mut cx| {
 556            async move {
 557                let register = async {
 558                    scan_complete.await;
 559                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 560                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 561                    let worktree = if let Some(project) = project.upgrade() {
 562                        project
 563                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 564                            .ok()
 565                            .flatten()
 566                            .context("worktree not found")?
 567                    } else {
 568                        return anyhow::Ok(());
 569                    };
 570                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
 571                    let mut changed_paths = cx
 572                        .background_executor()
 573                        .spawn(async move {
 574                            let mut changed_paths = BTreeMap::new();
 575                            for file in worktree.files(false, 0) {
 576                                let absolute_path = worktree.absolutize(&file.path)?;
 577
 578                                if file.is_external || file.is_ignored || file.is_symlink {
 579                                    continue;
 580                                }
 581
 582                                if let Ok(language) = language_registry
 583                                    .language_for_file(&absolute_path, None)
 584                                    .await
 585                                {
 586                                    // Test if file is valid parseable file
 587                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 588                                        .contains(&language.name().as_ref())
 589                                        && &language.name().as_ref() != &"Markdown"
 590                                        && language
 591                                            .grammar()
 592                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 593                                            .is_none()
 594                                    {
 595                                        continue;
 596                                    }
 597
 598                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 599                                    let already_stored = stored_mtime
 600                                        .map_or(false, |existing_mtime| {
 601                                            existing_mtime == file.mtime
 602                                        });
 603
 604                                    if !already_stored {
 605                                        changed_paths.insert(
 606                                            file.path.clone(),
 607                                            ChangedPathInfo {
 608                                                mtime: file.mtime,
 609                                                is_deleted: false,
 610                                            },
 611                                        );
 612                                    }
 613                                }
 614                            }
 615
 616                            // Clean up entries from database that are no longer in the worktree.
 617                            for (path, mtime) in file_mtimes {
 618                                changed_paths.insert(
 619                                    path.into(),
 620                                    ChangedPathInfo {
 621                                        mtime,
 622                                        is_deleted: true,
 623                                    },
 624                                );
 625                            }
 626
 627                            anyhow::Ok(changed_paths)
 628                        })
 629                        .await?;
 630                    this.update(&mut cx, |this, cx| {
 631                        let project_state = this
 632                            .projects
 633                            .get_mut(&project)
 634                            .context("project not registered")?;
 635                        let project = project.upgrade().context("project was dropped")?;
 636
 637                        if let Some(WorktreeState::Registering(state)) =
 638                            project_state.worktrees.remove(&worktree_id)
 639                        {
 640                            changed_paths.extend(state.changed_paths);
 641                        }
 642                        project_state.worktrees.insert(
 643                            worktree_id,
 644                            WorktreeState::Registered(RegisteredWorktreeState {
 645                                db_id,
 646                                changed_paths,
 647                            }),
 648                        );
 649                        this.index_project(project, cx).detach_and_log_err(cx);
 650
 651                        anyhow::Ok(())
 652                    })??;
 653
 654                    anyhow::Ok(())
 655                };
 656
 657                if register.await.log_err().is_none() {
 658                    // Stop tracking this worktree if the registration failed.
 659                    this.update(&mut cx, |this, _| {
 660                        this.projects.get_mut(&project).map(|project_state| {
 661                            project_state.worktrees.remove(&worktree_id);
 662                        });
 663                    })
 664                    .ok();
 665                }
 666
 667                *done_tx.borrow_mut() = Some(());
 668            }
 669        });
 670        project_state.worktrees.insert(
 671            worktree_id,
 672            WorktreeState::Registering(RegisteringWorktreeState {
 673                changed_paths: Default::default(),
 674                done_rx,
 675                _registration: registration,
 676            }),
 677        );
 678    }
 679
 680    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
 681        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 682        {
 683            project_state
 684        } else {
 685            return;
 686        };
 687
 688        let mut worktrees = project
 689            .read(cx)
 690            .worktrees()
 691            .filter(|worktree| worktree.read(cx).is_local())
 692            .collect::<Vec<_>>();
 693        let worktree_ids = worktrees
 694            .iter()
 695            .map(|worktree| worktree.read(cx).id())
 696            .collect::<HashSet<_>>();
 697
 698        // Remove worktrees that are no longer present
 699        project_state
 700            .worktrees
 701            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 702
 703        // Register new worktrees
 704        worktrees.retain(|worktree| {
 705            let worktree_id = worktree.read(cx).id();
 706            !project_state.worktrees.contains_key(&worktree_id)
 707        });
 708        for worktree in worktrees {
 709            self.register_worktree(project.clone(), worktree, cx);
 710        }
 711    }
 712
 713    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
 714        Some(
 715            self.projects
 716                .get(&project.downgrade())?
 717                .pending_file_count_rx
 718                .clone(),
 719        )
 720    }
 721
 722    pub fn search_project(
 723        &mut self,
 724        project: Model<Project>,
 725        query: String,
 726        limit: usize,
 727        includes: Vec<PathMatcher>,
 728        excludes: Vec<PathMatcher>,
 729        cx: &mut ModelContext<Self>,
 730    ) -> Task<Result<Vec<SearchResult>>> {
 731        if query.is_empty() {
 732            return Task::ready(Ok(Vec::new()));
 733        }
 734
 735        let index = self.index_project(project.clone(), cx);
 736        let embedding_provider = self.embedding_provider.clone();
 737
 738        cx.spawn(|this, mut cx| async move {
 739            index.await?;
 740            let t0 = Instant::now();
 741
 742            let query = embedding_provider
 743                .embed_batch(vec![query])
 744                .await?
 745                .pop()
 746                .context("could not embed query")?;
 747            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 748
 749            let search_start = Instant::now();
 750            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 751                this.search_modified_buffers(
 752                    &project,
 753                    query.clone(),
 754                    limit,
 755                    &includes,
 756                    &excludes,
 757                    cx,
 758                )
 759            })?;
 760            let file_results = this.update(&mut cx, |this, cx| {
 761                this.search_files(project, query, limit, includes, excludes, cx)
 762            })?;
 763            let (modified_buffer_results, file_results) =
 764                futures::join!(modified_buffer_results, file_results);
 765
 766            // Weave together the results from modified buffers and files.
 767            let mut results = Vec::new();
 768            let mut modified_buffers = HashSet::default();
 769            for result in modified_buffer_results.log_err().unwrap_or_default() {
 770                modified_buffers.insert(result.buffer.clone());
 771                results.push(result);
 772            }
 773            for result in file_results.log_err().unwrap_or_default() {
 774                if !modified_buffers.contains(&result.buffer) {
 775                    results.push(result);
 776                }
 777            }
 778            results.sort_by_key(|result| Reverse(result.similarity));
 779            results.truncate(limit);
 780            log::trace!("Semantic search took {:?}", search_start.elapsed());
 781            Ok(results)
 782        })
 783    }
 784
 785    pub fn search_files(
 786        &mut self,
 787        project: Model<Project>,
 788        query: Embedding,
 789        limit: usize,
 790        includes: Vec<PathMatcher>,
 791        excludes: Vec<PathMatcher>,
 792        cx: &mut ModelContext<Self>,
 793    ) -> Task<Result<Vec<SearchResult>>> {
 794        let db_path = self.db.path().clone();
 795        let fs = self.fs.clone();
 796        cx.spawn(|this, mut cx| async move {
 797            let database = VectorDatabase::new(
 798                fs.clone(),
 799                db_path.clone(),
 800                cx.background_executor().clone(),
 801            )
 802            .await?;
 803
 804            let worktree_db_ids = this.read_with(&cx, |this, _| {
 805                let project_state = this
 806                    .projects
 807                    .get(&project.downgrade())
 808                    .context("project was not indexed")?;
 809                let worktree_db_ids = project_state
 810                    .worktrees
 811                    .values()
 812                    .filter_map(|worktree| {
 813                        if let WorktreeState::Registered(worktree) = worktree {
 814                            Some(worktree.db_id)
 815                        } else {
 816                            None
 817                        }
 818                    })
 819                    .collect::<Vec<i64>>();
 820                anyhow::Ok(worktree_db_ids)
 821            })??;
 822
 823            let file_ids = database
 824                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 825                .await?;
 826
 827            let batch_n = cx.background_executor().num_cpus();
 828            let ids_len = file_ids.clone().len();
 829            let minimum_batch_size = 50;
 830
 831            let batch_size = {
 832                let size = ids_len / batch_n;
 833                if size < minimum_batch_size {
 834                    minimum_batch_size
 835                } else {
 836                    size
 837                }
 838            };
 839
 840            let mut batch_results = Vec::new();
 841            for batch in file_ids.chunks(batch_size) {
 842                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 843                let limit = limit.clone();
 844                let fs = fs.clone();
 845                let db_path = db_path.clone();
 846                let query = query.clone();
 847                if let Some(db) =
 848                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
 849                        .await
 850                        .log_err()
 851                {
 852                    batch_results.push(async move {
 853                        db.top_k_search(&query, limit, batch.as_slice()).await
 854                    });
 855                }
 856            }
 857
 858            let batch_results = futures::future::join_all(batch_results).await;
 859
 860            let mut results = Vec::new();
 861            for batch_result in batch_results {
 862                if batch_result.is_ok() {
 863                    for (id, similarity) in batch_result.unwrap() {
 864                        let ix = match results
 865                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 866                        {
 867                            Ok(ix) => ix,
 868                            Err(ix) => ix,
 869                        };
 870
 871                        results.insert(ix, (id, similarity));
 872                        results.truncate(limit);
 873                    }
 874                }
 875            }
 876
 877            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 878            let scores = results
 879                .into_iter()
 880                .map(|(_, score)| score)
 881                .collect::<Vec<_>>();
 882            let spans = database.spans_for_ids(ids.as_slice()).await?;
 883
 884            let mut tasks = Vec::new();
 885            let mut ranges = Vec::new();
 886            let weak_project = project.downgrade();
 887            project.update(&mut cx, |project, cx| {
 888                let this = this.upgrade().context("index was dropped")?;
 889                for (worktree_db_id, file_path, byte_range) in spans {
 890                    let project_state =
 891                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 892                            state
 893                        } else {
 894                            return Err(anyhow!("project not added"));
 895                        };
 896                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 897                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 898                        ranges.push(byte_range);
 899                    }
 900                }
 901
 902                Ok(())
 903            })??;
 904
 905            let buffers = futures::future::join_all(tasks).await;
 906            Ok(buffers
 907                .into_iter()
 908                .zip(ranges)
 909                .zip(scores)
 910                .filter_map(|((buffer, range), similarity)| {
 911                    let buffer = buffer.log_err()?;
 912                    let range = buffer
 913                        .read_with(&cx, |buffer, _| {
 914                            let start = buffer.clip_offset(range.start, Bias::Left);
 915                            let end = buffer.clip_offset(range.end, Bias::Right);
 916                            buffer.anchor_before(start)..buffer.anchor_after(end)
 917                        })
 918                        .log_err()?;
 919                    Some(SearchResult {
 920                        buffer,
 921                        range,
 922                        similarity,
 923                    })
 924                })
 925                .collect())
 926        })
 927    }
 928
 929    fn search_modified_buffers(
 930        &self,
 931        project: &Model<Project>,
 932        query: Embedding,
 933        limit: usize,
 934        includes: &[PathMatcher],
 935        excludes: &[PathMatcher],
 936        cx: &mut ModelContext<Self>,
 937    ) -> Task<Result<Vec<SearchResult>>> {
 938        let modified_buffers = project
 939            .read(cx)
 940            .opened_buffers()
 941            .into_iter()
 942            .filter_map(|buffer_handle| {
 943                let buffer = buffer_handle.read(cx);
 944                let snapshot = buffer.snapshot();
 945                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 946                    excludes.iter().any(|matcher| matcher.is_match(&path))
 947                });
 948
 949                let included = if includes.len() == 0 {
 950                    true
 951                } else {
 952                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 953                        includes.iter().any(|matcher| matcher.is_match(&path))
 954                    })
 955                };
 956
 957                if buffer.is_dirty() && !excluded && included {
 958                    Some((buffer_handle, snapshot))
 959                } else {
 960                    None
 961                }
 962            })
 963            .collect::<HashMap<_, _>>();
 964
 965        let embedding_provider = self.embedding_provider.clone();
 966        let fs = self.fs.clone();
 967        let db_path = self.db.path().clone();
 968        let background = cx.background_executor().clone();
 969        cx.background_executor().spawn(async move {
 970            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 971            let mut results = Vec::<SearchResult>::new();
 972
 973            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 974            for (buffer, snapshot) in modified_buffers {
 975                let language = snapshot
 976                    .language_at(0)
 977                    .cloned()
 978                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 979                let mut spans = retriever
 980                    .parse_file_with_template(None, &snapshot.text(), language)
 981                    .log_err()
 982                    .unwrap_or_default();
 983                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 984                    .await
 985                    .log_err()
 986                    .is_some()
 987                {
 988                    for span in spans {
 989                        let similarity = span.embedding.unwrap().similarity(&query);
 990                        let ix = match results
 991                            .binary_search_by_key(&Reverse(similarity), |result| {
 992                                Reverse(result.similarity)
 993                            }) {
 994                            Ok(ix) => ix,
 995                            Err(ix) => ix,
 996                        };
 997
 998                        let range = {
 999                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
1000                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
1001                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
1002                        };
1003
1004                        results.insert(
1005                            ix,
1006                            SearchResult {
1007                                buffer: buffer.clone(),
1008                                range,
1009                                similarity,
1010                            },
1011                        );
1012                        results.truncate(limit);
1013                    }
1014                }
1015            }
1016
1017            Ok(results)
1018        })
1019    }
1020
1021    pub fn index_project(
1022        &mut self,
1023        project: Model<Project>,
1024        cx: &mut ModelContext<Self>,
1025    ) -> Task<Result<()>> {
1026        if self.is_authenticated() {
1027            self.index_project_internal(project, cx)
1028        } else {
1029            let authenticate = self.authenticate(cx);
1030            cx.spawn(|this, mut cx| async move {
1031                if authenticate.await {
1032                    this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1033                        .await
1034                } else {
1035                    Err(anyhow!("user is not authenticated"))
1036                }
1037            })
1038        }
1039    }
1040
1041    fn index_project_internal(
1042        &mut self,
1043        project: Model<Project>,
1044        cx: &mut ModelContext<Self>,
1045    ) -> Task<Result<()>> {
1046        if !self.projects.contains_key(&project.downgrade()) {
1047            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1048                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1049                    this.project_worktrees_changed(project.clone(), cx);
1050                }
1051                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1052                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1053                }
1054                _ => {}
1055            });
1056            let project_state = ProjectState::new(subscription, cx);
1057            self.projects.insert(project.downgrade(), project_state);
1058            self.project_worktrees_changed(project.clone(), cx);
1059        }
1060        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1061        project_state.pending_index += 1;
1062        cx.notify();
1063
1064        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1065        let db = self.db.clone();
1066        let language_registry = self.language_registry.clone();
1067        let parsing_files_tx = self.parsing_files_tx.clone();
1068        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1069
1070        cx.spawn(|this, mut cx| async move {
1071            worktree_registration.await?;
1072
1073            let mut pending_files = Vec::new();
1074            let mut files_to_delete = Vec::new();
1075            this.update(&mut cx, |this, cx| {
1076                let project_state = this
1077                    .projects
1078                    .get_mut(&project.downgrade())
1079                    .context("project was dropped")?;
1080                let pending_file_count_tx = &project_state.pending_file_count_tx;
1081
1082                project_state
1083                    .worktrees
1084                    .retain(|worktree_id, worktree_state| {
1085                        let worktree = if let Some(worktree) =
1086                            project.read(cx).worktree_for_id(*worktree_id, cx)
1087                        {
1088                            worktree
1089                        } else {
1090                            return false;
1091                        };
1092                        let worktree_state =
1093                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1094                                worktree_state
1095                            } else {
1096                                return true;
1097                            };
1098
1099                        for (path, info) in &worktree_state.changed_paths {
1100                            if info.is_deleted {
1101                                files_to_delete.push((worktree_state.db_id, path.clone()));
1102                            } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1103                                let job_handle = JobHandle::new(pending_file_count_tx);
1104                                pending_files.push(PendingFile {
1105                                    absolute_path,
1106                                    relative_path: path.clone(),
1107                                    language: None,
1108                                    job_handle,
1109                                    modified_time: info.mtime,
1110                                    worktree_db_id: worktree_state.db_id,
1111                                });
1112                            }
1113                        }
1114                        worktree_state.changed_paths.clear();
1115                        true
1116                    });
1117
1118                anyhow::Ok(())
1119            })??;
1120
1121            cx.background_executor()
1122                .spawn(async move {
1123                    for (worktree_db_id, path) in files_to_delete {
1124                        db.delete_file(worktree_db_id, path).await.log_err();
1125                    }
1126
1127                    let embeddings_for_digest = {
1128                        let mut files = HashMap::default();
1129                        for pending_file in &pending_files {
1130                            files
1131                                .entry(pending_file.worktree_db_id)
1132                                .or_insert(Vec::new())
1133                                .push(pending_file.relative_path.clone());
1134                        }
1135                        Arc::new(
1136                            db.embeddings_for_files(files)
1137                                .await
1138                                .log_err()
1139                                .unwrap_or_default(),
1140                        )
1141                    };
1142
1143                    for mut pending_file in pending_files {
1144                        if let Ok(language) = language_registry
1145                            .language_for_file(&pending_file.relative_path, None)
1146                            .await
1147                        {
1148                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1149                                && &language.name().as_ref() != &"Markdown"
1150                                && language
1151                                    .grammar()
1152                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1153                                    .is_none()
1154                            {
1155                                continue;
1156                            }
1157                            pending_file.language = Some(language);
1158                        }
1159                        parsing_files_tx
1160                            .try_send((embeddings_for_digest.clone(), pending_file))
1161                            .ok();
1162                    }
1163
1164                    // Wait until we're done indexing.
1165                    while let Some(count) = pending_file_count_rx.next().await {
1166                        if count == 0 {
1167                            break;
1168                        }
1169                    }
1170                })
1171                .await;
1172
1173            this.update(&mut cx, |this, cx| {
1174                let project_state = this
1175                    .projects
1176                    .get_mut(&project.downgrade())
1177                    .context("project was dropped")?;
1178                project_state.pending_index -= 1;
1179                cx.notify();
1180                anyhow::Ok(())
1181            })??;
1182
1183            Ok(())
1184        })
1185    }
1186
1187    fn wait_for_worktree_registration(
1188        &self,
1189        project: &Model<Project>,
1190        cx: &mut ModelContext<Self>,
1191    ) -> Task<Result<()>> {
1192        let project = project.downgrade();
1193        cx.spawn(|this, cx| async move {
1194            loop {
1195                let mut pending_worktrees = Vec::new();
1196                this.upgrade()
1197                    .context("semantic index dropped")?
1198                    .read_with(&cx, |this, _| {
1199                        if let Some(project) = this.projects.get(&project) {
1200                            for worktree in project.worktrees.values() {
1201                                if let WorktreeState::Registering(worktree) = worktree {
1202                                    pending_worktrees.push(worktree.done());
1203                                }
1204                            }
1205                        }
1206                    })?;
1207
1208                if pending_worktrees.is_empty() {
1209                    break;
1210                } else {
1211                    future::join_all(pending_worktrees).await;
1212                }
1213            }
1214            Ok(())
1215        })
1216    }
1217
1218    async fn embed_spans(
1219        spans: &mut [Span],
1220        embedding_provider: &dyn EmbeddingProvider,
1221        db: &VectorDatabase,
1222    ) -> Result<()> {
1223        let mut batch = Vec::new();
1224        let mut batch_tokens = 0;
1225        let mut embeddings = Vec::new();
1226
1227        let digests = spans
1228            .iter()
1229            .map(|span| span.digest.clone())
1230            .collect::<Vec<_>>();
1231        let embeddings_for_digests = db
1232            .embeddings_for_digests(digests)
1233            .await
1234            .log_err()
1235            .unwrap_or_default();
1236
1237        for span in &*spans {
1238            if embeddings_for_digests.contains_key(&span.digest) {
1239                continue;
1240            };
1241
1242            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1243                let batch_embeddings = embedding_provider
1244                    .embed_batch(mem::take(&mut batch))
1245                    .await?;
1246                embeddings.extend(batch_embeddings);
1247                batch_tokens = 0;
1248            }
1249
1250            batch_tokens += span.token_count;
1251            batch.push(span.content.clone());
1252        }
1253
1254        if !batch.is_empty() {
1255            let batch_embeddings = embedding_provider
1256                .embed_batch(mem::take(&mut batch))
1257                .await?;
1258
1259            embeddings.extend(batch_embeddings);
1260        }
1261
1262        let mut embeddings = embeddings.into_iter();
1263        for span in spans {
1264            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1265                Some(embedding.clone())
1266            } else {
1267                embeddings.next()
1268            };
1269            let embedding = embedding.context("failed to embed spans")?;
1270            span.embedding = Some(embedding);
1271        }
1272        Ok(())
1273    }
1274}
1275
1276impl Drop for JobHandle {
1277    fn drop(&mut self) {
1278        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1279            // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1280            if let Some(tx) = inner.upgrade() {
1281                let mut tx = tx.lock();
1282                *tx.borrow_mut() -= 1;
1283            }
1284        }
1285    }
1286}
1287
1288#[cfg(test)]
1289mod tests {
1290
1291    use super::*;
1292    #[test]
1293    fn test_job_handle() {
1294        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1295        let tx = Arc::new(Mutex::new(job_count_tx));
1296        let job_handle = JobHandle::new(&tx);
1297
1298        assert_eq!(1, *job_count_rx.borrow());
1299        let new_job_handle = job_handle.clone();
1300        assert_eq!(1, *job_count_rx.borrow());
1301        drop(job_handle);
1302        assert_eq!(1, *job_count_rx.borrow());
1303        drop(new_job_handle);
1304        assert_eq!(0, *job_count_rx.borrow());
1305    }
1306}