vector_store.rs

  1mod db;
  2mod embedding;
  3mod parsing;
  4mod search;
  5
  6#[cfg(test)]
  7mod vector_store_tests;
  8
  9use anyhow::{anyhow, Result};
 10use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
 11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 12use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 13use language::{Language, LanguageRegistry};
 14use parsing::Document;
 15use project::{Fs, Project};
 16use smol::channel;
 17use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 18use tree_sitter::{Parser, QueryCursor};
 19use util::{http::HttpClient, ResultExt, TryFutureExt};
 20use workspace::{Workspace, WorkspaceCreated};
 21
 22actions!(semantic_search, [TestSearch]);
 23
 24pub fn init(
 25    fs: Arc<dyn Fs>,
 26    http_client: Arc<dyn HttpClient>,
 27    language_registry: Arc<LanguageRegistry>,
 28    cx: &mut AppContext,
 29) {
 30    let vector_store = cx.add_model(|cx| {
 31        VectorStore::new(
 32            fs,
 33            VECTOR_DB_URL.to_string(),
 34            Arc::new(OpenAIEmbeddings {
 35                client: http_client,
 36            }),
 37            language_registry,
 38        )
 39    });
 40
 41    cx.subscribe_global::<WorkspaceCreated, _>({
 42        let vector_store = vector_store.clone();
 43        move |event, cx| {
 44            let workspace = &event.0;
 45            if let Some(workspace) = workspace.upgrade(cx) {
 46                let project = workspace.read(cx).project().clone();
 47                if project.read(cx).is_local() {
 48                    vector_store.update(cx, |store, cx| {
 49                        store.add_project(project, cx).detach();
 50                    });
 51                }
 52            }
 53        }
 54    })
 55    .detach();
 56
 57    cx.add_action({
 58        let vector_store = vector_store.clone();
 59        move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
 60            let t0 = std::time::Instant::now();
 61            let task = vector_store.update(cx, |store, cx| {
 62                store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
 63            });
 64
 65            cx.spawn(|this, cx| async move {
 66                let results = task.await?;
 67                let duration = t0.elapsed();
 68
 69                println!("search took {:?}", duration);
 70                println!("results {:?}", results);
 71
 72                anyhow::Ok(())
 73            }).detach()
 74        }
 75    });
 76}
 77
 78#[derive(Debug)]
 79pub struct IndexedFile {
 80    path: PathBuf,
 81    sha1: FileSha1,
 82    documents: Vec<Document>,
 83}
 84
 85struct VectorStore {
 86    fs: Arc<dyn Fs>,
 87    database_url: Arc<str>,
 88    embedding_provider: Arc<dyn EmbeddingProvider>,
 89    language_registry: Arc<LanguageRegistry>,
 90}
 91
 92#[derive(Debug)]
 93pub struct SearchResult {
 94    pub name: String,
 95    pub offset: usize,
 96    pub file_path: PathBuf,
 97}
 98
 99impl VectorStore {
100    fn new(
101        fs: Arc<dyn Fs>,
102        database_url: String,
103        embedding_provider: Arc<dyn EmbeddingProvider>,
104        language_registry: Arc<LanguageRegistry>,
105    ) -> Self {
106        Self {
107            fs,
108            database_url: database_url.into(),
109            embedding_provider,
110            language_registry,
111        }
112    }
113
114    async fn index_file(
115        cursor: &mut QueryCursor,
116        parser: &mut Parser,
117        embedding_provider: &dyn EmbeddingProvider,
118        language: Arc<Language>,
119        file_path: PathBuf,
120        content: String,
121    ) -> Result<IndexedFile> {
122        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
123        let outline_config = grammar
124            .outline_config
125            .as_ref()
126            .ok_or_else(|| anyhow!("no outline query"))?;
127
128        parser.set_language(grammar.ts_language).unwrap();
129        let tree = parser
130            .parse(&content, None)
131            .ok_or_else(|| anyhow!("parsing failed"))?;
132
133        let mut documents = Vec::new();
134        let mut context_spans = Vec::new();
135        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
136            let mut item_range = None;
137            let mut name_range = None;
138            for capture in mat.captures {
139                if capture.index == outline_config.item_capture_ix {
140                    item_range = Some(capture.node.byte_range());
141                } else if capture.index == outline_config.name_capture_ix {
142                    name_range = Some(capture.node.byte_range());
143                }
144            }
145
146            if let Some((item_range, name_range)) = item_range.zip(name_range) {
147                if let Some((item, name)) =
148                    content.get(item_range.clone()).zip(content.get(name_range))
149                {
150                    context_spans.push(item);
151                    documents.push(Document {
152                        name: name.to_string(),
153                        offset: item_range.start,
154                        embedding: Vec::new(),
155                    });
156                }
157            }
158        }
159
160        let embeddings = embedding_provider.embed_batch(context_spans).await?;
161        for (document, embedding) in documents.iter_mut().zip(embeddings) {
162            document.embedding = embedding;
163        }
164
165        let sha1 = FileSha1::from_str(content);
166
167        return Ok(IndexedFile {
168            path: file_path,
169            sha1,
170            documents,
171        });
172    }
173
174    fn add_project(
175        &mut self,
176        project: ModelHandle<Project>,
177        cx: &mut ModelContext<Self>,
178    ) -> Task<Result<()>> {
179        let worktree_scans_complete = project
180            .read(cx)
181            .worktrees(cx)
182            .map(|worktree| {
183                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
184                async move {
185                    scan_complete.await;
186                    log::info!("worktree scan completed");
187                }
188            })
189            .collect::<Vec<_>>();
190
191        let fs = self.fs.clone();
192        let language_registry = self.language_registry.clone();
193        let embedding_provider = self.embedding_provider.clone();
194        let database_url = self.database_url.clone();
195
196        cx.spawn(|_, cx| async move {
197            futures::future::join_all(worktree_scans_complete).await;
198
199            // TODO: remove this after fixing the bug in scan_complete
200            cx.background()
201                .timer(std::time::Duration::from_secs(3))
202                .await;
203
204            let db = VectorDatabase::new(&database_url)?;
205
206            let worktrees = project.read_with(&cx, |project, cx| {
207                project
208                    .worktrees(cx)
209                    .map(|worktree| worktree.read(cx).snapshot())
210                    .collect::<Vec<_>>()
211            });
212
213            let worktree_root_paths = worktrees
214                .iter()
215                .map(|worktree| worktree.abs_path().clone())
216                .collect::<Vec<_>>();
217
218            // Here we query the worktree ids, and yet we dont have them elsewhere
219            // We likely want to clean up these datastructures
220            let (db, worktree_hashes, worktree_ids) = cx
221                .background()
222                .spawn(async move {
223                    let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
224                    let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
225                    for worktree_root_path in worktree_root_paths {
226                        let worktree_id =
227                            db.find_or_create_worktree(worktree_root_path.as_ref())?;
228                        worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
229                        hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
230                    }
231                    anyhow::Ok((db, hashes, worktree_ids))
232                })
233                .await?;
234
235            let (paths_tx, paths_rx) =
236                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
237            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
238            cx.background()
239                .spawn({
240                    let fs = fs.clone();
241                    async move {
242                        for worktree in worktrees.into_iter() {
243                            let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
244                            let file_hashes = &worktree_hashes[&worktree_id];
245                            for file in worktree.files(false, 0) {
246                                let absolute_path = worktree.absolutize(&file.path);
247
248                                if let Ok(language) = language_registry
249                                    .language_for_file(&absolute_path, None)
250                                    .await
251                                {
252                                    if language.name().as_ref() != "Rust" {
253                                        continue;
254                                    }
255
256                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
257                                        log::info!("loaded file: {absolute_path:?}");
258
259                                        let path_buf = file.path.to_path_buf();
260                                        let already_stored = file_hashes
261                                            .get(&path_buf)
262                                            .map_or(false, |existing_hash| {
263                                                existing_hash.equals(&content)
264                                            });
265
266                                        if !already_stored {
267                                            log::info!(
268                                                "File Changed (Sending to Parse): {:?}",
269                                                &path_buf
270                                            );
271                                            paths_tx
272                                                .try_send((
273                                                    worktree_id,
274                                                    path_buf,
275                                                    content,
276                                                    language,
277                                                ))
278                                                .unwrap();
279                                        }
280                                    }
281                                }
282                            }
283                        }
284                    }
285                })
286                .detach();
287
288            let db_write_task = cx.background().spawn(
289                async move {
290                    // Initialize Database, creates database and tables if not exists
291                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
292                        db.insert_file(worktree_id, indexed_file).log_err();
293                    }
294
295                    // ALL OF THE BELOW IS FOR TESTING,
296                    // This should be removed as we find and appropriate place for evaluate our search.
297
298                    // let queries = vec![
299                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
300                    //         "compute an outline view of all of the symbols in a buffer",
301                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
302                    // ];
303                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
304
305                    // let t2 = Instant::now();
306                    // let documents = db.get_documents().unwrap();
307                    // let files = db.get_files().unwrap();
308                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
309
310                    // let t1 = Instant::now();
311                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
312                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
313                    // for (idx, embed) in embeddings.into_iter().enumerate() {
314                    //     let t0 = Instant::now();
315                    //     println!("\nQuery: {:?}", queries[idx]);
316                    //     let results = bfs.top_k_search(&embed, 5).await;
317                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
318                    //     for (id, distance) in results {
319                    //         println!("");
320                    //         println!("   distance: {:?}", distance);
321                    //         println!("   document: {:?}", documents[&id].name);
322                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
323                    //     }
324
325                    // }
326
327                    anyhow::Ok(())
328                }
329                .log_err(),
330            );
331
332            cx.background()
333                .scoped(|scope| {
334                    for _ in 0..cx.background().num_cpus() {
335                        scope.spawn(async {
336                            let mut parser = Parser::new();
337                            let mut cursor = QueryCursor::new();
338                            while let Ok((worktree_id, file_path, content, language)) =
339                                paths_rx.recv().await
340                            {
341                                if let Some(indexed_file) = Self::index_file(
342                                    &mut cursor,
343                                    &mut parser,
344                                    embedding_provider.as_ref(),
345                                    language,
346                                    file_path,
347                                    content,
348                                )
349                                .await
350                                .log_err()
351                                {
352                                    indexed_files_tx
353                                        .try_send((worktree_id, indexed_file))
354                                        .unwrap();
355                                }
356                            }
357                        });
358                    }
359                })
360                .await;
361            drop(indexed_files_tx);
362
363            db_write_task.await;
364            anyhow::Ok(())
365        })
366    }
367
368    pub fn search(
369        &mut self,
370        phrase: String,
371        limit: usize,
372        cx: &mut ModelContext<Self>,
373    ) -> Task<Result<Vec<SearchResult>>> {
374        let embedding_provider = self.embedding_provider.clone();
375        let database_url = self.database_url.clone();
376        cx.background().spawn(async move {
377            let database = VectorDatabase::new(database_url.as_ref())?;
378
379            let phrase_embedding = embedding_provider
380                .embed_batch(vec![&phrase])
381                .await?
382                .into_iter()
383                .next()
384                .unwrap();
385
386            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
387            database.for_each_document(0, |id, embedding| {
388                let similarity = dot(&embedding.0, &phrase_embedding);
389                let ix = match results.binary_search_by(|(_, s)| {
390                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
391                }) {
392                    Ok(ix) => ix,
393                    Err(ix) => ix,
394                };
395                results.insert(ix, (id, similarity));
396                results.truncate(limit);
397            })?;
398
399            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
400            let documents = database.get_documents_by_ids(&ids)?;
401
402            anyhow::Ok(
403                documents
404                    .into_iter()
405                    .map(|(file_path, offset, name)| SearchResult {
406                        name,
407                        offset,
408                        file_path,
409                    })
410                    .collect(),
411            )
412        })
413    }
414}
415
416impl Entity for VectorStore {
417    type Event = ();
418}
419
420fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
421    let len = vec_a.len();
422    assert_eq!(len, vec_b.len());
423
424    let mut result = 0.0;
425    unsafe {
426        matrixmultiply::sgemm(
427            1,
428            len,
429            1,
430            1.0,
431            vec_a.as_ptr(),
432            len as isize,
433            1,
434            vec_b.as_ptr(),
435            1,
436            len as isize,
437            0.0,
438            &mut result as *mut f32,
439            1,
440            1,
441        );
442    }
443    result
444}