semantic_index.rs

   1mod chunking;
   2mod embedding;
   3mod project_index_debug_view;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use chunking::{chunk_text, Chunk};
   7use collections::{Bound, HashMap, HashSet};
   8pub use embedding::*;
   9use fs::Fs;
  10use futures::{future::Shared, stream::StreamExt, FutureExt};
  11use futures_batch::ChunksTimeoutStreamExt;
  12use gpui::{
  13    AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
  14    Model, ModelContext, Subscription, Task, WeakModel,
  15};
  16use heed::types::{SerdeBincode, Str};
  17use language::LanguageRegistry;
  18use parking_lot::Mutex;
  19use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
  20use serde::{Deserialize, Serialize};
  21use smol::channel;
  22use std::{
  23    cmp::Ordering,
  24    future::Future,
  25    iter,
  26    num::NonZeroUsize,
  27    ops::Range,
  28    path::{Path, PathBuf},
  29    sync::{Arc, Weak},
  30    time::{Duration, SystemTime},
  31};
  32use util::ResultExt;
  33use worktree::Snapshot;
  34
  35pub use project_index_debug_view::ProjectIndexDebugView;
  36
  37pub struct SemanticIndex {
  38    embedding_provider: Arc<dyn EmbeddingProvider>,
  39    db_connection: heed::Env,
  40    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
  41}
  42
  43impl Global for SemanticIndex {}
  44
  45impl SemanticIndex {
  46    pub async fn new(
  47        db_path: PathBuf,
  48        embedding_provider: Arc<dyn EmbeddingProvider>,
  49        cx: &mut AsyncAppContext,
  50    ) -> Result<Self> {
  51        let db_connection = cx
  52            .background_executor()
  53            .spawn(async move {
  54                std::fs::create_dir_all(&db_path)?;
  55                unsafe {
  56                    heed::EnvOpenOptions::new()
  57                        .map_size(1024 * 1024 * 1024)
  58                        .max_dbs(3000)
  59                        .open(db_path)
  60                }
  61            })
  62            .await
  63            .context("opening database connection")?;
  64
  65        Ok(SemanticIndex {
  66            db_connection,
  67            embedding_provider,
  68            project_indices: HashMap::default(),
  69        })
  70    }
  71
  72    pub fn project_index(
  73        &mut self,
  74        project: Model<Project>,
  75        cx: &mut AppContext,
  76    ) -> Model<ProjectIndex> {
  77        let project_weak = project.downgrade();
  78        project.update(cx, move |_, cx| {
  79            cx.on_release(move |_, cx| {
  80                if cx.has_global::<SemanticIndex>() {
  81                    cx.update_global::<SemanticIndex, _>(|this, _| {
  82                        this.project_indices.remove(&project_weak);
  83                    })
  84                }
  85            })
  86            .detach();
  87        });
  88
  89        self.project_indices
  90            .entry(project.downgrade())
  91            .or_insert_with(|| {
  92                cx.new_model(|cx| {
  93                    ProjectIndex::new(
  94                        project,
  95                        self.db_connection.clone(),
  96                        self.embedding_provider.clone(),
  97                        cx,
  98                    )
  99                })
 100            })
 101            .clone()
 102    }
 103}
 104
 105pub struct ProjectIndex {
 106    db_connection: heed::Env,
 107    project: WeakModel<Project>,
 108    worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
 109    language_registry: Arc<LanguageRegistry>,
 110    fs: Arc<dyn Fs>,
 111    last_status: Status,
 112    status_tx: channel::Sender<()>,
 113    embedding_provider: Arc<dyn EmbeddingProvider>,
 114    _maintain_status: Task<()>,
 115    _subscription: Subscription,
 116}
 117
 118#[derive(Clone)]
 119enum WorktreeIndexHandle {
 120    Loading {
 121        index: Shared<Task<Result<Model<WorktreeIndex>, Arc<anyhow::Error>>>>,
 122    },
 123    Loaded {
 124        index: Model<WorktreeIndex>,
 125    },
 126}
 127
 128impl ProjectIndex {
 129    fn new(
 130        project: Model<Project>,
 131        db_connection: heed::Env,
 132        embedding_provider: Arc<dyn EmbeddingProvider>,
 133        cx: &mut ModelContext<Self>,
 134    ) -> Self {
 135        let language_registry = project.read(cx).languages().clone();
 136        let fs = project.read(cx).fs().clone();
 137        let (status_tx, mut status_rx) = channel::unbounded();
 138        let mut this = ProjectIndex {
 139            db_connection,
 140            project: project.downgrade(),
 141            worktree_indices: HashMap::default(),
 142            language_registry,
 143            fs,
 144            status_tx,
 145            last_status: Status::Idle,
 146            embedding_provider,
 147            _subscription: cx.subscribe(&project, Self::handle_project_event),
 148            _maintain_status: cx.spawn(|this, mut cx| async move {
 149                while status_rx.next().await.is_some() {
 150                    if this
 151                        .update(&mut cx, |this, cx| this.update_status(cx))
 152                        .is_err()
 153                    {
 154                        break;
 155                    }
 156                }
 157            }),
 158        };
 159        this.update_worktree_indices(cx);
 160        this
 161    }
 162
 163    pub fn status(&self) -> Status {
 164        self.last_status
 165    }
 166
 167    pub fn project(&self) -> WeakModel<Project> {
 168        self.project.clone()
 169    }
 170
 171    pub fn fs(&self) -> Arc<dyn Fs> {
 172        self.fs.clone()
 173    }
 174
 175    fn handle_project_event(
 176        &mut self,
 177        _: Model<Project>,
 178        event: &project::Event,
 179        cx: &mut ModelContext<Self>,
 180    ) {
 181        match event {
 182            project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
 183                self.update_worktree_indices(cx);
 184            }
 185            _ => {}
 186        }
 187    }
 188
 189    fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
 190        let Some(project) = self.project.upgrade() else {
 191            return;
 192        };
 193
 194        let worktrees = project
 195            .read(cx)
 196            .visible_worktrees(cx)
 197            .filter_map(|worktree| {
 198                if worktree.read(cx).is_local() {
 199                    Some((worktree.entity_id(), worktree))
 200                } else {
 201                    None
 202                }
 203            })
 204            .collect::<HashMap<_, _>>();
 205
 206        self.worktree_indices
 207            .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
 208        for (worktree_id, worktree) in worktrees {
 209            self.worktree_indices.entry(worktree_id).or_insert_with(|| {
 210                let worktree_index = WorktreeIndex::load(
 211                    worktree.clone(),
 212                    self.db_connection.clone(),
 213                    self.language_registry.clone(),
 214                    self.fs.clone(),
 215                    self.status_tx.clone(),
 216                    self.embedding_provider.clone(),
 217                    cx,
 218                );
 219
 220                let load_worktree = cx.spawn(|this, mut cx| async move {
 221                    let result = match worktree_index.await {
 222                        Ok(worktree_index) => {
 223                            this.update(&mut cx, |this, _| {
 224                                this.worktree_indices.insert(
 225                                    worktree_id,
 226                                    WorktreeIndexHandle::Loaded {
 227                                        index: worktree_index.clone(),
 228                                    },
 229                                );
 230                            })?;
 231                            Ok(worktree_index)
 232                        }
 233                        Err(error) => {
 234                            this.update(&mut cx, |this, _cx| {
 235                                this.worktree_indices.remove(&worktree_id)
 236                            })?;
 237                            Err(Arc::new(error))
 238                        }
 239                    };
 240
 241                    this.update(&mut cx, |this, cx| this.update_status(cx))?;
 242
 243                    result
 244                });
 245
 246                WorktreeIndexHandle::Loading {
 247                    index: load_worktree.shared(),
 248                }
 249            });
 250        }
 251
 252        self.update_status(cx);
 253    }
 254
 255    fn update_status(&mut self, cx: &mut ModelContext<Self>) {
 256        let mut indexing_count = 0;
 257        let mut any_loading = false;
 258
 259        for index in self.worktree_indices.values_mut() {
 260            match index {
 261                WorktreeIndexHandle::Loading { .. } => {
 262                    any_loading = true;
 263                    break;
 264                }
 265                WorktreeIndexHandle::Loaded { index, .. } => {
 266                    indexing_count += index.read(cx).entry_ids_being_indexed.len();
 267                }
 268            }
 269        }
 270
 271        let status = if any_loading {
 272            Status::Loading
 273        } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
 274            Status::Scanning { remaining_count }
 275        } else {
 276            Status::Idle
 277        };
 278
 279        if status != self.last_status {
 280            self.last_status = status;
 281            cx.emit(status);
 282        }
 283    }
 284
 285    pub fn search(
 286        &self,
 287        query: String,
 288        limit: usize,
 289        cx: &AppContext,
 290    ) -> Task<Result<Vec<SearchResult>>> {
 291        let (chunks_tx, chunks_rx) = channel::bounded(1024);
 292        let mut worktree_scan_tasks = Vec::new();
 293        for worktree_index in self.worktree_indices.values() {
 294            let worktree_index = worktree_index.clone();
 295            let chunks_tx = chunks_tx.clone();
 296            worktree_scan_tasks.push(cx.spawn(|cx| async move {
 297                let index = match worktree_index {
 298                    WorktreeIndexHandle::Loading { index } => {
 299                        index.clone().await.map_err(|error| anyhow!(error))?
 300                    }
 301                    WorktreeIndexHandle::Loaded { index } => index.clone(),
 302                };
 303
 304                index
 305                    .read_with(&cx, |index, cx| {
 306                        let worktree_id = index.worktree.read(cx).id();
 307                        let db_connection = index.db_connection.clone();
 308                        let db = index.db;
 309                        cx.background_executor().spawn(async move {
 310                            let txn = db_connection
 311                                .read_txn()
 312                                .context("failed to create read transaction")?;
 313                            let db_entries = db.iter(&txn).context("failed to iterate database")?;
 314                            for db_entry in db_entries {
 315                                let (_key, db_embedded_file) = db_entry?;
 316                                for chunk in db_embedded_file.chunks {
 317                                    chunks_tx
 318                                        .send((worktree_id, db_embedded_file.path.clone(), chunk))
 319                                        .await?;
 320                                }
 321                            }
 322                            anyhow::Ok(())
 323                        })
 324                    })?
 325                    .await
 326            }));
 327        }
 328        drop(chunks_tx);
 329
 330        let project = self.project.clone();
 331        let embedding_provider = self.embedding_provider.clone();
 332        cx.spawn(|cx| async move {
 333            #[cfg(debug_assertions)]
 334            let embedding_query_start = std::time::Instant::now();
 335            log::info!("Searching for {query}");
 336
 337            let query_embeddings = embedding_provider
 338                .embed(&[TextToEmbed::new(&query)])
 339                .await?;
 340            let query_embedding = query_embeddings
 341                .into_iter()
 342                .next()
 343                .ok_or_else(|| anyhow!("no embedding for query"))?;
 344
 345            let mut results_by_worker = Vec::new();
 346            for _ in 0..cx.background_executor().num_cpus() {
 347                results_by_worker.push(Vec::<WorktreeSearchResult>::new());
 348            }
 349
 350            #[cfg(debug_assertions)]
 351            let search_start = std::time::Instant::now();
 352
 353            cx.background_executor()
 354                .scoped(|cx| {
 355                    for results in results_by_worker.iter_mut() {
 356                        cx.spawn(async {
 357                            while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
 358                                let score = chunk.embedding.similarity(&query_embedding);
 359                                let ix = match results.binary_search_by(|probe| {
 360                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
 361                                }) {
 362                                    Ok(ix) | Err(ix) => ix,
 363                                };
 364                                results.insert(
 365                                    ix,
 366                                    WorktreeSearchResult {
 367                                        worktree_id,
 368                                        path: path.clone(),
 369                                        range: chunk.chunk.range.clone(),
 370                                        score,
 371                                    },
 372                                );
 373                                results.truncate(limit);
 374                            }
 375                        });
 376                    }
 377                })
 378                .await;
 379
 380            for scan_task in futures::future::join_all(worktree_scan_tasks).await {
 381                scan_task.log_err();
 382            }
 383
 384            project.read_with(&cx, |project, cx| {
 385                let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
 386                for worker_results in results_by_worker {
 387                    search_results.extend(worker_results.into_iter().filter_map(|result| {
 388                        Some(SearchResult {
 389                            worktree: project.worktree_for_id(result.worktree_id, cx)?,
 390                            path: result.path,
 391                            range: result.range,
 392                            score: result.score,
 393                        })
 394                    }));
 395                }
 396                search_results.sort_unstable_by(|a, b| {
 397                    b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
 398                });
 399                search_results.truncate(limit);
 400
 401                #[cfg(debug_assertions)]
 402                {
 403                    let search_elapsed = search_start.elapsed();
 404                    log::debug!(
 405                        "searched {} entries in {:?}",
 406                        search_results.len(),
 407                        search_elapsed
 408                    );
 409                    let embedding_query_elapsed = embedding_query_start.elapsed();
 410                    log::debug!("embedding query took {:?}", embedding_query_elapsed);
 411                }
 412
 413                search_results
 414            })
 415        })
 416    }
 417
 418    #[cfg(test)]
 419    pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
 420        let mut result = 0;
 421        for worktree_index in self.worktree_indices.values() {
 422            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 423                result += index.read(cx).path_count()?;
 424            }
 425        }
 426        Ok(result)
 427    }
 428
 429    pub(crate) fn worktree_index(
 430        &self,
 431        worktree_id: WorktreeId,
 432        cx: &AppContext,
 433    ) -> Option<Model<WorktreeIndex>> {
 434        for index in self.worktree_indices.values() {
 435            if let WorktreeIndexHandle::Loaded { index, .. } = index {
 436                if index.read(cx).worktree.read(cx).id() == worktree_id {
 437                    return Some(index.clone());
 438                }
 439            }
 440        }
 441        None
 442    }
 443
 444    pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
 445        let mut result = self
 446            .worktree_indices
 447            .values()
 448            .filter_map(|index| {
 449                if let WorktreeIndexHandle::Loaded { index, .. } = index {
 450                    Some(index.clone())
 451                } else {
 452                    None
 453                }
 454            })
 455            .collect::<Vec<_>>();
 456        result.sort_by_key(|index| index.read(cx).worktree.read(cx).id());
 457        result
 458    }
 459}
 460
 461pub struct SearchResult {
 462    pub worktree: Model<Worktree>,
 463    pub path: Arc<Path>,
 464    pub range: Range<usize>,
 465    pub score: f32,
 466}
 467
 468pub struct WorktreeSearchResult {
 469    pub worktree_id: WorktreeId,
 470    pub path: Arc<Path>,
 471    pub range: Range<usize>,
 472    pub score: f32,
 473}
 474
 475#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 476pub enum Status {
 477    Idle,
 478    Loading,
 479    Scanning { remaining_count: NonZeroUsize },
 480}
 481
 482impl EventEmitter<Status> for ProjectIndex {}
 483
 484struct WorktreeIndex {
 485    worktree: Model<Worktree>,
 486    db_connection: heed::Env,
 487    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 488    language_registry: Arc<LanguageRegistry>,
 489    fs: Arc<dyn Fs>,
 490    embedding_provider: Arc<dyn EmbeddingProvider>,
 491    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 492    _index_entries: Task<Result<()>>,
 493    _subscription: Subscription,
 494}
 495
 496impl WorktreeIndex {
 497    pub fn load(
 498        worktree: Model<Worktree>,
 499        db_connection: heed::Env,
 500        language_registry: Arc<LanguageRegistry>,
 501        fs: Arc<dyn Fs>,
 502        status_tx: channel::Sender<()>,
 503        embedding_provider: Arc<dyn EmbeddingProvider>,
 504        cx: &mut AppContext,
 505    ) -> Task<Result<Model<Self>>> {
 506        let worktree_abs_path = worktree.read(cx).abs_path();
 507        cx.spawn(|mut cx| async move {
 508            let db = cx
 509                .background_executor()
 510                .spawn({
 511                    let db_connection = db_connection.clone();
 512                    async move {
 513                        let mut txn = db_connection.write_txn()?;
 514                        let db_name = worktree_abs_path.to_string_lossy();
 515                        let db = db_connection.create_database(&mut txn, Some(&db_name))?;
 516                        txn.commit()?;
 517                        anyhow::Ok(db)
 518                    }
 519                })
 520                .await?;
 521            cx.new_model(|cx| {
 522                Self::new(
 523                    worktree,
 524                    db_connection,
 525                    db,
 526                    status_tx,
 527                    language_registry,
 528                    fs,
 529                    embedding_provider,
 530                    cx,
 531                )
 532            })
 533        })
 534    }
 535
 536    #[allow(clippy::too_many_arguments)]
 537    fn new(
 538        worktree: Model<Worktree>,
 539        db_connection: heed::Env,
 540        db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 541        status: channel::Sender<()>,
 542        language_registry: Arc<LanguageRegistry>,
 543        fs: Arc<dyn Fs>,
 544        embedding_provider: Arc<dyn EmbeddingProvider>,
 545        cx: &mut ModelContext<Self>,
 546    ) -> Self {
 547        let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
 548        let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
 549            if let worktree::Event::UpdatedEntries(update) = event {
 550                _ = updated_entries_tx.try_send(update.clone());
 551            }
 552        });
 553
 554        Self {
 555            db_connection,
 556            db,
 557            worktree,
 558            language_registry,
 559            fs,
 560            embedding_provider,
 561            entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
 562            _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
 563            _subscription,
 564        }
 565    }
 566
 567    async fn index_entries(
 568        this: WeakModel<Self>,
 569        updated_entries: channel::Receiver<UpdatedEntriesSet>,
 570        mut cx: AsyncAppContext,
 571    ) -> Result<()> {
 572        let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
 573        index.await.log_err();
 574
 575        while let Ok(updated_entries) = updated_entries.recv().await {
 576            let index = this.update(&mut cx, |this, cx| {
 577                this.index_updated_entries(updated_entries, cx)
 578            })?;
 579            index.await.log_err();
 580        }
 581
 582        Ok(())
 583    }
 584
 585    fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
 586        let worktree = self.worktree.read(cx).snapshot();
 587        let worktree_abs_path = worktree.abs_path().clone();
 588        let scan = self.scan_entries(worktree, cx);
 589        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 590        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 591        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 592        async move {
 593            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 594            Ok(())
 595        }
 596    }
 597
 598    fn index_updated_entries(
 599        &self,
 600        updated_entries: UpdatedEntriesSet,
 601        cx: &AppContext,
 602    ) -> impl Future<Output = Result<()>> {
 603        let worktree = self.worktree.read(cx).snapshot();
 604        let worktree_abs_path = worktree.abs_path().clone();
 605        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 606        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 607        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 608        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 609        async move {
 610            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 611            Ok(())
 612        }
 613    }
 614
 615    fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
 616        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 617        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 618        let db_connection = self.db_connection.clone();
 619        let db = self.db;
 620        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 621        let task = cx.background_executor().spawn(async move {
 622            let txn = db_connection
 623                .read_txn()
 624                .context("failed to create read transaction")?;
 625            let mut db_entries = db
 626                .iter(&txn)
 627                .context("failed to create iterator")?
 628                .move_between_keys()
 629                .peekable();
 630
 631            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
 632            for entry in worktree.files(false, 0) {
 633                let entry_db_key = db_key_for_path(&entry.path);
 634
 635                let mut saved_mtime = None;
 636                while let Some(db_entry) = db_entries.peek() {
 637                    match db_entry {
 638                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
 639                            Ordering::Less => {
 640                                if let Some(deletion_range) = deletion_range.as_mut() {
 641                                    deletion_range.1 = Bound::Included(db_path);
 642                                } else {
 643                                    deletion_range =
 644                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
 645                                }
 646
 647                                db_entries.next();
 648                            }
 649                            Ordering::Equal => {
 650                                if let Some(deletion_range) = deletion_range.take() {
 651                                    deleted_entry_ranges_tx
 652                                        .send((
 653                                            deletion_range.0.map(ToString::to_string),
 654                                            deletion_range.1.map(ToString::to_string),
 655                                        ))
 656                                        .await?;
 657                                }
 658                                saved_mtime = db_embedded_file.mtime;
 659                                db_entries.next();
 660                                break;
 661                            }
 662                            Ordering::Greater => {
 663                                break;
 664                            }
 665                        },
 666                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
 667                    }
 668                }
 669
 670                if entry.mtime != saved_mtime {
 671                    let handle = entries_being_indexed.insert(entry.id);
 672                    updated_entries_tx.send((entry.clone(), handle)).await?;
 673                }
 674            }
 675
 676            if let Some(db_entry) = db_entries.next() {
 677                let (db_path, _) = db_entry?;
 678                deleted_entry_ranges_tx
 679                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
 680                    .await?;
 681            }
 682
 683            Ok(())
 684        });
 685
 686        ScanEntries {
 687            updated_entries: updated_entries_rx,
 688            deleted_entry_ranges: deleted_entry_ranges_rx,
 689            task,
 690        }
 691    }
 692
 693    fn scan_updated_entries(
 694        &self,
 695        worktree: Snapshot,
 696        updated_entries: UpdatedEntriesSet,
 697        cx: &AppContext,
 698    ) -> ScanEntries {
 699        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 700        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 701        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 702        let task = cx.background_executor().spawn(async move {
 703            for (path, entry_id, status) in updated_entries.iter() {
 704                match status {
 705                    project::PathChange::Added
 706                    | project::PathChange::Updated
 707                    | project::PathChange::AddedOrUpdated => {
 708                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
 709                            if entry.is_file() {
 710                                let handle = entries_being_indexed.insert(entry.id);
 711                                updated_entries_tx.send((entry.clone(), handle)).await?;
 712                            }
 713                        }
 714                    }
 715                    project::PathChange::Removed => {
 716                        let db_path = db_key_for_path(path);
 717                        deleted_entry_ranges_tx
 718                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
 719                            .await?;
 720                    }
 721                    project::PathChange::Loaded => {
 722                        // Do nothing.
 723                    }
 724                }
 725            }
 726
 727            Ok(())
 728        });
 729
 730        ScanEntries {
 731            updated_entries: updated_entries_rx,
 732            deleted_entry_ranges: deleted_entry_ranges_rx,
 733            task,
 734        }
 735    }
 736
 737    fn chunk_files(
 738        &self,
 739        worktree_abs_path: Arc<Path>,
 740        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 741        cx: &AppContext,
 742    ) -> ChunkFiles {
 743        let language_registry = self.language_registry.clone();
 744        let fs = self.fs.clone();
 745        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
 746        let task = cx.spawn(|cx| async move {
 747            cx.background_executor()
 748                .scoped(|cx| {
 749                    for _ in 0..cx.num_cpus() {
 750                        cx.spawn(async {
 751                            while let Ok((entry, handle)) = entries.recv().await {
 752                                let entry_abs_path = worktree_abs_path.join(&entry.path);
 753                                let Some(text) = fs
 754                                    .load(&entry_abs_path)
 755                                    .await
 756                                    .with_context(|| {
 757                                        format!("failed to read path {entry_abs_path:?}")
 758                                    })
 759                                    .log_err()
 760                                else {
 761                                    continue;
 762                                };
 763                                let language = language_registry
 764                                    .language_for_file_path(&entry.path)
 765                                    .await
 766                                    .ok();
 767                                let chunked_file = ChunkedFile {
 768                                    chunks: chunk_text(&text, language.as_ref(), &entry.path),
 769                                    handle,
 770                                    path: entry.path,
 771                                    mtime: entry.mtime,
 772                                    text,
 773                                };
 774
 775                                if chunked_files_tx.send(chunked_file).await.is_err() {
 776                                    return;
 777                                }
 778                            }
 779                        });
 780                    }
 781                })
 782                .await;
 783            Ok(())
 784        });
 785
 786        ChunkFiles {
 787            files: chunked_files_rx,
 788            task,
 789        }
 790    }
 791
 792    fn embed_files(
 793        embedding_provider: Arc<dyn EmbeddingProvider>,
 794        chunked_files: channel::Receiver<ChunkedFile>,
 795        cx: &AppContext,
 796    ) -> EmbedFiles {
 797        let embedding_provider = embedding_provider.clone();
 798        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
 799        let task = cx.background_executor().spawn(async move {
 800            let mut chunked_file_batches =
 801                chunked_files.chunks_timeout(512, Duration::from_secs(2));
 802            while let Some(chunked_files) = chunked_file_batches.next().await {
 803                // View the batch of files as a vec of chunks
 804                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
 805                // Once those are done, reassemble them back into the files in which they belong
 806                // If any embeddings fail for a file, the entire file is discarded
 807
 808                let chunks: Vec<TextToEmbed> = chunked_files
 809                    .iter()
 810                    .flat_map(|file| {
 811                        file.chunks.iter().map(|chunk| TextToEmbed {
 812                            text: &file.text[chunk.range.clone()],
 813                            digest: chunk.digest,
 814                        })
 815                    })
 816                    .collect::<Vec<_>>();
 817
 818                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
 819                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
 820                    if let Some(batch_embeddings) =
 821                        embedding_provider.embed(embedding_batch).await.log_err()
 822                    {
 823                        if batch_embeddings.len() == embedding_batch.len() {
 824                            embeddings.extend(batch_embeddings.into_iter().map(Some));
 825                            continue;
 826                        }
 827                        log::error!(
 828                            "embedding provider returned unexpected embedding count {}, expected {}",
 829                            batch_embeddings.len(), embedding_batch.len()
 830                        );
 831                    }
 832
 833                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
 834                }
 835
 836                let mut embeddings = embeddings.into_iter();
 837                for chunked_file in chunked_files {
 838                    let mut embedded_file = EmbeddedFile {
 839                        path: chunked_file.path,
 840                        mtime: chunked_file.mtime,
 841                        chunks: Vec::new(),
 842                    };
 843
 844                    let mut embedded_all_chunks = true;
 845                    for (chunk, embedding) in
 846                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
 847                    {
 848                        if let Some(embedding) = embedding {
 849                            embedded_file
 850                                .chunks
 851                                .push(EmbeddedChunk { chunk, embedding });
 852                        } else {
 853                            embedded_all_chunks = false;
 854                        }
 855                    }
 856
 857                    if embedded_all_chunks {
 858                        embedded_files_tx
 859                            .send((embedded_file, chunked_file.handle))
 860                            .await?;
 861                    }
 862                }
 863            }
 864            Ok(())
 865        });
 866
 867        EmbedFiles {
 868            files: embedded_files_rx,
 869            task,
 870        }
 871    }
 872
 873    fn persist_embeddings(
 874        &self,
 875        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 876        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 877        cx: &AppContext,
 878    ) -> Task<Result<()>> {
 879        let db_connection = self.db_connection.clone();
 880        let db = self.db;
 881        cx.background_executor().spawn(async move {
 882            while let Some(deletion_range) = deleted_entry_ranges.next().await {
 883                let mut txn = db_connection.write_txn()?;
 884                let start = deletion_range.0.as_ref().map(|start| start.as_str());
 885                let end = deletion_range.1.as_ref().map(|end| end.as_str());
 886                log::debug!("deleting embeddings in range {:?}", &(start, end));
 887                db.delete_range(&mut txn, &(start, end))?;
 888                txn.commit()?;
 889            }
 890
 891            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
 892            while let Some(embedded_files) = embedded_files.next().await {
 893                let mut txn = db_connection.write_txn()?;
 894                for (file, _) in &embedded_files {
 895                    log::debug!("saving embedding for file {:?}", file.path);
 896                    let key = db_key_for_path(&file.path);
 897                    db.put(&mut txn, &key, file)?;
 898                }
 899                txn.commit()?;
 900
 901                drop(embedded_files);
 902                log::debug!("committed");
 903            }
 904
 905            Ok(())
 906        })
 907    }
 908
 909    fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
 910        let connection = self.db_connection.clone();
 911        let db = self.db;
 912        cx.background_executor().spawn(async move {
 913            let tx = connection
 914                .read_txn()
 915                .context("failed to create read transaction")?;
 916            let result = db
 917                .iter(&tx)?
 918                .map(|entry| Ok(entry?.1.path.clone()))
 919                .collect::<Result<Vec<Arc<Path>>>>();
 920            drop(tx);
 921            result
 922        })
 923    }
 924
 925    fn chunks_for_path(
 926        &self,
 927        path: Arc<Path>,
 928        cx: &AppContext,
 929    ) -> Task<Result<Vec<EmbeddedChunk>>> {
 930        let connection = self.db_connection.clone();
 931        let db = self.db;
 932        cx.background_executor().spawn(async move {
 933            let tx = connection
 934                .read_txn()
 935                .context("failed to create read transaction")?;
 936            Ok(db
 937                .get(&tx, &db_key_for_path(&path))?
 938                .ok_or_else(|| anyhow!("no such path"))?
 939                .chunks
 940                .clone())
 941        })
 942    }
 943
 944    #[cfg(test)]
 945    fn path_count(&self) -> Result<u64> {
 946        let txn = self
 947            .db_connection
 948            .read_txn()
 949            .context("failed to create read transaction")?;
 950        Ok(self.db.len(&txn)?)
 951    }
 952}
 953
 954struct ScanEntries {
 955    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 956    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 957    task: Task<Result<()>>,
 958}
 959
 960struct ChunkFiles {
 961    files: channel::Receiver<ChunkedFile>,
 962    task: Task<Result<()>>,
 963}
 964
 965struct ChunkedFile {
 966    pub path: Arc<Path>,
 967    pub mtime: Option<SystemTime>,
 968    pub handle: IndexingEntryHandle,
 969    pub text: String,
 970    pub chunks: Vec<Chunk>,
 971}
 972
 973struct EmbedFiles {
 974    files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 975    task: Task<Result<()>>,
 976}
 977
 978#[derive(Debug, Serialize, Deserialize)]
 979struct EmbeddedFile {
 980    path: Arc<Path>,
 981    mtime: Option<SystemTime>,
 982    chunks: Vec<EmbeddedChunk>,
 983}
 984
 985#[derive(Clone, Debug, Serialize, Deserialize)]
 986struct EmbeddedChunk {
 987    chunk: Chunk,
 988    embedding: Embedding,
 989}
 990
 991/// The set of entries that are currently being indexed.
 992struct IndexingEntrySet {
 993    entry_ids: Mutex<HashSet<ProjectEntryId>>,
 994    tx: channel::Sender<()>,
 995}
 996
 997/// When dropped, removes the entry from the set of entries that are being indexed.
 998#[derive(Clone)]
 999struct IndexingEntryHandle {
1000    entry_id: ProjectEntryId,
1001    set: Weak<IndexingEntrySet>,
1002}
1003
1004impl IndexingEntrySet {
1005    fn new(tx: channel::Sender<()>) -> Self {
1006        Self {
1007            entry_ids: Default::default(),
1008            tx,
1009        }
1010    }
1011
1012    fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
1013        self.entry_ids.lock().insert(entry_id);
1014        self.tx.send_blocking(()).ok();
1015        IndexingEntryHandle {
1016            entry_id,
1017            set: Arc::downgrade(self),
1018        }
1019    }
1020
1021    pub fn len(&self) -> usize {
1022        self.entry_ids.lock().len()
1023    }
1024}
1025
1026impl Drop for IndexingEntryHandle {
1027    fn drop(&mut self) {
1028        if let Some(set) = self.set.upgrade() {
1029            set.tx.send_blocking(()).ok();
1030            set.entry_ids.lock().remove(&self.entry_id);
1031        }
1032    }
1033}
1034
1035fn db_key_for_path(path: &Arc<Path>) -> String {
1036    path.to_string_lossy().replace('/', "\0")
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::*;
1042    use futures::{future::BoxFuture, FutureExt};
1043    use gpui::TestAppContext;
1044    use language::language_settings::AllLanguageSettings;
1045    use project::Project;
1046    use settings::SettingsStore;
1047    use std::{future, path::Path, sync::Arc};
1048
1049    fn init_test(cx: &mut TestAppContext) {
1050        _ = cx.update(|cx| {
1051            let store = SettingsStore::test(cx);
1052            cx.set_global(store);
1053            language::init(cx);
1054            Project::init_settings(cx);
1055            SettingsStore::update(cx, |store, cx| {
1056                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1057            });
1058        });
1059    }
1060
1061    pub struct TestEmbeddingProvider {
1062        batch_size: usize,
1063        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1064    }
1065
1066    impl TestEmbeddingProvider {
1067        pub fn new(
1068            batch_size: usize,
1069            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1070        ) -> Self {
1071            return Self {
1072                batch_size,
1073                compute_embedding: Box::new(compute_embedding),
1074            };
1075        }
1076    }
1077
1078    impl EmbeddingProvider for TestEmbeddingProvider {
1079        fn embed<'a>(
1080            &'a self,
1081            texts: &'a [TextToEmbed<'a>],
1082        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1083            let embeddings = texts
1084                .iter()
1085                .map(|to_embed| (self.compute_embedding)(to_embed.text))
1086                .collect();
1087            future::ready(embeddings).boxed()
1088        }
1089
1090        fn batch_size(&self) -> usize {
1091            self.batch_size
1092        }
1093    }
1094
1095    #[gpui::test]
1096    async fn test_search(cx: &mut TestAppContext) {
1097        cx.executor().allow_parking();
1098
1099        init_test(cx);
1100
1101        let temp_dir = tempfile::tempdir().unwrap();
1102
1103        let mut semantic_index = SemanticIndex::new(
1104            temp_dir.path().into(),
1105            Arc::new(TestEmbeddingProvider::new(16, |text| {
1106                let mut embedding = vec![0f32; 2];
1107                // if the text contains garbage, give it a 1 in the first dimension
1108                if text.contains("garbage in") {
1109                    embedding[0] = 0.9;
1110                } else {
1111                    embedding[0] = -0.9;
1112                }
1113
1114                if text.contains("garbage out") {
1115                    embedding[1] = 0.9;
1116                } else {
1117                    embedding[1] = -0.9;
1118                }
1119
1120                Ok(Embedding::new(embedding))
1121            })),
1122            &mut cx.to_async(),
1123        )
1124        .await
1125        .unwrap();
1126
1127        let project_path = Path::new("./fixture");
1128
1129        let project = cx
1130            .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1131            .await;
1132
1133        cx.update(|cx| {
1134            let language_registry = project.read(cx).languages().clone();
1135            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1136            languages::init(language_registry, node_runtime, cx);
1137        });
1138
1139        let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1140
1141        while project_index
1142            .read_with(cx, |index, cx| index.path_count(cx))
1143            .unwrap()
1144            == 0
1145        {
1146            project_index.next_event(cx).await;
1147        }
1148
1149        let results = cx
1150            .update(|cx| {
1151                let project_index = project_index.read(cx);
1152                let query = "garbage in, garbage out";
1153                project_index.search(query.into(), 4, cx)
1154            })
1155            .await
1156            .unwrap();
1157
1158        assert!(results.len() > 1, "should have found some results");
1159
1160        for result in &results {
1161            println!("result: {:?}", result.path);
1162            println!("score: {:?}", result.score);
1163        }
1164
1165        // Find result that is greater than 0.5
1166        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1167
1168        assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1169
1170        let content = cx
1171            .update(|cx| {
1172                let worktree = search_result.worktree.read(cx);
1173                let entry_abs_path = worktree.abs_path().join(&search_result.path);
1174                let fs = project.read(cx).fs().clone();
1175                cx.background_executor()
1176                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
1177            })
1178            .await;
1179
1180        let range = search_result.range.clone();
1181        let content = content[range.clone()].to_owned();
1182
1183        assert!(content.contains("garbage in, garbage out"));
1184    }
1185
1186    #[gpui::test]
1187    async fn test_embed_files(cx: &mut TestAppContext) {
1188        cx.executor().allow_parking();
1189
1190        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1191            if text.contains('g') {
1192                Err(anyhow!("cannot embed text containing a 'g' character"))
1193            } else {
1194                Ok(Embedding::new(
1195                    ('a'..='z')
1196                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1197                        .collect(),
1198                ))
1199            }
1200        }));
1201
1202        let (indexing_progress_tx, _) = channel::unbounded();
1203        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1204
1205        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1206        chunked_files_tx
1207            .send_blocking(ChunkedFile {
1208                path: Path::new("test1.md").into(),
1209                mtime: None,
1210                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1211                text: "abcdefghijklmnop".to_string(),
1212                chunks: [0..4, 4..8, 8..12, 12..16]
1213                    .into_iter()
1214                    .map(|range| Chunk {
1215                        range,
1216                        digest: Default::default(),
1217                    })
1218                    .collect(),
1219            })
1220            .unwrap();
1221        chunked_files_tx
1222            .send_blocking(ChunkedFile {
1223                path: Path::new("test2.md").into(),
1224                mtime: None,
1225                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1226                text: "qrstuvwxyz".to_string(),
1227                chunks: [0..4, 4..8, 8..10]
1228                    .into_iter()
1229                    .map(|range| Chunk {
1230                        range,
1231                        digest: Default::default(),
1232                    })
1233                    .collect(),
1234            })
1235            .unwrap();
1236        chunked_files_tx.close();
1237
1238        let embed_files_task =
1239            cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1240        embed_files_task.task.await.unwrap();
1241
1242        let mut embedded_files_rx = embed_files_task.files;
1243        let mut embedded_files = Vec::new();
1244        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1245            embedded_files.push(embedded_file);
1246        }
1247
1248        assert_eq!(embedded_files.len(), 1);
1249        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1250        assert_eq!(
1251            embedded_files[0]
1252                .chunks
1253                .iter()
1254                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1255                .collect::<Vec<Embedding>>(),
1256            vec![
1257                (provider.compute_embedding)("qrst").unwrap(),
1258                (provider.compute_embedding)("uvwx").unwrap(),
1259                (provider.compute_embedding)("yz").unwrap(),
1260            ],
1261        );
1262    }
1263}