vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
 10use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
 11use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 12use language::{Language, LanguageRegistry};
 13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 14use project::{Fs, Project, WorktreeId};
 15use smol::channel;
 16use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 17use tree_sitter::{Parser, QueryCursor};
 18use util::{http::HttpClient, ResultExt, TryFutureExt};
 19use workspace::{Workspace, WorkspaceCreated};
 20
 21#[derive(Debug)]
 22pub struct Document {
 23    pub offset: usize,
 24    pub name: String,
 25    pub embedding: Vec<f32>,
 26}
 27
 28pub fn init(
 29    fs: Arc<dyn Fs>,
 30    http_client: Arc<dyn HttpClient>,
 31    language_registry: Arc<LanguageRegistry>,
 32    cx: &mut AppContext,
 33) {
 34    let vector_store = cx.add_model(|cx| {
 35        VectorStore::new(
 36            fs,
 37            VECTOR_DB_URL.to_string(),
 38            // Arc::new(OpenAIEmbeddings {
 39            //     client: http_client,
 40            // }),
 41            Arc::new(DummyEmbeddings {}),
 42            language_registry,
 43        )
 44    });
 45
 46    cx.subscribe_global::<WorkspaceCreated, _>({
 47        let vector_store = vector_store.clone();
 48        move |event, cx| {
 49            let workspace = &event.0;
 50            if let Some(workspace) = workspace.upgrade(cx) {
 51                let project = workspace.read(cx).project().clone();
 52                if project.read(cx).is_local() {
 53                    vector_store.update(cx, |store, cx| {
 54                        store.add_project(project, cx).detach();
 55                    });
 56                }
 57            }
 58        }
 59    })
 60    .detach();
 61
 62    cx.add_action({
 63        move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 64            let vector_store = vector_store.clone();
 65            workspace.toggle_modal(cx, |workspace, cx| {
 66                let project = workspace.project().clone();
 67                let workspace = cx.weak_handle();
 68                cx.add_view(|cx| {
 69                    SemanticSearch::new(
 70                        SemanticSearchDelegate::new(workspace, project, vector_store),
 71                        cx,
 72                    )
 73                })
 74            })
 75        }
 76    });
 77    SemanticSearch::init(cx);
 78}
 79
 80#[derive(Debug)]
 81pub struct IndexedFile {
 82    path: PathBuf,
 83    sha1: FileSha1,
 84    documents: Vec<Document>,
 85}
 86
 87pub struct VectorStore {
 88    fs: Arc<dyn Fs>,
 89    database_url: Arc<str>,
 90    embedding_provider: Arc<dyn EmbeddingProvider>,
 91    language_registry: Arc<LanguageRegistry>,
 92    worktree_db_ids: Vec<(WorktreeId, i64)>,
 93}
 94
 95#[derive(Debug)]
 96pub struct SearchResult {
 97    pub worktree_id: WorktreeId,
 98    pub name: String,
 99    pub offset: usize,
100    pub file_path: PathBuf,
101}
102
103impl VectorStore {
104    fn new(
105        fs: Arc<dyn Fs>,
106        database_url: String,
107        embedding_provider: Arc<dyn EmbeddingProvider>,
108        language_registry: Arc<LanguageRegistry>,
109    ) -> Self {
110        Self {
111            fs,
112            database_url: database_url.into(),
113            embedding_provider,
114            language_registry,
115            worktree_db_ids: Vec::new(),
116        }
117    }
118
119    async fn index_file(
120        cursor: &mut QueryCursor,
121        parser: &mut Parser,
122        embedding_provider: &dyn EmbeddingProvider,
123        language: Arc<Language>,
124        file_path: PathBuf,
125        content: String,
126    ) -> Result<IndexedFile> {
127        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
128        let outline_config = grammar
129            .outline_config
130            .as_ref()
131            .ok_or_else(|| anyhow!("no outline query"))?;
132
133        parser.set_language(grammar.ts_language).unwrap();
134        let tree = parser
135            .parse(&content, None)
136            .ok_or_else(|| anyhow!("parsing failed"))?;
137
138        let mut documents = Vec::new();
139        let mut context_spans = Vec::new();
140        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
141            let mut item_range = None;
142            let mut name_range = None;
143            for capture in mat.captures {
144                if capture.index == outline_config.item_capture_ix {
145                    item_range = Some(capture.node.byte_range());
146                } else if capture.index == outline_config.name_capture_ix {
147                    name_range = Some(capture.node.byte_range());
148                }
149            }
150
151            if let Some((item_range, name_range)) = item_range.zip(name_range) {
152                if let Some((item, name)) =
153                    content.get(item_range.clone()).zip(content.get(name_range))
154                {
155                    context_spans.push(item);
156                    documents.push(Document {
157                        name: name.to_string(),
158                        offset: item_range.start,
159                        embedding: Vec::new(),
160                    });
161                }
162            }
163        }
164
165        if !documents.is_empty() {
166            let embeddings = embedding_provider.embed_batch(context_spans).await?;
167            for (document, embedding) in documents.iter_mut().zip(embeddings) {
168                document.embedding = embedding;
169            }
170        }
171
172        let sha1 = FileSha1::from_str(content);
173
174        return Ok(IndexedFile {
175            path: file_path,
176            sha1,
177            documents,
178        });
179    }
180
181    fn add_project(
182        &mut self,
183        project: ModelHandle<Project>,
184        cx: &mut ModelContext<Self>,
185    ) -> Task<Result<()>> {
186        let worktree_scans_complete = project
187            .read(cx)
188            .worktrees(cx)
189            .map(|worktree| {
190                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
191                async move {
192                    scan_complete.await;
193                    log::info!("worktree scan completed");
194                }
195            })
196            .collect::<Vec<_>>();
197
198        let fs = self.fs.clone();
199        let language_registry = self.language_registry.clone();
200        let embedding_provider = self.embedding_provider.clone();
201        let database_url = self.database_url.clone();
202
203        cx.spawn(|this, mut cx| async move {
204            futures::future::join_all(worktree_scans_complete).await;
205
206            // TODO: remove this after fixing the bug in scan_complete
207            cx.background()
208                .timer(std::time::Duration::from_secs(3))
209                .await;
210
211            let db = VectorDatabase::new(&database_url)?;
212
213            let worktrees = project.read_with(&cx, |project, cx| {
214                project
215                    .worktrees(cx)
216                    .map(|worktree| worktree.read(cx).snapshot())
217                    .collect::<Vec<_>>()
218            });
219
220            // Here we query the worktree ids, and yet we dont have them elsewhere
221            // We likely want to clean up these datastructures
222            let (db, worktree_hashes, worktree_db_ids) = cx
223                .background()
224                .spawn({
225                    let worktrees = worktrees.clone();
226                    async move {
227                        let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
228                        let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
229                            HashMap::new();
230                        for worktree in worktrees {
231                            let worktree_db_id =
232                                db.find_or_create_worktree(worktree.abs_path().as_ref())?;
233                            worktree_db_ids.insert(worktree.id(), worktree_db_id);
234                            hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
235                        }
236                        anyhow::Ok((db, hashes, worktree_db_ids))
237                    }
238                })
239                .await?;
240
241            let (paths_tx, paths_rx) =
242                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
243            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
244            cx.background()
245                .spawn({
246                    let fs = fs.clone();
247                    let worktree_db_ids = worktree_db_ids.clone();
248                    async move {
249                        for worktree in worktrees.into_iter() {
250                            let file_hashes = &worktree_hashes[&worktree.id()];
251                            for file in worktree.files(false, 0) {
252                                let absolute_path = worktree.absolutize(&file.path);
253
254                                if let Ok(language) = language_registry
255                                    .language_for_file(&absolute_path, None)
256                                    .await
257                                {
258                                    if language.name().as_ref() != "Rust" {
259                                        continue;
260                                    }
261
262                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
263                                        log::info!("loaded file: {absolute_path:?}");
264
265                                        let path_buf = file.path.to_path_buf();
266                                        let already_stored = file_hashes
267                                            .get(&path_buf)
268                                            .map_or(false, |existing_hash| {
269                                                existing_hash.equals(&content)
270                                            });
271
272                                        if !already_stored {
273                                            log::info!(
274                                                "File Changed (Sending to Parse): {:?}",
275                                                &path_buf
276                                            );
277                                            paths_tx
278                                                .try_send((
279                                                    worktree_db_ids[&worktree.id()],
280                                                    path_buf,
281                                                    content,
282                                                    language,
283                                                ))
284                                                .unwrap();
285                                        }
286                                    }
287                                }
288                            }
289                        }
290                    }
291                })
292                .detach();
293
294            let db_write_task = cx.background().spawn(
295                async move {
296                    // Initialize Database, creates database and tables if not exists
297                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
298                        db.insert_file(worktree_id, indexed_file).log_err();
299                    }
300
301                    // ALL OF THE BELOW IS FOR TESTING,
302                    // This should be removed as we find and appropriate place for evaluate our search.
303
304                    // let queries = vec![
305                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
306                    //         "compute an outline view of all of the symbols in a buffer",
307                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
308                    // ];
309                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
310
311                    // let t2 = Instant::now();
312                    // let documents = db.get_documents().unwrap();
313                    // let files = db.get_files().unwrap();
314                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
315
316                    // let t1 = Instant::now();
317                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
318                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
319                    // for (idx, embed) in embeddings.into_iter().enumerate() {
320                    //     let t0 = Instant::now();
321                    //     println!("\nQuery: {:?}", queries[idx]);
322                    //     let results = bfs.top_k_search(&embed, 5).await;
323                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
324                    //     for (id, distance) in results {
325                    //         println!("");
326                    //         println!("   distance: {:?}", distance);
327                    //         println!("   document: {:?}", documents[&id].name);
328                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
329                    //     }
330
331                    // }
332
333                    anyhow::Ok(())
334                }
335                .log_err(),
336            );
337
338            cx.background()
339                .scoped(|scope| {
340                    for _ in 0..cx.background().num_cpus() {
341                        scope.spawn(async {
342                            let mut parser = Parser::new();
343                            let mut cursor = QueryCursor::new();
344                            while let Ok((worktree_id, file_path, content, language)) =
345                                paths_rx.recv().await
346                            {
347                                if let Some(indexed_file) = Self::index_file(
348                                    &mut cursor,
349                                    &mut parser,
350                                    embedding_provider.as_ref(),
351                                    language,
352                                    file_path,
353                                    content,
354                                )
355                                .await
356                                .log_err()
357                                {
358                                    indexed_files_tx
359                                        .try_send((worktree_id, indexed_file))
360                                        .unwrap();
361                                }
362                            }
363                        });
364                    }
365                })
366                .await;
367            drop(indexed_files_tx);
368
369            db_write_task.await;
370
371            this.update(&mut cx, |this, _| {
372                this.worktree_db_ids.extend(worktree_db_ids);
373            });
374
375            anyhow::Ok(())
376        })
377    }
378
379    pub fn search(
380        &mut self,
381        project: &ModelHandle<Project>,
382        phrase: String,
383        limit: usize,
384        cx: &mut ModelContext<Self>,
385    ) -> Task<Result<Vec<SearchResult>>> {
386        let project = project.read(cx);
387        let worktree_db_ids = project
388            .worktrees(cx)
389            .filter_map(|worktree| {
390                let worktree_id = worktree.read(cx).id();
391                self.worktree_db_ids.iter().find_map(|(id, db_id)| {
392                    if *id == worktree_id {
393                        Some(*db_id)
394                    } else {
395                        None
396                    }
397                })
398            })
399            .collect::<Vec<_>>();
400
401        let embedding_provider = self.embedding_provider.clone();
402        let database_url = self.database_url.clone();
403        cx.spawn(|this, cx| async move {
404            let documents = cx
405                .background()
406                .spawn(async move {
407                    let database = VectorDatabase::new(database_url.as_ref())?;
408
409                    let phrase_embedding = embedding_provider
410                        .embed_batch(vec![&phrase])
411                        .await?
412                        .into_iter()
413                        .next()
414                        .unwrap();
415
416                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
417                    database.for_each_document(&worktree_db_ids, |id, embedding| {
418                        let similarity = dot(&embedding.0, &phrase_embedding);
419                        let ix = match results.binary_search_by(|(_, s)| {
420                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
421                        }) {
422                            Ok(ix) => ix,
423                            Err(ix) => ix,
424                        };
425                        results.insert(ix, (id, similarity));
426                        results.truncate(limit);
427                    })?;
428
429                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
430                    database.get_documents_by_ids(&ids)
431                })
432                .await?;
433
434            let results = this.read_with(&cx, |this, _| {
435                documents
436                    .into_iter()
437                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
438                        let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
439                            if *db_id == worktree_db_id {
440                                Some(*id)
441                            } else {
442                                None
443                            }
444                        })?;
445                        Some(SearchResult {
446                            worktree_id,
447                            name,
448                            offset,
449                            file_path,
450                        })
451                    })
452                    .collect()
453            });
454
455            anyhow::Ok(results)
456        })
457    }
458}
459
460impl Entity for VectorStore {
461    type Event = ();
462}
463
464fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
465    let len = vec_a.len();
466    assert_eq!(len, vec_b.len());
467
468    let mut result = 0.0;
469    unsafe {
470        matrixmultiply::sgemm(
471            1,
472            len,
473            1,
474            1.0,
475            vec_a.as_ptr(),
476            len as isize,
477            1,
478            vec_b.as_ptr(),
479            1,
480            len as isize,
481            0.0,
482            &mut result as *mut f32,
483            1,
484            1,
485        );
486    }
487    result
488}