vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4mod search;
  5
  6#[cfg(test)]
  7mod vector_store_tests;
  8
  9use anyhow::{anyhow, Result};
 10use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
 11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 12use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 13use language::{Language, LanguageRegistry};
 14use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 15use project::{Fs, Project};
 16use smol::channel;
 17use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 18use tree_sitter::{Parser, QueryCursor};
 19use util::{http::HttpClient, ResultExt, TryFutureExt};
 20use workspace::{Workspace, WorkspaceCreated};
 21
 22#[derive(Debug)]
 23pub struct Document {
 24    pub offset: usize,
 25    pub name: String,
 26    pub embedding: Vec<f32>,
 27}
 28
 29pub fn init(
 30    fs: Arc<dyn Fs>,
 31    http_client: Arc<dyn HttpClient>,
 32    language_registry: Arc<LanguageRegistry>,
 33    cx: &mut AppContext,
 34) {
 35    let vector_store = cx.add_model(|cx| {
 36        VectorStore::new(
 37            fs,
 38            VECTOR_DB_URL.to_string(),
 39            Arc::new(OpenAIEmbeddings {
 40                client: http_client,
 41            }),
 42            language_registry,
 43        )
 44    });
 45
 46    cx.subscribe_global::<WorkspaceCreated, _>({
 47        let vector_store = vector_store.clone();
 48        move |event, cx| {
 49            let workspace = &event.0;
 50            if let Some(workspace) = workspace.upgrade(cx) {
 51                let project = workspace.read(cx).project().clone();
 52                if project.read(cx).is_local() {
 53                    vector_store.update(cx, |store, cx| {
 54                        store.add_project(project, cx).detach();
 55                    });
 56                }
 57            }
 58        }
 59    })
 60    .detach();
 61
 62    cx.add_action({
 63        move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 64            let vector_store = vector_store.clone();
 65            workspace.toggle_modal(cx, |workspace, cx| {
 66                let project = workspace.project().clone();
 67                let workspace = cx.weak_handle();
 68                cx.add_view(|cx| {
 69                    SemanticSearch::new(
 70                        SemanticSearchDelegate::new(workspace, project, vector_store),
 71                        cx,
 72                    )
 73                })
 74            })
 75        }
 76    });
 77    SemanticSearch::init(cx);
 78    // cx.add_action({
 79    //     let vector_store = vector_store.clone();
 80    //     move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
 81    //         let t0 = std::time::Instant::now();
 82    //         let task = vector_store.update(cx, |store, cx| {
 83    //             store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
 84    //         });
 85
 86    //         cx.spawn(|this, cx| async move {
 87    //             let results = task.await?;
 88    //             let duration = t0.elapsed();
 89
 90    //             println!("search took {:?}", duration);
 91    //             println!("results {:?}", results);
 92
 93    //             anyhow::Ok(())
 94    //         }).detach()
 95    //     }
 96    // });
 97}
 98
 99#[derive(Debug)]
100pub struct IndexedFile {
101    path: PathBuf,
102    sha1: FileSha1,
103    documents: Vec<Document>,
104}
105
106pub struct VectorStore {
107    fs: Arc<dyn Fs>,
108    database_url: Arc<str>,
109    embedding_provider: Arc<dyn EmbeddingProvider>,
110    language_registry: Arc<LanguageRegistry>,
111}
112
113#[derive(Debug)]
114pub struct SearchResult {
115    pub name: String,
116    pub offset: usize,
117    pub file_path: PathBuf,
118}
119
120impl VectorStore {
121    fn new(
122        fs: Arc<dyn Fs>,
123        database_url: String,
124        embedding_provider: Arc<dyn EmbeddingProvider>,
125        language_registry: Arc<LanguageRegistry>,
126    ) -> Self {
127        Self {
128            fs,
129            database_url: database_url.into(),
130            embedding_provider,
131            language_registry,
132        }
133    }
134
135    async fn index_file(
136        cursor: &mut QueryCursor,
137        parser: &mut Parser,
138        embedding_provider: &dyn EmbeddingProvider,
139        language: Arc<Language>,
140        file_path: PathBuf,
141        content: String,
142    ) -> Result<IndexedFile> {
143        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
144        let outline_config = grammar
145            .outline_config
146            .as_ref()
147            .ok_or_else(|| anyhow!("no outline query"))?;
148
149        parser.set_language(grammar.ts_language).unwrap();
150        let tree = parser
151            .parse(&content, None)
152            .ok_or_else(|| anyhow!("parsing failed"))?;
153
154        let mut documents = Vec::new();
155        let mut context_spans = Vec::new();
156        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
157            let mut item_range = None;
158            let mut name_range = None;
159            for capture in mat.captures {
160                if capture.index == outline_config.item_capture_ix {
161                    item_range = Some(capture.node.byte_range());
162                } else if capture.index == outline_config.name_capture_ix {
163                    name_range = Some(capture.node.byte_range());
164                }
165            }
166
167            if let Some((item_range, name_range)) = item_range.zip(name_range) {
168                if let Some((item, name)) =
169                    content.get(item_range.clone()).zip(content.get(name_range))
170                {
171                    context_spans.push(item);
172                    documents.push(Document {
173                        name: name.to_string(),
174                        offset: item_range.start,
175                        embedding: Vec::new(),
176                    });
177                }
178            }
179        }
180
181        let embeddings = embedding_provider.embed_batch(context_spans).await?;
182        for (document, embedding) in documents.iter_mut().zip(embeddings) {
183            document.embedding = embedding;
184        }
185
186        let sha1 = FileSha1::from_str(content);
187
188        return Ok(IndexedFile {
189            path: file_path,
190            sha1,
191            documents,
192        });
193    }
194
195    fn add_project(
196        &mut self,
197        project: ModelHandle<Project>,
198        cx: &mut ModelContext<Self>,
199    ) -> Task<Result<()>> {
200        let worktree_scans_complete = project
201            .read(cx)
202            .worktrees(cx)
203            .map(|worktree| {
204                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
205                async move {
206                    scan_complete.await;
207                    log::info!("worktree scan completed");
208                }
209            })
210            .collect::<Vec<_>>();
211
212        let fs = self.fs.clone();
213        let language_registry = self.language_registry.clone();
214        let embedding_provider = self.embedding_provider.clone();
215        let database_url = self.database_url.clone();
216
217        cx.spawn(|_, cx| async move {
218            futures::future::join_all(worktree_scans_complete).await;
219
220            // TODO: remove this after fixing the bug in scan_complete
221            cx.background()
222                .timer(std::time::Duration::from_secs(3))
223                .await;
224
225            let db = VectorDatabase::new(&database_url)?;
226
227            let worktrees = project.read_with(&cx, |project, cx| {
228                project
229                    .worktrees(cx)
230                    .map(|worktree| worktree.read(cx).snapshot())
231                    .collect::<Vec<_>>()
232            });
233
234            let worktree_root_paths = worktrees
235                .iter()
236                .map(|worktree| worktree.abs_path().clone())
237                .collect::<Vec<_>>();
238
239            // Here we query the worktree ids, and yet we dont have them elsewhere
240            // We likely want to clean up these datastructures
241            let (db, worktree_hashes, worktree_ids) = cx
242                .background()
243                .spawn(async move {
244                    let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
245                    let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
246                    for worktree_root_path in worktree_root_paths {
247                        let worktree_id =
248                            db.find_or_create_worktree(worktree_root_path.as_ref())?;
249                        worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
250                        hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
251                    }
252                    anyhow::Ok((db, hashes, worktree_ids))
253                })
254                .await?;
255
256            let (paths_tx, paths_rx) =
257                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
258            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
259            cx.background()
260                .spawn({
261                    let fs = fs.clone();
262                    async move {
263                        for worktree in worktrees.into_iter() {
264                            let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
265                            let file_hashes = &worktree_hashes[&worktree_id];
266                            for file in worktree.files(false, 0) {
267                                let absolute_path = worktree.absolutize(&file.path);
268
269                                if let Ok(language) = language_registry
270                                    .language_for_file(&absolute_path, None)
271                                    .await
272                                {
273                                    if language.name().as_ref() != "Rust" {
274                                        continue;
275                                    }
276
277                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
278                                        log::info!("loaded file: {absolute_path:?}");
279
280                                        let path_buf = file.path.to_path_buf();
281                                        let already_stored = file_hashes
282                                            .get(&path_buf)
283                                            .map_or(false, |existing_hash| {
284                                                existing_hash.equals(&content)
285                                            });
286
287                                        if !already_stored {
288                                            log::info!(
289                                                "File Changed (Sending to Parse): {:?}",
290                                                &path_buf
291                                            );
292                                            paths_tx
293                                                .try_send((
294                                                    worktree_id,
295                                                    path_buf,
296                                                    content,
297                                                    language,
298                                                ))
299                                                .unwrap();
300                                        }
301                                    }
302                                }
303                            }
304                        }
305                    }
306                })
307                .detach();
308
309            let db_write_task = cx.background().spawn(
310                async move {
311                    // Initialize Database, creates database and tables if not exists
312                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
313                        db.insert_file(worktree_id, indexed_file).log_err();
314                    }
315
316                    // ALL OF THE BELOW IS FOR TESTING,
317                    // This should be removed as we find and appropriate place for evaluate our search.
318
319                    // let queries = vec![
320                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
321                    //         "compute an outline view of all of the symbols in a buffer",
322                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
323                    // ];
324                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
325
326                    // let t2 = Instant::now();
327                    // let documents = db.get_documents().unwrap();
328                    // let files = db.get_files().unwrap();
329                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
330
331                    // let t1 = Instant::now();
332                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
333                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
334                    // for (idx, embed) in embeddings.into_iter().enumerate() {
335                    //     let t0 = Instant::now();
336                    //     println!("\nQuery: {:?}", queries[idx]);
337                    //     let results = bfs.top_k_search(&embed, 5).await;
338                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
339                    //     for (id, distance) in results {
340                    //         println!("");
341                    //         println!("   distance: {:?}", distance);
342                    //         println!("   document: {:?}", documents[&id].name);
343                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
344                    //     }
345
346                    // }
347
348                    anyhow::Ok(())
349                }
350                .log_err(),
351            );
352
353            cx.background()
354                .scoped(|scope| {
355                    for _ in 0..cx.background().num_cpus() {
356                        scope.spawn(async {
357                            let mut parser = Parser::new();
358                            let mut cursor = QueryCursor::new();
359                            while let Ok((worktree_id, file_path, content, language)) =
360                                paths_rx.recv().await
361                            {
362                                if let Some(indexed_file) = Self::index_file(
363                                    &mut cursor,
364                                    &mut parser,
365                                    embedding_provider.as_ref(),
366                                    language,
367                                    file_path,
368                                    content,
369                                )
370                                .await
371                                .log_err()
372                                {
373                                    indexed_files_tx
374                                        .try_send((worktree_id, indexed_file))
375                                        .unwrap();
376                                }
377                            }
378                        });
379                    }
380                })
381                .await;
382            drop(indexed_files_tx);
383
384            db_write_task.await;
385            anyhow::Ok(())
386        })
387    }
388
389    pub fn search(
390        &mut self,
391        phrase: String,
392        limit: usize,
393        cx: &mut ModelContext<Self>,
394    ) -> Task<Result<Vec<SearchResult>>> {
395        let embedding_provider = self.embedding_provider.clone();
396        let database_url = self.database_url.clone();
397        cx.background().spawn(async move {
398            let database = VectorDatabase::new(database_url.as_ref())?;
399
400            let phrase_embedding = embedding_provider
401                .embed_batch(vec![&phrase])
402                .await?
403                .into_iter()
404                .next()
405                .unwrap();
406
407            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
408            database.for_each_document(0, |id, embedding| {
409                let similarity = dot(&embedding.0, &phrase_embedding);
410                let ix = match results.binary_search_by(|(_, s)| {
411                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
412                }) {
413                    Ok(ix) => ix,
414                    Err(ix) => ix,
415                };
416                results.insert(ix, (id, similarity));
417                results.truncate(limit);
418            })?;
419
420            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
421            let documents = database.get_documents_by_ids(&ids)?;
422
423            anyhow::Ok(
424                documents
425                    .into_iter()
426                    .map(|(file_path, offset, name)| SearchResult {
427                        name,
428                        offset,
429                        file_path,
430                    })
431                    .collect(),
432            )
433        })
434    }
435}
436
437impl Entity for VectorStore {
438    type Event = ();
439}
440
441fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
442    let len = vec_a.len();
443    assert_eq!(len, vec_b.len());
444
445    let mut result = 0.0;
446    unsafe {
447        matrixmultiply::sgemm(
448            1,
449            len,
450            1,
451            1.0,
452            vec_a.as_ptr(),
453            len as isize,
454            1,
455            vec_b.as_ptr(),
456            1,
457            len as isize,
458            0.0,
459            &mut result as *mut f32,
460            1,
461            1,
462        );
463    }
464    result
465}