semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider};
  11use ai::providers::open_ai::OpenAIEmbeddingProvider;
  12use anyhow::{anyhow, Context as _, Result};
  13use collections::{BTreeMap, HashMap, HashSet};
  14use db::VectorDatabase;
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{
  18    AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
  19    WeakModel,
  20};
  21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  22use lazy_static::lazy_static;
  23use ordered_float::OrderedFloat;
  24use parking_lot::Mutex;
  25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  26use postage::watch;
  27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  28use settings::Settings;
  29use smol::channel;
  30use std::{
  31    cmp::Reverse,
  32    env,
  33    future::Future,
  34    mem,
  35    ops::Range,
  36    path::{Path, PathBuf},
  37    sync::{Arc, Weak},
  38    time::{Duration, Instant, SystemTime},
  39};
  40use util::paths::PathMatcher;
  41use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  42use workspace::Workspace;
  43
  44const SEMANTIC_INDEX_VERSION: usize = 11;
  45const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  46const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  47
  48lazy_static! {
  49    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  50}
  51
  52pub fn init(
  53    fs: Arc<dyn Fs>,
  54    http_client: Arc<dyn HttpClient>,
  55    language_registry: Arc<LanguageRegistry>,
  56    cx: &mut AppContext,
  57) {
  58    SemanticIndexSettings::register(cx);
  59
  60    let db_file_path = EMBEDDINGS_DIR
  61        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  62        .join("embeddings_db");
  63
  64    cx.observe_new_views(
  65        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
  66            let Some(semantic_index) = SemanticIndex::global(cx) else {
  67                return;
  68            };
  69            let project = workspace.project().clone();
  70
  71            if project.read(cx).is_local() {
  72                cx.app_mut()
  73                    .spawn(|mut cx| async move {
  74                        let previously_indexed = semantic_index
  75                            .update(&mut cx, |index, cx| {
  76                                index.project_previously_indexed(&project, cx)
  77                            })?
  78                            .await?;
  79                        if previously_indexed {
  80                            semantic_index
  81                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
  82                                .await?;
  83                        }
  84                        anyhow::Ok(())
  85                    })
  86                    .detach_and_log_err(cx);
  87            }
  88        },
  89    )
  90    .detach();
  91
  92    cx.spawn(move |cx| async move {
  93        let semantic_index = SemanticIndex::new(
  94            fs,
  95            db_file_path,
  96            Arc::new(OpenAIEmbeddingProvider::new(
  97                http_client,
  98                cx.background_executor().clone(),
  99            )),
 100            language_registry,
 101            cx.clone(),
 102        )
 103        .await?;
 104
 105        cx.update(|cx| cx.set_global(semantic_index.clone()))?;
 106
 107        anyhow::Ok(())
 108    })
 109    .detach();
 110}
 111
 112#[derive(Copy, Clone, Debug)]
 113pub enum SemanticIndexStatus {
 114    NotAuthenticated,
 115    NotIndexed,
 116    Indexed,
 117    Indexing {
 118        remaining_files: usize,
 119        rate_limit_expiry: Option<Instant>,
 120    },
 121}
 122
 123pub struct SemanticIndex {
 124    fs: Arc<dyn Fs>,
 125    db: VectorDatabase,
 126    embedding_provider: Arc<dyn EmbeddingProvider>,
 127    language_registry: Arc<LanguageRegistry>,
 128    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 129    _embedding_task: Task<()>,
 130    _parsing_files_tasks: Vec<Task<()>>,
 131    projects: HashMap<WeakModel<Project>, ProjectState>,
 132}
 133
 134struct ProjectState {
 135    worktrees: HashMap<WorktreeId, WorktreeState>,
 136    pending_file_count_rx: watch::Receiver<usize>,
 137    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 138    pending_index: usize,
 139    _subscription: gpui::Subscription,
 140    _observe_pending_file_count: Task<()>,
 141}
 142
 143enum WorktreeState {
 144    Registering(RegisteringWorktreeState),
 145    Registered(RegisteredWorktreeState),
 146}
 147
 148impl WorktreeState {
 149    fn is_registered(&self) -> bool {
 150        matches!(self, Self::Registered(_))
 151    }
 152
 153    fn paths_changed(
 154        &mut self,
 155        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 156        worktree: &Worktree,
 157    ) {
 158        let changed_paths = match self {
 159            Self::Registering(state) => &mut state.changed_paths,
 160            Self::Registered(state) => &mut state.changed_paths,
 161        };
 162
 163        for (path, entry_id, change) in changes.iter() {
 164            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 165                continue;
 166            };
 167            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 168                continue;
 169            }
 170            changed_paths.insert(
 171                path.clone(),
 172                ChangedPathInfo {
 173                    mtime: entry.mtime,
 174                    is_deleted: *change == PathChange::Removed,
 175                },
 176            );
 177        }
 178    }
 179}
 180
 181struct RegisteringWorktreeState {
 182    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 183    done_rx: watch::Receiver<Option<()>>,
 184    _registration: Task<()>,
 185}
 186
 187impl RegisteringWorktreeState {
 188    fn done(&self) -> impl Future<Output = ()> {
 189        let mut done_rx = self.done_rx.clone();
 190        async move {
 191            while let Some(result) = done_rx.next().await {
 192                if result.is_some() {
 193                    break;
 194                }
 195            }
 196        }
 197    }
 198}
 199
 200struct RegisteredWorktreeState {
 201    db_id: i64,
 202    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 203}
 204
 205struct ChangedPathInfo {
 206    mtime: SystemTime,
 207    is_deleted: bool,
 208}
 209
 210#[derive(Clone)]
 211pub struct JobHandle {
 212    /// The outer Arc is here to count the clones of a JobHandle instance;
 213    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 214    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 215}
 216
 217impl JobHandle {
 218    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 219        *tx.lock().borrow_mut() += 1;
 220        Self {
 221            tx: Arc::new(Arc::downgrade(&tx)),
 222        }
 223    }
 224}
 225
 226impl ProjectState {
 227    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 228        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 229        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 230        Self {
 231            worktrees: Default::default(),
 232            pending_file_count_rx: pending_file_count_rx.clone(),
 233            pending_file_count_tx,
 234            pending_index: 0,
 235            _subscription: subscription,
 236            _observe_pending_file_count: cx.spawn({
 237                let mut pending_file_count_rx = pending_file_count_rx.clone();
 238                |this, mut cx| async move {
 239                    while let Some(_) = pending_file_count_rx.next().await {
 240                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
 241                            break;
 242                        }
 243                    }
 244                }
 245            }),
 246        }
 247    }
 248
 249    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 250        self.worktrees
 251            .iter()
 252            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 253                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 254                _ => None,
 255            })
 256    }
 257}
 258
 259#[derive(Clone)]
 260pub struct PendingFile {
 261    worktree_db_id: i64,
 262    relative_path: Arc<Path>,
 263    absolute_path: PathBuf,
 264    language: Option<Arc<Language>>,
 265    modified_time: SystemTime,
 266    job_handle: JobHandle,
 267}
 268
 269#[derive(Clone)]
 270pub struct SearchResult {
 271    pub buffer: Model<Buffer>,
 272    pub range: Range<Anchor>,
 273    pub similarity: OrderedFloat<f32>,
 274}
 275
 276impl SemanticIndex {
 277    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
 278        if cx.has_global::<Model<Self>>() {
 279            Some(cx.global::<Model<SemanticIndex>>().clone())
 280        } else {
 281            None
 282        }
 283    }
 284
 285    pub fn authenticate(&mut self, cx: &mut AppContext) -> bool {
 286        if !self.embedding_provider.has_credentials() {
 287            self.embedding_provider.retrieve_credentials(cx);
 288        } else {
 289            return true;
 290        }
 291
 292        self.embedding_provider.has_credentials()
 293    }
 294
 295    pub fn is_authenticated(&self) -> bool {
 296        self.embedding_provider.has_credentials()
 297    }
 298
 299    pub fn enabled(cx: &AppContext) -> bool {
 300        SemanticIndexSettings::get_global(cx).enabled
 301    }
 302
 303    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
 304        if !self.is_authenticated() {
 305            return SemanticIndexStatus::NotAuthenticated;
 306        }
 307
 308        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 309            if project_state
 310                .worktrees
 311                .values()
 312                .all(|worktree| worktree.is_registered())
 313                && project_state.pending_index == 0
 314            {
 315                SemanticIndexStatus::Indexed
 316            } else {
 317                SemanticIndexStatus::Indexing {
 318                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 319                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 320                }
 321            }
 322        } else {
 323            SemanticIndexStatus::NotIndexed
 324        }
 325    }
 326
 327    pub async fn new(
 328        fs: Arc<dyn Fs>,
 329        database_path: PathBuf,
 330        embedding_provider: Arc<dyn EmbeddingProvider>,
 331        language_registry: Arc<LanguageRegistry>,
 332        mut cx: AsyncAppContext,
 333    ) -> Result<Model<Self>> {
 334        let t0 = Instant::now();
 335        let database_path = Arc::from(database_path);
 336        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
 337            .await?;
 338
 339        log::trace!(
 340            "db initialization took {:?} milliseconds",
 341            t0.elapsed().as_millis()
 342        );
 343
 344        cx.build_model(|cx| {
 345            let t0 = Instant::now();
 346            let embedding_queue =
 347                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
 348            let _embedding_task = cx.background_executor().spawn({
 349                let embedded_files = embedding_queue.finished_files();
 350                let db = db.clone();
 351                async move {
 352                    while let Ok(file) = embedded_files.recv().await {
 353                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 354                            .await
 355                            .log_err();
 356                    }
 357                }
 358            });
 359
 360            // Parse files into embeddable spans.
 361            let (parsing_files_tx, parsing_files_rx) =
 362                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 363            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 364            let mut _parsing_files_tasks = Vec::new();
 365            for _ in 0..cx.background_executor().num_cpus() {
 366                let fs = fs.clone();
 367                let mut parsing_files_rx = parsing_files_rx.clone();
 368                let embedding_provider = embedding_provider.clone();
 369                let embedding_queue = embedding_queue.clone();
 370                let background = cx.background_executor().clone();
 371                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
 372                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 373                    loop {
 374                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 375                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 376                        futures::select_biased! {
 377                            next_file_to_parse = next_file_to_parse => {
 378                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 379                                    Self::parse_file(
 380                                        &fs,
 381                                        pending_file,
 382                                        &mut retriever,
 383                                        &embedding_queue,
 384                                        &embeddings_for_digest,
 385                                    )
 386                                    .await
 387                                } else {
 388                                    break;
 389                                }
 390                            },
 391                            _ = timer => {
 392                                embedding_queue.lock().flush();
 393                            }
 394                        }
 395                    }
 396                }));
 397            }
 398
 399            log::trace!(
 400                "semantic index task initialization took {:?} milliseconds",
 401                t0.elapsed().as_millis()
 402            );
 403            Self {
 404                fs,
 405                db,
 406                embedding_provider,
 407                language_registry,
 408                parsing_files_tx,
 409                _embedding_task,
 410                _parsing_files_tasks,
 411                projects: Default::default(),
 412            }
 413        })
 414    }
 415
 416    async fn parse_file(
 417        fs: &Arc<dyn Fs>,
 418        pending_file: PendingFile,
 419        retriever: &mut CodeContextRetriever,
 420        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 421        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 422    ) {
 423        let Some(language) = pending_file.language else {
 424            return;
 425        };
 426
 427        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 428            if let Some(mut spans) = retriever
 429                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 430                .log_err()
 431            {
 432                log::trace!(
 433                    "parsed path {:?}: {} spans",
 434                    pending_file.relative_path,
 435                    spans.len()
 436                );
 437
 438                for span in &mut spans {
 439                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 440                        span.embedding = Some(embedding.to_owned());
 441                    }
 442                }
 443
 444                embedding_queue.lock().push(FileToEmbed {
 445                    worktree_id: pending_file.worktree_db_id,
 446                    path: pending_file.relative_path,
 447                    mtime: pending_file.modified_time,
 448                    job_handle: pending_file.job_handle,
 449                    spans,
 450                });
 451            }
 452        }
 453    }
 454
 455    pub fn project_previously_indexed(
 456        &mut self,
 457        project: &Model<Project>,
 458        cx: &mut ModelContext<Self>,
 459    ) -> Task<Result<bool>> {
 460        let worktrees_indexed_previously = project
 461            .read(cx)
 462            .worktrees()
 463            .map(|worktree| {
 464                self.db
 465                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 466            })
 467            .collect::<Vec<_>>();
 468        cx.spawn(|_, _cx| async move {
 469            let worktree_indexed_previously =
 470                futures::future::join_all(worktrees_indexed_previously).await;
 471
 472            Ok(worktree_indexed_previously
 473                .iter()
 474                .filter(|worktree| worktree.is_ok())
 475                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 476        })
 477    }
 478
 479    fn project_entries_changed(
 480        &mut self,
 481        project: Model<Project>,
 482        worktree_id: WorktreeId,
 483        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 484        cx: &mut ModelContext<Self>,
 485    ) {
 486        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 487            return;
 488        };
 489        let project = project.downgrade();
 490        let Some(project_state) = self.projects.get_mut(&project) else {
 491            return;
 492        };
 493
 494        let worktree = worktree.read(cx);
 495        let worktree_state =
 496            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 497                worktree_state
 498            } else {
 499                return;
 500            };
 501        worktree_state.paths_changed(changes, worktree);
 502        if let WorktreeState::Registered(_) = worktree_state {
 503            cx.spawn(|this, mut cx| async move {
 504                cx.background_executor()
 505                    .timer(BACKGROUND_INDEXING_DELAY)
 506                    .await;
 507                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
 508                    this.update(&mut cx, |this, cx| {
 509                        this.index_project(project, cx).detach_and_log_err(cx)
 510                    })?;
 511                }
 512                anyhow::Ok(())
 513            })
 514            .detach_and_log_err(cx);
 515        }
 516    }
 517
 518    fn register_worktree(
 519        &mut self,
 520        project: Model<Project>,
 521        worktree: Model<Worktree>,
 522        cx: &mut ModelContext<Self>,
 523    ) {
 524        let project = project.downgrade();
 525        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 526            project_state
 527        } else {
 528            return;
 529        };
 530        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 531            worktree
 532        } else {
 533            return;
 534        };
 535        let worktree_abs_path = worktree.abs_path().clone();
 536        let scan_complete = worktree.scan_complete();
 537        let worktree_id = worktree.id();
 538        let db = self.db.clone();
 539        let language_registry = self.language_registry.clone();
 540        let (mut done_tx, done_rx) = watch::channel();
 541        let registration = cx.spawn(|this, mut cx| {
 542            async move {
 543                let register = async {
 544                    scan_complete.await;
 545                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 546                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 547                    let worktree = if let Some(project) = project.upgrade() {
 548                        project
 549                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 550                            .ok()
 551                            .flatten()
 552                            .context("worktree not found")?
 553                    } else {
 554                        return anyhow::Ok(());
 555                    };
 556                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
 557                    let mut changed_paths = cx
 558                        .background_executor()
 559                        .spawn(async move {
 560                            let mut changed_paths = BTreeMap::new();
 561                            for file in worktree.files(false, 0) {
 562                                let absolute_path = worktree.absolutize(&file.path);
 563
 564                                if file.is_external || file.is_ignored || file.is_symlink {
 565                                    continue;
 566                                }
 567
 568                                if let Ok(language) = language_registry
 569                                    .language_for_file(&absolute_path, None)
 570                                    .await
 571                                {
 572                                    // Test if file is valid parseable file
 573                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 574                                        .contains(&language.name().as_ref())
 575                                        && &language.name().as_ref() != &"Markdown"
 576                                        && language
 577                                            .grammar()
 578                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 579                                            .is_none()
 580                                    {
 581                                        continue;
 582                                    }
 583
 584                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 585                                    let already_stored = stored_mtime
 586                                        .map_or(false, |existing_mtime| {
 587                                            existing_mtime == file.mtime
 588                                        });
 589
 590                                    if !already_stored {
 591                                        changed_paths.insert(
 592                                            file.path.clone(),
 593                                            ChangedPathInfo {
 594                                                mtime: file.mtime,
 595                                                is_deleted: false,
 596                                            },
 597                                        );
 598                                    }
 599                                }
 600                            }
 601
 602                            // Clean up entries from database that are no longer in the worktree.
 603                            for (path, mtime) in file_mtimes {
 604                                changed_paths.insert(
 605                                    path.into(),
 606                                    ChangedPathInfo {
 607                                        mtime,
 608                                        is_deleted: true,
 609                                    },
 610                                );
 611                            }
 612
 613                            anyhow::Ok(changed_paths)
 614                        })
 615                        .await?;
 616                    this.update(&mut cx, |this, cx| {
 617                        let project_state = this
 618                            .projects
 619                            .get_mut(&project)
 620                            .context("project not registered")?;
 621                        let project = project.upgrade().context("project was dropped")?;
 622
 623                        if let Some(WorktreeState::Registering(state)) =
 624                            project_state.worktrees.remove(&worktree_id)
 625                        {
 626                            changed_paths.extend(state.changed_paths);
 627                        }
 628                        project_state.worktrees.insert(
 629                            worktree_id,
 630                            WorktreeState::Registered(RegisteredWorktreeState {
 631                                db_id,
 632                                changed_paths,
 633                            }),
 634                        );
 635                        this.index_project(project, cx).detach_and_log_err(cx);
 636
 637                        anyhow::Ok(())
 638                    })??;
 639
 640                    anyhow::Ok(())
 641                };
 642
 643                if register.await.log_err().is_none() {
 644                    // Stop tracking this worktree if the registration failed.
 645                    this.update(&mut cx, |this, _| {
 646                        this.projects.get_mut(&project).map(|project_state| {
 647                            project_state.worktrees.remove(&worktree_id);
 648                        });
 649                    })
 650                    .ok();
 651                }
 652
 653                *done_tx.borrow_mut() = Some(());
 654            }
 655        });
 656        project_state.worktrees.insert(
 657            worktree_id,
 658            WorktreeState::Registering(RegisteringWorktreeState {
 659                changed_paths: Default::default(),
 660                done_rx,
 661                _registration: registration,
 662            }),
 663        );
 664    }
 665
 666    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
 667        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 668        {
 669            project_state
 670        } else {
 671            return;
 672        };
 673
 674        let mut worktrees = project
 675            .read(cx)
 676            .worktrees()
 677            .filter(|worktree| worktree.read(cx).is_local())
 678            .collect::<Vec<_>>();
 679        let worktree_ids = worktrees
 680            .iter()
 681            .map(|worktree| worktree.read(cx).id())
 682            .collect::<HashSet<_>>();
 683
 684        // Remove worktrees that are no longer present
 685        project_state
 686            .worktrees
 687            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 688
 689        // Register new worktrees
 690        worktrees.retain(|worktree| {
 691            let worktree_id = worktree.read(cx).id();
 692            !project_state.worktrees.contains_key(&worktree_id)
 693        });
 694        for worktree in worktrees {
 695            self.register_worktree(project.clone(), worktree, cx);
 696        }
 697    }
 698
 699    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
 700        Some(
 701            self.projects
 702                .get(&project.downgrade())?
 703                .pending_file_count_rx
 704                .clone(),
 705        )
 706    }
 707
 708    pub fn search_project(
 709        &mut self,
 710        project: Model<Project>,
 711        query: String,
 712        limit: usize,
 713        includes: Vec<PathMatcher>,
 714        excludes: Vec<PathMatcher>,
 715        cx: &mut ModelContext<Self>,
 716    ) -> Task<Result<Vec<SearchResult>>> {
 717        if query.is_empty() {
 718            return Task::ready(Ok(Vec::new()));
 719        }
 720
 721        let index = self.index_project(project.clone(), cx);
 722        let embedding_provider = self.embedding_provider.clone();
 723
 724        cx.spawn(|this, mut cx| async move {
 725            index.await?;
 726            let t0 = Instant::now();
 727
 728            let query = embedding_provider
 729                .embed_batch(vec![query])
 730                .await?
 731                .pop()
 732                .context("could not embed query")?;
 733            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 734
 735            let search_start = Instant::now();
 736            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 737                this.search_modified_buffers(
 738                    &project,
 739                    query.clone(),
 740                    limit,
 741                    &includes,
 742                    &excludes,
 743                    cx,
 744                )
 745            })?;
 746            let file_results = this.update(&mut cx, |this, cx| {
 747                this.search_files(project, query, limit, includes, excludes, cx)
 748            })?;
 749            let (modified_buffer_results, file_results) =
 750                futures::join!(modified_buffer_results, file_results);
 751
 752            // Weave together the results from modified buffers and files.
 753            let mut results = Vec::new();
 754            let mut modified_buffers = HashSet::default();
 755            for result in modified_buffer_results.log_err().unwrap_or_default() {
 756                modified_buffers.insert(result.buffer.clone());
 757                results.push(result);
 758            }
 759            for result in file_results.log_err().unwrap_or_default() {
 760                if !modified_buffers.contains(&result.buffer) {
 761                    results.push(result);
 762                }
 763            }
 764            results.sort_by_key(|result| Reverse(result.similarity));
 765            results.truncate(limit);
 766            log::trace!("Semantic search took {:?}", search_start.elapsed());
 767            Ok(results)
 768        })
 769    }
 770
 771    pub fn search_files(
 772        &mut self,
 773        project: Model<Project>,
 774        query: Embedding,
 775        limit: usize,
 776        includes: Vec<PathMatcher>,
 777        excludes: Vec<PathMatcher>,
 778        cx: &mut ModelContext<Self>,
 779    ) -> Task<Result<Vec<SearchResult>>> {
 780        let db_path = self.db.path().clone();
 781        let fs = self.fs.clone();
 782        cx.spawn(|this, mut cx| async move {
 783            let database = VectorDatabase::new(
 784                fs.clone(),
 785                db_path.clone(),
 786                cx.background_executor().clone(),
 787            )
 788            .await?;
 789
 790            let worktree_db_ids = this.read_with(&cx, |this, _| {
 791                let project_state = this
 792                    .projects
 793                    .get(&project.downgrade())
 794                    .context("project was not indexed")?;
 795                let worktree_db_ids = project_state
 796                    .worktrees
 797                    .values()
 798                    .filter_map(|worktree| {
 799                        if let WorktreeState::Registered(worktree) = worktree {
 800                            Some(worktree.db_id)
 801                        } else {
 802                            None
 803                        }
 804                    })
 805                    .collect::<Vec<i64>>();
 806                anyhow::Ok(worktree_db_ids)
 807            })??;
 808
 809            let file_ids = database
 810                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 811                .await?;
 812
 813            let batch_n = cx.background_executor().num_cpus();
 814            let ids_len = file_ids.clone().len();
 815            let minimum_batch_size = 50;
 816
 817            let batch_size = {
 818                let size = ids_len / batch_n;
 819                if size < minimum_batch_size {
 820                    minimum_batch_size
 821                } else {
 822                    size
 823                }
 824            };
 825
 826            let mut batch_results = Vec::new();
 827            for batch in file_ids.chunks(batch_size) {
 828                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 829                let limit = limit.clone();
 830                let fs = fs.clone();
 831                let db_path = db_path.clone();
 832                let query = query.clone();
 833                if let Some(db) =
 834                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
 835                        .await
 836                        .log_err()
 837                {
 838                    batch_results.push(async move {
 839                        db.top_k_search(&query, limit, batch.as_slice()).await
 840                    });
 841                }
 842            }
 843
 844            let batch_results = futures::future::join_all(batch_results).await;
 845
 846            let mut results = Vec::new();
 847            for batch_result in batch_results {
 848                if batch_result.is_ok() {
 849                    for (id, similarity) in batch_result.unwrap() {
 850                        let ix = match results
 851                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 852                        {
 853                            Ok(ix) => ix,
 854                            Err(ix) => ix,
 855                        };
 856
 857                        results.insert(ix, (id, similarity));
 858                        results.truncate(limit);
 859                    }
 860                }
 861            }
 862
 863            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 864            let scores = results
 865                .into_iter()
 866                .map(|(_, score)| score)
 867                .collect::<Vec<_>>();
 868            let spans = database.spans_for_ids(ids.as_slice()).await?;
 869
 870            let mut tasks = Vec::new();
 871            let mut ranges = Vec::new();
 872            let weak_project = project.downgrade();
 873            project.update(&mut cx, |project, cx| {
 874                let this = this.upgrade().context("index was dropped")?;
 875                for (worktree_db_id, file_path, byte_range) in spans {
 876                    let project_state =
 877                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 878                            state
 879                        } else {
 880                            return Err(anyhow!("project not added"));
 881                        };
 882                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 883                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 884                        ranges.push(byte_range);
 885                    }
 886                }
 887
 888                Ok(())
 889            })??;
 890
 891            let buffers = futures::future::join_all(tasks).await;
 892            Ok(buffers
 893                .into_iter()
 894                .zip(ranges)
 895                .zip(scores)
 896                .filter_map(|((buffer, range), similarity)| {
 897                    let buffer = buffer.log_err()?;
 898                    let range = buffer
 899                        .read_with(&cx, |buffer, _| {
 900                            let start = buffer.clip_offset(range.start, Bias::Left);
 901                            let end = buffer.clip_offset(range.end, Bias::Right);
 902                            buffer.anchor_before(start)..buffer.anchor_after(end)
 903                        })
 904                        .log_err()?;
 905                    Some(SearchResult {
 906                        buffer,
 907                        range,
 908                        similarity,
 909                    })
 910                })
 911                .collect())
 912        })
 913    }
 914
 915    fn search_modified_buffers(
 916        &self,
 917        project: &Model<Project>,
 918        query: Embedding,
 919        limit: usize,
 920        includes: &[PathMatcher],
 921        excludes: &[PathMatcher],
 922        cx: &mut ModelContext<Self>,
 923    ) -> Task<Result<Vec<SearchResult>>> {
 924        let modified_buffers = project
 925            .read(cx)
 926            .opened_buffers()
 927            .into_iter()
 928            .filter_map(|buffer_handle| {
 929                let buffer = buffer_handle.read(cx);
 930                let snapshot = buffer.snapshot();
 931                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 932                    excludes.iter().any(|matcher| matcher.is_match(&path))
 933                });
 934
 935                let included = if includes.len() == 0 {
 936                    true
 937                } else {
 938                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 939                        includes.iter().any(|matcher| matcher.is_match(&path))
 940                    })
 941                };
 942
 943                if buffer.is_dirty() && !excluded && included {
 944                    Some((buffer_handle, snapshot))
 945                } else {
 946                    None
 947                }
 948            })
 949            .collect::<HashMap<_, _>>();
 950
 951        let embedding_provider = self.embedding_provider.clone();
 952        let fs = self.fs.clone();
 953        let db_path = self.db.path().clone();
 954        let background = cx.background_executor().clone();
 955        cx.background_executor().spawn(async move {
 956            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 957            let mut results = Vec::<SearchResult>::new();
 958
 959            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 960            for (buffer, snapshot) in modified_buffers {
 961                let language = snapshot
 962                    .language_at(0)
 963                    .cloned()
 964                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 965                let mut spans = retriever
 966                    .parse_file_with_template(None, &snapshot.text(), language)
 967                    .log_err()
 968                    .unwrap_or_default();
 969                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 970                    .await
 971                    .log_err()
 972                    .is_some()
 973                {
 974                    for span in spans {
 975                        let similarity = span.embedding.unwrap().similarity(&query);
 976                        let ix = match results
 977                            .binary_search_by_key(&Reverse(similarity), |result| {
 978                                Reverse(result.similarity)
 979                            }) {
 980                            Ok(ix) => ix,
 981                            Err(ix) => ix,
 982                        };
 983
 984                        let range = {
 985                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 986                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 987                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 988                        };
 989
 990                        results.insert(
 991                            ix,
 992                            SearchResult {
 993                                buffer: buffer.clone(),
 994                                range,
 995                                similarity,
 996                            },
 997                        );
 998                        results.truncate(limit);
 999                    }
1000                }
1001            }
1002
1003            Ok(results)
1004        })
1005    }
1006
1007    pub fn index_project(
1008        &mut self,
1009        project: Model<Project>,
1010        cx: &mut ModelContext<Self>,
1011    ) -> Task<Result<()>> {
1012        if !self.is_authenticated() {
1013            if !self.authenticate(cx) {
1014                return Task::ready(Err(anyhow!("user is not authenticated")));
1015            }
1016        }
1017
1018        if !self.projects.contains_key(&project.downgrade()) {
1019            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1020                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1021                    this.project_worktrees_changed(project.clone(), cx);
1022                }
1023                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1024                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1025                }
1026                _ => {}
1027            });
1028            let project_state = ProjectState::new(subscription, cx);
1029            self.projects.insert(project.downgrade(), project_state);
1030            self.project_worktrees_changed(project.clone(), cx);
1031        }
1032        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1033        project_state.pending_index += 1;
1034        cx.notify();
1035
1036        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1037        let db = self.db.clone();
1038        let language_registry = self.language_registry.clone();
1039        let parsing_files_tx = self.parsing_files_tx.clone();
1040        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1041
1042        cx.spawn(|this, mut cx| async move {
1043            worktree_registration.await?;
1044
1045            let mut pending_files = Vec::new();
1046            let mut files_to_delete = Vec::new();
1047            this.update(&mut cx, |this, cx| {
1048                let project_state = this
1049                    .projects
1050                    .get_mut(&project.downgrade())
1051                    .context("project was dropped")?;
1052                let pending_file_count_tx = &project_state.pending_file_count_tx;
1053
1054                project_state
1055                    .worktrees
1056                    .retain(|worktree_id, worktree_state| {
1057                        let worktree = if let Some(worktree) =
1058                            project.read(cx).worktree_for_id(*worktree_id, cx)
1059                        {
1060                            worktree
1061                        } else {
1062                            return false;
1063                        };
1064                        let worktree_state =
1065                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1066                                worktree_state
1067                            } else {
1068                                return true;
1069                            };
1070
1071                        worktree_state.changed_paths.retain(|path, info| {
1072                            if info.is_deleted {
1073                                files_to_delete.push((worktree_state.db_id, path.clone()));
1074                            } else {
1075                                let absolute_path = worktree.read(cx).absolutize(path);
1076                                let job_handle = JobHandle::new(pending_file_count_tx);
1077                                pending_files.push(PendingFile {
1078                                    absolute_path,
1079                                    relative_path: path.clone(),
1080                                    language: None,
1081                                    job_handle,
1082                                    modified_time: info.mtime,
1083                                    worktree_db_id: worktree_state.db_id,
1084                                });
1085                            }
1086
1087                            false
1088                        });
1089                        true
1090                    });
1091
1092                anyhow::Ok(())
1093            })??;
1094
1095            cx.background_executor()
1096                .spawn(async move {
1097                    for (worktree_db_id, path) in files_to_delete {
1098                        db.delete_file(worktree_db_id, path).await.log_err();
1099                    }
1100
1101                    let embeddings_for_digest = {
1102                        let mut files = HashMap::default();
1103                        for pending_file in &pending_files {
1104                            files
1105                                .entry(pending_file.worktree_db_id)
1106                                .or_insert(Vec::new())
1107                                .push(pending_file.relative_path.clone());
1108                        }
1109                        Arc::new(
1110                            db.embeddings_for_files(files)
1111                                .await
1112                                .log_err()
1113                                .unwrap_or_default(),
1114                        )
1115                    };
1116
1117                    for mut pending_file in pending_files {
1118                        if let Ok(language) = language_registry
1119                            .language_for_file(&pending_file.relative_path, None)
1120                            .await
1121                        {
1122                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1123                                && &language.name().as_ref() != &"Markdown"
1124                                && language
1125                                    .grammar()
1126                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1127                                    .is_none()
1128                            {
1129                                continue;
1130                            }
1131                            pending_file.language = Some(language);
1132                        }
1133                        parsing_files_tx
1134                            .try_send((embeddings_for_digest.clone(), pending_file))
1135                            .ok();
1136                    }
1137
1138                    // Wait until we're done indexing.
1139                    while let Some(count) = pending_file_count_rx.next().await {
1140                        if count == 0 {
1141                            break;
1142                        }
1143                    }
1144                })
1145                .await;
1146
1147            this.update(&mut cx, |this, cx| {
1148                let project_state = this
1149                    .projects
1150                    .get_mut(&project.downgrade())
1151                    .context("project was dropped")?;
1152                project_state.pending_index -= 1;
1153                cx.notify();
1154                anyhow::Ok(())
1155            })??;
1156
1157            Ok(())
1158        })
1159    }
1160
1161    fn wait_for_worktree_registration(
1162        &self,
1163        project: &Model<Project>,
1164        cx: &mut ModelContext<Self>,
1165    ) -> Task<Result<()>> {
1166        let project = project.downgrade();
1167        cx.spawn(|this, cx| async move {
1168            loop {
1169                let mut pending_worktrees = Vec::new();
1170                this.upgrade()
1171                    .context("semantic index dropped")?
1172                    .read_with(&cx, |this, _| {
1173                        if let Some(project) = this.projects.get(&project) {
1174                            for worktree in project.worktrees.values() {
1175                                if let WorktreeState::Registering(worktree) = worktree {
1176                                    pending_worktrees.push(worktree.done());
1177                                }
1178                            }
1179                        }
1180                    })?;
1181
1182                if pending_worktrees.is_empty() {
1183                    break;
1184                } else {
1185                    future::join_all(pending_worktrees).await;
1186                }
1187            }
1188            Ok(())
1189        })
1190    }
1191
1192    async fn embed_spans(
1193        spans: &mut [Span],
1194        embedding_provider: &dyn EmbeddingProvider,
1195        db: &VectorDatabase,
1196    ) -> Result<()> {
1197        let mut batch = Vec::new();
1198        let mut batch_tokens = 0;
1199        let mut embeddings = Vec::new();
1200
1201        let digests = spans
1202            .iter()
1203            .map(|span| span.digest.clone())
1204            .collect::<Vec<_>>();
1205        let embeddings_for_digests = db
1206            .embeddings_for_digests(digests)
1207            .await
1208            .log_err()
1209            .unwrap_or_default();
1210
1211        for span in &*spans {
1212            if embeddings_for_digests.contains_key(&span.digest) {
1213                continue;
1214            };
1215
1216            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1217                let batch_embeddings = embedding_provider
1218                    .embed_batch(mem::take(&mut batch))
1219                    .await?;
1220                embeddings.extend(batch_embeddings);
1221                batch_tokens = 0;
1222            }
1223
1224            batch_tokens += span.token_count;
1225            batch.push(span.content.clone());
1226        }
1227
1228        if !batch.is_empty() {
1229            let batch_embeddings = embedding_provider
1230                .embed_batch(mem::take(&mut batch))
1231                .await?;
1232
1233            embeddings.extend(batch_embeddings);
1234        }
1235
1236        let mut embeddings = embeddings.into_iter();
1237        for span in spans {
1238            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1239                Some(embedding.clone())
1240            } else {
1241                embeddings.next()
1242            };
1243            let embedding = embedding.context("failed to embed spans")?;
1244            span.embedding = Some(embedding);
1245        }
1246        Ok(())
1247    }
1248}
1249
1250impl Drop for JobHandle {
1251    fn drop(&mut self) {
1252        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1253            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1254            if let Some(tx) = inner.upgrade() {
1255                let mut tx = tx.lock();
1256                *tx.borrow_mut() -= 1;
1257            }
1258        }
1259    }
1260}
1261
1262#[cfg(test)]
1263mod tests {
1264
1265    use super::*;
1266    #[test]
1267    fn test_job_handle() {
1268        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1269        let tx = Arc::new(Mutex::new(job_count_tx));
1270        let job_handle = JobHandle::new(&tx);
1271
1272        assert_eq!(1, *job_count_rx.borrow());
1273        let new_job_handle = job_handle.clone();
1274        assert_eq!(1, *job_count_rx.borrow());
1275        drop(job_handle);
1276        assert_eq!(1, *job_count_rx.borrow());
1277        drop(new_job_handle);
1278        assert_eq!(0, *job_count_rx.borrow());
1279    }
1280}