vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::{FileSha1, VectorDatabase};
 10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 11use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 12use language::{Language, LanguageRegistry};
 13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 14use project::{Fs, Project, WorktreeId};
 15use smol::channel;
 16use std::{
 17    cmp::Ordering,
 18    collections::{HashMap, HashSet},
 19    path::{Path, PathBuf},
 20    sync::Arc,
 21};
 22use tree_sitter::{Parser, QueryCursor};
 23use util::{
 24    channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt,
 25};
 26use workspace::{Workspace, WorkspaceCreated};
 27
 28#[derive(Debug)]
 29pub struct Document {
 30    pub offset: usize,
 31    pub name: String,
 32    pub embedding: Vec<f32>,
 33}
 34
 35pub fn init(
 36    fs: Arc<dyn Fs>,
 37    http_client: Arc<dyn HttpClient>,
 38    language_registry: Arc<LanguageRegistry>,
 39    cx: &mut AppContext,
 40) {
 41    let db_file_path = EMBEDDINGS_DIR
 42        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 43        .join("embeddings_db");
 44
 45    let vector_store = cx.add_model(|_| {
 46        VectorStore::new(
 47            fs,
 48            db_file_path,
 49            Arc::new(OpenAIEmbeddings {
 50                client: http_client,
 51            }),
 52            language_registry,
 53        )
 54    });
 55
 56    cx.subscribe_global::<WorkspaceCreated, _>({
 57        let vector_store = vector_store.clone();
 58        move |event, cx| {
 59            let workspace = &event.0;
 60            if let Some(workspace) = workspace.upgrade(cx) {
 61                let project = workspace.read(cx).project().clone();
 62                if project.read(cx).is_local() {
 63                    vector_store.update(cx, |store, cx| {
 64                        store.add_project(project, cx).detach();
 65                    });
 66                }
 67            }
 68        }
 69    })
 70    .detach();
 71
 72    cx.add_action({
 73        move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 74            let vector_store = vector_store.clone();
 75            workspace.toggle_modal(cx, |workspace, cx| {
 76                let project = workspace.project().clone();
 77                let workspace = cx.weak_handle();
 78                cx.add_view(|cx| {
 79                    SemanticSearch::new(
 80                        SemanticSearchDelegate::new(workspace, project, vector_store),
 81                        cx,
 82                    )
 83                })
 84            })
 85        }
 86    });
 87
 88    SemanticSearch::init(cx);
 89}
 90
 91#[derive(Debug)]
 92pub struct IndexedFile {
 93    path: PathBuf,
 94    sha1: FileSha1,
 95    documents: Vec<Document>,
 96}
 97
 98pub struct VectorStore {
 99    fs: Arc<dyn Fs>,
100    database_url: Arc<PathBuf>,
101    embedding_provider: Arc<dyn EmbeddingProvider>,
102    language_registry: Arc<LanguageRegistry>,
103    worktree_db_ids: Vec<(WorktreeId, i64)>,
104}
105
106#[derive(Debug, Clone)]
107pub struct SearchResult {
108    pub worktree_id: WorktreeId,
109    pub name: String,
110    pub offset: usize,
111    pub file_path: PathBuf,
112}
113
114impl VectorStore {
115    fn new(
116        fs: Arc<dyn Fs>,
117        database_url: PathBuf,
118        embedding_provider: Arc<dyn EmbeddingProvider>,
119        language_registry: Arc<LanguageRegistry>,
120    ) -> Self {
121        Self {
122            fs,
123            database_url: Arc::new(database_url),
124            embedding_provider,
125            language_registry,
126            worktree_db_ids: Vec::new(),
127        }
128    }
129
130    async fn index_file(
131        cursor: &mut QueryCursor,
132        parser: &mut Parser,
133        embedding_provider: &dyn EmbeddingProvider,
134        language: Arc<Language>,
135        file_path: PathBuf,
136        content: String,
137    ) -> Result<IndexedFile> {
138        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
139        let embedding_config = grammar
140            .embedding_config
141            .as_ref()
142            .ok_or_else(|| anyhow!("no outline query"))?;
143
144        parser.set_language(grammar.ts_language).unwrap();
145        let tree = parser
146            .parse(&content, None)
147            .ok_or_else(|| anyhow!("parsing failed"))?;
148
149        let mut documents = Vec::new();
150        let mut context_spans = Vec::new();
151        for mat in cursor.matches(
152            &embedding_config.query,
153            tree.root_node(),
154            content.as_bytes(),
155        ) {
156            let mut item_range = None;
157            let mut name_range = None;
158            for capture in mat.captures {
159                if capture.index == embedding_config.item_capture_ix {
160                    item_range = Some(capture.node.byte_range());
161                } else if capture.index == embedding_config.name_capture_ix {
162                    name_range = Some(capture.node.byte_range());
163                }
164            }
165
166            if let Some((item_range, name_range)) = item_range.zip(name_range) {
167                if let Some((item, name)) =
168                    content.get(item_range.clone()).zip(content.get(name_range))
169                {
170                    context_spans.push(item);
171                    documents.push(Document {
172                        name: name.to_string(),
173                        offset: item_range.start,
174                        embedding: Vec::new(),
175                    });
176                }
177            }
178        }
179
180        if !documents.is_empty() {
181            let embeddings = embedding_provider.embed_batch(context_spans).await?;
182            for (document, embedding) in documents.iter_mut().zip(embeddings) {
183                document.embedding = embedding;
184            }
185        }
186
187        let sha1 = FileSha1::from_str(content);
188
189        return Ok(IndexedFile {
190            path: file_path,
191            sha1,
192            documents,
193        });
194    }
195
196    fn add_project(
197        &mut self,
198        project: ModelHandle<Project>,
199        cx: &mut ModelContext<Self>,
200    ) -> Task<Result<()>> {
201        let worktree_scans_complete = project
202            .read(cx)
203            .worktrees(cx)
204            .map(|worktree| {
205                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
206                async move {
207                    scan_complete.await;
208                }
209            })
210            .collect::<Vec<_>>();
211
212        let fs = self.fs.clone();
213        let language_registry = self.language_registry.clone();
214        let embedding_provider = self.embedding_provider.clone();
215        let database_url = self.database_url.clone();
216
217        cx.spawn(|this, mut cx| async move {
218            futures::future::join_all(worktree_scans_complete).await;
219
220            if let Some(db_directory) = database_url.parent() {
221                fs.create_dir(db_directory).await.log_err();
222            }
223            let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
224
225            let worktrees = project.read_with(&cx, |project, cx| {
226                project
227                    .worktrees(cx)
228                    .map(|worktree| worktree.read(cx).snapshot())
229                    .collect::<Vec<_>>()
230            });
231
232            // Here we query the worktree ids, and yet we dont have them elsewhere
233            // We likely want to clean up these datastructures
234            let (db, worktree_hashes, worktree_db_ids) = cx
235                .background()
236                .spawn({
237                    let worktrees = worktrees.clone();
238                    async move {
239                        let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
240                        let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
241                            HashMap::new();
242                        for worktree in worktrees {
243                            let worktree_db_id =
244                                db.find_or_create_worktree(worktree.abs_path().as_ref())?;
245                            worktree_db_ids.insert(worktree.id(), worktree_db_id);
246                            hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
247                        }
248                        anyhow::Ok((db, hashes, worktree_db_ids))
249                    }
250                })
251                .await?;
252
253            let (paths_tx, paths_rx) =
254                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
255            let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>();
256            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
257            cx.background()
258                .spawn({
259                    let fs = fs.clone();
260                    let worktree_db_ids = worktree_db_ids.clone();
261                    async move {
262                        for worktree in worktrees.into_iter() {
263                            let file_hashes = &worktree_hashes[&worktree.id()];
264                            let mut files_included =
265                                file_hashes.keys().collect::<HashSet<&PathBuf>>();
266                            for file in worktree.files(false, 0) {
267                                let absolute_path = worktree.absolutize(&file.path);
268
269                                if let Ok(language) = language_registry
270                                    .language_for_file(&absolute_path, None)
271                                    .await
272                                {
273                                    if language
274                                        .grammar()
275                                        .and_then(|grammar| grammar.embedding_config.as_ref())
276                                        .is_none()
277                                    {
278                                        continue;
279                                    }
280
281                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
282                                        let path_buf = file.path.to_path_buf();
283                                        let already_stored = file_hashes.get(&path_buf).map_or(
284                                            false,
285                                            |existing_hash| {
286                                                files_included.remove(&path_buf);
287                                                existing_hash.equals(&content)
288                                            },
289                                        );
290
291                                        if !already_stored {
292                                            paths_tx
293                                                .try_send((
294                                                    worktree_db_ids[&worktree.id()],
295                                                    path_buf,
296                                                    content,
297                                                    language,
298                                                ))
299                                                .unwrap();
300                                        }
301                                    }
302                                }
303                            }
304                            for file in files_included {
305                                delete_paths_tx
306                                    .try_send((worktree_db_ids[&worktree.id()], file.to_owned()))
307                                    .unwrap();
308                            }
309                        }
310                    }
311                })
312                .detach();
313
314            let db_update_task = cx.background().spawn(
315                async move {
316                    // Inserting all new files
317                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
318                        log::info!("Inserting File: {:?}", &indexed_file.path);
319                        db.insert_file(worktree_id, indexed_file).log_err();
320                    }
321
322                    // Deleting all old files
323                    while let Ok((worktree_id, delete_path)) = delete_paths_rx.recv().await {
324                        log::info!("Deleting File: {:?}", &delete_path);
325                        db.delete_file(worktree_id, delete_path).log_err();
326                    }
327
328                    anyhow::Ok(())
329                }
330                .log_err(),
331            );
332
333            cx.background()
334                .scoped(|scope| {
335                    for _ in 0..cx.background().num_cpus() {
336                        scope.spawn(async {
337                            let mut parser = Parser::new();
338                            let mut cursor = QueryCursor::new();
339                            while let Ok((worktree_id, file_path, content, language)) =
340                                paths_rx.recv().await
341                            {
342                                if let Some(indexed_file) = Self::index_file(
343                                    &mut cursor,
344                                    &mut parser,
345                                    embedding_provider.as_ref(),
346                                    language,
347                                    file_path,
348                                    content,
349                                )
350                                .await
351                                .log_err()
352                                {
353                                    indexed_files_tx
354                                        .try_send((worktree_id, indexed_file))
355                                        .unwrap();
356                                }
357                            }
358                        });
359                    }
360                })
361                .await;
362            drop(indexed_files_tx);
363
364            db_update_task.await;
365
366            this.update(&mut cx, |this, _| {
367                this.worktree_db_ids.extend(worktree_db_ids);
368            });
369
370            log::info!("Semantic Indexing Complete!");
371
372            anyhow::Ok(())
373        })
374    }
375
376    pub fn search(
377        &mut self,
378        project: &ModelHandle<Project>,
379        phrase: String,
380        limit: usize,
381        cx: &mut ModelContext<Self>,
382    ) -> Task<Result<Vec<SearchResult>>> {
383        let project = project.read(cx);
384        let worktree_db_ids = project
385            .worktrees(cx)
386            .filter_map(|worktree| {
387                let worktree_id = worktree.read(cx).id();
388                self.worktree_db_ids.iter().find_map(|(id, db_id)| {
389                    if *id == worktree_id {
390                        Some(*db_id)
391                    } else {
392                        None
393                    }
394                })
395            })
396            .collect::<Vec<_>>();
397
398        let embedding_provider = self.embedding_provider.clone();
399        let database_url = self.database_url.clone();
400        cx.spawn(|this, cx| async move {
401            let documents = cx
402                .background()
403                .spawn(async move {
404                    let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
405
406                    let phrase_embedding = embedding_provider
407                        .embed_batch(vec![&phrase])
408                        .await?
409                        .into_iter()
410                        .next()
411                        .unwrap();
412
413                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
414                    database.for_each_document(&worktree_db_ids, |id, embedding| {
415                        let similarity = dot(&embedding.0, &phrase_embedding);
416                        let ix = match results.binary_search_by(|(_, s)| {
417                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
418                        }) {
419                            Ok(ix) => ix,
420                            Err(ix) => ix,
421                        };
422                        results.insert(ix, (id, similarity));
423                        results.truncate(limit);
424                    })?;
425
426                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
427                    database.get_documents_by_ids(&ids)
428                })
429                .await?;
430
431            let results = this.read_with(&cx, |this, _| {
432                documents
433                    .into_iter()
434                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
435                        let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
436                            if *db_id == worktree_db_id {
437                                Some(*id)
438                            } else {
439                                None
440                            }
441                        })?;
442                        Some(SearchResult {
443                            worktree_id,
444                            name,
445                            offset,
446                            file_path,
447                        })
448                    })
449                    .collect()
450            });
451
452            anyhow::Ok(results)
453        })
454    }
455}
456
457impl Entity for VectorStore {
458    type Event = ();
459}
460
461fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
462    let len = vec_a.len();
463    assert_eq!(len, vec_b.len());
464
465    let mut result = 0.0;
466    unsafe {
467        matrixmultiply::sgemm(
468            1,
469            len,
470            1,
471            1.0,
472            vec_a.as_ptr(),
473            len as isize,
474            1,
475            vec_b.as_ptr(),
476            1,
477            len as isize,
478            0.0,
479            &mut result as *mut f32,
480            1,
481            1,
482        );
483    }
484    result
485}