1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::VectorDatabase;
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use futures::{channel::oneshot, Future};
12use gpui::{
13 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
14 WeakModelHandle,
15};
16use language::{Language, LanguageRegistry};
17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
18use project::{Fs, Project, WorktreeId};
19use smol::channel;
20use std::{
21 cell::RefCell,
22 cmp::Ordering,
23 collections::HashMap,
24 path::{Path, PathBuf},
25 rc::Rc,
26 sync::Arc,
27 time::{Duration, Instant, SystemTime},
28};
29use tree_sitter::{Parser, QueryCursor};
30use util::{
31 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
32 http::HttpClient,
33 paths::EMBEDDINGS_DIR,
34 ResultExt,
35};
36use workspace::{Workspace, WorkspaceCreated};
37
38const REINDEXING_DELAY_SECONDS: u64 = 3;
39const EMBEDDINGS_BATCH_SIZE: usize = 150;
40
41#[derive(Debug, Clone)]
42pub struct Document {
43 pub offset: usize,
44 pub name: String,
45 pub embedding: Vec<f32>,
46}
47
48pub fn init(
49 fs: Arc<dyn Fs>,
50 http_client: Arc<dyn HttpClient>,
51 language_registry: Arc<LanguageRegistry>,
52 cx: &mut AppContext,
53) {
54 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
55 return;
56 }
57
58 let db_file_path = EMBEDDINGS_DIR
59 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
60 .join("embeddings_db");
61
62 cx.spawn(move |mut cx| async move {
63 let vector_store = VectorStore::new(
64 fs,
65 db_file_path,
66 // Arc::new(embedding::DummyEmbeddings {}),
67 Arc::new(OpenAIEmbeddings {
68 client: http_client,
69 }),
70 language_registry,
71 cx.clone(),
72 )
73 .await?;
74
75 cx.update(|cx| {
76 cx.subscribe_global::<WorkspaceCreated, _>({
77 let vector_store = vector_store.clone();
78 move |event, cx| {
79 let workspace = &event.0;
80 if let Some(workspace) = workspace.upgrade(cx) {
81 let project = workspace.read(cx).project().clone();
82 if project.read(cx).is_local() {
83 vector_store.update(cx, |store, cx| {
84 store.add_project(project, cx).detach();
85 });
86 }
87 }
88 }
89 })
90 .detach();
91
92 cx.add_action({
93 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
94 let vector_store = vector_store.clone();
95 workspace.toggle_modal(cx, |workspace, cx| {
96 let project = workspace.project().clone();
97 let workspace = cx.weak_handle();
98 cx.add_view(|cx| {
99 SemanticSearch::new(
100 SemanticSearchDelegate::new(workspace, project, vector_store),
101 cx,
102 )
103 })
104 })
105 }
106 });
107
108 SemanticSearch::init(cx);
109 });
110
111 anyhow::Ok(())
112 })
113 .detach();
114}
115
116#[derive(Debug, Clone)]
117pub struct IndexedFile {
118 path: PathBuf,
119 mtime: SystemTime,
120 documents: Vec<Document>,
121}
122
123pub struct VectorStore {
124 fs: Arc<dyn Fs>,
125 database_url: Arc<PathBuf>,
126 embedding_provider: Arc<dyn EmbeddingProvider>,
127 language_registry: Arc<LanguageRegistry>,
128 db_update_tx: channel::Sender<DbWrite>,
129 parsing_files_tx: channel::Sender<PendingFile>,
130 _db_update_task: Task<()>,
131 _embed_batch_task: Vec<Task<()>>,
132 _batch_files_task: Task<()>,
133 _parsing_files_tasks: Vec<Task<()>>,
134 projects: HashMap<WeakModelHandle<Project>, Rc<RefCell<ProjectState>>>,
135}
136
137struct ProjectState {
138 worktree_db_ids: Vec<(WorktreeId, i64)>,
139 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
140 _subscription: gpui::Subscription,
141}
142
143impl ProjectState {
144 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
145 // If Pending File Already Exists, Replace it with the new one
146 // but keep the old indexing time
147 if let Some(old_file) = self
148 .pending_files
149 .remove(&pending_file.relative_path.clone())
150 {
151 self.pending_files.insert(
152 pending_file.relative_path.clone(),
153 (pending_file, old_file.1),
154 );
155 } else {
156 self.pending_files.insert(
157 pending_file.relative_path.clone(),
158 (pending_file, indexing_time),
159 );
160 };
161 }
162
163 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
164 let mut outstanding_files = vec![];
165 let mut remove_keys = vec![];
166 for key in self.pending_files.keys().into_iter() {
167 if let Some(pending_details) = self.pending_files.get(key) {
168 let (pending_file, index_time) = pending_details;
169 if index_time <= &SystemTime::now() {
170 outstanding_files.push(pending_file.clone());
171 remove_keys.push(key.clone());
172 }
173 }
174 }
175
176 for key in remove_keys.iter() {
177 self.pending_files.remove(key);
178 }
179
180 return outstanding_files;
181 }
182}
183
184#[derive(Clone, Debug)]
185struct PendingFile {
186 worktree_db_id: i64,
187 relative_path: PathBuf,
188 absolute_path: PathBuf,
189 language: Arc<Language>,
190 modified_time: SystemTime,
191}
192
193#[derive(Debug, Clone)]
194pub struct SearchResult {
195 pub worktree_id: WorktreeId,
196 pub name: String,
197 pub offset: usize,
198 pub file_path: PathBuf,
199}
200
201enum DbWrite {
202 InsertFile {
203 worktree_id: i64,
204 indexed_file: IndexedFile,
205 },
206 Delete {
207 worktree_id: i64,
208 path: PathBuf,
209 },
210 FindOrCreateWorktree {
211 path: PathBuf,
212 sender: oneshot::Sender<Result<i64>>,
213 },
214}
215
216impl VectorStore {
217 async fn new(
218 fs: Arc<dyn Fs>,
219 database_url: PathBuf,
220 embedding_provider: Arc<dyn EmbeddingProvider>,
221 language_registry: Arc<LanguageRegistry>,
222 mut cx: AsyncAppContext,
223 ) -> Result<ModelHandle<Self>> {
224 let database_url = Arc::new(database_url);
225
226 let db = cx
227 .background()
228 .spawn({
229 let fs = fs.clone();
230 let database_url = database_url.clone();
231 async move {
232 if let Some(db_directory) = database_url.parent() {
233 fs.create_dir(db_directory).await.log_err();
234 }
235
236 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
237 anyhow::Ok(db)
238 }
239 })
240 .await?;
241
242 Ok(cx.add_model(|cx| {
243 // paths_tx -> embeddings_tx -> db_update_tx
244
245 //db_update_tx/rx: Updating Database
246 let (db_update_tx, db_update_rx) = channel::unbounded();
247 let _db_update_task = cx.background().spawn(async move {
248 while let Ok(job) = db_update_rx.recv().await {
249 match job {
250 DbWrite::InsertFile {
251 worktree_id,
252 indexed_file,
253 } => {
254 log::info!("Inserting Data for {:?}", &indexed_file.path);
255 db.insert_file(worktree_id, indexed_file).log_err();
256 }
257 DbWrite::Delete { worktree_id, path } => {
258 db.delete_file(worktree_id, path).log_err();
259 }
260 DbWrite::FindOrCreateWorktree { path, sender } => {
261 let id = db.find_or_create_worktree(&path);
262 sender.send(id).ok();
263 }
264 }
265 }
266 });
267
268 // embed_tx/rx: Embed Batch and Send to Database
269 let (embed_batch_tx, embed_batch_rx) =
270 channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
271 let mut _embed_batch_task = Vec::new();
272 for _ in 0..1 {
273 //cx.background().num_cpus() {
274 let db_update_tx = db_update_tx.clone();
275 let embed_batch_rx = embed_batch_rx.clone();
276 let embedding_provider = embedding_provider.clone();
277 _embed_batch_task.push(cx.background().spawn(async move {
278 while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
279 // Construct Batch
280 let mut embeddings_queue = embeddings_queue.clone();
281 let mut document_spans = vec![];
282 for (_, _, document_span) in embeddings_queue.clone().into_iter() {
283 document_spans.extend(document_span);
284 }
285
286 if let Ok(embeddings) = embedding_provider
287 .embed_batch(document_spans.iter().map(|x| &**x).collect())
288 .await
289 {
290 let mut i = 0;
291 let mut j = 0;
292
293 for embedding in embeddings.iter() {
294 while embeddings_queue[i].1.documents.len() == j {
295 i += 1;
296 j = 0;
297 }
298
299 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
300 j += 1;
301 }
302
303 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
304 for document in indexed_file.documents.iter() {
305 // TODO: Update this so it doesn't panic
306 assert!(
307 document.embedding.len() > 0,
308 "Document Embedding Not Complete"
309 );
310 }
311
312 db_update_tx
313 .send(DbWrite::InsertFile {
314 worktree_id,
315 indexed_file,
316 })
317 .await
318 .unwrap();
319 }
320 }
321 }
322 }))
323 }
324
325 // batch_tx/rx: Batch Files to Send for Embeddings
326 let (batch_files_tx, batch_files_rx) =
327 channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
328 let _batch_files_task = cx.background().spawn(async move {
329 let mut queue_len = 0;
330 let mut embeddings_queue = vec![];
331 while let Ok((worktree_id, indexed_file, document_spans)) =
332 batch_files_rx.recv().await
333 {
334 queue_len += &document_spans.len();
335 embeddings_queue.push((worktree_id, indexed_file, document_spans));
336 if queue_len >= EMBEDDINGS_BATCH_SIZE {
337 embed_batch_tx.try_send(embeddings_queue).unwrap();
338 embeddings_queue = vec![];
339 queue_len = 0;
340 }
341 }
342 if queue_len > 0 {
343 embed_batch_tx.try_send(embeddings_queue).unwrap();
344 }
345 });
346
347 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
348 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
349
350 let mut _parsing_files_tasks = Vec::new();
351 for _ in 0..cx.background().num_cpus() {
352 let fs = fs.clone();
353 let parsing_files_rx = parsing_files_rx.clone();
354 let batch_files_tx = batch_files_tx.clone();
355 _parsing_files_tasks.push(cx.background().spawn(async move {
356 let mut parser = Parser::new();
357 let mut cursor = QueryCursor::new();
358 while let Ok(pending_file) = parsing_files_rx.recv().await {
359 log::info!("Parsing File: {:?}", &pending_file.relative_path);
360 if let Some((indexed_file, document_spans)) = Self::index_file(
361 &mut cursor,
362 &mut parser,
363 &fs,
364 pending_file.language,
365 pending_file.relative_path.clone(),
366 pending_file.absolute_path.clone(),
367 pending_file.modified_time,
368 )
369 .await
370 .log_err()
371 {
372 batch_files_tx
373 .try_send((
374 pending_file.worktree_db_id,
375 indexed_file,
376 document_spans,
377 ))
378 .unwrap();
379 }
380 }
381 }));
382 }
383
384 Self {
385 fs,
386 database_url,
387 embedding_provider,
388 language_registry,
389 db_update_tx,
390 parsing_files_tx,
391 _db_update_task,
392 _embed_batch_task,
393 _batch_files_task,
394 _parsing_files_tasks,
395 projects: HashMap::new(),
396 }
397 }))
398 }
399
400 async fn index_file(
401 cursor: &mut QueryCursor,
402 parser: &mut Parser,
403 fs: &Arc<dyn Fs>,
404 language: Arc<Language>,
405 relative_file_path: PathBuf,
406 absolute_file_path: PathBuf,
407 mtime: SystemTime,
408 ) -> Result<(IndexedFile, Vec<String>)> {
409 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
410 let embedding_config = grammar
411 .embedding_config
412 .as_ref()
413 .ok_or_else(|| anyhow!("no outline query"))?;
414
415 let content = fs.load(&absolute_file_path).await?;
416
417 parser.set_language(grammar.ts_language).unwrap();
418 let tree = parser
419 .parse(&content, None)
420 .ok_or_else(|| anyhow!("parsing failed"))?;
421
422 let mut documents = Vec::new();
423 let mut context_spans = Vec::new();
424 for mat in cursor.matches(
425 &embedding_config.query,
426 tree.root_node(),
427 content.as_bytes(),
428 ) {
429 let mut item_range = None;
430 let mut name_range = None;
431 let mut context_range = None;
432 for capture in mat.captures {
433 if capture.index == embedding_config.item_capture_ix {
434 item_range = Some(capture.node.byte_range());
435 } else if capture.index == embedding_config.name_capture_ix {
436 name_range = Some(capture.node.byte_range());
437 }
438 if let Some(context_capture_ix) = embedding_config.context_capture_ix {
439 if capture.index == context_capture_ix {
440 context_range = Some(capture.node.byte_range());
441 }
442 }
443 }
444
445 if let Some((item_range, name_range)) = item_range.zip(name_range) {
446 let mut context_data = String::new();
447 if let Some(context_range) = context_range {
448 if let Some(context) = content.get(context_range.clone()) {
449 context_data.push_str(context);
450 }
451 }
452
453 if let Some((item, name)) =
454 content.get(item_range.clone()).zip(content.get(name_range))
455 {
456 context_spans.push(item.to_string());
457 documents.push(Document {
458 name: format!("{} {}", context_data.to_string(), name.to_string()),
459 offset: item_range.start,
460 embedding: Vec::new(),
461 });
462 }
463 }
464 }
465
466 return Ok((
467 IndexedFile {
468 path: relative_file_path,
469 mtime,
470 documents,
471 },
472 context_spans,
473 ));
474 }
475
476 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
477 let (tx, rx) = oneshot::channel();
478 self.db_update_tx
479 .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
480 .unwrap();
481 async move { rx.await? }
482 }
483
484 fn add_project(
485 &mut self,
486 project: ModelHandle<Project>,
487 cx: &mut ModelContext<Self>,
488 ) -> Task<Result<()>> {
489 let worktree_scans_complete = project
490 .read(cx)
491 .worktrees(cx)
492 .map(|worktree| {
493 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
494 async move {
495 scan_complete.await;
496 }
497 })
498 .collect::<Vec<_>>();
499 let worktree_db_ids = project
500 .read(cx)
501 .worktrees(cx)
502 .map(|worktree| {
503 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
504 })
505 .collect::<Vec<_>>();
506
507 let fs = self.fs.clone();
508 let language_registry = self.language_registry.clone();
509 let database_url = self.database_url.clone();
510 let db_update_tx = self.db_update_tx.clone();
511 let parsing_files_tx = self.parsing_files_tx.clone();
512
513 cx.spawn(|this, mut cx| async move {
514 let t0 = Instant::now();
515 futures::future::join_all(worktree_scans_complete).await;
516
517 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
518 log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
519
520 if let Some(db_directory) = database_url.parent() {
521 fs.create_dir(db_directory).await.log_err();
522 }
523
524 let worktrees = project.read_with(&cx, |project, cx| {
525 project
526 .worktrees(cx)
527 .map(|worktree| worktree.read(cx).snapshot())
528 .collect::<Vec<_>>()
529 });
530
531 // Here we query the worktree ids, and yet we dont have them elsewhere
532 // We likely want to clean up these datastructures
533 let (mut worktree_file_times, db_ids_by_worktree_id) = cx
534 .background()
535 .spawn({
536 let worktrees = worktrees.clone();
537 async move {
538 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
539 let mut db_ids_by_worktree_id = HashMap::new();
540 let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
541 HashMap::new();
542 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
543 let db_id = db_id?;
544 db_ids_by_worktree_id.insert(worktree.id(), db_id);
545 file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
546 }
547 anyhow::Ok((file_times, db_ids_by_worktree_id))
548 }
549 })
550 .await?;
551
552 cx.background()
553 .spawn({
554 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
555 let db_update_tx = db_update_tx.clone();
556 let language_registry = language_registry.clone();
557 let parsing_files_tx = parsing_files_tx.clone();
558 async move {
559 let t0 = Instant::now();
560 for worktree in worktrees.into_iter() {
561 let mut file_mtimes =
562 worktree_file_times.remove(&worktree.id()).unwrap();
563 for file in worktree.files(false, 0) {
564 let absolute_path = worktree.absolutize(&file.path);
565
566 if let Ok(language) = language_registry
567 .language_for_file(&absolute_path, None)
568 .await
569 {
570 if language
571 .grammar()
572 .and_then(|grammar| grammar.embedding_config.as_ref())
573 .is_none()
574 {
575 continue;
576 }
577
578 let path_buf = file.path.to_path_buf();
579 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
580 let already_stored = stored_mtime
581 .map_or(false, |existing_mtime| {
582 existing_mtime == file.mtime
583 });
584
585 if !already_stored {
586 parsing_files_tx
587 .try_send(PendingFile {
588 worktree_db_id: db_ids_by_worktree_id
589 [&worktree.id()],
590 relative_path: path_buf,
591 absolute_path,
592 language,
593 modified_time: file.mtime,
594 })
595 .unwrap();
596 }
597 }
598 }
599 for file in file_mtimes.keys() {
600 db_update_tx
601 .try_send(DbWrite::Delete {
602 worktree_id: db_ids_by_worktree_id[&worktree.id()],
603 path: file.to_owned(),
604 })
605 .unwrap();
606 }
607 }
608 log::info!(
609 "Parsing Worktree Completed in {:?}",
610 t0.elapsed().as_millis()
611 );
612 }
613 })
614 .detach();
615
616 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
617 this.update(&mut cx, |this, cx| {
618 // The below is managing for updated on save
619 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
620 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
621 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
622 if let Some(project_state) = this.projects.get(&project.downgrade()) {
623 let mut project_state = project_state.borrow_mut();
624 let worktree_db_ids = project_state.worktree_db_ids.clone();
625
626 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
627 {
628 // Get Worktree Object
629 let worktree =
630 project.read(cx).worktree_for_id(worktree_id.clone(), cx);
631 if worktree.is_none() {
632 return;
633 }
634 let worktree = worktree.unwrap();
635
636 // Get Database
637 let db_values = {
638 if let Ok(db) =
639 VectorDatabase::new(this.database_url.to_string_lossy().into())
640 {
641 let worktree_db_id: Option<i64> = {
642 let mut found_db_id = None;
643 for (w_id, db_id) in worktree_db_ids.into_iter() {
644 if &w_id == &worktree.read(cx).id() {
645 found_db_id = Some(db_id)
646 }
647 }
648 found_db_id
649 };
650 if worktree_db_id.is_none() {
651 return;
652 }
653 let worktree_db_id = worktree_db_id.unwrap();
654
655 let file_mtimes = db.get_file_mtimes(worktree_db_id);
656 if file_mtimes.is_err() {
657 return;
658 }
659
660 let file_mtimes = file_mtimes.unwrap();
661 Some((file_mtimes, worktree_db_id))
662 } else {
663 return;
664 }
665 };
666
667 if db_values.is_none() {
668 return;
669 }
670
671 let (file_mtimes, worktree_db_id) = db_values.unwrap();
672
673 // Iterate Through Changes
674 let language_registry = this.language_registry.clone();
675 let parsing_files_tx = this.parsing_files_tx.clone();
676
677 smol::block_on(async move {
678 for change in changes.into_iter() {
679 let change_path = change.0.clone();
680 let absolute_path = worktree.read(cx).absolutize(&change_path);
681 // Skip if git ignored or symlink
682 if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
683 if entry.is_ignored || entry.is_symlink {
684 continue;
685 } else {
686 log::info!(
687 "Testing for Reindexing: {:?}",
688 &change_path
689 );
690 }
691 };
692
693 if let Ok(language) = language_registry
694 .language_for_file(&change_path.to_path_buf(), None)
695 .await
696 {
697 if language
698 .grammar()
699 .and_then(|grammar| grammar.embedding_config.as_ref())
700 .is_none()
701 {
702 continue;
703 }
704
705 if let Some(modified_time) = {
706 let metadata = change_path.metadata();
707 if metadata.is_err() {
708 None
709 } else {
710 let mtime = metadata.unwrap().modified();
711 if mtime.is_err() {
712 None
713 } else {
714 Some(mtime.unwrap())
715 }
716 }
717 } {
718 let existing_time =
719 file_mtimes.get(&change_path.to_path_buf());
720 let already_stored = existing_time
721 .map_or(false, |existing_time| {
722 &modified_time != existing_time
723 });
724
725 let reindex_time = modified_time
726 + Duration::from_secs(REINDEXING_DELAY_SECONDS);
727
728 if !already_stored {
729 project_state.update_pending_files(
730 PendingFile {
731 relative_path: change_path.to_path_buf(),
732 absolute_path,
733 modified_time,
734 worktree_db_id,
735 language: language.clone(),
736 },
737 reindex_time,
738 );
739
740 for file in project_state.get_outstanding_files() {
741 parsing_files_tx.try_send(file).unwrap();
742 }
743 }
744 }
745 }
746 }
747 });
748 };
749 }
750 });
751
752 this.projects.insert(
753 project.downgrade(),
754 Rc::new(RefCell::new(ProjectState {
755 pending_files: HashMap::new(),
756 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
757 _subscription,
758 })),
759 );
760 });
761
762 anyhow::Ok(())
763 })
764 }
765
766 pub fn search(
767 &mut self,
768 project: ModelHandle<Project>,
769 phrase: String,
770 limit: usize,
771 cx: &mut ModelContext<Self>,
772 ) -> Task<Result<Vec<SearchResult>>> {
773 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
774 state.borrow()
775 } else {
776 return Task::ready(Err(anyhow!("project not added")));
777 };
778
779 let worktree_db_ids = project
780 .read(cx)
781 .worktrees(cx)
782 .filter_map(|worktree| {
783 let worktree_id = worktree.read(cx).id();
784 project_state
785 .worktree_db_ids
786 .iter()
787 .find_map(|(id, db_id)| {
788 if *id == worktree_id {
789 Some(*db_id)
790 } else {
791 None
792 }
793 })
794 })
795 .collect::<Vec<_>>();
796
797 let embedding_provider = self.embedding_provider.clone();
798 let database_url = self.database_url.clone();
799 cx.spawn(|this, cx| async move {
800 let documents = cx
801 .background()
802 .spawn(async move {
803 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
804
805 let phrase_embedding = embedding_provider
806 .embed_batch(vec![&phrase])
807 .await?
808 .into_iter()
809 .next()
810 .unwrap();
811
812 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
813 database.for_each_document(&worktree_db_ids, |id, embedding| {
814 let similarity = dot(&embedding.0, &phrase_embedding);
815 let ix = match results.binary_search_by(|(_, s)| {
816 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
817 }) {
818 Ok(ix) => ix,
819 Err(ix) => ix,
820 };
821 results.insert(ix, (id, similarity));
822 results.truncate(limit);
823 })?;
824
825 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
826 database.get_documents_by_ids(&ids)
827 })
828 .await?;
829
830 this.read_with(&cx, |this, _| {
831 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
832 state.borrow()
833 } else {
834 return Err(anyhow!("project not added"));
835 };
836
837 Ok(documents
838 .into_iter()
839 .filter_map(|(worktree_db_id, file_path, offset, name)| {
840 let worktree_id =
841 project_state
842 .worktree_db_ids
843 .iter()
844 .find_map(|(id, db_id)| {
845 if *db_id == worktree_db_id {
846 Some(*id)
847 } else {
848 None
849 }
850 })?;
851 Some(SearchResult {
852 worktree_id,
853 name,
854 offset,
855 file_path,
856 })
857 })
858 .collect())
859 })
860 })
861 }
862}
863
864impl Entity for VectorStore {
865 type Event = ();
866}
867
868fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
869 let len = vec_a.len();
870 assert_eq!(len, vec_b.len());
871
872 let mut result = 0.0;
873 unsafe {
874 matrixmultiply::sgemm(
875 1,
876 len,
877 1,
878 1.0,
879 vec_a.as_ptr(),
880 len as isize,
881 1,
882 vec_b.as_ptr(),
883 1,
884 len as isize,
885 0.0,
886 &mut result as *mut f32,
887 1,
888 1,
889 );
890 }
891 result
892}