1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::VectorDatabase;
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use futures::{channel::oneshot, Future};
12use gpui::{
13 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
14 WeakModelHandle,
15};
16use language::{Language, LanguageRegistry};
17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
18use project::{Fs, Project, WorktreeId};
19use smol::channel;
20use std::{
21 cmp::Ordering,
22 collections::HashMap,
23 path::{Path, PathBuf},
24 sync::Arc,
25 time::{Instant, SystemTime},
26};
27use tree_sitter::{Parser, QueryCursor};
28use util::{
29 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
30 http::HttpClient,
31 paths::EMBEDDINGS_DIR,
32 ResultExt,
33};
34use workspace::{Workspace, WorkspaceCreated};
35
36const REINDEXING_DELAY: u64 = 30;
37const EMBEDDINGS_BATCH_SIZE: usize = 25;
38
39#[derive(Debug, Clone)]
40pub struct Document {
41 pub offset: usize,
42 pub name: String,
43 pub embedding: Vec<f32>,
44}
45
46pub fn init(
47 fs: Arc<dyn Fs>,
48 http_client: Arc<dyn HttpClient>,
49 language_registry: Arc<LanguageRegistry>,
50 cx: &mut AppContext,
51) {
52 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
53 return;
54 }
55
56 let db_file_path = EMBEDDINGS_DIR
57 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
58 .join("embeddings_db");
59
60 cx.spawn(move |mut cx| async move {
61 let vector_store = VectorStore::new(
62 fs,
63 db_file_path,
64 // Arc::new(embedding::DummyEmbeddings {}),
65 Arc::new(OpenAIEmbeddings {
66 client: http_client,
67 }),
68 language_registry,
69 cx.clone(),
70 )
71 .await?;
72
73 cx.update(|cx| {
74 cx.subscribe_global::<WorkspaceCreated, _>({
75 let vector_store = vector_store.clone();
76 move |event, cx| {
77 let workspace = &event.0;
78 if let Some(workspace) = workspace.upgrade(cx) {
79 let project = workspace.read(cx).project().clone();
80 if project.read(cx).is_local() {
81 vector_store.update(cx, |store, cx| {
82 store.add_project(project, cx).detach();
83 });
84 }
85 }
86 }
87 })
88 .detach();
89
90 cx.add_action({
91 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
92 let vector_store = vector_store.clone();
93 workspace.toggle_modal(cx, |workspace, cx| {
94 let project = workspace.project().clone();
95 let workspace = cx.weak_handle();
96 cx.add_view(|cx| {
97 SemanticSearch::new(
98 SemanticSearchDelegate::new(workspace, project, vector_store),
99 cx,
100 )
101 })
102 })
103 }
104 });
105
106 SemanticSearch::init(cx);
107 });
108
109 anyhow::Ok(())
110 })
111 .detach();
112}
113
114#[derive(Debug, Clone)]
115pub struct IndexedFile {
116 path: PathBuf,
117 mtime: SystemTime,
118 documents: Vec<Document>,
119}
120
121pub struct VectorStore {
122 fs: Arc<dyn Fs>,
123 database_url: Arc<PathBuf>,
124 embedding_provider: Arc<dyn EmbeddingProvider>,
125 language_registry: Arc<LanguageRegistry>,
126 db_update_tx: channel::Sender<DbWrite>,
127 paths_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
128 _db_update_task: Task<()>,
129 _paths_update_task: Task<()>,
130 _embeddings_update_task: Task<()>,
131 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
132}
133
134struct ProjectState {
135 worktree_db_ids: Vec<(WorktreeId, i64)>,
136 _subscription: gpui::Subscription,
137}
138
139#[derive(Debug, Clone)]
140pub struct SearchResult {
141 pub worktree_id: WorktreeId,
142 pub name: String,
143 pub offset: usize,
144 pub file_path: PathBuf,
145}
146
147enum DbWrite {
148 InsertFile {
149 worktree_id: i64,
150 indexed_file: IndexedFile,
151 },
152 Delete {
153 worktree_id: i64,
154 path: PathBuf,
155 },
156 FindOrCreateWorktree {
157 path: PathBuf,
158 sender: oneshot::Sender<Result<i64>>,
159 },
160}
161
162impl VectorStore {
163 async fn new(
164 fs: Arc<dyn Fs>,
165 database_url: PathBuf,
166 embedding_provider: Arc<dyn EmbeddingProvider>,
167 language_registry: Arc<LanguageRegistry>,
168 mut cx: AsyncAppContext,
169 ) -> Result<ModelHandle<Self>> {
170 let database_url = Arc::new(database_url);
171
172 let db = cx
173 .background()
174 .spawn({
175 let fs = fs.clone();
176 let database_url = database_url.clone();
177 async move {
178 if let Some(db_directory) = database_url.parent() {
179 fs.create_dir(db_directory).await.log_err();
180 }
181
182 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
183 anyhow::Ok(db)
184 }
185 })
186 .await?;
187
188 Ok(cx.add_model(|cx| {
189 // paths_tx -> embeddings_tx -> db_update_tx
190
191 let (db_update_tx, db_update_rx) = channel::unbounded();
192 let (paths_tx, paths_rx) =
193 channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
194 let (embeddings_tx, embeddings_rx) =
195 channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
196
197 let _db_update_task = cx.background().spawn(async move {
198 while let Ok(job) = db_update_rx.recv().await {
199 match job {
200 DbWrite::InsertFile {
201 worktree_id,
202 indexed_file,
203 } => {
204 db.insert_file(worktree_id, indexed_file).log_err();
205 }
206 DbWrite::Delete { worktree_id, path } => {
207 db.delete_file(worktree_id, path).log_err();
208 }
209 DbWrite::FindOrCreateWorktree { path, sender } => {
210 let id = db.find_or_create_worktree(&path);
211 sender.send(id).ok();
212 }
213 }
214 }
215 });
216
217 async fn embed_batch(
218 embeddings_queue: Vec<(i64, IndexedFile, Vec<String>)>,
219 embedding_provider: &Arc<dyn EmbeddingProvider>,
220 db_update_tx: channel::Sender<DbWrite>,
221 ) -> Result<()> {
222 let mut embeddings_queue = embeddings_queue.clone();
223
224 let mut document_spans = vec![];
225 for (_, _, document_span) in embeddings_queue.clone().into_iter() {
226 document_spans.extend(document_span);
227 }
228
229 let mut embeddings = embedding_provider
230 .embed_batch(document_spans.iter().map(|x| &**x).collect())
231 .await?;
232
233 // This assumes the embeddings are returned in order
234 let t0 = Instant::now();
235 let mut i = 0;
236 let mut j = 0;
237 while let Some(embedding) = embeddings.pop() {
238 // This has to accomodate for multiple indexed_files in a row without documents
239 while embeddings_queue[i].1.documents.len() == j {
240 i += 1;
241 j = 0;
242 }
243
244 embeddings_queue[i].1.documents[j].embedding = embedding;
245 j += 1;
246 }
247
248 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
249 // TODO: Update this so it doesnt panic
250 for document in indexed_file.documents.iter() {
251 assert!(
252 document.embedding.len() > 0,
253 "Document Embedding not Complete"
254 );
255 }
256
257 db_update_tx
258 .send(DbWrite::InsertFile {
259 worktree_id,
260 indexed_file,
261 })
262 .await
263 .unwrap();
264 }
265
266 anyhow::Ok(())
267 }
268
269 let embedding_provider_clone = embedding_provider.clone();
270
271 let db_update_tx_clone = db_update_tx.clone();
272 let _embeddings_update_task = cx.background().spawn(async move {
273 let mut queue_len = 0;
274 let mut embeddings_queue = vec![];
275 let mut request_count = 0;
276 while let Ok((worktree_id, indexed_file, document_spans)) =
277 embeddings_rx.recv().await
278 {
279 queue_len += &document_spans.len();
280 embeddings_queue.push((worktree_id, indexed_file, document_spans));
281
282 if queue_len >= EMBEDDINGS_BATCH_SIZE {
283 let _ = embed_batch(
284 embeddings_queue,
285 &embedding_provider_clone,
286 db_update_tx_clone.clone(),
287 )
288 .await;
289
290 embeddings_queue = vec![];
291 queue_len = 0;
292
293 request_count += 1;
294 }
295 }
296
297 if queue_len > 0 {
298 let _ = embed_batch(
299 embeddings_queue,
300 &embedding_provider_clone,
301 db_update_tx_clone.clone(),
302 )
303 .await;
304 request_count += 1;
305 }
306 });
307
308 let fs_clone = fs.clone();
309
310 let _paths_update_task = cx.background().spawn(async move {
311 let mut parser = Parser::new();
312 let mut cursor = QueryCursor::new();
313 while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await {
314 if let Some((indexed_file, document_spans)) = Self::index_file(
315 &mut cursor,
316 &mut parser,
317 &fs_clone,
318 language,
319 file_path.clone(),
320 mtime,
321 )
322 .await
323 .log_err()
324 {
325 embeddings_tx
326 .try_send((worktree_id, indexed_file, document_spans))
327 .unwrap();
328 }
329 }
330 });
331
332 Self {
333 fs,
334 database_url,
335 db_update_tx,
336 paths_tx,
337 embedding_provider,
338 language_registry,
339 projects: HashMap::new(),
340 _db_update_task,
341 _paths_update_task,
342 _embeddings_update_task,
343 }
344 }))
345 }
346
347 async fn index_file(
348 cursor: &mut QueryCursor,
349 parser: &mut Parser,
350 fs: &Arc<dyn Fs>,
351 language: Arc<Language>,
352 file_path: PathBuf,
353 mtime: SystemTime,
354 ) -> Result<(IndexedFile, Vec<String>)> {
355 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
356 let embedding_config = grammar
357 .embedding_config
358 .as_ref()
359 .ok_or_else(|| anyhow!("no outline query"))?;
360
361 let content = fs.load(&file_path).await?;
362
363 parser.set_language(grammar.ts_language).unwrap();
364 let tree = parser
365 .parse(&content, None)
366 .ok_or_else(|| anyhow!("parsing failed"))?;
367
368 let mut documents = Vec::new();
369 let mut context_spans = Vec::new();
370 for mat in cursor.matches(
371 &embedding_config.query,
372 tree.root_node(),
373 content.as_bytes(),
374 ) {
375 let mut item_range = None;
376 let mut name_range = None;
377 for capture in mat.captures {
378 if capture.index == embedding_config.item_capture_ix {
379 item_range = Some(capture.node.byte_range());
380 } else if capture.index == embedding_config.name_capture_ix {
381 name_range = Some(capture.node.byte_range());
382 }
383 }
384
385 if let Some((item_range, name_range)) = item_range.zip(name_range) {
386 if let Some((item, name)) =
387 content.get(item_range.clone()).zip(content.get(name_range))
388 {
389 context_spans.push(item.to_string());
390 documents.push(Document {
391 name: name.to_string(),
392 offset: item_range.start,
393 embedding: Vec::new(),
394 });
395 }
396 }
397 }
398
399 return Ok((
400 IndexedFile {
401 path: file_path,
402 mtime,
403 documents,
404 },
405 context_spans,
406 ));
407 }
408
409 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
410 let (tx, rx) = oneshot::channel();
411 self.db_update_tx
412 .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
413 .unwrap();
414 async move { rx.await? }
415 }
416
417 fn add_project(
418 &mut self,
419 project: ModelHandle<Project>,
420 cx: &mut ModelContext<Self>,
421 ) -> Task<Result<()>> {
422 let worktree_scans_complete = project
423 .read(cx)
424 .worktrees(cx)
425 .map(|worktree| {
426 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
427 async move {
428 scan_complete.await;
429 }
430 })
431 .collect::<Vec<_>>();
432 let worktree_db_ids = project
433 .read(cx)
434 .worktrees(cx)
435 .map(|worktree| {
436 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
437 })
438 .collect::<Vec<_>>();
439
440 let fs = self.fs.clone();
441 let language_registry = self.language_registry.clone();
442 let database_url = self.database_url.clone();
443 let db_update_tx = self.db_update_tx.clone();
444 let paths_tx = self.paths_tx.clone();
445
446 cx.spawn(|this, mut cx| async move {
447 futures::future::join_all(worktree_scans_complete).await;
448
449 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
450
451 if let Some(db_directory) = database_url.parent() {
452 fs.create_dir(db_directory).await.log_err();
453 }
454
455 let worktrees = project.read_with(&cx, |project, cx| {
456 project
457 .worktrees(cx)
458 .map(|worktree| worktree.read(cx).snapshot())
459 .collect::<Vec<_>>()
460 });
461
462 // Here we query the worktree ids, and yet we dont have them elsewhere
463 // We likely want to clean up these datastructures
464 let (mut worktree_file_times, db_ids_by_worktree_id) = cx
465 .background()
466 .spawn({
467 let worktrees = worktrees.clone();
468 async move {
469 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
470 let mut db_ids_by_worktree_id = HashMap::new();
471 let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
472 HashMap::new();
473 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
474 let db_id = db_id?;
475 db_ids_by_worktree_id.insert(worktree.id(), db_id);
476 file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
477 }
478 anyhow::Ok((file_times, db_ids_by_worktree_id))
479 }
480 })
481 .await?;
482
483 cx.background()
484 .spawn({
485 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
486 let db_update_tx = db_update_tx.clone();
487 let language_registry = language_registry.clone();
488 let paths_tx = paths_tx.clone();
489 async move {
490 for worktree in worktrees.into_iter() {
491 let mut file_mtimes =
492 worktree_file_times.remove(&worktree.id()).unwrap();
493 for file in worktree.files(false, 0) {
494 let absolute_path = worktree.absolutize(&file.path);
495
496 if let Ok(language) = language_registry
497 .language_for_file(&absolute_path, None)
498 .await
499 {
500 if language
501 .grammar()
502 .and_then(|grammar| grammar.embedding_config.as_ref())
503 .is_none()
504 {
505 continue;
506 }
507
508 let path_buf = file.path.to_path_buf();
509 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
510 let already_stored = stored_mtime
511 .map_or(false, |existing_mtime| {
512 existing_mtime == file.mtime
513 });
514
515 if !already_stored {
516 paths_tx
517 .try_send((
518 db_ids_by_worktree_id[&worktree.id()],
519 path_buf,
520 language,
521 file.mtime,
522 ))
523 .unwrap();
524 }
525 }
526 }
527 for file in file_mtimes.keys() {
528 db_update_tx
529 .try_send(DbWrite::Delete {
530 worktree_id: db_ids_by_worktree_id[&worktree.id()],
531 path: file.to_owned(),
532 })
533 .unwrap();
534 }
535 }
536 }
537 })
538 .detach();
539
540 this.update(&mut cx, |this, cx| {
541 // The below is managing for updated on save
542 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
543 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
544 let _subscription = cx.subscribe(&project, |this, project, event, _cx| {
545 if let Some(project_state) = this.projects.get(&project.downgrade()) {
546 let worktree_db_ids = project_state.worktree_db_ids.clone();
547
548 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
549 {
550 // Iterate through changes
551 let language_registry = this.language_registry.clone();
552
553 let db =
554 VectorDatabase::new(this.database_url.to_string_lossy().into());
555 if db.is_err() {
556 return;
557 }
558 let db = db.unwrap();
559
560 let worktree_db_id: Option<i64> = {
561 let mut found_db_id = None;
562 for (w_id, db_id) in worktree_db_ids.into_iter() {
563 if &w_id == worktree_id {
564 found_db_id = Some(db_id);
565 }
566 }
567
568 found_db_id
569 };
570
571 if worktree_db_id.is_none() {
572 return;
573 }
574 let worktree_db_id = worktree_db_id.unwrap();
575
576 let file_mtimes = db.get_file_mtimes(worktree_db_id);
577 if file_mtimes.is_err() {
578 return;
579 }
580
581 let file_mtimes = file_mtimes.unwrap();
582 let paths_tx = this.paths_tx.clone();
583
584 smol::block_on(async move {
585 for change in changes.into_iter() {
586 let change_path = change.0.clone();
587 log::info!("Change: {:?}", &change_path);
588 if let Ok(language) = language_registry
589 .language_for_file(&change_path.to_path_buf(), None)
590 .await
591 {
592 if language
593 .grammar()
594 .and_then(|grammar| grammar.embedding_config.as_ref())
595 .is_none()
596 {
597 continue;
598 }
599
600 // TODO: Make this a bit more defensive
601 let modified_time =
602 change_path.metadata().unwrap().modified().unwrap();
603 let existing_time =
604 file_mtimes.get(&change_path.to_path_buf());
605 let already_stored =
606 existing_time.map_or(false, |existing_time| {
607 if &modified_time != existing_time
608 && existing_time.elapsed().unwrap().as_secs()
609 > REINDEXING_DELAY
610 {
611 false
612 } else {
613 true
614 }
615 });
616
617 if !already_stored {
618 log::info!("Need to reindex: {:?}", &change_path);
619 paths_tx
620 .try_send((
621 worktree_db_id,
622 change_path.to_path_buf(),
623 language,
624 modified_time,
625 ))
626 .unwrap();
627 }
628 }
629 }
630 })
631 }
632 }
633 });
634
635 this.projects.insert(
636 project.downgrade(),
637 ProjectState {
638 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
639 _subscription,
640 },
641 );
642 });
643
644 anyhow::Ok(())
645 })
646 }
647
648 pub fn search(
649 &mut self,
650 project: ModelHandle<Project>,
651 phrase: String,
652 limit: usize,
653 cx: &mut ModelContext<Self>,
654 ) -> Task<Result<Vec<SearchResult>>> {
655 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
656 state
657 } else {
658 return Task::ready(Err(anyhow!("project not added")));
659 };
660
661 let worktree_db_ids = project
662 .read(cx)
663 .worktrees(cx)
664 .filter_map(|worktree| {
665 let worktree_id = worktree.read(cx).id();
666 project_state
667 .worktree_db_ids
668 .iter()
669 .find_map(|(id, db_id)| {
670 if *id == worktree_id {
671 Some(*db_id)
672 } else {
673 None
674 }
675 })
676 })
677 .collect::<Vec<_>>();
678
679 let embedding_provider = self.embedding_provider.clone();
680 let database_url = self.database_url.clone();
681 cx.spawn(|this, cx| async move {
682 let documents = cx
683 .background()
684 .spawn(async move {
685 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
686
687 let phrase_embedding = embedding_provider
688 .embed_batch(vec![&phrase])
689 .await?
690 .into_iter()
691 .next()
692 .unwrap();
693
694 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
695 database.for_each_document(&worktree_db_ids, |id, embedding| {
696 let similarity = dot(&embedding.0, &phrase_embedding);
697 let ix = match results.binary_search_by(|(_, s)| {
698 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
699 }) {
700 Ok(ix) => ix,
701 Err(ix) => ix,
702 };
703 results.insert(ix, (id, similarity));
704 results.truncate(limit);
705 })?;
706
707 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
708 database.get_documents_by_ids(&ids)
709 })
710 .await?;
711
712 this.read_with(&cx, |this, _| {
713 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
714 state
715 } else {
716 return Err(anyhow!("project not added"));
717 };
718
719 Ok(documents
720 .into_iter()
721 .filter_map(|(worktree_db_id, file_path, offset, name)| {
722 let worktree_id =
723 project_state
724 .worktree_db_ids
725 .iter()
726 .find_map(|(id, db_id)| {
727 if *db_id == worktree_db_id {
728 Some(*id)
729 } else {
730 None
731 }
732 })?;
733 Some(SearchResult {
734 worktree_id,
735 name,
736 offset,
737 file_path,
738 })
739 })
740 .collect())
741 })
742 })
743 }
744}
745
746impl Entity for VectorStore {
747 type Event = ();
748}
749
750fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
751 let len = vec_a.len();
752 assert_eq!(len, vec_b.len());
753
754 let mut result = 0.0;
755 unsafe {
756 matrixmultiply::sgemm(
757 1,
758 len,
759 1,
760 1.0,
761 vec_a.as_ptr(),
762 len as isize,
763 1,
764 vec_b.as_ptr(),
765 1,
766 len as isize,
767 0.0,
768 &mut result as *mut f32,
769 1,
770 1,
771 );
772 }
773 result
774}