1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::VectorDatabase;
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use futures::{channel::oneshot, Future};
12use gpui::{
13 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
14 WeakModelHandle,
15};
16use language::{Language, LanguageRegistry};
17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
18use project::{Fs, Project, WorktreeId};
19use smol::channel;
20use std::{
21 cell::RefCell,
22 cmp::Ordering,
23 collections::HashMap,
24 path::{Path, PathBuf},
25 rc::Rc,
26 sync::Arc,
27 time::{Duration, Instant, SystemTime},
28};
29use tree_sitter::{Parser, QueryCursor};
30use util::{
31 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
32 http::HttpClient,
33 paths::EMBEDDINGS_DIR,
34 ResultExt,
35};
36use workspace::{Workspace, WorkspaceCreated};
37
38const REINDEXING_DELAY_SECONDS: u64 = 3;
39const EMBEDDINGS_BATCH_SIZE: usize = 150;
40
41#[derive(Debug, Clone)]
42pub struct Document {
43 pub offset: usize,
44 pub name: String,
45 pub embedding: Vec<f32>,
46}
47
48pub fn init(
49 fs: Arc<dyn Fs>,
50 http_client: Arc<dyn HttpClient>,
51 language_registry: Arc<LanguageRegistry>,
52 cx: &mut AppContext,
53) {
54 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
55 return;
56 }
57
58 let db_file_path = EMBEDDINGS_DIR
59 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
60 .join("embeddings_db");
61
62 cx.spawn(move |mut cx| async move {
63 let vector_store = VectorStore::new(
64 fs,
65 db_file_path,
66 // Arc::new(embedding::DummyEmbeddings {}),
67 Arc::new(OpenAIEmbeddings {
68 client: http_client,
69 }),
70 language_registry,
71 cx.clone(),
72 )
73 .await?;
74
75 cx.update(|cx| {
76 cx.subscribe_global::<WorkspaceCreated, _>({
77 let vector_store = vector_store.clone();
78 move |event, cx| {
79 let workspace = &event.0;
80 if let Some(workspace) = workspace.upgrade(cx) {
81 let project = workspace.read(cx).project().clone();
82 if project.read(cx).is_local() {
83 vector_store.update(cx, |store, cx| {
84 store.add_project(project, cx).detach();
85 });
86 }
87 }
88 }
89 })
90 .detach();
91
92 cx.add_action({
93 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
94 let vector_store = vector_store.clone();
95 workspace.toggle_modal(cx, |workspace, cx| {
96 let project = workspace.project().clone();
97 let workspace = cx.weak_handle();
98 cx.add_view(|cx| {
99 SemanticSearch::new(
100 SemanticSearchDelegate::new(workspace, project, vector_store),
101 cx,
102 )
103 })
104 })
105 }
106 });
107
108 SemanticSearch::init(cx);
109 });
110
111 anyhow::Ok(())
112 })
113 .detach();
114}
115
116#[derive(Debug, Clone)]
117pub struct IndexedFile {
118 path: PathBuf,
119 mtime: SystemTime,
120 documents: Vec<Document>,
121}
122
123pub struct VectorStore {
124 fs: Arc<dyn Fs>,
125 database_url: Arc<PathBuf>,
126 embedding_provider: Arc<dyn EmbeddingProvider>,
127 language_registry: Arc<LanguageRegistry>,
128 db_update_tx: channel::Sender<DbWrite>,
129 parsing_files_tx: channel::Sender<PendingFile>,
130 _db_update_task: Task<()>,
131 _embed_batch_task: Vec<Task<()>>,
132 _batch_files_task: Task<()>,
133 _parsing_files_tasks: Vec<Task<()>>,
134 projects: HashMap<WeakModelHandle<Project>, Rc<RefCell<ProjectState>>>,
135}
136
137struct ProjectState {
138 worktree_db_ids: Vec<(WorktreeId, i64)>,
139 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
140 _subscription: gpui::Subscription,
141}
142
143impl ProjectState {
144 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
145 // If Pending File Already Exists, Replace it with the new one
146 // but keep the old indexing time
147 if let Some(old_file) = self.pending_files.remove(&pending_file.path.clone()) {
148 self.pending_files
149 .insert(pending_file.path.clone(), (pending_file, old_file.1));
150 } else {
151 self.pending_files
152 .insert(pending_file.path.clone(), (pending_file, indexing_time));
153 };
154 }
155
156 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
157 let mut outstanding_files = vec![];
158 let mut remove_keys = vec![];
159 for key in self.pending_files.keys().into_iter() {
160 if let Some(pending_details) = self.pending_files.get(key) {
161 let (pending_file, index_time) = pending_details;
162 if index_time <= &SystemTime::now() {
163 outstanding_files.push(pending_file.clone());
164 remove_keys.push(key.clone());
165 }
166 }
167 }
168
169 for key in remove_keys.iter() {
170 self.pending_files.remove(key);
171 }
172
173 return outstanding_files;
174 }
175}
176
177#[derive(Clone, Debug)]
178struct PendingFile {
179 worktree_db_id: i64,
180 path: PathBuf,
181 language: Arc<Language>,
182 modified_time: SystemTime,
183}
184
185#[derive(Debug, Clone)]
186pub struct SearchResult {
187 pub worktree_id: WorktreeId,
188 pub name: String,
189 pub offset: usize,
190 pub file_path: PathBuf,
191}
192
193enum DbWrite {
194 InsertFile {
195 worktree_id: i64,
196 indexed_file: IndexedFile,
197 },
198 Delete {
199 worktree_id: i64,
200 path: PathBuf,
201 },
202 FindOrCreateWorktree {
203 path: PathBuf,
204 sender: oneshot::Sender<Result<i64>>,
205 },
206}
207
208impl VectorStore {
209 async fn new(
210 fs: Arc<dyn Fs>,
211 database_url: PathBuf,
212 embedding_provider: Arc<dyn EmbeddingProvider>,
213 language_registry: Arc<LanguageRegistry>,
214 mut cx: AsyncAppContext,
215 ) -> Result<ModelHandle<Self>> {
216 let database_url = Arc::new(database_url);
217
218 let db = cx
219 .background()
220 .spawn({
221 let fs = fs.clone();
222 let database_url = database_url.clone();
223 async move {
224 if let Some(db_directory) = database_url.parent() {
225 fs.create_dir(db_directory).await.log_err();
226 }
227
228 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
229 anyhow::Ok(db)
230 }
231 })
232 .await?;
233
234 Ok(cx.add_model(|cx| {
235 // paths_tx -> embeddings_tx -> db_update_tx
236
237 //db_update_tx/rx: Updating Database
238 let (db_update_tx, db_update_rx) = channel::unbounded();
239 let _db_update_task = cx.background().spawn(async move {
240 while let Ok(job) = db_update_rx.recv().await {
241 match job {
242 DbWrite::InsertFile {
243 worktree_id,
244 indexed_file,
245 } => {
246 log::info!("Inserting Data for {:?}", &indexed_file.path);
247 db.insert_file(worktree_id, indexed_file).log_err();
248 }
249 DbWrite::Delete { worktree_id, path } => {
250 db.delete_file(worktree_id, path).log_err();
251 }
252 DbWrite::FindOrCreateWorktree { path, sender } => {
253 let id = db.find_or_create_worktree(&path);
254 sender.send(id).ok();
255 }
256 }
257 }
258 });
259
260 // embed_tx/rx: Embed Batch and Send to Database
261 let (embed_batch_tx, embed_batch_rx) =
262 channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
263 let mut _embed_batch_task = Vec::new();
264 for _ in 0..1 {
265 //cx.background().num_cpus() {
266 let db_update_tx = db_update_tx.clone();
267 let embed_batch_rx = embed_batch_rx.clone();
268 let embedding_provider = embedding_provider.clone();
269 _embed_batch_task.push(cx.background().spawn(async move {
270 while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
271 // Construct Batch
272 let mut embeddings_queue = embeddings_queue.clone();
273 let mut document_spans = vec![];
274 for (_, _, document_span) in embeddings_queue.clone().into_iter() {
275 document_spans.extend(document_span);
276 }
277
278 if let Ok(embeddings) = embedding_provider
279 .embed_batch(document_spans.iter().map(|x| &**x).collect())
280 .await
281 {
282 let mut i = 0;
283 let mut j = 0;
284
285 for embedding in embeddings.iter() {
286 while embeddings_queue[i].1.documents.len() == j {
287 i += 1;
288 j = 0;
289 }
290
291 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
292 j += 1;
293 }
294
295 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
296 for document in indexed_file.documents.iter() {
297 // TODO: Update this so it doesn't panic
298 assert!(
299 document.embedding.len() > 0,
300 "Document Embedding Not Complete"
301 );
302 }
303
304 db_update_tx
305 .send(DbWrite::InsertFile {
306 worktree_id,
307 indexed_file,
308 })
309 .await
310 .unwrap();
311 }
312 }
313 }
314 }))
315 }
316
317 // batch_tx/rx: Batch Files to Send for Embeddings
318 let (batch_files_tx, batch_files_rx) =
319 channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
320 let _batch_files_task = cx.background().spawn(async move {
321 let mut queue_len = 0;
322 let mut embeddings_queue = vec![];
323 while let Ok((worktree_id, indexed_file, document_spans)) =
324 batch_files_rx.recv().await
325 {
326 queue_len += &document_spans.len();
327 embeddings_queue.push((worktree_id, indexed_file, document_spans));
328 if queue_len >= EMBEDDINGS_BATCH_SIZE {
329 embed_batch_tx.try_send(embeddings_queue).unwrap();
330 embeddings_queue = vec![];
331 queue_len = 0;
332 }
333 }
334 if queue_len > 0 {
335 embed_batch_tx.try_send(embeddings_queue).unwrap();
336 }
337 });
338
339 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
340 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
341
342 let mut _parsing_files_tasks = Vec::new();
343 for _ in 0..cx.background().num_cpus() {
344 let fs = fs.clone();
345 let parsing_files_rx = parsing_files_rx.clone();
346 let batch_files_tx = batch_files_tx.clone();
347 _parsing_files_tasks.push(cx.background().spawn(async move {
348 let mut parser = Parser::new();
349 let mut cursor = QueryCursor::new();
350 while let Ok(pending_file) = parsing_files_rx.recv().await {
351 log::info!("Parsing File: {:?}", &pending_file.path);
352 if let Some((indexed_file, document_spans)) = Self::index_file(
353 &mut cursor,
354 &mut parser,
355 &fs,
356 pending_file.language,
357 pending_file.path.clone(),
358 pending_file.modified_time,
359 )
360 .await
361 .log_err()
362 {
363 batch_files_tx
364 .try_send((
365 pending_file.worktree_db_id,
366 indexed_file,
367 document_spans,
368 ))
369 .unwrap();
370 }
371 }
372 }));
373 }
374
375 Self {
376 fs,
377 database_url,
378 embedding_provider,
379 language_registry,
380 db_update_tx,
381 parsing_files_tx,
382 _db_update_task,
383 _embed_batch_task,
384 _batch_files_task,
385 _parsing_files_tasks,
386 projects: HashMap::new(),
387 }
388 }))
389 }
390
391 async fn index_file(
392 cursor: &mut QueryCursor,
393 parser: &mut Parser,
394 fs: &Arc<dyn Fs>,
395 language: Arc<Language>,
396 file_path: PathBuf,
397 mtime: SystemTime,
398 ) -> Result<(IndexedFile, Vec<String>)> {
399 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
400 let embedding_config = grammar
401 .embedding_config
402 .as_ref()
403 .ok_or_else(|| anyhow!("no outline query"))?;
404
405 let content = fs.load(&file_path).await?;
406
407 parser.set_language(grammar.ts_language).unwrap();
408 let tree = parser
409 .parse(&content, None)
410 .ok_or_else(|| anyhow!("parsing failed"))?;
411
412 let mut documents = Vec::new();
413 let mut context_spans = Vec::new();
414 for mat in cursor.matches(
415 &embedding_config.query,
416 tree.root_node(),
417 content.as_bytes(),
418 ) {
419 let mut item_range = None;
420 let mut name_range = None;
421 let mut context_range = None;
422 for capture in mat.captures {
423 if capture.index == embedding_config.item_capture_ix {
424 item_range = Some(capture.node.byte_range());
425 } else if capture.index == embedding_config.name_capture_ix {
426 name_range = Some(capture.node.byte_range());
427 }
428 if let Some(context_capture_ix) = embedding_config.context_capture_ix {
429 if capture.index == context_capture_ix {
430 context_range = Some(capture.node.byte_range());
431 }
432 }
433 }
434
435 if let Some((item_range, name_range)) = item_range.zip(name_range) {
436 let mut context_data = String::new();
437 if let Some(context_range) = context_range {
438 if let Some(context) = content.get(context_range.clone()) {
439 context_data.push_str(context);
440 }
441 }
442
443 if let Some((item, name)) =
444 content.get(item_range.clone()).zip(content.get(name_range))
445 {
446 context_spans.push(item.to_string());
447 documents.push(Document {
448 name: format!("{} {}", context_data.to_string(), name.to_string()),
449 offset: item_range.start,
450 embedding: Vec::new(),
451 });
452 }
453 }
454 }
455
456 return Ok((
457 IndexedFile {
458 path: file_path,
459 mtime,
460 documents,
461 },
462 context_spans,
463 ));
464 }
465
466 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
467 let (tx, rx) = oneshot::channel();
468 self.db_update_tx
469 .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
470 .unwrap();
471 async move { rx.await? }
472 }
473
474 fn add_project(
475 &mut self,
476 project: ModelHandle<Project>,
477 cx: &mut ModelContext<Self>,
478 ) -> Task<Result<()>> {
479 let worktree_scans_complete = project
480 .read(cx)
481 .worktrees(cx)
482 .map(|worktree| {
483 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
484 async move {
485 scan_complete.await;
486 }
487 })
488 .collect::<Vec<_>>();
489 let worktree_db_ids = project
490 .read(cx)
491 .worktrees(cx)
492 .map(|worktree| {
493 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
494 })
495 .collect::<Vec<_>>();
496
497 let fs = self.fs.clone();
498 let language_registry = self.language_registry.clone();
499 let database_url = self.database_url.clone();
500 let db_update_tx = self.db_update_tx.clone();
501 let parsing_files_tx = self.parsing_files_tx.clone();
502
503 cx.spawn(|this, mut cx| async move {
504 let t0 = Instant::now();
505 futures::future::join_all(worktree_scans_complete).await;
506
507 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
508 log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
509
510 if let Some(db_directory) = database_url.parent() {
511 fs.create_dir(db_directory).await.log_err();
512 }
513
514 let worktrees = project.read_with(&cx, |project, cx| {
515 project
516 .worktrees(cx)
517 .map(|worktree| worktree.read(cx).snapshot())
518 .collect::<Vec<_>>()
519 });
520
521 // Here we query the worktree ids, and yet we dont have them elsewhere
522 // We likely want to clean up these datastructures
523 let (mut worktree_file_times, db_ids_by_worktree_id) = cx
524 .background()
525 .spawn({
526 let worktrees = worktrees.clone();
527 async move {
528 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
529 let mut db_ids_by_worktree_id = HashMap::new();
530 let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
531 HashMap::new();
532 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
533 let db_id = db_id?;
534 db_ids_by_worktree_id.insert(worktree.id(), db_id);
535 file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
536 }
537 anyhow::Ok((file_times, db_ids_by_worktree_id))
538 }
539 })
540 .await?;
541
542 cx.background()
543 .spawn({
544 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
545 let db_update_tx = db_update_tx.clone();
546 let language_registry = language_registry.clone();
547 let parsing_files_tx = parsing_files_tx.clone();
548 async move {
549 let t0 = Instant::now();
550 for worktree in worktrees.into_iter() {
551 let mut file_mtimes =
552 worktree_file_times.remove(&worktree.id()).unwrap();
553 for file in worktree.files(false, 0) {
554 let absolute_path = worktree.absolutize(&file.path);
555
556 if let Ok(language) = language_registry
557 .language_for_file(&absolute_path, None)
558 .await
559 {
560 if language
561 .grammar()
562 .and_then(|grammar| grammar.embedding_config.as_ref())
563 .is_none()
564 {
565 continue;
566 }
567
568 let path_buf = file.path.to_path_buf();
569 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
570 let already_stored = stored_mtime
571 .map_or(false, |existing_mtime| {
572 existing_mtime == file.mtime
573 });
574
575 if !already_stored {
576 parsing_files_tx
577 .try_send(PendingFile {
578 worktree_db_id: db_ids_by_worktree_id
579 [&worktree.id()],
580 path: path_buf,
581 language,
582 modified_time: file.mtime,
583 })
584 .unwrap();
585 }
586 }
587 }
588 for file in file_mtimes.keys() {
589 db_update_tx
590 .try_send(DbWrite::Delete {
591 worktree_id: db_ids_by_worktree_id[&worktree.id()],
592 path: file.to_owned(),
593 })
594 .unwrap();
595 }
596 }
597 log::info!(
598 "Parsing Worktree Completed in {:?}",
599 t0.elapsed().as_millis()
600 );
601 }
602 })
603 .detach();
604
605 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
606 this.update(&mut cx, |this, cx| {
607 // The below is managing for updated on save
608 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
609 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
610 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
611 if let Some(project_state) = this.projects.get(&project.downgrade()) {
612 let mut project_state = project_state.borrow_mut();
613 let worktree_db_ids = project_state.worktree_db_ids.clone();
614
615 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
616 {
617 // Get Worktree Object
618 let worktree =
619 project.read(cx).worktree_for_id(worktree_id.clone(), cx);
620 if worktree.is_none() {
621 return;
622 }
623 let worktree = worktree.unwrap();
624
625 // Get Database
626 let db_values = {
627 if let Ok(db) =
628 VectorDatabase::new(this.database_url.to_string_lossy().into())
629 {
630 let worktree_db_id: Option<i64> = {
631 let mut found_db_id = None;
632 for (w_id, db_id) in worktree_db_ids.into_iter() {
633 if &w_id == &worktree.read(cx).id() {
634 found_db_id = Some(db_id)
635 }
636 }
637 found_db_id
638 };
639 if worktree_db_id.is_none() {
640 return;
641 }
642 let worktree_db_id = worktree_db_id.unwrap();
643
644 let file_mtimes = db.get_file_mtimes(worktree_db_id);
645 if file_mtimes.is_err() {
646 return;
647 }
648
649 let file_mtimes = file_mtimes.unwrap();
650 Some((file_mtimes, worktree_db_id))
651 } else {
652 return;
653 }
654 };
655
656 if db_values.is_none() {
657 return;
658 }
659
660 let (file_mtimes, worktree_db_id) = db_values.unwrap();
661
662 // Iterate Through Changes
663 let language_registry = this.language_registry.clone();
664 let parsing_files_tx = this.parsing_files_tx.clone();
665
666 smol::block_on(async move {
667 for change in changes.into_iter() {
668 let change_path = change.0.clone();
669 // Skip if git ignored or symlink
670 if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
671 if entry.is_ignored || entry.is_symlink {
672 continue;
673 } else {
674 log::info!(
675 "Testing for Reindexing: {:?}",
676 &change_path
677 );
678 }
679 };
680
681 if let Ok(language) = language_registry
682 .language_for_file(&change_path.to_path_buf(), None)
683 .await
684 {
685 if language
686 .grammar()
687 .and_then(|grammar| grammar.embedding_config.as_ref())
688 .is_none()
689 {
690 continue;
691 }
692
693 if let Some(modified_time) = {
694 let metadata = change_path.metadata();
695 if metadata.is_err() {
696 None
697 } else {
698 let mtime = metadata.unwrap().modified();
699 if mtime.is_err() {
700 None
701 } else {
702 Some(mtime.unwrap())
703 }
704 }
705 } {
706 let existing_time =
707 file_mtimes.get(&change_path.to_path_buf());
708 let already_stored = existing_time
709 .map_or(false, |existing_time| {
710 &modified_time != existing_time
711 });
712
713 let reindex_time = modified_time
714 + Duration::from_secs(REINDEXING_DELAY_SECONDS);
715
716 if !already_stored {
717 project_state.update_pending_files(
718 PendingFile {
719 path: change_path.to_path_buf(),
720 modified_time,
721 worktree_db_id,
722 language: language.clone(),
723 },
724 reindex_time,
725 );
726
727 for file in project_state.get_outstanding_files() {
728 parsing_files_tx.try_send(file).unwrap();
729 }
730 }
731 }
732 }
733 }
734 });
735 };
736 }
737 });
738
739 this.projects.insert(
740 project.downgrade(),
741 Rc::new(RefCell::new(ProjectState {
742 pending_files: HashMap::new(),
743 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
744 _subscription,
745 })),
746 );
747 });
748
749 anyhow::Ok(())
750 })
751 }
752
753 pub fn search(
754 &mut self,
755 project: ModelHandle<Project>,
756 phrase: String,
757 limit: usize,
758 cx: &mut ModelContext<Self>,
759 ) -> Task<Result<Vec<SearchResult>>> {
760 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
761 state.borrow()
762 } else {
763 return Task::ready(Err(anyhow!("project not added")));
764 };
765
766 let worktree_db_ids = project
767 .read(cx)
768 .worktrees(cx)
769 .filter_map(|worktree| {
770 let worktree_id = worktree.read(cx).id();
771 project_state
772 .worktree_db_ids
773 .iter()
774 .find_map(|(id, db_id)| {
775 if *id == worktree_id {
776 Some(*db_id)
777 } else {
778 None
779 }
780 })
781 })
782 .collect::<Vec<_>>();
783
784 let embedding_provider = self.embedding_provider.clone();
785 let database_url = self.database_url.clone();
786 cx.spawn(|this, cx| async move {
787 let documents = cx
788 .background()
789 .spawn(async move {
790 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
791
792 let phrase_embedding = embedding_provider
793 .embed_batch(vec![&phrase])
794 .await?
795 .into_iter()
796 .next()
797 .unwrap();
798
799 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
800 database.for_each_document(&worktree_db_ids, |id, embedding| {
801 let similarity = dot(&embedding.0, &phrase_embedding);
802 let ix = match results.binary_search_by(|(_, s)| {
803 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
804 }) {
805 Ok(ix) => ix,
806 Err(ix) => ix,
807 };
808 results.insert(ix, (id, similarity));
809 results.truncate(limit);
810 })?;
811
812 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
813 database.get_documents_by_ids(&ids)
814 })
815 .await?;
816
817 this.read_with(&cx, |this, _| {
818 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
819 state.borrow()
820 } else {
821 return Err(anyhow!("project not added"));
822 };
823
824 Ok(documents
825 .into_iter()
826 .filter_map(|(worktree_db_id, file_path, offset, name)| {
827 let worktree_id =
828 project_state
829 .worktree_db_ids
830 .iter()
831 .find_map(|(id, db_id)| {
832 if *db_id == worktree_db_id {
833 Some(*id)
834 } else {
835 None
836 }
837 })?;
838 Some(SearchResult {
839 worktree_id,
840 name,
841 offset,
842 file_path,
843 })
844 })
845 .collect())
846 })
847 })
848 }
849}
850
851impl Entity for VectorStore {
852 type Event = ();
853}
854
855fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
856 let len = vec_a.len();
857 assert_eq!(len, vec_b.len());
858
859 let mut result = 0.0;
860 unsafe {
861 matrixmultiply::sgemm(
862 1,
863 len,
864 1,
865 1.0,
866 vec_a.as_ptr(),
867 len as isize,
868 1,
869 vec_b.as_ptr(),
870 1,
871 len as isize,
872 0.0,
873 &mut result as *mut f32,
874 1,
875 1,
876 );
877 }
878 result
879}