vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::VectorDatabase;
 10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 11use futures::{channel::oneshot, Future};
 12use gpui::{
 13    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
 14    WeakModelHandle,
 15};
 16use language::{Language, LanguageRegistry};
 17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 18use project::{Fs, Project, WorktreeId};
 19use smol::channel;
 20use std::{
 21    cell::RefCell,
 22    cmp::Ordering,
 23    collections::HashMap,
 24    path::{Path, PathBuf},
 25    rc::Rc,
 26    sync::Arc,
 27    time::{Duration, Instant, SystemTime},
 28};
 29use tree_sitter::{Parser, QueryCursor};
 30use util::{
 31    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 32    http::HttpClient,
 33    paths::EMBEDDINGS_DIR,
 34    ResultExt,
 35};
 36use workspace::{Workspace, WorkspaceCreated};
 37
 38const REINDEXING_DELAY_SECONDS: u64 = 3;
 39const EMBEDDINGS_BATCH_SIZE: usize = 150;
 40
 41#[derive(Debug, Clone)]
 42pub struct Document {
 43    pub offset: usize,
 44    pub name: String,
 45    pub embedding: Vec<f32>,
 46}
 47
 48pub fn init(
 49    fs: Arc<dyn Fs>,
 50    http_client: Arc<dyn HttpClient>,
 51    language_registry: Arc<LanguageRegistry>,
 52    cx: &mut AppContext,
 53) {
 54    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 55        return;
 56    }
 57
 58    let db_file_path = EMBEDDINGS_DIR
 59        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 60        .join("embeddings_db");
 61
 62    cx.spawn(move |mut cx| async move {
 63        let vector_store = VectorStore::new(
 64            fs,
 65            db_file_path,
 66            // Arc::new(embedding::DummyEmbeddings {}),
 67            Arc::new(OpenAIEmbeddings {
 68                client: http_client,
 69            }),
 70            language_registry,
 71            cx.clone(),
 72        )
 73        .await?;
 74
 75        cx.update(|cx| {
 76            cx.subscribe_global::<WorkspaceCreated, _>({
 77                let vector_store = vector_store.clone();
 78                move |event, cx| {
 79                    let workspace = &event.0;
 80                    if let Some(workspace) = workspace.upgrade(cx) {
 81                        let project = workspace.read(cx).project().clone();
 82                        if project.read(cx).is_local() {
 83                            vector_store.update(cx, |store, cx| {
 84                                store.add_project(project, cx).detach();
 85                            });
 86                        }
 87                    }
 88                }
 89            })
 90            .detach();
 91
 92            cx.add_action({
 93                move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 94                    let vector_store = vector_store.clone();
 95                    workspace.toggle_modal(cx, |workspace, cx| {
 96                        let project = workspace.project().clone();
 97                        let workspace = cx.weak_handle();
 98                        cx.add_view(|cx| {
 99                            SemanticSearch::new(
100                                SemanticSearchDelegate::new(workspace, project, vector_store),
101                                cx,
102                            )
103                        })
104                    })
105                }
106            });
107
108            SemanticSearch::init(cx);
109        });
110
111        anyhow::Ok(())
112    })
113    .detach();
114}
115
116#[derive(Debug, Clone)]
117pub struct IndexedFile {
118    path: PathBuf,
119    mtime: SystemTime,
120    documents: Vec<Document>,
121}
122
123pub struct VectorStore {
124    fs: Arc<dyn Fs>,
125    database_url: Arc<PathBuf>,
126    embedding_provider: Arc<dyn EmbeddingProvider>,
127    language_registry: Arc<LanguageRegistry>,
128    db_update_tx: channel::Sender<DbWrite>,
129    parsing_files_tx: channel::Sender<PendingFile>,
130    _db_update_task: Task<()>,
131    _embed_batch_task: Vec<Task<()>>,
132    _batch_files_task: Task<()>,
133    _parsing_files_tasks: Vec<Task<()>>,
134    projects: HashMap<WeakModelHandle<Project>, Rc<RefCell<ProjectState>>>,
135}
136
137struct ProjectState {
138    worktree_db_ids: Vec<(WorktreeId, i64)>,
139    pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
140    _subscription: gpui::Subscription,
141}
142
143impl ProjectState {
144    fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
145        // If Pending File Already Exists, Replace it with the new one
146        // but keep the old indexing time
147        if let Some(old_file) = self.pending_files.remove(&pending_file.path.clone()) {
148            self.pending_files
149                .insert(pending_file.path.clone(), (pending_file, old_file.1));
150        } else {
151            self.pending_files
152                .insert(pending_file.path.clone(), (pending_file, indexing_time));
153        };
154    }
155
156    fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
157        let mut outstanding_files = vec![];
158        let mut remove_keys = vec![];
159        for key in self.pending_files.keys().into_iter() {
160            if let Some(pending_details) = self.pending_files.get(key) {
161                let (pending_file, index_time) = pending_details;
162                if index_time <= &SystemTime::now() {
163                    outstanding_files.push(pending_file.clone());
164                    remove_keys.push(key.clone());
165                }
166            }
167        }
168
169        for key in remove_keys.iter() {
170            self.pending_files.remove(key);
171        }
172
173        return outstanding_files;
174    }
175}
176
177#[derive(Clone, Debug)]
178struct PendingFile {
179    worktree_db_id: i64,
180    path: PathBuf,
181    language: Arc<Language>,
182    modified_time: SystemTime,
183}
184
185#[derive(Debug, Clone)]
186pub struct SearchResult {
187    pub worktree_id: WorktreeId,
188    pub name: String,
189    pub offset: usize,
190    pub file_path: PathBuf,
191}
192
193enum DbWrite {
194    InsertFile {
195        worktree_id: i64,
196        indexed_file: IndexedFile,
197    },
198    Delete {
199        worktree_id: i64,
200        path: PathBuf,
201    },
202    FindOrCreateWorktree {
203        path: PathBuf,
204        sender: oneshot::Sender<Result<i64>>,
205    },
206}
207
208impl VectorStore {
209    async fn new(
210        fs: Arc<dyn Fs>,
211        database_url: PathBuf,
212        embedding_provider: Arc<dyn EmbeddingProvider>,
213        language_registry: Arc<LanguageRegistry>,
214        mut cx: AsyncAppContext,
215    ) -> Result<ModelHandle<Self>> {
216        let database_url = Arc::new(database_url);
217
218        let db = cx
219            .background()
220            .spawn({
221                let fs = fs.clone();
222                let database_url = database_url.clone();
223                async move {
224                    if let Some(db_directory) = database_url.parent() {
225                        fs.create_dir(db_directory).await.log_err();
226                    }
227
228                    let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
229                    anyhow::Ok(db)
230                }
231            })
232            .await?;
233
234        Ok(cx.add_model(|cx| {
235            // paths_tx -> embeddings_tx -> db_update_tx
236
237            //db_update_tx/rx: Updating Database
238            let (db_update_tx, db_update_rx) = channel::unbounded();
239            let _db_update_task = cx.background().spawn(async move {
240                while let Ok(job) = db_update_rx.recv().await {
241                    match job {
242                        DbWrite::InsertFile {
243                            worktree_id,
244                            indexed_file,
245                        } => {
246                            log::info!("Inserting Data for {:?}", &indexed_file.path);
247                            db.insert_file(worktree_id, indexed_file).log_err();
248                        }
249                        DbWrite::Delete { worktree_id, path } => {
250                            db.delete_file(worktree_id, path).log_err();
251                        }
252                        DbWrite::FindOrCreateWorktree { path, sender } => {
253                            let id = db.find_or_create_worktree(&path);
254                            sender.send(id).ok();
255                        }
256                    }
257                }
258            });
259
260            // embed_tx/rx: Embed Batch and Send to Database
261            let (embed_batch_tx, embed_batch_rx) =
262                channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
263            let mut _embed_batch_task = Vec::new();
264            for _ in 0..1 {
265                //cx.background().num_cpus() {
266                let db_update_tx = db_update_tx.clone();
267                let embed_batch_rx = embed_batch_rx.clone();
268                let embedding_provider = embedding_provider.clone();
269                _embed_batch_task.push(cx.background().spawn(async move {
270                    while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
271                        // Construct Batch
272                        let mut embeddings_queue = embeddings_queue.clone();
273                        let mut document_spans = vec![];
274                        for (_, _, document_span) in embeddings_queue.clone().into_iter() {
275                            document_spans.extend(document_span);
276                        }
277
278                        if let Ok(embeddings) = embedding_provider
279                            .embed_batch(document_spans.iter().map(|x| &**x).collect())
280                            .await
281                        {
282                            let mut i = 0;
283                            let mut j = 0;
284
285                            for embedding in embeddings.iter() {
286                                while embeddings_queue[i].1.documents.len() == j {
287                                    i += 1;
288                                    j = 0;
289                                }
290
291                                embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
292                                j += 1;
293                            }
294
295                            for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
296                                for document in indexed_file.documents.iter() {
297                                    // TODO: Update this so it doesn't panic
298                                    assert!(
299                                        document.embedding.len() > 0,
300                                        "Document Embedding Not Complete"
301                                    );
302                                }
303
304                                db_update_tx
305                                    .send(DbWrite::InsertFile {
306                                        worktree_id,
307                                        indexed_file,
308                                    })
309                                    .await
310                                    .unwrap();
311                            }
312                        }
313                    }
314                }))
315            }
316
317            // batch_tx/rx: Batch Files to Send for Embeddings
318            let (batch_files_tx, batch_files_rx) =
319                channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
320            let _batch_files_task = cx.background().spawn(async move {
321                let mut queue_len = 0;
322                let mut embeddings_queue = vec![];
323                while let Ok((worktree_id, indexed_file, document_spans)) =
324                    batch_files_rx.recv().await
325                {
326                    queue_len += &document_spans.len();
327                    embeddings_queue.push((worktree_id, indexed_file, document_spans));
328                    if queue_len >= EMBEDDINGS_BATCH_SIZE {
329                        embed_batch_tx.try_send(embeddings_queue).unwrap();
330                        embeddings_queue = vec![];
331                        queue_len = 0;
332                    }
333                }
334                if queue_len > 0 {
335                    embed_batch_tx.try_send(embeddings_queue).unwrap();
336                }
337            });
338
339            // parsing_files_tx/rx: Parsing Files to Embeddable Documents
340            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
341
342            let mut _parsing_files_tasks = Vec::new();
343            for _ in 0..cx.background().num_cpus() {
344                let fs = fs.clone();
345                let parsing_files_rx = parsing_files_rx.clone();
346                let batch_files_tx = batch_files_tx.clone();
347                _parsing_files_tasks.push(cx.background().spawn(async move {
348                    let mut parser = Parser::new();
349                    let mut cursor = QueryCursor::new();
350                    while let Ok(pending_file) = parsing_files_rx.recv().await {
351                        log::info!("Parsing File: {:?}", &pending_file.path);
352                        if let Some((indexed_file, document_spans)) = Self::index_file(
353                            &mut cursor,
354                            &mut parser,
355                            &fs,
356                            pending_file.language,
357                            pending_file.path.clone(),
358                            pending_file.modified_time,
359                        )
360                        .await
361                        .log_err()
362                        {
363                            batch_files_tx
364                                .try_send((
365                                    pending_file.worktree_db_id,
366                                    indexed_file,
367                                    document_spans,
368                                ))
369                                .unwrap();
370                        }
371                    }
372                }));
373            }
374
375            Self {
376                fs,
377                database_url,
378                embedding_provider,
379                language_registry,
380                db_update_tx,
381                parsing_files_tx,
382                _db_update_task,
383                _embed_batch_task,
384                _batch_files_task,
385                _parsing_files_tasks,
386                projects: HashMap::new(),
387            }
388        }))
389    }
390
391    async fn index_file(
392        cursor: &mut QueryCursor,
393        parser: &mut Parser,
394        fs: &Arc<dyn Fs>,
395        language: Arc<Language>,
396        file_path: PathBuf,
397        mtime: SystemTime,
398    ) -> Result<(IndexedFile, Vec<String>)> {
399        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
400        let embedding_config = grammar
401            .embedding_config
402            .as_ref()
403            .ok_or_else(|| anyhow!("no outline query"))?;
404
405        let content = fs.load(&file_path).await?;
406
407        parser.set_language(grammar.ts_language).unwrap();
408        let tree = parser
409            .parse(&content, None)
410            .ok_or_else(|| anyhow!("parsing failed"))?;
411
412        let mut documents = Vec::new();
413        let mut context_spans = Vec::new();
414        for mat in cursor.matches(
415            &embedding_config.query,
416            tree.root_node(),
417            content.as_bytes(),
418        ) {
419            let mut item_range = None;
420            let mut name_range = None;
421            for capture in mat.captures {
422                if capture.index == embedding_config.item_capture_ix {
423                    item_range = Some(capture.node.byte_range());
424                } else if capture.index == embedding_config.name_capture_ix {
425                    name_range = Some(capture.node.byte_range());
426                }
427            }
428
429            if let Some((item_range, name_range)) = item_range.zip(name_range) {
430                if let Some((item, name)) =
431                    content.get(item_range.clone()).zip(content.get(name_range))
432                {
433                    context_spans.push(item.to_string());
434                    documents.push(Document {
435                        name: name.to_string(),
436                        offset: item_range.start,
437                        embedding: Vec::new(),
438                    });
439                }
440            }
441        }
442
443        return Ok((
444            IndexedFile {
445                path: file_path,
446                mtime,
447                documents,
448            },
449            context_spans,
450        ));
451    }
452
453    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
454        let (tx, rx) = oneshot::channel();
455        self.db_update_tx
456            .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
457            .unwrap();
458        async move { rx.await? }
459    }
460
461    fn add_project(
462        &mut self,
463        project: ModelHandle<Project>,
464        cx: &mut ModelContext<Self>,
465    ) -> Task<Result<()>> {
466        let worktree_scans_complete = project
467            .read(cx)
468            .worktrees(cx)
469            .map(|worktree| {
470                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
471                async move {
472                    scan_complete.await;
473                }
474            })
475            .collect::<Vec<_>>();
476        let worktree_db_ids = project
477            .read(cx)
478            .worktrees(cx)
479            .map(|worktree| {
480                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
481            })
482            .collect::<Vec<_>>();
483
484        let fs = self.fs.clone();
485        let language_registry = self.language_registry.clone();
486        let database_url = self.database_url.clone();
487        let db_update_tx = self.db_update_tx.clone();
488        let parsing_files_tx = self.parsing_files_tx.clone();
489
490        cx.spawn(|this, mut cx| async move {
491            let t0 = Instant::now();
492            futures::future::join_all(worktree_scans_complete).await;
493
494            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
495            log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
496
497            if let Some(db_directory) = database_url.parent() {
498                fs.create_dir(db_directory).await.log_err();
499            }
500
501            let worktrees = project.read_with(&cx, |project, cx| {
502                project
503                    .worktrees(cx)
504                    .map(|worktree| worktree.read(cx).snapshot())
505                    .collect::<Vec<_>>()
506            });
507
508            // Here we query the worktree ids, and yet we dont have them elsewhere
509            // We likely want to clean up these datastructures
510            let (mut worktree_file_times, db_ids_by_worktree_id) = cx
511                .background()
512                .spawn({
513                    let worktrees = worktrees.clone();
514                    async move {
515                        let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
516                        let mut db_ids_by_worktree_id = HashMap::new();
517                        let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
518                            HashMap::new();
519                        for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
520                            let db_id = db_id?;
521                            db_ids_by_worktree_id.insert(worktree.id(), db_id);
522                            file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
523                        }
524                        anyhow::Ok((file_times, db_ids_by_worktree_id))
525                    }
526                })
527                .await?;
528
529            cx.background()
530                .spawn({
531                    let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
532                    let db_update_tx = db_update_tx.clone();
533                    let language_registry = language_registry.clone();
534                    let parsing_files_tx = parsing_files_tx.clone();
535                    async move {
536                        let t0 = Instant::now();
537                        for worktree in worktrees.into_iter() {
538                            let mut file_mtimes =
539                                worktree_file_times.remove(&worktree.id()).unwrap();
540                            for file in worktree.files(false, 0) {
541                                let absolute_path = worktree.absolutize(&file.path);
542
543                                if let Ok(language) = language_registry
544                                    .language_for_file(&absolute_path, None)
545                                    .await
546                                {
547                                    if language
548                                        .grammar()
549                                        .and_then(|grammar| grammar.embedding_config.as_ref())
550                                        .is_none()
551                                    {
552                                        continue;
553                                    }
554
555                                    let path_buf = file.path.to_path_buf();
556                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
557                                    let already_stored = stored_mtime
558                                        .map_or(false, |existing_mtime| {
559                                            existing_mtime == file.mtime
560                                        });
561
562                                    if !already_stored {
563                                        parsing_files_tx
564                                            .try_send(PendingFile {
565                                                worktree_db_id: db_ids_by_worktree_id
566                                                    [&worktree.id()],
567                                                path: path_buf,
568                                                language,
569                                                modified_time: file.mtime,
570                                            })
571                                            .unwrap();
572                                    }
573                                }
574                            }
575                            for file in file_mtimes.keys() {
576                                db_update_tx
577                                    .try_send(DbWrite::Delete {
578                                        worktree_id: db_ids_by_worktree_id[&worktree.id()],
579                                        path: file.to_owned(),
580                                    })
581                                    .unwrap();
582                            }
583                        }
584                        log::info!(
585                            "Parsing Worktree Completed in {:?}",
586                            t0.elapsed().as_millis()
587                        );
588                    }
589                })
590                .detach();
591
592            // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
593            this.update(&mut cx, |this, cx| {
594                // The below is managing for updated on save
595                // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
596                // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
597                let _subscription = cx.subscribe(&project, |this, project, event, cx| {
598                    if let Some(project_state) = this.projects.get(&project.downgrade()) {
599                        let mut project_state = project_state.borrow_mut();
600                        let worktree_db_ids = project_state.worktree_db_ids.clone();
601
602                        if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
603                        {
604                            // Get Worktree Object
605                            let worktree =
606                                project.read(cx).worktree_for_id(worktree_id.clone(), cx);
607                            if worktree.is_none() {
608                                return;
609                            }
610                            let worktree = worktree.unwrap();
611
612                            // Get Database
613                            let db_values = {
614                                if let Ok(db) =
615                                    VectorDatabase::new(this.database_url.to_string_lossy().into())
616                                {
617                                    let worktree_db_id: Option<i64> = {
618                                        let mut found_db_id = None;
619                                        for (w_id, db_id) in worktree_db_ids.into_iter() {
620                                            if &w_id == &worktree.read(cx).id() {
621                                                found_db_id = Some(db_id)
622                                            }
623                                        }
624                                        found_db_id
625                                    };
626                                    if worktree_db_id.is_none() {
627                                        return;
628                                    }
629                                    let worktree_db_id = worktree_db_id.unwrap();
630
631                                    let file_mtimes = db.get_file_mtimes(worktree_db_id);
632                                    if file_mtimes.is_err() {
633                                        return;
634                                    }
635
636                                    let file_mtimes = file_mtimes.unwrap();
637                                    Some((file_mtimes, worktree_db_id))
638                                } else {
639                                    return;
640                                }
641                            };
642
643                            if db_values.is_none() {
644                                return;
645                            }
646
647                            let (file_mtimes, worktree_db_id) = db_values.unwrap();
648
649                            // Iterate Through Changes
650                            let language_registry = this.language_registry.clone();
651                            let parsing_files_tx = this.parsing_files_tx.clone();
652
653                            smol::block_on(async move {
654                                for change in changes.into_iter() {
655                                    let change_path = change.0.clone();
656                                    // Skip if git ignored or symlink
657                                    if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
658                                        if entry.is_ignored || entry.is_symlink {
659                                            continue;
660                                        } else {
661                                            log::info!(
662                                                "Testing for Reindexing: {:?}",
663                                                &change_path
664                                            );
665                                        }
666                                    };
667
668                                    if let Ok(language) = language_registry
669                                        .language_for_file(&change_path.to_path_buf(), None)
670                                        .await
671                                    {
672                                        if language
673                                            .grammar()
674                                            .and_then(|grammar| grammar.embedding_config.as_ref())
675                                            .is_none()
676                                        {
677                                            continue;
678                                        }
679
680                                        if let Some(modified_time) = {
681                                            let metadata = change_path.metadata();
682                                            if metadata.is_err() {
683                                                None
684                                            } else {
685                                                let mtime = metadata.unwrap().modified();
686                                                if mtime.is_err() {
687                                                    None
688                                                } else {
689                                                    Some(mtime.unwrap())
690                                                }
691                                            }
692                                        } {
693                                            let existing_time =
694                                                file_mtimes.get(&change_path.to_path_buf());
695                                            let already_stored = existing_time
696                                                .map_or(false, |existing_time| {
697                                                    &modified_time != existing_time
698                                                });
699
700                                            let reindex_time = modified_time
701                                                + Duration::from_secs(REINDEXING_DELAY_SECONDS);
702
703                                            if !already_stored {
704                                                project_state.update_pending_files(
705                                                    PendingFile {
706                                                        path: change_path.to_path_buf(),
707                                                        modified_time,
708                                                        worktree_db_id,
709                                                        language: language.clone(),
710                                                    },
711                                                    reindex_time,
712                                                );
713
714                                                for file in project_state.get_outstanding_files() {
715                                                    parsing_files_tx.try_send(file).unwrap();
716                                                }
717                                            }
718                                        }
719                                    }
720                                }
721                            });
722                        };
723                    }
724                });
725
726                this.projects.insert(
727                    project.downgrade(),
728                    Rc::new(RefCell::new(ProjectState {
729                        pending_files: HashMap::new(),
730                        worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
731                        _subscription,
732                    })),
733                );
734            });
735
736            anyhow::Ok(())
737        })
738    }
739
740    pub fn search(
741        &mut self,
742        project: ModelHandle<Project>,
743        phrase: String,
744        limit: usize,
745        cx: &mut ModelContext<Self>,
746    ) -> Task<Result<Vec<SearchResult>>> {
747        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
748            state.borrow()
749        } else {
750            return Task::ready(Err(anyhow!("project not added")));
751        };
752
753        let worktree_db_ids = project
754            .read(cx)
755            .worktrees(cx)
756            .filter_map(|worktree| {
757                let worktree_id = worktree.read(cx).id();
758                project_state
759                    .worktree_db_ids
760                    .iter()
761                    .find_map(|(id, db_id)| {
762                        if *id == worktree_id {
763                            Some(*db_id)
764                        } else {
765                            None
766                        }
767                    })
768            })
769            .collect::<Vec<_>>();
770
771        let embedding_provider = self.embedding_provider.clone();
772        let database_url = self.database_url.clone();
773        cx.spawn(|this, cx| async move {
774            let documents = cx
775                .background()
776                .spawn(async move {
777                    let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
778
779                    let phrase_embedding = embedding_provider
780                        .embed_batch(vec![&phrase])
781                        .await?
782                        .into_iter()
783                        .next()
784                        .unwrap();
785
786                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
787                    database.for_each_document(&worktree_db_ids, |id, embedding| {
788                        let similarity = dot(&embedding.0, &phrase_embedding);
789                        let ix = match results.binary_search_by(|(_, s)| {
790                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
791                        }) {
792                            Ok(ix) => ix,
793                            Err(ix) => ix,
794                        };
795                        results.insert(ix, (id, similarity));
796                        results.truncate(limit);
797                    })?;
798
799                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
800                    database.get_documents_by_ids(&ids)
801                })
802                .await?;
803
804            this.read_with(&cx, |this, _| {
805                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
806                    state.borrow()
807                } else {
808                    return Err(anyhow!("project not added"));
809                };
810
811                Ok(documents
812                    .into_iter()
813                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
814                        let worktree_id =
815                            project_state
816                                .worktree_db_ids
817                                .iter()
818                                .find_map(|(id, db_id)| {
819                                    if *db_id == worktree_db_id {
820                                        Some(*id)
821                                    } else {
822                                        None
823                                    }
824                                })?;
825                        Some(SearchResult {
826                            worktree_id,
827                            name,
828                            offset,
829                            file_path,
830                        })
831                    })
832                    .collect())
833            })
834        })
835    }
836}
837
838impl Entity for VectorStore {
839    type Event = ();
840}
841
842fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
843    let len = vec_a.len();
844    assert_eq!(len, vec_b.len());
845
846    let mut result = 0.0;
847    unsafe {
848        matrixmultiply::sgemm(
849            1,
850            len,
851            1,
852            1.0,
853            vec_a.as_ptr(),
854            len as isize,
855            1,
856            vec_b.as_ptr(),
857            1,
858            len as isize,
859            0.0,
860            &mut result as *mut f32,
861            1,
862            1,
863        );
864    }
865    result
866}