vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::VectorDatabase;
 10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 11use futures::{channel::oneshot, Future};
 12use gpui::{
 13    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
 14    WeakModelHandle,
 15};
 16use language::{Language, LanguageRegistry};
 17use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 18use project::{Fs, Project, WorktreeId};
 19use smol::channel;
 20use std::{
 21    cmp::Ordering,
 22    collections::HashMap,
 23    path::{Path, PathBuf},
 24    sync::Arc,
 25    time::{Instant, SystemTime},
 26};
 27use tree_sitter::{Parser, QueryCursor};
 28use util::{
 29    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 30    http::HttpClient,
 31    paths::EMBEDDINGS_DIR,
 32    ResultExt,
 33};
 34use workspace::{Workspace, WorkspaceCreated};
 35
 36const REINDEXING_DELAY: u64 = 30;
 37const EMBEDDINGS_BATCH_SIZE: usize = 25;
 38
 39#[derive(Debug, Clone)]
 40pub struct Document {
 41    pub offset: usize,
 42    pub name: String,
 43    pub embedding: Vec<f32>,
 44}
 45
 46pub fn init(
 47    fs: Arc<dyn Fs>,
 48    http_client: Arc<dyn HttpClient>,
 49    language_registry: Arc<LanguageRegistry>,
 50    cx: &mut AppContext,
 51) {
 52    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 53        return;
 54    }
 55
 56    let db_file_path = EMBEDDINGS_DIR
 57        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 58        .join("embeddings_db");
 59
 60    cx.spawn(move |mut cx| async move {
 61        let vector_store = VectorStore::new(
 62            fs,
 63            db_file_path,
 64            // Arc::new(embedding::DummyEmbeddings {}),
 65            Arc::new(OpenAIEmbeddings {
 66                client: http_client,
 67            }),
 68            language_registry,
 69            cx.clone(),
 70        )
 71        .await?;
 72
 73        cx.update(|cx| {
 74            cx.subscribe_global::<WorkspaceCreated, _>({
 75                let vector_store = vector_store.clone();
 76                move |event, cx| {
 77                    let workspace = &event.0;
 78                    if let Some(workspace) = workspace.upgrade(cx) {
 79                        let project = workspace.read(cx).project().clone();
 80                        if project.read(cx).is_local() {
 81                            vector_store.update(cx, |store, cx| {
 82                                store.add_project(project, cx).detach();
 83                            });
 84                        }
 85                    }
 86                }
 87            })
 88            .detach();
 89
 90            cx.add_action({
 91                move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 92                    let vector_store = vector_store.clone();
 93                    workspace.toggle_modal(cx, |workspace, cx| {
 94                        let project = workspace.project().clone();
 95                        let workspace = cx.weak_handle();
 96                        cx.add_view(|cx| {
 97                            SemanticSearch::new(
 98                                SemanticSearchDelegate::new(workspace, project, vector_store),
 99                                cx,
100                            )
101                        })
102                    })
103                }
104            });
105
106            SemanticSearch::init(cx);
107        });
108
109        anyhow::Ok(())
110    })
111    .detach();
112}
113
114#[derive(Debug, Clone)]
115pub struct IndexedFile {
116    path: PathBuf,
117    mtime: SystemTime,
118    documents: Vec<Document>,
119}
120
121pub struct VectorStore {
122    fs: Arc<dyn Fs>,
123    database_url: Arc<PathBuf>,
124    embedding_provider: Arc<dyn EmbeddingProvider>,
125    language_registry: Arc<LanguageRegistry>,
126    db_update_tx: channel::Sender<DbWrite>,
127    paths_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
128    _db_update_task: Task<()>,
129    _paths_update_task: Task<()>,
130    _embeddings_update_task: Task<()>,
131    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
132}
133
134struct ProjectState {
135    worktree_db_ids: Vec<(WorktreeId, i64)>,
136    _subscription: gpui::Subscription,
137}
138
139#[derive(Debug, Clone)]
140pub struct SearchResult {
141    pub worktree_id: WorktreeId,
142    pub name: String,
143    pub offset: usize,
144    pub file_path: PathBuf,
145}
146
147enum DbWrite {
148    InsertFile {
149        worktree_id: i64,
150        indexed_file: IndexedFile,
151    },
152    Delete {
153        worktree_id: i64,
154        path: PathBuf,
155    },
156    FindOrCreateWorktree {
157        path: PathBuf,
158        sender: oneshot::Sender<Result<i64>>,
159    },
160}
161
162impl VectorStore {
163    async fn new(
164        fs: Arc<dyn Fs>,
165        database_url: PathBuf,
166        embedding_provider: Arc<dyn EmbeddingProvider>,
167        language_registry: Arc<LanguageRegistry>,
168        mut cx: AsyncAppContext,
169    ) -> Result<ModelHandle<Self>> {
170        let database_url = Arc::new(database_url);
171
172        let db = cx
173            .background()
174            .spawn({
175                let fs = fs.clone();
176                let database_url = database_url.clone();
177                async move {
178                    if let Some(db_directory) = database_url.parent() {
179                        fs.create_dir(db_directory).await.log_err();
180                    }
181
182                    let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
183                    anyhow::Ok(db)
184                }
185            })
186            .await?;
187
188        Ok(cx.add_model(|cx| {
189            // paths_tx -> embeddings_tx -> db_update_tx
190
191            let (db_update_tx, db_update_rx) = channel::unbounded();
192            let (paths_tx, paths_rx) =
193                channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
194            let (embeddings_tx, embeddings_rx) =
195                channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
196
197            let _db_update_task = cx.background().spawn(async move {
198                while let Ok(job) = db_update_rx.recv().await {
199                    match job {
200                        DbWrite::InsertFile {
201                            worktree_id,
202                            indexed_file,
203                        } => {
204                            db.insert_file(worktree_id, indexed_file).log_err();
205                        }
206                        DbWrite::Delete { worktree_id, path } => {
207                            db.delete_file(worktree_id, path).log_err();
208                        }
209                        DbWrite::FindOrCreateWorktree { path, sender } => {
210                            let id = db.find_or_create_worktree(&path);
211                            sender.send(id).ok();
212                        }
213                    }
214                }
215            });
216
217            async fn embed_batch(
218                embeddings_queue: Vec<(i64, IndexedFile, Vec<String>)>,
219                embedding_provider: &Arc<dyn EmbeddingProvider>,
220                db_update_tx: channel::Sender<DbWrite>,
221            ) -> Result<()> {
222                let mut embeddings_queue = embeddings_queue.clone();
223
224                let mut document_spans = vec![];
225                for (_, _, document_span) in embeddings_queue.clone().into_iter() {
226                    document_spans.extend(document_span);
227                }
228
229                let mut embeddings = embedding_provider
230                    .embed_batch(document_spans.iter().map(|x| &**x).collect())
231                    .await?;
232
233                // This assumes the embeddings are returned in order
234                let t0 = Instant::now();
235                let mut i = 0;
236                let mut j = 0;
237                while let Some(embedding) = embeddings.pop() {
238                    // This has to accomodate for multiple indexed_files in a row without documents
239                    while embeddings_queue[i].1.documents.len() == j {
240                        i += 1;
241                        j = 0;
242                    }
243
244                    embeddings_queue[i].1.documents[j].embedding = embedding;
245                    j += 1;
246                }
247
248                for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
249                    // TODO: Update this so it doesnt panic
250                    for document in indexed_file.documents.iter() {
251                        assert!(
252                            document.embedding.len() > 0,
253                            "Document Embedding not Complete"
254                        );
255                    }
256
257                    db_update_tx
258                        .send(DbWrite::InsertFile {
259                            worktree_id,
260                            indexed_file,
261                        })
262                        .await
263                        .unwrap();
264                }
265
266                anyhow::Ok(())
267            }
268
269            let embedding_provider_clone = embedding_provider.clone();
270
271            let db_update_tx_clone = db_update_tx.clone();
272            let _embeddings_update_task = cx.background().spawn(async move {
273                let mut queue_len = 0;
274                let mut embeddings_queue = vec![];
275                let mut request_count = 0;
276                while let Ok((worktree_id, indexed_file, document_spans)) =
277                    embeddings_rx.recv().await
278                {
279                    queue_len += &document_spans.len();
280                    embeddings_queue.push((worktree_id, indexed_file, document_spans));
281
282                    if queue_len >= EMBEDDINGS_BATCH_SIZE {
283                        let _ = embed_batch(
284                            embeddings_queue,
285                            &embedding_provider_clone,
286                            db_update_tx_clone.clone(),
287                        )
288                        .await;
289
290                        embeddings_queue = vec![];
291                        queue_len = 0;
292
293                        request_count += 1;
294                    }
295                }
296
297                if queue_len > 0 {
298                    let _ = embed_batch(
299                        embeddings_queue,
300                        &embedding_provider_clone,
301                        db_update_tx_clone.clone(),
302                    )
303                    .await;
304                    request_count += 1;
305                }
306            });
307
308            let fs_clone = fs.clone();
309
310            let _paths_update_task = cx.background().spawn(async move {
311                let mut parser = Parser::new();
312                let mut cursor = QueryCursor::new();
313                while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await {
314                    if let Some((indexed_file, document_spans)) = Self::index_file(
315                        &mut cursor,
316                        &mut parser,
317                        &fs_clone,
318                        language,
319                        file_path.clone(),
320                        mtime,
321                    )
322                    .await
323                    .log_err()
324                    {
325                        embeddings_tx
326                            .try_send((worktree_id, indexed_file, document_spans))
327                            .unwrap();
328                    }
329                }
330            });
331
332            Self {
333                fs,
334                database_url,
335                db_update_tx,
336                paths_tx,
337                embedding_provider,
338                language_registry,
339                projects: HashMap::new(),
340                _db_update_task,
341                _paths_update_task,
342                _embeddings_update_task,
343            }
344        }))
345    }
346
347    async fn index_file(
348        cursor: &mut QueryCursor,
349        parser: &mut Parser,
350        fs: &Arc<dyn Fs>,
351        language: Arc<Language>,
352        file_path: PathBuf,
353        mtime: SystemTime,
354    ) -> Result<(IndexedFile, Vec<String>)> {
355        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
356        let embedding_config = grammar
357            .embedding_config
358            .as_ref()
359            .ok_or_else(|| anyhow!("no outline query"))?;
360
361        let content = fs.load(&file_path).await?;
362
363        parser.set_language(grammar.ts_language).unwrap();
364        let tree = parser
365            .parse(&content, None)
366            .ok_or_else(|| anyhow!("parsing failed"))?;
367
368        let mut documents = Vec::new();
369        let mut context_spans = Vec::new();
370        for mat in cursor.matches(
371            &embedding_config.query,
372            tree.root_node(),
373            content.as_bytes(),
374        ) {
375            let mut item_range = None;
376            let mut name_range = None;
377            for capture in mat.captures {
378                if capture.index == embedding_config.item_capture_ix {
379                    item_range = Some(capture.node.byte_range());
380                } else if capture.index == embedding_config.name_capture_ix {
381                    name_range = Some(capture.node.byte_range());
382                }
383            }
384
385            if let Some((item_range, name_range)) = item_range.zip(name_range) {
386                if let Some((item, name)) =
387                    content.get(item_range.clone()).zip(content.get(name_range))
388                {
389                    context_spans.push(item.to_string());
390                    documents.push(Document {
391                        name: name.to_string(),
392                        offset: item_range.start,
393                        embedding: Vec::new(),
394                    });
395                }
396            }
397        }
398
399        return Ok((
400            IndexedFile {
401                path: file_path,
402                mtime,
403                documents,
404            },
405            context_spans,
406        ));
407    }
408
409    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
410        let (tx, rx) = oneshot::channel();
411        self.db_update_tx
412            .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
413            .unwrap();
414        async move { rx.await? }
415    }
416
417    fn add_project(
418        &mut self,
419        project: ModelHandle<Project>,
420        cx: &mut ModelContext<Self>,
421    ) -> Task<Result<()>> {
422        let worktree_scans_complete = project
423            .read(cx)
424            .worktrees(cx)
425            .map(|worktree| {
426                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
427                async move {
428                    scan_complete.await;
429                }
430            })
431            .collect::<Vec<_>>();
432        let worktree_db_ids = project
433            .read(cx)
434            .worktrees(cx)
435            .map(|worktree| {
436                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
437            })
438            .collect::<Vec<_>>();
439
440        let fs = self.fs.clone();
441        let language_registry = self.language_registry.clone();
442        let database_url = self.database_url.clone();
443        let db_update_tx = self.db_update_tx.clone();
444        let paths_tx = self.paths_tx.clone();
445
446        cx.spawn(|this, mut cx| async move {
447            futures::future::join_all(worktree_scans_complete).await;
448
449            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
450
451            if let Some(db_directory) = database_url.parent() {
452                fs.create_dir(db_directory).await.log_err();
453            }
454
455            let worktrees = project.read_with(&cx, |project, cx| {
456                project
457                    .worktrees(cx)
458                    .map(|worktree| worktree.read(cx).snapshot())
459                    .collect::<Vec<_>>()
460            });
461
462            // Here we query the worktree ids, and yet we dont have them elsewhere
463            // We likely want to clean up these datastructures
464            let (mut worktree_file_times, db_ids_by_worktree_id) = cx
465                .background()
466                .spawn({
467                    let worktrees = worktrees.clone();
468                    async move {
469                        let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
470                        let mut db_ids_by_worktree_id = HashMap::new();
471                        let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
472                            HashMap::new();
473                        for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
474                            let db_id = db_id?;
475                            db_ids_by_worktree_id.insert(worktree.id(), db_id);
476                            file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
477                        }
478                        anyhow::Ok((file_times, db_ids_by_worktree_id))
479                    }
480                })
481                .await?;
482
483            cx.background()
484                .spawn({
485                    let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
486                    let db_update_tx = db_update_tx.clone();
487                    let language_registry = language_registry.clone();
488                    let paths_tx = paths_tx.clone();
489                    async move {
490                        for worktree in worktrees.into_iter() {
491                            let mut file_mtimes =
492                                worktree_file_times.remove(&worktree.id()).unwrap();
493                            for file in worktree.files(false, 0) {
494                                let absolute_path = worktree.absolutize(&file.path);
495
496                                if let Ok(language) = language_registry
497                                    .language_for_file(&absolute_path, None)
498                                    .await
499                                {
500                                    if language
501                                        .grammar()
502                                        .and_then(|grammar| grammar.embedding_config.as_ref())
503                                        .is_none()
504                                    {
505                                        continue;
506                                    }
507
508                                    let path_buf = file.path.to_path_buf();
509                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
510                                    let already_stored = stored_mtime
511                                        .map_or(false, |existing_mtime| {
512                                            existing_mtime == file.mtime
513                                        });
514
515                                    if !already_stored {
516                                        paths_tx
517                                            .try_send((
518                                                db_ids_by_worktree_id[&worktree.id()],
519                                                path_buf,
520                                                language,
521                                                file.mtime,
522                                            ))
523                                            .unwrap();
524                                    }
525                                }
526                            }
527                            for file in file_mtimes.keys() {
528                                db_update_tx
529                                    .try_send(DbWrite::Delete {
530                                        worktree_id: db_ids_by_worktree_id[&worktree.id()],
531                                        path: file.to_owned(),
532                                    })
533                                    .unwrap();
534                            }
535                        }
536                    }
537                })
538                .detach();
539
540            this.update(&mut cx, |this, cx| {
541                // The below is managing for updated on save
542                // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
543                // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
544                let _subscription = cx.subscribe(&project, |this, project, event, _cx| {
545                    if let Some(project_state) = this.projects.get(&project.downgrade()) {
546                        let worktree_db_ids = project_state.worktree_db_ids.clone();
547
548                        if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
549                        {
550                            // Iterate through changes
551                            let language_registry = this.language_registry.clone();
552
553                            let db =
554                                VectorDatabase::new(this.database_url.to_string_lossy().into());
555                            if db.is_err() {
556                                return;
557                            }
558                            let db = db.unwrap();
559
560                            let worktree_db_id: Option<i64> = {
561                                let mut found_db_id = None;
562                                for (w_id, db_id) in worktree_db_ids.into_iter() {
563                                    if &w_id == worktree_id {
564                                        found_db_id = Some(db_id);
565                                    }
566                                }
567
568                                found_db_id
569                            };
570
571                            if worktree_db_id.is_none() {
572                                return;
573                            }
574                            let worktree_db_id = worktree_db_id.unwrap();
575
576                            let file_mtimes = db.get_file_mtimes(worktree_db_id);
577                            if file_mtimes.is_err() {
578                                return;
579                            }
580
581                            let file_mtimes = file_mtimes.unwrap();
582                            let paths_tx = this.paths_tx.clone();
583
584                            smol::block_on(async move {
585                                for change in changes.into_iter() {
586                                    let change_path = change.0.clone();
587                                    log::info!("Change: {:?}", &change_path);
588                                    if let Ok(language) = language_registry
589                                        .language_for_file(&change_path.to_path_buf(), None)
590                                        .await
591                                    {
592                                        if language
593                                            .grammar()
594                                            .and_then(|grammar| grammar.embedding_config.as_ref())
595                                            .is_none()
596                                        {
597                                            continue;
598                                        }
599
600                                        // TODO: Make this a bit more defensive
601                                        let modified_time =
602                                            change_path.metadata().unwrap().modified().unwrap();
603                                        let existing_time =
604                                            file_mtimes.get(&change_path.to_path_buf());
605                                        let already_stored =
606                                            existing_time.map_or(false, |existing_time| {
607                                                if &modified_time != existing_time
608                                                    && existing_time.elapsed().unwrap().as_secs()
609                                                        > REINDEXING_DELAY
610                                                {
611                                                    false
612                                                } else {
613                                                    true
614                                                }
615                                            });
616
617                                        if !already_stored {
618                                            log::info!("Need to reindex: {:?}", &change_path);
619                                            paths_tx
620                                                .try_send((
621                                                    worktree_db_id,
622                                                    change_path.to_path_buf(),
623                                                    language,
624                                                    modified_time,
625                                                ))
626                                                .unwrap();
627                                        }
628                                    }
629                                }
630                            })
631                        }
632                    }
633                });
634
635                this.projects.insert(
636                    project.downgrade(),
637                    ProjectState {
638                        worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
639                        _subscription,
640                    },
641                );
642            });
643
644            anyhow::Ok(())
645        })
646    }
647
648    pub fn search(
649        &mut self,
650        project: ModelHandle<Project>,
651        phrase: String,
652        limit: usize,
653        cx: &mut ModelContext<Self>,
654    ) -> Task<Result<Vec<SearchResult>>> {
655        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
656            state
657        } else {
658            return Task::ready(Err(anyhow!("project not added")));
659        };
660
661        let worktree_db_ids = project
662            .read(cx)
663            .worktrees(cx)
664            .filter_map(|worktree| {
665                let worktree_id = worktree.read(cx).id();
666                project_state
667                    .worktree_db_ids
668                    .iter()
669                    .find_map(|(id, db_id)| {
670                        if *id == worktree_id {
671                            Some(*db_id)
672                        } else {
673                            None
674                        }
675                    })
676            })
677            .collect::<Vec<_>>();
678
679        let embedding_provider = self.embedding_provider.clone();
680        let database_url = self.database_url.clone();
681        cx.spawn(|this, cx| async move {
682            let documents = cx
683                .background()
684                .spawn(async move {
685                    let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
686
687                    let phrase_embedding = embedding_provider
688                        .embed_batch(vec![&phrase])
689                        .await?
690                        .into_iter()
691                        .next()
692                        .unwrap();
693
694                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
695                    database.for_each_document(&worktree_db_ids, |id, embedding| {
696                        let similarity = dot(&embedding.0, &phrase_embedding);
697                        let ix = match results.binary_search_by(|(_, s)| {
698                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
699                        }) {
700                            Ok(ix) => ix,
701                            Err(ix) => ix,
702                        };
703                        results.insert(ix, (id, similarity));
704                        results.truncate(limit);
705                    })?;
706
707                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
708                    database.get_documents_by_ids(&ids)
709                })
710                .await?;
711
712            this.read_with(&cx, |this, _| {
713                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
714                    state
715                } else {
716                    return Err(anyhow!("project not added"));
717                };
718
719                Ok(documents
720                    .into_iter()
721                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
722                        let worktree_id =
723                            project_state
724                                .worktree_db_ids
725                                .iter()
726                                .find_map(|(id, db_id)| {
727                                    if *db_id == worktree_db_id {
728                                        Some(*id)
729                                    } else {
730                                        None
731                                    }
732                                })?;
733                        Some(SearchResult {
734                            worktree_id,
735                            name,
736                            offset,
737                            file_path,
738                        })
739                    })
740                    .collect())
741            })
742        })
743    }
744}
745
746impl Entity for VectorStore {
747    type Event = ();
748}
749
750fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
751    let len = vec_a.len();
752    assert_eq!(len, vec_b.len());
753
754    let mut result = 0.0;
755    unsafe {
756        matrixmultiply::sgemm(
757            1,
758            len,
759            1,
760            1.0,
761            vec_a.as_ptr(),
762            len as isize,
763            1,
764            vec_b.as_ptr(),
765            1,
766            len as isize,
767            0.0,
768            &mut result as *mut f32,
769            1,
770            1,
771        );
772    }
773    result
774}