1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::VectorDatabase;
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use futures::{channel::oneshot, Future};
12use gpui::{
13 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
14 WeakModelHandle,
15};
16use language::{Language, LanguageRegistry};
17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
18use project::{Fs, Project, WorktreeId};
19use smol::channel;
20use std::{
21 cmp::Ordering,
22 collections::HashMap,
23 path::{Path, PathBuf},
24 sync::Arc,
25 time::SystemTime,
26};
27use tree_sitter::{Parser, QueryCursor};
28use util::{
29 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
30 http::HttpClient,
31 paths::EMBEDDINGS_DIR,
32 ResultExt,
33};
34use workspace::{Workspace, WorkspaceCreated};
35
36#[derive(Debug)]
37pub struct Document {
38 pub offset: usize,
39 pub name: String,
40 pub embedding: Vec<f32>,
41}
42
43pub fn init(
44 fs: Arc<dyn Fs>,
45 http_client: Arc<dyn HttpClient>,
46 language_registry: Arc<LanguageRegistry>,
47 cx: &mut AppContext,
48) {
49 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
50 return;
51 }
52
53 let db_file_path = EMBEDDINGS_DIR
54 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
55 .join("embeddings_db");
56
57 cx.spawn(move |mut cx| async move {
58 let vector_store = VectorStore::new(
59 fs,
60 db_file_path,
61 Arc::new(embedding::DummyEmbeddings {}),
62 // Arc::new(OpenAIEmbeddings {
63 // client: http_client,
64 // }),
65 language_registry,
66 cx.clone(),
67 )
68 .await?;
69
70 cx.update(|cx| {
71 cx.subscribe_global::<WorkspaceCreated, _>({
72 let vector_store = vector_store.clone();
73 move |event, cx| {
74 let workspace = &event.0;
75 if let Some(workspace) = workspace.upgrade(cx) {
76 let project = workspace.read(cx).project().clone();
77 if project.read(cx).is_local() {
78 vector_store.update(cx, |store, cx| {
79 store.add_project(project, cx).detach();
80 });
81 }
82 }
83 }
84 })
85 .detach();
86
87 cx.add_action({
88 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
89 let vector_store = vector_store.clone();
90 workspace.toggle_modal(cx, |workspace, cx| {
91 let project = workspace.project().clone();
92 let workspace = cx.weak_handle();
93 cx.add_view(|cx| {
94 SemanticSearch::new(
95 SemanticSearchDelegate::new(workspace, project, vector_store),
96 cx,
97 )
98 })
99 })
100 }
101 });
102
103 SemanticSearch::init(cx);
104 });
105
106 anyhow::Ok(())
107 })
108 .detach();
109}
110
111#[derive(Debug)]
112pub struct IndexedFile {
113 path: PathBuf,
114 mtime: SystemTime,
115 documents: Vec<Document>,
116}
117
118pub struct VectorStore {
119 fs: Arc<dyn Fs>,
120 database_url: Arc<PathBuf>,
121 embedding_provider: Arc<dyn EmbeddingProvider>,
122 language_registry: Arc<LanguageRegistry>,
123 db_update_tx: channel::Sender<DbWrite>,
124 _db_update_task: Task<()>,
125 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
126}
127
128struct ProjectState {
129 worktree_db_ids: Vec<(WorktreeId, i64)>,
130 _subscription: gpui::Subscription,
131}
132
133#[derive(Debug, Clone)]
134pub struct SearchResult {
135 pub worktree_id: WorktreeId,
136 pub name: String,
137 pub offset: usize,
138 pub file_path: PathBuf,
139}
140
141enum DbWrite {
142 InsertFile {
143 worktree_id: i64,
144 indexed_file: IndexedFile,
145 },
146 Delete {
147 worktree_id: i64,
148 path: PathBuf,
149 },
150 FindOrCreateWorktree {
151 path: PathBuf,
152 sender: oneshot::Sender<Result<i64>>,
153 },
154}
155
156impl VectorStore {
157 async fn new(
158 fs: Arc<dyn Fs>,
159 database_url: PathBuf,
160 embedding_provider: Arc<dyn EmbeddingProvider>,
161 language_registry: Arc<LanguageRegistry>,
162 mut cx: AsyncAppContext,
163 ) -> Result<ModelHandle<Self>> {
164 let database_url = Arc::new(database_url);
165
166 let db = cx
167 .background()
168 .spawn({
169 let fs = fs.clone();
170 let database_url = database_url.clone();
171 async move {
172 if let Some(db_directory) = database_url.parent() {
173 fs.create_dir(db_directory).await.log_err();
174 }
175
176 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
177 anyhow::Ok(db)
178 }
179 })
180 .await?;
181
182 Ok(cx.add_model(|cx| {
183 let (db_update_tx, db_update_rx) = channel::unbounded();
184 let _db_update_task = cx.background().spawn(async move {
185 while let Ok(job) = db_update_rx.recv().await {
186 match job {
187 DbWrite::InsertFile {
188 worktree_id,
189 indexed_file,
190 } => {
191 log::info!("Inserting File: {:?}", &indexed_file.path);
192 db.insert_file(worktree_id, indexed_file).log_err();
193 }
194 DbWrite::Delete { worktree_id, path } => {
195 log::info!("Deleting File: {:?}", &path);
196 db.delete_file(worktree_id, path).log_err();
197 }
198 DbWrite::FindOrCreateWorktree { path, sender } => {
199 let id = db.find_or_create_worktree(&path);
200 sender.send(id).ok();
201 }
202 }
203 }
204 });
205
206 Self {
207 fs,
208 database_url,
209 db_update_tx,
210 embedding_provider,
211 language_registry,
212 projects: HashMap::new(),
213 _db_update_task,
214 }
215 }))
216 }
217
218 async fn index_file(
219 cursor: &mut QueryCursor,
220 parser: &mut Parser,
221 embedding_provider: &dyn EmbeddingProvider,
222 fs: &Arc<dyn Fs>,
223 language: Arc<Language>,
224 file_path: PathBuf,
225 mtime: SystemTime,
226 ) -> Result<IndexedFile> {
227 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
228 let embedding_config = grammar
229 .embedding_config
230 .as_ref()
231 .ok_or_else(|| anyhow!("no outline query"))?;
232
233 let content = fs.load(&file_path).await?;
234
235 parser.set_language(grammar.ts_language).unwrap();
236 let tree = parser
237 .parse(&content, None)
238 .ok_or_else(|| anyhow!("parsing failed"))?;
239
240 let mut documents = Vec::new();
241 let mut context_spans = Vec::new();
242 for mat in cursor.matches(
243 &embedding_config.query,
244 tree.root_node(),
245 content.as_bytes(),
246 ) {
247 let mut item_range = None;
248 let mut name_range = None;
249 for capture in mat.captures {
250 if capture.index == embedding_config.item_capture_ix {
251 item_range = Some(capture.node.byte_range());
252 } else if capture.index == embedding_config.name_capture_ix {
253 name_range = Some(capture.node.byte_range());
254 }
255 }
256
257 if let Some((item_range, name_range)) = item_range.zip(name_range) {
258 if let Some((item, name)) =
259 content.get(item_range.clone()).zip(content.get(name_range))
260 {
261 context_spans.push(item);
262 documents.push(Document {
263 name: name.to_string(),
264 offset: item_range.start,
265 embedding: Vec::new(),
266 });
267 }
268 }
269 }
270
271 if !documents.is_empty() {
272 let embeddings = embedding_provider.embed_batch(context_spans).await?;
273 for (document, embedding) in documents.iter_mut().zip(embeddings) {
274 document.embedding = embedding;
275 }
276 }
277
278 return Ok(IndexedFile {
279 path: file_path,
280 mtime,
281 documents,
282 });
283 }
284
285 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
286 let (tx, rx) = oneshot::channel();
287 self.db_update_tx
288 .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
289 .unwrap();
290 async move { rx.await? }
291 }
292
293 fn add_project(
294 &mut self,
295 project: ModelHandle<Project>,
296 cx: &mut ModelContext<Self>,
297 ) -> Task<Result<()>> {
298 let worktree_scans_complete = project
299 .read(cx)
300 .worktrees(cx)
301 .map(|worktree| {
302 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
303 async move {
304 scan_complete.await;
305 }
306 })
307 .collect::<Vec<_>>();
308 let worktree_db_ids = project
309 .read(cx)
310 .worktrees(cx)
311 .map(|worktree| {
312 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
313 })
314 .collect::<Vec<_>>();
315
316 let fs = self.fs.clone();
317 let language_registry = self.language_registry.clone();
318 let embedding_provider = self.embedding_provider.clone();
319 let database_url = self.database_url.clone();
320 let db_update_tx = self.db_update_tx.clone();
321
322 cx.spawn(|this, mut cx| async move {
323 futures::future::join_all(worktree_scans_complete).await;
324
325 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
326
327 if let Some(db_directory) = database_url.parent() {
328 fs.create_dir(db_directory).await.log_err();
329 }
330
331 let worktrees = project.read_with(&cx, |project, cx| {
332 project
333 .worktrees(cx)
334 .map(|worktree| worktree.read(cx).snapshot())
335 .collect::<Vec<_>>()
336 });
337
338 // Here we query the worktree ids, and yet we dont have them elsewhere
339 // We likely want to clean up these datastructures
340 let (mut worktree_file_times, db_ids_by_worktree_id) = cx
341 .background()
342 .spawn({
343 let worktrees = worktrees.clone();
344 async move {
345 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
346 let mut db_ids_by_worktree_id = HashMap::new();
347 let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
348 HashMap::new();
349 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
350 let db_id = db_id?;
351 db_ids_by_worktree_id.insert(worktree.id(), db_id);
352 file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
353 }
354 anyhow::Ok((file_times, db_ids_by_worktree_id))
355 }
356 })
357 .await?;
358
359 let (paths_tx, paths_rx) =
360 channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
361 cx.background()
362 .spawn({
363 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
364 let db_update_tx = db_update_tx.clone();
365 let language_registry = language_registry.clone();
366 let paths_tx = paths_tx.clone();
367 async move {
368 for worktree in worktrees.into_iter() {
369 let mut file_mtimes =
370 worktree_file_times.remove(&worktree.id()).unwrap();
371 for file in worktree.files(false, 0) {
372 let absolute_path = worktree.absolutize(&file.path);
373
374 if let Ok(language) = language_registry
375 .language_for_file(&absolute_path, None)
376 .await
377 {
378 if language
379 .grammar()
380 .and_then(|grammar| grammar.embedding_config.as_ref())
381 .is_none()
382 {
383 continue;
384 }
385
386 let path_buf = file.path.to_path_buf();
387 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
388 let already_stored = stored_mtime
389 .map_or(false, |existing_mtime| {
390 existing_mtime == file.mtime
391 });
392
393 if !already_stored {
394 paths_tx
395 .try_send((
396 db_ids_by_worktree_id[&worktree.id()],
397 path_buf,
398 language,
399 file.mtime,
400 ))
401 .unwrap();
402 }
403 }
404 }
405 for file in file_mtimes.keys() {
406 db_update_tx
407 .try_send(DbWrite::Delete {
408 worktree_id: db_ids_by_worktree_id[&worktree.id()],
409 path: file.to_owned(),
410 })
411 .unwrap();
412 }
413 }
414 }
415 })
416 .detach();
417
418 cx.background()
419 .scoped(|scope| {
420 for _ in 0..cx.background().num_cpus() {
421 scope.spawn(async {
422 let mut parser = Parser::new();
423 let mut cursor = QueryCursor::new();
424 while let Ok((worktree_id, file_path, language, mtime)) =
425 paths_rx.recv().await
426 {
427 if let Some(indexed_file) = Self::index_file(
428 &mut cursor,
429 &mut parser,
430 embedding_provider.as_ref(),
431 &fs,
432 language,
433 file_path,
434 mtime,
435 )
436 .await
437 .log_err()
438 {
439 db_update_tx
440 .try_send(DbWrite::InsertFile {
441 worktree_id,
442 indexed_file,
443 })
444 .unwrap();
445 }
446 }
447 });
448 }
449 })
450 .await;
451
452 this.update(&mut cx, |this, cx| {
453 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
454 if let Some(project_state) = this.projects.get(&project.downgrade()) {
455 let worktree_db_ids = project_state.worktree_db_ids.clone();
456
457 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
458 {
459 // Iterate through changes
460 let language_registry = this.language_registry.clone();
461
462 let db =
463 VectorDatabase::new(this.database_url.to_string_lossy().into());
464 if db.is_err() {
465 return;
466 }
467 let db = db.unwrap();
468
469 let worktree_db_id: Option<i64> = {
470 let mut found_db_id = None;
471 for (w_id, db_id) in worktree_db_ids.into_iter() {
472 if &w_id == worktree_id {
473 found_db_id = Some(db_id);
474 }
475 }
476
477 found_db_id
478 };
479
480 if worktree_db_id.is_none() {
481 return;
482 }
483 let worktree_db_id = worktree_db_id.unwrap();
484
485 let file_mtimes = db.get_file_mtimes(worktree_db_id);
486 if file_mtimes.is_err() {
487 return;
488 }
489
490 let file_mtimes = file_mtimes.unwrap();
491
492 smol::block_on(async move {
493 for change in changes.into_iter() {
494 let change_path = change.0.clone();
495 log::info!("Change: {:?}", &change_path);
496 if let Ok(language) = language_registry
497 .language_for_file(&change_path.to_path_buf(), None)
498 .await
499 {
500 if language
501 .grammar()
502 .and_then(|grammar| grammar.embedding_config.as_ref())
503 .is_none()
504 {
505 continue;
506 }
507 log::info!("Language found: {:?}: ", language.name());
508
509 // TODO: Make this a bit more defensive
510 let modified_time =
511 change_path.metadata().unwrap().modified().unwrap();
512 let existing_time =
513 file_mtimes.get(&change_path.to_path_buf());
514 let already_stored =
515 existing_time.map_or(false, |existing_time| {
516 if &modified_time != existing_time
517 && existing_time.elapsed().unwrap().as_secs()
518 > 30
519 {
520 false
521 } else {
522 true
523 }
524 });
525
526 if !already_stored {
527 log::info!("Need to reindex: {:?}", &change_path);
528 // paths_tx
529 // .try_send((
530 // worktree_db_id,
531 // change_path.to_path_buf(),
532 // language,
533 // modified_time,
534 // ))
535 // .unwrap();
536 }
537 }
538 }
539 })
540 }
541 }
542 });
543
544 this.projects.insert(
545 project.downgrade(),
546 ProjectState {
547 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
548 _subscription,
549 },
550 );
551 });
552
553 log::info!("Semantic Indexing Complete!");
554
555 anyhow::Ok(())
556 })
557 }
558
559 pub fn search(
560 &mut self,
561 project: ModelHandle<Project>,
562 phrase: String,
563 limit: usize,
564 cx: &mut ModelContext<Self>,
565 ) -> Task<Result<Vec<SearchResult>>> {
566 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
567 state
568 } else {
569 return Task::ready(Err(anyhow!("project not added")));
570 };
571
572 let worktree_db_ids = project
573 .read(cx)
574 .worktrees(cx)
575 .filter_map(|worktree| {
576 let worktree_id = worktree.read(cx).id();
577 project_state
578 .worktree_db_ids
579 .iter()
580 .find_map(|(id, db_id)| {
581 if *id == worktree_id {
582 Some(*db_id)
583 } else {
584 None
585 }
586 })
587 })
588 .collect::<Vec<_>>();
589
590 log::info!("Searching for: {:?}", phrase);
591
592 let embedding_provider = self.embedding_provider.clone();
593 let database_url = self.database_url.clone();
594 cx.spawn(|this, cx| async move {
595 let documents = cx
596 .background()
597 .spawn(async move {
598 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
599
600 let phrase_embedding = embedding_provider
601 .embed_batch(vec![&phrase])
602 .await?
603 .into_iter()
604 .next()
605 .unwrap();
606
607 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
608 database.for_each_document(&worktree_db_ids, |id, embedding| {
609 let similarity = dot(&embedding.0, &phrase_embedding);
610 let ix = match results.binary_search_by(|(_, s)| {
611 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
612 }) {
613 Ok(ix) => ix,
614 Err(ix) => ix,
615 };
616 results.insert(ix, (id, similarity));
617 results.truncate(limit);
618 })?;
619
620 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
621 database.get_documents_by_ids(&ids)
622 })
623 .await?;
624
625 this.read_with(&cx, |this, _| {
626 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
627 state
628 } else {
629 return Err(anyhow!("project not added"));
630 };
631
632 Ok(documents
633 .into_iter()
634 .filter_map(|(worktree_db_id, file_path, offset, name)| {
635 let worktree_id =
636 project_state
637 .worktree_db_ids
638 .iter()
639 .find_map(|(id, db_id)| {
640 if *db_id == worktree_db_id {
641 Some(*id)
642 } else {
643 None
644 }
645 })?;
646 Some(SearchResult {
647 worktree_id,
648 name,
649 offset,
650 file_path,
651 })
652 })
653 .collect())
654 })
655 })
656 }
657}
658
659impl Entity for VectorStore {
660 type Event = ();
661}
662
663fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
664 let len = vec_a.len();
665 assert_eq!(len, vec_b.len());
666
667 let mut result = 0.0;
668 unsafe {
669 matrixmultiply::sgemm(
670 1,
671 len,
672 1,
673 1.0,
674 vec_a.as_ptr(),
675 len as isize,
676 1,
677 vec_b.as_ptr(),
678 1,
679 len as isize,
680 0.0,
681 &mut result as *mut f32,
682 1,
683 1,
684 );
685 }
686 result
687}