semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider};
  11use ai::providers::open_ai::OpenAIEmbeddingProvider;
  12use anyhow::{anyhow, Context as _, Result};
  13use collections::{BTreeMap, HashMap, HashSet};
  14use db::VectorDatabase;
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{
  18    AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
  19    WeakModel,
  20};
  21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  22use lazy_static::lazy_static;
  23use ordered_float::OrderedFloat;
  24use parking_lot::Mutex;
  25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  26use postage::watch;
  27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  28use settings::Settings;
  29use smol::channel;
  30use std::{
  31    cmp::Reverse,
  32    env,
  33    future::Future,
  34    mem,
  35    ops::Range,
  36    path::{Path, PathBuf},
  37    sync::{Arc, Weak},
  38    time::{Duration, Instant, SystemTime},
  39};
  40use util::paths::PathMatcher;
  41use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  42use workspace::Workspace;
  43
  44const SEMANTIC_INDEX_VERSION: usize = 11;
  45const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  46const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  47
  48lazy_static! {
  49    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  50}
  51
  52pub fn init(
  53    fs: Arc<dyn Fs>,
  54    http_client: Arc<dyn HttpClient>,
  55    language_registry: Arc<LanguageRegistry>,
  56    cx: &mut AppContext,
  57) {
  58    SemanticIndexSettings::register(cx);
  59
  60    let db_file_path = EMBEDDINGS_DIR
  61        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  62        .join("embeddings_db");
  63
  64    cx.observe_new_views(
  65        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
  66            let Some(semantic_index) = SemanticIndex::global(cx) else {
  67                return;
  68            };
  69            let project = workspace.project().clone();
  70
  71            if project.read(cx).is_local() {
  72                cx.app_mut()
  73                    .spawn(|mut cx| async move {
  74                        let previously_indexed = semantic_index
  75                            .update(&mut cx, |index, cx| {
  76                                index.project_previously_indexed(&project, cx)
  77                            })?
  78                            .await?;
  79                        if previously_indexed {
  80                            semantic_index
  81                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
  82                                .await?;
  83                        }
  84                        anyhow::Ok(())
  85                    })
  86                    .detach_and_log_err(cx);
  87            }
  88        },
  89    )
  90    .detach();
  91
  92    cx.spawn(move |cx| async move {
  93        let embedding_provider =
  94            OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
  95        let semantic_index = SemanticIndex::new(
  96            fs,
  97            db_file_path,
  98            Arc::new(embedding_provider),
  99            language_registry,
 100            cx.clone(),
 101        )
 102        .await?;
 103
 104        cx.update(|cx| cx.set_global(semantic_index.clone()))?;
 105
 106        anyhow::Ok(())
 107    })
 108    .detach();
 109}
 110
 111#[derive(Copy, Clone, Debug)]
 112pub enum SemanticIndexStatus {
 113    NotAuthenticated,
 114    NotIndexed,
 115    Indexed,
 116    Indexing {
 117        remaining_files: usize,
 118        rate_limit_expiry: Option<Instant>,
 119    },
 120}
 121
 122pub struct SemanticIndex {
 123    fs: Arc<dyn Fs>,
 124    db: VectorDatabase,
 125    embedding_provider: Arc<dyn EmbeddingProvider>,
 126    language_registry: Arc<LanguageRegistry>,
 127    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 128    _embedding_task: Task<()>,
 129    _parsing_files_tasks: Vec<Task<()>>,
 130    projects: HashMap<WeakModel<Project>, ProjectState>,
 131}
 132
 133struct ProjectState {
 134    worktrees: HashMap<WorktreeId, WorktreeState>,
 135    pending_file_count_rx: watch::Receiver<usize>,
 136    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 137    pending_index: usize,
 138    _subscription: gpui::Subscription,
 139    _observe_pending_file_count: Task<()>,
 140}
 141
 142enum WorktreeState {
 143    Registering(RegisteringWorktreeState),
 144    Registered(RegisteredWorktreeState),
 145}
 146
 147impl WorktreeState {
 148    fn is_registered(&self) -> bool {
 149        matches!(self, Self::Registered(_))
 150    }
 151
 152    fn paths_changed(
 153        &mut self,
 154        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 155        worktree: &Worktree,
 156    ) {
 157        let changed_paths = match self {
 158            Self::Registering(state) => &mut state.changed_paths,
 159            Self::Registered(state) => &mut state.changed_paths,
 160        };
 161
 162        for (path, entry_id, change) in changes.iter() {
 163            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 164                continue;
 165            };
 166            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 167                continue;
 168            }
 169            changed_paths.insert(
 170                path.clone(),
 171                ChangedPathInfo {
 172                    mtime: entry.mtime,
 173                    is_deleted: *change == PathChange::Removed,
 174                },
 175            );
 176        }
 177    }
 178}
 179
 180struct RegisteringWorktreeState {
 181    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 182    done_rx: watch::Receiver<Option<()>>,
 183    _registration: Task<()>,
 184}
 185
 186impl RegisteringWorktreeState {
 187    fn done(&self) -> impl Future<Output = ()> {
 188        let mut done_rx = self.done_rx.clone();
 189        async move {
 190            while let Some(result) = done_rx.next().await {
 191                if result.is_some() {
 192                    break;
 193                }
 194            }
 195        }
 196    }
 197}
 198
 199struct RegisteredWorktreeState {
 200    db_id: i64,
 201    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 202}
 203
 204struct ChangedPathInfo {
 205    mtime: SystemTime,
 206    is_deleted: bool,
 207}
 208
 209#[derive(Clone)]
 210pub struct JobHandle {
 211    /// The outer Arc is here to count the clones of a JobHandle instance;
 212    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 213    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 214}
 215
 216impl JobHandle {
 217    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 218        *tx.lock().borrow_mut() += 1;
 219        Self {
 220            tx: Arc::new(Arc::downgrade(&tx)),
 221        }
 222    }
 223}
 224
 225impl ProjectState {
 226    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 227        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 228        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 229        Self {
 230            worktrees: Default::default(),
 231            pending_file_count_rx: pending_file_count_rx.clone(),
 232            pending_file_count_tx,
 233            pending_index: 0,
 234            _subscription: subscription,
 235            _observe_pending_file_count: cx.spawn({
 236                let mut pending_file_count_rx = pending_file_count_rx.clone();
 237                |this, mut cx| async move {
 238                    while let Some(_) = pending_file_count_rx.next().await {
 239                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
 240                            break;
 241                        }
 242                    }
 243                }
 244            }),
 245        }
 246    }
 247
 248    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 249        self.worktrees
 250            .iter()
 251            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 252                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 253                _ => None,
 254            })
 255    }
 256}
 257
 258#[derive(Clone)]
 259pub struct PendingFile {
 260    worktree_db_id: i64,
 261    relative_path: Arc<Path>,
 262    absolute_path: PathBuf,
 263    language: Option<Arc<Language>>,
 264    modified_time: SystemTime,
 265    job_handle: JobHandle,
 266}
 267
 268#[derive(Clone)]
 269pub struct SearchResult {
 270    pub buffer: Model<Buffer>,
 271    pub range: Range<Anchor>,
 272    pub similarity: OrderedFloat<f32>,
 273}
 274
 275impl SemanticIndex {
 276    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
 277        cx.try_global::<Model<Self>>()
 278            .map(|semantic_index| semantic_index.clone())
 279    }
 280
 281    pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
 282        if !self.embedding_provider.has_credentials() {
 283            let embedding_provider = self.embedding_provider.clone();
 284            cx.spawn(|cx| async move {
 285                if let Some(retrieve_credentials) = cx
 286                    .update(|cx| embedding_provider.retrieve_credentials(cx))
 287                    .log_err()
 288                {
 289                    retrieve_credentials.await;
 290                }
 291
 292                embedding_provider.has_credentials()
 293            })
 294        } else {
 295            Task::ready(true)
 296        }
 297    }
 298
 299    pub fn is_authenticated(&self) -> bool {
 300        self.embedding_provider.has_credentials()
 301    }
 302
 303    pub fn enabled(cx: &AppContext) -> bool {
 304        SemanticIndexSettings::get_global(cx).enabled
 305    }
 306
 307    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
 308        if !self.is_authenticated() {
 309            return SemanticIndexStatus::NotAuthenticated;
 310        }
 311
 312        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 313            if project_state
 314                .worktrees
 315                .values()
 316                .all(|worktree| worktree.is_registered())
 317                && project_state.pending_index == 0
 318            {
 319                SemanticIndexStatus::Indexed
 320            } else {
 321                SemanticIndexStatus::Indexing {
 322                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 323                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 324                }
 325            }
 326        } else {
 327            SemanticIndexStatus::NotIndexed
 328        }
 329    }
 330
 331    pub async fn new(
 332        fs: Arc<dyn Fs>,
 333        database_path: PathBuf,
 334        embedding_provider: Arc<dyn EmbeddingProvider>,
 335        language_registry: Arc<LanguageRegistry>,
 336        mut cx: AsyncAppContext,
 337    ) -> Result<Model<Self>> {
 338        let t0 = Instant::now();
 339        let database_path = Arc::from(database_path);
 340        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
 341            .await?;
 342
 343        log::trace!(
 344            "db initialization took {:?} milliseconds",
 345            t0.elapsed().as_millis()
 346        );
 347
 348        cx.new_model(|cx| {
 349            let t0 = Instant::now();
 350            let embedding_queue =
 351                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
 352            let _embedding_task = cx.background_executor().spawn({
 353                let embedded_files = embedding_queue.finished_files();
 354                let db = db.clone();
 355                async move {
 356                    while let Ok(file) = embedded_files.recv().await {
 357                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 358                            .await
 359                            .log_err();
 360                    }
 361                }
 362            });
 363
 364            // Parse files into embeddable spans.
 365            let (parsing_files_tx, parsing_files_rx) =
 366                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 367            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 368            let mut _parsing_files_tasks = Vec::new();
 369            for _ in 0..cx.background_executor().num_cpus() {
 370                let fs = fs.clone();
 371                let mut parsing_files_rx = parsing_files_rx.clone();
 372                let embedding_provider = embedding_provider.clone();
 373                let embedding_queue = embedding_queue.clone();
 374                let background = cx.background_executor().clone();
 375                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
 376                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 377                    loop {
 378                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 379                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 380                        futures::select_biased! {
 381                            next_file_to_parse = next_file_to_parse => {
 382                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 383                                    Self::parse_file(
 384                                        &fs,
 385                                        pending_file,
 386                                        &mut retriever,
 387                                        &embedding_queue,
 388                                        &embeddings_for_digest,
 389                                    )
 390                                    .await
 391                                } else {
 392                                    break;
 393                                }
 394                            },
 395                            _ = timer => {
 396                                embedding_queue.lock().flush();
 397                            }
 398                        }
 399                    }
 400                }));
 401            }
 402
 403            log::trace!(
 404                "semantic index task initialization took {:?} milliseconds",
 405                t0.elapsed().as_millis()
 406            );
 407            Self {
 408                fs,
 409                db,
 410                embedding_provider,
 411                language_registry,
 412                parsing_files_tx,
 413                _embedding_task,
 414                _parsing_files_tasks,
 415                projects: Default::default(),
 416            }
 417        })
 418    }
 419
 420    async fn parse_file(
 421        fs: &Arc<dyn Fs>,
 422        pending_file: PendingFile,
 423        retriever: &mut CodeContextRetriever,
 424        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 425        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 426    ) {
 427        let Some(language) = pending_file.language else {
 428            return;
 429        };
 430
 431        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 432            if let Some(mut spans) = retriever
 433                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 434                .log_err()
 435            {
 436                log::trace!(
 437                    "parsed path {:?}: {} spans",
 438                    pending_file.relative_path,
 439                    spans.len()
 440                );
 441
 442                for span in &mut spans {
 443                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 444                        span.embedding = Some(embedding.to_owned());
 445                    }
 446                }
 447
 448                embedding_queue.lock().push(FileToEmbed {
 449                    worktree_id: pending_file.worktree_db_id,
 450                    path: pending_file.relative_path,
 451                    mtime: pending_file.modified_time,
 452                    job_handle: pending_file.job_handle,
 453                    spans,
 454                });
 455            }
 456        }
 457    }
 458
 459    pub fn project_previously_indexed(
 460        &mut self,
 461        project: &Model<Project>,
 462        cx: &mut ModelContext<Self>,
 463    ) -> Task<Result<bool>> {
 464        let worktrees_indexed_previously = project
 465            .read(cx)
 466            .worktrees()
 467            .map(|worktree| {
 468                self.db
 469                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 470            })
 471            .collect::<Vec<_>>();
 472        cx.spawn(|_, _cx| async move {
 473            let worktree_indexed_previously =
 474                futures::future::join_all(worktrees_indexed_previously).await;
 475
 476            Ok(worktree_indexed_previously
 477                .iter()
 478                .filter(|worktree| worktree.is_ok())
 479                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 480        })
 481    }
 482
 483    fn project_entries_changed(
 484        &mut self,
 485        project: Model<Project>,
 486        worktree_id: WorktreeId,
 487        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 488        cx: &mut ModelContext<Self>,
 489    ) {
 490        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 491            return;
 492        };
 493        let project = project.downgrade();
 494        let Some(project_state) = self.projects.get_mut(&project) else {
 495            return;
 496        };
 497
 498        let worktree = worktree.read(cx);
 499        let worktree_state =
 500            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 501                worktree_state
 502            } else {
 503                return;
 504            };
 505        worktree_state.paths_changed(changes, worktree);
 506        if let WorktreeState::Registered(_) = worktree_state {
 507            cx.spawn(|this, mut cx| async move {
 508                cx.background_executor()
 509                    .timer(BACKGROUND_INDEXING_DELAY)
 510                    .await;
 511                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
 512                    this.update(&mut cx, |this, cx| {
 513                        this.index_project(project, cx).detach_and_log_err(cx)
 514                    })?;
 515                }
 516                anyhow::Ok(())
 517            })
 518            .detach_and_log_err(cx);
 519        }
 520    }
 521
 522    fn register_worktree(
 523        &mut self,
 524        project: Model<Project>,
 525        worktree: Model<Worktree>,
 526        cx: &mut ModelContext<Self>,
 527    ) {
 528        let project = project.downgrade();
 529        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 530            project_state
 531        } else {
 532            return;
 533        };
 534        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 535            worktree
 536        } else {
 537            return;
 538        };
 539        let worktree_abs_path = worktree.abs_path().clone();
 540        let scan_complete = worktree.scan_complete();
 541        let worktree_id = worktree.id();
 542        let db = self.db.clone();
 543        let language_registry = self.language_registry.clone();
 544        let (mut done_tx, done_rx) = watch::channel();
 545        let registration = cx.spawn(|this, mut cx| {
 546            async move {
 547                let register = async {
 548                    scan_complete.await;
 549                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 550                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 551                    let worktree = if let Some(project) = project.upgrade() {
 552                        project
 553                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 554                            .ok()
 555                            .flatten()
 556                            .context("worktree not found")?
 557                    } else {
 558                        return anyhow::Ok(());
 559                    };
 560                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
 561                    let mut changed_paths = cx
 562                        .background_executor()
 563                        .spawn(async move {
 564                            let mut changed_paths = BTreeMap::new();
 565                            for file in worktree.files(false, 0) {
 566                                let absolute_path = worktree.absolutize(&file.path)?;
 567
 568                                if file.is_external || file.is_ignored || file.is_symlink {
 569                                    continue;
 570                                }
 571
 572                                if let Ok(language) = language_registry
 573                                    .language_for_file(&absolute_path, None)
 574                                    .await
 575                                {
 576                                    // Test if file is valid parseable file
 577                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 578                                        .contains(&language.name().as_ref())
 579                                        && &language.name().as_ref() != &"Markdown"
 580                                        && language
 581                                            .grammar()
 582                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 583                                            .is_none()
 584                                    {
 585                                        continue;
 586                                    }
 587
 588                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 589                                    let already_stored = stored_mtime
 590                                        .map_or(false, |existing_mtime| {
 591                                            existing_mtime == file.mtime
 592                                        });
 593
 594                                    if !already_stored {
 595                                        changed_paths.insert(
 596                                            file.path.clone(),
 597                                            ChangedPathInfo {
 598                                                mtime: file.mtime,
 599                                                is_deleted: false,
 600                                            },
 601                                        );
 602                                    }
 603                                }
 604                            }
 605
 606                            // Clean up entries from database that are no longer in the worktree.
 607                            for (path, mtime) in file_mtimes {
 608                                changed_paths.insert(
 609                                    path.into(),
 610                                    ChangedPathInfo {
 611                                        mtime,
 612                                        is_deleted: true,
 613                                    },
 614                                );
 615                            }
 616
 617                            anyhow::Ok(changed_paths)
 618                        })
 619                        .await?;
 620                    this.update(&mut cx, |this, cx| {
 621                        let project_state = this
 622                            .projects
 623                            .get_mut(&project)
 624                            .context("project not registered")?;
 625                        let project = project.upgrade().context("project was dropped")?;
 626
 627                        if let Some(WorktreeState::Registering(state)) =
 628                            project_state.worktrees.remove(&worktree_id)
 629                        {
 630                            changed_paths.extend(state.changed_paths);
 631                        }
 632                        project_state.worktrees.insert(
 633                            worktree_id,
 634                            WorktreeState::Registered(RegisteredWorktreeState {
 635                                db_id,
 636                                changed_paths,
 637                            }),
 638                        );
 639                        this.index_project(project, cx).detach_and_log_err(cx);
 640
 641                        anyhow::Ok(())
 642                    })??;
 643
 644                    anyhow::Ok(())
 645                };
 646
 647                if register.await.log_err().is_none() {
 648                    // Stop tracking this worktree if the registration failed.
 649                    this.update(&mut cx, |this, _| {
 650                        this.projects.get_mut(&project).map(|project_state| {
 651                            project_state.worktrees.remove(&worktree_id);
 652                        });
 653                    })
 654                    .ok();
 655                }
 656
 657                *done_tx.borrow_mut() = Some(());
 658            }
 659        });
 660        project_state.worktrees.insert(
 661            worktree_id,
 662            WorktreeState::Registering(RegisteringWorktreeState {
 663                changed_paths: Default::default(),
 664                done_rx,
 665                _registration: registration,
 666            }),
 667        );
 668    }
 669
 670    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
 671        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 672        {
 673            project_state
 674        } else {
 675            return;
 676        };
 677
 678        let mut worktrees = project
 679            .read(cx)
 680            .worktrees()
 681            .filter(|worktree| worktree.read(cx).is_local())
 682            .collect::<Vec<_>>();
 683        let worktree_ids = worktrees
 684            .iter()
 685            .map(|worktree| worktree.read(cx).id())
 686            .collect::<HashSet<_>>();
 687
 688        // Remove worktrees that are no longer present
 689        project_state
 690            .worktrees
 691            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 692
 693        // Register new worktrees
 694        worktrees.retain(|worktree| {
 695            let worktree_id = worktree.read(cx).id();
 696            !project_state.worktrees.contains_key(&worktree_id)
 697        });
 698        for worktree in worktrees {
 699            self.register_worktree(project.clone(), worktree, cx);
 700        }
 701    }
 702
 703    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
 704        Some(
 705            self.projects
 706                .get(&project.downgrade())?
 707                .pending_file_count_rx
 708                .clone(),
 709        )
 710    }
 711
 712    pub fn search_project(
 713        &mut self,
 714        project: Model<Project>,
 715        query: String,
 716        limit: usize,
 717        includes: Vec<PathMatcher>,
 718        excludes: Vec<PathMatcher>,
 719        cx: &mut ModelContext<Self>,
 720    ) -> Task<Result<Vec<SearchResult>>> {
 721        if query.is_empty() {
 722            return Task::ready(Ok(Vec::new()));
 723        }
 724
 725        let index = self.index_project(project.clone(), cx);
 726        let embedding_provider = self.embedding_provider.clone();
 727
 728        cx.spawn(|this, mut cx| async move {
 729            index.await?;
 730            let t0 = Instant::now();
 731
 732            let query = embedding_provider
 733                .embed_batch(vec![query])
 734                .await?
 735                .pop()
 736                .context("could not embed query")?;
 737            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 738
 739            let search_start = Instant::now();
 740            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 741                this.search_modified_buffers(
 742                    &project,
 743                    query.clone(),
 744                    limit,
 745                    &includes,
 746                    &excludes,
 747                    cx,
 748                )
 749            })?;
 750            let file_results = this.update(&mut cx, |this, cx| {
 751                this.search_files(project, query, limit, includes, excludes, cx)
 752            })?;
 753            let (modified_buffer_results, file_results) =
 754                futures::join!(modified_buffer_results, file_results);
 755
 756            // Weave together the results from modified buffers and files.
 757            let mut results = Vec::new();
 758            let mut modified_buffers = HashSet::default();
 759            for result in modified_buffer_results.log_err().unwrap_or_default() {
 760                modified_buffers.insert(result.buffer.clone());
 761                results.push(result);
 762            }
 763            for result in file_results.log_err().unwrap_or_default() {
 764                if !modified_buffers.contains(&result.buffer) {
 765                    results.push(result);
 766                }
 767            }
 768            results.sort_by_key(|result| Reverse(result.similarity));
 769            results.truncate(limit);
 770            log::trace!("Semantic search took {:?}", search_start.elapsed());
 771            Ok(results)
 772        })
 773    }
 774
 775    pub fn search_files(
 776        &mut self,
 777        project: Model<Project>,
 778        query: Embedding,
 779        limit: usize,
 780        includes: Vec<PathMatcher>,
 781        excludes: Vec<PathMatcher>,
 782        cx: &mut ModelContext<Self>,
 783    ) -> Task<Result<Vec<SearchResult>>> {
 784        let db_path = self.db.path().clone();
 785        let fs = self.fs.clone();
 786        cx.spawn(|this, mut cx| async move {
 787            let database = VectorDatabase::new(
 788                fs.clone(),
 789                db_path.clone(),
 790                cx.background_executor().clone(),
 791            )
 792            .await?;
 793
 794            let worktree_db_ids = this.read_with(&cx, |this, _| {
 795                let project_state = this
 796                    .projects
 797                    .get(&project.downgrade())
 798                    .context("project was not indexed")?;
 799                let worktree_db_ids = project_state
 800                    .worktrees
 801                    .values()
 802                    .filter_map(|worktree| {
 803                        if let WorktreeState::Registered(worktree) = worktree {
 804                            Some(worktree.db_id)
 805                        } else {
 806                            None
 807                        }
 808                    })
 809                    .collect::<Vec<i64>>();
 810                anyhow::Ok(worktree_db_ids)
 811            })??;
 812
 813            let file_ids = database
 814                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 815                .await?;
 816
 817            let batch_n = cx.background_executor().num_cpus();
 818            let ids_len = file_ids.clone().len();
 819            let minimum_batch_size = 50;
 820
 821            let batch_size = {
 822                let size = ids_len / batch_n;
 823                if size < minimum_batch_size {
 824                    minimum_batch_size
 825                } else {
 826                    size
 827                }
 828            };
 829
 830            let mut batch_results = Vec::new();
 831            for batch in file_ids.chunks(batch_size) {
 832                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 833                let limit = limit.clone();
 834                let fs = fs.clone();
 835                let db_path = db_path.clone();
 836                let query = query.clone();
 837                if let Some(db) =
 838                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
 839                        .await
 840                        .log_err()
 841                {
 842                    batch_results.push(async move {
 843                        db.top_k_search(&query, limit, batch.as_slice()).await
 844                    });
 845                }
 846            }
 847
 848            let batch_results = futures::future::join_all(batch_results).await;
 849
 850            let mut results = Vec::new();
 851            for batch_result in batch_results {
 852                if batch_result.is_ok() {
 853                    for (id, similarity) in batch_result.unwrap() {
 854                        let ix = match results
 855                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 856                        {
 857                            Ok(ix) => ix,
 858                            Err(ix) => ix,
 859                        };
 860
 861                        results.insert(ix, (id, similarity));
 862                        results.truncate(limit);
 863                    }
 864                }
 865            }
 866
 867            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 868            let scores = results
 869                .into_iter()
 870                .map(|(_, score)| score)
 871                .collect::<Vec<_>>();
 872            let spans = database.spans_for_ids(ids.as_slice()).await?;
 873
 874            let mut tasks = Vec::new();
 875            let mut ranges = Vec::new();
 876            let weak_project = project.downgrade();
 877            project.update(&mut cx, |project, cx| {
 878                let this = this.upgrade().context("index was dropped")?;
 879                for (worktree_db_id, file_path, byte_range) in spans {
 880                    let project_state =
 881                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 882                            state
 883                        } else {
 884                            return Err(anyhow!("project not added"));
 885                        };
 886                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 887                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 888                        ranges.push(byte_range);
 889                    }
 890                }
 891
 892                Ok(())
 893            })??;
 894
 895            let buffers = futures::future::join_all(tasks).await;
 896            Ok(buffers
 897                .into_iter()
 898                .zip(ranges)
 899                .zip(scores)
 900                .filter_map(|((buffer, range), similarity)| {
 901                    let buffer = buffer.log_err()?;
 902                    let range = buffer
 903                        .read_with(&cx, |buffer, _| {
 904                            let start = buffer.clip_offset(range.start, Bias::Left);
 905                            let end = buffer.clip_offset(range.end, Bias::Right);
 906                            buffer.anchor_before(start)..buffer.anchor_after(end)
 907                        })
 908                        .log_err()?;
 909                    Some(SearchResult {
 910                        buffer,
 911                        range,
 912                        similarity,
 913                    })
 914                })
 915                .collect())
 916        })
 917    }
 918
 919    fn search_modified_buffers(
 920        &self,
 921        project: &Model<Project>,
 922        query: Embedding,
 923        limit: usize,
 924        includes: &[PathMatcher],
 925        excludes: &[PathMatcher],
 926        cx: &mut ModelContext<Self>,
 927    ) -> Task<Result<Vec<SearchResult>>> {
 928        let modified_buffers = project
 929            .read(cx)
 930            .opened_buffers()
 931            .into_iter()
 932            .filter_map(|buffer_handle| {
 933                let buffer = buffer_handle.read(cx);
 934                let snapshot = buffer.snapshot();
 935                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 936                    excludes.iter().any(|matcher| matcher.is_match(&path))
 937                });
 938
 939                let included = if includes.len() == 0 {
 940                    true
 941                } else {
 942                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 943                        includes.iter().any(|matcher| matcher.is_match(&path))
 944                    })
 945                };
 946
 947                if buffer.is_dirty() && !excluded && included {
 948                    Some((buffer_handle, snapshot))
 949                } else {
 950                    None
 951                }
 952            })
 953            .collect::<HashMap<_, _>>();
 954
 955        let embedding_provider = self.embedding_provider.clone();
 956        let fs = self.fs.clone();
 957        let db_path = self.db.path().clone();
 958        let background = cx.background_executor().clone();
 959        cx.background_executor().spawn(async move {
 960            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 961            let mut results = Vec::<SearchResult>::new();
 962
 963            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 964            for (buffer, snapshot) in modified_buffers {
 965                let language = snapshot
 966                    .language_at(0)
 967                    .cloned()
 968                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 969                let mut spans = retriever
 970                    .parse_file_with_template(None, &snapshot.text(), language)
 971                    .log_err()
 972                    .unwrap_or_default();
 973                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 974                    .await
 975                    .log_err()
 976                    .is_some()
 977                {
 978                    for span in spans {
 979                        let similarity = span.embedding.unwrap().similarity(&query);
 980                        let ix = match results
 981                            .binary_search_by_key(&Reverse(similarity), |result| {
 982                                Reverse(result.similarity)
 983                            }) {
 984                            Ok(ix) => ix,
 985                            Err(ix) => ix,
 986                        };
 987
 988                        let range = {
 989                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 990                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 991                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 992                        };
 993
 994                        results.insert(
 995                            ix,
 996                            SearchResult {
 997                                buffer: buffer.clone(),
 998                                range,
 999                                similarity,
1000                            },
1001                        );
1002                        results.truncate(limit);
1003                    }
1004                }
1005            }
1006
1007            Ok(results)
1008        })
1009    }
1010
1011    pub fn index_project(
1012        &mut self,
1013        project: Model<Project>,
1014        cx: &mut ModelContext<Self>,
1015    ) -> Task<Result<()>> {
1016        if self.is_authenticated() {
1017            self.index_project_internal(project, cx)
1018        } else {
1019            let authenticate = self.authenticate(cx);
1020            cx.spawn(|this, mut cx| async move {
1021                if authenticate.await {
1022                    this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1023                        .await
1024                } else {
1025                    Err(anyhow!("user is not authenticated"))
1026                }
1027            })
1028        }
1029    }
1030
1031    fn index_project_internal(
1032        &mut self,
1033        project: Model<Project>,
1034        cx: &mut ModelContext<Self>,
1035    ) -> Task<Result<()>> {
1036        if !self.projects.contains_key(&project.downgrade()) {
1037            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1038                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1039                    this.project_worktrees_changed(project.clone(), cx);
1040                }
1041                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1042                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1043                }
1044                _ => {}
1045            });
1046            let project_state = ProjectState::new(subscription, cx);
1047            self.projects.insert(project.downgrade(), project_state);
1048            self.project_worktrees_changed(project.clone(), cx);
1049        }
1050        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1051        project_state.pending_index += 1;
1052        cx.notify();
1053
1054        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1055        let db = self.db.clone();
1056        let language_registry = self.language_registry.clone();
1057        let parsing_files_tx = self.parsing_files_tx.clone();
1058        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1059
1060        cx.spawn(|this, mut cx| async move {
1061            worktree_registration.await?;
1062
1063            let mut pending_files = Vec::new();
1064            let mut files_to_delete = Vec::new();
1065            this.update(&mut cx, |this, cx| {
1066                let project_state = this
1067                    .projects
1068                    .get_mut(&project.downgrade())
1069                    .context("project was dropped")?;
1070                let pending_file_count_tx = &project_state.pending_file_count_tx;
1071
1072                project_state
1073                    .worktrees
1074                    .retain(|worktree_id, worktree_state| {
1075                        let worktree = if let Some(worktree) =
1076                            project.read(cx).worktree_for_id(*worktree_id, cx)
1077                        {
1078                            worktree
1079                        } else {
1080                            return false;
1081                        };
1082                        let worktree_state =
1083                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1084                                worktree_state
1085                            } else {
1086                                return true;
1087                            };
1088
1089                        for (path, info) in &worktree_state.changed_paths {
1090                            if info.is_deleted {
1091                                files_to_delete.push((worktree_state.db_id, path.clone()));
1092                            } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1093                                let job_handle = JobHandle::new(pending_file_count_tx);
1094                                pending_files.push(PendingFile {
1095                                    absolute_path,
1096                                    relative_path: path.clone(),
1097                                    language: None,
1098                                    job_handle,
1099                                    modified_time: info.mtime,
1100                                    worktree_db_id: worktree_state.db_id,
1101                                });
1102                            }
1103                        }
1104                        worktree_state.changed_paths.clear();
1105                        true
1106                    });
1107
1108                anyhow::Ok(())
1109            })??;
1110
1111            cx.background_executor()
1112                .spawn(async move {
1113                    for (worktree_db_id, path) in files_to_delete {
1114                        db.delete_file(worktree_db_id, path).await.log_err();
1115                    }
1116
1117                    let embeddings_for_digest = {
1118                        let mut files = HashMap::default();
1119                        for pending_file in &pending_files {
1120                            files
1121                                .entry(pending_file.worktree_db_id)
1122                                .or_insert(Vec::new())
1123                                .push(pending_file.relative_path.clone());
1124                        }
1125                        Arc::new(
1126                            db.embeddings_for_files(files)
1127                                .await
1128                                .log_err()
1129                                .unwrap_or_default(),
1130                        )
1131                    };
1132
1133                    for mut pending_file in pending_files {
1134                        if let Ok(language) = language_registry
1135                            .language_for_file(&pending_file.relative_path, None)
1136                            .await
1137                        {
1138                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1139                                && &language.name().as_ref() != &"Markdown"
1140                                && language
1141                                    .grammar()
1142                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1143                                    .is_none()
1144                            {
1145                                continue;
1146                            }
1147                            pending_file.language = Some(language);
1148                        }
1149                        parsing_files_tx
1150                            .try_send((embeddings_for_digest.clone(), pending_file))
1151                            .ok();
1152                    }
1153
1154                    // Wait until we're done indexing.
1155                    while let Some(count) = pending_file_count_rx.next().await {
1156                        if count == 0 {
1157                            break;
1158                        }
1159                    }
1160                })
1161                .await;
1162
1163            this.update(&mut cx, |this, cx| {
1164                let project_state = this
1165                    .projects
1166                    .get_mut(&project.downgrade())
1167                    .context("project was dropped")?;
1168                project_state.pending_index -= 1;
1169                cx.notify();
1170                anyhow::Ok(())
1171            })??;
1172
1173            Ok(())
1174        })
1175    }
1176
1177    fn wait_for_worktree_registration(
1178        &self,
1179        project: &Model<Project>,
1180        cx: &mut ModelContext<Self>,
1181    ) -> Task<Result<()>> {
1182        let project = project.downgrade();
1183        cx.spawn(|this, cx| async move {
1184            loop {
1185                let mut pending_worktrees = Vec::new();
1186                this.upgrade()
1187                    .context("semantic index dropped")?
1188                    .read_with(&cx, |this, _| {
1189                        if let Some(project) = this.projects.get(&project) {
1190                            for worktree in project.worktrees.values() {
1191                                if let WorktreeState::Registering(worktree) = worktree {
1192                                    pending_worktrees.push(worktree.done());
1193                                }
1194                            }
1195                        }
1196                    })?;
1197
1198                if pending_worktrees.is_empty() {
1199                    break;
1200                } else {
1201                    future::join_all(pending_worktrees).await;
1202                }
1203            }
1204            Ok(())
1205        })
1206    }
1207
1208    async fn embed_spans(
1209        spans: &mut [Span],
1210        embedding_provider: &dyn EmbeddingProvider,
1211        db: &VectorDatabase,
1212    ) -> Result<()> {
1213        let mut batch = Vec::new();
1214        let mut batch_tokens = 0;
1215        let mut embeddings = Vec::new();
1216
1217        let digests = spans
1218            .iter()
1219            .map(|span| span.digest.clone())
1220            .collect::<Vec<_>>();
1221        let embeddings_for_digests = db
1222            .embeddings_for_digests(digests)
1223            .await
1224            .log_err()
1225            .unwrap_or_default();
1226
1227        for span in &*spans {
1228            if embeddings_for_digests.contains_key(&span.digest) {
1229                continue;
1230            };
1231
1232            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1233                let batch_embeddings = embedding_provider
1234                    .embed_batch(mem::take(&mut batch))
1235                    .await?;
1236                embeddings.extend(batch_embeddings);
1237                batch_tokens = 0;
1238            }
1239
1240            batch_tokens += span.token_count;
1241            batch.push(span.content.clone());
1242        }
1243
1244        if !batch.is_empty() {
1245            let batch_embeddings = embedding_provider
1246                .embed_batch(mem::take(&mut batch))
1247                .await?;
1248
1249            embeddings.extend(batch_embeddings);
1250        }
1251
1252        let mut embeddings = embeddings.into_iter();
1253        for span in spans {
1254            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1255                Some(embedding.clone())
1256            } else {
1257                embeddings.next()
1258            };
1259            let embedding = embedding.context("failed to embed spans")?;
1260            span.embedding = Some(embedding);
1261        }
1262        Ok(())
1263    }
1264}
1265
1266impl Drop for JobHandle {
1267    fn drop(&mut self) {
1268        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1269            // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1270            if let Some(tx) = inner.upgrade() {
1271                let mut tx = tx.lock();
1272                *tx.borrow_mut() -= 1;
1273            }
1274        }
1275    }
1276}
1277
1278#[cfg(test)]
1279mod tests {
1280
1281    use super::*;
1282    #[test]
1283    fn test_job_handle() {
1284        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1285        let tx = Arc::new(Mutex::new(job_count_tx));
1286        let job_handle = JobHandle::new(&tx);
1287
1288        assert_eq!(1, *job_count_rx.borrow());
1289        let new_job_handle = job_handle.clone();
1290        assert_eq!(1, *job_count_rx.borrow());
1291        drop(job_handle);
1292        assert_eq!(1, *job_count_rx.borrow());
1293        drop(new_job_handle);
1294        assert_eq!(0, *job_count_rx.borrow());
1295    }
1296}