vector_store.rs

  1mod db;
  2use anyhow::Result;
  3use db::VectorDatabase;
  4use gpui::{AppContext, Entity, ModelContext, ModelHandle};
  5use language::LanguageRegistry;
  6use project::{Fs, Project};
  7use rand::Rng;
  8use smol::channel;
  9use std::{path::PathBuf, sync::Arc, time::Instant};
 10use util::ResultExt;
 11use workspace::WorkspaceCreated;
 12
 13pub fn init(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>, cx: &mut AppContext) {
 14    let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry));
 15
 16    cx.subscribe_global::<WorkspaceCreated, _>({
 17        let vector_store = vector_store.clone();
 18        move |event, cx| {
 19            let workspace = &event.0;
 20            if let Some(workspace) = workspace.upgrade(cx) {
 21                let project = workspace.read(cx).project().clone();
 22                if project.read(cx).is_local() {
 23                    vector_store.update(cx, |store, cx| {
 24                        store.add_project(project, cx);
 25                    });
 26                }
 27            }
 28        }
 29    })
 30    .detach();
 31}
 32
 33#[derive(Debug, sqlx::FromRow)]
 34struct Document {
 35    offset: usize,
 36    name: String,
 37    embedding: Vec<f32>,
 38}
 39
 40#[derive(Debug, sqlx::FromRow)]
 41pub struct IndexedFile {
 42    path: PathBuf,
 43    sha1: String,
 44    documents: Vec<Document>,
 45}
 46
 47struct SearchResult {
 48    path: PathBuf,
 49    offset: usize,
 50    name: String,
 51    distance: f32,
 52}
 53
 54struct VectorStore {
 55    fs: Arc<dyn Fs>,
 56    language_registry: Arc<LanguageRegistry>,
 57}
 58
 59impl VectorStore {
 60    fn new(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>) -> Self {
 61        Self {
 62            fs,
 63            language_registry,
 64        }
 65    }
 66
 67    async fn index_file(
 68        fs: &Arc<dyn Fs>,
 69        language_registry: &Arc<LanguageRegistry>,
 70        file_path: PathBuf,
 71    ) -> Result<IndexedFile> {
 72        // This is creating dummy documents to test the database writes.
 73        let mut documents = vec![];
 74        let mut rng = rand::thread_rng();
 75        let rand_num_of_documents: u8 = rng.gen_range(0..200);
 76        for _ in 0..rand_num_of_documents {
 77            let doc = Document {
 78                offset: 0,
 79                name: "test symbol".to_string(),
 80                embedding: vec![0.32 as f32; 768],
 81            };
 82            documents.push(doc);
 83        }
 84
 85        return Ok(IndexedFile {
 86            path: file_path,
 87            sha1: "asdfasdfasdf".to_string(),
 88            documents,
 89        });
 90    }
 91
 92    fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
 93        let worktree_scans_complete = project
 94            .read(cx)
 95            .worktrees(cx)
 96            .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
 97            .collect::<Vec<_>>();
 98
 99        let fs = self.fs.clone();
100        let language_registry = self.language_registry.clone();
101
102        cx.spawn(|this, cx| async move {
103            futures::future::join_all(worktree_scans_complete).await;
104
105            let worktrees = project.read_with(&cx, |project, cx| {
106                project
107                    .worktrees(cx)
108                    .map(|worktree| worktree.read(cx).snapshot())
109                    .collect::<Vec<_>>()
110            });
111
112            let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
113            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
114            cx.background()
115                .spawn(async move {
116                    for worktree in worktrees {
117                        for file in worktree.files(false, 0) {
118                            paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
119                        }
120                    }
121                })
122                .detach();
123
124            cx.background()
125                .spawn(async move {
126                    // Initialize Database, creates database and tables if not exists
127                    VectorDatabase::initialize_database().await.log_err();
128                    while let Ok(indexed_file) = indexed_files_rx.recv().await {
129                        VectorDatabase::insert_file(indexed_file).await.log_err();
130                    }
131                })
132                .detach();
133
134            cx.background()
135                .scoped(|scope| {
136                    for _ in 0..cx.background().num_cpus() {
137                        scope.spawn(async {
138                            while let Ok(file_path) = paths_rx.recv().await {
139                                if let Some(indexed_file) =
140                                    Self::index_file(&fs, &language_registry, file_path)
141                                        .await
142                                        .log_err()
143                                {
144                                    indexed_files_tx.try_send(indexed_file).unwrap();
145                                }
146                            }
147                        });
148                    }
149                })
150                .await;
151        })
152        .detach();
153    }
154}
155
156impl Entity for VectorStore {
157    type Event = ();
158}