vector_store.rs

  1mod db;
  2mod embedding;
  3mod search;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
 10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 11use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 12use language::{Language, LanguageRegistry};
 13use project::{Fs, Project};
 14use smol::channel;
 15use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 16use tree_sitter::{Parser, QueryCursor};
 17use util::{http::HttpClient, ResultExt, TryFutureExt};
 18use workspace::{Workspace, WorkspaceCreated};
 19
 20actions!(semantic_search, [TestSearch]);
 21
 22#[derive(Debug)]
 23pub struct Document {
 24    pub offset: usize,
 25    pub name: String,
 26    pub embedding: Vec<f32>,
 27}
 28
 29pub fn init(
 30    fs: Arc<dyn Fs>,
 31    http_client: Arc<dyn HttpClient>,
 32    language_registry: Arc<LanguageRegistry>,
 33    cx: &mut AppContext,
 34) {
 35    let vector_store = cx.add_model(|cx| {
 36        VectorStore::new(
 37            fs,
 38            VECTOR_DB_URL.to_string(),
 39            Arc::new(OpenAIEmbeddings {
 40                client: http_client,
 41            }),
 42            language_registry,
 43        )
 44    });
 45
 46    cx.subscribe_global::<WorkspaceCreated, _>({
 47        let vector_store = vector_store.clone();
 48        move |event, cx| {
 49            let workspace = &event.0;
 50            if let Some(workspace) = workspace.upgrade(cx) {
 51                let project = workspace.read(cx).project().clone();
 52                if project.read(cx).is_local() {
 53                    vector_store.update(cx, |store, cx| {
 54                        store.add_project(project, cx).detach();
 55                    });
 56                }
 57            }
 58        }
 59    })
 60    .detach();
 61
 62    cx.add_action({
 63        let vector_store = vector_store.clone();
 64        move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
 65            let t0 = std::time::Instant::now();
 66            let task = vector_store.update(cx, |store, cx| {
 67                store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
 68            });
 69
 70            cx.spawn(|this, cx| async move {
 71                let results = task.await?;
 72                let duration = t0.elapsed();
 73
 74                println!("search took {:?}", duration);
 75                println!("results {:?}", results);
 76
 77                anyhow::Ok(())
 78            }).detach()
 79        }
 80    });
 81}
 82
 83#[derive(Debug)]
 84pub struct IndexedFile {
 85    path: PathBuf,
 86    sha1: FileSha1,
 87    documents: Vec<Document>,
 88}
 89
 90struct VectorStore {
 91    fs: Arc<dyn Fs>,
 92    database_url: Arc<str>,
 93    embedding_provider: Arc<dyn EmbeddingProvider>,
 94    language_registry: Arc<LanguageRegistry>,
 95}
 96
 97#[derive(Debug)]
 98pub struct SearchResult {
 99    pub name: String,
100    pub offset: usize,
101    pub file_path: PathBuf,
102}
103
104impl VectorStore {
105    fn new(
106        fs: Arc<dyn Fs>,
107        database_url: String,
108        embedding_provider: Arc<dyn EmbeddingProvider>,
109        language_registry: Arc<LanguageRegistry>,
110    ) -> Self {
111        Self {
112            fs,
113            database_url: database_url.into(),
114            embedding_provider,
115            language_registry,
116        }
117    }
118
119    async fn index_file(
120        cursor: &mut QueryCursor,
121        parser: &mut Parser,
122        embedding_provider: &dyn EmbeddingProvider,
123        language: Arc<Language>,
124        file_path: PathBuf,
125        content: String,
126    ) -> Result<IndexedFile> {
127        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
128        let outline_config = grammar
129            .outline_config
130            .as_ref()
131            .ok_or_else(|| anyhow!("no outline query"))?;
132
133        parser.set_language(grammar.ts_language).unwrap();
134        let tree = parser
135            .parse(&content, None)
136            .ok_or_else(|| anyhow!("parsing failed"))?;
137
138        let mut documents = Vec::new();
139        let mut context_spans = Vec::new();
140        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
141            let mut item_range = None;
142            let mut name_range = None;
143            for capture in mat.captures {
144                if capture.index == outline_config.item_capture_ix {
145                    item_range = Some(capture.node.byte_range());
146                } else if capture.index == outline_config.name_capture_ix {
147                    name_range = Some(capture.node.byte_range());
148                }
149            }
150
151            if let Some((item_range, name_range)) = item_range.zip(name_range) {
152                if let Some((item, name)) =
153                    content.get(item_range.clone()).zip(content.get(name_range))
154                {
155                    context_spans.push(item);
156                    documents.push(Document {
157                        name: name.to_string(),
158                        offset: item_range.start,
159                        embedding: Vec::new(),
160                    });
161                }
162            }
163        }
164
165        let embeddings = embedding_provider.embed_batch(context_spans).await?;
166        for (document, embedding) in documents.iter_mut().zip(embeddings) {
167            document.embedding = embedding;
168        }
169
170        let sha1 = FileSha1::from_str(content);
171
172        return Ok(IndexedFile {
173            path: file_path,
174            sha1,
175            documents,
176        });
177    }
178
179    fn add_project(
180        &mut self,
181        project: ModelHandle<Project>,
182        cx: &mut ModelContext<Self>,
183    ) -> Task<Result<()>> {
184        let worktree_scans_complete = project
185            .read(cx)
186            .worktrees(cx)
187            .map(|worktree| {
188                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
189                async move {
190                    scan_complete.await;
191                    log::info!("worktree scan completed");
192                }
193            })
194            .collect::<Vec<_>>();
195
196        let fs = self.fs.clone();
197        let language_registry = self.language_registry.clone();
198        let embedding_provider = self.embedding_provider.clone();
199        let database_url = self.database_url.clone();
200
201        cx.spawn(|_, cx| async move {
202            futures::future::join_all(worktree_scans_complete).await;
203
204            // TODO: remove this after fixing the bug in scan_complete
205            cx.background()
206                .timer(std::time::Duration::from_secs(3))
207                .await;
208
209            let db = VectorDatabase::new(&database_url)?;
210
211            let worktrees = project.read_with(&cx, |project, cx| {
212                project
213                    .worktrees(cx)
214                    .map(|worktree| worktree.read(cx).snapshot())
215                    .collect::<Vec<_>>()
216            });
217
218            let worktree_root_paths = worktrees
219                .iter()
220                .map(|worktree| worktree.abs_path().clone())
221                .collect::<Vec<_>>();
222
223            // Here we query the worktree ids, and yet we dont have them elsewhere
224            // We likely want to clean up these datastructures
225            let (db, worktree_hashes, worktree_ids) = cx
226                .background()
227                .spawn(async move {
228                    let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
229                    let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
230                    for worktree_root_path in worktree_root_paths {
231                        let worktree_id =
232                            db.find_or_create_worktree(worktree_root_path.as_ref())?;
233                        worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
234                        hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
235                    }
236                    anyhow::Ok((db, hashes, worktree_ids))
237                })
238                .await?;
239
240            let (paths_tx, paths_rx) =
241                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
242            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
243            cx.background()
244                .spawn({
245                    let fs = fs.clone();
246                    async move {
247                        for worktree in worktrees.into_iter() {
248                            let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
249                            let file_hashes = &worktree_hashes[&worktree_id];
250                            for file in worktree.files(false, 0) {
251                                let absolute_path = worktree.absolutize(&file.path);
252
253                                if let Ok(language) = language_registry
254                                    .language_for_file(&absolute_path, None)
255                                    .await
256                                {
257                                    if language.name().as_ref() != "Rust" {
258                                        continue;
259                                    }
260
261                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
262                                        log::info!("loaded file: {absolute_path:?}");
263
264                                        let path_buf = file.path.to_path_buf();
265                                        let already_stored = file_hashes
266                                            .get(&path_buf)
267                                            .map_or(false, |existing_hash| {
268                                                existing_hash.equals(&content)
269                                            });
270
271                                        if !already_stored {
272                                            log::info!(
273                                                "File Changed (Sending to Parse): {:?}",
274                                                &path_buf
275                                            );
276                                            paths_tx
277                                                .try_send((
278                                                    worktree_id,
279                                                    path_buf,
280                                                    content,
281                                                    language,
282                                                ))
283                                                .unwrap();
284                                        }
285                                    }
286                                }
287                            }
288                        }
289                    }
290                })
291                .detach();
292
293            let db_write_task = cx.background().spawn(
294                async move {
295                    // Initialize Database, creates database and tables if not exists
296                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
297                        db.insert_file(worktree_id, indexed_file).log_err();
298                    }
299
300                    // ALL OF THE BELOW IS FOR TESTING,
301                    // This should be removed as we find and appropriate place for evaluate our search.
302
303                    // let queries = vec![
304                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
305                    //         "compute an outline view of all of the symbols in a buffer",
306                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
307                    // ];
308                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
309
310                    // let t2 = Instant::now();
311                    // let documents = db.get_documents().unwrap();
312                    // let files = db.get_files().unwrap();
313                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
314
315                    // let t1 = Instant::now();
316                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
317                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
318                    // for (idx, embed) in embeddings.into_iter().enumerate() {
319                    //     let t0 = Instant::now();
320                    //     println!("\nQuery: {:?}", queries[idx]);
321                    //     let results = bfs.top_k_search(&embed, 5).await;
322                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
323                    //     for (id, distance) in results {
324                    //         println!("");
325                    //         println!("   distance: {:?}", distance);
326                    //         println!("   document: {:?}", documents[&id].name);
327                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
328                    //     }
329
330                    // }
331
332                    anyhow::Ok(())
333                }
334                .log_err(),
335            );
336
337            cx.background()
338                .scoped(|scope| {
339                    for _ in 0..cx.background().num_cpus() {
340                        scope.spawn(async {
341                            let mut parser = Parser::new();
342                            let mut cursor = QueryCursor::new();
343                            while let Ok((worktree_id, file_path, content, language)) =
344                                paths_rx.recv().await
345                            {
346                                if let Some(indexed_file) = Self::index_file(
347                                    &mut cursor,
348                                    &mut parser,
349                                    embedding_provider.as_ref(),
350                                    language,
351                                    file_path,
352                                    content,
353                                )
354                                .await
355                                .log_err()
356                                {
357                                    indexed_files_tx
358                                        .try_send((worktree_id, indexed_file))
359                                        .unwrap();
360                                }
361                            }
362                        });
363                    }
364                })
365                .await;
366            drop(indexed_files_tx);
367
368            db_write_task.await;
369            anyhow::Ok(())
370        })
371    }
372
373    pub fn search(
374        &mut self,
375        phrase: String,
376        limit: usize,
377        cx: &mut ModelContext<Self>,
378    ) -> Task<Result<Vec<SearchResult>>> {
379        let embedding_provider = self.embedding_provider.clone();
380        let database_url = self.database_url.clone();
381        cx.background().spawn(async move {
382            let database = VectorDatabase::new(database_url.as_ref())?;
383
384            let phrase_embedding = embedding_provider
385                .embed_batch(vec![&phrase])
386                .await?
387                .into_iter()
388                .next()
389                .unwrap();
390
391            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
392            database.for_each_document(0, |id, embedding| {
393                let similarity = dot(&embedding.0, &phrase_embedding);
394                let ix = match results.binary_search_by(|(_, s)| {
395                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
396                }) {
397                    Ok(ix) => ix,
398                    Err(ix) => ix,
399                };
400                results.insert(ix, (id, similarity));
401                results.truncate(limit);
402            })?;
403
404            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
405            let documents = database.get_documents_by_ids(&ids)?;
406
407            anyhow::Ok(
408                documents
409                    .into_iter()
410                    .map(|(file_path, offset, name)| SearchResult {
411                        name,
412                        offset,
413                        file_path,
414                    })
415                    .collect(),
416            )
417        })
418    }
419}
420
421impl Entity for VectorStore {
422    type Event = ();
423}
424
425fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
426    let len = vec_a.len();
427    assert_eq!(len, vec_b.len());
428
429    let mut result = 0.0;
430    unsafe {
431        matrixmultiply::sgemm(
432            1,
433            len,
434            1,
435            1.0,
436            vec_a.as_ptr(),
437            len as isize,
438            1,
439            vec_b.as_ptr(),
440            1,
441            len as isize,
442            0.0,
443            &mut result as *mut f32,
444            1,
445            1,
446        );
447    }
448    result
449}