semantic_index.rs

  1mod chunking;
  2mod embedding;
  3
  4use anyhow::{anyhow, Context as _, Result};
  5use chunking::{chunk_text, Chunk};
  6use collections::{Bound, HashMap};
  7pub use embedding::*;
  8use fs::Fs;
  9use futures::stream::StreamExt;
 10use futures_batch::ChunksTimeoutStreamExt;
 11use gpui::{
 12    AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext,
 13    Subscription, Task, WeakModel,
 14};
 15use heed::types::{SerdeBincode, Str};
 16use language::LanguageRegistry;
 17use project::{Entry, Project, UpdatedEntriesSet, Worktree};
 18use serde::{Deserialize, Serialize};
 19use smol::channel;
 20use std::{
 21    cmp::Ordering,
 22    future::Future,
 23    ops::Range,
 24    path::Path,
 25    sync::Arc,
 26    time::{Duration, SystemTime},
 27};
 28use util::ResultExt;
 29use worktree::LocalSnapshot;
 30
 31pub struct SemanticIndex {
 32    embedding_provider: Arc<dyn EmbeddingProvider>,
 33    db_connection: heed::Env,
 34    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
 35}
 36
 37impl Global for SemanticIndex {}
 38
 39impl SemanticIndex {
 40    pub fn new(
 41        db_path: &Path,
 42        embedding_provider: Arc<dyn EmbeddingProvider>,
 43        cx: &mut AppContext,
 44    ) -> Task<Result<Self>> {
 45        let db_path = db_path.to_path_buf();
 46        cx.spawn(|cx| async move {
 47            let db_connection = cx
 48                .background_executor()
 49                .spawn(async move {
 50                    unsafe {
 51                        heed::EnvOpenOptions::new()
 52                            .map_size(1024 * 1024 * 1024)
 53                            .max_dbs(3000)
 54                            .open(db_path)
 55                    }
 56                })
 57                .await?;
 58
 59            Ok(SemanticIndex {
 60                db_connection,
 61                embedding_provider,
 62                project_indices: HashMap::default(),
 63            })
 64        })
 65    }
 66
 67    pub fn project_index(
 68        &mut self,
 69        project: Model<Project>,
 70        cx: &mut AppContext,
 71    ) -> Model<ProjectIndex> {
 72        self.project_indices
 73            .entry(project.downgrade())
 74            .or_insert_with(|| {
 75                cx.new_model(|cx| {
 76                    ProjectIndex::new(
 77                        project,
 78                        self.db_connection.clone(),
 79                        self.embedding_provider.clone(),
 80                        cx,
 81                    )
 82                })
 83            })
 84            .clone()
 85    }
 86}
 87
 88pub struct ProjectIndex {
 89    db_connection: heed::Env,
 90    project: Model<Project>,
 91    worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
 92    language_registry: Arc<LanguageRegistry>,
 93    fs: Arc<dyn Fs>,
 94    last_status: Status,
 95    embedding_provider: Arc<dyn EmbeddingProvider>,
 96    _subscription: Subscription,
 97}
 98
 99enum WorktreeIndexHandle {
100    Loading {
101        _task: Task<Result<()>>,
102    },
103    Loaded {
104        index: Model<WorktreeIndex>,
105        _subscription: Subscription,
106    },
107}
108
109impl ProjectIndex {
110    fn new(
111        project: Model<Project>,
112        db_connection: heed::Env,
113        embedding_provider: Arc<dyn EmbeddingProvider>,
114        cx: &mut ModelContext<Self>,
115    ) -> Self {
116        let language_registry = project.read(cx).languages().clone();
117        let fs = project.read(cx).fs().clone();
118        let mut this = ProjectIndex {
119            db_connection,
120            project: project.clone(),
121            worktree_indices: HashMap::default(),
122            language_registry,
123            fs,
124            last_status: Status::Idle,
125            embedding_provider,
126            _subscription: cx.subscribe(&project, Self::handle_project_event),
127        };
128        this.update_worktree_indices(cx);
129        this
130    }
131
132    fn handle_project_event(
133        &mut self,
134        _: Model<Project>,
135        event: &project::Event,
136        cx: &mut ModelContext<Self>,
137    ) {
138        match event {
139            project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
140                self.update_worktree_indices(cx);
141            }
142            _ => {}
143        }
144    }
145
146    fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
147        let worktrees = self
148            .project
149            .read(cx)
150            .visible_worktrees(cx)
151            .filter_map(|worktree| {
152                if worktree.read(cx).is_local() {
153                    Some((worktree.entity_id(), worktree))
154                } else {
155                    None
156                }
157            })
158            .collect::<HashMap<_, _>>();
159
160        self.worktree_indices
161            .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
162        for (worktree_id, worktree) in worktrees {
163            self.worktree_indices.entry(worktree_id).or_insert_with(|| {
164                let worktree_index = WorktreeIndex::load(
165                    worktree.clone(),
166                    self.db_connection.clone(),
167                    self.language_registry.clone(),
168                    self.fs.clone(),
169                    self.embedding_provider.clone(),
170                    cx,
171                );
172
173                let load_worktree = cx.spawn(|this, mut cx| async move {
174                    if let Some(index) = worktree_index.await.log_err() {
175                        this.update(&mut cx, |this, cx| {
176                            this.worktree_indices.insert(
177                                worktree_id,
178                                WorktreeIndexHandle::Loaded {
179                                    _subscription: cx
180                                        .observe(&index, |this, _, cx| this.update_status(cx)),
181                                    index,
182                                },
183                            );
184                        })?;
185                    } else {
186                        this.update(&mut cx, |this, _cx| {
187                            this.worktree_indices.remove(&worktree_id)
188                        })?;
189                    }
190
191                    this.update(&mut cx, |this, cx| this.update_status(cx))
192                });
193
194                WorktreeIndexHandle::Loading {
195                    _task: load_worktree,
196                }
197            });
198        }
199
200        self.update_status(cx);
201    }
202
203    fn update_status(&mut self, cx: &mut ModelContext<Self>) {
204        let mut status = Status::Idle;
205        for index in self.worktree_indices.values() {
206            match index {
207                WorktreeIndexHandle::Loading { .. } => {
208                    status = Status::Scanning;
209                    break;
210                }
211                WorktreeIndexHandle::Loaded { index, .. } => {
212                    if index.read(cx).status == Status::Scanning {
213                        status = Status::Scanning;
214                        break;
215                    }
216                }
217            }
218        }
219
220        if status != self.last_status {
221            self.last_status = status;
222            cx.emit(status);
223        }
224    }
225
226    pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
227        let mut worktree_searches = Vec::new();
228        for worktree_index in self.worktree_indices.values() {
229            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
230                worktree_searches
231                    .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
232            }
233        }
234
235        cx.spawn(|_| async move {
236            let mut results = Vec::new();
237            let worktree_searches = futures::future::join_all(worktree_searches).await;
238
239            for worktree_search_results in worktree_searches {
240                if let Some(worktree_search_results) = worktree_search_results.log_err() {
241                    results.extend(worktree_search_results);
242                }
243            }
244
245            results
246                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
247            results.truncate(limit);
248
249            results
250        })
251    }
252}
253
254pub struct SearchResult {
255    pub worktree: Model<Worktree>,
256    pub path: Arc<Path>,
257    pub range: Range<usize>,
258    pub score: f32,
259}
260
261#[derive(Copy, Clone, Debug, Eq, PartialEq)]
262pub enum Status {
263    Idle,
264    Scanning,
265}
266
267impl EventEmitter<Status> for ProjectIndex {}
268
269struct WorktreeIndex {
270    worktree: Model<Worktree>,
271    db_connection: heed::Env,
272    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
273    language_registry: Arc<LanguageRegistry>,
274    fs: Arc<dyn Fs>,
275    embedding_provider: Arc<dyn EmbeddingProvider>,
276    status: Status,
277    _index_entries: Task<Result<()>>,
278    _subscription: Subscription,
279}
280
281impl WorktreeIndex {
282    pub fn load(
283        worktree: Model<Worktree>,
284        db_connection: heed::Env,
285        language_registry: Arc<LanguageRegistry>,
286        fs: Arc<dyn Fs>,
287        embedding_provider: Arc<dyn EmbeddingProvider>,
288        cx: &mut AppContext,
289    ) -> Task<Result<Model<Self>>> {
290        let worktree_abs_path = worktree.read(cx).abs_path();
291        cx.spawn(|mut cx| async move {
292            let db = cx
293                .background_executor()
294                .spawn({
295                    let db_connection = db_connection.clone();
296                    async move {
297                        let mut txn = db_connection.write_txn()?;
298                        let db_name = worktree_abs_path.to_string_lossy();
299                        let db = db_connection.create_database(&mut txn, Some(&db_name))?;
300                        txn.commit()?;
301                        anyhow::Ok(db)
302                    }
303                })
304                .await?;
305            cx.new_model(|cx| {
306                Self::new(
307                    worktree,
308                    db_connection,
309                    db,
310                    language_registry,
311                    fs,
312                    embedding_provider,
313                    cx,
314                )
315            })
316        })
317    }
318
319    fn new(
320        worktree: Model<Worktree>,
321        db_connection: heed::Env,
322        db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
323        language_registry: Arc<LanguageRegistry>,
324        fs: Arc<dyn Fs>,
325        embedding_provider: Arc<dyn EmbeddingProvider>,
326        cx: &mut ModelContext<Self>,
327    ) -> Self {
328        let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
329        let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
330            if let worktree::Event::UpdatedEntries(update) = event {
331                _ = updated_entries_tx.try_send(update.clone());
332            }
333        });
334
335        Self {
336            db_connection,
337            db,
338            worktree,
339            language_registry,
340            fs,
341            embedding_provider,
342            status: Status::Idle,
343            _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
344            _subscription,
345        }
346    }
347
348    async fn index_entries(
349        this: WeakModel<Self>,
350        updated_entries: channel::Receiver<UpdatedEntriesSet>,
351        mut cx: AsyncAppContext,
352    ) -> Result<()> {
353        let index = this.update(&mut cx, |this, cx| {
354            cx.notify();
355            this.status = Status::Scanning;
356            this.index_entries_changed_on_disk(cx)
357        })?;
358        index.await.log_err();
359        this.update(&mut cx, |this, cx| {
360            this.status = Status::Idle;
361            cx.notify();
362        })?;
363
364        while let Ok(updated_entries) = updated_entries.recv().await {
365            let index = this.update(&mut cx, |this, cx| {
366                cx.notify();
367                this.status = Status::Scanning;
368                this.index_updated_entries(updated_entries, cx)
369            })?;
370            index.await.log_err();
371            this.update(&mut cx, |this, cx| {
372                this.status = Status::Idle;
373                cx.notify();
374            })?;
375        }
376
377        Ok(())
378    }
379
380    fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
381        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
382        let worktree_abs_path = worktree.abs_path().clone();
383        let scan = self.scan_entries(worktree.clone(), cx);
384        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
385        let embed = self.embed_files(chunk.files, cx);
386        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
387        async move {
388            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
389            Ok(())
390        }
391    }
392
393    fn index_updated_entries(
394        &self,
395        updated_entries: UpdatedEntriesSet,
396        cx: &AppContext,
397    ) -> impl Future<Output = Result<()>> {
398        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
399        let worktree_abs_path = worktree.abs_path().clone();
400        let scan = self.scan_updated_entries(worktree, updated_entries, cx);
401        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
402        let embed = self.embed_files(chunk.files, cx);
403        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
404        async move {
405            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
406            Ok(())
407        }
408    }
409
410    fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
411        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
412        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
413        let db_connection = self.db_connection.clone();
414        let db = self.db;
415        let task = cx.background_executor().spawn(async move {
416            let txn = db_connection
417                .read_txn()
418                .context("failed to create read transaction")?;
419            let mut db_entries = db
420                .iter(&txn)
421                .context("failed to create iterator")?
422                .move_between_keys()
423                .peekable();
424
425            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
426            for entry in worktree.files(false, 0) {
427                let entry_db_key = db_key_for_path(&entry.path);
428
429                let mut saved_mtime = None;
430                while let Some(db_entry) = db_entries.peek() {
431                    match db_entry {
432                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
433                            Ordering::Less => {
434                                if let Some(deletion_range) = deletion_range.as_mut() {
435                                    deletion_range.1 = Bound::Included(db_path);
436                                } else {
437                                    deletion_range =
438                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
439                                }
440
441                                db_entries.next();
442                            }
443                            Ordering::Equal => {
444                                if let Some(deletion_range) = deletion_range.take() {
445                                    deleted_entry_ranges_tx
446                                        .send((
447                                            deletion_range.0.map(ToString::to_string),
448                                            deletion_range.1.map(ToString::to_string),
449                                        ))
450                                        .await?;
451                                }
452                                saved_mtime = db_embedded_file.mtime;
453                                db_entries.next();
454                                break;
455                            }
456                            Ordering::Greater => {
457                                break;
458                            }
459                        },
460                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
461                    }
462                }
463
464                if entry.mtime != saved_mtime {
465                    updated_entries_tx.send(entry.clone()).await?;
466                }
467            }
468
469            if let Some(db_entry) = db_entries.next() {
470                let (db_path, _) = db_entry?;
471                deleted_entry_ranges_tx
472                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
473                    .await?;
474            }
475
476            Ok(())
477        });
478
479        ScanEntries {
480            updated_entries: updated_entries_rx,
481            deleted_entry_ranges: deleted_entry_ranges_rx,
482            task,
483        }
484    }
485
486    fn scan_updated_entries(
487        &self,
488        worktree: LocalSnapshot,
489        updated_entries: UpdatedEntriesSet,
490        cx: &AppContext,
491    ) -> ScanEntries {
492        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
493        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
494        let task = cx.background_executor().spawn(async move {
495            for (path, entry_id, status) in updated_entries.iter() {
496                match status {
497                    project::PathChange::Added
498                    | project::PathChange::Updated
499                    | project::PathChange::AddedOrUpdated => {
500                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
501                            updated_entries_tx.send(entry.clone()).await?;
502                        }
503                    }
504                    project::PathChange::Removed => {
505                        let db_path = db_key_for_path(path);
506                        deleted_entry_ranges_tx
507                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
508                            .await?;
509                    }
510                    project::PathChange::Loaded => {
511                        // Do nothing.
512                    }
513                }
514            }
515
516            Ok(())
517        });
518
519        ScanEntries {
520            updated_entries: updated_entries_rx,
521            deleted_entry_ranges: deleted_entry_ranges_rx,
522            task,
523        }
524    }
525
526    fn chunk_files(
527        &self,
528        worktree_abs_path: Arc<Path>,
529        entries: channel::Receiver<Entry>,
530        cx: &AppContext,
531    ) -> ChunkFiles {
532        let language_registry = self.language_registry.clone();
533        let fs = self.fs.clone();
534        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
535        let task = cx.spawn(|cx| async move {
536            cx.background_executor()
537                .scoped(|cx| {
538                    for _ in 0..cx.num_cpus() {
539                        cx.spawn(async {
540                            while let Ok(entry) = entries.recv().await {
541                                let entry_abs_path = worktree_abs_path.join(&entry.path);
542                                let Some(text) = fs.load(&entry_abs_path).await.log_err() else {
543                                    continue;
544                                };
545                                let language = language_registry
546                                    .language_for_file_path(&entry.path)
547                                    .await
548                                    .ok();
549                                let grammar =
550                                    language.as_ref().and_then(|language| language.grammar());
551                                let chunked_file = ChunkedFile {
552                                    worktree_root: worktree_abs_path.clone(),
553                                    chunks: chunk_text(&text, grammar),
554                                    entry,
555                                    text,
556                                };
557
558                                if chunked_files_tx.send(chunked_file).await.is_err() {
559                                    return;
560                                }
561                            }
562                        });
563                    }
564                })
565                .await;
566            Ok(())
567        });
568
569        ChunkFiles {
570            files: chunked_files_rx,
571            task,
572        }
573    }
574
575    fn embed_files(
576        &self,
577        chunked_files: channel::Receiver<ChunkedFile>,
578        cx: &AppContext,
579    ) -> EmbedFiles {
580        let embedding_provider = self.embedding_provider.clone();
581        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
582        let task = cx.background_executor().spawn(async move {
583            let mut chunked_file_batches =
584                chunked_files.chunks_timeout(512, Duration::from_secs(2));
585            while let Some(chunked_files) = chunked_file_batches.next().await {
586                // View the batch of files as a vec of chunks
587                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
588                // Once those are done, reassemble it back into which files they belong to
589
590                let chunks = chunked_files
591                    .iter()
592                    .flat_map(|file| {
593                        file.chunks.iter().map(|chunk| TextToEmbed {
594                            text: &file.text[chunk.range.clone()],
595                            digest: chunk.digest,
596                        })
597                    })
598                    .collect::<Vec<_>>();
599
600                let mut embeddings = Vec::new();
601                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
602                    // todo!("add a retry facility")
603                    embeddings.extend(embedding_provider.embed(embedding_batch).await?);
604                }
605
606                let mut embeddings = embeddings.into_iter();
607                for chunked_file in chunked_files {
608                    let chunk_embeddings = embeddings
609                        .by_ref()
610                        .take(chunked_file.chunks.len())
611                        .collect::<Vec<_>>();
612                    let embedded_chunks = chunked_file
613                        .chunks
614                        .into_iter()
615                        .zip(chunk_embeddings)
616                        .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
617                        .collect();
618                    let embedded_file = EmbeddedFile {
619                        path: chunked_file.entry.path.clone(),
620                        mtime: chunked_file.entry.mtime,
621                        chunks: embedded_chunks,
622                    };
623
624                    embedded_files_tx.send(embedded_file).await?;
625                }
626            }
627            Ok(())
628        });
629
630        EmbedFiles {
631            files: embedded_files_rx,
632            task,
633        }
634    }
635
636    fn persist_embeddings(
637        &self,
638        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
639        embedded_files: channel::Receiver<EmbeddedFile>,
640        cx: &AppContext,
641    ) -> Task<Result<()>> {
642        let db_connection = self.db_connection.clone();
643        let db = self.db;
644        cx.background_executor().spawn(async move {
645            while let Some(deletion_range) = deleted_entry_ranges.next().await {
646                let mut txn = db_connection.write_txn()?;
647                let start = deletion_range.0.as_ref().map(|start| start.as_str());
648                let end = deletion_range.1.as_ref().map(|end| end.as_str());
649                log::debug!("deleting embeddings in range {:?}", &(start, end));
650                db.delete_range(&mut txn, &(start, end))?;
651                txn.commit()?;
652            }
653
654            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
655            while let Some(embedded_files) = embedded_files.next().await {
656                let mut txn = db_connection.write_txn()?;
657                for file in embedded_files {
658                    log::debug!("saving embedding for file {:?}", file.path);
659                    let key = db_key_for_path(&file.path);
660                    db.put(&mut txn, &key, &file)?;
661                }
662                txn.commit()?;
663                log::debug!("committed");
664            }
665
666            Ok(())
667        })
668    }
669
670    fn search(
671        &self,
672        query: &str,
673        limit: usize,
674        cx: &AppContext,
675    ) -> Task<Result<Vec<SearchResult>>> {
676        let (chunks_tx, chunks_rx) = channel::bounded(1024);
677
678        let db_connection = self.db_connection.clone();
679        let db = self.db;
680        let scan_chunks = cx.background_executor().spawn({
681            async move {
682                let txn = db_connection
683                    .read_txn()
684                    .context("failed to create read transaction")?;
685                let db_entries = db.iter(&txn).context("failed to iterate database")?;
686                for db_entry in db_entries {
687                    let (_, db_embedded_file) = db_entry?;
688                    for chunk in db_embedded_file.chunks {
689                        chunks_tx
690                            .send((db_embedded_file.path.clone(), chunk))
691                            .await?;
692                    }
693                }
694                anyhow::Ok(())
695            }
696        });
697
698        let query = query.to_string();
699        let embedding_provider = self.embedding_provider.clone();
700        let worktree = self.worktree.clone();
701        cx.spawn(|cx| async move {
702            #[cfg(debug_assertions)]
703            let embedding_query_start = std::time::Instant::now();
704
705            let mut query_embeddings = embedding_provider
706                .embed(&[TextToEmbed::new(&query)])
707                .await?;
708            let query_embedding = query_embeddings
709                .pop()
710                .ok_or_else(|| anyhow!("no embedding for query"))?;
711            let mut workers = Vec::new();
712            for _ in 0..cx.background_executor().num_cpus() {
713                workers.push(Vec::<SearchResult>::new());
714            }
715
716            #[cfg(debug_assertions)]
717            let search_start = std::time::Instant::now();
718
719            cx.background_executor()
720                .scoped(|cx| {
721                    for worker_results in workers.iter_mut() {
722                        cx.spawn(async {
723                            while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
724                                let score = embedded_chunk.embedding.similarity(&query_embedding);
725                                let ix = match worker_results.binary_search_by(|probe| {
726                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
727                                }) {
728                                    Ok(ix) | Err(ix) => ix,
729                                };
730                                worker_results.insert(
731                                    ix,
732                                    SearchResult {
733                                        worktree: worktree.clone(),
734                                        path: path.clone(),
735                                        range: embedded_chunk.chunk.range.clone(),
736                                        score,
737                                    },
738                                );
739                                worker_results.truncate(limit);
740                            }
741                        });
742                    }
743                })
744                .await;
745            scan_chunks.await?;
746
747            let mut search_results = Vec::with_capacity(workers.len() * limit);
748            for worker_results in workers {
749                search_results.extend(worker_results);
750            }
751            search_results
752                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
753            search_results.truncate(limit);
754            #[cfg(debug_assertions)]
755            {
756                let search_elapsed = search_start.elapsed();
757                log::debug!(
758                    "searched {} entries in {:?}",
759                    search_results.len(),
760                    search_elapsed
761                );
762                let embedding_query_elapsed = embedding_query_start.elapsed();
763                log::debug!("embedding query took {:?}", embedding_query_elapsed);
764            }
765
766            Ok(search_results)
767        })
768    }
769}
770
771struct ScanEntries {
772    updated_entries: channel::Receiver<Entry>,
773    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
774    task: Task<Result<()>>,
775}
776
777struct ChunkFiles {
778    files: channel::Receiver<ChunkedFile>,
779    task: Task<Result<()>>,
780}
781
782struct ChunkedFile {
783    #[allow(dead_code)]
784    pub worktree_root: Arc<Path>,
785    pub entry: Entry,
786    pub text: String,
787    pub chunks: Vec<Chunk>,
788}
789
790struct EmbedFiles {
791    files: channel::Receiver<EmbeddedFile>,
792    task: Task<Result<()>>,
793}
794
795#[derive(Debug, Serialize, Deserialize)]
796struct EmbeddedFile {
797    path: Arc<Path>,
798    mtime: Option<SystemTime>,
799    chunks: Vec<EmbeddedChunk>,
800}
801
802#[derive(Debug, Serialize, Deserialize)]
803struct EmbeddedChunk {
804    chunk: Chunk,
805    embedding: Embedding,
806}
807
808fn db_key_for_path(path: &Arc<Path>) -> String {
809    path.to_string_lossy().replace('/', "\0")
810}
811
812#[cfg(test)]
813mod tests {
814    use super::*;
815
816    use futures::channel::oneshot;
817    use futures::{future::BoxFuture, FutureExt};
818
819    use gpui::{Global, TestAppContext};
820    use language::language_settings::AllLanguageSettings;
821    use project::Project;
822    use settings::SettingsStore;
823    use std::{future, path::Path, sync::Arc};
824
825    fn init_test(cx: &mut TestAppContext) {
826        _ = cx.update(|cx| {
827            let store = SettingsStore::test(cx);
828            cx.set_global(store);
829            language::init(cx);
830            Project::init_settings(cx);
831            SettingsStore::update(cx, |store, cx| {
832                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
833            });
834        });
835    }
836
837    pub struct TestEmbeddingProvider;
838
839    impl EmbeddingProvider for TestEmbeddingProvider {
840        fn embed<'a>(
841            &'a self,
842            texts: &'a [TextToEmbed<'a>],
843        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
844            let embeddings = texts
845                .iter()
846                .map(|text| {
847                    let mut embedding = vec![0f32; 2];
848                    // if the text contains garbage, give it a 1 in the first dimension
849                    if text.text.contains("garbage in") {
850                        embedding[0] = 0.9;
851                    } else {
852                        embedding[0] = -0.9;
853                    }
854
855                    if text.text.contains("garbage out") {
856                        embedding[1] = 0.9;
857                    } else {
858                        embedding[1] = -0.9;
859                    }
860
861                    Embedding::new(embedding)
862                })
863                .collect();
864            future::ready(Ok(embeddings)).boxed()
865        }
866
867        fn batch_size(&self) -> usize {
868            16
869        }
870    }
871
872    #[gpui::test]
873    async fn test_search(cx: &mut TestAppContext) {
874        cx.executor().allow_parking();
875
876        init_test(cx);
877
878        let temp_dir = tempfile::tempdir().unwrap();
879
880        let mut semantic_index = cx
881            .update(|cx| {
882                let semantic_index = SemanticIndex::new(
883                    Path::new(temp_dir.path()),
884                    Arc::new(TestEmbeddingProvider),
885                    cx,
886                );
887                semantic_index
888            })
889            .await
890            .unwrap();
891
892        // todo!(): use a fixture
893        let project_path = Path::new("./fixture");
894
895        let project = cx
896            .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
897            .await;
898
899        cx.update(|cx| {
900            let language_registry = project.read(cx).languages().clone();
901            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
902            languages::init(language_registry, node_runtime, cx);
903        });
904
905        let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
906
907        let (tx, rx) = oneshot::channel();
908        let mut tx = Some(tx);
909        let subscription = cx.update(|cx| {
910            cx.subscribe(&project_index, move |_, event, _| {
911                if let Some(tx) = tx.take() {
912                    _ = tx.send(*event);
913                }
914            })
915        });
916
917        rx.await.expect("no event emitted");
918        drop(subscription);
919
920        let results = cx
921            .update(|cx| {
922                let project_index = project_index.read(cx);
923                let query = "garbage in, garbage out";
924                project_index.search(query, 4, cx)
925            })
926            .await;
927
928        assert!(results.len() > 1, "should have found some results");
929
930        for result in &results {
931            println!("result: {:?}", result.path);
932            println!("score: {:?}", result.score);
933        }
934
935        // Find result that is greater than 0.5
936        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
937
938        assert_eq!(search_result.path.to_string_lossy(), "needle.md");
939
940        let content = cx
941            .update(|cx| {
942                let worktree = search_result.worktree.read(cx);
943                let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
944                let fs = project.read(cx).fs().clone();
945                cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
946            })
947            .await;
948
949        let range = search_result.range.clone();
950        let content = content[range.clone()].to_owned();
951
952        assert!(content.contains("garbage in, garbage out"));
953    }
954}