1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
10use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
11use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
12use language::{Language, LanguageRegistry};
13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
14use project::{Fs, Project, WorktreeId};
15use smol::channel;
16use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
17use tree_sitter::{Parser, QueryCursor};
18use util::{http::HttpClient, ResultExt, TryFutureExt};
19use workspace::{Workspace, WorkspaceCreated};
20
21#[derive(Debug)]
22pub struct Document {
23 pub offset: usize,
24 pub name: String,
25 pub embedding: Vec<f32>,
26}
27
28pub fn init(
29 fs: Arc<dyn Fs>,
30 http_client: Arc<dyn HttpClient>,
31 language_registry: Arc<LanguageRegistry>,
32 cx: &mut AppContext,
33) {
34 let vector_store = cx.add_model(|cx| {
35 VectorStore::new(
36 fs,
37 VECTOR_DB_URL.to_string(),
38 // Arc::new(OpenAIEmbeddings {
39 // client: http_client,
40 // }),
41 Arc::new(DummyEmbeddings {}),
42 language_registry,
43 )
44 });
45
46 cx.subscribe_global::<WorkspaceCreated, _>({
47 let vector_store = vector_store.clone();
48 move |event, cx| {
49 let workspace = &event.0;
50 if let Some(workspace) = workspace.upgrade(cx) {
51 let project = workspace.read(cx).project().clone();
52 if project.read(cx).is_local() {
53 vector_store.update(cx, |store, cx| {
54 store.add_project(project, cx).detach();
55 });
56 }
57 }
58 }
59 })
60 .detach();
61
62 cx.add_action({
63 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
64 let vector_store = vector_store.clone();
65 workspace.toggle_modal(cx, |workspace, cx| {
66 let project = workspace.project().clone();
67 let workspace = cx.weak_handle();
68 cx.add_view(|cx| {
69 SemanticSearch::new(
70 SemanticSearchDelegate::new(workspace, project, vector_store),
71 cx,
72 )
73 })
74 })
75 }
76 });
77 SemanticSearch::init(cx);
78}
79
80#[derive(Debug)]
81pub struct IndexedFile {
82 path: PathBuf,
83 sha1: FileSha1,
84 documents: Vec<Document>,
85}
86
87pub struct VectorStore {
88 fs: Arc<dyn Fs>,
89 database_url: Arc<str>,
90 embedding_provider: Arc<dyn EmbeddingProvider>,
91 language_registry: Arc<LanguageRegistry>,
92 worktree_db_ids: Vec<(WorktreeId, i64)>,
93}
94
95#[derive(Debug)]
96pub struct SearchResult {
97 pub worktree_id: WorktreeId,
98 pub name: String,
99 pub offset: usize,
100 pub file_path: PathBuf,
101}
102
103impl VectorStore {
104 fn new(
105 fs: Arc<dyn Fs>,
106 database_url: String,
107 embedding_provider: Arc<dyn EmbeddingProvider>,
108 language_registry: Arc<LanguageRegistry>,
109 ) -> Self {
110 Self {
111 fs,
112 database_url: database_url.into(),
113 embedding_provider,
114 language_registry,
115 worktree_db_ids: Vec::new(),
116 }
117 }
118
119 async fn index_file(
120 cursor: &mut QueryCursor,
121 parser: &mut Parser,
122 embedding_provider: &dyn EmbeddingProvider,
123 language: Arc<Language>,
124 file_path: PathBuf,
125 content: String,
126 ) -> Result<IndexedFile> {
127 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
128 let outline_config = grammar
129 .outline_config
130 .as_ref()
131 .ok_or_else(|| anyhow!("no outline query"))?;
132
133 parser.set_language(grammar.ts_language).unwrap();
134 let tree = parser
135 .parse(&content, None)
136 .ok_or_else(|| anyhow!("parsing failed"))?;
137
138 let mut documents = Vec::new();
139 let mut context_spans = Vec::new();
140 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
141 let mut item_range = None;
142 let mut name_range = None;
143 for capture in mat.captures {
144 if capture.index == outline_config.item_capture_ix {
145 item_range = Some(capture.node.byte_range());
146 } else if capture.index == outline_config.name_capture_ix {
147 name_range = Some(capture.node.byte_range());
148 }
149 }
150
151 if let Some((item_range, name_range)) = item_range.zip(name_range) {
152 if let Some((item, name)) =
153 content.get(item_range.clone()).zip(content.get(name_range))
154 {
155 context_spans.push(item);
156 documents.push(Document {
157 name: name.to_string(),
158 offset: item_range.start,
159 embedding: Vec::new(),
160 });
161 }
162 }
163 }
164
165 if !documents.is_empty() {
166 let embeddings = embedding_provider.embed_batch(context_spans).await?;
167 for (document, embedding) in documents.iter_mut().zip(embeddings) {
168 document.embedding = embedding;
169 }
170 }
171
172 let sha1 = FileSha1::from_str(content);
173
174 return Ok(IndexedFile {
175 path: file_path,
176 sha1,
177 documents,
178 });
179 }
180
181 fn add_project(
182 &mut self,
183 project: ModelHandle<Project>,
184 cx: &mut ModelContext<Self>,
185 ) -> Task<Result<()>> {
186 let worktree_scans_complete = project
187 .read(cx)
188 .worktrees(cx)
189 .map(|worktree| {
190 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
191 async move {
192 scan_complete.await;
193 log::info!("worktree scan completed");
194 }
195 })
196 .collect::<Vec<_>>();
197
198 let fs = self.fs.clone();
199 let language_registry = self.language_registry.clone();
200 let embedding_provider = self.embedding_provider.clone();
201 let database_url = self.database_url.clone();
202
203 cx.spawn(|this, mut cx| async move {
204 futures::future::join_all(worktree_scans_complete).await;
205
206 // TODO: remove this after fixing the bug in scan_complete
207 cx.background()
208 .timer(std::time::Duration::from_secs(3))
209 .await;
210
211 let db = VectorDatabase::new(&database_url)?;
212
213 let worktrees = project.read_with(&cx, |project, cx| {
214 project
215 .worktrees(cx)
216 .map(|worktree| worktree.read(cx).snapshot())
217 .collect::<Vec<_>>()
218 });
219
220 // Here we query the worktree ids, and yet we dont have them elsewhere
221 // We likely want to clean up these datastructures
222 let (db, worktree_hashes, worktree_db_ids) = cx
223 .background()
224 .spawn({
225 let worktrees = worktrees.clone();
226 async move {
227 let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
228 let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
229 HashMap::new();
230 for worktree in worktrees {
231 let worktree_db_id =
232 db.find_or_create_worktree(worktree.abs_path().as_ref())?;
233 worktree_db_ids.insert(worktree.id(), worktree_db_id);
234 hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
235 }
236 anyhow::Ok((db, hashes, worktree_db_ids))
237 }
238 })
239 .await?;
240
241 let (paths_tx, paths_rx) =
242 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
243 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
244 cx.background()
245 .spawn({
246 let fs = fs.clone();
247 let worktree_db_ids = worktree_db_ids.clone();
248 async move {
249 for worktree in worktrees.into_iter() {
250 let file_hashes = &worktree_hashes[&worktree.id()];
251 for file in worktree.files(false, 0) {
252 let absolute_path = worktree.absolutize(&file.path);
253
254 if let Ok(language) = language_registry
255 .language_for_file(&absolute_path, None)
256 .await
257 {
258 if language.name().as_ref() != "Rust" {
259 continue;
260 }
261
262 if let Some(content) = fs.load(&absolute_path).await.log_err() {
263 log::info!("loaded file: {absolute_path:?}");
264
265 let path_buf = file.path.to_path_buf();
266 let already_stored = file_hashes
267 .get(&path_buf)
268 .map_or(false, |existing_hash| {
269 existing_hash.equals(&content)
270 });
271
272 if !already_stored {
273 log::info!(
274 "File Changed (Sending to Parse): {:?}",
275 &path_buf
276 );
277 paths_tx
278 .try_send((
279 worktree_db_ids[&worktree.id()],
280 path_buf,
281 content,
282 language,
283 ))
284 .unwrap();
285 }
286 }
287 }
288 }
289 }
290 }
291 })
292 .detach();
293
294 let db_write_task = cx.background().spawn(
295 async move {
296 // Initialize Database, creates database and tables if not exists
297 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
298 db.insert_file(worktree_id, indexed_file).log_err();
299 }
300
301 // ALL OF THE BELOW IS FOR TESTING,
302 // This should be removed as we find and appropriate place for evaluate our search.
303
304 // let queries = vec![
305 // "compute embeddings for all of the symbols in the codebase, and write them to a database",
306 // "compute an outline view of all of the symbols in a buffer",
307 // "scan a directory on the file system and load all of its children into an in-memory snapshot",
308 // ];
309 // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
310
311 // let t2 = Instant::now();
312 // let documents = db.get_documents().unwrap();
313 // let files = db.get_files().unwrap();
314 // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
315
316 // let t1 = Instant::now();
317 // let mut bfs = BruteForceSearch::load(&db).unwrap();
318 // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
319 // for (idx, embed) in embeddings.into_iter().enumerate() {
320 // let t0 = Instant::now();
321 // println!("\nQuery: {:?}", queries[idx]);
322 // let results = bfs.top_k_search(&embed, 5).await;
323 // println!("Search Elapsed: {}", t0.elapsed().as_millis());
324 // for (id, distance) in results {
325 // println!("");
326 // println!(" distance: {:?}", distance);
327 // println!(" document: {:?}", documents[&id].name);
328 // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
329 // }
330
331 // }
332
333 anyhow::Ok(())
334 }
335 .log_err(),
336 );
337
338 cx.background()
339 .scoped(|scope| {
340 for _ in 0..cx.background().num_cpus() {
341 scope.spawn(async {
342 let mut parser = Parser::new();
343 let mut cursor = QueryCursor::new();
344 while let Ok((worktree_id, file_path, content, language)) =
345 paths_rx.recv().await
346 {
347 if let Some(indexed_file) = Self::index_file(
348 &mut cursor,
349 &mut parser,
350 embedding_provider.as_ref(),
351 language,
352 file_path,
353 content,
354 )
355 .await
356 .log_err()
357 {
358 indexed_files_tx
359 .try_send((worktree_id, indexed_file))
360 .unwrap();
361 }
362 }
363 });
364 }
365 })
366 .await;
367 drop(indexed_files_tx);
368
369 db_write_task.await;
370
371 this.update(&mut cx, |this, _| {
372 this.worktree_db_ids.extend(worktree_db_ids);
373 });
374
375 anyhow::Ok(())
376 })
377 }
378
379 pub fn search(
380 &mut self,
381 project: &ModelHandle<Project>,
382 phrase: String,
383 limit: usize,
384 cx: &mut ModelContext<Self>,
385 ) -> Task<Result<Vec<SearchResult>>> {
386 let project = project.read(cx);
387 let worktree_db_ids = project
388 .worktrees(cx)
389 .filter_map(|worktree| {
390 let worktree_id = worktree.read(cx).id();
391 self.worktree_db_ids.iter().find_map(|(id, db_id)| {
392 if *id == worktree_id {
393 Some(*db_id)
394 } else {
395 None
396 }
397 })
398 })
399 .collect::<Vec<_>>();
400
401 let embedding_provider = self.embedding_provider.clone();
402 let database_url = self.database_url.clone();
403 cx.spawn(|this, cx| async move {
404 let documents = cx
405 .background()
406 .spawn(async move {
407 let database = VectorDatabase::new(database_url.as_ref())?;
408
409 let phrase_embedding = embedding_provider
410 .embed_batch(vec![&phrase])
411 .await?
412 .into_iter()
413 .next()
414 .unwrap();
415
416 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
417 database.for_each_document(&worktree_db_ids, |id, embedding| {
418 let similarity = dot(&embedding.0, &phrase_embedding);
419 let ix = match results.binary_search_by(|(_, s)| {
420 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
421 }) {
422 Ok(ix) => ix,
423 Err(ix) => ix,
424 };
425 results.insert(ix, (id, similarity));
426 results.truncate(limit);
427 })?;
428
429 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
430 database.get_documents_by_ids(&ids)
431 })
432 .await?;
433
434 let results = this.read_with(&cx, |this, _| {
435 documents
436 .into_iter()
437 .filter_map(|(worktree_db_id, file_path, offset, name)| {
438 let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
439 if *db_id == worktree_db_id {
440 Some(*id)
441 } else {
442 None
443 }
444 })?;
445 Some(SearchResult {
446 worktree_id,
447 name,
448 offset,
449 file_path,
450 })
451 })
452 .collect()
453 });
454
455 anyhow::Ok(results)
456 })
457 }
458}
459
460impl Entity for VectorStore {
461 type Event = ();
462}
463
464fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
465 let len = vec_a.len();
466 assert_eq!(len, vec_b.len());
467
468 let mut result = 0.0;
469 unsafe {
470 matrixmultiply::sgemm(
471 1,
472 len,
473 1,
474 1.0,
475 vec_a.as_ptr(),
476 len as isize,
477 1,
478 vec_b.as_ptr(),
479 1,
480 len as isize,
481 0.0,
482 &mut result as *mut f32,
483 1,
484 1,
485 );
486 }
487 result
488}