1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::VectorDatabase;
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use futures::{channel::oneshot, Future};
12use gpui::{
13 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
14 WeakModelHandle,
15};
16use language::{Language, LanguageRegistry};
17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
18use project::{Fs, Project, WorktreeId};
19use smol::channel;
20use std::{
21 cell::RefCell,
22 cmp::Ordering,
23 collections::HashMap,
24 path::{Path, PathBuf},
25 rc::Rc,
26 sync::Arc,
27 time::{Duration, Instant, SystemTime},
28};
29use tree_sitter::{Parser, QueryCursor};
30use util::{
31 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
32 http::HttpClient,
33 paths::EMBEDDINGS_DIR,
34 ResultExt,
35};
36use workspace::{Workspace, WorkspaceCreated};
37
38const REINDEXING_DELAY_SECONDS: u64 = 3;
39const EMBEDDINGS_BATCH_SIZE: usize = 150;
40
41#[derive(Debug, Clone)]
42pub struct Document {
43 pub offset: usize,
44 pub name: String,
45 pub embedding: Vec<f32>,
46}
47
48pub fn init(
49 fs: Arc<dyn Fs>,
50 http_client: Arc<dyn HttpClient>,
51 language_registry: Arc<LanguageRegistry>,
52 cx: &mut AppContext,
53) {
54 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
55 return;
56 }
57
58 let db_file_path = EMBEDDINGS_DIR
59 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
60 .join("embeddings_db");
61
62 cx.spawn(move |mut cx| async move {
63 let vector_store = VectorStore::new(
64 fs,
65 db_file_path,
66 // Arc::new(embedding::DummyEmbeddings {}),
67 Arc::new(OpenAIEmbeddings {
68 client: http_client,
69 }),
70 language_registry,
71 cx.clone(),
72 )
73 .await?;
74
75 cx.update(|cx| {
76 cx.subscribe_global::<WorkspaceCreated, _>({
77 let vector_store = vector_store.clone();
78 move |event, cx| {
79 let workspace = &event.0;
80 if let Some(workspace) = workspace.upgrade(cx) {
81 let project = workspace.read(cx).project().clone();
82 if project.read(cx).is_local() {
83 vector_store.update(cx, |store, cx| {
84 store.add_project(project, cx).detach();
85 });
86 }
87 }
88 }
89 })
90 .detach();
91
92 cx.add_action({
93 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
94 let vector_store = vector_store.clone();
95 workspace.toggle_modal(cx, |workspace, cx| {
96 let project = workspace.project().clone();
97 let workspace = cx.weak_handle();
98 cx.add_view(|cx| {
99 SemanticSearch::new(
100 SemanticSearchDelegate::new(workspace, project, vector_store),
101 cx,
102 )
103 })
104 })
105 }
106 });
107
108 SemanticSearch::init(cx);
109 });
110
111 anyhow::Ok(())
112 })
113 .detach();
114}
115
116#[derive(Debug, Clone)]
117pub struct IndexedFile {
118 path: PathBuf,
119 mtime: SystemTime,
120 documents: Vec<Document>,
121}
122
123pub struct VectorStore {
124 fs: Arc<dyn Fs>,
125 database_url: Arc<PathBuf>,
126 embedding_provider: Arc<dyn EmbeddingProvider>,
127 language_registry: Arc<LanguageRegistry>,
128 db_update_tx: channel::Sender<DbWrite>,
129 parsing_files_tx: channel::Sender<PendingFile>,
130 _db_update_task: Task<()>,
131 _embed_batch_task: Vec<Task<()>>,
132 _batch_files_task: Task<()>,
133 _parsing_files_tasks: Vec<Task<()>>,
134 projects: HashMap<WeakModelHandle<Project>, Rc<RefCell<ProjectState>>>,
135}
136
137struct ProjectState {
138 worktree_db_ids: Vec<(WorktreeId, i64)>,
139 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
140 _subscription: gpui::Subscription,
141}
142
143impl ProjectState {
144 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
145 // If Pending File Already Exists, Replace it with the new one
146 // but keep the old indexing time
147 if let Some(old_file) = self.pending_files.remove(&pending_file.path.clone()) {
148 self.pending_files
149 .insert(pending_file.path.clone(), (pending_file, old_file.1));
150 } else {
151 self.pending_files
152 .insert(pending_file.path.clone(), (pending_file, indexing_time));
153 };
154 }
155
156 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
157 let mut outstanding_files = vec![];
158 let mut remove_keys = vec![];
159 for key in self.pending_files.keys().into_iter() {
160 if let Some(pending_details) = self.pending_files.get(key) {
161 let (pending_file, index_time) = pending_details;
162 if index_time <= &SystemTime::now() {
163 outstanding_files.push(pending_file.clone());
164 remove_keys.push(key.clone());
165 }
166 }
167 }
168
169 for key in remove_keys.iter() {
170 self.pending_files.remove(key);
171 }
172
173 return outstanding_files;
174 }
175}
176
177#[derive(Clone, Debug)]
178struct PendingFile {
179 worktree_db_id: i64,
180 path: PathBuf,
181 language: Arc<Language>,
182 modified_time: SystemTime,
183}
184
185#[derive(Debug, Clone)]
186pub struct SearchResult {
187 pub worktree_id: WorktreeId,
188 pub name: String,
189 pub offset: usize,
190 pub file_path: PathBuf,
191}
192
193enum DbWrite {
194 InsertFile {
195 worktree_id: i64,
196 indexed_file: IndexedFile,
197 },
198 Delete {
199 worktree_id: i64,
200 path: PathBuf,
201 },
202 FindOrCreateWorktree {
203 path: PathBuf,
204 sender: oneshot::Sender<Result<i64>>,
205 },
206}
207
208impl VectorStore {
209 async fn new(
210 fs: Arc<dyn Fs>,
211 database_url: PathBuf,
212 embedding_provider: Arc<dyn EmbeddingProvider>,
213 language_registry: Arc<LanguageRegistry>,
214 mut cx: AsyncAppContext,
215 ) -> Result<ModelHandle<Self>> {
216 let database_url = Arc::new(database_url);
217
218 let db = cx
219 .background()
220 .spawn({
221 let fs = fs.clone();
222 let database_url = database_url.clone();
223 async move {
224 if let Some(db_directory) = database_url.parent() {
225 fs.create_dir(db_directory).await.log_err();
226 }
227
228 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
229 anyhow::Ok(db)
230 }
231 })
232 .await?;
233
234 Ok(cx.add_model(|cx| {
235 // paths_tx -> embeddings_tx -> db_update_tx
236
237 //db_update_tx/rx: Updating Database
238 let (db_update_tx, db_update_rx) = channel::unbounded();
239 let _db_update_task = cx.background().spawn(async move {
240 while let Ok(job) = db_update_rx.recv().await {
241 match job {
242 DbWrite::InsertFile {
243 worktree_id,
244 indexed_file,
245 } => {
246 log::info!("Inserting Data for {:?}", &indexed_file.path);
247 db.insert_file(worktree_id, indexed_file).log_err();
248 }
249 DbWrite::Delete { worktree_id, path } => {
250 db.delete_file(worktree_id, path).log_err();
251 }
252 DbWrite::FindOrCreateWorktree { path, sender } => {
253 let id = db.find_or_create_worktree(&path);
254 sender.send(id).ok();
255 }
256 }
257 }
258 });
259
260 // embed_tx/rx: Embed Batch and Send to Database
261 let (embed_batch_tx, embed_batch_rx) =
262 channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
263 let mut _embed_batch_task = Vec::new();
264 for _ in 0..1 {
265 //cx.background().num_cpus() {
266 let db_update_tx = db_update_tx.clone();
267 let embed_batch_rx = embed_batch_rx.clone();
268 let embedding_provider = embedding_provider.clone();
269 _embed_batch_task.push(cx.background().spawn(async move {
270 while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
271 // Construct Batch
272 let mut embeddings_queue = embeddings_queue.clone();
273 let mut document_spans = vec![];
274 for (_, _, document_span) in embeddings_queue.clone().into_iter() {
275 document_spans.extend(document_span);
276 }
277
278 if let Ok(embeddings) = embedding_provider
279 .embed_batch(document_spans.iter().map(|x| &**x).collect())
280 .await
281 {
282 let mut i = 0;
283 let mut j = 0;
284
285 for embedding in embeddings.iter() {
286 while embeddings_queue[i].1.documents.len() == j {
287 i += 1;
288 j = 0;
289 }
290
291 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
292 j += 1;
293 }
294
295 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
296 for document in indexed_file.documents.iter() {
297 // TODO: Update this so it doesn't panic
298 assert!(
299 document.embedding.len() > 0,
300 "Document Embedding Not Complete"
301 );
302 }
303
304 db_update_tx
305 .send(DbWrite::InsertFile {
306 worktree_id,
307 indexed_file,
308 })
309 .await
310 .unwrap();
311 }
312 }
313 }
314 }))
315 }
316
317 // batch_tx/rx: Batch Files to Send for Embeddings
318 let (batch_files_tx, batch_files_rx) =
319 channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
320 let _batch_files_task = cx.background().spawn(async move {
321 let mut queue_len = 0;
322 let mut embeddings_queue = vec![];
323 while let Ok((worktree_id, indexed_file, document_spans)) =
324 batch_files_rx.recv().await
325 {
326 queue_len += &document_spans.len();
327 embeddings_queue.push((worktree_id, indexed_file, document_spans));
328 if queue_len >= EMBEDDINGS_BATCH_SIZE {
329 embed_batch_tx.try_send(embeddings_queue).unwrap();
330 embeddings_queue = vec![];
331 queue_len = 0;
332 }
333 }
334 if queue_len > 0 {
335 embed_batch_tx.try_send(embeddings_queue).unwrap();
336 }
337 });
338
339 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
340 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
341
342 let mut _parsing_files_tasks = Vec::new();
343 for _ in 0..cx.background().num_cpus() {
344 let fs = fs.clone();
345 let parsing_files_rx = parsing_files_rx.clone();
346 let batch_files_tx = batch_files_tx.clone();
347 _parsing_files_tasks.push(cx.background().spawn(async move {
348 let mut parser = Parser::new();
349 let mut cursor = QueryCursor::new();
350 while let Ok(pending_file) = parsing_files_rx.recv().await {
351 log::info!("Parsing File: {:?}", &pending_file.path);
352 if let Some((indexed_file, document_spans)) = Self::index_file(
353 &mut cursor,
354 &mut parser,
355 &fs,
356 pending_file.language,
357 pending_file.path.clone(),
358 pending_file.modified_time,
359 )
360 .await
361 .log_err()
362 {
363 batch_files_tx
364 .try_send((
365 pending_file.worktree_db_id,
366 indexed_file,
367 document_spans,
368 ))
369 .unwrap();
370 }
371 }
372 }));
373 }
374
375 Self {
376 fs,
377 database_url,
378 embedding_provider,
379 language_registry,
380 db_update_tx,
381 parsing_files_tx,
382 _db_update_task,
383 _embed_batch_task,
384 _batch_files_task,
385 _parsing_files_tasks,
386 projects: HashMap::new(),
387 }
388 }))
389 }
390
391 async fn index_file(
392 cursor: &mut QueryCursor,
393 parser: &mut Parser,
394 fs: &Arc<dyn Fs>,
395 language: Arc<Language>,
396 file_path: PathBuf,
397 mtime: SystemTime,
398 ) -> Result<(IndexedFile, Vec<String>)> {
399 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
400 let embedding_config = grammar
401 .embedding_config
402 .as_ref()
403 .ok_or_else(|| anyhow!("no outline query"))?;
404
405 let content = fs.load(&file_path).await?;
406
407 parser.set_language(grammar.ts_language).unwrap();
408 let tree = parser
409 .parse(&content, None)
410 .ok_or_else(|| anyhow!("parsing failed"))?;
411
412 let mut documents = Vec::new();
413 let mut context_spans = Vec::new();
414 for mat in cursor.matches(
415 &embedding_config.query,
416 tree.root_node(),
417 content.as_bytes(),
418 ) {
419 let mut item_range = None;
420 let mut name_range = None;
421 for capture in mat.captures {
422 if capture.index == embedding_config.item_capture_ix {
423 item_range = Some(capture.node.byte_range());
424 } else if capture.index == embedding_config.name_capture_ix {
425 name_range = Some(capture.node.byte_range());
426 }
427 }
428
429 if let Some((item_range, name_range)) = item_range.zip(name_range) {
430 if let Some((item, name)) =
431 content.get(item_range.clone()).zip(content.get(name_range))
432 {
433 context_spans.push(item.to_string());
434 documents.push(Document {
435 name: name.to_string(),
436 offset: item_range.start,
437 embedding: Vec::new(),
438 });
439 }
440 }
441 }
442
443 return Ok((
444 IndexedFile {
445 path: file_path,
446 mtime,
447 documents,
448 },
449 context_spans,
450 ));
451 }
452
453 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
454 let (tx, rx) = oneshot::channel();
455 self.db_update_tx
456 .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
457 .unwrap();
458 async move { rx.await? }
459 }
460
461 fn add_project(
462 &mut self,
463 project: ModelHandle<Project>,
464 cx: &mut ModelContext<Self>,
465 ) -> Task<Result<()>> {
466 let worktree_scans_complete = project
467 .read(cx)
468 .worktrees(cx)
469 .map(|worktree| {
470 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
471 async move {
472 scan_complete.await;
473 }
474 })
475 .collect::<Vec<_>>();
476 let worktree_db_ids = project
477 .read(cx)
478 .worktrees(cx)
479 .map(|worktree| {
480 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
481 })
482 .collect::<Vec<_>>();
483
484 let fs = self.fs.clone();
485 let language_registry = self.language_registry.clone();
486 let database_url = self.database_url.clone();
487 let db_update_tx = self.db_update_tx.clone();
488 let parsing_files_tx = self.parsing_files_tx.clone();
489
490 cx.spawn(|this, mut cx| async move {
491 let t0 = Instant::now();
492 futures::future::join_all(worktree_scans_complete).await;
493
494 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
495 log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
496
497 if let Some(db_directory) = database_url.parent() {
498 fs.create_dir(db_directory).await.log_err();
499 }
500
501 let worktrees = project.read_with(&cx, |project, cx| {
502 project
503 .worktrees(cx)
504 .map(|worktree| worktree.read(cx).snapshot())
505 .collect::<Vec<_>>()
506 });
507
508 // Here we query the worktree ids, and yet we dont have them elsewhere
509 // We likely want to clean up these datastructures
510 let (mut worktree_file_times, db_ids_by_worktree_id) = cx
511 .background()
512 .spawn({
513 let worktrees = worktrees.clone();
514 async move {
515 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
516 let mut db_ids_by_worktree_id = HashMap::new();
517 let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
518 HashMap::new();
519 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
520 let db_id = db_id?;
521 db_ids_by_worktree_id.insert(worktree.id(), db_id);
522 file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
523 }
524 anyhow::Ok((file_times, db_ids_by_worktree_id))
525 }
526 })
527 .await?;
528
529 cx.background()
530 .spawn({
531 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
532 let db_update_tx = db_update_tx.clone();
533 let language_registry = language_registry.clone();
534 let parsing_files_tx = parsing_files_tx.clone();
535 async move {
536 let t0 = Instant::now();
537 for worktree in worktrees.into_iter() {
538 let mut file_mtimes =
539 worktree_file_times.remove(&worktree.id()).unwrap();
540 for file in worktree.files(false, 0) {
541 let absolute_path = worktree.absolutize(&file.path);
542
543 if let Ok(language) = language_registry
544 .language_for_file(&absolute_path, None)
545 .await
546 {
547 if language
548 .grammar()
549 .and_then(|grammar| grammar.embedding_config.as_ref())
550 .is_none()
551 {
552 continue;
553 }
554
555 let path_buf = file.path.to_path_buf();
556 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
557 let already_stored = stored_mtime
558 .map_or(false, |existing_mtime| {
559 existing_mtime == file.mtime
560 });
561
562 if !already_stored {
563 parsing_files_tx
564 .try_send(PendingFile {
565 worktree_db_id: db_ids_by_worktree_id
566 [&worktree.id()],
567 path: path_buf,
568 language,
569 modified_time: file.mtime,
570 })
571 .unwrap();
572 }
573 }
574 }
575 for file in file_mtimes.keys() {
576 db_update_tx
577 .try_send(DbWrite::Delete {
578 worktree_id: db_ids_by_worktree_id[&worktree.id()],
579 path: file.to_owned(),
580 })
581 .unwrap();
582 }
583 }
584 log::info!(
585 "Parsing Worktree Completed in {:?}",
586 t0.elapsed().as_millis()
587 );
588 }
589 })
590 .detach();
591
592 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
593 this.update(&mut cx, |this, cx| {
594 // The below is managing for updated on save
595 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
596 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
597 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
598 if let Some(project_state) = this.projects.get(&project.downgrade()) {
599 let mut project_state = project_state.borrow_mut();
600 let worktree_db_ids = project_state.worktree_db_ids.clone();
601
602 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
603 {
604 // Get Worktree Object
605 let worktree =
606 project.read(cx).worktree_for_id(worktree_id.clone(), cx);
607 if worktree.is_none() {
608 return;
609 }
610 let worktree = worktree.unwrap();
611
612 // Get Database
613 let db_values = {
614 if let Ok(db) =
615 VectorDatabase::new(this.database_url.to_string_lossy().into())
616 {
617 let worktree_db_id: Option<i64> = {
618 let mut found_db_id = None;
619 for (w_id, db_id) in worktree_db_ids.into_iter() {
620 if &w_id == &worktree.read(cx).id() {
621 found_db_id = Some(db_id)
622 }
623 }
624 found_db_id
625 };
626 if worktree_db_id.is_none() {
627 return;
628 }
629 let worktree_db_id = worktree_db_id.unwrap();
630
631 let file_mtimes = db.get_file_mtimes(worktree_db_id);
632 if file_mtimes.is_err() {
633 return;
634 }
635
636 let file_mtimes = file_mtimes.unwrap();
637 Some((file_mtimes, worktree_db_id))
638 } else {
639 return;
640 }
641 };
642
643 if db_values.is_none() {
644 return;
645 }
646
647 let (file_mtimes, worktree_db_id) = db_values.unwrap();
648
649 // Iterate Through Changes
650 let language_registry = this.language_registry.clone();
651 let parsing_files_tx = this.parsing_files_tx.clone();
652
653 smol::block_on(async move {
654 for change in changes.into_iter() {
655 let change_path = change.0.clone();
656 // Skip if git ignored or symlink
657 if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
658 if entry.is_ignored || entry.is_symlink {
659 continue;
660 } else {
661 log::info!(
662 "Testing for Reindexing: {:?}",
663 &change_path
664 );
665 }
666 };
667
668 if let Ok(language) = language_registry
669 .language_for_file(&change_path.to_path_buf(), None)
670 .await
671 {
672 if language
673 .grammar()
674 .and_then(|grammar| grammar.embedding_config.as_ref())
675 .is_none()
676 {
677 continue;
678 }
679
680 if let Some(modified_time) = {
681 let metadata = change_path.metadata();
682 if metadata.is_err() {
683 None
684 } else {
685 let mtime = metadata.unwrap().modified();
686 if mtime.is_err() {
687 None
688 } else {
689 Some(mtime.unwrap())
690 }
691 }
692 } {
693 let existing_time =
694 file_mtimes.get(&change_path.to_path_buf());
695 let already_stored = existing_time
696 .map_or(false, |existing_time| {
697 &modified_time != existing_time
698 });
699
700 let reindex_time = modified_time
701 + Duration::from_secs(REINDEXING_DELAY_SECONDS);
702
703 if !already_stored {
704 project_state.update_pending_files(
705 PendingFile {
706 path: change_path.to_path_buf(),
707 modified_time,
708 worktree_db_id,
709 language: language.clone(),
710 },
711 reindex_time,
712 );
713
714 for file in project_state.get_outstanding_files() {
715 parsing_files_tx.try_send(file).unwrap();
716 }
717 }
718 }
719 }
720 }
721 });
722 };
723 }
724 });
725
726 this.projects.insert(
727 project.downgrade(),
728 Rc::new(RefCell::new(ProjectState {
729 pending_files: HashMap::new(),
730 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
731 _subscription,
732 })),
733 );
734 });
735
736 anyhow::Ok(())
737 })
738 }
739
740 pub fn search(
741 &mut self,
742 project: ModelHandle<Project>,
743 phrase: String,
744 limit: usize,
745 cx: &mut ModelContext<Self>,
746 ) -> Task<Result<Vec<SearchResult>>> {
747 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
748 state.borrow()
749 } else {
750 return Task::ready(Err(anyhow!("project not added")));
751 };
752
753 let worktree_db_ids = project
754 .read(cx)
755 .worktrees(cx)
756 .filter_map(|worktree| {
757 let worktree_id = worktree.read(cx).id();
758 project_state
759 .worktree_db_ids
760 .iter()
761 .find_map(|(id, db_id)| {
762 if *id == worktree_id {
763 Some(*db_id)
764 } else {
765 None
766 }
767 })
768 })
769 .collect::<Vec<_>>();
770
771 let embedding_provider = self.embedding_provider.clone();
772 let database_url = self.database_url.clone();
773 cx.spawn(|this, cx| async move {
774 let documents = cx
775 .background()
776 .spawn(async move {
777 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
778
779 let phrase_embedding = embedding_provider
780 .embed_batch(vec![&phrase])
781 .await?
782 .into_iter()
783 .next()
784 .unwrap();
785
786 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
787 database.for_each_document(&worktree_db_ids, |id, embedding| {
788 let similarity = dot(&embedding.0, &phrase_embedding);
789 let ix = match results.binary_search_by(|(_, s)| {
790 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
791 }) {
792 Ok(ix) => ix,
793 Err(ix) => ix,
794 };
795 results.insert(ix, (id, similarity));
796 results.truncate(limit);
797 })?;
798
799 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
800 database.get_documents_by_ids(&ids)
801 })
802 .await?;
803
804 this.read_with(&cx, |this, _| {
805 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
806 state.borrow()
807 } else {
808 return Err(anyhow!("project not added"));
809 };
810
811 Ok(documents
812 .into_iter()
813 .filter_map(|(worktree_db_id, file_path, offset, name)| {
814 let worktree_id =
815 project_state
816 .worktree_db_ids
817 .iter()
818 .find_map(|(id, db_id)| {
819 if *db_id == worktree_db_id {
820 Some(*id)
821 } else {
822 None
823 }
824 })?;
825 Some(SearchResult {
826 worktree_id,
827 name,
828 offset,
829 file_path,
830 })
831 })
832 .collect())
833 })
834 })
835 }
836}
837
838impl Entity for VectorStore {
839 type Event = ();
840}
841
842fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
843 let len = vec_a.len();
844 assert_eq!(len, vec_b.len());
845
846 let mut result = 0.0;
847 unsafe {
848 matrixmultiply::sgemm(
849 1,
850 len,
851 1,
852 1.0,
853 vec_a.as_ptr(),
854 len as isize,
855 1,
856 vec_b.as_ptr(),
857 1,
858 len as isize,
859 0.0,
860 &mut result as *mut f32,
861 1,
862 1,
863 );
864 }
865 result
866}