vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4
  5#[cfg(test)]
  6mod vector_store_tests;
  7
  8use anyhow::{anyhow, Result};
  9use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
 10use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
 11use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
 12use language::{Language, LanguageRegistry};
 13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 14use project::{Fs, Project, WorktreeId};
 15use smol::channel;
 16use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 17use tree_sitter::{Parser, QueryCursor};
 18use util::{http::HttpClient, ResultExt, TryFutureExt};
 19use workspace::{Workspace, WorkspaceCreated};
 20
 21#[derive(Debug)]
 22pub struct Document {
 23    pub offset: usize,
 24    pub name: String,
 25    pub embedding: Vec<f32>,
 26}
 27
 28pub fn init(
 29    fs: Arc<dyn Fs>,
 30    http_client: Arc<dyn HttpClient>,
 31    language_registry: Arc<LanguageRegistry>,
 32    cx: &mut AppContext,
 33) {
 34    let vector_store = cx.add_model(|_| {
 35        VectorStore::new(
 36            fs,
 37            VECTOR_DB_URL.to_string(),
 38            // Arc::new(DummyEmbeddings {}),
 39            Arc::new(OpenAIEmbeddings {
 40                client: http_client,
 41            }),
 42            language_registry,
 43        )
 44    });
 45
 46    cx.subscribe_global::<WorkspaceCreated, _>({
 47        let vector_store = vector_store.clone();
 48        move |event, cx| {
 49            let workspace = &event.0;
 50            if let Some(workspace) = workspace.upgrade(cx) {
 51                let project = workspace.read(cx).project().clone();
 52                if project.read(cx).is_local() {
 53                    vector_store.update(cx, |store, cx| {
 54                        store.add_project(project, cx).detach();
 55                    });
 56                }
 57            }
 58        }
 59    })
 60    .detach();
 61
 62    cx.add_action({
 63        move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 64            let vector_store = vector_store.clone();
 65            workspace.toggle_modal(cx, |workspace, cx| {
 66                let project = workspace.project().clone();
 67                let workspace = cx.weak_handle();
 68                cx.add_view(|cx| {
 69                    SemanticSearch::new(
 70                        SemanticSearchDelegate::new(workspace, project, vector_store),
 71                        cx,
 72                    )
 73                })
 74            })
 75        }
 76    });
 77
 78    SemanticSearch::init(cx);
 79}
 80
 81#[derive(Debug)]
 82pub struct IndexedFile {
 83    path: PathBuf,
 84    sha1: FileSha1,
 85    documents: Vec<Document>,
 86}
 87
 88pub struct VectorStore {
 89    fs: Arc<dyn Fs>,
 90    database_url: Arc<str>,
 91    embedding_provider: Arc<dyn EmbeddingProvider>,
 92    language_registry: Arc<LanguageRegistry>,
 93    worktree_db_ids: Vec<(WorktreeId, i64)>,
 94}
 95
 96#[derive(Debug)]
 97pub struct SearchResult {
 98    pub worktree_id: WorktreeId,
 99    pub name: String,
100    pub offset: usize,
101    pub file_path: PathBuf,
102}
103
104impl VectorStore {
105    fn new(
106        fs: Arc<dyn Fs>,
107        database_url: String,
108        embedding_provider: Arc<dyn EmbeddingProvider>,
109        language_registry: Arc<LanguageRegistry>,
110    ) -> Self {
111        Self {
112            fs,
113            database_url: database_url.into(),
114            embedding_provider,
115            language_registry,
116            worktree_db_ids: Vec::new(),
117        }
118    }
119
120    async fn index_file(
121        cursor: &mut QueryCursor,
122        parser: &mut Parser,
123        embedding_provider: &dyn EmbeddingProvider,
124        language: Arc<Language>,
125        file_path: PathBuf,
126        content: String,
127    ) -> Result<IndexedFile> {
128        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
129        let outline_config = grammar
130            .outline_config
131            .as_ref()
132            .ok_or_else(|| anyhow!("no outline query"))?;
133
134        parser.set_language(grammar.ts_language).unwrap();
135        let tree = parser
136            .parse(&content, None)
137            .ok_or_else(|| anyhow!("parsing failed"))?;
138
139        let mut documents = Vec::new();
140        let mut context_spans = Vec::new();
141        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
142            let mut item_range = None;
143            let mut name_range = None;
144            for capture in mat.captures {
145                if capture.index == outline_config.item_capture_ix {
146                    item_range = Some(capture.node.byte_range());
147                } else if capture.index == outline_config.name_capture_ix {
148                    name_range = Some(capture.node.byte_range());
149                }
150            }
151
152            if let Some((item_range, name_range)) = item_range.zip(name_range) {
153                if let Some((item, name)) =
154                    content.get(item_range.clone()).zip(content.get(name_range))
155                {
156                    context_spans.push(item);
157                    documents.push(Document {
158                        name: name.to_string(),
159                        offset: item_range.start,
160                        embedding: Vec::new(),
161                    });
162                }
163            }
164        }
165
166        if !documents.is_empty() {
167            let embeddings = embedding_provider.embed_batch(context_spans).await?;
168            for (document, embedding) in documents.iter_mut().zip(embeddings) {
169                document.embedding = embedding;
170            }
171        }
172
173        let sha1 = FileSha1::from_str(content);
174
175        return Ok(IndexedFile {
176            path: file_path,
177            sha1,
178            documents,
179        });
180    }
181
182    fn add_project(
183        &mut self,
184        project: ModelHandle<Project>,
185        cx: &mut ModelContext<Self>,
186    ) -> Task<Result<()>> {
187        let worktree_scans_complete = project
188            .read(cx)
189            .worktrees(cx)
190            .map(|worktree| {
191                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
192                async move {
193                    scan_complete.await;
194                    log::info!("worktree scan completed");
195                }
196            })
197            .collect::<Vec<_>>();
198
199        let fs = self.fs.clone();
200        let language_registry = self.language_registry.clone();
201        let embedding_provider = self.embedding_provider.clone();
202        let database_url = self.database_url.clone();
203
204        cx.spawn(|this, mut cx| async move {
205            futures::future::join_all(worktree_scans_complete).await;
206
207            // TODO: remove this after fixing the bug in scan_complete
208            cx.background()
209                .timer(std::time::Duration::from_secs(3))
210                .await;
211
212            let db = VectorDatabase::new(&database_url)?;
213
214            let worktrees = project.read_with(&cx, |project, cx| {
215                project
216                    .worktrees(cx)
217                    .map(|worktree| worktree.read(cx).snapshot())
218                    .collect::<Vec<_>>()
219            });
220
221            // Here we query the worktree ids, and yet we dont have them elsewhere
222            // We likely want to clean up these datastructures
223            let (db, worktree_hashes, worktree_db_ids) = cx
224                .background()
225                .spawn({
226                    let worktrees = worktrees.clone();
227                    async move {
228                        let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
229                        let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
230                            HashMap::new();
231                        for worktree in worktrees {
232                            let worktree_db_id =
233                                db.find_or_create_worktree(worktree.abs_path().as_ref())?;
234                            worktree_db_ids.insert(worktree.id(), worktree_db_id);
235                            hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
236                        }
237                        anyhow::Ok((db, hashes, worktree_db_ids))
238                    }
239                })
240                .await?;
241
242            let (paths_tx, paths_rx) =
243                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
244            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
245            cx.background()
246                .spawn({
247                    let fs = fs.clone();
248                    let worktree_db_ids = worktree_db_ids.clone();
249                    async move {
250                        for worktree in worktrees.into_iter() {
251                            let file_hashes = &worktree_hashes[&worktree.id()];
252                            for file in worktree.files(false, 0) {
253                                let absolute_path = worktree.absolutize(&file.path);
254
255                                if let Ok(language) = language_registry
256                                    .language_for_file(&absolute_path, None)
257                                    .await
258                                {
259                                    if language.name().as_ref() != "Rust" {
260                                        continue;
261                                    }
262
263                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
264                                        log::info!("loaded file: {absolute_path:?}");
265
266                                        let path_buf = file.path.to_path_buf();
267                                        let already_stored = file_hashes
268                                            .get(&path_buf)
269                                            .map_or(false, |existing_hash| {
270                                                existing_hash.equals(&content)
271                                            });
272
273                                        if !already_stored {
274                                            log::info!(
275                                                "File Changed (Sending to Parse): {:?}",
276                                                &path_buf
277                                            );
278                                            paths_tx
279                                                .try_send((
280                                                    worktree_db_ids[&worktree.id()],
281                                                    path_buf,
282                                                    content,
283                                                    language,
284                                                ))
285                                                .unwrap();
286                                        }
287                                    }
288                                }
289                            }
290                        }
291                    }
292                })
293                .detach();
294
295            let db_write_task = cx.background().spawn(
296                async move {
297                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
298                        db.insert_file(worktree_id, indexed_file).log_err();
299                    }
300
301                    anyhow::Ok(())
302                }
303                .log_err(),
304            );
305
306            cx.background()
307                .scoped(|scope| {
308                    for _ in 0..cx.background().num_cpus() {
309                        scope.spawn(async {
310                            let mut parser = Parser::new();
311                            let mut cursor = QueryCursor::new();
312                            while let Ok((worktree_id, file_path, content, language)) =
313                                paths_rx.recv().await
314                            {
315                                if let Some(indexed_file) = Self::index_file(
316                                    &mut cursor,
317                                    &mut parser,
318                                    embedding_provider.as_ref(),
319                                    language,
320                                    file_path,
321                                    content,
322                                )
323                                .await
324                                .log_err()
325                                {
326                                    indexed_files_tx
327                                        .try_send((worktree_id, indexed_file))
328                                        .unwrap();
329                                }
330                            }
331                        });
332                    }
333                })
334                .await;
335            drop(indexed_files_tx);
336
337            db_write_task.await;
338
339            this.update(&mut cx, |this, _| {
340                this.worktree_db_ids.extend(worktree_db_ids);
341            });
342
343            anyhow::Ok(())
344        })
345    }
346
347    pub fn search(
348        &mut self,
349        project: &ModelHandle<Project>,
350        phrase: String,
351        limit: usize,
352        cx: &mut ModelContext<Self>,
353    ) -> Task<Result<Vec<SearchResult>>> {
354        let project = project.read(cx);
355        let worktree_db_ids = project
356            .worktrees(cx)
357            .filter_map(|worktree| {
358                let worktree_id = worktree.read(cx).id();
359                self.worktree_db_ids.iter().find_map(|(id, db_id)| {
360                    if *id == worktree_id {
361                        Some(*db_id)
362                    } else {
363                        None
364                    }
365                })
366            })
367            .collect::<Vec<_>>();
368
369        let embedding_provider = self.embedding_provider.clone();
370        let database_url = self.database_url.clone();
371        cx.spawn(|this, cx| async move {
372            let documents = cx
373                .background()
374                .spawn(async move {
375                    let database = VectorDatabase::new(database_url.as_ref())?;
376
377                    let phrase_embedding = embedding_provider
378                        .embed_batch(vec![&phrase])
379                        .await?
380                        .into_iter()
381                        .next()
382                        .unwrap();
383
384                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
385                    database.for_each_document(&worktree_db_ids, |id, embedding| {
386                        let similarity = dot(&embedding.0, &phrase_embedding);
387                        let ix = match results.binary_search_by(|(_, s)| {
388                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
389                        }) {
390                            Ok(ix) => ix,
391                            Err(ix) => ix,
392                        };
393                        results.insert(ix, (id, similarity));
394                        results.truncate(limit);
395                    })?;
396
397                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
398                    database.get_documents_by_ids(&ids)
399                })
400                .await?;
401
402            let results = this.read_with(&cx, |this, _| {
403                documents
404                    .into_iter()
405                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
406                        let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
407                            if *db_id == worktree_db_id {
408                                Some(*id)
409                            } else {
410                                None
411                            }
412                        })?;
413                        Some(SearchResult {
414                            worktree_id,
415                            name,
416                            offset,
417                            file_path,
418                        })
419                    })
420                    .collect()
421            });
422
423            anyhow::Ok(results)
424        })
425    }
426}
427
428impl Entity for VectorStore {
429    type Event = ();
430}
431
432fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
433    let len = vec_a.len();
434    assert_eq!(len, vec_b.len());
435
436    let mut result = 0.0;
437    unsafe {
438        matrixmultiply::sgemm(
439            1,
440            len,
441            1,
442            1.0,
443            vec_a.as_ptr(),
444            len as isize,
445            1,
446            vec_b.as_ptr(),
447            1,
448            len as isize,
449            0.0,
450            &mut result as *mut f32,
451            1,
452            1,
453        );
454    }
455    result
456}