semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::auth::ProviderCredential;
  11use ai::embedding::{Embedding, EmbeddingProvider};
  12use ai::providers::open_ai::OpenAIEmbeddingProvider;
  13use anyhow::{anyhow, Result};
  14use collections::{BTreeMap, HashMap, HashSet};
  15use db::VectorDatabase;
  16use embedding_queue::{EmbeddingQueue, FileToEmbed};
  17use futures::{future, FutureExt, StreamExt};
  18use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  19use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  20use lazy_static::lazy_static;
  21use ordered_float::OrderedFloat;
  22use parking_lot::Mutex;
  23use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  24use postage::watch;
  25use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  26use smol::channel;
  27use std::{
  28    cmp::Reverse,
  29    env,
  30    future::Future,
  31    mem,
  32    ops::Range,
  33    path::{Path, PathBuf},
  34    sync::{Arc, Weak},
  35    time::{Duration, Instant, SystemTime},
  36};
  37use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  38use workspace::WorkspaceCreated;
  39
  40const SEMANTIC_INDEX_VERSION: usize = 11;
  41const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  42const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  43
  44lazy_static! {
  45    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  46}
  47
  48pub fn init(
  49    fs: Arc<dyn Fs>,
  50    http_client: Arc<dyn HttpClient>,
  51    language_registry: Arc<LanguageRegistry>,
  52    cx: &mut AppContext,
  53) {
  54    settings::register::<SemanticIndexSettings>(cx);
  55
  56    let db_file_path = EMBEDDINGS_DIR
  57        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  58        .join("embeddings_db");
  59
  60    cx.subscribe_global::<WorkspaceCreated, _>({
  61        move |event, cx| {
  62            let Some(semantic_index) = SemanticIndex::global(cx) else {
  63                return;
  64            };
  65            let workspace = &event.0;
  66            if let Some(workspace) = workspace.upgrade(cx) {
  67                let project = workspace.read(cx).project().clone();
  68                if project.read(cx).is_local() {
  69                    cx.spawn(|mut cx| async move {
  70                        let previously_indexed = semantic_index
  71                            .update(&mut cx, |index, cx| {
  72                                index.project_previously_indexed(&project, cx)
  73                            })
  74                            .await?;
  75                        if previously_indexed {
  76                            semantic_index
  77                                .update(&mut cx, |index, cx| index.index_project(project, cx))
  78                                .await?;
  79                        }
  80                        anyhow::Ok(())
  81                    })
  82                    .detach_and_log_err(cx);
  83                }
  84            }
  85        }
  86    })
  87    .detach();
  88
  89    cx.spawn(move |mut cx| async move {
  90        let semantic_index = SemanticIndex::new(
  91            fs,
  92            db_file_path,
  93            Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
  94            language_registry,
  95            cx.clone(),
  96        )
  97        .await?;
  98
  99        cx.update(|cx| {
 100            cx.set_global(semantic_index.clone());
 101        });
 102
 103        anyhow::Ok(())
 104    })
 105    .detach();
 106}
 107
 108#[derive(Copy, Clone, Debug)]
 109pub enum SemanticIndexStatus {
 110    NotAuthenticated,
 111    NotIndexed,
 112    Indexed,
 113    Indexing {
 114        remaining_files: usize,
 115        rate_limit_expiry: Option<Instant>,
 116    },
 117}
 118
 119pub struct SemanticIndex {
 120    fs: Arc<dyn Fs>,
 121    db: VectorDatabase,
 122    embedding_provider: Arc<dyn EmbeddingProvider>,
 123    language_registry: Arc<LanguageRegistry>,
 124    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 125    _embedding_task: Task<()>,
 126    _parsing_files_tasks: Vec<Task<()>>,
 127    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 128    provider_credential: ProviderCredential,
 129    embedding_queue: Arc<Mutex<EmbeddingQueue>>,
 130}
 131
 132struct ProjectState {
 133    worktrees: HashMap<WorktreeId, WorktreeState>,
 134    pending_file_count_rx: watch::Receiver<usize>,
 135    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 136    pending_index: usize,
 137    _subscription: gpui::Subscription,
 138    _observe_pending_file_count: Task<()>,
 139}
 140
 141enum WorktreeState {
 142    Registering(RegisteringWorktreeState),
 143    Registered(RegisteredWorktreeState),
 144}
 145
 146impl WorktreeState {
 147    fn is_registered(&self) -> bool {
 148        matches!(self, Self::Registered(_))
 149    }
 150
 151    fn paths_changed(
 152        &mut self,
 153        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 154        worktree: &Worktree,
 155    ) {
 156        let changed_paths = match self {
 157            Self::Registering(state) => &mut state.changed_paths,
 158            Self::Registered(state) => &mut state.changed_paths,
 159        };
 160
 161        for (path, entry_id, change) in changes.iter() {
 162            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 163                continue;
 164            };
 165            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 166                continue;
 167            }
 168            changed_paths.insert(
 169                path.clone(),
 170                ChangedPathInfo {
 171                    mtime: entry.mtime,
 172                    is_deleted: *change == PathChange::Removed,
 173                },
 174            );
 175        }
 176    }
 177}
 178
 179struct RegisteringWorktreeState {
 180    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 181    done_rx: watch::Receiver<Option<()>>,
 182    _registration: Task<()>,
 183}
 184
 185impl RegisteringWorktreeState {
 186    fn done(&self) -> impl Future<Output = ()> {
 187        let mut done_rx = self.done_rx.clone();
 188        async move {
 189            while let Some(result) = done_rx.next().await {
 190                if result.is_some() {
 191                    break;
 192                }
 193            }
 194        }
 195    }
 196}
 197
 198struct RegisteredWorktreeState {
 199    db_id: i64,
 200    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 201}
 202
 203struct ChangedPathInfo {
 204    mtime: SystemTime,
 205    is_deleted: bool,
 206}
 207
 208#[derive(Clone)]
 209pub struct JobHandle {
 210    /// The outer Arc is here to count the clones of a JobHandle instance;
 211    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 212    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 213}
 214
 215impl JobHandle {
 216    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 217        *tx.lock().borrow_mut() += 1;
 218        Self {
 219            tx: Arc::new(Arc::downgrade(&tx)),
 220        }
 221    }
 222}
 223
 224impl ProjectState {
 225    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 226        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 227        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 228        Self {
 229            worktrees: Default::default(),
 230            pending_file_count_rx: pending_file_count_rx.clone(),
 231            pending_file_count_tx,
 232            pending_index: 0,
 233            _subscription: subscription,
 234            _observe_pending_file_count: cx.spawn_weak({
 235                let mut pending_file_count_rx = pending_file_count_rx.clone();
 236                |this, mut cx| async move {
 237                    while let Some(_) = pending_file_count_rx.next().await {
 238                        if let Some(this) = this.upgrade(&cx) {
 239                            this.update(&mut cx, |_, cx| cx.notify());
 240                        }
 241                    }
 242                }
 243            }),
 244        }
 245    }
 246
 247    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 248        self.worktrees
 249            .iter()
 250            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 251                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 252                _ => None,
 253            })
 254    }
 255}
 256
 257#[derive(Clone)]
 258pub struct PendingFile {
 259    worktree_db_id: i64,
 260    relative_path: Arc<Path>,
 261    absolute_path: PathBuf,
 262    language: Option<Arc<Language>>,
 263    modified_time: SystemTime,
 264    job_handle: JobHandle,
 265}
 266
 267#[derive(Clone)]
 268pub struct SearchResult {
 269    pub buffer: ModelHandle<Buffer>,
 270    pub range: Range<Anchor>,
 271    pub similarity: OrderedFloat<f32>,
 272}
 273
 274impl SemanticIndex {
 275    pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
 276        if cx.has_global::<ModelHandle<Self>>() {
 277            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 278        } else {
 279            None
 280        }
 281    }
 282
 283    pub fn authenticate(&mut self, cx: &AppContext) -> bool {
 284        let existing_credential = self.provider_credential.clone();
 285        let credential = match existing_credential {
 286            ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
 287            _ => existing_credential,
 288        };
 289
 290        self.provider_credential = credential.clone();
 291        self.embedding_queue.lock().set_credential(credential);
 292        self.is_authenticated()
 293    }
 294
 295    pub fn is_authenticated(&self) -> bool {
 296        let credential = &self.provider_credential;
 297        match credential {
 298            &ProviderCredential::Credentials { .. } => true,
 299            &ProviderCredential::NotNeeded => true,
 300            _ => false,
 301        }
 302    }
 303
 304    pub fn enabled(cx: &AppContext) -> bool {
 305        settings::get::<SemanticIndexSettings>(cx).enabled
 306    }
 307
 308    pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
 309        if !self.is_authenticated() {
 310            return SemanticIndexStatus::NotAuthenticated;
 311        }
 312
 313        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 314            if project_state
 315                .worktrees
 316                .values()
 317                .all(|worktree| worktree.is_registered())
 318                && project_state.pending_index == 0
 319            {
 320                SemanticIndexStatus::Indexed
 321            } else {
 322                SemanticIndexStatus::Indexing {
 323                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 324                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 325                }
 326            }
 327        } else {
 328            SemanticIndexStatus::NotIndexed
 329        }
 330    }
 331
 332    pub async fn new(
 333        fs: Arc<dyn Fs>,
 334        database_path: PathBuf,
 335        embedding_provider: Arc<dyn EmbeddingProvider>,
 336        language_registry: Arc<LanguageRegistry>,
 337        mut cx: AsyncAppContext,
 338    ) -> Result<ModelHandle<Self>> {
 339        let t0 = Instant::now();
 340        let database_path = Arc::from(database_path);
 341        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 342
 343        log::trace!(
 344            "db initialization took {:?} milliseconds",
 345            t0.elapsed().as_millis()
 346        );
 347
 348        Ok(cx.add_model(|cx| {
 349            let t0 = Instant::now();
 350            let embedding_queue =
 351                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
 352            let _embedding_task = cx.background().spawn({
 353                let embedded_files = embedding_queue.finished_files();
 354                let db = db.clone();
 355                async move {
 356                    while let Ok(file) = embedded_files.recv().await {
 357                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 358                            .await
 359                            .log_err();
 360                    }
 361                }
 362            });
 363
 364            // Parse files into embeddable spans.
 365            let (parsing_files_tx, parsing_files_rx) =
 366                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 367            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 368            let mut _parsing_files_tasks = Vec::new();
 369            for _ in 0..cx.background().num_cpus() {
 370                let fs = fs.clone();
 371                let mut parsing_files_rx = parsing_files_rx.clone();
 372                let embedding_provider = embedding_provider.clone();
 373                let embedding_queue = embedding_queue.clone();
 374                let background = cx.background().clone();
 375                _parsing_files_tasks.push(cx.background().spawn(async move {
 376                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 377                    loop {
 378                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 379                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 380                        futures::select_biased! {
 381                            next_file_to_parse = next_file_to_parse => {
 382                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 383                                    Self::parse_file(
 384                                        &fs,
 385                                        pending_file,
 386                                        &mut retriever,
 387                                        &embedding_queue,
 388                                        &embeddings_for_digest,
 389                                    )
 390                                    .await
 391                                } else {
 392                                    break;
 393                                }
 394                            },
 395                            _ = timer => {
 396                                embedding_queue.lock().flush();
 397                            }
 398                        }
 399                    }
 400                }));
 401            }
 402
 403            log::trace!(
 404                "semantic index task initialization took {:?} milliseconds",
 405                t0.elapsed().as_millis()
 406            );
 407            Self {
 408                fs,
 409                db,
 410                embedding_provider,
 411                language_registry,
 412                parsing_files_tx,
 413                _embedding_task,
 414                _parsing_files_tasks,
 415                projects: Default::default(),
 416                provider_credential: ProviderCredential::NoCredentials,
 417                embedding_queue
 418            }
 419        }))
 420    }
 421
 422    async fn parse_file(
 423        fs: &Arc<dyn Fs>,
 424        pending_file: PendingFile,
 425        retriever: &mut CodeContextRetriever,
 426        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 427        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 428    ) {
 429        let Some(language) = pending_file.language else {
 430            return;
 431        };
 432
 433        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 434            if let Some(mut spans) = retriever
 435                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 436                .log_err()
 437            {
 438                log::trace!(
 439                    "parsed path {:?}: {} spans",
 440                    pending_file.relative_path,
 441                    spans.len()
 442                );
 443
 444                for span in &mut spans {
 445                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 446                        span.embedding = Some(embedding.to_owned());
 447                    }
 448                }
 449
 450                embedding_queue.lock().push(FileToEmbed {
 451                    worktree_id: pending_file.worktree_db_id,
 452                    path: pending_file.relative_path,
 453                    mtime: pending_file.modified_time,
 454                    job_handle: pending_file.job_handle,
 455                    spans,
 456                });
 457            }
 458        }
 459    }
 460
 461    pub fn project_previously_indexed(
 462        &mut self,
 463        project: &ModelHandle<Project>,
 464        cx: &mut ModelContext<Self>,
 465    ) -> Task<Result<bool>> {
 466        let worktrees_indexed_previously = project
 467            .read(cx)
 468            .worktrees(cx)
 469            .map(|worktree| {
 470                self.db
 471                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 472            })
 473            .collect::<Vec<_>>();
 474        cx.spawn(|_, _cx| async move {
 475            let worktree_indexed_previously =
 476                futures::future::join_all(worktrees_indexed_previously).await;
 477
 478            Ok(worktree_indexed_previously
 479                .iter()
 480                .filter(|worktree| worktree.is_ok())
 481                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 482        })
 483    }
 484
 485    fn project_entries_changed(
 486        &mut self,
 487        project: ModelHandle<Project>,
 488        worktree_id: WorktreeId,
 489        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 490        cx: &mut ModelContext<Self>,
 491    ) {
 492        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 493            return;
 494        };
 495        let project = project.downgrade();
 496        let Some(project_state) = self.projects.get_mut(&project) else {
 497            return;
 498        };
 499
 500        let worktree = worktree.read(cx);
 501        let worktree_state =
 502            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 503                worktree_state
 504            } else {
 505                return;
 506            };
 507        worktree_state.paths_changed(changes, worktree);
 508        if let WorktreeState::Registered(_) = worktree_state {
 509            cx.spawn_weak(|this, mut cx| async move {
 510                cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
 511                if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
 512                    this.update(&mut cx, |this, cx| {
 513                        this.index_project(project, cx).detach_and_log_err(cx)
 514                    });
 515                }
 516            })
 517            .detach();
 518        }
 519    }
 520
 521    fn register_worktree(
 522        &mut self,
 523        project: ModelHandle<Project>,
 524        worktree: ModelHandle<Worktree>,
 525        cx: &mut ModelContext<Self>,
 526    ) {
 527        let project = project.downgrade();
 528        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 529            project_state
 530        } else {
 531            return;
 532        };
 533        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 534            worktree
 535        } else {
 536            return;
 537        };
 538        let worktree_abs_path = worktree.abs_path().clone();
 539        let scan_complete = worktree.scan_complete();
 540        let worktree_id = worktree.id();
 541        let db = self.db.clone();
 542        let language_registry = self.language_registry.clone();
 543        let (mut done_tx, done_rx) = watch::channel();
 544        let registration = cx.spawn(|this, mut cx| {
 545            async move {
 546                let register = async {
 547                    scan_complete.await;
 548                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 549                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 550                    let worktree = if let Some(project) = project.upgrade(&cx) {
 551                        project
 552                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 553                            .ok_or_else(|| anyhow!("worktree not found"))?
 554                    } else {
 555                        return anyhow::Ok(());
 556                    };
 557                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
 558                    let mut changed_paths = cx
 559                        .background()
 560                        .spawn(async move {
 561                            let mut changed_paths = BTreeMap::new();
 562                            for file in worktree.files(false, 0) {
 563                                let absolute_path = worktree.absolutize(&file.path);
 564
 565                                if file.is_external || file.is_ignored || file.is_symlink {
 566                                    continue;
 567                                }
 568
 569                                if let Ok(language) = language_registry
 570                                    .language_for_file(&absolute_path, None)
 571                                    .await
 572                                {
 573                                    // Test if file is valid parseable file
 574                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 575                                        .contains(&language.name().as_ref())
 576                                        && &language.name().as_ref() != &"Markdown"
 577                                        && language
 578                                            .grammar()
 579                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 580                                            .is_none()
 581                                    {
 582                                        continue;
 583                                    }
 584
 585                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 586                                    let already_stored = stored_mtime
 587                                        .map_or(false, |existing_mtime| {
 588                                            existing_mtime == file.mtime
 589                                        });
 590
 591                                    if !already_stored {
 592                                        changed_paths.insert(
 593                                            file.path.clone(),
 594                                            ChangedPathInfo {
 595                                                mtime: file.mtime,
 596                                                is_deleted: false,
 597                                            },
 598                                        );
 599                                    }
 600                                }
 601                            }
 602
 603                            // Clean up entries from database that are no longer in the worktree.
 604                            for (path, mtime) in file_mtimes {
 605                                changed_paths.insert(
 606                                    path.into(),
 607                                    ChangedPathInfo {
 608                                        mtime,
 609                                        is_deleted: true,
 610                                    },
 611                                );
 612                            }
 613
 614                            anyhow::Ok(changed_paths)
 615                        })
 616                        .await?;
 617                    this.update(&mut cx, |this, cx| {
 618                        let project_state = this
 619                            .projects
 620                            .get_mut(&project)
 621                            .ok_or_else(|| anyhow!("project not registered"))?;
 622                        let project = project
 623                            .upgrade(cx)
 624                            .ok_or_else(|| anyhow!("project was dropped"))?;
 625
 626                        if let Some(WorktreeState::Registering(state)) =
 627                            project_state.worktrees.remove(&worktree_id)
 628                        {
 629                            changed_paths.extend(state.changed_paths);
 630                        }
 631                        project_state.worktrees.insert(
 632                            worktree_id,
 633                            WorktreeState::Registered(RegisteredWorktreeState {
 634                                db_id,
 635                                changed_paths,
 636                            }),
 637                        );
 638                        this.index_project(project, cx).detach_and_log_err(cx);
 639
 640                        anyhow::Ok(())
 641                    })?;
 642
 643                    anyhow::Ok(())
 644                };
 645
 646                if register.await.log_err().is_none() {
 647                    // Stop tracking this worktree if the registration failed.
 648                    this.update(&mut cx, |this, _| {
 649                        this.projects.get_mut(&project).map(|project_state| {
 650                            project_state.worktrees.remove(&worktree_id);
 651                        });
 652                    })
 653                }
 654
 655                *done_tx.borrow_mut() = Some(());
 656            }
 657        });
 658        project_state.worktrees.insert(
 659            worktree_id,
 660            WorktreeState::Registering(RegisteringWorktreeState {
 661                changed_paths: Default::default(),
 662                done_rx,
 663                _registration: registration,
 664            }),
 665        );
 666    }
 667
 668    fn project_worktrees_changed(
 669        &mut self,
 670        project: ModelHandle<Project>,
 671        cx: &mut ModelContext<Self>,
 672    ) {
 673        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 674        {
 675            project_state
 676        } else {
 677            return;
 678        };
 679
 680        let mut worktrees = project
 681            .read(cx)
 682            .worktrees(cx)
 683            .filter(|worktree| worktree.read(cx).is_local())
 684            .collect::<Vec<_>>();
 685        let worktree_ids = worktrees
 686            .iter()
 687            .map(|worktree| worktree.read(cx).id())
 688            .collect::<HashSet<_>>();
 689
 690        // Remove worktrees that are no longer present
 691        project_state
 692            .worktrees
 693            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 694
 695        // Register new worktrees
 696        worktrees.retain(|worktree| {
 697            let worktree_id = worktree.read(cx).id();
 698            !project_state.worktrees.contains_key(&worktree_id)
 699        });
 700        for worktree in worktrees {
 701            self.register_worktree(project.clone(), worktree, cx);
 702        }
 703    }
 704
 705    pub fn pending_file_count(
 706        &self,
 707        project: &ModelHandle<Project>,
 708    ) -> Option<watch::Receiver<usize>> {
 709        Some(
 710            self.projects
 711                .get(&project.downgrade())?
 712                .pending_file_count_rx
 713                .clone(),
 714        )
 715    }
 716
 717    pub fn search_project(
 718        &mut self,
 719        project: ModelHandle<Project>,
 720        query: String,
 721        limit: usize,
 722        includes: Vec<PathMatcher>,
 723        excludes: Vec<PathMatcher>,
 724        cx: &mut ModelContext<Self>,
 725    ) -> Task<Result<Vec<SearchResult>>> {
 726        if query.is_empty() {
 727            return Task::ready(Ok(Vec::new()));
 728        }
 729
 730        let index = self.index_project(project.clone(), cx);
 731        let embedding_provider = self.embedding_provider.clone();
 732        let credential = self.provider_credential.clone();
 733
 734        cx.spawn(|this, mut cx| async move {
 735            index.await?;
 736            let t0 = Instant::now();
 737
 738            let query = embedding_provider
 739                .embed_batch(vec![query], credential)
 740                .await?
 741                .pop()
 742                .ok_or_else(|| anyhow!("could not embed query"))?;
 743            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 744
 745            let search_start = Instant::now();
 746            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 747                this.search_modified_buffers(
 748                    &project,
 749                    query.clone(),
 750                    limit,
 751                    &includes,
 752                    &excludes,
 753                    cx,
 754                )
 755            });
 756            let file_results = this.update(&mut cx, |this, cx| {
 757                this.search_files(project, query, limit, includes, excludes, cx)
 758            });
 759            let (modified_buffer_results, file_results) =
 760                futures::join!(modified_buffer_results, file_results);
 761
 762            // Weave together the results from modified buffers and files.
 763            let mut results = Vec::new();
 764            let mut modified_buffers = HashSet::default();
 765            for result in modified_buffer_results.log_err().unwrap_or_default() {
 766                modified_buffers.insert(result.buffer.clone());
 767                results.push(result);
 768            }
 769            for result in file_results.log_err().unwrap_or_default() {
 770                if !modified_buffers.contains(&result.buffer) {
 771                    results.push(result);
 772                }
 773            }
 774            results.sort_by_key(|result| Reverse(result.similarity));
 775            results.truncate(limit);
 776            log::trace!("Semantic search took {:?}", search_start.elapsed());
 777            Ok(results)
 778        })
 779    }
 780
 781    pub fn search_files(
 782        &mut self,
 783        project: ModelHandle<Project>,
 784        query: Embedding,
 785        limit: usize,
 786        includes: Vec<PathMatcher>,
 787        excludes: Vec<PathMatcher>,
 788        cx: &mut ModelContext<Self>,
 789    ) -> Task<Result<Vec<SearchResult>>> {
 790        let db_path = self.db.path().clone();
 791        let fs = self.fs.clone();
 792        cx.spawn(|this, mut cx| async move {
 793            let database =
 794                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 795
 796            let worktree_db_ids = this.read_with(&cx, |this, _| {
 797                let project_state = this
 798                    .projects
 799                    .get(&project.downgrade())
 800                    .ok_or_else(|| anyhow!("project was not indexed"))?;
 801                let worktree_db_ids = project_state
 802                    .worktrees
 803                    .values()
 804                    .filter_map(|worktree| {
 805                        if let WorktreeState::Registered(worktree) = worktree {
 806                            Some(worktree.db_id)
 807                        } else {
 808                            None
 809                        }
 810                    })
 811                    .collect::<Vec<i64>>();
 812                anyhow::Ok(worktree_db_ids)
 813            })?;
 814
 815            let file_ids = database
 816                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 817                .await?;
 818
 819            let batch_n = cx.background().num_cpus();
 820            let ids_len = file_ids.clone().len();
 821            let minimum_batch_size = 50;
 822
 823            let batch_size = {
 824                let size = ids_len / batch_n;
 825                if size < minimum_batch_size {
 826                    minimum_batch_size
 827                } else {
 828                    size
 829                }
 830            };
 831
 832            let mut batch_results = Vec::new();
 833            for batch in file_ids.chunks(batch_size) {
 834                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 835                let limit = limit.clone();
 836                let fs = fs.clone();
 837                let db_path = db_path.clone();
 838                let query = query.clone();
 839                if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
 840                    .await
 841                    .log_err()
 842                {
 843                    batch_results.push(async move {
 844                        db.top_k_search(&query, limit, batch.as_slice()).await
 845                    });
 846                }
 847            }
 848
 849            let batch_results = futures::future::join_all(batch_results).await;
 850
 851            let mut results = Vec::new();
 852            for batch_result in batch_results {
 853                if batch_result.is_ok() {
 854                    for (id, similarity) in batch_result.unwrap() {
 855                        let ix = match results
 856                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 857                        {
 858                            Ok(ix) => ix,
 859                            Err(ix) => ix,
 860                        };
 861
 862                        results.insert(ix, (id, similarity));
 863                        results.truncate(limit);
 864                    }
 865                }
 866            }
 867
 868            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 869            let scores = results
 870                .into_iter()
 871                .map(|(_, score)| score)
 872                .collect::<Vec<_>>();
 873            let spans = database.spans_for_ids(ids.as_slice()).await?;
 874
 875            let mut tasks = Vec::new();
 876            let mut ranges = Vec::new();
 877            let weak_project = project.downgrade();
 878            project.update(&mut cx, |project, cx| {
 879                for (worktree_db_id, file_path, byte_range) in spans {
 880                    let project_state =
 881                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 882                            state
 883                        } else {
 884                            return Err(anyhow!("project not added"));
 885                        };
 886                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 887                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 888                        ranges.push(byte_range);
 889                    }
 890                }
 891
 892                Ok(())
 893            })?;
 894
 895            let buffers = futures::future::join_all(tasks).await;
 896            Ok(buffers
 897                .into_iter()
 898                .zip(ranges)
 899                .zip(scores)
 900                .filter_map(|((buffer, range), similarity)| {
 901                    let buffer = buffer.log_err()?;
 902                    let range = buffer.read_with(&cx, |buffer, _| {
 903                        let start = buffer.clip_offset(range.start, Bias::Left);
 904                        let end = buffer.clip_offset(range.end, Bias::Right);
 905                        buffer.anchor_before(start)..buffer.anchor_after(end)
 906                    });
 907                    Some(SearchResult {
 908                        buffer,
 909                        range,
 910                        similarity,
 911                    })
 912                })
 913                .collect())
 914        })
 915    }
 916
 917    fn search_modified_buffers(
 918        &self,
 919        project: &ModelHandle<Project>,
 920        query: Embedding,
 921        limit: usize,
 922        includes: &[PathMatcher],
 923        excludes: &[PathMatcher],
 924        cx: &mut ModelContext<Self>,
 925    ) -> Task<Result<Vec<SearchResult>>> {
 926        let modified_buffers = project
 927            .read(cx)
 928            .opened_buffers(cx)
 929            .into_iter()
 930            .filter_map(|buffer_handle| {
 931                let buffer = buffer_handle.read(cx);
 932                let snapshot = buffer.snapshot();
 933                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 934                    excludes.iter().any(|matcher| matcher.is_match(&path))
 935                });
 936
 937                let included = if includes.len() == 0 {
 938                    true
 939                } else {
 940                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 941                        includes.iter().any(|matcher| matcher.is_match(&path))
 942                    })
 943                };
 944
 945                if buffer.is_dirty() && !excluded && included {
 946                    Some((buffer_handle, snapshot))
 947                } else {
 948                    None
 949                }
 950            })
 951            .collect::<HashMap<_, _>>();
 952
 953        let embedding_provider = self.embedding_provider.clone();
 954        let fs = self.fs.clone();
 955        let db_path = self.db.path().clone();
 956        let background = cx.background().clone();
 957        let credential = self.provider_credential.clone();
 958        cx.background().spawn(async move {
 959            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 960            let mut results = Vec::<SearchResult>::new();
 961
 962            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 963            for (buffer, snapshot) in modified_buffers {
 964                let language = snapshot
 965                    .language_at(0)
 966                    .cloned()
 967                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 968                let mut spans = retriever
 969                    .parse_file_with_template(None, &snapshot.text(), language)
 970                    .log_err()
 971                    .unwrap_or_default();
 972                if Self::embed_spans(
 973                    &mut spans,
 974                    embedding_provider.as_ref(),
 975                    &db,
 976                    credential.clone(),
 977                )
 978                .await
 979                .log_err()
 980                .is_some()
 981                {
 982                    for span in spans {
 983                        let similarity = span.embedding.unwrap().similarity(&query);
 984                        let ix = match results
 985                            .binary_search_by_key(&Reverse(similarity), |result| {
 986                                Reverse(result.similarity)
 987                            }) {
 988                            Ok(ix) => ix,
 989                            Err(ix) => ix,
 990                        };
 991
 992                        let range = {
 993                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 994                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 995                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 996                        };
 997
 998                        results.insert(
 999                            ix,
1000                            SearchResult {
1001                                buffer: buffer.clone(),
1002                                range,
1003                                similarity,
1004                            },
1005                        );
1006                        results.truncate(limit);
1007                    }
1008                }
1009            }
1010
1011            Ok(results)
1012        })
1013    }
1014
1015    pub fn index_project(
1016        &mut self,
1017        project: ModelHandle<Project>,
1018        cx: &mut ModelContext<Self>,
1019    ) -> Task<Result<()>> {
1020        if !self.is_authenticated() {
1021            if !self.authenticate(cx) {
1022                return Task::ready(Err(anyhow!("user is not authenticated")));
1023            }
1024        }
1025
1026        if !self.projects.contains_key(&project.downgrade()) {
1027            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1028                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1029                    this.project_worktrees_changed(project.clone(), cx);
1030                }
1031                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1032                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1033                }
1034                _ => {}
1035            });
1036            let project_state = ProjectState::new(subscription, cx);
1037            self.projects.insert(project.downgrade(), project_state);
1038            self.project_worktrees_changed(project.clone(), cx);
1039        }
1040        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1041        project_state.pending_index += 1;
1042        cx.notify();
1043
1044        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1045        let db = self.db.clone();
1046        let language_registry = self.language_registry.clone();
1047        let parsing_files_tx = self.parsing_files_tx.clone();
1048        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1049
1050        cx.spawn(|this, mut cx| async move {
1051            worktree_registration.await?;
1052
1053            let mut pending_files = Vec::new();
1054            let mut files_to_delete = Vec::new();
1055            this.update(&mut cx, |this, cx| {
1056                let project_state = this
1057                    .projects
1058                    .get_mut(&project.downgrade())
1059                    .ok_or_else(|| anyhow!("project was dropped"))?;
1060                let pending_file_count_tx = &project_state.pending_file_count_tx;
1061
1062                project_state
1063                    .worktrees
1064                    .retain(|worktree_id, worktree_state| {
1065                        let worktree = if let Some(worktree) =
1066                            project.read(cx).worktree_for_id(*worktree_id, cx)
1067                        {
1068                            worktree
1069                        } else {
1070                            return false;
1071                        };
1072                        let worktree_state =
1073                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1074                                worktree_state
1075                            } else {
1076                                return true;
1077                            };
1078
1079                        worktree_state.changed_paths.retain(|path, info| {
1080                            if info.is_deleted {
1081                                files_to_delete.push((worktree_state.db_id, path.clone()));
1082                            } else {
1083                                let absolute_path = worktree.read(cx).absolutize(path);
1084                                let job_handle = JobHandle::new(pending_file_count_tx);
1085                                pending_files.push(PendingFile {
1086                                    absolute_path,
1087                                    relative_path: path.clone(),
1088                                    language: None,
1089                                    job_handle,
1090                                    modified_time: info.mtime,
1091                                    worktree_db_id: worktree_state.db_id,
1092                                });
1093                            }
1094
1095                            false
1096                        });
1097                        true
1098                    });
1099
1100                anyhow::Ok(())
1101            })?;
1102
1103            cx.background()
1104                .spawn(async move {
1105                    for (worktree_db_id, path) in files_to_delete {
1106                        db.delete_file(worktree_db_id, path).await.log_err();
1107                    }
1108
1109                    let embeddings_for_digest = {
1110                        let mut files = HashMap::default();
1111                        for pending_file in &pending_files {
1112                            files
1113                                .entry(pending_file.worktree_db_id)
1114                                .or_insert(Vec::new())
1115                                .push(pending_file.relative_path.clone());
1116                        }
1117                        Arc::new(
1118                            db.embeddings_for_files(files)
1119                                .await
1120                                .log_err()
1121                                .unwrap_or_default(),
1122                        )
1123                    };
1124
1125                    for mut pending_file in pending_files {
1126                        if let Ok(language) = language_registry
1127                            .language_for_file(&pending_file.relative_path, None)
1128                            .await
1129                        {
1130                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1131                                && &language.name().as_ref() != &"Markdown"
1132                                && language
1133                                    .grammar()
1134                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1135                                    .is_none()
1136                            {
1137                                continue;
1138                            }
1139                            pending_file.language = Some(language);
1140                        }
1141                        parsing_files_tx
1142                            .try_send((embeddings_for_digest.clone(), pending_file))
1143                            .ok();
1144                    }
1145
1146                    // Wait until we're done indexing.
1147                    while let Some(count) = pending_file_count_rx.next().await {
1148                        if count == 0 {
1149                            break;
1150                        }
1151                    }
1152                })
1153                .await;
1154
1155            this.update(&mut cx, |this, cx| {
1156                let project_state = this
1157                    .projects
1158                    .get_mut(&project.downgrade())
1159                    .ok_or_else(|| anyhow!("project was dropped"))?;
1160                project_state.pending_index -= 1;
1161                cx.notify();
1162                anyhow::Ok(())
1163            })?;
1164
1165            Ok(())
1166        })
1167    }
1168
1169    fn wait_for_worktree_registration(
1170        &self,
1171        project: &ModelHandle<Project>,
1172        cx: &mut ModelContext<Self>,
1173    ) -> Task<Result<()>> {
1174        let project = project.downgrade();
1175        cx.spawn_weak(|this, cx| async move {
1176            loop {
1177                let mut pending_worktrees = Vec::new();
1178                this.upgrade(&cx)
1179                    .ok_or_else(|| anyhow!("semantic index dropped"))?
1180                    .read_with(&cx, |this, _| {
1181                        if let Some(project) = this.projects.get(&project) {
1182                            for worktree in project.worktrees.values() {
1183                                if let WorktreeState::Registering(worktree) = worktree {
1184                                    pending_worktrees.push(worktree.done());
1185                                }
1186                            }
1187                        }
1188                    });
1189
1190                if pending_worktrees.is_empty() {
1191                    break;
1192                } else {
1193                    future::join_all(pending_worktrees).await;
1194                }
1195            }
1196            Ok(())
1197        })
1198    }
1199
1200    async fn embed_spans(
1201        spans: &mut [Span],
1202        embedding_provider: &dyn EmbeddingProvider,
1203        db: &VectorDatabase,
1204        credential: ProviderCredential,
1205    ) -> Result<()> {
1206        let mut batch = Vec::new();
1207        let mut batch_tokens = 0;
1208        let mut embeddings = Vec::new();
1209
1210        let digests = spans
1211            .iter()
1212            .map(|span| span.digest.clone())
1213            .collect::<Vec<_>>();
1214        let embeddings_for_digests = db
1215            .embeddings_for_digests(digests)
1216            .await
1217            .log_err()
1218            .unwrap_or_default();
1219
1220        for span in &*spans {
1221            if embeddings_for_digests.contains_key(&span.digest) {
1222                continue;
1223            };
1224
1225            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1226                let batch_embeddings = embedding_provider
1227                    .embed_batch(mem::take(&mut batch), credential.clone())
1228                    .await?;
1229                embeddings.extend(batch_embeddings);
1230                batch_tokens = 0;
1231            }
1232
1233            batch_tokens += span.token_count;
1234            batch.push(span.content.clone());
1235        }
1236
1237        if !batch.is_empty() {
1238            let batch_embeddings = embedding_provider
1239                .embed_batch(mem::take(&mut batch), credential)
1240                .await?;
1241
1242            embeddings.extend(batch_embeddings);
1243        }
1244
1245        let mut embeddings = embeddings.into_iter();
1246        for span in spans {
1247            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1248                Some(embedding.clone())
1249            } else {
1250                embeddings.next()
1251            };
1252            let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1253            span.embedding = Some(embedding);
1254        }
1255        Ok(())
1256    }
1257}
1258
1259impl Entity for SemanticIndex {
1260    type Event = ();
1261}
1262
1263impl Drop for JobHandle {
1264    fn drop(&mut self) {
1265        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1266            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1267            if let Some(tx) = inner.upgrade() {
1268                let mut tx = tx.lock();
1269                *tx.borrow_mut() -= 1;
1270            }
1271        }
1272    }
1273}
1274
1275#[cfg(test)]
1276mod tests {
1277
1278    use super::*;
1279    #[test]
1280    fn test_job_handle() {
1281        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1282        let tx = Arc::new(Mutex::new(job_count_tx));
1283        let job_handle = JobHandle::new(&tx);
1284
1285        assert_eq!(1, *job_count_rx.borrow());
1286        let new_job_handle = job_handle.clone();
1287        assert_eq!(1, *job_count_rx.borrow());
1288        drop(job_handle);
1289        assert_eq!(1, *job_count_rx.borrow());
1290        drop(new_job_handle);
1291        assert_eq!(0, *job_count_rx.borrow());
1292    }
1293}