semantic_index.rs

  1mod chunking;
  2mod embedding;
  3
  4use anyhow::{anyhow, Context as _, Result};
  5use chunking::{chunk_text, Chunk};
  6use collections::{Bound, HashMap};
  7pub use embedding::*;
  8use fs::Fs;
  9use futures::stream::StreamExt;
 10use futures_batch::ChunksTimeoutStreamExt;
 11use gpui::{
 12    AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext,
 13    Subscription, Task, WeakModel,
 14};
 15use heed::types::{SerdeBincode, Str};
 16use language::LanguageRegistry;
 17use project::{Entry, Project, UpdatedEntriesSet, Worktree};
 18use serde::{Deserialize, Serialize};
 19use smol::channel;
 20use std::{
 21    cmp::Ordering,
 22    future::Future,
 23    ops::Range,
 24    path::{Path, PathBuf},
 25    sync::Arc,
 26    time::{Duration, SystemTime},
 27};
 28use util::ResultExt;
 29use worktree::LocalSnapshot;
 30
 31pub struct SemanticIndex {
 32    embedding_provider: Arc<dyn EmbeddingProvider>,
 33    db_connection: heed::Env,
 34    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
 35}
 36
 37impl Global for SemanticIndex {}
 38
 39impl SemanticIndex {
 40    pub async fn new(
 41        db_path: PathBuf,
 42        embedding_provider: Arc<dyn EmbeddingProvider>,
 43        cx: &mut AsyncAppContext,
 44    ) -> Result<Self> {
 45        let db_connection = cx
 46            .background_executor()
 47            .spawn(async move {
 48                std::fs::create_dir_all(&db_path)?;
 49                unsafe {
 50                    heed::EnvOpenOptions::new()
 51                        .map_size(1024 * 1024 * 1024)
 52                        .max_dbs(3000)
 53                        .open(db_path)
 54                }
 55            })
 56            .await
 57            .context("opening database connection")?;
 58
 59        Ok(SemanticIndex {
 60            db_connection,
 61            embedding_provider,
 62            project_indices: HashMap::default(),
 63        })
 64    }
 65
 66    pub fn project_index(
 67        &mut self,
 68        project: Model<Project>,
 69        cx: &mut AppContext,
 70    ) -> Model<ProjectIndex> {
 71        self.project_indices
 72            .entry(project.downgrade())
 73            .or_insert_with(|| {
 74                cx.new_model(|cx| {
 75                    ProjectIndex::new(
 76                        project,
 77                        self.db_connection.clone(),
 78                        self.embedding_provider.clone(),
 79                        cx,
 80                    )
 81                })
 82            })
 83            .clone()
 84    }
 85}
 86
 87pub struct ProjectIndex {
 88    db_connection: heed::Env,
 89    project: Model<Project>,
 90    worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
 91    language_registry: Arc<LanguageRegistry>,
 92    fs: Arc<dyn Fs>,
 93    pub last_status: Status,
 94    embedding_provider: Arc<dyn EmbeddingProvider>,
 95    _subscription: Subscription,
 96}
 97
 98enum WorktreeIndexHandle {
 99    Loading {
100        _task: Task<Result<()>>,
101    },
102    Loaded {
103        index: Model<WorktreeIndex>,
104        _subscription: Subscription,
105    },
106}
107
108impl ProjectIndex {
109    fn new(
110        project: Model<Project>,
111        db_connection: heed::Env,
112        embedding_provider: Arc<dyn EmbeddingProvider>,
113        cx: &mut ModelContext<Self>,
114    ) -> Self {
115        let language_registry = project.read(cx).languages().clone();
116        let fs = project.read(cx).fs().clone();
117        let mut this = ProjectIndex {
118            db_connection,
119            project: project.clone(),
120            worktree_indices: HashMap::default(),
121            language_registry,
122            fs,
123            last_status: Status::Idle,
124            embedding_provider,
125            _subscription: cx.subscribe(&project, Self::handle_project_event),
126        };
127        this.update_worktree_indices(cx);
128        this
129    }
130
131    fn handle_project_event(
132        &mut self,
133        _: Model<Project>,
134        event: &project::Event,
135        cx: &mut ModelContext<Self>,
136    ) {
137        match event {
138            project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
139                self.update_worktree_indices(cx);
140            }
141            _ => {}
142        }
143    }
144
145    fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
146        let worktrees = self
147            .project
148            .read(cx)
149            .visible_worktrees(cx)
150            .filter_map(|worktree| {
151                if worktree.read(cx).is_local() {
152                    Some((worktree.entity_id(), worktree))
153                } else {
154                    None
155                }
156            })
157            .collect::<HashMap<_, _>>();
158
159        self.worktree_indices
160            .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
161        for (worktree_id, worktree) in worktrees {
162            self.worktree_indices.entry(worktree_id).or_insert_with(|| {
163                let worktree_index = WorktreeIndex::load(
164                    worktree.clone(),
165                    self.db_connection.clone(),
166                    self.language_registry.clone(),
167                    self.fs.clone(),
168                    self.embedding_provider.clone(),
169                    cx,
170                );
171
172                let load_worktree = cx.spawn(|this, mut cx| async move {
173                    if let Some(index) = worktree_index.await.log_err() {
174                        this.update(&mut cx, |this, cx| {
175                            this.worktree_indices.insert(
176                                worktree_id,
177                                WorktreeIndexHandle::Loaded {
178                                    _subscription: cx
179                                        .observe(&index, |this, _, cx| this.update_status(cx)),
180                                    index,
181                                },
182                            );
183                        })?;
184                    } else {
185                        this.update(&mut cx, |this, _cx| {
186                            this.worktree_indices.remove(&worktree_id)
187                        })?;
188                    }
189
190                    this.update(&mut cx, |this, cx| this.update_status(cx))
191                });
192
193                WorktreeIndexHandle::Loading {
194                    _task: load_worktree,
195                }
196            });
197        }
198
199        self.update_status(cx);
200    }
201
202    fn update_status(&mut self, cx: &mut ModelContext<Self>) {
203        let mut status = Status::Idle;
204        for index in self.worktree_indices.values() {
205            match index {
206                WorktreeIndexHandle::Loading { .. } => {
207                    status = Status::Scanning;
208                    break;
209                }
210                WorktreeIndexHandle::Loaded { index, .. } => {
211                    if index.read(cx).status == Status::Scanning {
212                        status = Status::Scanning;
213                        break;
214                    }
215                }
216            }
217        }
218
219        if status != self.last_status {
220            self.last_status = status;
221            cx.emit(status);
222        }
223    }
224
225    pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
226        let mut worktree_searches = Vec::new();
227        for worktree_index in self.worktree_indices.values() {
228            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
229                worktree_searches
230                    .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
231            }
232        }
233
234        cx.spawn(|_| async move {
235            let mut results = Vec::new();
236            let worktree_searches = futures::future::join_all(worktree_searches).await;
237
238            for worktree_search_results in worktree_searches {
239                if let Some(worktree_search_results) = worktree_search_results.log_err() {
240                    results.extend(worktree_search_results);
241                }
242            }
243
244            results
245                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
246            results.truncate(limit);
247
248            results
249        })
250    }
251}
252
253pub struct SearchResult {
254    pub worktree: Model<Worktree>,
255    pub path: Arc<Path>,
256    pub range: Range<usize>,
257    pub score: f32,
258}
259
260#[derive(Copy, Clone, Debug, Eq, PartialEq)]
261pub enum Status {
262    Idle,
263    Scanning,
264}
265
266impl EventEmitter<Status> for ProjectIndex {}
267
268struct WorktreeIndex {
269    worktree: Model<Worktree>,
270    db_connection: heed::Env,
271    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
272    language_registry: Arc<LanguageRegistry>,
273    fs: Arc<dyn Fs>,
274    embedding_provider: Arc<dyn EmbeddingProvider>,
275    status: Status,
276    _index_entries: Task<Result<()>>,
277    _subscription: Subscription,
278}
279
280impl WorktreeIndex {
281    pub fn load(
282        worktree: Model<Worktree>,
283        db_connection: heed::Env,
284        language_registry: Arc<LanguageRegistry>,
285        fs: Arc<dyn Fs>,
286        embedding_provider: Arc<dyn EmbeddingProvider>,
287        cx: &mut AppContext,
288    ) -> Task<Result<Model<Self>>> {
289        let worktree_abs_path = worktree.read(cx).abs_path();
290        cx.spawn(|mut cx| async move {
291            let db = cx
292                .background_executor()
293                .spawn({
294                    let db_connection = db_connection.clone();
295                    async move {
296                        let mut txn = db_connection.write_txn()?;
297                        let db_name = worktree_abs_path.to_string_lossy();
298                        let db = db_connection.create_database(&mut txn, Some(&db_name))?;
299                        txn.commit()?;
300                        anyhow::Ok(db)
301                    }
302                })
303                .await?;
304            cx.new_model(|cx| {
305                Self::new(
306                    worktree,
307                    db_connection,
308                    db,
309                    language_registry,
310                    fs,
311                    embedding_provider,
312                    cx,
313                )
314            })
315        })
316    }
317
318    fn new(
319        worktree: Model<Worktree>,
320        db_connection: heed::Env,
321        db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
322        language_registry: Arc<LanguageRegistry>,
323        fs: Arc<dyn Fs>,
324        embedding_provider: Arc<dyn EmbeddingProvider>,
325        cx: &mut ModelContext<Self>,
326    ) -> Self {
327        let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
328        let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
329            if let worktree::Event::UpdatedEntries(update) = event {
330                _ = updated_entries_tx.try_send(update.clone());
331            }
332        });
333
334        Self {
335            db_connection,
336            db,
337            worktree,
338            language_registry,
339            fs,
340            embedding_provider,
341            status: Status::Idle,
342            _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
343            _subscription,
344        }
345    }
346
347    async fn index_entries(
348        this: WeakModel<Self>,
349        updated_entries: channel::Receiver<UpdatedEntriesSet>,
350        mut cx: AsyncAppContext,
351    ) -> Result<()> {
352        let index = this.update(&mut cx, |this, cx| {
353            cx.notify();
354            this.status = Status::Scanning;
355            this.index_entries_changed_on_disk(cx)
356        })?;
357        index.await.log_err();
358        this.update(&mut cx, |this, cx| {
359            this.status = Status::Idle;
360            cx.notify();
361        })?;
362
363        while let Ok(updated_entries) = updated_entries.recv().await {
364            let index = this.update(&mut cx, |this, cx| {
365                cx.notify();
366                this.status = Status::Scanning;
367                this.index_updated_entries(updated_entries, cx)
368            })?;
369            index.await.log_err();
370            this.update(&mut cx, |this, cx| {
371                this.status = Status::Idle;
372                cx.notify();
373            })?;
374        }
375
376        Ok(())
377    }
378
379    fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
380        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
381        let worktree_abs_path = worktree.abs_path().clone();
382        let scan = self.scan_entries(worktree.clone(), cx);
383        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
384        let embed = self.embed_files(chunk.files, cx);
385        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
386        async move {
387            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
388            Ok(())
389        }
390    }
391
392    fn index_updated_entries(
393        &self,
394        updated_entries: UpdatedEntriesSet,
395        cx: &AppContext,
396    ) -> impl Future<Output = Result<()>> {
397        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
398        let worktree_abs_path = worktree.abs_path().clone();
399        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
400        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
401        let embed = self.embed_files(chunk.files, cx);
402        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
403        async move {
404            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
405            Ok(())
406        }
407    }
408
409    fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
410        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
411        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
412        let db_connection = self.db_connection.clone();
413        let db = self.db;
414        let task = cx.background_executor().spawn(async move {
415            let txn = db_connection
416                .read_txn()
417                .context("failed to create read transaction")?;
418            let mut db_entries = db
419                .iter(&txn)
420                .context("failed to create iterator")?
421                .move_between_keys()
422                .peekable();
423
424            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
425            for entry in worktree.files(false, 0) {
426                let entry_db_key = db_key_for_path(&entry.path);
427
428                let mut saved_mtime = None;
429                while let Some(db_entry) = db_entries.peek() {
430                    match db_entry {
431                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
432                            Ordering::Less => {
433                                if let Some(deletion_range) = deletion_range.as_mut() {
434                                    deletion_range.1 = Bound::Included(db_path);
435                                } else {
436                                    deletion_range =
437                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
438                                }
439
440                                db_entries.next();
441                            }
442                            Ordering::Equal => {
443                                if let Some(deletion_range) = deletion_range.take() {
444                                    deleted_entry_ranges_tx
445                                        .send((
446                                            deletion_range.0.map(ToString::to_string),
447                                            deletion_range.1.map(ToString::to_string),
448                                        ))
449                                        .await?;
450                                }
451                                saved_mtime = db_embedded_file.mtime;
452                                db_entries.next();
453                                break;
454                            }
455                            Ordering::Greater => {
456                                break;
457                            }
458                        },
459                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
460                    }
461                }
462
463                if entry.mtime != saved_mtime {
464                    updated_entries_tx.send(entry.clone()).await?;
465                }
466            }
467
468            if let Some(db_entry) = db_entries.next() {
469                let (db_path, _) = db_entry?;
470                deleted_entry_ranges_tx
471                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
472                    .await?;
473            }
474
475            Ok(())
476        });
477
478        ScanEntries {
479            updated_entries: updated_entries_rx,
480            deleted_entry_ranges: deleted_entry_ranges_rx,
481            task,
482        }
483    }
484
485    fn scan_updated_entries(
486        &self,
487        worktree: LocalSnapshot,
488        updated_entries: UpdatedEntriesSet,
489        cx: &AppContext,
490    ) -> ScanEntries {
491        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
492        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
493        let task = cx.background_executor().spawn(async move {
494            for (path, entry_id, status) in updated_entries.iter() {
495                match status {
496                    project::PathChange::Added
497                    | project::PathChange::Updated
498                    | project::PathChange::AddedOrUpdated => {
499                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
500                            if entry.is_file() {
501                                updated_entries_tx.send(entry.clone()).await?;
502                            }
503                        }
504                    }
505                    project::PathChange::Removed => {
506                        let db_path = db_key_for_path(path);
507                        deleted_entry_ranges_tx
508                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
509                            .await?;
510                    }
511                    project::PathChange::Loaded => {
512                        // Do nothing.
513                    }
514                }
515            }
516
517            Ok(())
518        });
519
520        ScanEntries {
521            updated_entries: updated_entries_rx,
522            deleted_entry_ranges: deleted_entry_ranges_rx,
523            task,
524        }
525    }
526
527    fn chunk_files(
528        &self,
529        worktree_abs_path: Arc<Path>,
530        entries: channel::Receiver<Entry>,
531        cx: &AppContext,
532    ) -> ChunkFiles {
533        let language_registry = self.language_registry.clone();
534        let fs = self.fs.clone();
535        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
536        let task = cx.spawn(|cx| async move {
537            cx.background_executor()
538                .scoped(|cx| {
539                    for _ in 0..cx.num_cpus() {
540                        cx.spawn(async {
541                            while let Ok(entry) = entries.recv().await {
542                                let entry_abs_path = worktree_abs_path.join(&entry.path);
543                                let Some(text) = fs
544                                    .load(&entry_abs_path)
545                                    .await
546                                    .with_context(|| {
547                                        format!("failed to read path {entry_abs_path:?}")
548                                    })
549                                    .log_err()
550                                else {
551                                    continue;
552                                };
553                                let language = language_registry
554                                    .language_for_file_path(&entry.path)
555                                    .await
556                                    .ok();
557                                let grammar =
558                                    language.as_ref().and_then(|language| language.grammar());
559                                let chunked_file = ChunkedFile {
560                                    worktree_root: worktree_abs_path.clone(),
561                                    chunks: chunk_text(&text, grammar),
562                                    entry,
563                                    text,
564                                };
565
566                                if chunked_files_tx.send(chunked_file).await.is_err() {
567                                    return;
568                                }
569                            }
570                        });
571                    }
572                })
573                .await;
574            Ok(())
575        });
576
577        ChunkFiles {
578            files: chunked_files_rx,
579            task,
580        }
581    }
582
583    fn embed_files(
584        &self,
585        chunked_files: channel::Receiver<ChunkedFile>,
586        cx: &AppContext,
587    ) -> EmbedFiles {
588        let embedding_provider = self.embedding_provider.clone();
589        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
590        let task = cx.background_executor().spawn(async move {
591            let mut chunked_file_batches =
592                chunked_files.chunks_timeout(512, Duration::from_secs(2));
593            while let Some(chunked_files) = chunked_file_batches.next().await {
594                // View the batch of files as a vec of chunks
595                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
596                // Once those are done, reassemble it back into which files they belong to
597
598                let chunks = chunked_files
599                    .iter()
600                    .flat_map(|file| {
601                        file.chunks.iter().map(|chunk| TextToEmbed {
602                            text: &file.text[chunk.range.clone()],
603                            digest: chunk.digest,
604                        })
605                    })
606                    .collect::<Vec<_>>();
607
608                let mut embeddings = Vec::new();
609                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
610                    embeddings.extend(embedding_provider.embed(embedding_batch).await?);
611                }
612
613                let mut embeddings = embeddings.into_iter();
614                for chunked_file in chunked_files {
615                    let chunk_embeddings = embeddings
616                        .by_ref()
617                        .take(chunked_file.chunks.len())
618                        .collect::<Vec<_>>();
619                    let embedded_chunks = chunked_file
620                        .chunks
621                        .into_iter()
622                        .zip(chunk_embeddings)
623                        .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
624                        .collect();
625                    let embedded_file = EmbeddedFile {
626                        path: chunked_file.entry.path.clone(),
627                        mtime: chunked_file.entry.mtime,
628                        chunks: embedded_chunks,
629                    };
630
631                    embedded_files_tx.send(embedded_file).await?;
632                }
633            }
634            Ok(())
635        });
636
637        EmbedFiles {
638            files: embedded_files_rx,
639            task,
640        }
641    }
642
643    fn persist_embeddings(
644        &self,
645        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
646        embedded_files: channel::Receiver<EmbeddedFile>,
647        cx: &AppContext,
648    ) -> Task<Result<()>> {
649        let db_connection = self.db_connection.clone();
650        let db = self.db;
651        cx.background_executor().spawn(async move {
652            while let Some(deletion_range) = deleted_entry_ranges.next().await {
653                let mut txn = db_connection.write_txn()?;
654                let start = deletion_range.0.as_ref().map(|start| start.as_str());
655                let end = deletion_range.1.as_ref().map(|end| end.as_str());
656                log::debug!("deleting embeddings in range {:?}", &(start, end));
657                db.delete_range(&mut txn, &(start, end))?;
658                txn.commit()?;
659            }
660
661            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
662            while let Some(embedded_files) = embedded_files.next().await {
663                let mut txn = db_connection.write_txn()?;
664                for file in embedded_files {
665                    log::debug!("saving embedding for file {:?}", file.path);
666                    let key = db_key_for_path(&file.path);
667                    db.put(&mut txn, &key, &file)?;
668                }
669                txn.commit()?;
670                log::debug!("committed");
671            }
672
673            Ok(())
674        })
675    }
676
677    fn search(
678        &self,
679        query: &str,
680        limit: usize,
681        cx: &AppContext,
682    ) -> Task<Result<Vec<SearchResult>>> {
683        let (chunks_tx, chunks_rx) = channel::bounded(1024);
684
685        let db_connection = self.db_connection.clone();
686        let db = self.db;
687        let scan_chunks = cx.background_executor().spawn({
688            async move {
689                let txn = db_connection
690                    .read_txn()
691                    .context("failed to create read transaction")?;
692                let db_entries = db.iter(&txn).context("failed to iterate database")?;
693                for db_entry in db_entries {
694                    let (_key, db_embedded_file) = db_entry?;
695                    for chunk in db_embedded_file.chunks {
696                        chunks_tx
697                            .send((db_embedded_file.path.clone(), chunk))
698                            .await?;
699                    }
700                }
701                anyhow::Ok(())
702            }
703        });
704
705        let query = query.to_string();
706        let embedding_provider = self.embedding_provider.clone();
707        let worktree = self.worktree.clone();
708        cx.spawn(|cx| async move {
709            #[cfg(debug_assertions)]
710            let embedding_query_start = std::time::Instant::now();
711            log::info!("Searching for {query}");
712
713            let mut query_embeddings = embedding_provider
714                .embed(&[TextToEmbed::new(&query)])
715                .await?;
716            let query_embedding = query_embeddings
717                .pop()
718                .ok_or_else(|| anyhow!("no embedding for query"))?;
719            let mut workers = Vec::new();
720            for _ in 0..cx.background_executor().num_cpus() {
721                workers.push(Vec::<SearchResult>::new());
722            }
723
724            #[cfg(debug_assertions)]
725            let search_start = std::time::Instant::now();
726
727            cx.background_executor()
728                .scoped(|cx| {
729                    for worker_results in workers.iter_mut() {
730                        cx.spawn(async {
731                            while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
732                                let score = embedded_chunk.embedding.similarity(&query_embedding);
733                                let ix = match worker_results.binary_search_by(|probe| {
734                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
735                                }) {
736                                    Ok(ix) | Err(ix) => ix,
737                                };
738                                worker_results.insert(
739                                    ix,
740                                    SearchResult {
741                                        worktree: worktree.clone(),
742                                        path: path.clone(),
743                                        range: embedded_chunk.chunk.range.clone(),
744                                        score,
745                                    },
746                                );
747                                worker_results.truncate(limit);
748                            }
749                        });
750                    }
751                })
752                .await;
753            scan_chunks.await?;
754
755            let mut search_results = Vec::with_capacity(workers.len() * limit);
756            for worker_results in workers {
757                search_results.extend(worker_results);
758            }
759            search_results
760                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
761            search_results.truncate(limit);
762            #[cfg(debug_assertions)]
763            {
764                let search_elapsed = search_start.elapsed();
765                log::debug!(
766                    "searched {} entries in {:?}",
767                    search_results.len(),
768                    search_elapsed
769                );
770                let embedding_query_elapsed = embedding_query_start.elapsed();
771                log::debug!("embedding query took {:?}", embedding_query_elapsed);
772            }
773
774            Ok(search_results)
775        })
776    }
777}
778
779struct ScanEntries {
780    updated_entries: channel::Receiver<Entry>,
781    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
782    task: Task<Result<()>>,
783}
784
785struct ChunkFiles {
786    files: channel::Receiver<ChunkedFile>,
787    task: Task<Result<()>>,
788}
789
790struct ChunkedFile {
791    #[allow(dead_code)]
792    pub worktree_root: Arc<Path>,
793    pub entry: Entry,
794    pub text: String,
795    pub chunks: Vec<Chunk>,
796}
797
798struct EmbedFiles {
799    files: channel::Receiver<EmbeddedFile>,
800    task: Task<Result<()>>,
801}
802
803#[derive(Debug, Serialize, Deserialize)]
804struct EmbeddedFile {
805    path: Arc<Path>,
806    mtime: Option<SystemTime>,
807    chunks: Vec<EmbeddedChunk>,
808}
809
810#[derive(Debug, Serialize, Deserialize)]
811struct EmbeddedChunk {
812    chunk: Chunk,
813    embedding: Embedding,
814}
815
816fn db_key_for_path(path: &Arc<Path>) -> String {
817    path.to_string_lossy().replace('/', "\0")
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823
824    use futures::channel::oneshot;
825    use futures::{future::BoxFuture, FutureExt};
826
827    use gpui::{Global, TestAppContext};
828    use language::language_settings::AllLanguageSettings;
829    use project::Project;
830    use settings::SettingsStore;
831    use std::{future, path::Path, sync::Arc};
832
833    fn init_test(cx: &mut TestAppContext) {
834        _ = cx.update(|cx| {
835            let store = SettingsStore::test(cx);
836            cx.set_global(store);
837            language::init(cx);
838            Project::init_settings(cx);
839            SettingsStore::update(cx, |store, cx| {
840                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
841            });
842        });
843    }
844
845    pub struct TestEmbeddingProvider;
846
847    impl EmbeddingProvider for TestEmbeddingProvider {
848        fn embed<'a>(
849            &'a self,
850            texts: &'a [TextToEmbed<'a>],
851        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
852            let embeddings = texts
853                .iter()
854                .map(|text| {
855                    let mut embedding = vec![0f32; 2];
856                    // if the text contains garbage, give it a 1 in the first dimension
857                    if text.text.contains("garbage in") {
858                        embedding[0] = 0.9;
859                    } else {
860                        embedding[0] = -0.9;
861                    }
862
863                    if text.text.contains("garbage out") {
864                        embedding[1] = 0.9;
865                    } else {
866                        embedding[1] = -0.9;
867                    }
868
869                    Embedding::new(embedding)
870                })
871                .collect();
872            future::ready(Ok(embeddings)).boxed()
873        }
874
875        fn batch_size(&self) -> usize {
876            16
877        }
878    }
879
880    #[gpui::test]
881    async fn test_search(cx: &mut TestAppContext) {
882        cx.executor().allow_parking();
883
884        init_test(cx);
885
886        let temp_dir = tempfile::tempdir().unwrap();
887
888        let mut semantic_index = SemanticIndex::new(
889            temp_dir.path().into(),
890            Arc::new(TestEmbeddingProvider),
891            &mut cx.to_async(),
892        )
893        .await
894        .unwrap();
895
896        let project_path = Path::new("./fixture");
897
898        let project = cx
899            .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
900            .await;
901
902        cx.update(|cx| {
903            let language_registry = project.read(cx).languages().clone();
904            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
905            languages::init(language_registry, node_runtime, cx);
906        });
907
908        let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
909
910        let (tx, rx) = oneshot::channel();
911        let mut tx = Some(tx);
912        let subscription = cx.update(|cx| {
913            cx.subscribe(&project_index, move |_, event, _| {
914                if let Some(tx) = tx.take() {
915                    _ = tx.send(*event);
916                }
917            })
918        });
919
920        rx.await.expect("no event emitted");
921        drop(subscription);
922
923        let results = cx
924            .update(|cx| {
925                let project_index = project_index.read(cx);
926                let query = "garbage in, garbage out";
927                project_index.search(query, 4, cx)
928            })
929            .await;
930
931        assert!(results.len() > 1, "should have found some results");
932
933        for result in &results {
934            println!("result: {:?}", result.path);
935            println!("score: {:?}", result.score);
936        }
937
938        // Find result that is greater than 0.5
939        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
940
941        assert_eq!(search_result.path.to_string_lossy(), "needle.md");
942
943        let content = cx
944            .update(|cx| {
945                let worktree = search_result.worktree.read(cx);
946                let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
947                let fs = project.read(cx).fs().clone();
948                cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
949            })
950            .await;
951
952        let range = search_result.range.clone();
953        let content = content[range.clone()].to_owned();
954
955        assert!(content.contains("garbage in, garbage out"));
956    }
957}