semantic_index.rs

   1mod chunking;
   2mod embedding;
   3mod project_index_debug_view;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use chunking::{chunk_text, Chunk};
   7use collections::{Bound, HashMap, HashSet};
   8pub use embedding::*;
   9use fs::Fs;
  10use futures::stream::StreamExt;
  11use futures_batch::ChunksTimeoutStreamExt;
  12use gpui::{
  13    AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
  14    Model, ModelContext, Subscription, Task, WeakModel,
  15};
  16use heed::types::{SerdeBincode, Str};
  17use language::LanguageRegistry;
  18use parking_lot::Mutex;
  19use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
  20use serde::{Deserialize, Serialize};
  21use smol::channel;
  22use std::{
  23    cmp::Ordering,
  24    future::Future,
  25    iter,
  26    num::NonZeroUsize,
  27    ops::Range,
  28    path::{Path, PathBuf},
  29    sync::{Arc, Weak},
  30    time::{Duration, SystemTime},
  31};
  32use util::ResultExt;
  33use worktree::LocalSnapshot;
  34
  35pub use project_index_debug_view::ProjectIndexDebugView;
  36
  37pub struct SemanticIndex {
  38    embedding_provider: Arc<dyn EmbeddingProvider>,
  39    db_connection: heed::Env,
  40    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
  41}
  42
  43impl Global for SemanticIndex {}
  44
  45impl SemanticIndex {
  46    pub async fn new(
  47        db_path: PathBuf,
  48        embedding_provider: Arc<dyn EmbeddingProvider>,
  49        cx: &mut AsyncAppContext,
  50    ) -> Result<Self> {
  51        let db_connection = cx
  52            .background_executor()
  53            .spawn(async move {
  54                std::fs::create_dir_all(&db_path)?;
  55                unsafe {
  56                    heed::EnvOpenOptions::new()
  57                        .map_size(1024 * 1024 * 1024)
  58                        .max_dbs(3000)
  59                        .open(db_path)
  60                }
  61            })
  62            .await
  63            .context("opening database connection")?;
  64
  65        Ok(SemanticIndex {
  66            db_connection,
  67            embedding_provider,
  68            project_indices: HashMap::default(),
  69        })
  70    }
  71
  72    pub fn project_index(
  73        &mut self,
  74        project: Model<Project>,
  75        cx: &mut AppContext,
  76    ) -> Model<ProjectIndex> {
  77        let project_weak = project.downgrade();
  78        project.update(cx, move |_, cx| {
  79            cx.on_release(move |_, cx| {
  80                if cx.has_global::<SemanticIndex>() {
  81                    cx.update_global::<SemanticIndex, _>(|this, _| {
  82                        this.project_indices.remove(&project_weak);
  83                    })
  84                }
  85            })
  86            .detach();
  87        });
  88
  89        self.project_indices
  90            .entry(project.downgrade())
  91            .or_insert_with(|| {
  92                cx.new_model(|cx| {
  93                    ProjectIndex::new(
  94                        project,
  95                        self.db_connection.clone(),
  96                        self.embedding_provider.clone(),
  97                        cx,
  98                    )
  99                })
 100            })
 101            .clone()
 102    }
 103}
 104
 105pub struct ProjectIndex {
 106    db_connection: heed::Env,
 107    project: WeakModel<Project>,
 108    worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
 109    language_registry: Arc<LanguageRegistry>,
 110    fs: Arc<dyn Fs>,
 111    last_status: Status,
 112    status_tx: channel::Sender<()>,
 113    embedding_provider: Arc<dyn EmbeddingProvider>,
 114    _maintain_status: Task<()>,
 115    _subscription: Subscription,
 116}
 117
 118enum WorktreeIndexHandle {
 119    Loading { _task: Task<Result<()>> },
 120    Loaded { index: Model<WorktreeIndex> },
 121}
 122
 123impl ProjectIndex {
 124    fn new(
 125        project: Model<Project>,
 126        db_connection: heed::Env,
 127        embedding_provider: Arc<dyn EmbeddingProvider>,
 128        cx: &mut ModelContext<Self>,
 129    ) -> Self {
 130        let language_registry = project.read(cx).languages().clone();
 131        let fs = project.read(cx).fs().clone();
 132        let (status_tx, mut status_rx) = channel::unbounded();
 133        let mut this = ProjectIndex {
 134            db_connection,
 135            project: project.downgrade(),
 136            worktree_indices: HashMap::default(),
 137            language_registry,
 138            fs,
 139            status_tx,
 140            last_status: Status::Idle,
 141            embedding_provider,
 142            _subscription: cx.subscribe(&project, Self::handle_project_event),
 143            _maintain_status: cx.spawn(|this, mut cx| async move {
 144                while status_rx.next().await.is_some() {
 145                    if this
 146                        .update(&mut cx, |this, cx| this.update_status(cx))
 147                        .is_err()
 148                    {
 149                        break;
 150                    }
 151                }
 152            }),
 153        };
 154        this.update_worktree_indices(cx);
 155        this
 156    }
 157
 158    pub fn status(&self) -> Status {
 159        self.last_status
 160    }
 161
 162    pub fn project(&self) -> WeakModel<Project> {
 163        self.project.clone()
 164    }
 165
 166    pub fn fs(&self) -> Arc<dyn Fs> {
 167        self.fs.clone()
 168    }
 169
 170    fn handle_project_event(
 171        &mut self,
 172        _: Model<Project>,
 173        event: &project::Event,
 174        cx: &mut ModelContext<Self>,
 175    ) {
 176        match event {
 177            project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
 178                self.update_worktree_indices(cx);
 179            }
 180            _ => {}
 181        }
 182    }
 183
 184    fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
 185        let Some(project) = self.project.upgrade() else {
 186            return;
 187        };
 188
 189        let worktrees = project
 190            .read(cx)
 191            .visible_worktrees(cx)
 192            .filter_map(|worktree| {
 193                if worktree.read(cx).is_local() {
 194                    Some((worktree.entity_id(), worktree))
 195                } else {
 196                    None
 197                }
 198            })
 199            .collect::<HashMap<_, _>>();
 200
 201        self.worktree_indices
 202            .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
 203        for (worktree_id, worktree) in worktrees {
 204            self.worktree_indices.entry(worktree_id).or_insert_with(|| {
 205                let worktree_index = WorktreeIndex::load(
 206                    worktree.clone(),
 207                    self.db_connection.clone(),
 208                    self.language_registry.clone(),
 209                    self.fs.clone(),
 210                    self.status_tx.clone(),
 211                    self.embedding_provider.clone(),
 212                    cx,
 213                );
 214
 215                let load_worktree = cx.spawn(|this, mut cx| async move {
 216                    if let Some(worktree_index) = worktree_index.await.log_err() {
 217                        this.update(&mut cx, |this, _| {
 218                            this.worktree_indices.insert(
 219                                worktree_id,
 220                                WorktreeIndexHandle::Loaded {
 221                                    index: worktree_index,
 222                                },
 223                            );
 224                        })?;
 225                    } else {
 226                        this.update(&mut cx, |this, _cx| {
 227                            this.worktree_indices.remove(&worktree_id)
 228                        })?;
 229                    }
 230
 231                    this.update(&mut cx, |this, cx| this.update_status(cx))
 232                });
 233
 234                WorktreeIndexHandle::Loading {
 235                    _task: load_worktree,
 236                }
 237            });
 238        }
 239
 240        self.update_status(cx);
 241    }
 242
 243    fn update_status(&mut self, cx: &mut ModelContext<Self>) {
 244        let mut indexing_count = 0;
 245        let mut any_loading = false;
 246
 247        for index in self.worktree_indices.values_mut() {
 248            match index {
 249                WorktreeIndexHandle::Loading { .. } => {
 250                    any_loading = true;
 251                    break;
 252                }
 253                WorktreeIndexHandle::Loaded { index, .. } => {
 254                    indexing_count += index.read(cx).entry_ids_being_indexed.len();
 255                }
 256            }
 257        }
 258
 259        let status = if any_loading {
 260            Status::Loading
 261        } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
 262            Status::Scanning { remaining_count }
 263        } else {
 264            Status::Idle
 265        };
 266
 267        if status != self.last_status {
 268            self.last_status = status;
 269            cx.emit(status);
 270        }
 271    }
 272
 273    pub fn search(
 274        &self,
 275        query: String,
 276        limit: usize,
 277        cx: &AppContext,
 278    ) -> Task<Result<Vec<SearchResult>>> {
 279        let (chunks_tx, chunks_rx) = channel::bounded(1024);
 280        let mut worktree_scan_tasks = Vec::new();
 281        for worktree_index in self.worktree_indices.values() {
 282            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 283                let chunks_tx = chunks_tx.clone();
 284                index.read_with(cx, |index, cx| {
 285                    let worktree_id = index.worktree.read(cx).id();
 286                    let db_connection = index.db_connection.clone();
 287                    let db = index.db;
 288                    worktree_scan_tasks.push(cx.background_executor().spawn({
 289                        async move {
 290                            let txn = db_connection
 291                                .read_txn()
 292                                .context("failed to create read transaction")?;
 293                            let db_entries = db.iter(&txn).context("failed to iterate database")?;
 294                            for db_entry in db_entries {
 295                                let (_key, db_embedded_file) = db_entry?;
 296                                for chunk in db_embedded_file.chunks {
 297                                    chunks_tx
 298                                        .send((worktree_id, db_embedded_file.path.clone(), chunk))
 299                                        .await?;
 300                                }
 301                            }
 302                            anyhow::Ok(())
 303                        }
 304                    }));
 305                })
 306            }
 307        }
 308        drop(chunks_tx);
 309
 310        let project = self.project.clone();
 311        let embedding_provider = self.embedding_provider.clone();
 312        cx.spawn(|cx| async move {
 313            #[cfg(debug_assertions)]
 314            let embedding_query_start = std::time::Instant::now();
 315            log::info!("Searching for {query}");
 316
 317            let query_embeddings = embedding_provider
 318                .embed(&[TextToEmbed::new(&query)])
 319                .await?;
 320            let query_embedding = query_embeddings
 321                .into_iter()
 322                .next()
 323                .ok_or_else(|| anyhow!("no embedding for query"))?;
 324
 325            let mut results_by_worker = Vec::new();
 326            for _ in 0..cx.background_executor().num_cpus() {
 327                results_by_worker.push(Vec::<WorktreeSearchResult>::new());
 328            }
 329
 330            #[cfg(debug_assertions)]
 331            let search_start = std::time::Instant::now();
 332
 333            cx.background_executor()
 334                .scoped(|cx| {
 335                    for results in results_by_worker.iter_mut() {
 336                        cx.spawn(async {
 337                            while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
 338                                let score = chunk.embedding.similarity(&query_embedding);
 339                                let ix = match results.binary_search_by(|probe| {
 340                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
 341                                }) {
 342                                    Ok(ix) | Err(ix) => ix,
 343                                };
 344                                results.insert(
 345                                    ix,
 346                                    WorktreeSearchResult {
 347                                        worktree_id,
 348                                        path: path.clone(),
 349                                        range: chunk.chunk.range.clone(),
 350                                        score,
 351                                    },
 352                                );
 353                                results.truncate(limit);
 354                            }
 355                        });
 356                    }
 357                })
 358                .await;
 359
 360            futures::future::try_join_all(worktree_scan_tasks).await?;
 361
 362            project.read_with(&cx, |project, cx| {
 363                let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
 364                for worker_results in results_by_worker {
 365                    search_results.extend(worker_results.into_iter().filter_map(|result| {
 366                        Some(SearchResult {
 367                            worktree: project.worktree_for_id(result.worktree_id, cx)?,
 368                            path: result.path,
 369                            range: result.range,
 370                            score: result.score,
 371                        })
 372                    }));
 373                }
 374                search_results.sort_unstable_by(|a, b| {
 375                    b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
 376                });
 377                search_results.truncate(limit);
 378
 379                #[cfg(debug_assertions)]
 380                {
 381                    let search_elapsed = search_start.elapsed();
 382                    log::debug!(
 383                        "searched {} entries in {:?}",
 384                        search_results.len(),
 385                        search_elapsed
 386                    );
 387                    let embedding_query_elapsed = embedding_query_start.elapsed();
 388                    log::debug!("embedding query took {:?}", embedding_query_elapsed);
 389                }
 390
 391                search_results
 392            })
 393        })
 394    }
 395
 396    #[cfg(test)]
 397    pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
 398        let mut result = 0;
 399        for worktree_index in self.worktree_indices.values() {
 400            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 401                result += index.read(cx).path_count()?;
 402            }
 403        }
 404        Ok(result)
 405    }
 406
 407    pub(crate) fn worktree_index(
 408        &self,
 409        worktree_id: WorktreeId,
 410        cx: &AppContext,
 411    ) -> Option<Model<WorktreeIndex>> {
 412        for index in self.worktree_indices.values() {
 413            if let WorktreeIndexHandle::Loaded { index, .. } = index {
 414                if index.read(cx).worktree.read(cx).id() == worktree_id {
 415                    return Some(index.clone());
 416                }
 417            }
 418        }
 419        None
 420    }
 421
 422    pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
 423        let mut result = self
 424            .worktree_indices
 425            .values()
 426            .filter_map(|index| {
 427                if let WorktreeIndexHandle::Loaded { index, .. } = index {
 428                    Some(index.clone())
 429                } else {
 430                    None
 431                }
 432            })
 433            .collect::<Vec<_>>();
 434        result.sort_by_key(|index| index.read(cx).worktree.read(cx).id());
 435        result
 436    }
 437}
 438
 439pub struct SearchResult {
 440    pub worktree: Model<Worktree>,
 441    pub path: Arc<Path>,
 442    pub range: Range<usize>,
 443    pub score: f32,
 444}
 445
 446pub struct WorktreeSearchResult {
 447    pub worktree_id: WorktreeId,
 448    pub path: Arc<Path>,
 449    pub range: Range<usize>,
 450    pub score: f32,
 451}
 452
 453#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 454pub enum Status {
 455    Idle,
 456    Loading,
 457    Scanning { remaining_count: NonZeroUsize },
 458}
 459
 460impl EventEmitter<Status> for ProjectIndex {}
 461
 462struct WorktreeIndex {
 463    worktree: Model<Worktree>,
 464    db_connection: heed::Env,
 465    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 466    language_registry: Arc<LanguageRegistry>,
 467    fs: Arc<dyn Fs>,
 468    embedding_provider: Arc<dyn EmbeddingProvider>,
 469    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 470    _index_entries: Task<Result<()>>,
 471    _subscription: Subscription,
 472}
 473
 474impl WorktreeIndex {
 475    pub fn load(
 476        worktree: Model<Worktree>,
 477        db_connection: heed::Env,
 478        language_registry: Arc<LanguageRegistry>,
 479        fs: Arc<dyn Fs>,
 480        status_tx: channel::Sender<()>,
 481        embedding_provider: Arc<dyn EmbeddingProvider>,
 482        cx: &mut AppContext,
 483    ) -> Task<Result<Model<Self>>> {
 484        let worktree_abs_path = worktree.read(cx).abs_path();
 485        cx.spawn(|mut cx| async move {
 486            let db = cx
 487                .background_executor()
 488                .spawn({
 489                    let db_connection = db_connection.clone();
 490                    async move {
 491                        let mut txn = db_connection.write_txn()?;
 492                        let db_name = worktree_abs_path.to_string_lossy();
 493                        let db = db_connection.create_database(&mut txn, Some(&db_name))?;
 494                        txn.commit()?;
 495                        anyhow::Ok(db)
 496                    }
 497                })
 498                .await?;
 499            cx.new_model(|cx| {
 500                Self::new(
 501                    worktree,
 502                    db_connection,
 503                    db,
 504                    status_tx,
 505                    language_registry,
 506                    fs,
 507                    embedding_provider,
 508                    cx,
 509                )
 510            })
 511        })
 512    }
 513
 514    #[allow(clippy::too_many_arguments)]
 515    fn new(
 516        worktree: Model<Worktree>,
 517        db_connection: heed::Env,
 518        db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 519        status: channel::Sender<()>,
 520        language_registry: Arc<LanguageRegistry>,
 521        fs: Arc<dyn Fs>,
 522        embedding_provider: Arc<dyn EmbeddingProvider>,
 523        cx: &mut ModelContext<Self>,
 524    ) -> Self {
 525        let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
 526        let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
 527            if let worktree::Event::UpdatedEntries(update) = event {
 528                _ = updated_entries_tx.try_send(update.clone());
 529            }
 530        });
 531
 532        Self {
 533            db_connection,
 534            db,
 535            worktree,
 536            language_registry,
 537            fs,
 538            embedding_provider,
 539            entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
 540            _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
 541            _subscription,
 542        }
 543    }
 544
 545    async fn index_entries(
 546        this: WeakModel<Self>,
 547        updated_entries: channel::Receiver<UpdatedEntriesSet>,
 548        mut cx: AsyncAppContext,
 549    ) -> Result<()> {
 550        let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
 551        index.await.log_err();
 552
 553        while let Ok(updated_entries) = updated_entries.recv().await {
 554            let index = this.update(&mut cx, |this, cx| {
 555                this.index_updated_entries(updated_entries, cx)
 556            })?;
 557            index.await.log_err();
 558        }
 559
 560        Ok(())
 561    }
 562
 563    fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
 564        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
 565        let worktree_abs_path = worktree.abs_path().clone();
 566        let scan = self.scan_entries(worktree.clone(), cx);
 567        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 568        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 569        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 570        async move {
 571            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 572            Ok(())
 573        }
 574    }
 575
 576    fn index_updated_entries(
 577        &self,
 578        updated_entries: UpdatedEntriesSet,
 579        cx: &AppContext,
 580    ) -> impl Future<Output = Result<()>> {
 581        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
 582        let worktree_abs_path = worktree.abs_path().clone();
 583        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 584        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 585        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 586        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 587        async move {
 588            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 589            Ok(())
 590        }
 591    }
 592
 593    fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
 594        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 595        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 596        let db_connection = self.db_connection.clone();
 597        let db = self.db;
 598        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 599        let task = cx.background_executor().spawn(async move {
 600            let txn = db_connection
 601                .read_txn()
 602                .context("failed to create read transaction")?;
 603            let mut db_entries = db
 604                .iter(&txn)
 605                .context("failed to create iterator")?
 606                .move_between_keys()
 607                .peekable();
 608
 609            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
 610            for entry in worktree.files(false, 0) {
 611                let entry_db_key = db_key_for_path(&entry.path);
 612
 613                let mut saved_mtime = None;
 614                while let Some(db_entry) = db_entries.peek() {
 615                    match db_entry {
 616                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
 617                            Ordering::Less => {
 618                                if let Some(deletion_range) = deletion_range.as_mut() {
 619                                    deletion_range.1 = Bound::Included(db_path);
 620                                } else {
 621                                    deletion_range =
 622                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
 623                                }
 624
 625                                db_entries.next();
 626                            }
 627                            Ordering::Equal => {
 628                                if let Some(deletion_range) = deletion_range.take() {
 629                                    deleted_entry_ranges_tx
 630                                        .send((
 631                                            deletion_range.0.map(ToString::to_string),
 632                                            deletion_range.1.map(ToString::to_string),
 633                                        ))
 634                                        .await?;
 635                                }
 636                                saved_mtime = db_embedded_file.mtime;
 637                                db_entries.next();
 638                                break;
 639                            }
 640                            Ordering::Greater => {
 641                                break;
 642                            }
 643                        },
 644                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
 645                    }
 646                }
 647
 648                if entry.mtime != saved_mtime {
 649                    let handle = entries_being_indexed.insert(entry.id);
 650                    updated_entries_tx.send((entry.clone(), handle)).await?;
 651                }
 652            }
 653
 654            if let Some(db_entry) = db_entries.next() {
 655                let (db_path, _) = db_entry?;
 656                deleted_entry_ranges_tx
 657                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
 658                    .await?;
 659            }
 660
 661            Ok(())
 662        });
 663
 664        ScanEntries {
 665            updated_entries: updated_entries_rx,
 666            deleted_entry_ranges: deleted_entry_ranges_rx,
 667            task,
 668        }
 669    }
 670
 671    fn scan_updated_entries(
 672        &self,
 673        worktree: LocalSnapshot,
 674        updated_entries: UpdatedEntriesSet,
 675        cx: &AppContext,
 676    ) -> ScanEntries {
 677        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 678        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 679        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 680        let task = cx.background_executor().spawn(async move {
 681            for (path, entry_id, status) in updated_entries.iter() {
 682                match status {
 683                    project::PathChange::Added
 684                    | project::PathChange::Updated
 685                    | project::PathChange::AddedOrUpdated => {
 686                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
 687                            if entry.is_file() {
 688                                let handle = entries_being_indexed.insert(entry.id);
 689                                updated_entries_tx.send((entry.clone(), handle)).await?;
 690                            }
 691                        }
 692                    }
 693                    project::PathChange::Removed => {
 694                        let db_path = db_key_for_path(path);
 695                        deleted_entry_ranges_tx
 696                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
 697                            .await?;
 698                    }
 699                    project::PathChange::Loaded => {
 700                        // Do nothing.
 701                    }
 702                }
 703            }
 704
 705            Ok(())
 706        });
 707
 708        ScanEntries {
 709            updated_entries: updated_entries_rx,
 710            deleted_entry_ranges: deleted_entry_ranges_rx,
 711            task,
 712        }
 713    }
 714
 715    fn chunk_files(
 716        &self,
 717        worktree_abs_path: Arc<Path>,
 718        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 719        cx: &AppContext,
 720    ) -> ChunkFiles {
 721        let language_registry = self.language_registry.clone();
 722        let fs = self.fs.clone();
 723        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
 724        let task = cx.spawn(|cx| async move {
 725            cx.background_executor()
 726                .scoped(|cx| {
 727                    for _ in 0..cx.num_cpus() {
 728                        cx.spawn(async {
 729                            while let Ok((entry, handle)) = entries.recv().await {
 730                                let entry_abs_path = worktree_abs_path.join(&entry.path);
 731                                let Some(text) = fs
 732                                    .load(&entry_abs_path)
 733                                    .await
 734                                    .with_context(|| {
 735                                        format!("failed to read path {entry_abs_path:?}")
 736                                    })
 737                                    .log_err()
 738                                else {
 739                                    continue;
 740                                };
 741                                let language = language_registry
 742                                    .language_for_file_path(&entry.path)
 743                                    .await
 744                                    .ok();
 745                                let chunked_file = ChunkedFile {
 746                                    chunks: chunk_text(&text, language.as_ref(), &entry.path),
 747                                    handle,
 748                                    path: entry.path,
 749                                    mtime: entry.mtime,
 750                                    text,
 751                                };
 752
 753                                if chunked_files_tx.send(chunked_file).await.is_err() {
 754                                    return;
 755                                }
 756                            }
 757                        });
 758                    }
 759                })
 760                .await;
 761            Ok(())
 762        });
 763
 764        ChunkFiles {
 765            files: chunked_files_rx,
 766            task,
 767        }
 768    }
 769
 770    fn embed_files(
 771        embedding_provider: Arc<dyn EmbeddingProvider>,
 772        chunked_files: channel::Receiver<ChunkedFile>,
 773        cx: &AppContext,
 774    ) -> EmbedFiles {
 775        let embedding_provider = embedding_provider.clone();
 776        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
 777        let task = cx.background_executor().spawn(async move {
 778            let mut chunked_file_batches =
 779                chunked_files.chunks_timeout(512, Duration::from_secs(2));
 780            while let Some(chunked_files) = chunked_file_batches.next().await {
 781                // View the batch of files as a vec of chunks
 782                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
 783                // Once those are done, reassemble them back into the files in which they belong
 784                // If any embeddings fail for a file, the entire file is discarded
 785
 786                let chunks: Vec<TextToEmbed> = chunked_files
 787                    .iter()
 788                    .flat_map(|file| {
 789                        file.chunks.iter().map(|chunk| TextToEmbed {
 790                            text: &file.text[chunk.range.clone()],
 791                            digest: chunk.digest,
 792                        })
 793                    })
 794                    .collect::<Vec<_>>();
 795
 796                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
 797                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
 798                    if let Some(batch_embeddings) =
 799                        embedding_provider.embed(embedding_batch).await.log_err()
 800                    {
 801                        if batch_embeddings.len() == embedding_batch.len() {
 802                            embeddings.extend(batch_embeddings.into_iter().map(Some));
 803                            continue;
 804                        }
 805                        log::error!(
 806                            "embedding provider returned unexpected embedding count {}, expected {}",
 807                            batch_embeddings.len(), embedding_batch.len()
 808                        );
 809                    }
 810
 811                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
 812                }
 813
 814                let mut embeddings = embeddings.into_iter();
 815                for chunked_file in chunked_files {
 816                    let mut embedded_file = EmbeddedFile {
 817                        path: chunked_file.path,
 818                        mtime: chunked_file.mtime,
 819                        chunks: Vec::new(),
 820                    };
 821
 822                    let mut embedded_all_chunks = true;
 823                    for (chunk, embedding) in
 824                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
 825                    {
 826                        if let Some(embedding) = embedding {
 827                            embedded_file
 828                                .chunks
 829                                .push(EmbeddedChunk { chunk, embedding });
 830                        } else {
 831                            embedded_all_chunks = false;
 832                        }
 833                    }
 834
 835                    if embedded_all_chunks {
 836                        embedded_files_tx
 837                            .send((embedded_file, chunked_file.handle))
 838                            .await?;
 839                    }
 840                }
 841            }
 842            Ok(())
 843        });
 844
 845        EmbedFiles {
 846            files: embedded_files_rx,
 847            task,
 848        }
 849    }
 850
 851    fn persist_embeddings(
 852        &self,
 853        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 854        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 855        cx: &AppContext,
 856    ) -> Task<Result<()>> {
 857        let db_connection = self.db_connection.clone();
 858        let db = self.db;
 859        cx.background_executor().spawn(async move {
 860            while let Some(deletion_range) = deleted_entry_ranges.next().await {
 861                let mut txn = db_connection.write_txn()?;
 862                let start = deletion_range.0.as_ref().map(|start| start.as_str());
 863                let end = deletion_range.1.as_ref().map(|end| end.as_str());
 864                log::debug!("deleting embeddings in range {:?}", &(start, end));
 865                db.delete_range(&mut txn, &(start, end))?;
 866                txn.commit()?;
 867            }
 868
 869            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
 870            while let Some(embedded_files) = embedded_files.next().await {
 871                let mut txn = db_connection.write_txn()?;
 872                for (file, _) in &embedded_files {
 873                    log::debug!("saving embedding for file {:?}", file.path);
 874                    let key = db_key_for_path(&file.path);
 875                    db.put(&mut txn, &key, file)?;
 876                }
 877                txn.commit()?;
 878
 879                drop(embedded_files);
 880                log::debug!("committed");
 881            }
 882
 883            Ok(())
 884        })
 885    }
 886
 887    fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
 888        let connection = self.db_connection.clone();
 889        let db = self.db;
 890        cx.background_executor().spawn(async move {
 891            let tx = connection
 892                .read_txn()
 893                .context("failed to create read transaction")?;
 894            let result = db
 895                .iter(&tx)?
 896                .map(|entry| Ok(entry?.1.path.clone()))
 897                .collect::<Result<Vec<Arc<Path>>>>();
 898            drop(tx);
 899            result
 900        })
 901    }
 902
 903    fn chunks_for_path(
 904        &self,
 905        path: Arc<Path>,
 906        cx: &AppContext,
 907    ) -> Task<Result<Vec<EmbeddedChunk>>> {
 908        let connection = self.db_connection.clone();
 909        let db = self.db;
 910        cx.background_executor().spawn(async move {
 911            let tx = connection
 912                .read_txn()
 913                .context("failed to create read transaction")?;
 914            Ok(db
 915                .get(&tx, &db_key_for_path(&path))?
 916                .ok_or_else(|| anyhow!("no such path"))?
 917                .chunks
 918                .clone())
 919        })
 920    }
 921
 922    #[cfg(test)]
 923    fn path_count(&self) -> Result<u64> {
 924        let txn = self
 925            .db_connection
 926            .read_txn()
 927            .context("failed to create read transaction")?;
 928        Ok(self.db.len(&txn)?)
 929    }
 930}
 931
 932struct ScanEntries {
 933    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 934    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 935    task: Task<Result<()>>,
 936}
 937
 938struct ChunkFiles {
 939    files: channel::Receiver<ChunkedFile>,
 940    task: Task<Result<()>>,
 941}
 942
 943struct ChunkedFile {
 944    pub path: Arc<Path>,
 945    pub mtime: Option<SystemTime>,
 946    pub handle: IndexingEntryHandle,
 947    pub text: String,
 948    pub chunks: Vec<Chunk>,
 949}
 950
 951struct EmbedFiles {
 952    files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 953    task: Task<Result<()>>,
 954}
 955
 956#[derive(Debug, Serialize, Deserialize)]
 957struct EmbeddedFile {
 958    path: Arc<Path>,
 959    mtime: Option<SystemTime>,
 960    chunks: Vec<EmbeddedChunk>,
 961}
 962
 963#[derive(Clone, Debug, Serialize, Deserialize)]
 964struct EmbeddedChunk {
 965    chunk: Chunk,
 966    embedding: Embedding,
 967}
 968
 969/// The set of entries that are currently being indexed.
 970struct IndexingEntrySet {
 971    entry_ids: Mutex<HashSet<ProjectEntryId>>,
 972    tx: channel::Sender<()>,
 973}
 974
 975/// When dropped, removes the entry from the set of entries that are being indexed.
 976#[derive(Clone)]
 977struct IndexingEntryHandle {
 978    entry_id: ProjectEntryId,
 979    set: Weak<IndexingEntrySet>,
 980}
 981
 982impl IndexingEntrySet {
 983    fn new(tx: channel::Sender<()>) -> Self {
 984        Self {
 985            entry_ids: Default::default(),
 986            tx,
 987        }
 988    }
 989
 990    fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
 991        self.entry_ids.lock().insert(entry_id);
 992        self.tx.send_blocking(()).ok();
 993        IndexingEntryHandle {
 994            entry_id,
 995            set: Arc::downgrade(self),
 996        }
 997    }
 998
 999    pub fn len(&self) -> usize {
1000        self.entry_ids.lock().len()
1001    }
1002}
1003
1004impl Drop for IndexingEntryHandle {
1005    fn drop(&mut self) {
1006        if let Some(set) = self.set.upgrade() {
1007            set.tx.send_blocking(()).ok();
1008            set.entry_ids.lock().remove(&self.entry_id);
1009        }
1010    }
1011}
1012
1013fn db_key_for_path(path: &Arc<Path>) -> String {
1014    path.to_string_lossy().replace('/', "\0")
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use super::*;
1020    use futures::{future::BoxFuture, FutureExt};
1021    use gpui::TestAppContext;
1022    use language::language_settings::AllLanguageSettings;
1023    use project::Project;
1024    use settings::SettingsStore;
1025    use std::{future, path::Path, sync::Arc};
1026
1027    fn init_test(cx: &mut TestAppContext) {
1028        _ = cx.update(|cx| {
1029            let store = SettingsStore::test(cx);
1030            cx.set_global(store);
1031            language::init(cx);
1032            Project::init_settings(cx);
1033            SettingsStore::update(cx, |store, cx| {
1034                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1035            });
1036        });
1037    }
1038
1039    pub struct TestEmbeddingProvider {
1040        batch_size: usize,
1041        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1042    }
1043
1044    impl TestEmbeddingProvider {
1045        pub fn new(
1046            batch_size: usize,
1047            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1048        ) -> Self {
1049            return Self {
1050                batch_size,
1051                compute_embedding: Box::new(compute_embedding),
1052            };
1053        }
1054    }
1055
1056    impl EmbeddingProvider for TestEmbeddingProvider {
1057        fn embed<'a>(
1058            &'a self,
1059            texts: &'a [TextToEmbed<'a>],
1060        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1061            let embeddings = texts
1062                .iter()
1063                .map(|to_embed| (self.compute_embedding)(to_embed.text))
1064                .collect();
1065            future::ready(embeddings).boxed()
1066        }
1067
1068        fn batch_size(&self) -> usize {
1069            self.batch_size
1070        }
1071    }
1072
1073    #[gpui::test]
1074    async fn test_search(cx: &mut TestAppContext) {
1075        cx.executor().allow_parking();
1076
1077        init_test(cx);
1078
1079        let temp_dir = tempfile::tempdir().unwrap();
1080
1081        let mut semantic_index = SemanticIndex::new(
1082            temp_dir.path().into(),
1083            Arc::new(TestEmbeddingProvider::new(16, |text| {
1084                let mut embedding = vec![0f32; 2];
1085                // if the text contains garbage, give it a 1 in the first dimension
1086                if text.contains("garbage in") {
1087                    embedding[0] = 0.9;
1088                } else {
1089                    embedding[0] = -0.9;
1090                }
1091
1092                if text.contains("garbage out") {
1093                    embedding[1] = 0.9;
1094                } else {
1095                    embedding[1] = -0.9;
1096                }
1097
1098                Ok(Embedding::new(embedding))
1099            })),
1100            &mut cx.to_async(),
1101        )
1102        .await
1103        .unwrap();
1104
1105        let project_path = Path::new("./fixture");
1106
1107        let project = cx
1108            .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1109            .await;
1110
1111        cx.update(|cx| {
1112            let language_registry = project.read(cx).languages().clone();
1113            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1114            languages::init(language_registry, node_runtime, cx);
1115        });
1116
1117        let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1118
1119        while project_index
1120            .read_with(cx, |index, cx| index.path_count(cx))
1121            .unwrap()
1122            == 0
1123        {
1124            project_index.next_event(cx).await;
1125        }
1126
1127        let results = cx
1128            .update(|cx| {
1129                let project_index = project_index.read(cx);
1130                let query = "garbage in, garbage out";
1131                project_index.search(query.into(), 4, cx)
1132            })
1133            .await
1134            .unwrap();
1135
1136        assert!(results.len() > 1, "should have found some results");
1137
1138        for result in &results {
1139            println!("result: {:?}", result.path);
1140            println!("score: {:?}", result.score);
1141        }
1142
1143        // Find result that is greater than 0.5
1144        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1145
1146        assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1147
1148        let content = cx
1149            .update(|cx| {
1150                let worktree = search_result.worktree.read(cx);
1151                let entry_abs_path = worktree.abs_path().join(&search_result.path);
1152                let fs = project.read(cx).fs().clone();
1153                cx.background_executor()
1154                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
1155            })
1156            .await;
1157
1158        let range = search_result.range.clone();
1159        let content = content[range.clone()].to_owned();
1160
1161        assert!(content.contains("garbage in, garbage out"));
1162    }
1163
1164    #[gpui::test]
1165    async fn test_embed_files(cx: &mut TestAppContext) {
1166        cx.executor().allow_parking();
1167
1168        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1169            if text.contains('g') {
1170                Err(anyhow!("cannot embed text containing a 'g' character"))
1171            } else {
1172                Ok(Embedding::new(
1173                    ('a'..'z')
1174                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1175                        .collect(),
1176                ))
1177            }
1178        }));
1179
1180        let (indexing_progress_tx, _) = channel::unbounded();
1181        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1182
1183        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1184        chunked_files_tx
1185            .send_blocking(ChunkedFile {
1186                path: Path::new("test1.md").into(),
1187                mtime: None,
1188                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1189                text: "abcdefghijklmnop".to_string(),
1190                chunks: [0..4, 4..8, 8..12, 12..16]
1191                    .into_iter()
1192                    .map(|range| Chunk {
1193                        range,
1194                        digest: Default::default(),
1195                    })
1196                    .collect(),
1197            })
1198            .unwrap();
1199        chunked_files_tx
1200            .send_blocking(ChunkedFile {
1201                path: Path::new("test2.md").into(),
1202                mtime: None,
1203                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1204                text: "qrstuvwxyz".to_string(),
1205                chunks: [0..4, 4..8, 8..10]
1206                    .into_iter()
1207                    .map(|range| Chunk {
1208                        range,
1209                        digest: Default::default(),
1210                    })
1211                    .collect(),
1212            })
1213            .unwrap();
1214        chunked_files_tx.close();
1215
1216        let embed_files_task =
1217            cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1218        embed_files_task.task.await.unwrap();
1219
1220        let mut embedded_files_rx = embed_files_task.files;
1221        let mut embedded_files = Vec::new();
1222        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1223            embedded_files.push(embedded_file);
1224        }
1225
1226        assert_eq!(embedded_files.len(), 1);
1227        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1228        assert_eq!(
1229            embedded_files[0]
1230                .chunks
1231                .iter()
1232                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1233                .collect::<Vec<Embedding>>(),
1234            vec![
1235                (provider.compute_embedding)("qrst").unwrap(),
1236                (provider.compute_embedding)("uvwx").unwrap(),
1237                (provider.compute_embedding)("yz").unwrap(),
1238            ],
1239        );
1240    }
1241}