vector_store.rs

  1mod db;
  2mod embedding;
  3mod parsing;
  4mod search;
  5
  6#[cfg(test)]
  7mod vector_store_tests;
  8
  9use anyhow::{anyhow, Result};
 10use db::{VectorDatabase, VECTOR_DB_URL};
 11use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
 12use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
 13use language::LanguageRegistry;
 14use parsing::Document;
 15use project::{Fs, Project};
 16use search::{BruteForceSearch, VectorSearch};
 17use smol::channel;
 18use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
 19use tree_sitter::{Parser, QueryCursor};
 20use util::{http::HttpClient, ResultExt, TryFutureExt};
 21use workspace::WorkspaceCreated;
 22
 23pub fn init(
 24    fs: Arc<dyn Fs>,
 25    http_client: Arc<dyn HttpClient>,
 26    language_registry: Arc<LanguageRegistry>,
 27    cx: &mut AppContext,
 28) {
 29    let vector_store = cx.add_model(|cx| {
 30        VectorStore::new(
 31            fs,
 32            VECTOR_DB_URL.to_string(),
 33            Arc::new(OpenAIEmbeddings {
 34                client: http_client,
 35            }),
 36            language_registry,
 37        )
 38    });
 39
 40    cx.subscribe_global::<WorkspaceCreated, _>({
 41        let vector_store = vector_store.clone();
 42        move |event, cx| {
 43            let workspace = &event.0;
 44            if let Some(workspace) = workspace.upgrade(cx) {
 45                let project = workspace.read(cx).project().clone();
 46                if project.read(cx).is_local() {
 47                    vector_store.update(cx, |store, cx| {
 48                        store.add_project(project, cx);
 49                    });
 50                }
 51            }
 52        }
 53    })
 54    .detach();
 55}
 56
 57#[derive(Debug)]
 58pub struct IndexedFile {
 59    path: PathBuf,
 60    sha1: String,
 61    documents: Vec<Document>,
 62}
 63
 64// struct SearchResult {
 65//     path: PathBuf,
 66//     offset: usize,
 67//     name: String,
 68//     distance: f32,
 69// }
 70struct VectorStore {
 71    fs: Arc<dyn Fs>,
 72    database_url: Arc<str>,
 73    embedding_provider: Arc<dyn EmbeddingProvider>,
 74    language_registry: Arc<LanguageRegistry>,
 75}
 76
 77pub struct SearchResult {
 78    pub name: String,
 79    pub offset: usize,
 80    pub file_path: PathBuf,
 81}
 82
 83impl VectorStore {
 84    fn new(
 85        fs: Arc<dyn Fs>,
 86        database_url: String,
 87        embedding_provider: Arc<dyn EmbeddingProvider>,
 88        language_registry: Arc<LanguageRegistry>,
 89    ) -> Self {
 90        Self {
 91            fs,
 92            database_url: database_url.into(),
 93            embedding_provider,
 94            language_registry,
 95        }
 96    }
 97
 98    async fn index_file(
 99        cursor: &mut QueryCursor,
100        parser: &mut Parser,
101        embedding_provider: &dyn EmbeddingProvider,
102        language_registry: &Arc<LanguageRegistry>,
103        file_path: PathBuf,
104        content: String,
105    ) -> Result<IndexedFile> {
106        dbg!(&file_path, &content);
107
108        let language = language_registry
109            .language_for_file(&file_path, None)
110            .await?;
111
112        if language.name().as_ref() != "Rust" {
113            Err(anyhow!("unsupported language"))?;
114        }
115
116        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
117        let outline_config = grammar
118            .outline_config
119            .as_ref()
120            .ok_or_else(|| anyhow!("no outline query"))?;
121
122        parser.set_language(grammar.ts_language).unwrap();
123        let tree = parser
124            .parse(&content, None)
125            .ok_or_else(|| anyhow!("parsing failed"))?;
126
127        let mut documents = Vec::new();
128        let mut context_spans = Vec::new();
129        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
130            let mut item_range = None;
131            let mut name_range = None;
132            for capture in mat.captures {
133                if capture.index == outline_config.item_capture_ix {
134                    item_range = Some(capture.node.byte_range());
135                } else if capture.index == outline_config.name_capture_ix {
136                    name_range = Some(capture.node.byte_range());
137                }
138            }
139
140            if let Some((item_range, name_range)) = item_range.zip(name_range) {
141                if let Some((item, name)) =
142                    content.get(item_range.clone()).zip(content.get(name_range))
143                {
144                    context_spans.push(item);
145                    documents.push(Document {
146                        name: name.to_string(),
147                        offset: item_range.start,
148                        embedding: Vec::new(),
149                    });
150                }
151            }
152        }
153
154        let embeddings = embedding_provider.embed_batch(context_spans).await?;
155        for (document, embedding) in documents.iter_mut().zip(embeddings) {
156            document.embedding = embedding;
157        }
158
159        return Ok(IndexedFile {
160            path: file_path,
161            sha1: String::new(),
162            documents,
163        });
164    }
165
166    fn add_project(
167        &mut self,
168        project: ModelHandle<Project>,
169        cx: &mut ModelContext<Self>,
170    ) -> Task<Result<()>> {
171        let worktree_scans_complete = project
172            .read(cx)
173            .worktrees(cx)
174            .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
175            .collect::<Vec<_>>();
176
177        let fs = self.fs.clone();
178        let language_registry = self.language_registry.clone();
179        let embedding_provider = self.embedding_provider.clone();
180        let database_url = self.database_url.clone();
181
182        cx.spawn(|_, cx| async move {
183            futures::future::join_all(worktree_scans_complete).await;
184
185            let worktrees = project.read_with(&cx, |project, cx| {
186                project
187                    .worktrees(cx)
188                    .map(|worktree| worktree.read(cx).snapshot())
189                    .collect::<Vec<_>>()
190            });
191
192            let db = VectorDatabase::new(&database_url)?;
193            let worktree_root_paths = worktrees
194                .iter()
195                .map(|worktree| worktree.abs_path().clone())
196                .collect::<Vec<_>>();
197            let (db, file_hashes) = cx
198                .background()
199                .spawn(async move {
200                    let mut hashes = Vec::new();
201                    for worktree_root_path in worktree_root_paths {
202                        let worktree_id =
203                            db.find_or_create_worktree(worktree_root_path.as_ref())?;
204                        hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
205                    }
206                    anyhow::Ok((db, hashes))
207                })
208                .await?;
209
210            let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
211            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
212            cx.background()
213                .spawn({
214                    let fs = fs.clone();
215                    async move {
216                        for worktree in worktrees.into_iter() {
217                            for file in worktree.files(false, 0) {
218                                let absolute_path = worktree.absolutize(&file.path);
219                                dbg!(&absolute_path);
220                                if let Some(content) = fs.load(&absolute_path).await.log_err() {
221                                    dbg!(&content);
222                                    paths_tx.try_send((0, absolute_path, content)).unwrap();
223                                }
224                            }
225                        }
226                    }
227                })
228                .detach();
229
230            let db_write_task = cx.background().spawn(
231                async move {
232                    // Initialize Database, creates database and tables if not exists
233                    while let Ok(indexed_file) = indexed_files_rx.recv().await {
234                        db.insert_file(indexed_file).log_err();
235                    }
236
237                    // ALL OF THE BELOW IS FOR TESTING,
238                    // This should be removed as we find and appropriate place for evaluate our search.
239
240                    // let queries = vec![
241                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
242                    //         "compute an outline view of all of the symbols in a buffer",
243                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
244                    // ];
245                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
246
247                    // let t2 = Instant::now();
248                    // let documents = db.get_documents().unwrap();
249                    // let files = db.get_files().unwrap();
250                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
251
252                    // let t1 = Instant::now();
253                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
254                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
255                    // for (idx, embed) in embeddings.into_iter().enumerate() {
256                    //     let t0 = Instant::now();
257                    //     println!("\nQuery: {:?}", queries[idx]);
258                    //     let results = bfs.top_k_search(&embed, 5).await;
259                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
260                    //     for (id, distance) in results {
261                    //         println!("");
262                    //         println!("   distance: {:?}", distance);
263                    //         println!("   document: {:?}", documents[&id].name);
264                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
265                    //     }
266
267                    // }
268
269                    anyhow::Ok(())
270                }
271                .log_err(),
272            );
273
274            let provider = DummyEmbeddings {};
275            // let provider = OpenAIEmbeddings { client };
276
277            cx.background()
278                .scoped(|scope| {
279                    for _ in 0..cx.background().num_cpus() {
280                        scope.spawn(async {
281                            let mut parser = Parser::new();
282                            let mut cursor = QueryCursor::new();
283                            while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
284                            {
285                                if let Some(indexed_file) = Self::index_file(
286                                    &mut cursor,
287                                    &mut parser,
288                                    &provider,
289                                    &language_registry,
290                                    file_path,
291                                    content,
292                                )
293                                .await
294                                .log_err()
295                                {
296                                    indexed_files_tx.try_send(indexed_file).unwrap();
297                                }
298                            }
299                        });
300                    }
301                })
302                .await;
303            drop(indexed_files_tx);
304
305            db_write_task.await;
306            anyhow::Ok(())
307        })
308    }
309
310    pub fn search(
311        &mut self,
312        phrase: String,
313        limit: usize,
314        cx: &mut ModelContext<Self>,
315    ) -> Task<Result<Vec<SearchResult>>> {
316        let embedding_provider = self.embedding_provider.clone();
317        let database_url = self.database_url.clone();
318        cx.spawn(|this, cx| async move {
319            let database = VectorDatabase::new(database_url.as_ref())?;
320
321            // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
322            //
323            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
324
325            database.for_each_document(0, |id, embedding| {
326                dbg!(id, &embedding);
327
328                let similarity = dot(&embedding.0, &embedding.0);
329                let ix = match results.binary_search_by(|(_, s)| {
330                    s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
331                }) {
332                    Ok(ix) => ix,
333                    Err(ix) => ix,
334                };
335
336                results.insert(ix, (id, similarity));
337                results.truncate(limit);
338            })?;
339
340            dbg!(&results);
341
342            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
343            // let documents = database.get_documents_by_ids(ids)?;
344
345            // let search_provider = cx
346            //     .background()
347            //     .spawn(async move { BruteForceSearch::load(&database) })
348            //     .await?;
349
350            // let results = search_provider.top_k_search(&embedding, limit))
351
352            anyhow::Ok(vec![])
353        })
354    }
355}
356
357impl Entity for VectorStore {
358    type Event = ();
359}
360
361fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
362    let len = vec_a.len();
363    assert_eq!(len, vec_b.len());
364
365    let mut result = 0.0;
366    unsafe {
367        matrixmultiply::sgemm(
368            1,
369            len,
370            1,
371            1.0,
372            vec_a.as_ptr(),
373            len as isize,
374            1,
375            vec_b.as_ptr(),
376            1,
377            len as isize,
378            0.0,
379            &mut result as *mut f32,
380            1,
381            1,
382        );
383    }
384    result
385}