1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::VectorDatabase;
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use futures::{channel::oneshot, Future};
12use gpui::{
13 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
14 WeakModelHandle,
15};
16use language::{Language, LanguageRegistry};
17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
18use project::{Fs, Project, WorktreeId};
19use smol::channel;
20use std::{
21 cmp::Ordering,
22 collections::HashMap,
23 path::{Path, PathBuf},
24 sync::Arc,
25 time::{Instant, SystemTime},
26};
27use tree_sitter::{Parser, QueryCursor};
28use util::{
29 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
30 http::HttpClient,
31 paths::EMBEDDINGS_DIR,
32 ResultExt,
33};
34use workspace::{Workspace, WorkspaceCreated};
35
36const REINDEXING_DELAY: u64 = 30;
37const EMBEDDINGS_BATCH_SIZE: usize = 150;
38
39#[derive(Debug, Clone)]
40pub struct Document {
41 pub offset: usize,
42 pub name: String,
43 pub embedding: Vec<f32>,
44}
45
46pub fn init(
47 fs: Arc<dyn Fs>,
48 http_client: Arc<dyn HttpClient>,
49 language_registry: Arc<LanguageRegistry>,
50 cx: &mut AppContext,
51) {
52 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
53 return;
54 }
55
56 let db_file_path = EMBEDDINGS_DIR
57 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
58 .join("embeddings_db");
59
60 cx.spawn(move |mut cx| async move {
61 let vector_store = VectorStore::new(
62 fs,
63 db_file_path,
64 // Arc::new(embedding::DummyEmbeddings {}),
65 Arc::new(OpenAIEmbeddings {
66 client: http_client,
67 }),
68 language_registry,
69 cx.clone(),
70 )
71 .await?;
72
73 cx.update(|cx| {
74 cx.subscribe_global::<WorkspaceCreated, _>({
75 let vector_store = vector_store.clone();
76 move |event, cx| {
77 let workspace = &event.0;
78 if let Some(workspace) = workspace.upgrade(cx) {
79 let project = workspace.read(cx).project().clone();
80 if project.read(cx).is_local() {
81 vector_store.update(cx, |store, cx| {
82 store.add_project(project, cx).detach();
83 });
84 }
85 }
86 }
87 })
88 .detach();
89
90 cx.add_action({
91 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
92 let vector_store = vector_store.clone();
93 workspace.toggle_modal(cx, |workspace, cx| {
94 let project = workspace.project().clone();
95 let workspace = cx.weak_handle();
96 cx.add_view(|cx| {
97 SemanticSearch::new(
98 SemanticSearchDelegate::new(workspace, project, vector_store),
99 cx,
100 )
101 })
102 })
103 }
104 });
105
106 SemanticSearch::init(cx);
107 });
108
109 anyhow::Ok(())
110 })
111 .detach();
112}
113
114#[derive(Debug, Clone)]
115pub struct IndexedFile {
116 path: PathBuf,
117 mtime: SystemTime,
118 documents: Vec<Document>,
119}
120
121pub struct VectorStore {
122 fs: Arc<dyn Fs>,
123 database_url: Arc<PathBuf>,
124 embedding_provider: Arc<dyn EmbeddingProvider>,
125 language_registry: Arc<LanguageRegistry>,
126 db_update_tx: channel::Sender<DbWrite>,
127 // embed_batch_tx: channel::Sender<Vec<(i64, IndexedFile, Vec<String>)>>,
128 parsing_files_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
129 _db_update_task: Task<()>,
130 _embed_batch_task: Vec<Task<()>>,
131 _batch_files_task: Task<()>,
132 _parsing_files_tasks: Vec<Task<()>>,
133 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
134}
135
136struct ProjectState {
137 worktree_db_ids: Vec<(WorktreeId, i64)>,
138 _subscription: gpui::Subscription,
139}
140
141#[derive(Debug, Clone)]
142pub struct SearchResult {
143 pub worktree_id: WorktreeId,
144 pub name: String,
145 pub offset: usize,
146 pub file_path: PathBuf,
147}
148
149enum DbWrite {
150 InsertFile {
151 worktree_id: i64,
152 indexed_file: IndexedFile,
153 },
154 Delete {
155 worktree_id: i64,
156 path: PathBuf,
157 },
158 FindOrCreateWorktree {
159 path: PathBuf,
160 sender: oneshot::Sender<Result<i64>>,
161 },
162}
163
164impl VectorStore {
165 async fn new(
166 fs: Arc<dyn Fs>,
167 database_url: PathBuf,
168 embedding_provider: Arc<dyn EmbeddingProvider>,
169 language_registry: Arc<LanguageRegistry>,
170 mut cx: AsyncAppContext,
171 ) -> Result<ModelHandle<Self>> {
172 let database_url = Arc::new(database_url);
173
174 let db = cx
175 .background()
176 .spawn({
177 let fs = fs.clone();
178 let database_url = database_url.clone();
179 async move {
180 if let Some(db_directory) = database_url.parent() {
181 fs.create_dir(db_directory).await.log_err();
182 }
183
184 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
185 anyhow::Ok(db)
186 }
187 })
188 .await?;
189
190 Ok(cx.add_model(|cx| {
191 // paths_tx -> embeddings_tx -> db_update_tx
192
193 //db_update_tx/rx: Updating Database
194 let (db_update_tx, db_update_rx) = channel::unbounded();
195 let _db_update_task = cx.background().spawn(async move {
196 while let Ok(job) = db_update_rx.recv().await {
197 match job {
198 DbWrite::InsertFile {
199 worktree_id,
200 indexed_file,
201 } => {
202 log::info!("Inserting Data for {:?}", &indexed_file.path);
203 db.insert_file(worktree_id, indexed_file).log_err();
204 }
205 DbWrite::Delete { worktree_id, path } => {
206 db.delete_file(worktree_id, path).log_err();
207 }
208 DbWrite::FindOrCreateWorktree { path, sender } => {
209 let id = db.find_or_create_worktree(&path);
210 sender.send(id).ok();
211 }
212 }
213 }
214 });
215
216 // embed_tx/rx: Embed Batch and Send to Database
217 let (embed_batch_tx, embed_batch_rx) =
218 channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
219 let mut _embed_batch_task = Vec::new();
220 for _ in 0..1 {
221 //cx.background().num_cpus() {
222 let db_update_tx = db_update_tx.clone();
223 let embed_batch_rx = embed_batch_rx.clone();
224 let embedding_provider = embedding_provider.clone();
225 _embed_batch_task.push(cx.background().spawn(async move {
226 while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
227 // Construct Batch
228 let mut embeddings_queue = embeddings_queue.clone();
229 let mut document_spans = vec![];
230 for (_, _, document_span) in embeddings_queue.clone().into_iter() {
231 document_spans.extend(document_span);
232 }
233
234 if let Ok(embeddings) = embedding_provider
235 .embed_batch(document_spans.iter().map(|x| &**x).collect())
236 .await
237 {
238 let mut i = 0;
239 let mut j = 0;
240
241 for embedding in embeddings.iter() {
242 while embeddings_queue[i].1.documents.len() == j {
243 i += 1;
244 j = 0;
245 }
246
247 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
248 j += 1;
249 }
250
251 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
252 for document in indexed_file.documents.iter() {
253 // TODO: Update this so it doesn't panic
254 assert!(
255 document.embedding.len() > 0,
256 "Document Embedding Not Complete"
257 );
258 }
259
260 db_update_tx
261 .send(DbWrite::InsertFile {
262 worktree_id,
263 indexed_file,
264 })
265 .await
266 .unwrap();
267 }
268 }
269 }
270 }))
271 }
272
273 // batch_tx/rx: Batch Files to Send for Embeddings
274 let (batch_files_tx, batch_files_rx) =
275 channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
276 let _batch_files_task = cx.background().spawn(async move {
277 let mut queue_len = 0;
278 let mut embeddings_queue = vec![];
279 while let Ok((worktree_id, indexed_file, document_spans)) =
280 batch_files_rx.recv().await
281 {
282 queue_len += &document_spans.len();
283 embeddings_queue.push((worktree_id, indexed_file, document_spans));
284 if queue_len >= EMBEDDINGS_BATCH_SIZE {
285 embed_batch_tx.try_send(embeddings_queue).unwrap();
286 embeddings_queue = vec![];
287 queue_len = 0;
288 }
289 }
290 if queue_len > 0 {
291 embed_batch_tx.try_send(embeddings_queue).unwrap();
292 }
293 });
294
295 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
296 let (parsing_files_tx, parsing_files_rx) =
297 channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
298
299 let mut _parsing_files_tasks = Vec::new();
300 for _ in 0..cx.background().num_cpus() {
301 let fs = fs.clone();
302 let parsing_files_rx = parsing_files_rx.clone();
303 let batch_files_tx = batch_files_tx.clone();
304 _parsing_files_tasks.push(cx.background().spawn(async move {
305 let mut parser = Parser::new();
306 let mut cursor = QueryCursor::new();
307 while let Ok((worktree_id, file_path, language, mtime)) =
308 parsing_files_rx.recv().await
309 {
310 log::info!("Parsing File: {:?}", &file_path);
311 if let Some((indexed_file, document_spans)) = Self::index_file(
312 &mut cursor,
313 &mut parser,
314 &fs,
315 language,
316 file_path.clone(),
317 mtime,
318 )
319 .await
320 .log_err()
321 {
322 batch_files_tx
323 .try_send((worktree_id, indexed_file, document_spans))
324 .unwrap();
325 }
326 }
327 }));
328 }
329
330 Self {
331 fs,
332 database_url,
333 embedding_provider,
334 language_registry,
335 db_update_tx,
336 parsing_files_tx,
337 _db_update_task,
338 _embed_batch_task,
339 _batch_files_task,
340 _parsing_files_tasks,
341 projects: HashMap::new(),
342 }
343 }))
344 }
345
346 async fn index_file(
347 cursor: &mut QueryCursor,
348 parser: &mut Parser,
349 fs: &Arc<dyn Fs>,
350 language: Arc<Language>,
351 file_path: PathBuf,
352 mtime: SystemTime,
353 ) -> Result<(IndexedFile, Vec<String>)> {
354 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
355 let embedding_config = grammar
356 .embedding_config
357 .as_ref()
358 .ok_or_else(|| anyhow!("no outline query"))?;
359
360 let content = fs.load(&file_path).await?;
361
362 parser.set_language(grammar.ts_language).unwrap();
363 let tree = parser
364 .parse(&content, None)
365 .ok_or_else(|| anyhow!("parsing failed"))?;
366
367 let mut documents = Vec::new();
368 let mut context_spans = Vec::new();
369 for mat in cursor.matches(
370 &embedding_config.query,
371 tree.root_node(),
372 content.as_bytes(),
373 ) {
374 let mut item_range = None;
375 let mut name_range = None;
376 for capture in mat.captures {
377 if capture.index == embedding_config.item_capture_ix {
378 item_range = Some(capture.node.byte_range());
379 } else if capture.index == embedding_config.name_capture_ix {
380 name_range = Some(capture.node.byte_range());
381 }
382 }
383
384 if let Some((item_range, name_range)) = item_range.zip(name_range) {
385 if let Some((item, name)) =
386 content.get(item_range.clone()).zip(content.get(name_range))
387 {
388 context_spans.push(item.to_string());
389 documents.push(Document {
390 name: name.to_string(),
391 offset: item_range.start,
392 embedding: Vec::new(),
393 });
394 }
395 }
396 }
397
398 return Ok((
399 IndexedFile {
400 path: file_path,
401 mtime,
402 documents,
403 },
404 context_spans,
405 ));
406 }
407
408 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
409 let (tx, rx) = oneshot::channel();
410 self.db_update_tx
411 .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
412 .unwrap();
413 async move { rx.await? }
414 }
415
416 fn add_project(
417 &mut self,
418 project: ModelHandle<Project>,
419 cx: &mut ModelContext<Self>,
420 ) -> Task<Result<()>> {
421 let worktree_scans_complete = project
422 .read(cx)
423 .worktrees(cx)
424 .map(|worktree| {
425 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
426 async move {
427 scan_complete.await;
428 }
429 })
430 .collect::<Vec<_>>();
431 let worktree_db_ids = project
432 .read(cx)
433 .worktrees(cx)
434 .map(|worktree| {
435 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
436 })
437 .collect::<Vec<_>>();
438
439 let fs = self.fs.clone();
440 let language_registry = self.language_registry.clone();
441 let database_url = self.database_url.clone();
442 let db_update_tx = self.db_update_tx.clone();
443 let parsing_files_tx = self.parsing_files_tx.clone();
444
445 cx.spawn(|this, mut cx| async move {
446 let t0 = Instant::now();
447 futures::future::join_all(worktree_scans_complete).await;
448
449 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
450 log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
451
452 if let Some(db_directory) = database_url.parent() {
453 fs.create_dir(db_directory).await.log_err();
454 }
455
456 let worktrees = project.read_with(&cx, |project, cx| {
457 project
458 .worktrees(cx)
459 .map(|worktree| worktree.read(cx).snapshot())
460 .collect::<Vec<_>>()
461 });
462
463 // Here we query the worktree ids, and yet we dont have them elsewhere
464 // We likely want to clean up these datastructures
465 let (mut worktree_file_times, db_ids_by_worktree_id) = cx
466 .background()
467 .spawn({
468 let worktrees = worktrees.clone();
469 async move {
470 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
471 let mut db_ids_by_worktree_id = HashMap::new();
472 let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
473 HashMap::new();
474 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
475 let db_id = db_id?;
476 db_ids_by_worktree_id.insert(worktree.id(), db_id);
477 file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
478 }
479 anyhow::Ok((file_times, db_ids_by_worktree_id))
480 }
481 })
482 .await?;
483
484 cx.background()
485 .spawn({
486 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
487 let db_update_tx = db_update_tx.clone();
488 let language_registry = language_registry.clone();
489 let parsing_files_tx = parsing_files_tx.clone();
490 async move {
491 let t0 = Instant::now();
492 for worktree in worktrees.into_iter() {
493 let mut file_mtimes =
494 worktree_file_times.remove(&worktree.id()).unwrap();
495 for file in worktree.files(false, 0) {
496 let absolute_path = worktree.absolutize(&file.path);
497
498 if let Ok(language) = language_registry
499 .language_for_file(&absolute_path, None)
500 .await
501 {
502 if language
503 .grammar()
504 .and_then(|grammar| grammar.embedding_config.as_ref())
505 .is_none()
506 {
507 continue;
508 }
509
510 let path_buf = file.path.to_path_buf();
511 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
512 let already_stored = stored_mtime
513 .map_or(false, |existing_mtime| {
514 existing_mtime == file.mtime
515 });
516
517 if !already_stored {
518 parsing_files_tx
519 .try_send((
520 db_ids_by_worktree_id[&worktree.id()],
521 path_buf,
522 language,
523 file.mtime,
524 ))
525 .unwrap();
526 }
527 }
528 }
529 for file in file_mtimes.keys() {
530 db_update_tx
531 .try_send(DbWrite::Delete {
532 worktree_id: db_ids_by_worktree_id[&worktree.id()],
533 path: file.to_owned(),
534 })
535 .unwrap();
536 }
537 }
538 log::info!(
539 "Parsing Worktree Completed in {:?}",
540 t0.elapsed().as_millis()
541 );
542 }
543 })
544 .detach();
545
546 this.update(&mut cx, |this, cx| {
547 // The below is managing for updated on save
548 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
549 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
550 let _subscription = cx.subscribe(&project, |this, project, event, _cx| {
551 if let Some(project_state) = this.projects.get(&project.downgrade()) {
552 let worktree_db_ids = project_state.worktree_db_ids.clone();
553
554 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
555 {
556 // Iterate through changes
557 let language_registry = this.language_registry.clone();
558
559 let db =
560 VectorDatabase::new(this.database_url.to_string_lossy().into());
561 if db.is_err() {
562 return;
563 }
564 let db = db.unwrap();
565
566 let worktree_db_id: Option<i64> = {
567 let mut found_db_id = None;
568 for (w_id, db_id) in worktree_db_ids.into_iter() {
569 if &w_id == worktree_id {
570 found_db_id = Some(db_id);
571 }
572 }
573
574 found_db_id
575 };
576
577 if worktree_db_id.is_none() {
578 return;
579 }
580 let worktree_db_id = worktree_db_id.unwrap();
581
582 let file_mtimes = db.get_file_mtimes(worktree_db_id);
583 if file_mtimes.is_err() {
584 return;
585 }
586
587 let file_mtimes = file_mtimes.unwrap();
588 let parsing_files_tx = this.parsing_files_tx.clone();
589
590 smol::block_on(async move {
591 for change in changes.into_iter() {
592 let change_path = change.0.clone();
593 log::info!("Change: {:?}", &change_path);
594 if let Ok(language) = language_registry
595 .language_for_file(&change_path.to_path_buf(), None)
596 .await
597 {
598 if language
599 .grammar()
600 .and_then(|grammar| grammar.embedding_config.as_ref())
601 .is_none()
602 {
603 continue;
604 }
605
606 // TODO: Make this a bit more defensive
607 let modified_time =
608 change_path.metadata().unwrap().modified().unwrap();
609 let existing_time =
610 file_mtimes.get(&change_path.to_path_buf());
611 let already_stored =
612 existing_time.map_or(false, |existing_time| {
613 if &modified_time != existing_time
614 && existing_time.elapsed().unwrap().as_secs()
615 > REINDEXING_DELAY
616 {
617 false
618 } else {
619 true
620 }
621 });
622
623 if !already_stored {
624 log::info!("Need to reindex: {:?}", &change_path);
625 parsing_files_tx
626 .try_send((
627 worktree_db_id,
628 change_path.to_path_buf(),
629 language,
630 modified_time,
631 ))
632 .unwrap();
633 }
634 }
635 }
636 })
637 }
638 }
639 });
640
641 this.projects.insert(
642 project.downgrade(),
643 ProjectState {
644 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
645 _subscription,
646 },
647 );
648 });
649
650 anyhow::Ok(())
651 })
652 }
653
654 pub fn search(
655 &mut self,
656 project: ModelHandle<Project>,
657 phrase: String,
658 limit: usize,
659 cx: &mut ModelContext<Self>,
660 ) -> Task<Result<Vec<SearchResult>>> {
661 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
662 state
663 } else {
664 return Task::ready(Err(anyhow!("project not added")));
665 };
666
667 let worktree_db_ids = project
668 .read(cx)
669 .worktrees(cx)
670 .filter_map(|worktree| {
671 let worktree_id = worktree.read(cx).id();
672 project_state
673 .worktree_db_ids
674 .iter()
675 .find_map(|(id, db_id)| {
676 if *id == worktree_id {
677 Some(*db_id)
678 } else {
679 None
680 }
681 })
682 })
683 .collect::<Vec<_>>();
684
685 let embedding_provider = self.embedding_provider.clone();
686 let database_url = self.database_url.clone();
687 cx.spawn(|this, cx| async move {
688 let documents = cx
689 .background()
690 .spawn(async move {
691 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
692
693 let phrase_embedding = embedding_provider
694 .embed_batch(vec![&phrase])
695 .await?
696 .into_iter()
697 .next()
698 .unwrap();
699
700 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
701 database.for_each_document(&worktree_db_ids, |id, embedding| {
702 let similarity = dot(&embedding.0, &phrase_embedding);
703 let ix = match results.binary_search_by(|(_, s)| {
704 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
705 }) {
706 Ok(ix) => ix,
707 Err(ix) => ix,
708 };
709 results.insert(ix, (id, similarity));
710 results.truncate(limit);
711 })?;
712
713 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
714 database.get_documents_by_ids(&ids)
715 })
716 .await?;
717
718 this.read_with(&cx, |this, _| {
719 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
720 state
721 } else {
722 return Err(anyhow!("project not added"));
723 };
724
725 Ok(documents
726 .into_iter()
727 .filter_map(|(worktree_db_id, file_path, offset, name)| {
728 let worktree_id =
729 project_state
730 .worktree_db_ids
731 .iter()
732 .find_map(|(id, db_id)| {
733 if *db_id == worktree_db_id {
734 Some(*id)
735 } else {
736 None
737 }
738 })?;
739 Some(SearchResult {
740 worktree_id,
741 name,
742 offset,
743 file_path,
744 })
745 })
746 .collect())
747 })
748 })
749 }
750}
751
752impl Entity for VectorStore {
753 type Event = ();
754}
755
756fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
757 let len = vec_a.len();
758 assert_eq!(len, vec_b.len());
759
760 let mut result = 0.0;
761 unsafe {
762 matrixmultiply::sgemm(
763 1,
764 len,
765 1,
766 1.0,
767 vec_a.as_ptr(),
768 len as isize,
769 1,
770 vec_b.as_ptr(),
771 1,
772 len as isize,
773 0.0,
774 &mut result as *mut f32,
775 1,
776 1,
777 );
778 }
779 result
780}