semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
  11use anyhow::{anyhow, Result};
  12use collections::{BTreeMap, HashMap, HashSet};
  13use db::VectorDatabase;
  14use embedding_queue::{EmbeddingQueue, FileToEmbed};
  15use futures::{future, FutureExt, StreamExt};
  16use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  17use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  18use lazy_static::lazy_static;
  19use ordered_float::OrderedFloat;
  20use parking_lot::Mutex;
  21use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  22use postage::watch;
  23use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  24use smol::channel;
  25use std::{
  26    cmp::Reverse,
  27    env,
  28    future::Future,
  29    mem,
  30    ops::Range,
  31    path::{Path, PathBuf},
  32    sync::{Arc, Weak},
  33    time::{Duration, Instant, SystemTime},
  34};
  35use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  36use workspace::WorkspaceCreated;
  37
  38const SEMANTIC_INDEX_VERSION: usize = 11;
  39const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  40const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  41
  42lazy_static! {
  43    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  44}
  45
  46pub fn init(
  47    fs: Arc<dyn Fs>,
  48    http_client: Arc<dyn HttpClient>,
  49    language_registry: Arc<LanguageRegistry>,
  50    cx: &mut AppContext,
  51) {
  52    settings::register::<SemanticIndexSettings>(cx);
  53
  54    let db_file_path = EMBEDDINGS_DIR
  55        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  56        .join("embeddings_db");
  57
  58    cx.subscribe_global::<WorkspaceCreated, _>({
  59        move |event, cx| {
  60            let Some(semantic_index) = SemanticIndex::global(cx) else {
  61                return;
  62            };
  63            let workspace = &event.0;
  64            if let Some(workspace) = workspace.upgrade(cx) {
  65                let project = workspace.read(cx).project().clone();
  66                if project.read(cx).is_local() {
  67                    cx.spawn(|mut cx| async move {
  68                        let previously_indexed = semantic_index
  69                            .update(&mut cx, |index, cx| {
  70                                index.project_previously_indexed(&project, cx)
  71                            })
  72                            .await?;
  73                        if previously_indexed {
  74                            semantic_index
  75                                .update(&mut cx, |index, cx| index.index_project(project, cx))
  76                                .await?;
  77                        }
  78                        anyhow::Ok(())
  79                    })
  80                    .detach_and_log_err(cx);
  81                }
  82            }
  83        }
  84    })
  85    .detach();
  86
  87    cx.spawn(move |mut cx| async move {
  88        let semantic_index = SemanticIndex::new(
  89            fs,
  90            db_file_path,
  91            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
  92            language_registry,
  93            cx.clone(),
  94        )
  95        .await?;
  96
  97        cx.update(|cx| {
  98            cx.set_global(semantic_index.clone());
  99        });
 100
 101        anyhow::Ok(())
 102    })
 103    .detach();
 104}
 105
 106#[derive(Copy, Clone, Debug)]
 107pub enum SemanticIndexStatus {
 108    NotAuthenticated,
 109    NotIndexed,
 110    Indexed,
 111    Indexing {
 112        remaining_files: usize,
 113        rate_limit_expiry: Option<Instant>,
 114    },
 115}
 116
 117pub struct SemanticIndex {
 118    fs: Arc<dyn Fs>,
 119    db: VectorDatabase,
 120    embedding_provider: Arc<dyn EmbeddingProvider>,
 121    language_registry: Arc<LanguageRegistry>,
 122    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 123    _embedding_task: Task<()>,
 124    _parsing_files_tasks: Vec<Task<()>>,
 125    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 126    api_key: Option<String>,
 127    embedding_queue: Arc<Mutex<EmbeddingQueue>>,
 128}
 129
 130struct ProjectState {
 131    worktrees: HashMap<WorktreeId, WorktreeState>,
 132    pending_file_count_rx: watch::Receiver<usize>,
 133    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 134    pending_index: usize,
 135    _subscription: gpui::Subscription,
 136    _observe_pending_file_count: Task<()>,
 137}
 138
 139enum WorktreeState {
 140    Registering(RegisteringWorktreeState),
 141    Registered(RegisteredWorktreeState),
 142}
 143
 144impl WorktreeState {
 145    fn is_registered(&self) -> bool {
 146        matches!(self, Self::Registered(_))
 147    }
 148
 149    fn paths_changed(
 150        &mut self,
 151        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 152        worktree: &Worktree,
 153    ) {
 154        let changed_paths = match self {
 155            Self::Registering(state) => &mut state.changed_paths,
 156            Self::Registered(state) => &mut state.changed_paths,
 157        };
 158
 159        for (path, entry_id, change) in changes.iter() {
 160            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 161                continue;
 162            };
 163            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 164                continue;
 165            }
 166            changed_paths.insert(
 167                path.clone(),
 168                ChangedPathInfo {
 169                    mtime: entry.mtime,
 170                    is_deleted: *change == PathChange::Removed,
 171                },
 172            );
 173        }
 174    }
 175}
 176
 177struct RegisteringWorktreeState {
 178    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 179    done_rx: watch::Receiver<Option<()>>,
 180    _registration: Task<()>,
 181}
 182
 183impl RegisteringWorktreeState {
 184    fn done(&self) -> impl Future<Output = ()> {
 185        let mut done_rx = self.done_rx.clone();
 186        async move {
 187            while let Some(result) = done_rx.next().await {
 188                if result.is_some() {
 189                    break;
 190                }
 191            }
 192        }
 193    }
 194}
 195
 196struct RegisteredWorktreeState {
 197    db_id: i64,
 198    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 199}
 200
 201struct ChangedPathInfo {
 202    mtime: SystemTime,
 203    is_deleted: bool,
 204}
 205
 206#[derive(Clone)]
 207pub struct JobHandle {
 208    /// The outer Arc is here to count the clones of a JobHandle instance;
 209    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 210    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 211}
 212
 213impl JobHandle {
 214    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 215        *tx.lock().borrow_mut() += 1;
 216        Self {
 217            tx: Arc::new(Arc::downgrade(&tx)),
 218        }
 219    }
 220}
 221
 222impl ProjectState {
 223    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 224        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 225        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 226        Self {
 227            worktrees: Default::default(),
 228            pending_file_count_rx: pending_file_count_rx.clone(),
 229            pending_file_count_tx,
 230            pending_index: 0,
 231            _subscription: subscription,
 232            _observe_pending_file_count: cx.spawn_weak({
 233                let mut pending_file_count_rx = pending_file_count_rx.clone();
 234                |this, mut cx| async move {
 235                    while let Some(_) = pending_file_count_rx.next().await {
 236                        if let Some(this) = this.upgrade(&cx) {
 237                            this.update(&mut cx, |_, cx| cx.notify());
 238                        }
 239                    }
 240                }
 241            }),
 242        }
 243    }
 244
 245    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 246        self.worktrees
 247            .iter()
 248            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 249                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 250                _ => None,
 251            })
 252    }
 253}
 254
 255#[derive(Clone)]
 256pub struct PendingFile {
 257    worktree_db_id: i64,
 258    relative_path: Arc<Path>,
 259    absolute_path: PathBuf,
 260    language: Option<Arc<Language>>,
 261    modified_time: SystemTime,
 262    job_handle: JobHandle,
 263}
 264
 265#[derive(Clone)]
 266pub struct SearchResult {
 267    pub buffer: ModelHandle<Buffer>,
 268    pub range: Range<Anchor>,
 269    pub similarity: OrderedFloat<f32>,
 270}
 271
 272impl SemanticIndex {
 273    pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
 274        if cx.has_global::<ModelHandle<Self>>() {
 275            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 276        } else {
 277            None
 278        }
 279    }
 280
 281    pub fn authenticate(&mut self, cx: &AppContext) {
 282        if self.api_key.is_none() {
 283            self.api_key = self.embedding_provider.retrieve_credentials(cx);
 284
 285            self.embedding_queue
 286                .lock()
 287                .set_api_key(self.api_key.clone());
 288        }
 289    }
 290
 291    pub fn is_authenticated(&self) -> bool {
 292        self.api_key.is_some()
 293    }
 294
 295    pub fn enabled(cx: &AppContext) -> bool {
 296        settings::get::<SemanticIndexSettings>(cx).enabled
 297    }
 298
 299    pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
 300        if !self.is_authenticated() {
 301            return SemanticIndexStatus::NotAuthenticated;
 302        }
 303
 304        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 305            if project_state
 306                .worktrees
 307                .values()
 308                .all(|worktree| worktree.is_registered())
 309                && project_state.pending_index == 0
 310            {
 311                SemanticIndexStatus::Indexed
 312            } else {
 313                SemanticIndexStatus::Indexing {
 314                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 315                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 316                }
 317            }
 318        } else {
 319            SemanticIndexStatus::NotIndexed
 320        }
 321    }
 322
 323    pub async fn new(
 324        fs: Arc<dyn Fs>,
 325        database_path: PathBuf,
 326        embedding_provider: Arc<dyn EmbeddingProvider>,
 327        language_registry: Arc<LanguageRegistry>,
 328        mut cx: AsyncAppContext,
 329    ) -> Result<ModelHandle<Self>> {
 330        let t0 = Instant::now();
 331        let database_path = Arc::from(database_path);
 332        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 333
 334        log::trace!(
 335            "db initialization took {:?} milliseconds",
 336            t0.elapsed().as_millis()
 337        );
 338
 339        Ok(cx.add_model(|cx| {
 340            let t0 = Instant::now();
 341            let embedding_queue =
 342                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
 343            let _embedding_task = cx.background().spawn({
 344                let embedded_files = embedding_queue.finished_files();
 345                let db = db.clone();
 346                async move {
 347                    while let Ok(file) = embedded_files.recv().await {
 348                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 349                            .await
 350                            .log_err();
 351                    }
 352                }
 353            });
 354
 355            // Parse files into embeddable spans.
 356            let (parsing_files_tx, parsing_files_rx) =
 357                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 358            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 359            let mut _parsing_files_tasks = Vec::new();
 360            for _ in 0..cx.background().num_cpus() {
 361                let fs = fs.clone();
 362                let mut parsing_files_rx = parsing_files_rx.clone();
 363                let embedding_provider = embedding_provider.clone();
 364                let embedding_queue = embedding_queue.clone();
 365                let background = cx.background().clone();
 366                _parsing_files_tasks.push(cx.background().spawn(async move {
 367                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 368                    loop {
 369                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 370                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 371                        futures::select_biased! {
 372                            next_file_to_parse = next_file_to_parse => {
 373                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 374                                    Self::parse_file(
 375                                        &fs,
 376                                        pending_file,
 377                                        &mut retriever,
 378                                        &embedding_queue,
 379                                        &embeddings_for_digest,
 380                                    )
 381                                    .await
 382                                } else {
 383                                    break;
 384                                }
 385                            },
 386                            _ = timer => {
 387                                embedding_queue.lock().flush();
 388                            }
 389                        }
 390                    }
 391                }));
 392            }
 393
 394            log::trace!(
 395                "semantic index task initialization took {:?} milliseconds",
 396                t0.elapsed().as_millis()
 397            );
 398            Self {
 399                fs,
 400                db,
 401                embedding_provider,
 402                language_registry,
 403                parsing_files_tx,
 404                _embedding_task,
 405                _parsing_files_tasks,
 406                projects: Default::default(),
 407                api_key: None,
 408                embedding_queue
 409            }
 410        }))
 411    }
 412
 413    async fn parse_file(
 414        fs: &Arc<dyn Fs>,
 415        pending_file: PendingFile,
 416        retriever: &mut CodeContextRetriever,
 417        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 418        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 419    ) {
 420        let Some(language) = pending_file.language else {
 421            return;
 422        };
 423
 424        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 425            if let Some(mut spans) = retriever
 426                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 427                .log_err()
 428            {
 429                log::trace!(
 430                    "parsed path {:?}: {} spans",
 431                    pending_file.relative_path,
 432                    spans.len()
 433                );
 434
 435                for span in &mut spans {
 436                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 437                        span.embedding = Some(embedding.to_owned());
 438                    }
 439                }
 440
 441                embedding_queue.lock().push(FileToEmbed {
 442                    worktree_id: pending_file.worktree_db_id,
 443                    path: pending_file.relative_path,
 444                    mtime: pending_file.modified_time,
 445                    job_handle: pending_file.job_handle,
 446                    spans,
 447                });
 448            }
 449        }
 450    }
 451
 452    pub fn project_previously_indexed(
 453        &mut self,
 454        project: &ModelHandle<Project>,
 455        cx: &mut ModelContext<Self>,
 456    ) -> Task<Result<bool>> {
 457        let worktrees_indexed_previously = project
 458            .read(cx)
 459            .worktrees(cx)
 460            .map(|worktree| {
 461                self.db
 462                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 463            })
 464            .collect::<Vec<_>>();
 465        cx.spawn(|_, _cx| async move {
 466            let worktree_indexed_previously =
 467                futures::future::join_all(worktrees_indexed_previously).await;
 468
 469            Ok(worktree_indexed_previously
 470                .iter()
 471                .filter(|worktree| worktree.is_ok())
 472                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 473        })
 474    }
 475
 476    fn project_entries_changed(
 477        &mut self,
 478        project: ModelHandle<Project>,
 479        worktree_id: WorktreeId,
 480        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 481        cx: &mut ModelContext<Self>,
 482    ) {
 483        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 484            return;
 485        };
 486        let project = project.downgrade();
 487        let Some(project_state) = self.projects.get_mut(&project) else {
 488            return;
 489        };
 490
 491        let worktree = worktree.read(cx);
 492        let worktree_state =
 493            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 494                worktree_state
 495            } else {
 496                return;
 497            };
 498        worktree_state.paths_changed(changes, worktree);
 499        if let WorktreeState::Registered(_) = worktree_state {
 500            cx.spawn_weak(|this, mut cx| async move {
 501                cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
 502                if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
 503                    this.update(&mut cx, |this, cx| {
 504                        this.index_project(project, cx).detach_and_log_err(cx)
 505                    });
 506                }
 507            })
 508            .detach();
 509        }
 510    }
 511
 512    fn register_worktree(
 513        &mut self,
 514        project: ModelHandle<Project>,
 515        worktree: ModelHandle<Worktree>,
 516        cx: &mut ModelContext<Self>,
 517    ) {
 518        let project = project.downgrade();
 519        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 520            project_state
 521        } else {
 522            return;
 523        };
 524        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 525            worktree
 526        } else {
 527            return;
 528        };
 529        let worktree_abs_path = worktree.abs_path().clone();
 530        let scan_complete = worktree.scan_complete();
 531        let worktree_id = worktree.id();
 532        let db = self.db.clone();
 533        let language_registry = self.language_registry.clone();
 534        let (mut done_tx, done_rx) = watch::channel();
 535        let registration = cx.spawn(|this, mut cx| {
 536            async move {
 537                let register = async {
 538                    scan_complete.await;
 539                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 540                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 541                    let worktree = if let Some(project) = project.upgrade(&cx) {
 542                        project
 543                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 544                            .ok_or_else(|| anyhow!("worktree not found"))?
 545                    } else {
 546                        return anyhow::Ok(());
 547                    };
 548                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
 549                    let mut changed_paths = cx
 550                        .background()
 551                        .spawn(async move {
 552                            let mut changed_paths = BTreeMap::new();
 553                            for file in worktree.files(false, 0) {
 554                                let absolute_path = worktree.absolutize(&file.path);
 555
 556                                if file.is_external || file.is_ignored || file.is_symlink {
 557                                    continue;
 558                                }
 559
 560                                if let Ok(language) = language_registry
 561                                    .language_for_file(&absolute_path, None)
 562                                    .await
 563                                {
 564                                    // Test if file is valid parseable file
 565                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 566                                        .contains(&language.name().as_ref())
 567                                        && &language.name().as_ref() != &"Markdown"
 568                                        && language
 569                                            .grammar()
 570                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 571                                            .is_none()
 572                                    {
 573                                        continue;
 574                                    }
 575
 576                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 577                                    let already_stored = stored_mtime
 578                                        .map_or(false, |existing_mtime| {
 579                                            existing_mtime == file.mtime
 580                                        });
 581
 582                                    if !already_stored {
 583                                        changed_paths.insert(
 584                                            file.path.clone(),
 585                                            ChangedPathInfo {
 586                                                mtime: file.mtime,
 587                                                is_deleted: false,
 588                                            },
 589                                        );
 590                                    }
 591                                }
 592                            }
 593
 594                            // Clean up entries from database that are no longer in the worktree.
 595                            for (path, mtime) in file_mtimes {
 596                                changed_paths.insert(
 597                                    path.into(),
 598                                    ChangedPathInfo {
 599                                        mtime,
 600                                        is_deleted: true,
 601                                    },
 602                                );
 603                            }
 604
 605                            anyhow::Ok(changed_paths)
 606                        })
 607                        .await?;
 608                    this.update(&mut cx, |this, cx| {
 609                        let project_state = this
 610                            .projects
 611                            .get_mut(&project)
 612                            .ok_or_else(|| anyhow!("project not registered"))?;
 613                        let project = project
 614                            .upgrade(cx)
 615                            .ok_or_else(|| anyhow!("project was dropped"))?;
 616
 617                        if let Some(WorktreeState::Registering(state)) =
 618                            project_state.worktrees.remove(&worktree_id)
 619                        {
 620                            changed_paths.extend(state.changed_paths);
 621                        }
 622                        project_state.worktrees.insert(
 623                            worktree_id,
 624                            WorktreeState::Registered(RegisteredWorktreeState {
 625                                db_id,
 626                                changed_paths,
 627                            }),
 628                        );
 629                        this.index_project(project, cx).detach_and_log_err(cx);
 630
 631                        anyhow::Ok(())
 632                    })?;
 633
 634                    anyhow::Ok(())
 635                };
 636
 637                if register.await.log_err().is_none() {
 638                    // Stop tracking this worktree if the registration failed.
 639                    this.update(&mut cx, |this, _| {
 640                        this.projects.get_mut(&project).map(|project_state| {
 641                            project_state.worktrees.remove(&worktree_id);
 642                        });
 643                    })
 644                }
 645
 646                *done_tx.borrow_mut() = Some(());
 647            }
 648        });
 649        project_state.worktrees.insert(
 650            worktree_id,
 651            WorktreeState::Registering(RegisteringWorktreeState {
 652                changed_paths: Default::default(),
 653                done_rx,
 654                _registration: registration,
 655            }),
 656        );
 657    }
 658
 659    fn project_worktrees_changed(
 660        &mut self,
 661        project: ModelHandle<Project>,
 662        cx: &mut ModelContext<Self>,
 663    ) {
 664        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 665        {
 666            project_state
 667        } else {
 668            return;
 669        };
 670
 671        let mut worktrees = project
 672            .read(cx)
 673            .worktrees(cx)
 674            .filter(|worktree| worktree.read(cx).is_local())
 675            .collect::<Vec<_>>();
 676        let worktree_ids = worktrees
 677            .iter()
 678            .map(|worktree| worktree.read(cx).id())
 679            .collect::<HashSet<_>>();
 680
 681        // Remove worktrees that are no longer present
 682        project_state
 683            .worktrees
 684            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 685
 686        // Register new worktrees
 687        worktrees.retain(|worktree| {
 688            let worktree_id = worktree.read(cx).id();
 689            !project_state.worktrees.contains_key(&worktree_id)
 690        });
 691        for worktree in worktrees {
 692            self.register_worktree(project.clone(), worktree, cx);
 693        }
 694    }
 695
 696    pub fn pending_file_count(
 697        &self,
 698        project: &ModelHandle<Project>,
 699    ) -> Option<watch::Receiver<usize>> {
 700        Some(
 701            self.projects
 702                .get(&project.downgrade())?
 703                .pending_file_count_rx
 704                .clone(),
 705        )
 706    }
 707
 708    pub fn search_project(
 709        &mut self,
 710        project: ModelHandle<Project>,
 711        query: String,
 712        limit: usize,
 713        includes: Vec<PathMatcher>,
 714        excludes: Vec<PathMatcher>,
 715        cx: &mut ModelContext<Self>,
 716    ) -> Task<Result<Vec<SearchResult>>> {
 717        if query.is_empty() {
 718            return Task::ready(Ok(Vec::new()));
 719        }
 720
 721        let index = self.index_project(project.clone(), cx);
 722        let embedding_provider = self.embedding_provider.clone();
 723        let api_key = self.api_key.clone();
 724
 725        cx.spawn(|this, mut cx| async move {
 726            index.await?;
 727            let t0 = Instant::now();
 728            let query = embedding_provider
 729                .embed_batch(vec![query], api_key)
 730                .await?
 731                .pop()
 732                .ok_or_else(|| anyhow!("could not embed query"))?;
 733            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 734
 735            let search_start = Instant::now();
 736            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 737                this.search_modified_buffers(
 738                    &project,
 739                    query.clone(),
 740                    limit,
 741                    &includes,
 742                    &excludes,
 743                    cx,
 744                )
 745            });
 746            let file_results = this.update(&mut cx, |this, cx| {
 747                this.search_files(project, query, limit, includes, excludes, cx)
 748            });
 749            let (modified_buffer_results, file_results) =
 750                futures::join!(modified_buffer_results, file_results);
 751
 752            // Weave together the results from modified buffers and files.
 753            let mut results = Vec::new();
 754            let mut modified_buffers = HashSet::default();
 755            for result in modified_buffer_results.log_err().unwrap_or_default() {
 756                modified_buffers.insert(result.buffer.clone());
 757                results.push(result);
 758            }
 759            for result in file_results.log_err().unwrap_or_default() {
 760                if !modified_buffers.contains(&result.buffer) {
 761                    results.push(result);
 762                }
 763            }
 764            results.sort_by_key(|result| Reverse(result.similarity));
 765            results.truncate(limit);
 766            log::trace!("Semantic search took {:?}", search_start.elapsed());
 767            Ok(results)
 768        })
 769    }
 770
 771    pub fn search_files(
 772        &mut self,
 773        project: ModelHandle<Project>,
 774        query: Embedding,
 775        limit: usize,
 776        includes: Vec<PathMatcher>,
 777        excludes: Vec<PathMatcher>,
 778        cx: &mut ModelContext<Self>,
 779    ) -> Task<Result<Vec<SearchResult>>> {
 780        let db_path = self.db.path().clone();
 781        let fs = self.fs.clone();
 782        cx.spawn(|this, mut cx| async move {
 783            let database =
 784                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 785
 786            let worktree_db_ids = this.read_with(&cx, |this, _| {
 787                let project_state = this
 788                    .projects
 789                    .get(&project.downgrade())
 790                    .ok_or_else(|| anyhow!("project was not indexed"))?;
 791                let worktree_db_ids = project_state
 792                    .worktrees
 793                    .values()
 794                    .filter_map(|worktree| {
 795                        if let WorktreeState::Registered(worktree) = worktree {
 796                            Some(worktree.db_id)
 797                        } else {
 798                            None
 799                        }
 800                    })
 801                    .collect::<Vec<i64>>();
 802                anyhow::Ok(worktree_db_ids)
 803            })?;
 804
 805            let file_ids = database
 806                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 807                .await?;
 808
 809            let batch_n = cx.background().num_cpus();
 810            let ids_len = file_ids.clone().len();
 811            let minimum_batch_size = 50;
 812
 813            let batch_size = {
 814                let size = ids_len / batch_n;
 815                if size < minimum_batch_size {
 816                    minimum_batch_size
 817                } else {
 818                    size
 819                }
 820            };
 821
 822            let mut batch_results = Vec::new();
 823            for batch in file_ids.chunks(batch_size) {
 824                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 825                let limit = limit.clone();
 826                let fs = fs.clone();
 827                let db_path = db_path.clone();
 828                let query = query.clone();
 829                if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
 830                    .await
 831                    .log_err()
 832                {
 833                    batch_results.push(async move {
 834                        db.top_k_search(&query, limit, batch.as_slice()).await
 835                    });
 836                }
 837            }
 838
 839            let batch_results = futures::future::join_all(batch_results).await;
 840
 841            let mut results = Vec::new();
 842            for batch_result in batch_results {
 843                if batch_result.is_ok() {
 844                    for (id, similarity) in batch_result.unwrap() {
 845                        let ix = match results
 846                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 847                        {
 848                            Ok(ix) => ix,
 849                            Err(ix) => ix,
 850                        };
 851
 852                        results.insert(ix, (id, similarity));
 853                        results.truncate(limit);
 854                    }
 855                }
 856            }
 857
 858            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 859            let scores = results
 860                .into_iter()
 861                .map(|(_, score)| score)
 862                .collect::<Vec<_>>();
 863            let spans = database.spans_for_ids(ids.as_slice()).await?;
 864
 865            let mut tasks = Vec::new();
 866            let mut ranges = Vec::new();
 867            let weak_project = project.downgrade();
 868            project.update(&mut cx, |project, cx| {
 869                for (worktree_db_id, file_path, byte_range) in spans {
 870                    let project_state =
 871                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 872                            state
 873                        } else {
 874                            return Err(anyhow!("project not added"));
 875                        };
 876                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 877                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 878                        ranges.push(byte_range);
 879                    }
 880                }
 881
 882                Ok(())
 883            })?;
 884
 885            let buffers = futures::future::join_all(tasks).await;
 886            Ok(buffers
 887                .into_iter()
 888                .zip(ranges)
 889                .zip(scores)
 890                .filter_map(|((buffer, range), similarity)| {
 891                    let buffer = buffer.log_err()?;
 892                    let range = buffer.read_with(&cx, |buffer, _| {
 893                        let start = buffer.clip_offset(range.start, Bias::Left);
 894                        let end = buffer.clip_offset(range.end, Bias::Right);
 895                        buffer.anchor_before(start)..buffer.anchor_after(end)
 896                    });
 897                    Some(SearchResult {
 898                        buffer,
 899                        range,
 900                        similarity,
 901                    })
 902                })
 903                .collect())
 904        })
 905    }
 906
 907    fn search_modified_buffers(
 908        &self,
 909        project: &ModelHandle<Project>,
 910        query: Embedding,
 911        limit: usize,
 912        includes: &[PathMatcher],
 913        excludes: &[PathMatcher],
 914        cx: &mut ModelContext<Self>,
 915    ) -> Task<Result<Vec<SearchResult>>> {
 916        let modified_buffers = project
 917            .read(cx)
 918            .opened_buffers(cx)
 919            .into_iter()
 920            .filter_map(|buffer_handle| {
 921                let buffer = buffer_handle.read(cx);
 922                let snapshot = buffer.snapshot();
 923                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 924                    excludes.iter().any(|matcher| matcher.is_match(&path))
 925                });
 926
 927                let included = if includes.len() == 0 {
 928                    true
 929                } else {
 930                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 931                        includes.iter().any(|matcher| matcher.is_match(&path))
 932                    })
 933                };
 934
 935                if buffer.is_dirty() && !excluded && included {
 936                    Some((buffer_handle, snapshot))
 937                } else {
 938                    None
 939                }
 940            })
 941            .collect::<HashMap<_, _>>();
 942
 943        let embedding_provider = self.embedding_provider.clone();
 944        let fs = self.fs.clone();
 945        let db_path = self.db.path().clone();
 946        let background = cx.background().clone();
 947        let api_key = self.api_key.clone();
 948        cx.background().spawn(async move {
 949            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 950            let mut results = Vec::<SearchResult>::new();
 951
 952            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 953            for (buffer, snapshot) in modified_buffers {
 954                let language = snapshot
 955                    .language_at(0)
 956                    .cloned()
 957                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 958                let mut spans = retriever
 959                    .parse_file_with_template(None, &snapshot.text(), language)
 960                    .log_err()
 961                    .unwrap_or_default();
 962                if Self::embed_spans(
 963                    &mut spans,
 964                    embedding_provider.as_ref(),
 965                    &db,
 966                    api_key.clone(),
 967                )
 968                .await
 969                .log_err()
 970                .is_some()
 971                {
 972                    for span in spans {
 973                        let similarity = span.embedding.unwrap().similarity(&query);
 974                        let ix = match results
 975                            .binary_search_by_key(&Reverse(similarity), |result| {
 976                                Reverse(result.similarity)
 977                            }) {
 978                            Ok(ix) => ix,
 979                            Err(ix) => ix,
 980                        };
 981
 982                        let range = {
 983                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 984                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 985                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 986                        };
 987
 988                        results.insert(
 989                            ix,
 990                            SearchResult {
 991                                buffer: buffer.clone(),
 992                                range,
 993                                similarity,
 994                            },
 995                        );
 996                        results.truncate(limit);
 997                    }
 998                }
 999            }
1000
1001            Ok(results)
1002        })
1003    }
1004
1005    pub fn index_project(
1006        &mut self,
1007        project: ModelHandle<Project>,
1008        cx: &mut ModelContext<Self>,
1009    ) -> Task<Result<()>> {
1010        if self.api_key.is_none() {
1011            self.authenticate(cx);
1012            if self.api_key.is_none() {
1013                return Task::ready(Err(anyhow!("user is not authenticated")));
1014            }
1015        }
1016
1017        if !self.projects.contains_key(&project.downgrade()) {
1018            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1019                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1020                    this.project_worktrees_changed(project.clone(), cx);
1021                }
1022                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1023                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1024                }
1025                _ => {}
1026            });
1027            let project_state = ProjectState::new(subscription, cx);
1028            self.projects.insert(project.downgrade(), project_state);
1029            self.project_worktrees_changed(project.clone(), cx);
1030        }
1031        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1032        project_state.pending_index += 1;
1033        cx.notify();
1034
1035        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1036        let db = self.db.clone();
1037        let language_registry = self.language_registry.clone();
1038        let parsing_files_tx = self.parsing_files_tx.clone();
1039        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1040
1041        cx.spawn(|this, mut cx| async move {
1042            worktree_registration.await?;
1043
1044            let mut pending_files = Vec::new();
1045            let mut files_to_delete = Vec::new();
1046            this.update(&mut cx, |this, cx| {
1047                let project_state = this
1048                    .projects
1049                    .get_mut(&project.downgrade())
1050                    .ok_or_else(|| anyhow!("project was dropped"))?;
1051                let pending_file_count_tx = &project_state.pending_file_count_tx;
1052
1053                project_state
1054                    .worktrees
1055                    .retain(|worktree_id, worktree_state| {
1056                        let worktree = if let Some(worktree) =
1057                            project.read(cx).worktree_for_id(*worktree_id, cx)
1058                        {
1059                            worktree
1060                        } else {
1061                            return false;
1062                        };
1063                        let worktree_state =
1064                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1065                                worktree_state
1066                            } else {
1067                                return true;
1068                            };
1069
1070                        worktree_state.changed_paths.retain(|path, info| {
1071                            if info.is_deleted {
1072                                files_to_delete.push((worktree_state.db_id, path.clone()));
1073                            } else {
1074                                let absolute_path = worktree.read(cx).absolutize(path);
1075                                let job_handle = JobHandle::new(pending_file_count_tx);
1076                                pending_files.push(PendingFile {
1077                                    absolute_path,
1078                                    relative_path: path.clone(),
1079                                    language: None,
1080                                    job_handle,
1081                                    modified_time: info.mtime,
1082                                    worktree_db_id: worktree_state.db_id,
1083                                });
1084                            }
1085
1086                            false
1087                        });
1088                        true
1089                    });
1090
1091                anyhow::Ok(())
1092            })?;
1093
1094            cx.background()
1095                .spawn(async move {
1096                    for (worktree_db_id, path) in files_to_delete {
1097                        db.delete_file(worktree_db_id, path).await.log_err();
1098                    }
1099
1100                    let embeddings_for_digest = {
1101                        let mut files = HashMap::default();
1102                        for pending_file in &pending_files {
1103                            files
1104                                .entry(pending_file.worktree_db_id)
1105                                .or_insert(Vec::new())
1106                                .push(pending_file.relative_path.clone());
1107                        }
1108                        Arc::new(
1109                            db.embeddings_for_files(files)
1110                                .await
1111                                .log_err()
1112                                .unwrap_or_default(),
1113                        )
1114                    };
1115
1116                    for mut pending_file in pending_files {
1117                        if let Ok(language) = language_registry
1118                            .language_for_file(&pending_file.relative_path, None)
1119                            .await
1120                        {
1121                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1122                                && &language.name().as_ref() != &"Markdown"
1123                                && language
1124                                    .grammar()
1125                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1126                                    .is_none()
1127                            {
1128                                continue;
1129                            }
1130                            pending_file.language = Some(language);
1131                        }
1132                        parsing_files_tx
1133                            .try_send((embeddings_for_digest.clone(), pending_file))
1134                            .ok();
1135                    }
1136
1137                    // Wait until we're done indexing.
1138                    while let Some(count) = pending_file_count_rx.next().await {
1139                        if count == 0 {
1140                            break;
1141                        }
1142                    }
1143                })
1144                .await;
1145
1146            this.update(&mut cx, |this, cx| {
1147                let project_state = this
1148                    .projects
1149                    .get_mut(&project.downgrade())
1150                    .ok_or_else(|| anyhow!("project was dropped"))?;
1151                project_state.pending_index -= 1;
1152                cx.notify();
1153                anyhow::Ok(())
1154            })?;
1155
1156            Ok(())
1157        })
1158    }
1159
1160    fn wait_for_worktree_registration(
1161        &self,
1162        project: &ModelHandle<Project>,
1163        cx: &mut ModelContext<Self>,
1164    ) -> Task<Result<()>> {
1165        let project = project.downgrade();
1166        cx.spawn_weak(|this, cx| async move {
1167            loop {
1168                let mut pending_worktrees = Vec::new();
1169                this.upgrade(&cx)
1170                    .ok_or_else(|| anyhow!("semantic index dropped"))?
1171                    .read_with(&cx, |this, _| {
1172                        if let Some(project) = this.projects.get(&project) {
1173                            for worktree in project.worktrees.values() {
1174                                if let WorktreeState::Registering(worktree) = worktree {
1175                                    pending_worktrees.push(worktree.done());
1176                                }
1177                            }
1178                        }
1179                    });
1180
1181                if pending_worktrees.is_empty() {
1182                    break;
1183                } else {
1184                    future::join_all(pending_worktrees).await;
1185                }
1186            }
1187            Ok(())
1188        })
1189    }
1190
1191    async fn embed_spans(
1192        spans: &mut [Span],
1193        embedding_provider: &dyn EmbeddingProvider,
1194        db: &VectorDatabase,
1195        api_key: Option<String>,
1196    ) -> Result<()> {
1197        let mut batch = Vec::new();
1198        let mut batch_tokens = 0;
1199        let mut embeddings = Vec::new();
1200
1201        let digests = spans
1202            .iter()
1203            .map(|span| span.digest.clone())
1204            .collect::<Vec<_>>();
1205        let embeddings_for_digests = db
1206            .embeddings_for_digests(digests)
1207            .await
1208            .log_err()
1209            .unwrap_or_default();
1210
1211        for span in &*spans {
1212            if embeddings_for_digests.contains_key(&span.digest) {
1213                continue;
1214            };
1215
1216            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1217                let batch_embeddings = embedding_provider
1218                    .embed_batch(mem::take(&mut batch), api_key.clone())
1219                    .await?;
1220                embeddings.extend(batch_embeddings);
1221                batch_tokens = 0;
1222            }
1223
1224            batch_tokens += span.token_count;
1225            batch.push(span.content.clone());
1226        }
1227
1228        if !batch.is_empty() {
1229            let batch_embeddings = embedding_provider
1230                .embed_batch(mem::take(&mut batch), api_key)
1231                .await?;
1232
1233            embeddings.extend(batch_embeddings);
1234        }
1235
1236        let mut embeddings = embeddings.into_iter();
1237        for span in spans {
1238            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1239                Some(embedding.clone())
1240            } else {
1241                embeddings.next()
1242            };
1243            let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1244            span.embedding = Some(embedding);
1245        }
1246        Ok(())
1247    }
1248}
1249
1250impl Entity for SemanticIndex {
1251    type Event = ();
1252}
1253
1254impl Drop for JobHandle {
1255    fn drop(&mut self) {
1256        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1257            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1258            if let Some(tx) = inner.upgrade() {
1259                let mut tx = tx.lock();
1260                *tx.borrow_mut() -= 1;
1261            }
1262        }
1263    }
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268
1269    use super::*;
1270    #[test]
1271    fn test_job_handle() {
1272        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1273        let tx = Arc::new(Mutex::new(job_count_tx));
1274        let job_handle = JobHandle::new(&tx);
1275
1276        assert_eq!(1, *job_count_rx.borrow());
1277        let new_job_handle = job_handle.clone();
1278        assert_eq!(1, *job_count_rx.borrow());
1279        drop(job_handle);
1280        assert_eq!(1, *job_count_rx.borrow());
1281        drop(new_job_handle);
1282        assert_eq!(0, *job_count_rx.borrow());
1283    }
1284}