semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider};
  11use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
  12use anyhow::{anyhow, Context as _, Result};
  13use collections::{BTreeMap, HashMap, HashSet};
  14use db::VectorDatabase;
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{
  18    AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
  19    ViewContext, WeakModel,
  20};
  21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  22use lazy_static::lazy_static;
  23use ordered_float::OrderedFloat;
  24use parking_lot::Mutex;
  25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  26use postage::watch;
  27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  28use release_channel::ReleaseChannel;
  29use settings::Settings;
  30use smol::channel;
  31use std::{
  32    cmp::Reverse,
  33    env,
  34    future::Future,
  35    mem,
  36    ops::Range,
  37    path::{Path, PathBuf},
  38    sync::{Arc, Weak},
  39    time::{Duration, Instant, SystemTime},
  40};
  41use util::paths::PathMatcher;
  42use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  43use workspace::Workspace;
  44
  45const SEMANTIC_INDEX_VERSION: usize = 11;
  46const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  47const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  48
  49lazy_static! {
  50    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  51}
  52
  53pub fn init(
  54    fs: Arc<dyn Fs>,
  55    http_client: Arc<dyn HttpClient>,
  56    language_registry: Arc<LanguageRegistry>,
  57    cx: &mut AppContext,
  58) {
  59    SemanticIndexSettings::register(cx);
  60
  61    let db_file_path = EMBEDDINGS_DIR
  62        .join(Path::new(ReleaseChannel::global(cx).dev_name()))
  63        .join("embeddings_db");
  64
  65    cx.observe_new_views(
  66        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
  67            let Some(semantic_index) = SemanticIndex::global(cx) else {
  68                return;
  69            };
  70            let project = workspace.project().clone();
  71
  72            if project.read(cx).is_local() {
  73                cx.app_mut()
  74                    .spawn(|mut cx| async move {
  75                        let previously_indexed = semantic_index
  76                            .update(&mut cx, |index, cx| {
  77                                index.project_previously_indexed(&project, cx)
  78                            })?
  79                            .await?;
  80                        if previously_indexed {
  81                            semantic_index
  82                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
  83                                .await?;
  84                        }
  85                        anyhow::Ok(())
  86                    })
  87                    .detach_and_log_err(cx);
  88            }
  89        },
  90    )
  91    .detach();
  92
  93    cx.spawn(move |cx| async move {
  94        let embedding_provider = OpenAiEmbeddingProvider::new(
  95            // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
  96            OPEN_AI_API_URL.to_string(),
  97            http_client,
  98            cx.background_executor().clone(),
  99        )
 100        .await;
 101        let semantic_index = SemanticIndex::new(
 102            fs,
 103            db_file_path,
 104            Arc::new(embedding_provider),
 105            language_registry,
 106            cx.clone(),
 107        )
 108        .await?;
 109
 110        cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
 111
 112        anyhow::Ok(())
 113    })
 114    .detach();
 115}
 116
 117#[derive(Copy, Clone, Debug)]
 118pub enum SemanticIndexStatus {
 119    NotAuthenticated,
 120    NotIndexed,
 121    Indexed,
 122    Indexing {
 123        remaining_files: usize,
 124        rate_limit_expiry: Option<Instant>,
 125    },
 126}
 127
 128pub struct SemanticIndex {
 129    fs: Arc<dyn Fs>,
 130    db: VectorDatabase,
 131    embedding_provider: Arc<dyn EmbeddingProvider>,
 132    language_registry: Arc<LanguageRegistry>,
 133    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 134    _embedding_task: Task<()>,
 135    _parsing_files_tasks: Vec<Task<()>>,
 136    projects: HashMap<WeakModel<Project>, ProjectState>,
 137}
 138
 139struct GlobalSemanticIndex(Model<SemanticIndex>);
 140
 141impl Global for GlobalSemanticIndex {}
 142
 143struct ProjectState {
 144    worktrees: HashMap<WorktreeId, WorktreeState>,
 145    pending_file_count_rx: watch::Receiver<usize>,
 146    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 147    pending_index: usize,
 148    _subscription: gpui::Subscription,
 149    _observe_pending_file_count: Task<()>,
 150}
 151
 152enum WorktreeState {
 153    Registering(RegisteringWorktreeState),
 154    Registered(RegisteredWorktreeState),
 155}
 156
 157impl WorktreeState {
 158    fn is_registered(&self) -> bool {
 159        matches!(self, Self::Registered(_))
 160    }
 161
 162    fn paths_changed(
 163        &mut self,
 164        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 165        worktree: &Worktree,
 166    ) {
 167        let changed_paths = match self {
 168            Self::Registering(state) => &mut state.changed_paths,
 169            Self::Registered(state) => &mut state.changed_paths,
 170        };
 171
 172        for (path, entry_id, change) in changes.iter() {
 173            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 174                continue;
 175            };
 176            let Some(mtime) = entry.mtime else {
 177                continue;
 178            };
 179            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 180                continue;
 181            }
 182            changed_paths.insert(
 183                path.clone(),
 184                ChangedPathInfo {
 185                    mtime,
 186                    is_deleted: *change == PathChange::Removed,
 187                },
 188            );
 189        }
 190    }
 191}
 192
 193struct RegisteringWorktreeState {
 194    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 195    done_rx: watch::Receiver<Option<()>>,
 196    _registration: Task<()>,
 197}
 198
 199impl RegisteringWorktreeState {
 200    fn done(&self) -> impl Future<Output = ()> {
 201        let mut done_rx = self.done_rx.clone();
 202        async move {
 203            while let Some(result) = done_rx.next().await {
 204                if result.is_some() {
 205                    break;
 206                }
 207            }
 208        }
 209    }
 210}
 211
 212struct RegisteredWorktreeState {
 213    db_id: i64,
 214    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 215}
 216
 217struct ChangedPathInfo {
 218    mtime: SystemTime,
 219    is_deleted: bool,
 220}
 221
 222#[derive(Clone)]
 223pub struct JobHandle {
 224    /// The outer Arc is here to count the clones of a JobHandle instance;
 225    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 226    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 227}
 228
 229impl JobHandle {
 230    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 231        *tx.lock().borrow_mut() += 1;
 232        Self {
 233            tx: Arc::new(Arc::downgrade(&tx)),
 234        }
 235    }
 236}
 237
 238impl ProjectState {
 239    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 240        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 241        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 242        Self {
 243            worktrees: Default::default(),
 244            pending_file_count_rx: pending_file_count_rx.clone(),
 245            pending_file_count_tx,
 246            pending_index: 0,
 247            _subscription: subscription,
 248            _observe_pending_file_count: cx.spawn({
 249                let mut pending_file_count_rx = pending_file_count_rx.clone();
 250                |this, mut cx| async move {
 251                    while let Some(_) = pending_file_count_rx.next().await {
 252                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
 253                            break;
 254                        }
 255                    }
 256                }
 257            }),
 258        }
 259    }
 260
 261    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 262        self.worktrees
 263            .iter()
 264            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 265                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 266                _ => None,
 267            })
 268    }
 269}
 270
 271#[derive(Clone)]
 272pub struct PendingFile {
 273    worktree_db_id: i64,
 274    relative_path: Arc<Path>,
 275    absolute_path: PathBuf,
 276    language: Option<Arc<Language>>,
 277    modified_time: SystemTime,
 278    job_handle: JobHandle,
 279}
 280
 281#[derive(Clone)]
 282pub struct SearchResult {
 283    pub buffer: Model<Buffer>,
 284    pub range: Range<Anchor>,
 285    pub similarity: OrderedFloat<f32>,
 286}
 287
 288impl SemanticIndex {
 289    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
 290        cx.try_global::<GlobalSemanticIndex>()
 291            .map(|semantic_index| semantic_index.0.clone())
 292    }
 293
 294    pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
 295        if !self.embedding_provider.has_credentials() {
 296            let embedding_provider = self.embedding_provider.clone();
 297            cx.spawn(|cx| async move {
 298                if let Some(retrieve_credentials) = cx
 299                    .update(|cx| embedding_provider.retrieve_credentials(cx))
 300                    .log_err()
 301                {
 302                    retrieve_credentials.await;
 303                }
 304
 305                embedding_provider.has_credentials()
 306            })
 307        } else {
 308            Task::ready(true)
 309        }
 310    }
 311
 312    pub fn is_authenticated(&self) -> bool {
 313        self.embedding_provider.has_credentials()
 314    }
 315
 316    pub fn enabled(cx: &AppContext) -> bool {
 317        SemanticIndexSettings::get_global(cx).enabled
 318    }
 319
 320    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
 321        if !self.is_authenticated() {
 322            return SemanticIndexStatus::NotAuthenticated;
 323        }
 324
 325        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 326            if project_state
 327                .worktrees
 328                .values()
 329                .all(|worktree| worktree.is_registered())
 330                && project_state.pending_index == 0
 331            {
 332                SemanticIndexStatus::Indexed
 333            } else {
 334                SemanticIndexStatus::Indexing {
 335                    remaining_files: *project_state.pending_file_count_rx.borrow(),
 336                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 337                }
 338            }
 339        } else {
 340            SemanticIndexStatus::NotIndexed
 341        }
 342    }
 343
 344    pub async fn new(
 345        fs: Arc<dyn Fs>,
 346        database_path: PathBuf,
 347        embedding_provider: Arc<dyn EmbeddingProvider>,
 348        language_registry: Arc<LanguageRegistry>,
 349        mut cx: AsyncAppContext,
 350    ) -> Result<Model<Self>> {
 351        let t0 = Instant::now();
 352        let database_path = Arc::from(database_path);
 353        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
 354            .await?;
 355
 356        log::trace!(
 357            "db initialization took {:?} milliseconds",
 358            t0.elapsed().as_millis()
 359        );
 360
 361        cx.new_model(|cx| {
 362            let t0 = Instant::now();
 363            let embedding_queue =
 364                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
 365            let _embedding_task = cx.background_executor().spawn({
 366                let embedded_files = embedding_queue.finished_files();
 367                let db = db.clone();
 368                async move {
 369                    while let Ok(file) = embedded_files.recv().await {
 370                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 371                            .await
 372                            .log_err();
 373                    }
 374                }
 375            });
 376
 377            // Parse files into embeddable spans.
 378            let (parsing_files_tx, parsing_files_rx) =
 379                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 380            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 381            let mut _parsing_files_tasks = Vec::new();
 382            for _ in 0..cx.background_executor().num_cpus() {
 383                let fs = fs.clone();
 384                let mut parsing_files_rx = parsing_files_rx.clone();
 385                let embedding_provider = embedding_provider.clone();
 386                let embedding_queue = embedding_queue.clone();
 387                let background = cx.background_executor().clone();
 388                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
 389                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 390                    loop {
 391                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 392                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 393                        futures::select_biased! {
 394                            next_file_to_parse = next_file_to_parse => {
 395                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 396                                    Self::parse_file(
 397                                        &fs,
 398                                        pending_file,
 399                                        &mut retriever,
 400                                        &embedding_queue,
 401                                        &embeddings_for_digest,
 402                                    )
 403                                    .await
 404                                } else {
 405                                    break;
 406                                }
 407                            },
 408                            _ = timer => {
 409                                embedding_queue.lock().flush();
 410                            }
 411                        }
 412                    }
 413                }));
 414            }
 415
 416            log::trace!(
 417                "semantic index task initialization took {:?} milliseconds",
 418                t0.elapsed().as_millis()
 419            );
 420            Self {
 421                fs,
 422                db,
 423                embedding_provider,
 424                language_registry,
 425                parsing_files_tx,
 426                _embedding_task,
 427                _parsing_files_tasks,
 428                projects: Default::default(),
 429            }
 430        })
 431    }
 432
 433    async fn parse_file(
 434        fs: &Arc<dyn Fs>,
 435        pending_file: PendingFile,
 436        retriever: &mut CodeContextRetriever,
 437        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 438        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 439    ) {
 440        let Some(language) = pending_file.language else {
 441            return;
 442        };
 443
 444        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 445            if let Some(mut spans) = retriever
 446                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 447                .log_err()
 448            {
 449                log::trace!(
 450                    "parsed path {:?}: {} spans",
 451                    pending_file.relative_path,
 452                    spans.len()
 453                );
 454
 455                for span in &mut spans {
 456                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 457                        span.embedding = Some(embedding.to_owned());
 458                    }
 459                }
 460
 461                embedding_queue.lock().push(FileToEmbed {
 462                    worktree_id: pending_file.worktree_db_id,
 463                    path: pending_file.relative_path,
 464                    mtime: pending_file.modified_time,
 465                    job_handle: pending_file.job_handle,
 466                    spans,
 467                });
 468            }
 469        }
 470    }
 471
 472    pub fn project_previously_indexed(
 473        &mut self,
 474        project: &Model<Project>,
 475        cx: &mut ModelContext<Self>,
 476    ) -> Task<Result<bool>> {
 477        let worktrees_indexed_previously = project
 478            .read(cx)
 479            .worktrees()
 480            .map(|worktree| {
 481                self.db
 482                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 483            })
 484            .collect::<Vec<_>>();
 485        cx.spawn(|_, _cx| async move {
 486            let worktree_indexed_previously =
 487                futures::future::join_all(worktrees_indexed_previously).await;
 488
 489            Ok(worktree_indexed_previously
 490                .iter()
 491                .filter(|worktree| worktree.is_ok())
 492                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 493        })
 494    }
 495
 496    fn project_entries_changed(
 497        &mut self,
 498        project: Model<Project>,
 499        worktree_id: WorktreeId,
 500        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 501        cx: &mut ModelContext<Self>,
 502    ) {
 503        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) else {
 504            return;
 505        };
 506        let project = project.downgrade();
 507        let Some(project_state) = self.projects.get_mut(&project) else {
 508            return;
 509        };
 510
 511        let worktree = worktree.read(cx);
 512        let worktree_state =
 513            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 514                worktree_state
 515            } else {
 516                return;
 517            };
 518        worktree_state.paths_changed(changes, worktree);
 519        if let WorktreeState::Registered(_) = worktree_state {
 520            cx.spawn(|this, mut cx| async move {
 521                cx.background_executor()
 522                    .timer(BACKGROUND_INDEXING_DELAY)
 523                    .await;
 524                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
 525                    this.update(&mut cx, |this, cx| {
 526                        this.index_project(project, cx).detach_and_log_err(cx)
 527                    })?;
 528                }
 529                anyhow::Ok(())
 530            })
 531            .detach_and_log_err(cx);
 532        }
 533    }
 534
 535    fn register_worktree(
 536        &mut self,
 537        project: Model<Project>,
 538        worktree: Model<Worktree>,
 539        cx: &mut ModelContext<Self>,
 540    ) {
 541        let project = project.downgrade();
 542        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 543            project_state
 544        } else {
 545            return;
 546        };
 547        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 548            worktree
 549        } else {
 550            return;
 551        };
 552        let worktree_abs_path = worktree.abs_path().clone();
 553        let scan_complete = worktree.scan_complete();
 554        let worktree_id = worktree.id();
 555        let db = self.db.clone();
 556        let language_registry = self.language_registry.clone();
 557        let (mut done_tx, done_rx) = watch::channel();
 558        let registration = cx.spawn(|this, mut cx| {
 559            async move {
 560                let register = async {
 561                    scan_complete.await;
 562                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 563                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 564                    let worktree = if let Some(project) = project.upgrade() {
 565                        project
 566                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 567                            .ok()
 568                            .flatten()
 569                            .context("worktree not found")?
 570                    } else {
 571                        return anyhow::Ok(());
 572                    };
 573                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
 574                    let mut changed_paths = cx
 575                        .background_executor()
 576                        .spawn(async move {
 577                            let mut changed_paths = BTreeMap::new();
 578                            for file in worktree.files(false, 0) {
 579                                let absolute_path = worktree.absolutize(&file.path)?;
 580
 581                                if file.is_external || file.is_ignored || file.is_symlink {
 582                                    continue;
 583                                }
 584
 585                                if let Ok(language) = language_registry
 586                                    .language_for_file_path(&absolute_path)
 587                                    .await
 588                                {
 589                                    // Test if file is valid parseable file
 590                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 591                                        .contains(&language.name().as_ref())
 592                                        && &language.name().as_ref() != &"Markdown"
 593                                        && language
 594                                            .grammar()
 595                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 596                                            .is_none()
 597                                    {
 598                                        continue;
 599                                    }
 600                                    let Some(new_mtime) = file.mtime else {
 601                                        continue;
 602                                    };
 603
 604                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 605                                    let already_stored = stored_mtime == Some(new_mtime);
 606
 607                                    if !already_stored {
 608                                        changed_paths.insert(
 609                                            file.path.clone(),
 610                                            ChangedPathInfo {
 611                                                mtime: new_mtime,
 612                                                is_deleted: false,
 613                                            },
 614                                        );
 615                                    }
 616                                }
 617                            }
 618
 619                            // Clean up entries from database that are no longer in the worktree.
 620                            for (path, mtime) in file_mtimes {
 621                                changed_paths.insert(
 622                                    path.into(),
 623                                    ChangedPathInfo {
 624                                        mtime,
 625                                        is_deleted: true,
 626                                    },
 627                                );
 628                            }
 629
 630                            anyhow::Ok(changed_paths)
 631                        })
 632                        .await?;
 633                    this.update(&mut cx, |this, cx| {
 634                        let project_state = this
 635                            .projects
 636                            .get_mut(&project)
 637                            .context("project not registered")?;
 638                        let project = project.upgrade().context("project was dropped")?;
 639
 640                        if let Some(WorktreeState::Registering(state)) =
 641                            project_state.worktrees.remove(&worktree_id)
 642                        {
 643                            changed_paths.extend(state.changed_paths);
 644                        }
 645                        project_state.worktrees.insert(
 646                            worktree_id,
 647                            WorktreeState::Registered(RegisteredWorktreeState {
 648                                db_id,
 649                                changed_paths,
 650                            }),
 651                        );
 652                        this.index_project(project, cx).detach_and_log_err(cx);
 653
 654                        anyhow::Ok(())
 655                    })??;
 656
 657                    anyhow::Ok(())
 658                };
 659
 660                if register.await.log_err().is_none() {
 661                    // Stop tracking this worktree if the registration failed.
 662                    this.update(&mut cx, |this, _| {
 663                        if let Some(project_state) = this.projects.get_mut(&project) {
 664                            project_state.worktrees.remove(&worktree_id);
 665                        }
 666                    })
 667                    .ok();
 668                }
 669
 670                *done_tx.borrow_mut() = Some(());
 671            }
 672        });
 673        project_state.worktrees.insert(
 674            worktree_id,
 675            WorktreeState::Registering(RegisteringWorktreeState {
 676                changed_paths: Default::default(),
 677                done_rx,
 678                _registration: registration,
 679            }),
 680        );
 681    }
 682
 683    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
 684        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 685        {
 686            project_state
 687        } else {
 688            return;
 689        };
 690
 691        let mut worktrees = project
 692            .read(cx)
 693            .worktrees()
 694            .filter(|worktree| worktree.read(cx).is_local())
 695            .collect::<Vec<_>>();
 696        let worktree_ids = worktrees
 697            .iter()
 698            .map(|worktree| worktree.read(cx).id())
 699            .collect::<HashSet<_>>();
 700
 701        // Remove worktrees that are no longer present
 702        project_state
 703            .worktrees
 704            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 705
 706        // Register new worktrees
 707        worktrees.retain(|worktree| {
 708            let worktree_id = worktree.read(cx).id();
 709            !project_state.worktrees.contains_key(&worktree_id)
 710        });
 711        for worktree in worktrees {
 712            self.register_worktree(project.clone(), worktree, cx);
 713        }
 714    }
 715
 716    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
 717        Some(
 718            self.projects
 719                .get(&project.downgrade())?
 720                .pending_file_count_rx
 721                .clone(),
 722        )
 723    }
 724
 725    pub fn search_project(
 726        &mut self,
 727        project: Model<Project>,
 728        query: String,
 729        limit: usize,
 730        includes: Vec<PathMatcher>,
 731        excludes: Vec<PathMatcher>,
 732        cx: &mut ModelContext<Self>,
 733    ) -> Task<Result<Vec<SearchResult>>> {
 734        if query.is_empty() {
 735            return Task::ready(Ok(Vec::new()));
 736        }
 737
 738        let index = self.index_project(project.clone(), cx);
 739        let embedding_provider = self.embedding_provider.clone();
 740
 741        cx.spawn(|this, mut cx| async move {
 742            index.await?;
 743            let t0 = Instant::now();
 744
 745            let query = embedding_provider
 746                .embed_batch(vec![query])
 747                .await?
 748                .pop()
 749                .context("could not embed query")?;
 750            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 751
 752            let search_start = Instant::now();
 753            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 754                this.search_modified_buffers(
 755                    &project,
 756                    query.clone(),
 757                    limit,
 758                    &includes,
 759                    &excludes,
 760                    cx,
 761                )
 762            })?;
 763            let file_results = this.update(&mut cx, |this, cx| {
 764                this.search_files(project, query, limit, includes, excludes, cx)
 765            })?;
 766            let (modified_buffer_results, file_results) =
 767                futures::join!(modified_buffer_results, file_results);
 768
 769            // Weave together the results from modified buffers and files.
 770            let mut results = Vec::new();
 771            let mut modified_buffers = HashSet::default();
 772            for result in modified_buffer_results.log_err().unwrap_or_default() {
 773                modified_buffers.insert(result.buffer.clone());
 774                results.push(result);
 775            }
 776            for result in file_results.log_err().unwrap_or_default() {
 777                if !modified_buffers.contains(&result.buffer) {
 778                    results.push(result);
 779                }
 780            }
 781            results.sort_by_key(|result| Reverse(result.similarity));
 782            results.truncate(limit);
 783            log::trace!("Semantic search took {:?}", search_start.elapsed());
 784            Ok(results)
 785        })
 786    }
 787
 788    pub fn search_files(
 789        &mut self,
 790        project: Model<Project>,
 791        query: Embedding,
 792        limit: usize,
 793        includes: Vec<PathMatcher>,
 794        excludes: Vec<PathMatcher>,
 795        cx: &mut ModelContext<Self>,
 796    ) -> Task<Result<Vec<SearchResult>>> {
 797        let db_path = self.db.path().clone();
 798        let fs = self.fs.clone();
 799        cx.spawn(|this, mut cx| async move {
 800            let database = VectorDatabase::new(
 801                fs.clone(),
 802                db_path.clone(),
 803                cx.background_executor().clone(),
 804            )
 805            .await?;
 806
 807            let worktree_db_ids = this.read_with(&cx, |this, _| {
 808                let project_state = this
 809                    .projects
 810                    .get(&project.downgrade())
 811                    .context("project was not indexed")?;
 812                let worktree_db_ids = project_state
 813                    .worktrees
 814                    .values()
 815                    .filter_map(|worktree| {
 816                        if let WorktreeState::Registered(worktree) = worktree {
 817                            Some(worktree.db_id)
 818                        } else {
 819                            None
 820                        }
 821                    })
 822                    .collect::<Vec<i64>>();
 823                anyhow::Ok(worktree_db_ids)
 824            })??;
 825
 826            let file_ids = database
 827                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 828                .await?;
 829
 830            let batch_n = cx.background_executor().num_cpus();
 831            let ids_len = file_ids.clone().len();
 832            let minimum_batch_size = 50;
 833
 834            let batch_size = {
 835                let size = ids_len / batch_n;
 836                if size < minimum_batch_size {
 837                    minimum_batch_size
 838                } else {
 839                    size
 840                }
 841            };
 842
 843            let mut batch_results = Vec::new();
 844            for batch in file_ids.chunks(batch_size) {
 845                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 846                let fs = fs.clone();
 847                let db_path = db_path.clone();
 848                let query = query.clone();
 849                if let Some(db) =
 850                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
 851                        .await
 852                        .log_err()
 853                {
 854                    batch_results.push(async move {
 855                        db.top_k_search(&query, limit, batch.as_slice()).await
 856                    });
 857                }
 858            }
 859
 860            let batch_results = futures::future::join_all(batch_results).await;
 861
 862            let mut results = Vec::new();
 863            for batch_result in batch_results {
 864                if batch_result.is_ok() {
 865                    for (id, similarity) in batch_result.unwrap() {
 866                        let ix = match results
 867                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 868                        {
 869                            Ok(ix) => ix,
 870                            Err(ix) => ix,
 871                        };
 872
 873                        results.insert(ix, (id, similarity));
 874                        results.truncate(limit);
 875                    }
 876                }
 877            }
 878
 879            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 880            let scores = results
 881                .into_iter()
 882                .map(|(_, score)| score)
 883                .collect::<Vec<_>>();
 884            let spans = database.spans_for_ids(ids.as_slice()).await?;
 885
 886            let mut tasks = Vec::new();
 887            let mut ranges = Vec::new();
 888            let weak_project = project.downgrade();
 889            project.update(&mut cx, |project, cx| {
 890                let this = this.upgrade().context("index was dropped")?;
 891                for (worktree_db_id, file_path, byte_range) in spans {
 892                    let project_state =
 893                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 894                            state
 895                        } else {
 896                            return Err(anyhow!("project not added"));
 897                        };
 898                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 899                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 900                        ranges.push(byte_range);
 901                    }
 902                }
 903
 904                Ok(())
 905            })??;
 906
 907            let buffers = futures::future::join_all(tasks).await;
 908            Ok(buffers
 909                .into_iter()
 910                .zip(ranges)
 911                .zip(scores)
 912                .filter_map(|((buffer, range), similarity)| {
 913                    let buffer = buffer.log_err()?;
 914                    let range = buffer
 915                        .read_with(&cx, |buffer, _| {
 916                            let start = buffer.clip_offset(range.start, Bias::Left);
 917                            let end = buffer.clip_offset(range.end, Bias::Right);
 918                            buffer.anchor_before(start)..buffer.anchor_after(end)
 919                        })
 920                        .log_err()?;
 921                    Some(SearchResult {
 922                        buffer,
 923                        range,
 924                        similarity,
 925                    })
 926                })
 927                .collect())
 928        })
 929    }
 930
 931    fn search_modified_buffers(
 932        &self,
 933        project: &Model<Project>,
 934        query: Embedding,
 935        limit: usize,
 936        includes: &[PathMatcher],
 937        excludes: &[PathMatcher],
 938        cx: &mut ModelContext<Self>,
 939    ) -> Task<Result<Vec<SearchResult>>> {
 940        let modified_buffers = project
 941            .read(cx)
 942            .opened_buffers()
 943            .into_iter()
 944            .filter_map(|buffer_handle| {
 945                let buffer = buffer_handle.read(cx);
 946                let snapshot = buffer.snapshot();
 947                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 948                    excludes.iter().any(|matcher| matcher.is_match(&path))
 949                });
 950
 951                let included = if includes.len() == 0 {
 952                    true
 953                } else {
 954                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 955                        includes.iter().any(|matcher| matcher.is_match(&path))
 956                    })
 957                };
 958
 959                if buffer.is_dirty() && !excluded && included {
 960                    Some((buffer_handle, snapshot))
 961                } else {
 962                    None
 963                }
 964            })
 965            .collect::<HashMap<_, _>>();
 966
 967        let embedding_provider = self.embedding_provider.clone();
 968        let fs = self.fs.clone();
 969        let db_path = self.db.path().clone();
 970        let background = cx.background_executor().clone();
 971        cx.background_executor().spawn(async move {
 972            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 973            let mut results = Vec::<SearchResult>::new();
 974
 975            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 976            for (buffer, snapshot) in modified_buffers {
 977                let language = snapshot
 978                    .language_at(0)
 979                    .cloned()
 980                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 981                let mut spans = retriever
 982                    .parse_file_with_template(None, &snapshot.text(), language)
 983                    .log_err()
 984                    .unwrap_or_default();
 985                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 986                    .await
 987                    .log_err()
 988                    .is_some()
 989                {
 990                    for span in spans {
 991                        let similarity = span.embedding.unwrap().similarity(&query);
 992                        let ix = match results
 993                            .binary_search_by_key(&Reverse(similarity), |result| {
 994                                Reverse(result.similarity)
 995                            }) {
 996                            Ok(ix) => ix,
 997                            Err(ix) => ix,
 998                        };
 999
1000                        let range = {
1001                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
1002                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
1003                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
1004                        };
1005
1006                        results.insert(
1007                            ix,
1008                            SearchResult {
1009                                buffer: buffer.clone(),
1010                                range,
1011                                similarity,
1012                            },
1013                        );
1014                        results.truncate(limit);
1015                    }
1016                }
1017            }
1018
1019            Ok(results)
1020        })
1021    }
1022
1023    pub fn index_project(
1024        &mut self,
1025        project: Model<Project>,
1026        cx: &mut ModelContext<Self>,
1027    ) -> Task<Result<()>> {
1028        if self.is_authenticated() {
1029            self.index_project_internal(project, cx)
1030        } else {
1031            let authenticate = self.authenticate(cx);
1032            cx.spawn(|this, mut cx| async move {
1033                if authenticate.await {
1034                    this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1035                        .await
1036                } else {
1037                    Err(anyhow!("user is not authenticated"))
1038                }
1039            })
1040        }
1041    }
1042
1043    fn index_project_internal(
1044        &mut self,
1045        project: Model<Project>,
1046        cx: &mut ModelContext<Self>,
1047    ) -> Task<Result<()>> {
1048        if !self.projects.contains_key(&project.downgrade()) {
1049            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1050                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1051                    this.project_worktrees_changed(project.clone(), cx);
1052                }
1053                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1054                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1055                }
1056                _ => {}
1057            });
1058            let project_state = ProjectState::new(subscription, cx);
1059            self.projects.insert(project.downgrade(), project_state);
1060            self.project_worktrees_changed(project.clone(), cx);
1061        }
1062        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1063        project_state.pending_index += 1;
1064        cx.notify();
1065
1066        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1067        let db = self.db.clone();
1068        let language_registry = self.language_registry.clone();
1069        let parsing_files_tx = self.parsing_files_tx.clone();
1070        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1071
1072        cx.spawn(|this, mut cx| async move {
1073            worktree_registration.await?;
1074
1075            let mut pending_files = Vec::new();
1076            let mut files_to_delete = Vec::new();
1077            this.update(&mut cx, |this, cx| {
1078                let project_state = this
1079                    .projects
1080                    .get_mut(&project.downgrade())
1081                    .context("project was dropped")?;
1082                let pending_file_count_tx = &project_state.pending_file_count_tx;
1083
1084                project_state
1085                    .worktrees
1086                    .retain(|worktree_id, worktree_state| {
1087                        let worktree = if let Some(worktree) =
1088                            project.read(cx).worktree_for_id(*worktree_id, cx)
1089                        {
1090                            worktree
1091                        } else {
1092                            return false;
1093                        };
1094                        let worktree_state =
1095                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1096                                worktree_state
1097                            } else {
1098                                return true;
1099                            };
1100
1101                        for (path, info) in &worktree_state.changed_paths {
1102                            if info.is_deleted {
1103                                files_to_delete.push((worktree_state.db_id, path.clone()));
1104                            } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1105                                let job_handle = JobHandle::new(pending_file_count_tx);
1106                                pending_files.push(PendingFile {
1107                                    absolute_path,
1108                                    relative_path: path.clone(),
1109                                    language: None,
1110                                    job_handle,
1111                                    modified_time: info.mtime,
1112                                    worktree_db_id: worktree_state.db_id,
1113                                });
1114                            }
1115                        }
1116                        worktree_state.changed_paths.clear();
1117                        true
1118                    });
1119
1120                anyhow::Ok(())
1121            })??;
1122
1123            cx.background_executor()
1124                .spawn(async move {
1125                    for (worktree_db_id, path) in files_to_delete {
1126                        db.delete_file(worktree_db_id, path).await.log_err();
1127                    }
1128
1129                    let embeddings_for_digest = {
1130                        let mut files = HashMap::default();
1131                        for pending_file in &pending_files {
1132                            files
1133                                .entry(pending_file.worktree_db_id)
1134                                .or_insert(Vec::new())
1135                                .push(pending_file.relative_path.clone());
1136                        }
1137                        Arc::new(
1138                            db.embeddings_for_files(files)
1139                                .await
1140                                .log_err()
1141                                .unwrap_or_default(),
1142                        )
1143                    };
1144
1145                    for mut pending_file in pending_files {
1146                        if let Ok(language) = language_registry
1147                            .language_for_file_path(&pending_file.relative_path)
1148                            .await
1149                        {
1150                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1151                                && &language.name().as_ref() != &"Markdown"
1152                                && language
1153                                    .grammar()
1154                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1155                                    .is_none()
1156                            {
1157                                continue;
1158                            }
1159                            pending_file.language = Some(language);
1160                        }
1161                        parsing_files_tx
1162                            .try_send((embeddings_for_digest.clone(), pending_file))
1163                            .ok();
1164                    }
1165
1166                    // Wait until we're done indexing.
1167                    while let Some(count) = pending_file_count_rx.next().await {
1168                        if count == 0 {
1169                            break;
1170                        }
1171                    }
1172                })
1173                .await;
1174
1175            this.update(&mut cx, |this, cx| {
1176                let project_state = this
1177                    .projects
1178                    .get_mut(&project.downgrade())
1179                    .context("project was dropped")?;
1180                project_state.pending_index -= 1;
1181                cx.notify();
1182                anyhow::Ok(())
1183            })??;
1184
1185            Ok(())
1186        })
1187    }
1188
1189    fn wait_for_worktree_registration(
1190        &self,
1191        project: &Model<Project>,
1192        cx: &mut ModelContext<Self>,
1193    ) -> Task<Result<()>> {
1194        let project = project.downgrade();
1195        cx.spawn(|this, cx| async move {
1196            loop {
1197                let mut pending_worktrees = Vec::new();
1198                this.upgrade()
1199                    .context("semantic index dropped")?
1200                    .read_with(&cx, |this, _| {
1201                        if let Some(project) = this.projects.get(&project) {
1202                            for worktree in project.worktrees.values() {
1203                                if let WorktreeState::Registering(worktree) = worktree {
1204                                    pending_worktrees.push(worktree.done());
1205                                }
1206                            }
1207                        }
1208                    })?;
1209
1210                if pending_worktrees.is_empty() {
1211                    break;
1212                } else {
1213                    future::join_all(pending_worktrees).await;
1214                }
1215            }
1216            Ok(())
1217        })
1218    }
1219
1220    async fn embed_spans(
1221        spans: &mut [Span],
1222        embedding_provider: &dyn EmbeddingProvider,
1223        db: &VectorDatabase,
1224    ) -> Result<()> {
1225        let mut batch = Vec::new();
1226        let mut batch_tokens = 0;
1227        let mut embeddings = Vec::new();
1228
1229        let digests = spans
1230            .iter()
1231            .map(|span| span.digest.clone())
1232            .collect::<Vec<_>>();
1233        let embeddings_for_digests = db
1234            .embeddings_for_digests(digests)
1235            .await
1236            .log_err()
1237            .unwrap_or_default();
1238
1239        for span in &*spans {
1240            if embeddings_for_digests.contains_key(&span.digest) {
1241                continue;
1242            };
1243
1244            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1245                let batch_embeddings = embedding_provider
1246                    .embed_batch(mem::take(&mut batch))
1247                    .await?;
1248                embeddings.extend(batch_embeddings);
1249                batch_tokens = 0;
1250            }
1251
1252            batch_tokens += span.token_count;
1253            batch.push(span.content.clone());
1254        }
1255
1256        if !batch.is_empty() {
1257            let batch_embeddings = embedding_provider
1258                .embed_batch(mem::take(&mut batch))
1259                .await?;
1260
1261            embeddings.extend(batch_embeddings);
1262        }
1263
1264        let mut embeddings = embeddings.into_iter();
1265        for span in spans {
1266            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1267                Some(embedding.clone())
1268            } else {
1269                embeddings.next()
1270            };
1271            let embedding = embedding.context("failed to embed spans")?;
1272            span.embedding = Some(embedding);
1273        }
1274        Ok(())
1275    }
1276}
1277
1278impl Drop for JobHandle {
1279    fn drop(&mut self) {
1280        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1281            // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1282            if let Some(tx) = inner.upgrade() {
1283                let mut tx = tx.lock();
1284                *tx.borrow_mut() -= 1;
1285            }
1286        }
1287    }
1288}
1289
1290#[cfg(test)]
1291mod tests {
1292
1293    use super::*;
1294    #[test]
1295    fn test_job_handle() {
1296        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1297        let tx = Arc::new(Mutex::new(job_count_tx));
1298        let job_handle = JobHandle::new(&tx);
1299
1300        assert_eq!(1, *job_count_rx.borrow());
1301        let new_job_handle = job_handle.clone();
1302        assert_eq!(1, *job_count_rx.borrow());
1303        drop(job_handle);
1304        assert_eq!(1, *job_count_rx.borrow());
1305        drop(new_job_handle);
1306        assert_eq!(0, *job_count_rx.borrow());
1307    }
1308}