vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::VectorDatabase;
 10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 11use futures::{channel::oneshot, Future};
 12use gpui::{
 13    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
 14    WeakModelHandle,
 15};
 16use language::{Language, LanguageRegistry};
 17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 18use project::{Fs, Project, WorktreeId};
 19use smol::channel;
 20use std::{
 21    cell::RefCell,
 22    cmp::Ordering,
 23    collections::HashMap,
 24    path::{Path, PathBuf},
 25    rc::Rc,
 26    sync::Arc,
 27    time::{Duration, Instant, SystemTime},
 28};
 29use tree_sitter::{Parser, QueryCursor};
 30use util::{
 31    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 32    http::HttpClient,
 33    paths::EMBEDDINGS_DIR,
 34    ResultExt,
 35};
 36use workspace::{Workspace, WorkspaceCreated};
 37
 38const REINDEXING_DELAY_SECONDS: u64 = 3;
 39const EMBEDDINGS_BATCH_SIZE: usize = 150;
 40
 41#[derive(Debug, Clone)]
 42pub struct Document {
 43    pub offset: usize,
 44    pub name: String,
 45    pub embedding: Vec<f32>,
 46}
 47
 48pub fn init(
 49    fs: Arc<dyn Fs>,
 50    http_client: Arc<dyn HttpClient>,
 51    language_registry: Arc<LanguageRegistry>,
 52    cx: &mut AppContext,
 53) {
 54    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 55        return;
 56    }
 57
 58    let db_file_path = EMBEDDINGS_DIR
 59        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 60        .join("embeddings_db");
 61
 62    cx.spawn(move |mut cx| async move {
 63        let vector_store = VectorStore::new(
 64            fs,
 65            db_file_path,
 66            // Arc::new(embedding::DummyEmbeddings {}),
 67            Arc::new(OpenAIEmbeddings {
 68                client: http_client,
 69            }),
 70            language_registry,
 71            cx.clone(),
 72        )
 73        .await?;
 74
 75        cx.update(|cx| {
 76            cx.subscribe_global::<WorkspaceCreated, _>({
 77                let vector_store = vector_store.clone();
 78                move |event, cx| {
 79                    let workspace = &event.0;
 80                    if let Some(workspace) = workspace.upgrade(cx) {
 81                        let project = workspace.read(cx).project().clone();
 82                        if project.read(cx).is_local() {
 83                            vector_store.update(cx, |store, cx| {
 84                                store.add_project(project, cx).detach();
 85                            });
 86                        }
 87                    }
 88                }
 89            })
 90            .detach();
 91
 92            cx.add_action({
 93                move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 94                    let vector_store = vector_store.clone();
 95                    workspace.toggle_modal(cx, |workspace, cx| {
 96                        let project = workspace.project().clone();
 97                        let workspace = cx.weak_handle();
 98                        cx.add_view(|cx| {
 99                            SemanticSearch::new(
100                                SemanticSearchDelegate::new(workspace, project, vector_store),
101                                cx,
102                            )
103                        })
104                    })
105                }
106            });
107
108            SemanticSearch::init(cx);
109        });
110
111        anyhow::Ok(())
112    })
113    .detach();
114}
115
116#[derive(Debug, Clone)]
117pub struct IndexedFile {
118    path: PathBuf,
119    mtime: SystemTime,
120    documents: Vec<Document>,
121}
122
123pub struct VectorStore {
124    fs: Arc<dyn Fs>,
125    database_url: Arc<PathBuf>,
126    embedding_provider: Arc<dyn EmbeddingProvider>,
127    language_registry: Arc<LanguageRegistry>,
128    db_update_tx: channel::Sender<DbWrite>,
129    parsing_files_tx: channel::Sender<PendingFile>,
130    _db_update_task: Task<()>,
131    _embed_batch_task: Vec<Task<()>>,
132    _batch_files_task: Task<()>,
133    _parsing_files_tasks: Vec<Task<()>>,
134    projects: HashMap<WeakModelHandle<Project>, Rc<RefCell<ProjectState>>>,
135}
136
137struct ProjectState {
138    worktree_db_ids: Vec<(WorktreeId, i64)>,
139    pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
140    _subscription: gpui::Subscription,
141}
142
143impl ProjectState {
144    fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
145        // If Pending File Already Exists, Replace it with the new one
146        // but keep the old indexing time
147        if let Some(old_file) = self
148            .pending_files
149            .remove(&pending_file.relative_path.clone())
150        {
151            self.pending_files.insert(
152                pending_file.relative_path.clone(),
153                (pending_file, old_file.1),
154            );
155        } else {
156            self.pending_files.insert(
157                pending_file.relative_path.clone(),
158                (pending_file, indexing_time),
159            );
160        };
161    }
162
163    fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
164        let mut outstanding_files = vec![];
165        let mut remove_keys = vec![];
166        for key in self.pending_files.keys().into_iter() {
167            if let Some(pending_details) = self.pending_files.get(key) {
168                let (pending_file, index_time) = pending_details;
169                if index_time <= &SystemTime::now() {
170                    outstanding_files.push(pending_file.clone());
171                    remove_keys.push(key.clone());
172                }
173            }
174        }
175
176        for key in remove_keys.iter() {
177            self.pending_files.remove(key);
178        }
179
180        return outstanding_files;
181    }
182}
183
184#[derive(Clone, Debug)]
185struct PendingFile {
186    worktree_db_id: i64,
187    relative_path: PathBuf,
188    absolute_path: PathBuf,
189    language: Arc<Language>,
190    modified_time: SystemTime,
191}
192
193#[derive(Debug, Clone)]
194pub struct SearchResult {
195    pub worktree_id: WorktreeId,
196    pub name: String,
197    pub offset: usize,
198    pub file_path: PathBuf,
199}
200
201enum DbWrite {
202    InsertFile {
203        worktree_id: i64,
204        indexed_file: IndexedFile,
205    },
206    Delete {
207        worktree_id: i64,
208        path: PathBuf,
209    },
210    FindOrCreateWorktree {
211        path: PathBuf,
212        sender: oneshot::Sender<Result<i64>>,
213    },
214}
215
216impl VectorStore {
217    async fn new(
218        fs: Arc<dyn Fs>,
219        database_url: PathBuf,
220        embedding_provider: Arc<dyn EmbeddingProvider>,
221        language_registry: Arc<LanguageRegistry>,
222        mut cx: AsyncAppContext,
223    ) -> Result<ModelHandle<Self>> {
224        let database_url = Arc::new(database_url);
225
226        let db = cx
227            .background()
228            .spawn({
229                let fs = fs.clone();
230                let database_url = database_url.clone();
231                async move {
232                    if let Some(db_directory) = database_url.parent() {
233                        fs.create_dir(db_directory).await.log_err();
234                    }
235
236                    let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
237                    anyhow::Ok(db)
238                }
239            })
240            .await?;
241
242        Ok(cx.add_model(|cx| {
243            // paths_tx -> embeddings_tx -> db_update_tx
244
245            //db_update_tx/rx: Updating Database
246            let (db_update_tx, db_update_rx) = channel::unbounded();
247            let _db_update_task = cx.background().spawn(async move {
248                while let Ok(job) = db_update_rx.recv().await {
249                    match job {
250                        DbWrite::InsertFile {
251                            worktree_id,
252                            indexed_file,
253                        } => {
254                            log::info!("Inserting Data for {:?}", &indexed_file.path);
255                            db.insert_file(worktree_id, indexed_file).log_err();
256                        }
257                        DbWrite::Delete { worktree_id, path } => {
258                            db.delete_file(worktree_id, path).log_err();
259                        }
260                        DbWrite::FindOrCreateWorktree { path, sender } => {
261                            let id = db.find_or_create_worktree(&path);
262                            sender.send(id).ok();
263                        }
264                    }
265                }
266            });
267
268            // embed_tx/rx: Embed Batch and Send to Database
269            let (embed_batch_tx, embed_batch_rx) =
270                channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
271            let mut _embed_batch_task = Vec::new();
272            for _ in 0..1 {
273                //cx.background().num_cpus() {
274                let db_update_tx = db_update_tx.clone();
275                let embed_batch_rx = embed_batch_rx.clone();
276                let embedding_provider = embedding_provider.clone();
277                _embed_batch_task.push(cx.background().spawn(async move {
278                    while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
279                        // Construct Batch
280                        let mut embeddings_queue = embeddings_queue.clone();
281                        let mut document_spans = vec![];
282                        for (_, _, document_span) in embeddings_queue.clone().into_iter() {
283                            document_spans.extend(document_span);
284                        }
285
286                        if let Ok(embeddings) = embedding_provider
287                            .embed_batch(document_spans.iter().map(|x| &**x).collect())
288                            .await
289                        {
290                            let mut i = 0;
291                            let mut j = 0;
292
293                            for embedding in embeddings.iter() {
294                                while embeddings_queue[i].1.documents.len() == j {
295                                    i += 1;
296                                    j = 0;
297                                }
298
299                                embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
300                                j += 1;
301                            }
302
303                            for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
304                                for document in indexed_file.documents.iter() {
305                                    // TODO: Update this so it doesn't panic
306                                    assert!(
307                                        document.embedding.len() > 0,
308                                        "Document Embedding Not Complete"
309                                    );
310                                }
311
312                                db_update_tx
313                                    .send(DbWrite::InsertFile {
314                                        worktree_id,
315                                        indexed_file,
316                                    })
317                                    .await
318                                    .unwrap();
319                            }
320                        }
321                    }
322                }))
323            }
324
325            // batch_tx/rx: Batch Files to Send for Embeddings
326            let (batch_files_tx, batch_files_rx) =
327                channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
328            let _batch_files_task = cx.background().spawn(async move {
329                let mut queue_len = 0;
330                let mut embeddings_queue = vec![];
331                while let Ok((worktree_id, indexed_file, document_spans)) =
332                    batch_files_rx.recv().await
333                {
334                    queue_len += &document_spans.len();
335                    embeddings_queue.push((worktree_id, indexed_file, document_spans));
336                    if queue_len >= EMBEDDINGS_BATCH_SIZE {
337                        embed_batch_tx.try_send(embeddings_queue).unwrap();
338                        embeddings_queue = vec![];
339                        queue_len = 0;
340                    }
341                }
342                if queue_len > 0 {
343                    embed_batch_tx.try_send(embeddings_queue).unwrap();
344                }
345            });
346
347            // parsing_files_tx/rx: Parsing Files to Embeddable Documents
348            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
349
350            let mut _parsing_files_tasks = Vec::new();
351            for _ in 0..cx.background().num_cpus() {
352                let fs = fs.clone();
353                let parsing_files_rx = parsing_files_rx.clone();
354                let batch_files_tx = batch_files_tx.clone();
355                _parsing_files_tasks.push(cx.background().spawn(async move {
356                    let mut parser = Parser::new();
357                    let mut cursor = QueryCursor::new();
358                    while let Ok(pending_file) = parsing_files_rx.recv().await {
359                        log::info!("Parsing File: {:?}", &pending_file.relative_path);
360                        if let Some((indexed_file, document_spans)) = Self::index_file(
361                            &mut cursor,
362                            &mut parser,
363                            &fs,
364                            pending_file.language,
365                            pending_file.relative_path.clone(),
366                            pending_file.absolute_path.clone(),
367                            pending_file.modified_time,
368                        )
369                        .await
370                        .log_err()
371                        {
372                            batch_files_tx
373                                .try_send((
374                                    pending_file.worktree_db_id,
375                                    indexed_file,
376                                    document_spans,
377                                ))
378                                .unwrap();
379                        }
380                    }
381                }));
382            }
383
384            Self {
385                fs,
386                database_url,
387                embedding_provider,
388                language_registry,
389                db_update_tx,
390                parsing_files_tx,
391                _db_update_task,
392                _embed_batch_task,
393                _batch_files_task,
394                _parsing_files_tasks,
395                projects: HashMap::new(),
396            }
397        }))
398    }
399
400    async fn index_file(
401        cursor: &mut QueryCursor,
402        parser: &mut Parser,
403        fs: &Arc<dyn Fs>,
404        language: Arc<Language>,
405        relative_file_path: PathBuf,
406        absolute_file_path: PathBuf,
407        mtime: SystemTime,
408    ) -> Result<(IndexedFile, Vec<String>)> {
409        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
410        let embedding_config = grammar
411            .embedding_config
412            .as_ref()
413            .ok_or_else(|| anyhow!("no outline query"))?;
414
415        let content = fs.load(&absolute_file_path).await?;
416
417        parser.set_language(grammar.ts_language).unwrap();
418        let tree = parser
419            .parse(&content, None)
420            .ok_or_else(|| anyhow!("parsing failed"))?;
421
422        let mut documents = Vec::new();
423        let mut context_spans = Vec::new();
424        for mat in cursor.matches(
425            &embedding_config.query,
426            tree.root_node(),
427            content.as_bytes(),
428        ) {
429            let mut item_range = None;
430            let mut name_range = None;
431            let mut context_range = None;
432            for capture in mat.captures {
433                if capture.index == embedding_config.item_capture_ix {
434                    item_range = Some(capture.node.byte_range());
435                } else if capture.index == embedding_config.name_capture_ix {
436                    name_range = Some(capture.node.byte_range());
437                }
438                if let Some(context_capture_ix) = embedding_config.context_capture_ix {
439                    if capture.index == context_capture_ix {
440                        context_range = Some(capture.node.byte_range());
441                    }
442                }
443            }
444
445            if let Some((item_range, name_range)) = item_range.zip(name_range) {
446                let mut context_data = String::new();
447                if let Some(context_range) = context_range {
448                    if let Some(context) = content.get(context_range.clone()) {
449                        context_data.push_str(context);
450                    }
451                }
452
453                if let Some((item, name)) =
454                    content.get(item_range.clone()).zip(content.get(name_range))
455                {
456                    context_spans.push(item.to_string());
457                    documents.push(Document {
458                        name: format!("{} {}", context_data.to_string(), name.to_string()),
459                        offset: item_range.start,
460                        embedding: Vec::new(),
461                    });
462                }
463            }
464        }
465
466        return Ok((
467            IndexedFile {
468                path: relative_file_path,
469                mtime,
470                documents,
471            },
472            context_spans,
473        ));
474    }
475
476    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
477        let (tx, rx) = oneshot::channel();
478        self.db_update_tx
479            .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
480            .unwrap();
481        async move { rx.await? }
482    }
483
484    fn add_project(
485        &mut self,
486        project: ModelHandle<Project>,
487        cx: &mut ModelContext<Self>,
488    ) -> Task<Result<()>> {
489        let worktree_scans_complete = project
490            .read(cx)
491            .worktrees(cx)
492            .map(|worktree| {
493                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
494                async move {
495                    scan_complete.await;
496                }
497            })
498            .collect::<Vec<_>>();
499        let worktree_db_ids = project
500            .read(cx)
501            .worktrees(cx)
502            .map(|worktree| {
503                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
504            })
505            .collect::<Vec<_>>();
506
507        let fs = self.fs.clone();
508        let language_registry = self.language_registry.clone();
509        let database_url = self.database_url.clone();
510        let db_update_tx = self.db_update_tx.clone();
511        let parsing_files_tx = self.parsing_files_tx.clone();
512
513        cx.spawn(|this, mut cx| async move {
514            let t0 = Instant::now();
515            futures::future::join_all(worktree_scans_complete).await;
516
517            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
518            log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
519
520            if let Some(db_directory) = database_url.parent() {
521                fs.create_dir(db_directory).await.log_err();
522            }
523
524            let worktrees = project.read_with(&cx, |project, cx| {
525                project
526                    .worktrees(cx)
527                    .map(|worktree| worktree.read(cx).snapshot())
528                    .collect::<Vec<_>>()
529            });
530
531            // Here we query the worktree ids, and yet we dont have them elsewhere
532            // We likely want to clean up these datastructures
533            let (mut worktree_file_times, db_ids_by_worktree_id) = cx
534                .background()
535                .spawn({
536                    let worktrees = worktrees.clone();
537                    async move {
538                        let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
539                        let mut db_ids_by_worktree_id = HashMap::new();
540                        let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
541                            HashMap::new();
542                        for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
543                            let db_id = db_id?;
544                            db_ids_by_worktree_id.insert(worktree.id(), db_id);
545                            file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
546                        }
547                        anyhow::Ok((file_times, db_ids_by_worktree_id))
548                    }
549                })
550                .await?;
551
552            cx.background()
553                .spawn({
554                    let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
555                    let db_update_tx = db_update_tx.clone();
556                    let language_registry = language_registry.clone();
557                    let parsing_files_tx = parsing_files_tx.clone();
558                    async move {
559                        let t0 = Instant::now();
560                        for worktree in worktrees.into_iter() {
561                            let mut file_mtimes =
562                                worktree_file_times.remove(&worktree.id()).unwrap();
563                            for file in worktree.files(false, 0) {
564                                let absolute_path = worktree.absolutize(&file.path);
565
566                                if let Ok(language) = language_registry
567                                    .language_for_file(&absolute_path, None)
568                                    .await
569                                {
570                                    if language
571                                        .grammar()
572                                        .and_then(|grammar| grammar.embedding_config.as_ref())
573                                        .is_none()
574                                    {
575                                        continue;
576                                    }
577
578                                    let path_buf = file.path.to_path_buf();
579                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
580                                    let already_stored = stored_mtime
581                                        .map_or(false, |existing_mtime| {
582                                            existing_mtime == file.mtime
583                                        });
584
585                                    if !already_stored {
586                                        parsing_files_tx
587                                            .try_send(PendingFile {
588                                                worktree_db_id: db_ids_by_worktree_id
589                                                    [&worktree.id()],
590                                                relative_path: path_buf,
591                                                absolute_path,
592                                                language,
593                                                modified_time: file.mtime,
594                                            })
595                                            .unwrap();
596                                    }
597                                }
598                            }
599                            for file in file_mtimes.keys() {
600                                db_update_tx
601                                    .try_send(DbWrite::Delete {
602                                        worktree_id: db_ids_by_worktree_id[&worktree.id()],
603                                        path: file.to_owned(),
604                                    })
605                                    .unwrap();
606                            }
607                        }
608                        log::info!(
609                            "Parsing Worktree Completed in {:?}",
610                            t0.elapsed().as_millis()
611                        );
612                    }
613                })
614                .detach();
615
616            // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
617            this.update(&mut cx, |this, cx| {
618                // The below is managing for updated on save
619                // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
620                // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
621                let _subscription = cx.subscribe(&project, |this, project, event, cx| {
622                    if let Some(project_state) = this.projects.get(&project.downgrade()) {
623                        let mut project_state = project_state.borrow_mut();
624                        let worktree_db_ids = project_state.worktree_db_ids.clone();
625
626                        if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
627                        {
628                            // Get Worktree Object
629                            let worktree =
630                                project.read(cx).worktree_for_id(worktree_id.clone(), cx);
631                            if worktree.is_none() {
632                                return;
633                            }
634                            let worktree = worktree.unwrap();
635
636                            // Get Database
637                            let db_values = {
638                                if let Ok(db) =
639                                    VectorDatabase::new(this.database_url.to_string_lossy().into())
640                                {
641                                    let worktree_db_id: Option<i64> = {
642                                        let mut found_db_id = None;
643                                        for (w_id, db_id) in worktree_db_ids.into_iter() {
644                                            if &w_id == &worktree.read(cx).id() {
645                                                found_db_id = Some(db_id)
646                                            }
647                                        }
648                                        found_db_id
649                                    };
650                                    if worktree_db_id.is_none() {
651                                        return;
652                                    }
653                                    let worktree_db_id = worktree_db_id.unwrap();
654
655                                    let file_mtimes = db.get_file_mtimes(worktree_db_id);
656                                    if file_mtimes.is_err() {
657                                        return;
658                                    }
659
660                                    let file_mtimes = file_mtimes.unwrap();
661                                    Some((file_mtimes, worktree_db_id))
662                                } else {
663                                    return;
664                                }
665                            };
666
667                            if db_values.is_none() {
668                                return;
669                            }
670
671                            let (file_mtimes, worktree_db_id) = db_values.unwrap();
672
673                            // Iterate Through Changes
674                            let language_registry = this.language_registry.clone();
675                            let parsing_files_tx = this.parsing_files_tx.clone();
676
677                            smol::block_on(async move {
678                                for change in changes.into_iter() {
679                                    let change_path = change.0.clone();
680                                    let absolute_path = worktree.read(cx).absolutize(&change_path);
681                                    // Skip if git ignored or symlink
682                                    if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
683                                        if entry.is_ignored || entry.is_symlink {
684                                            continue;
685                                        } else {
686                                            log::info!(
687                                                "Testing for Reindexing: {:?}",
688                                                &change_path
689                                            );
690                                        }
691                                    };
692
693                                    if let Ok(language) = language_registry
694                                        .language_for_file(&change_path.to_path_buf(), None)
695                                        .await
696                                    {
697                                        if language
698                                            .grammar()
699                                            .and_then(|grammar| grammar.embedding_config.as_ref())
700                                            .is_none()
701                                        {
702                                            continue;
703                                        }
704
705                                        if let Some(modified_time) = {
706                                            let metadata = change_path.metadata();
707                                            if metadata.is_err() {
708                                                None
709                                            } else {
710                                                let mtime = metadata.unwrap().modified();
711                                                if mtime.is_err() {
712                                                    None
713                                                } else {
714                                                    Some(mtime.unwrap())
715                                                }
716                                            }
717                                        } {
718                                            let existing_time =
719                                                file_mtimes.get(&change_path.to_path_buf());
720                                            let already_stored = existing_time
721                                                .map_or(false, |existing_time| {
722                                                    &modified_time != existing_time
723                                                });
724
725                                            let reindex_time = modified_time
726                                                + Duration::from_secs(REINDEXING_DELAY_SECONDS);
727
728                                            if !already_stored {
729                                                project_state.update_pending_files(
730                                                    PendingFile {
731                                                        relative_path: change_path.to_path_buf(),
732                                                        absolute_path,
733                                                        modified_time,
734                                                        worktree_db_id,
735                                                        language: language.clone(),
736                                                    },
737                                                    reindex_time,
738                                                );
739
740                                                for file in project_state.get_outstanding_files() {
741                                                    parsing_files_tx.try_send(file).unwrap();
742                                                }
743                                            }
744                                        }
745                                    }
746                                }
747                            });
748                        };
749                    }
750                });
751
752                this.projects.insert(
753                    project.downgrade(),
754                    Rc::new(RefCell::new(ProjectState {
755                        pending_files: HashMap::new(),
756                        worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
757                        _subscription,
758                    })),
759                );
760            });
761
762            anyhow::Ok(())
763        })
764    }
765
766    pub fn search(
767        &mut self,
768        project: ModelHandle<Project>,
769        phrase: String,
770        limit: usize,
771        cx: &mut ModelContext<Self>,
772    ) -> Task<Result<Vec<SearchResult>>> {
773        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
774            state.borrow()
775        } else {
776            return Task::ready(Err(anyhow!("project not added")));
777        };
778
779        let worktree_db_ids = project
780            .read(cx)
781            .worktrees(cx)
782            .filter_map(|worktree| {
783                let worktree_id = worktree.read(cx).id();
784                project_state
785                    .worktree_db_ids
786                    .iter()
787                    .find_map(|(id, db_id)| {
788                        if *id == worktree_id {
789                            Some(*db_id)
790                        } else {
791                            None
792                        }
793                    })
794            })
795            .collect::<Vec<_>>();
796
797        let embedding_provider = self.embedding_provider.clone();
798        let database_url = self.database_url.clone();
799        cx.spawn(|this, cx| async move {
800            let documents = cx
801                .background()
802                .spawn(async move {
803                    let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
804
805                    let phrase_embedding = embedding_provider
806                        .embed_batch(vec![&phrase])
807                        .await?
808                        .into_iter()
809                        .next()
810                        .unwrap();
811
812                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
813                    database.for_each_document(&worktree_db_ids, |id, embedding| {
814                        let similarity = dot(&embedding.0, &phrase_embedding);
815                        let ix = match results.binary_search_by(|(_, s)| {
816                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
817                        }) {
818                            Ok(ix) => ix,
819                            Err(ix) => ix,
820                        };
821                        results.insert(ix, (id, similarity));
822                        results.truncate(limit);
823                    })?;
824
825                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
826                    database.get_documents_by_ids(&ids)
827                })
828                .await?;
829
830            this.read_with(&cx, |this, _| {
831                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
832                    state.borrow()
833                } else {
834                    return Err(anyhow!("project not added"));
835                };
836
837                Ok(documents
838                    .into_iter()
839                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
840                        let worktree_id =
841                            project_state
842                                .worktree_db_ids
843                                .iter()
844                                .find_map(|(id, db_id)| {
845                                    if *db_id == worktree_db_id {
846                                        Some(*id)
847                                    } else {
848                                        None
849                                    }
850                                })?;
851                        Some(SearchResult {
852                            worktree_id,
853                            name,
854                            offset,
855                            file_path,
856                        })
857                    })
858                    .collect())
859            })
860        })
861    }
862}
863
864impl Entity for VectorStore {
865    type Event = ();
866}
867
868fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
869    let len = vec_a.len();
870    assert_eq!(len, vec_b.len());
871
872    let mut result = 0.0;
873    unsafe {
874        matrixmultiply::sgemm(
875            1,
876            len,
877            1,
878            1.0,
879            vec_a.as_ptr(),
880            len as isize,
881            1,
882            vec_b.as_ptr(),
883            1,
884            len as isize,
885            0.0,
886            &mut result as *mut f32,
887            1,
888            1,
889        );
890    }
891    result
892}