semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider};
  11use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
  12use anyhow::{anyhow, Context as _, Result};
  13use collections::{BTreeMap, HashMap, HashSet};
  14use db::VectorDatabase;
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{
  18    AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
  19    ViewContext, WeakModel,
  20};
  21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  22use lazy_static::lazy_static;
  23use ordered_float::OrderedFloat;
  24use parking_lot::Mutex;
  25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  26use postage::watch;
  27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  28use release_channel::ReleaseChannel;
  29use settings::Settings;
  30use smol::channel;
  31use std::{
  32    cmp::Reverse,
  33    env,
  34    future::Future,
  35    mem,
  36    ops::Range,
  37    path::{Path, PathBuf},
  38    sync::{Arc, Weak},
  39    time::{Duration, Instant, SystemTime},
  40};
  41use util::paths::PathMatcher;
  42use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  43use workspace::Workspace;
  44
  45const SEMANTIC_INDEX_VERSION: usize = 11;
  46const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  47const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  48
  49lazy_static! {
  50    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  51}
  52
  53pub fn init(
  54    fs: Arc<dyn Fs>,
  55    http_client: Arc<dyn HttpClient>,
  56    language_registry: Arc<LanguageRegistry>,
  57    cx: &mut AppContext,
  58) {
  59    SemanticIndexSettings::register(cx);
  60
  61    let db_file_path = EMBEDDINGS_DIR
  62        .join(Path::new(ReleaseChannel::global(cx).dev_name()))
  63        .join("embeddings_db");
  64
  65    cx.observe_new_views(
  66        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
  67            let Some(semantic_index) = SemanticIndex::global(cx) else {
  68                return;
  69            };
  70            let project = workspace.project().clone();
  71
  72            if project.read(cx).is_local() {
  73                cx.app_mut()
  74                    .spawn(|mut cx| async move {
  75                        let previously_indexed = semantic_index
  76                            .update(&mut cx, |index, cx| {
  77                                index.project_previously_indexed(&project, cx)
  78                            })?
  79                            .await?;
  80                        if previously_indexed {
  81                            semantic_index
  82                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
  83                                .await?;
  84                        }
  85                        anyhow::Ok(())
  86                    })
  87                    .detach_and_log_err(cx);
  88            }
  89        },
  90    )
  91    .detach();
  92
  93    cx.spawn(move |cx| async move {
  94        let embedding_provider = OpenAiEmbeddingProvider::new(
  95            // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
  96            OPEN_AI_API_URL.to_string(),
  97            http_client,
  98            cx.background_executor().clone(),
  99        )
 100        .await;
 101        let semantic_index = SemanticIndex::new(
 102            fs,
 103            db_file_path,
 104            Arc::new(embedding_provider),
 105            language_registry,
 106            cx.clone(),
 107        )
 108        .await?;
 109
 110        cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
 111
 112        anyhow::Ok(())
 113    })
 114    .detach();
 115}
 116
 117#[derive(Copy, Clone, Debug)]
 118pub enum SemanticIndexStatus {
 119    NotAuthenticated,
 120    NotIndexed,
 121    Indexed,
 122    Indexing {
 123        remaining_files: usize,
 124        rate_limit_expiry: Option<Instant>,
 125    },
 126}
 127
 128pub struct SemanticIndex {
 129    fs: Arc<dyn Fs>,
 130    db: VectorDatabase,
 131    embedding_provider: Arc<dyn EmbeddingProvider>,
 132    language_registry: Arc<LanguageRegistry>,
 133    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 134    _embedding_task: Task<()>,
 135    _parsing_files_tasks: Vec<Task<()>>,
 136    projects: HashMap<WeakModel<Project>, ProjectState>,
 137}
 138
 139struct GlobalSemanticIndex(Model<SemanticIndex>);
 140
 141impl Global for GlobalSemanticIndex {}
 142
 143struct ProjectState {
 144    worktrees: HashMap<WorktreeId, WorktreeState>,
 145    pending_file_count_rx: watch::Receiver<usize>,
 146    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 147    pending_index: usize,
 148    _subscription: gpui::Subscription,
 149    _observe_pending_file_count: Task<()>,
 150}
 151
 152enum WorktreeState {
 153    Registering(RegisteringWorktreeState),
 154    Registered(RegisteredWorktreeState),
 155}
 156
 157impl WorktreeState {
 158    fn is_registered(&self) -> bool {
 159        matches!(self, Self::Registered(_))
 160    }
 161
 162    fn paths_changed(
 163        &mut self,
 164        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 165        worktree: &Worktree,
 166    ) {
 167        let changed_paths = match self {
 168            Self::Registering(state) => &mut state.changed_paths,
 169            Self::Registered(state) => &mut state.changed_paths,
 170        };
 171
 172        for (path, entry_id, change) in changes.iter() {
 173            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 174                continue;
 175            };
 176            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 177                continue;
 178            }
 179            changed_paths.insert(
 180                path.clone(),
 181                ChangedPathInfo {
 182                    mtime: entry.mtime,
 183                    is_deleted: *change == PathChange::Removed,
 184                },
 185            );
 186        }
 187    }
 188}
 189
 190struct RegisteringWorktreeState {
 191    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 192    done_rx: watch::Receiver<Option<()>>,
 193    _registration: Task<()>,
 194}
 195
 196impl RegisteringWorktreeState {
 197    fn done(&self) -> impl Future<Output = ()> {
 198        let mut done_rx = self.done_rx.clone();
 199        async move {
 200            while let Some(result) = done_rx.next().await {
 201                if result.is_some() {
 202                    break;
 203                }
 204            }
 205        }
 206    }
 207}
 208
 209struct RegisteredWorktreeState {
 210    db_id: i64,
 211    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 212}
 213
 214struct ChangedPathInfo {
 215    mtime: SystemTime,
 216    is_deleted: bool,
 217}
 218
 219#[derive(Clone)]
 220pub struct JobHandle {
 221    /// The outer Arc is here to count the clones of a JobHandle instance;
 222    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 223    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 224}
 225
 226impl JobHandle {
 227    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 228        *tx.lock().borrow_mut() += 1;
 229        Self {
 230            tx: Arc::new(Arc::downgrade(&tx)),
 231        }
 232    }
 233}
 234
 235impl ProjectState {
 236    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 237        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 238        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 239        Self {
 240            worktrees: Default::default(),
 241            pending_file_count_rx: pending_file_count_rx.clone(),
 242            pending_file_count_tx,
 243            pending_index: 0,
 244            _subscription: subscription,
 245            _observe_pending_file_count: cx.spawn({
 246                let mut pending_file_count_rx = pending_file_count_rx.clone();
 247                |this, mut cx| async move {
 248                    while let Some(_) = pending_file_count_rx.next().await {
 249                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
 250                            break;
 251                        }
 252                    }
 253                }
 254            }),
 255        }
 256    }
 257
 258    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 259        self.worktrees
 260            .iter()
 261            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 262                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 263                _ => None,
 264            })
 265    }
 266}
 267
 268#[derive(Clone)]
 269pub struct PendingFile {
 270    worktree_db_id: i64,
 271    relative_path: Arc<Path>,
 272    absolute_path: PathBuf,
 273    language: Option<Arc<Language>>,
 274    modified_time: SystemTime,
 275    job_handle: JobHandle,
 276}
 277
 278#[derive(Clone)]
 279pub struct SearchResult {
 280    pub buffer: Model<Buffer>,
 281    pub range: Range<Anchor>,
 282    pub similarity: OrderedFloat<f32>,
 283}
 284
 285impl SemanticIndex {
 286    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
 287        cx.try_global::<GlobalSemanticIndex>()
 288            .map(|semantic_index| semantic_index.0.clone())
 289    }
 290
 291    pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
 292        if !self.embedding_provider.has_credentials() {
 293            let embedding_provider = self.embedding_provider.clone();
 294            cx.spawn(|cx| async move {
 295                if let Some(retrieve_credentials) = cx
 296                    .update(|cx| embedding_provider.retrieve_credentials(cx))
 297                    .log_err()
 298                {
 299                    retrieve_credentials.await;
 300                }
 301
 302                embedding_provider.has_credentials()
 303            })
 304        } else {
 305            Task::ready(true)
 306        }
 307    }
 308
 309    pub fn is_authenticated(&self) -> bool {
 310        self.embedding_provider.has_credentials()
 311    }
 312
 313    pub fn enabled(cx: &AppContext) -> bool {
 314        SemanticIndexSettings::get_global(cx).enabled
 315    }
 316
 317    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
 318        if !self.is_authenticated() {
 319            return SemanticIndexStatus::NotAuthenticated;
 320        }
 321
 322        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 323            if project_state
 324                .worktrees
 325                .values()
 326                .all(|worktree| worktree.is_registered())
 327                && project_state.pending_index == 0
 328            {
 329                SemanticIndexStatus::Indexed
 330            } else {
 331                SemanticIndexStatus::Indexing {
 332                    remaining_files: *project_state.pending_file_count_rx.borrow(),
 333                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 334                }
 335            }
 336        } else {
 337            SemanticIndexStatus::NotIndexed
 338        }
 339    }
 340
 341    pub async fn new(
 342        fs: Arc<dyn Fs>,
 343        database_path: PathBuf,
 344        embedding_provider: Arc<dyn EmbeddingProvider>,
 345        language_registry: Arc<LanguageRegistry>,
 346        mut cx: AsyncAppContext,
 347    ) -> Result<Model<Self>> {
 348        let t0 = Instant::now();
 349        let database_path = Arc::from(database_path);
 350        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
 351            .await?;
 352
 353        log::trace!(
 354            "db initialization took {:?} milliseconds",
 355            t0.elapsed().as_millis()
 356        );
 357
 358        cx.new_model(|cx| {
 359            let t0 = Instant::now();
 360            let embedding_queue =
 361                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
 362            let _embedding_task = cx.background_executor().spawn({
 363                let embedded_files = embedding_queue.finished_files();
 364                let db = db.clone();
 365                async move {
 366                    while let Ok(file) = embedded_files.recv().await {
 367                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 368                            .await
 369                            .log_err();
 370                    }
 371                }
 372            });
 373
 374            // Parse files into embeddable spans.
 375            let (parsing_files_tx, parsing_files_rx) =
 376                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 377            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 378            let mut _parsing_files_tasks = Vec::new();
 379            for _ in 0..cx.background_executor().num_cpus() {
 380                let fs = fs.clone();
 381                let mut parsing_files_rx = parsing_files_rx.clone();
 382                let embedding_provider = embedding_provider.clone();
 383                let embedding_queue = embedding_queue.clone();
 384                let background = cx.background_executor().clone();
 385                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
 386                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 387                    loop {
 388                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 389                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 390                        futures::select_biased! {
 391                            next_file_to_parse = next_file_to_parse => {
 392                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 393                                    Self::parse_file(
 394                                        &fs,
 395                                        pending_file,
 396                                        &mut retriever,
 397                                        &embedding_queue,
 398                                        &embeddings_for_digest,
 399                                    )
 400                                    .await
 401                                } else {
 402                                    break;
 403                                }
 404                            },
 405                            _ = timer => {
 406                                embedding_queue.lock().flush();
 407                            }
 408                        }
 409                    }
 410                }));
 411            }
 412
 413            log::trace!(
 414                "semantic index task initialization took {:?} milliseconds",
 415                t0.elapsed().as_millis()
 416            );
 417            Self {
 418                fs,
 419                db,
 420                embedding_provider,
 421                language_registry,
 422                parsing_files_tx,
 423                _embedding_task,
 424                _parsing_files_tasks,
 425                projects: Default::default(),
 426            }
 427        })
 428    }
 429
 430    async fn parse_file(
 431        fs: &Arc<dyn Fs>,
 432        pending_file: PendingFile,
 433        retriever: &mut CodeContextRetriever,
 434        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 435        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 436    ) {
 437        let Some(language) = pending_file.language else {
 438            return;
 439        };
 440
 441        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 442            if let Some(mut spans) = retriever
 443                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 444                .log_err()
 445            {
 446                log::trace!(
 447                    "parsed path {:?}: {} spans",
 448                    pending_file.relative_path,
 449                    spans.len()
 450                );
 451
 452                for span in &mut spans {
 453                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 454                        span.embedding = Some(embedding.to_owned());
 455                    }
 456                }
 457
 458                embedding_queue.lock().push(FileToEmbed {
 459                    worktree_id: pending_file.worktree_db_id,
 460                    path: pending_file.relative_path,
 461                    mtime: pending_file.modified_time,
 462                    job_handle: pending_file.job_handle,
 463                    spans,
 464                });
 465            }
 466        }
 467    }
 468
 469    pub fn project_previously_indexed(
 470        &mut self,
 471        project: &Model<Project>,
 472        cx: &mut ModelContext<Self>,
 473    ) -> Task<Result<bool>> {
 474        let worktrees_indexed_previously = project
 475            .read(cx)
 476            .worktrees()
 477            .map(|worktree| {
 478                self.db
 479                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 480            })
 481            .collect::<Vec<_>>();
 482        cx.spawn(|_, _cx| async move {
 483            let worktree_indexed_previously =
 484                futures::future::join_all(worktrees_indexed_previously).await;
 485
 486            Ok(worktree_indexed_previously
 487                .iter()
 488                .filter(|worktree| worktree.is_ok())
 489                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 490        })
 491    }
 492
 493    fn project_entries_changed(
 494        &mut self,
 495        project: Model<Project>,
 496        worktree_id: WorktreeId,
 497        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 498        cx: &mut ModelContext<Self>,
 499    ) {
 500        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) else {
 501            return;
 502        };
 503        let project = project.downgrade();
 504        let Some(project_state) = self.projects.get_mut(&project) else {
 505            return;
 506        };
 507
 508        let worktree = worktree.read(cx);
 509        let worktree_state =
 510            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 511                worktree_state
 512            } else {
 513                return;
 514            };
 515        worktree_state.paths_changed(changes, worktree);
 516        if let WorktreeState::Registered(_) = worktree_state {
 517            cx.spawn(|this, mut cx| async move {
 518                cx.background_executor()
 519                    .timer(BACKGROUND_INDEXING_DELAY)
 520                    .await;
 521                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
 522                    this.update(&mut cx, |this, cx| {
 523                        this.index_project(project, cx).detach_and_log_err(cx)
 524                    })?;
 525                }
 526                anyhow::Ok(())
 527            })
 528            .detach_and_log_err(cx);
 529        }
 530    }
 531
 532    fn register_worktree(
 533        &mut self,
 534        project: Model<Project>,
 535        worktree: Model<Worktree>,
 536        cx: &mut ModelContext<Self>,
 537    ) {
 538        let project = project.downgrade();
 539        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 540            project_state
 541        } else {
 542            return;
 543        };
 544        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 545            worktree
 546        } else {
 547            return;
 548        };
 549        let worktree_abs_path = worktree.abs_path().clone();
 550        let scan_complete = worktree.scan_complete();
 551        let worktree_id = worktree.id();
 552        let db = self.db.clone();
 553        let language_registry = self.language_registry.clone();
 554        let (mut done_tx, done_rx) = watch::channel();
 555        let registration = cx.spawn(|this, mut cx| {
 556            async move {
 557                let register = async {
 558                    scan_complete.await;
 559                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 560                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 561                    let worktree = if let Some(project) = project.upgrade() {
 562                        project
 563                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 564                            .ok()
 565                            .flatten()
 566                            .context("worktree not found")?
 567                    } else {
 568                        return anyhow::Ok(());
 569                    };
 570                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
 571                    let mut changed_paths = cx
 572                        .background_executor()
 573                        .spawn(async move {
 574                            let mut changed_paths = BTreeMap::new();
 575                            for file in worktree.files(false, 0) {
 576                                let absolute_path = worktree.absolutize(&file.path)?;
 577
 578                                if file.is_external || file.is_ignored || file.is_symlink {
 579                                    continue;
 580                                }
 581
 582                                if let Ok(language) = language_registry
 583                                    .language_for_file(&absolute_path, None)
 584                                    .await
 585                                {
 586                                    // Test if file is valid parseable file
 587                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 588                                        .contains(&language.name().as_ref())
 589                                        && &language.name().as_ref() != &"Markdown"
 590                                        && language
 591                                            .grammar()
 592                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 593                                            .is_none()
 594                                    {
 595                                        continue;
 596                                    }
 597
 598                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 599                                    let already_stored = stored_mtime
 600                                        .map_or(false, |existing_mtime| {
 601                                            existing_mtime == file.mtime
 602                                        });
 603
 604                                    if !already_stored {
 605                                        changed_paths.insert(
 606                                            file.path.clone(),
 607                                            ChangedPathInfo {
 608                                                mtime: file.mtime,
 609                                                is_deleted: false,
 610                                            },
 611                                        );
 612                                    }
 613                                }
 614                            }
 615
 616                            // Clean up entries from database that are no longer in the worktree.
 617                            for (path, mtime) in file_mtimes {
 618                                changed_paths.insert(
 619                                    path.into(),
 620                                    ChangedPathInfo {
 621                                        mtime,
 622                                        is_deleted: true,
 623                                    },
 624                                );
 625                            }
 626
 627                            anyhow::Ok(changed_paths)
 628                        })
 629                        .await?;
 630                    this.update(&mut cx, |this, cx| {
 631                        let project_state = this
 632                            .projects
 633                            .get_mut(&project)
 634                            .context("project not registered")?;
 635                        let project = project.upgrade().context("project was dropped")?;
 636
 637                        if let Some(WorktreeState::Registering(state)) =
 638                            project_state.worktrees.remove(&worktree_id)
 639                        {
 640                            changed_paths.extend(state.changed_paths);
 641                        }
 642                        project_state.worktrees.insert(
 643                            worktree_id,
 644                            WorktreeState::Registered(RegisteredWorktreeState {
 645                                db_id,
 646                                changed_paths,
 647                            }),
 648                        );
 649                        this.index_project(project, cx).detach_and_log_err(cx);
 650
 651                        anyhow::Ok(())
 652                    })??;
 653
 654                    anyhow::Ok(())
 655                };
 656
 657                if register.await.log_err().is_none() {
 658                    // Stop tracking this worktree if the registration failed.
 659                    this.update(&mut cx, |this, _| {
 660                        if let Some(project_state) = this.projects.get_mut(&project) {
 661                            project_state.worktrees.remove(&worktree_id);
 662                        }
 663                    })
 664                    .ok();
 665                }
 666
 667                *done_tx.borrow_mut() = Some(());
 668            }
 669        });
 670        project_state.worktrees.insert(
 671            worktree_id,
 672            WorktreeState::Registering(RegisteringWorktreeState {
 673                changed_paths: Default::default(),
 674                done_rx,
 675                _registration: registration,
 676            }),
 677        );
 678    }
 679
 680    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
 681        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 682        {
 683            project_state
 684        } else {
 685            return;
 686        };
 687
 688        let mut worktrees = project
 689            .read(cx)
 690            .worktrees()
 691            .filter(|worktree| worktree.read(cx).is_local())
 692            .collect::<Vec<_>>();
 693        let worktree_ids = worktrees
 694            .iter()
 695            .map(|worktree| worktree.read(cx).id())
 696            .collect::<HashSet<_>>();
 697
 698        // Remove worktrees that are no longer present
 699        project_state
 700            .worktrees
 701            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 702
 703        // Register new worktrees
 704        worktrees.retain(|worktree| {
 705            let worktree_id = worktree.read(cx).id();
 706            !project_state.worktrees.contains_key(&worktree_id)
 707        });
 708        for worktree in worktrees {
 709            self.register_worktree(project.clone(), worktree, cx);
 710        }
 711    }
 712
 713    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
 714        Some(
 715            self.projects
 716                .get(&project.downgrade())?
 717                .pending_file_count_rx
 718                .clone(),
 719        )
 720    }
 721
 722    pub fn search_project(
 723        &mut self,
 724        project: Model<Project>,
 725        query: String,
 726        limit: usize,
 727        includes: Vec<PathMatcher>,
 728        excludes: Vec<PathMatcher>,
 729        cx: &mut ModelContext<Self>,
 730    ) -> Task<Result<Vec<SearchResult>>> {
 731        if query.is_empty() {
 732            return Task::ready(Ok(Vec::new()));
 733        }
 734
 735        let index = self.index_project(project.clone(), cx);
 736        let embedding_provider = self.embedding_provider.clone();
 737
 738        cx.spawn(|this, mut cx| async move {
 739            index.await?;
 740            let t0 = Instant::now();
 741
 742            let query = embedding_provider
 743                .embed_batch(vec![query])
 744                .await?
 745                .pop()
 746                .context("could not embed query")?;
 747            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 748
 749            let search_start = Instant::now();
 750            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 751                this.search_modified_buffers(
 752                    &project,
 753                    query.clone(),
 754                    limit,
 755                    &includes,
 756                    &excludes,
 757                    cx,
 758                )
 759            })?;
 760            let file_results = this.update(&mut cx, |this, cx| {
 761                this.search_files(project, query, limit, includes, excludes, cx)
 762            })?;
 763            let (modified_buffer_results, file_results) =
 764                futures::join!(modified_buffer_results, file_results);
 765
 766            // Weave together the results from modified buffers and files.
 767            let mut results = Vec::new();
 768            let mut modified_buffers = HashSet::default();
 769            for result in modified_buffer_results.log_err().unwrap_or_default() {
 770                modified_buffers.insert(result.buffer.clone());
 771                results.push(result);
 772            }
 773            for result in file_results.log_err().unwrap_or_default() {
 774                if !modified_buffers.contains(&result.buffer) {
 775                    results.push(result);
 776                }
 777            }
 778            results.sort_by_key(|result| Reverse(result.similarity));
 779            results.truncate(limit);
 780            log::trace!("Semantic search took {:?}", search_start.elapsed());
 781            Ok(results)
 782        })
 783    }
 784
 785    pub fn search_files(
 786        &mut self,
 787        project: Model<Project>,
 788        query: Embedding,
 789        limit: usize,
 790        includes: Vec<PathMatcher>,
 791        excludes: Vec<PathMatcher>,
 792        cx: &mut ModelContext<Self>,
 793    ) -> Task<Result<Vec<SearchResult>>> {
 794        let db_path = self.db.path().clone();
 795        let fs = self.fs.clone();
 796        cx.spawn(|this, mut cx| async move {
 797            let database = VectorDatabase::new(
 798                fs.clone(),
 799                db_path.clone(),
 800                cx.background_executor().clone(),
 801            )
 802            .await?;
 803
 804            let worktree_db_ids = this.read_with(&cx, |this, _| {
 805                let project_state = this
 806                    .projects
 807                    .get(&project.downgrade())
 808                    .context("project was not indexed")?;
 809                let worktree_db_ids = project_state
 810                    .worktrees
 811                    .values()
 812                    .filter_map(|worktree| {
 813                        if let WorktreeState::Registered(worktree) = worktree {
 814                            Some(worktree.db_id)
 815                        } else {
 816                            None
 817                        }
 818                    })
 819                    .collect::<Vec<i64>>();
 820                anyhow::Ok(worktree_db_ids)
 821            })??;
 822
 823            let file_ids = database
 824                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 825                .await?;
 826
 827            let batch_n = cx.background_executor().num_cpus();
 828            let ids_len = file_ids.clone().len();
 829            let minimum_batch_size = 50;
 830
 831            let batch_size = {
 832                let size = ids_len / batch_n;
 833                if size < minimum_batch_size {
 834                    minimum_batch_size
 835                } else {
 836                    size
 837                }
 838            };
 839
 840            let mut batch_results = Vec::new();
 841            for batch in file_ids.chunks(batch_size) {
 842                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 843                let fs = fs.clone();
 844                let db_path = db_path.clone();
 845                let query = query.clone();
 846                if let Some(db) =
 847                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
 848                        .await
 849                        .log_err()
 850                {
 851                    batch_results.push(async move {
 852                        db.top_k_search(&query, limit, batch.as_slice()).await
 853                    });
 854                }
 855            }
 856
 857            let batch_results = futures::future::join_all(batch_results).await;
 858
 859            let mut results = Vec::new();
 860            for batch_result in batch_results {
 861                if batch_result.is_ok() {
 862                    for (id, similarity) in batch_result.unwrap() {
 863                        let ix = match results
 864                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 865                        {
 866                            Ok(ix) => ix,
 867                            Err(ix) => ix,
 868                        };
 869
 870                        results.insert(ix, (id, similarity));
 871                        results.truncate(limit);
 872                    }
 873                }
 874            }
 875
 876            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 877            let scores = results
 878                .into_iter()
 879                .map(|(_, score)| score)
 880                .collect::<Vec<_>>();
 881            let spans = database.spans_for_ids(ids.as_slice()).await?;
 882
 883            let mut tasks = Vec::new();
 884            let mut ranges = Vec::new();
 885            let weak_project = project.downgrade();
 886            project.update(&mut cx, |project, cx| {
 887                let this = this.upgrade().context("index was dropped")?;
 888                for (worktree_db_id, file_path, byte_range) in spans {
 889                    let project_state =
 890                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 891                            state
 892                        } else {
 893                            return Err(anyhow!("project not added"));
 894                        };
 895                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 896                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 897                        ranges.push(byte_range);
 898                    }
 899                }
 900
 901                Ok(())
 902            })??;
 903
 904            let buffers = futures::future::join_all(tasks).await;
 905            Ok(buffers
 906                .into_iter()
 907                .zip(ranges)
 908                .zip(scores)
 909                .filter_map(|((buffer, range), similarity)| {
 910                    let buffer = buffer.log_err()?;
 911                    let range = buffer
 912                        .read_with(&cx, |buffer, _| {
 913                            let start = buffer.clip_offset(range.start, Bias::Left);
 914                            let end = buffer.clip_offset(range.end, Bias::Right);
 915                            buffer.anchor_before(start)..buffer.anchor_after(end)
 916                        })
 917                        .log_err()?;
 918                    Some(SearchResult {
 919                        buffer,
 920                        range,
 921                        similarity,
 922                    })
 923                })
 924                .collect())
 925        })
 926    }
 927
 928    fn search_modified_buffers(
 929        &self,
 930        project: &Model<Project>,
 931        query: Embedding,
 932        limit: usize,
 933        includes: &[PathMatcher],
 934        excludes: &[PathMatcher],
 935        cx: &mut ModelContext<Self>,
 936    ) -> Task<Result<Vec<SearchResult>>> {
 937        let modified_buffers = project
 938            .read(cx)
 939            .opened_buffers()
 940            .into_iter()
 941            .filter_map(|buffer_handle| {
 942                let buffer = buffer_handle.read(cx);
 943                let snapshot = buffer.snapshot();
 944                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 945                    excludes.iter().any(|matcher| matcher.is_match(&path))
 946                });
 947
 948                let included = if includes.len() == 0 {
 949                    true
 950                } else {
 951                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 952                        includes.iter().any(|matcher| matcher.is_match(&path))
 953                    })
 954                };
 955
 956                if buffer.is_dirty() && !excluded && included {
 957                    Some((buffer_handle, snapshot))
 958                } else {
 959                    None
 960                }
 961            })
 962            .collect::<HashMap<_, _>>();
 963
 964        let embedding_provider = self.embedding_provider.clone();
 965        let fs = self.fs.clone();
 966        let db_path = self.db.path().clone();
 967        let background = cx.background_executor().clone();
 968        cx.background_executor().spawn(async move {
 969            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 970            let mut results = Vec::<SearchResult>::new();
 971
 972            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 973            for (buffer, snapshot) in modified_buffers {
 974                let language = snapshot
 975                    .language_at(0)
 976                    .cloned()
 977                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 978                let mut spans = retriever
 979                    .parse_file_with_template(None, &snapshot.text(), language)
 980                    .log_err()
 981                    .unwrap_or_default();
 982                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 983                    .await
 984                    .log_err()
 985                    .is_some()
 986                {
 987                    for span in spans {
 988                        let similarity = span.embedding.unwrap().similarity(&query);
 989                        let ix = match results
 990                            .binary_search_by_key(&Reverse(similarity), |result| {
 991                                Reverse(result.similarity)
 992                            }) {
 993                            Ok(ix) => ix,
 994                            Err(ix) => ix,
 995                        };
 996
 997                        let range = {
 998                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 999                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
1000                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
1001                        };
1002
1003                        results.insert(
1004                            ix,
1005                            SearchResult {
1006                                buffer: buffer.clone(),
1007                                range,
1008                                similarity,
1009                            },
1010                        );
1011                        results.truncate(limit);
1012                    }
1013                }
1014            }
1015
1016            Ok(results)
1017        })
1018    }
1019
1020    pub fn index_project(
1021        &mut self,
1022        project: Model<Project>,
1023        cx: &mut ModelContext<Self>,
1024    ) -> Task<Result<()>> {
1025        if self.is_authenticated() {
1026            self.index_project_internal(project, cx)
1027        } else {
1028            let authenticate = self.authenticate(cx);
1029            cx.spawn(|this, mut cx| async move {
1030                if authenticate.await {
1031                    this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1032                        .await
1033                } else {
1034                    Err(anyhow!("user is not authenticated"))
1035                }
1036            })
1037        }
1038    }
1039
1040    fn index_project_internal(
1041        &mut self,
1042        project: Model<Project>,
1043        cx: &mut ModelContext<Self>,
1044    ) -> Task<Result<()>> {
1045        if !self.projects.contains_key(&project.downgrade()) {
1046            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1047                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1048                    this.project_worktrees_changed(project.clone(), cx);
1049                }
1050                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1051                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1052                }
1053                _ => {}
1054            });
1055            let project_state = ProjectState::new(subscription, cx);
1056            self.projects.insert(project.downgrade(), project_state);
1057            self.project_worktrees_changed(project.clone(), cx);
1058        }
1059        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1060        project_state.pending_index += 1;
1061        cx.notify();
1062
1063        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1064        let db = self.db.clone();
1065        let language_registry = self.language_registry.clone();
1066        let parsing_files_tx = self.parsing_files_tx.clone();
1067        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1068
1069        cx.spawn(|this, mut cx| async move {
1070            worktree_registration.await?;
1071
1072            let mut pending_files = Vec::new();
1073            let mut files_to_delete = Vec::new();
1074            this.update(&mut cx, |this, cx| {
1075                let project_state = this
1076                    .projects
1077                    .get_mut(&project.downgrade())
1078                    .context("project was dropped")?;
1079                let pending_file_count_tx = &project_state.pending_file_count_tx;
1080
1081                project_state
1082                    .worktrees
1083                    .retain(|worktree_id, worktree_state| {
1084                        let worktree = if let Some(worktree) =
1085                            project.read(cx).worktree_for_id(*worktree_id, cx)
1086                        {
1087                            worktree
1088                        } else {
1089                            return false;
1090                        };
1091                        let worktree_state =
1092                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1093                                worktree_state
1094                            } else {
1095                                return true;
1096                            };
1097
1098                        for (path, info) in &worktree_state.changed_paths {
1099                            if info.is_deleted {
1100                                files_to_delete.push((worktree_state.db_id, path.clone()));
1101                            } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1102                                let job_handle = JobHandle::new(pending_file_count_tx);
1103                                pending_files.push(PendingFile {
1104                                    absolute_path,
1105                                    relative_path: path.clone(),
1106                                    language: None,
1107                                    job_handle,
1108                                    modified_time: info.mtime,
1109                                    worktree_db_id: worktree_state.db_id,
1110                                });
1111                            }
1112                        }
1113                        worktree_state.changed_paths.clear();
1114                        true
1115                    });
1116
1117                anyhow::Ok(())
1118            })??;
1119
1120            cx.background_executor()
1121                .spawn(async move {
1122                    for (worktree_db_id, path) in files_to_delete {
1123                        db.delete_file(worktree_db_id, path).await.log_err();
1124                    }
1125
1126                    let embeddings_for_digest = {
1127                        let mut files = HashMap::default();
1128                        for pending_file in &pending_files {
1129                            files
1130                                .entry(pending_file.worktree_db_id)
1131                                .or_insert(Vec::new())
1132                                .push(pending_file.relative_path.clone());
1133                        }
1134                        Arc::new(
1135                            db.embeddings_for_files(files)
1136                                .await
1137                                .log_err()
1138                                .unwrap_or_default(),
1139                        )
1140                    };
1141
1142                    for mut pending_file in pending_files {
1143                        if let Ok(language) = language_registry
1144                            .language_for_file(&pending_file.relative_path, None)
1145                            .await
1146                        {
1147                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1148                                && &language.name().as_ref() != &"Markdown"
1149                                && language
1150                                    .grammar()
1151                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1152                                    .is_none()
1153                            {
1154                                continue;
1155                            }
1156                            pending_file.language = Some(language);
1157                        }
1158                        parsing_files_tx
1159                            .try_send((embeddings_for_digest.clone(), pending_file))
1160                            .ok();
1161                    }
1162
1163                    // Wait until we're done indexing.
1164                    while let Some(count) = pending_file_count_rx.next().await {
1165                        if count == 0 {
1166                            break;
1167                        }
1168                    }
1169                })
1170                .await;
1171
1172            this.update(&mut cx, |this, cx| {
1173                let project_state = this
1174                    .projects
1175                    .get_mut(&project.downgrade())
1176                    .context("project was dropped")?;
1177                project_state.pending_index -= 1;
1178                cx.notify();
1179                anyhow::Ok(())
1180            })??;
1181
1182            Ok(())
1183        })
1184    }
1185
1186    fn wait_for_worktree_registration(
1187        &self,
1188        project: &Model<Project>,
1189        cx: &mut ModelContext<Self>,
1190    ) -> Task<Result<()>> {
1191        let project = project.downgrade();
1192        cx.spawn(|this, cx| async move {
1193            loop {
1194                let mut pending_worktrees = Vec::new();
1195                this.upgrade()
1196                    .context("semantic index dropped")?
1197                    .read_with(&cx, |this, _| {
1198                        if let Some(project) = this.projects.get(&project) {
1199                            for worktree in project.worktrees.values() {
1200                                if let WorktreeState::Registering(worktree) = worktree {
1201                                    pending_worktrees.push(worktree.done());
1202                                }
1203                            }
1204                        }
1205                    })?;
1206
1207                if pending_worktrees.is_empty() {
1208                    break;
1209                } else {
1210                    future::join_all(pending_worktrees).await;
1211                }
1212            }
1213            Ok(())
1214        })
1215    }
1216
1217    async fn embed_spans(
1218        spans: &mut [Span],
1219        embedding_provider: &dyn EmbeddingProvider,
1220        db: &VectorDatabase,
1221    ) -> Result<()> {
1222        let mut batch = Vec::new();
1223        let mut batch_tokens = 0;
1224        let mut embeddings = Vec::new();
1225
1226        let digests = spans
1227            .iter()
1228            .map(|span| span.digest.clone())
1229            .collect::<Vec<_>>();
1230        let embeddings_for_digests = db
1231            .embeddings_for_digests(digests)
1232            .await
1233            .log_err()
1234            .unwrap_or_default();
1235
1236        for span in &*spans {
1237            if embeddings_for_digests.contains_key(&span.digest) {
1238                continue;
1239            };
1240
1241            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1242                let batch_embeddings = embedding_provider
1243                    .embed_batch(mem::take(&mut batch))
1244                    .await?;
1245                embeddings.extend(batch_embeddings);
1246                batch_tokens = 0;
1247            }
1248
1249            batch_tokens += span.token_count;
1250            batch.push(span.content.clone());
1251        }
1252
1253        if !batch.is_empty() {
1254            let batch_embeddings = embedding_provider
1255                .embed_batch(mem::take(&mut batch))
1256                .await?;
1257
1258            embeddings.extend(batch_embeddings);
1259        }
1260
1261        let mut embeddings = embeddings.into_iter();
1262        for span in spans {
1263            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1264                Some(embedding.clone())
1265            } else {
1266                embeddings.next()
1267            };
1268            let embedding = embedding.context("failed to embed spans")?;
1269            span.embedding = Some(embedding);
1270        }
1271        Ok(())
1272    }
1273}
1274
1275impl Drop for JobHandle {
1276    fn drop(&mut self) {
1277        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1278            // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1279            if let Some(tx) = inner.upgrade() {
1280                let mut tx = tx.lock();
1281                *tx.borrow_mut() -= 1;
1282            }
1283        }
1284    }
1285}
1286
1287#[cfg(test)]
1288mod tests {
1289
1290    use super::*;
1291    #[test]
1292    fn test_job_handle() {
1293        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1294        let tx = Arc::new(Mutex::new(job_count_tx));
1295        let job_handle = JobHandle::new(&tx);
1296
1297        assert_eq!(1, *job_count_rx.borrow());
1298        let new_job_handle = job_handle.clone();
1299        assert_eq!(1, *job_count_rx.borrow());
1300        drop(job_handle);
1301        assert_eq!(1, *job_count_rx.borrow());
1302        drop(new_job_handle);
1303        assert_eq!(0, *job_count_rx.borrow());
1304    }
1305}