vector_store.rs

  1mod db;
  2mod embedding;
  3mod parsing;
  4mod search;
  5
  6#[cfg(test)]
  7mod vector_store_tests;
  8
  9use anyhow::{anyhow, Result};
 10use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
 11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 12use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
 13use language::{Language, LanguageRegistry};
 14use parsing::Document;
 15use project::{Fs, Project};
 16use smol::channel;
 17use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 18use tree_sitter::{Parser, QueryCursor};
 19use util::{http::HttpClient, ResultExt, TryFutureExt};
 20use workspace::WorkspaceCreated;
 21
 22pub fn init(
 23    fs: Arc<dyn Fs>,
 24    http_client: Arc<dyn HttpClient>,
 25    language_registry: Arc<LanguageRegistry>,
 26    cx: &mut AppContext,
 27) {
 28    let vector_store = cx.add_model(|cx| {
 29        VectorStore::new(
 30            fs,
 31            VECTOR_DB_URL.to_string(),
 32            Arc::new(OpenAIEmbeddings {
 33                client: http_client,
 34            }),
 35            language_registry,
 36        )
 37    });
 38
 39    cx.subscribe_global::<WorkspaceCreated, _>({
 40        let vector_store = vector_store.clone();
 41        move |event, cx| {
 42            let workspace = &event.0;
 43            if let Some(workspace) = workspace.upgrade(cx) {
 44                let project = workspace.read(cx).project().clone();
 45                if project.read(cx).is_local() {
 46                    vector_store.update(cx, |store, cx| {
 47                        store.add_project(project, cx).detach();
 48                    });
 49                }
 50            }
 51        }
 52    })
 53    .detach();
 54}
 55
 56#[derive(Debug)]
 57pub struct IndexedFile {
 58    path: PathBuf,
 59    sha1: FileSha1,
 60    documents: Vec<Document>,
 61}
 62
 63struct VectorStore {
 64    fs: Arc<dyn Fs>,
 65    database_url: Arc<str>,
 66    embedding_provider: Arc<dyn EmbeddingProvider>,
 67    language_registry: Arc<LanguageRegistry>,
 68}
 69
 70pub struct SearchResult {
 71    pub name: String,
 72    pub offset: usize,
 73    pub file_path: PathBuf,
 74}
 75
 76impl VectorStore {
 77    fn new(
 78        fs: Arc<dyn Fs>,
 79        database_url: String,
 80        embedding_provider: Arc<dyn EmbeddingProvider>,
 81        language_registry: Arc<LanguageRegistry>,
 82    ) -> Self {
 83        Self {
 84            fs,
 85            database_url: database_url.into(),
 86            embedding_provider,
 87            language_registry,
 88        }
 89    }
 90
 91    async fn index_file(
 92        cursor: &mut QueryCursor,
 93        parser: &mut Parser,
 94        embedding_provider: &dyn EmbeddingProvider,
 95        language: Arc<Language>,
 96        file_path: PathBuf,
 97        content: String,
 98    ) -> Result<IndexedFile> {
 99        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
100        let outline_config = grammar
101            .outline_config
102            .as_ref()
103            .ok_or_else(|| anyhow!("no outline query"))?;
104
105        parser.set_language(grammar.ts_language).unwrap();
106        let tree = parser
107            .parse(&content, None)
108            .ok_or_else(|| anyhow!("parsing failed"))?;
109
110        let mut documents = Vec::new();
111        let mut context_spans = Vec::new();
112        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
113            let mut item_range = None;
114            let mut name_range = None;
115            for capture in mat.captures {
116                if capture.index == outline_config.item_capture_ix {
117                    item_range = Some(capture.node.byte_range());
118                } else if capture.index == outline_config.name_capture_ix {
119                    name_range = Some(capture.node.byte_range());
120                }
121            }
122
123            if let Some((item_range, name_range)) = item_range.zip(name_range) {
124                if let Some((item, name)) =
125                    content.get(item_range.clone()).zip(content.get(name_range))
126                {
127                    context_spans.push(item);
128                    documents.push(Document {
129                        name: name.to_string(),
130                        offset: item_range.start,
131                        embedding: Vec::new(),
132                    });
133                }
134            }
135        }
136
137        let embeddings = embedding_provider.embed_batch(context_spans).await?;
138        for (document, embedding) in documents.iter_mut().zip(embeddings) {
139            document.embedding = embedding;
140        }
141
142        let sha1 = FileSha1::from_str(content);
143
144        return Ok(IndexedFile {
145            path: file_path,
146            sha1,
147            documents,
148        });
149    }
150
151    fn add_project(
152        &mut self,
153        project: ModelHandle<Project>,
154        cx: &mut ModelContext<Self>,
155    ) -> Task<Result<()>> {
156        let worktree_scans_complete = project
157            .read(cx)
158            .worktrees(cx)
159            .map(|worktree| {
160                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
161                async move {
162                    scan_complete.await;
163                    log::info!("worktree scan completed");
164                }
165            })
166            .collect::<Vec<_>>();
167
168        let fs = self.fs.clone();
169        let language_registry = self.language_registry.clone();
170        let embedding_provider = self.embedding_provider.clone();
171        let database_url = self.database_url.clone();
172
173        cx.spawn(|_, cx| async move {
174            futures::future::join_all(worktree_scans_complete).await;
175
176            // TODO: remove this after fixing the bug in scan_complete
177            cx.background()
178                .timer(std::time::Duration::from_secs(3))
179                .await;
180
181            let db = VectorDatabase::new(&database_url)?;
182
183            let worktrees = project.read_with(&cx, |project, cx| {
184                project
185                    .worktrees(cx)
186                    .map(|worktree| worktree.read(cx).snapshot())
187                    .collect::<Vec<_>>()
188            });
189
190            let worktree_root_paths = worktrees
191                .iter()
192                .map(|worktree| worktree.abs_path().clone())
193                .collect::<Vec<_>>();
194
195            // Here we query the worktree ids, and yet we dont have them elsewhere
196            // We likely want to clean up these datastructures
197            let (db, worktree_hashes, worktree_ids) = cx
198                .background()
199                .spawn(async move {
200                    let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
201                    let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
202                    for worktree_root_path in worktree_root_paths {
203                        let worktree_id =
204                            db.find_or_create_worktree(worktree_root_path.as_ref())?;
205                        worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
206                        hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
207                    }
208                    anyhow::Ok((db, hashes, worktree_ids))
209                })
210                .await?;
211
212            let (paths_tx, paths_rx) =
213                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
214            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
215            cx.background()
216                .spawn({
217                    let fs = fs.clone();
218                    async move {
219                        for worktree in worktrees.into_iter() {
220                            let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
221                            let file_hashes = &worktree_hashes[&worktree_id];
222                            for file in worktree.files(false, 0) {
223                                let absolute_path = worktree.absolutize(&file.path);
224
225                                if let Ok(language) = language_registry
226                                    .language_for_file(&absolute_path, None)
227                                    .await
228                                {
229                                    if language.name().as_ref() != "Rust" {
230                                        continue;
231                                    }
232
233                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
234                                        log::info!("loaded file: {absolute_path:?}");
235
236                                        let path_buf = file.path.to_path_buf();
237                                        let already_stored = file_hashes
238                                            .get(&path_buf)
239                                            .map_or(false, |existing_hash| {
240                                                existing_hash.equals(&content)
241                                            });
242
243                                        if !already_stored {
244                                            log::info!(
245                                                "File Changed (Sending to Parse): {:?}",
246                                                &path_buf
247                                            );
248                                            paths_tx
249                                                .try_send((
250                                                    worktree_id,
251                                                    path_buf,
252                                                    content,
253                                                    language,
254                                                ))
255                                                .unwrap();
256                                        }
257                                    }
258                                }
259                            }
260                        }
261                    }
262                })
263                .detach();
264
265            let db_write_task = cx.background().spawn(
266                async move {
267                    // Initialize Database, creates database and tables if not exists
268                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
269                        db.insert_file(worktree_id, indexed_file).log_err();
270                    }
271
272                    // ALL OF THE BELOW IS FOR TESTING,
273                    // This should be removed as we find and appropriate place for evaluate our search.
274
275                    // let queries = vec![
276                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
277                    //         "compute an outline view of all of the symbols in a buffer",
278                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
279                    // ];
280                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
281
282                    // let t2 = Instant::now();
283                    // let documents = db.get_documents().unwrap();
284                    // let files = db.get_files().unwrap();
285                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
286
287                    // let t1 = Instant::now();
288                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
289                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
290                    // for (idx, embed) in embeddings.into_iter().enumerate() {
291                    //     let t0 = Instant::now();
292                    //     println!("\nQuery: {:?}", queries[idx]);
293                    //     let results = bfs.top_k_search(&embed, 5).await;
294                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
295                    //     for (id, distance) in results {
296                    //         println!("");
297                    //         println!("   distance: {:?}", distance);
298                    //         println!("   document: {:?}", documents[&id].name);
299                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
300                    //     }
301
302                    // }
303
304                    anyhow::Ok(())
305                }
306                .log_err(),
307            );
308
309            cx.background()
310                .scoped(|scope| {
311                    for _ in 0..cx.background().num_cpus() {
312                        scope.spawn(async {
313                            let mut parser = Parser::new();
314                            let mut cursor = QueryCursor::new();
315                            while let Ok((worktree_id, file_path, content, language)) =
316                                paths_rx.recv().await
317                            {
318                                if let Some(indexed_file) = Self::index_file(
319                                    &mut cursor,
320                                    &mut parser,
321                                    embedding_provider.as_ref(),
322                                    language,
323                                    file_path,
324                                    content,
325                                )
326                                .await
327                                .log_err()
328                                {
329                                    indexed_files_tx
330                                        .try_send((worktree_id, indexed_file))
331                                        .unwrap();
332                                }
333                            }
334                        });
335                    }
336                })
337                .await;
338            drop(indexed_files_tx);
339
340            db_write_task.await;
341            anyhow::Ok(())
342        })
343    }
344
345    pub fn search(
346        &mut self,
347        phrase: String,
348        limit: usize,
349        cx: &mut ModelContext<Self>,
350    ) -> Task<Result<Vec<SearchResult>>> {
351        let embedding_provider = self.embedding_provider.clone();
352        let database_url = self.database_url.clone();
353        cx.background().spawn(async move {
354            let database = VectorDatabase::new(database_url.as_ref())?;
355
356            let phrase_embedding = embedding_provider
357                .embed_batch(vec![&phrase])
358                .await?
359                .into_iter()
360                .next()
361                .unwrap();
362
363            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
364            database.for_each_document(0, |id, embedding| {
365                let similarity = dot(&embedding.0, &phrase_embedding);
366                let ix = match results.binary_search_by(|(_, s)| {
367                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
368                }) {
369                    Ok(ix) => ix,
370                    Err(ix) => ix,
371                };
372                results.insert(ix, (id, similarity));
373                results.truncate(limit);
374            })?;
375
376            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
377            let documents = database.get_documents_by_ids(&ids)?;
378
379            anyhow::Ok(
380                documents
381                    .into_iter()
382                    .map(|(file_path, offset, name)| SearchResult {
383                        name,
384                        offset,
385                        file_path,
386                    })
387                    .collect(),
388            )
389        })
390    }
391}
392
393impl Entity for VectorStore {
394    type Event = ();
395}
396
397fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
398    let len = vec_a.len();
399    assert_eq!(len, vec_b.len());
400
401    let mut result = 0.0;
402    unsafe {
403        matrixmultiply::sgemm(
404            1,
405            len,
406            1,
407            1.0,
408            vec_a.as_ptr(),
409            len as isize,
410            1,
411            vec_b.as_ptr(),
412            1,
413            len as isize,
414            0.0,
415            &mut result as *mut f32,
416            1,
417            1,
418        );
419    }
420    result
421}