vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
 10use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
 11use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 12use language::{Language, LanguageRegistry};
 13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 14use project::{Fs, Project, WorktreeId};
 15use smol::channel;
 16use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 17use tree_sitter::{Parser, QueryCursor};
 18use util::{http::HttpClient, ResultExt, TryFutureExt};
 19use workspace::{Workspace, WorkspaceCreated};
 20
 21#[derive(Debug)]
 22pub struct Document {
 23    pub offset: usize,
 24    pub name: String,
 25    pub embedding: Vec<f32>,
 26}
 27
 28pub fn init(
 29    fs: Arc<dyn Fs>,
 30    http_client: Arc<dyn HttpClient>,
 31    language_registry: Arc<LanguageRegistry>,
 32    cx: &mut AppContext,
 33) {
 34    let vector_store = cx.add_model(|_| {
 35        VectorStore::new(
 36            fs,
 37            VECTOR_DB_URL.to_string(),
 38            // Arc::new(DummyEmbeddings {}),
 39            Arc::new(OpenAIEmbeddings {
 40                client: http_client,
 41            }),
 42            language_registry,
 43        )
 44    });
 45
 46    cx.subscribe_global::<WorkspaceCreated, _>({
 47        let vector_store = vector_store.clone();
 48        move |event, cx| {
 49            let workspace = &event.0;
 50            if let Some(workspace) = workspace.upgrade(cx) {
 51                let project = workspace.read(cx).project().clone();
 52                if project.read(cx).is_local() {
 53                    vector_store.update(cx, |store, cx| {
 54                        store.add_project(project, cx).detach();
 55                    });
 56                }
 57            }
 58        }
 59    })
 60    .detach();
 61
 62    cx.add_action({
 63        move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 64            let vector_store = vector_store.clone();
 65            workspace.toggle_modal(cx, |workspace, cx| {
 66                let project = workspace.project().clone();
 67                let workspace = cx.weak_handle();
 68                cx.add_view(|cx| {
 69                    SemanticSearch::new(
 70                        SemanticSearchDelegate::new(workspace, project, vector_store),
 71                        cx,
 72                    )
 73                })
 74            })
 75        }
 76    });
 77
 78    SemanticSearch::init(cx);
 79}
 80
 81#[derive(Debug)]
 82pub struct IndexedFile {
 83    path: PathBuf,
 84    sha1: FileSha1,
 85    documents: Vec<Document>,
 86}
 87
 88pub struct VectorStore {
 89    fs: Arc<dyn Fs>,
 90    database_url: Arc<str>,
 91    embedding_provider: Arc<dyn EmbeddingProvider>,
 92    language_registry: Arc<LanguageRegistry>,
 93    worktree_db_ids: Vec<(WorktreeId, i64)>,
 94}
 95
 96#[derive(Debug)]
 97pub struct SearchResult {
 98    pub worktree_id: WorktreeId,
 99    pub name: String,
100    pub offset: usize,
101    pub file_path: PathBuf,
102}
103
104impl VectorStore {
105    fn new(
106        fs: Arc<dyn Fs>,
107        database_url: String,
108        embedding_provider: Arc<dyn EmbeddingProvider>,
109        language_registry: Arc<LanguageRegistry>,
110    ) -> Self {
111        Self {
112            fs,
113            database_url: database_url.into(),
114            embedding_provider,
115            language_registry,
116            worktree_db_ids: Vec::new(),
117        }
118    }
119
120    async fn index_file(
121        cursor: &mut QueryCursor,
122        parser: &mut Parser,
123        embedding_provider: &dyn EmbeddingProvider,
124        language: Arc<Language>,
125        file_path: PathBuf,
126        content: String,
127    ) -> Result<IndexedFile> {
128        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
129        let outline_config = grammar
130            .outline_config
131            .as_ref()
132            .ok_or_else(|| anyhow!("no outline query"))?;
133
134        parser.set_language(grammar.ts_language).unwrap();
135        let tree = parser
136            .parse(&content, None)
137            .ok_or_else(|| anyhow!("parsing failed"))?;
138
139        let mut documents = Vec::new();
140        let mut context_spans = Vec::new();
141        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
142            let mut item_range = None;
143            let mut name_range = None;
144            for capture in mat.captures {
145                if capture.index == outline_config.item_capture_ix {
146                    item_range = Some(capture.node.byte_range());
147                } else if capture.index == outline_config.name_capture_ix {
148                    name_range = Some(capture.node.byte_range());
149                }
150            }
151
152            if let Some((item_range, name_range)) = item_range.zip(name_range) {
153                if let Some((item, name)) =
154                    content.get(item_range.clone()).zip(content.get(name_range))
155                {
156                    context_spans.push(item);
157                    documents.push(Document {
158                        name: name.to_string(),
159                        offset: item_range.start,
160                        embedding: Vec::new(),
161                    });
162                }
163            }
164        }
165
166        if !documents.is_empty() {
167            let embeddings = embedding_provider.embed_batch(context_spans).await?;
168            for (document, embedding) in documents.iter_mut().zip(embeddings) {
169                document.embedding = embedding;
170            }
171        }
172
173        let sha1 = FileSha1::from_str(content);
174
175        return Ok(IndexedFile {
176            path: file_path,
177            sha1,
178            documents,
179        });
180    }
181
182    fn add_project(
183        &mut self,
184        project: ModelHandle<Project>,
185        cx: &mut ModelContext<Self>,
186    ) -> Task<Result<()>> {
187        let worktree_scans_complete = project
188            .read(cx)
189            .worktrees(cx)
190            .map(|worktree| {
191                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
192                async move {
193                    scan_complete.await;
194                    log::info!("worktree scan completed");
195                }
196            })
197            .collect::<Vec<_>>();
198
199        let fs = self.fs.clone();
200        let language_registry = self.language_registry.clone();
201        let embedding_provider = self.embedding_provider.clone();
202        let database_url = self.database_url.clone();
203
204        cx.spawn(|this, mut cx| async move {
205            futures::future::join_all(worktree_scans_complete).await;
206
207            // TODO: remove this after fixing the bug in scan_complete
208            cx.background()
209                .timer(std::time::Duration::from_secs(3))
210                .await;
211
212            let db = VectorDatabase::new(&database_url)?;
213
214            let worktrees = project.read_with(&cx, |project, cx| {
215                project
216                    .worktrees(cx)
217                    .map(|worktree| worktree.read(cx).snapshot())
218                    .collect::<Vec<_>>()
219            });
220
221            // Here we query the worktree ids, and yet we dont have them elsewhere
222            // We likely want to clean up these datastructures
223            let (db, worktree_hashes, worktree_db_ids) = cx
224                .background()
225                .spawn({
226                    let worktrees = worktrees.clone();
227                    async move {
228                        let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
229                        let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
230                            HashMap::new();
231                        for worktree in worktrees {
232                            let worktree_db_id =
233                                db.find_or_create_worktree(worktree.abs_path().as_ref())?;
234                            worktree_db_ids.insert(worktree.id(), worktree_db_id);
235                            hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
236                        }
237                        anyhow::Ok((db, hashes, worktree_db_ids))
238                    }
239                })
240                .await?;
241
242            let (paths_tx, paths_rx) =
243                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
244            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
245            cx.background()
246                .spawn({
247                    let fs = fs.clone();
248                    let worktree_db_ids = worktree_db_ids.clone();
249                    async move {
250                        for worktree in worktrees.into_iter() {
251                            let file_hashes = &worktree_hashes[&worktree.id()];
252                            for file in worktree.files(false, 0) {
253                                let absolute_path = worktree.absolutize(&file.path);
254
255                                if let Ok(language) = language_registry
256                                    .language_for_file(&absolute_path, None)
257                                    .await
258                                {
259                                    if language.name().as_ref() != "Rust" {
260                                        continue;
261                                    }
262
263                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
264                                        log::info!("loaded file: {absolute_path:?}");
265
266                                        let path_buf = file.path.to_path_buf();
267                                        let already_stored = file_hashes
268                                            .get(&path_buf)
269                                            .map_or(false, |existing_hash| {
270                                                existing_hash.equals(&content)
271                                            });
272
273                                        if !already_stored {
274                                            log::info!(
275                                                "File Changed (Sending to Parse): {:?}",
276                                                &path_buf
277                                            );
278                                            paths_tx
279                                                .try_send((
280                                                    worktree_db_ids[&worktree.id()],
281                                                    path_buf,
282                                                    content,
283                                                    language,
284                                                ))
285                                                .unwrap();
286                                        }
287                                    }
288                                }
289                            }
290                        }
291                    }
292                })
293                .detach();
294
295            let db_write_task = cx.background().spawn(
296                async move {
297                    // Initialize Database, creates database and tables if not exists
298                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
299                        db.insert_file(worktree_id, indexed_file).log_err();
300                    }
301
302                    // ALL OF THE BELOW IS FOR TESTING,
303                    // This should be removed as we find and appropriate place for evaluate our search.
304
305                    // let queries = vec![
306                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
307                    //         "compute an outline view of all of the symbols in a buffer",
308                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
309                    // ];
310                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
311
312                    // let t2 = Instant::now();
313                    // let documents = db.get_documents().unwrap();
314                    // let files = db.get_files().unwrap();
315                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
316
317                    // let t1 = Instant::now();
318                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
319                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
320                    // for (idx, embed) in embeddings.into_iter().enumerate() {
321                    //     let t0 = Instant::now();
322                    //     println!("\nQuery: {:?}", queries[idx]);
323                    //     let results = bfs.top_k_search(&embed, 5).await;
324                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
325                    //     for (id, distance) in results {
326                    //         println!("");
327                    //         println!("   distance: {:?}", distance);
328                    //         println!("   document: {:?}", documents[&id].name);
329                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
330                    //     }
331
332                    // }
333
334                    anyhow::Ok(())
335                }
336                .log_err(),
337            );
338
339            cx.background()
340                .scoped(|scope| {
341                    for _ in 0..cx.background().num_cpus() {
342                        scope.spawn(async {
343                            let mut parser = Parser::new();
344                            let mut cursor = QueryCursor::new();
345                            while let Ok((worktree_id, file_path, content, language)) =
346                                paths_rx.recv().await
347                            {
348                                if let Some(indexed_file) = Self::index_file(
349                                    &mut cursor,
350                                    &mut parser,
351                                    embedding_provider.as_ref(),
352                                    language,
353                                    file_path,
354                                    content,
355                                )
356                                .await
357                                .log_err()
358                                {
359                                    indexed_files_tx
360                                        .try_send((worktree_id, indexed_file))
361                                        .unwrap();
362                                }
363                            }
364                        });
365                    }
366                })
367                .await;
368            drop(indexed_files_tx);
369
370            db_write_task.await;
371
372            this.update(&mut cx, |this, _| {
373                this.worktree_db_ids.extend(worktree_db_ids);
374            });
375
376            anyhow::Ok(())
377        })
378    }
379
380    pub fn search(
381        &mut self,
382        project: &ModelHandle<Project>,
383        phrase: String,
384        limit: usize,
385        cx: &mut ModelContext<Self>,
386    ) -> Task<Result<Vec<SearchResult>>> {
387        let project = project.read(cx);
388        let worktree_db_ids = project
389            .worktrees(cx)
390            .filter_map(|worktree| {
391                let worktree_id = worktree.read(cx).id();
392                self.worktree_db_ids.iter().find_map(|(id, db_id)| {
393                    if *id == worktree_id {
394                        Some(*db_id)
395                    } else {
396                        None
397                    }
398                })
399            })
400            .collect::<Vec<_>>();
401
402        let embedding_provider = self.embedding_provider.clone();
403        let database_url = self.database_url.clone();
404        cx.spawn(|this, cx| async move {
405            let documents = cx
406                .background()
407                .spawn(async move {
408                    let database = VectorDatabase::new(database_url.as_ref())?;
409
410                    let phrase_embedding = embedding_provider
411                        .embed_batch(vec![&phrase])
412                        .await?
413                        .into_iter()
414                        .next()
415                        .unwrap();
416
417                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
418                    database.for_each_document(&worktree_db_ids, |id, embedding| {
419                        let similarity = dot(&embedding.0, &phrase_embedding);
420                        let ix = match results.binary_search_by(|(_, s)| {
421                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
422                        }) {
423                            Ok(ix) => ix,
424                            Err(ix) => ix,
425                        };
426                        results.insert(ix, (id, similarity));
427                        results.truncate(limit);
428                    })?;
429
430                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
431                    database.get_documents_by_ids(&ids)
432                })
433                .await?;
434
435            let results = this.read_with(&cx, |this, _| {
436                documents
437                    .into_iter()
438                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
439                        let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
440                            if *db_id == worktree_db_id {
441                                Some(*id)
442                            } else {
443                                None
444                            }
445                        })?;
446                        Some(SearchResult {
447                            worktree_id,
448                            name,
449                            offset,
450                            file_path,
451                        })
452                    })
453                    .collect()
454            });
455
456            anyhow::Ok(results)
457        })
458    }
459}
460
461impl Entity for VectorStore {
462    type Event = ();
463}
464
465fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
466    let len = vec_a.len();
467    assert_eq!(len, vec_b.len());
468
469    let mut result = 0.0;
470    unsafe {
471        matrixmultiply::sgemm(
472            1,
473            len,
474            1,
475            1.0,
476            vec_a.as_ptr(),
477            len as isize,
478            1,
479            vec_b.as_ptr(),
480            1,
481            len as isize,
482            0.0,
483            &mut result as *mut f32,
484            1,
485            1,
486        );
487    }
488    result
489}