vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::VectorDatabase;
 10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 11use futures::{channel::oneshot, Future};
 12use gpui::{
 13    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
 14    WeakModelHandle,
 15};
 16use language::{Language, LanguageRegistry};
 17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 18use project::{Fs, Project, WorktreeId};
 19use smol::channel;
 20use std::{
 21    cmp::Ordering,
 22    collections::HashMap,
 23    path::{Path, PathBuf},
 24    sync::Arc,
 25    time::SystemTime,
 26};
 27use tree_sitter::{Parser, QueryCursor};
 28use util::{
 29    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 30    http::HttpClient,
 31    paths::EMBEDDINGS_DIR,
 32    ResultExt,
 33};
 34use workspace::{Workspace, WorkspaceCreated};
 35
 36#[derive(Debug)]
 37pub struct Document {
 38    pub offset: usize,
 39    pub name: String,
 40    pub embedding: Vec<f32>,
 41}
 42
 43pub fn init(
 44    fs: Arc<dyn Fs>,
 45    http_client: Arc<dyn HttpClient>,
 46    language_registry: Arc<LanguageRegistry>,
 47    cx: &mut AppContext,
 48) {
 49    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 50        return;
 51    }
 52
 53    let db_file_path = EMBEDDINGS_DIR
 54        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 55        .join("embeddings_db");
 56
 57    cx.spawn(move |mut cx| async move {
 58        let vector_store = VectorStore::new(
 59            fs,
 60            db_file_path,
 61            Arc::new(embedding::DummyEmbeddings {}),
 62            // Arc::new(OpenAIEmbeddings {
 63            //     client: http_client,
 64            // }),
 65            language_registry,
 66            cx.clone(),
 67        )
 68        .await?;
 69
 70        cx.update(|cx| {
 71            cx.subscribe_global::<WorkspaceCreated, _>({
 72                let vector_store = vector_store.clone();
 73                move |event, cx| {
 74                    let workspace = &event.0;
 75                    if let Some(workspace) = workspace.upgrade(cx) {
 76                        let project = workspace.read(cx).project().clone();
 77                        if project.read(cx).is_local() {
 78                            vector_store.update(cx, |store, cx| {
 79                                store.add_project(project, cx).detach();
 80                            });
 81                        }
 82                    }
 83                }
 84            })
 85            .detach();
 86
 87            cx.add_action({
 88                move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 89                    let vector_store = vector_store.clone();
 90                    workspace.toggle_modal(cx, |workspace, cx| {
 91                        let project = workspace.project().clone();
 92                        let workspace = cx.weak_handle();
 93                        cx.add_view(|cx| {
 94                            SemanticSearch::new(
 95                                SemanticSearchDelegate::new(workspace, project, vector_store),
 96                                cx,
 97                            )
 98                        })
 99                    })
100                }
101            });
102
103            SemanticSearch::init(cx);
104        });
105
106        anyhow::Ok(())
107    })
108    .detach();
109}
110
111#[derive(Debug)]
112pub struct IndexedFile {
113    path: PathBuf,
114    mtime: SystemTime,
115    documents: Vec<Document>,
116}
117
118pub struct VectorStore {
119    fs: Arc<dyn Fs>,
120    database_url: Arc<PathBuf>,
121    embedding_provider: Arc<dyn EmbeddingProvider>,
122    language_registry: Arc<LanguageRegistry>,
123    db_update_tx: channel::Sender<DbWrite>,
124    _db_update_task: Task<()>,
125    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
126}
127
128struct ProjectState {
129    worktree_db_ids: Vec<(WorktreeId, i64)>,
130    _subscription: gpui::Subscription,
131}
132
133#[derive(Debug, Clone)]
134pub struct SearchResult {
135    pub worktree_id: WorktreeId,
136    pub name: String,
137    pub offset: usize,
138    pub file_path: PathBuf,
139}
140
141enum DbWrite {
142    InsertFile {
143        worktree_id: i64,
144        indexed_file: IndexedFile,
145    },
146    Delete {
147        worktree_id: i64,
148        path: PathBuf,
149    },
150    FindOrCreateWorktree {
151        path: PathBuf,
152        sender: oneshot::Sender<Result<i64>>,
153    },
154}
155
156impl VectorStore {
157    async fn new(
158        fs: Arc<dyn Fs>,
159        database_url: PathBuf,
160        embedding_provider: Arc<dyn EmbeddingProvider>,
161        language_registry: Arc<LanguageRegistry>,
162        mut cx: AsyncAppContext,
163    ) -> Result<ModelHandle<Self>> {
164        let database_url = Arc::new(database_url);
165
166        let db = cx
167            .background()
168            .spawn({
169                let fs = fs.clone();
170                let database_url = database_url.clone();
171                async move {
172                    if let Some(db_directory) = database_url.parent() {
173                        fs.create_dir(db_directory).await.log_err();
174                    }
175
176                    let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
177                    anyhow::Ok(db)
178                }
179            })
180            .await?;
181
182        Ok(cx.add_model(|cx| {
183            let (db_update_tx, db_update_rx) = channel::unbounded();
184            let _db_update_task = cx.background().spawn(async move {
185                while let Ok(job) = db_update_rx.recv().await {
186                    match job {
187                        DbWrite::InsertFile {
188                            worktree_id,
189                            indexed_file,
190                        } => {
191                            log::info!("Inserting File: {:?}", &indexed_file.path);
192                            db.insert_file(worktree_id, indexed_file).log_err();
193                        }
194                        DbWrite::Delete { worktree_id, path } => {
195                            log::info!("Deleting File: {:?}", &path);
196                            db.delete_file(worktree_id, path).log_err();
197                        }
198                        DbWrite::FindOrCreateWorktree { path, sender } => {
199                            let id = db.find_or_create_worktree(&path);
200                            sender.send(id).ok();
201                        }
202                    }
203                }
204            });
205
206            Self {
207                fs,
208                database_url,
209                db_update_tx,
210                embedding_provider,
211                language_registry,
212                projects: HashMap::new(),
213                _db_update_task,
214            }
215        }))
216    }
217
218    async fn index_file(
219        cursor: &mut QueryCursor,
220        parser: &mut Parser,
221        embedding_provider: &dyn EmbeddingProvider,
222        fs: &Arc<dyn Fs>,
223        language: Arc<Language>,
224        file_path: PathBuf,
225        mtime: SystemTime,
226    ) -> Result<IndexedFile> {
227        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
228        let embedding_config = grammar
229            .embedding_config
230            .as_ref()
231            .ok_or_else(|| anyhow!("no outline query"))?;
232
233        let content = fs.load(&file_path).await?;
234
235        parser.set_language(grammar.ts_language).unwrap();
236        let tree = parser
237            .parse(&content, None)
238            .ok_or_else(|| anyhow!("parsing failed"))?;
239
240        let mut documents = Vec::new();
241        let mut context_spans = Vec::new();
242        for mat in cursor.matches(
243            &embedding_config.query,
244            tree.root_node(),
245            content.as_bytes(),
246        ) {
247            let mut item_range = None;
248            let mut name_range = None;
249            for capture in mat.captures {
250                if capture.index == embedding_config.item_capture_ix {
251                    item_range = Some(capture.node.byte_range());
252                } else if capture.index == embedding_config.name_capture_ix {
253                    name_range = Some(capture.node.byte_range());
254                }
255            }
256
257            if let Some((item_range, name_range)) = item_range.zip(name_range) {
258                if let Some((item, name)) =
259                    content.get(item_range.clone()).zip(content.get(name_range))
260                {
261                    context_spans.push(item);
262                    documents.push(Document {
263                        name: name.to_string(),
264                        offset: item_range.start,
265                        embedding: Vec::new(),
266                    });
267                }
268            }
269        }
270
271        if !documents.is_empty() {
272            let embeddings = embedding_provider.embed_batch(context_spans).await?;
273            for (document, embedding) in documents.iter_mut().zip(embeddings) {
274                document.embedding = embedding;
275            }
276        }
277
278        return Ok(IndexedFile {
279            path: file_path,
280            mtime,
281            documents,
282        });
283    }
284
285    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
286        let (tx, rx) = oneshot::channel();
287        self.db_update_tx
288            .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
289            .unwrap();
290        async move { rx.await? }
291    }
292
293    fn add_project(
294        &mut self,
295        project: ModelHandle<Project>,
296        cx: &mut ModelContext<Self>,
297    ) -> Task<Result<()>> {
298        let worktree_scans_complete = project
299            .read(cx)
300            .worktrees(cx)
301            .map(|worktree| {
302                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
303                async move {
304                    scan_complete.await;
305                }
306            })
307            .collect::<Vec<_>>();
308        let worktree_db_ids = project
309            .read(cx)
310            .worktrees(cx)
311            .map(|worktree| {
312                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
313            })
314            .collect::<Vec<_>>();
315
316        let fs = self.fs.clone();
317        let language_registry = self.language_registry.clone();
318        let embedding_provider = self.embedding_provider.clone();
319        let database_url = self.database_url.clone();
320        let db_update_tx = self.db_update_tx.clone();
321
322        cx.spawn(|this, mut cx| async move {
323            futures::future::join_all(worktree_scans_complete).await;
324
325            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
326
327            if let Some(db_directory) = database_url.parent() {
328                fs.create_dir(db_directory).await.log_err();
329            }
330
331            let worktrees = project.read_with(&cx, |project, cx| {
332                project
333                    .worktrees(cx)
334                    .map(|worktree| worktree.read(cx).snapshot())
335                    .collect::<Vec<_>>()
336            });
337
338            // Here we query the worktree ids, and yet we dont have them elsewhere
339            // We likely want to clean up these datastructures
340            let (mut worktree_file_times, db_ids_by_worktree_id) = cx
341                .background()
342                .spawn({
343                    let worktrees = worktrees.clone();
344                    async move {
345                        let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
346                        let mut db_ids_by_worktree_id = HashMap::new();
347                        let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
348                            HashMap::new();
349                        for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
350                            let db_id = db_id?;
351                            db_ids_by_worktree_id.insert(worktree.id(), db_id);
352                            file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
353                        }
354                        anyhow::Ok((file_times, db_ids_by_worktree_id))
355                    }
356                })
357                .await?;
358
359            let (paths_tx, paths_rx) =
360                channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
361            cx.background()
362                .spawn({
363                    let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
364                    let db_update_tx = db_update_tx.clone();
365                    let language_registry = language_registry.clone();
366                    let paths_tx = paths_tx.clone();
367                    async move {
368                        for worktree in worktrees.into_iter() {
369                            let mut file_mtimes =
370                                worktree_file_times.remove(&worktree.id()).unwrap();
371                            for file in worktree.files(false, 0) {
372                                let absolute_path = worktree.absolutize(&file.path);
373
374                                if let Ok(language) = language_registry
375                                    .language_for_file(&absolute_path, None)
376                                    .await
377                                {
378                                    if language
379                                        .grammar()
380                                        .and_then(|grammar| grammar.embedding_config.as_ref())
381                                        .is_none()
382                                    {
383                                        continue;
384                                    }
385
386                                    let path_buf = file.path.to_path_buf();
387                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
388                                    let already_stored = stored_mtime
389                                        .map_or(false, |existing_mtime| {
390                                            existing_mtime == file.mtime
391                                        });
392
393                                    if !already_stored {
394                                        paths_tx
395                                            .try_send((
396                                                db_ids_by_worktree_id[&worktree.id()],
397                                                path_buf,
398                                                language,
399                                                file.mtime,
400                                            ))
401                                            .unwrap();
402                                    }
403                                }
404                            }
405                            for file in file_mtimes.keys() {
406                                db_update_tx
407                                    .try_send(DbWrite::Delete {
408                                        worktree_id: db_ids_by_worktree_id[&worktree.id()],
409                                        path: file.to_owned(),
410                                    })
411                                    .unwrap();
412                            }
413                        }
414                    }
415                })
416                .detach();
417
418            cx.background()
419                .scoped(|scope| {
420                    for _ in 0..cx.background().num_cpus() {
421                        scope.spawn(async {
422                            let mut parser = Parser::new();
423                            let mut cursor = QueryCursor::new();
424                            while let Ok((worktree_id, file_path, language, mtime)) =
425                                paths_rx.recv().await
426                            {
427                                if let Some(indexed_file) = Self::index_file(
428                                    &mut cursor,
429                                    &mut parser,
430                                    embedding_provider.as_ref(),
431                                    &fs,
432                                    language,
433                                    file_path,
434                                    mtime,
435                                )
436                                .await
437                                .log_err()
438                                {
439                                    db_update_tx
440                                        .try_send(DbWrite::InsertFile {
441                                            worktree_id,
442                                            indexed_file,
443                                        })
444                                        .unwrap();
445                                }
446                            }
447                        });
448                    }
449                })
450                .await;
451
452            this.update(&mut cx, |this, cx| {
453                let _subscription = cx.subscribe(&project, |this, project, event, cx| {
454                    if let Some(project_state) = this.projects.get(&project.downgrade()) {
455                        let worktree_db_ids = project_state.worktree_db_ids.clone();
456
457                        if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
458                        {
459                            // Iterate through changes
460                            let language_registry = this.language_registry.clone();
461
462                            let db =
463                                VectorDatabase::new(this.database_url.to_string_lossy().into());
464                            if db.is_err() {
465                                return;
466                            }
467                            let db = db.unwrap();
468
469                            let worktree_db_id: Option<i64> = {
470                                let mut found_db_id = None;
471                                for (w_id, db_id) in worktree_db_ids.into_iter() {
472                                    if &w_id == worktree_id {
473                                        found_db_id = Some(db_id);
474                                    }
475                                }
476
477                                found_db_id
478                            };
479
480                            if worktree_db_id.is_none() {
481                                return;
482                            }
483                            let worktree_db_id = worktree_db_id.unwrap();
484
485                            let file_mtimes = db.get_file_mtimes(worktree_db_id);
486                            if file_mtimes.is_err() {
487                                return;
488                            }
489
490                            let file_mtimes = file_mtimes.unwrap();
491
492                            smol::block_on(async move {
493                                for change in changes.into_iter() {
494                                    let change_path = change.0.clone();
495                                    log::info!("Change: {:?}", &change_path);
496                                    if let Ok(language) = language_registry
497                                        .language_for_file(&change_path.to_path_buf(), None)
498                                        .await
499                                    {
500                                        if language
501                                            .grammar()
502                                            .and_then(|grammar| grammar.embedding_config.as_ref())
503                                            .is_none()
504                                        {
505                                            continue;
506                                        }
507                                        log::info!("Language found: {:?}: ", language.name());
508
509                                        // TODO: Make this a bit more defensive
510                                        let modified_time =
511                                            change_path.metadata().unwrap().modified().unwrap();
512                                        let existing_time =
513                                            file_mtimes.get(&change_path.to_path_buf());
514                                        let already_stored =
515                                            existing_time.map_or(false, |existing_time| {
516                                                if &modified_time != existing_time
517                                                    && existing_time.elapsed().unwrap().as_secs()
518                                                        > 30
519                                                {
520                                                    false
521                                                } else {
522                                                    true
523                                                }
524                                            });
525
526                                        if !already_stored {
527                                            log::info!("Need to reindex: {:?}", &change_path);
528                                            // paths_tx
529                                            //     .try_send((
530                                            //         worktree_db_id,
531                                            //         change_path.to_path_buf(),
532                                            //         language,
533                                            //         modified_time,
534                                            //     ))
535                                            //     .unwrap();
536                                        }
537                                    }
538                                }
539                            })
540                        }
541                    }
542                });
543
544                this.projects.insert(
545                    project.downgrade(),
546                    ProjectState {
547                        worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
548                        _subscription,
549                    },
550                );
551            });
552
553            log::info!("Semantic Indexing Complete!");
554
555            anyhow::Ok(())
556        })
557    }
558
559    pub fn search(
560        &mut self,
561        project: ModelHandle<Project>,
562        phrase: String,
563        limit: usize,
564        cx: &mut ModelContext<Self>,
565    ) -> Task<Result<Vec<SearchResult>>> {
566        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
567            state
568        } else {
569            return Task::ready(Err(anyhow!("project not added")));
570        };
571
572        let worktree_db_ids = project
573            .read(cx)
574            .worktrees(cx)
575            .filter_map(|worktree| {
576                let worktree_id = worktree.read(cx).id();
577                project_state
578                    .worktree_db_ids
579                    .iter()
580                    .find_map(|(id, db_id)| {
581                        if *id == worktree_id {
582                            Some(*db_id)
583                        } else {
584                            None
585                        }
586                    })
587            })
588            .collect::<Vec<_>>();
589
590        log::info!("Searching for: {:?}", phrase);
591
592        let embedding_provider = self.embedding_provider.clone();
593        let database_url = self.database_url.clone();
594        cx.spawn(|this, cx| async move {
595            let documents = cx
596                .background()
597                .spawn(async move {
598                    let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
599
600                    let phrase_embedding = embedding_provider
601                        .embed_batch(vec![&phrase])
602                        .await?
603                        .into_iter()
604                        .next()
605                        .unwrap();
606
607                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
608                    database.for_each_document(&worktree_db_ids, |id, embedding| {
609                        let similarity = dot(&embedding.0, &phrase_embedding);
610                        let ix = match results.binary_search_by(|(_, s)| {
611                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
612                        }) {
613                            Ok(ix) => ix,
614                            Err(ix) => ix,
615                        };
616                        results.insert(ix, (id, similarity));
617                        results.truncate(limit);
618                    })?;
619
620                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
621                    database.get_documents_by_ids(&ids)
622                })
623                .await?;
624
625            this.read_with(&cx, |this, _| {
626                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
627                    state
628                } else {
629                    return Err(anyhow!("project not added"));
630                };
631
632                Ok(documents
633                    .into_iter()
634                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
635                        let worktree_id =
636                            project_state
637                                .worktree_db_ids
638                                .iter()
639                                .find_map(|(id, db_id)| {
640                                    if *db_id == worktree_db_id {
641                                        Some(*id)
642                                    } else {
643                                        None
644                                    }
645                                })?;
646                        Some(SearchResult {
647                            worktree_id,
648                            name,
649                            offset,
650                            file_path,
651                        })
652                    })
653                    .collect())
654            })
655        })
656    }
657}
658
659impl Entity for VectorStore {
660    type Event = ();
661}
662
663fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
664    let len = vec_a.len();
665    assert_eq!(len, vec_b.len());
666
667    let mut result = 0.0;
668    unsafe {
669        matrixmultiply::sgemm(
670            1,
671            len,
672            1,
673            1.0,
674            vec_a.as_ptr(),
675            len as isize,
676            1,
677            vec_b.as_ptr(),
678            1,
679            len as isize,
680            0.0,
681            &mut result as *mut f32,
682            1,
683            1,
684        );
685    }
686    result
687}