1mod db;
2mod embedding;
3mod modal;
4mod search;
5
6#[cfg(test)]
7mod vector_store_tests;
8
9use anyhow::{anyhow, Result};
10use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
12use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
13use language::{Language, LanguageRegistry};
14use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
15use project::{Fs, Project};
16use smol::channel;
17use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
18use tree_sitter::{Parser, QueryCursor};
19use util::{http::HttpClient, ResultExt, TryFutureExt};
20use workspace::{Workspace, WorkspaceCreated};
21
22#[derive(Debug)]
23pub struct Document {
24 pub offset: usize,
25 pub name: String,
26 pub embedding: Vec<f32>,
27}
28
29pub fn init(
30 fs: Arc<dyn Fs>,
31 http_client: Arc<dyn HttpClient>,
32 language_registry: Arc<LanguageRegistry>,
33 cx: &mut AppContext,
34) {
35 let vector_store = cx.add_model(|cx| {
36 VectorStore::new(
37 fs,
38 VECTOR_DB_URL.to_string(),
39 Arc::new(OpenAIEmbeddings {
40 client: http_client,
41 }),
42 language_registry,
43 )
44 });
45
46 cx.subscribe_global::<WorkspaceCreated, _>({
47 let vector_store = vector_store.clone();
48 move |event, cx| {
49 let workspace = &event.0;
50 if let Some(workspace) = workspace.upgrade(cx) {
51 let project = workspace.read(cx).project().clone();
52 if project.read(cx).is_local() {
53 vector_store.update(cx, |store, cx| {
54 store.add_project(project, cx).detach();
55 });
56 }
57 }
58 }
59 })
60 .detach();
61
62 cx.add_action({
63 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
64 let vector_store = vector_store.clone();
65 workspace.toggle_modal(cx, |workspace, cx| {
66 let project = workspace.project().clone();
67 let workspace = cx.weak_handle();
68 cx.add_view(|cx| {
69 SemanticSearch::new(
70 SemanticSearchDelegate::new(workspace, project, vector_store),
71 cx,
72 )
73 })
74 })
75 }
76 });
77 SemanticSearch::init(cx);
78 // cx.add_action({
79 // let vector_store = vector_store.clone();
80 // move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
81 // let t0 = std::time::Instant::now();
82 // let task = vector_store.update(cx, |store, cx| {
83 // store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
84 // });
85
86 // cx.spawn(|this, cx| async move {
87 // let results = task.await?;
88 // let duration = t0.elapsed();
89
90 // println!("search took {:?}", duration);
91 // println!("results {:?}", results);
92
93 // anyhow::Ok(())
94 // }).detach()
95 // }
96 // });
97}
98
99#[derive(Debug)]
100pub struct IndexedFile {
101 path: PathBuf,
102 sha1: FileSha1,
103 documents: Vec<Document>,
104}
105
106pub struct VectorStore {
107 fs: Arc<dyn Fs>,
108 database_url: Arc<str>,
109 embedding_provider: Arc<dyn EmbeddingProvider>,
110 language_registry: Arc<LanguageRegistry>,
111}
112
113#[derive(Debug)]
114pub struct SearchResult {
115 pub name: String,
116 pub offset: usize,
117 pub file_path: PathBuf,
118}
119
120impl VectorStore {
121 fn new(
122 fs: Arc<dyn Fs>,
123 database_url: String,
124 embedding_provider: Arc<dyn EmbeddingProvider>,
125 language_registry: Arc<LanguageRegistry>,
126 ) -> Self {
127 Self {
128 fs,
129 database_url: database_url.into(),
130 embedding_provider,
131 language_registry,
132 }
133 }
134
135 async fn index_file(
136 cursor: &mut QueryCursor,
137 parser: &mut Parser,
138 embedding_provider: &dyn EmbeddingProvider,
139 language: Arc<Language>,
140 file_path: PathBuf,
141 content: String,
142 ) -> Result<IndexedFile> {
143 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
144 let outline_config = grammar
145 .outline_config
146 .as_ref()
147 .ok_or_else(|| anyhow!("no outline query"))?;
148
149 parser.set_language(grammar.ts_language).unwrap();
150 let tree = parser
151 .parse(&content, None)
152 .ok_or_else(|| anyhow!("parsing failed"))?;
153
154 let mut documents = Vec::new();
155 let mut context_spans = Vec::new();
156 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
157 let mut item_range = None;
158 let mut name_range = None;
159 for capture in mat.captures {
160 if capture.index == outline_config.item_capture_ix {
161 item_range = Some(capture.node.byte_range());
162 } else if capture.index == outline_config.name_capture_ix {
163 name_range = Some(capture.node.byte_range());
164 }
165 }
166
167 if let Some((item_range, name_range)) = item_range.zip(name_range) {
168 if let Some((item, name)) =
169 content.get(item_range.clone()).zip(content.get(name_range))
170 {
171 context_spans.push(item);
172 documents.push(Document {
173 name: name.to_string(),
174 offset: item_range.start,
175 embedding: Vec::new(),
176 });
177 }
178 }
179 }
180
181 let embeddings = embedding_provider.embed_batch(context_spans).await?;
182 for (document, embedding) in documents.iter_mut().zip(embeddings) {
183 document.embedding = embedding;
184 }
185
186 let sha1 = FileSha1::from_str(content);
187
188 return Ok(IndexedFile {
189 path: file_path,
190 sha1,
191 documents,
192 });
193 }
194
195 fn add_project(
196 &mut self,
197 project: ModelHandle<Project>,
198 cx: &mut ModelContext<Self>,
199 ) -> Task<Result<()>> {
200 let worktree_scans_complete = project
201 .read(cx)
202 .worktrees(cx)
203 .map(|worktree| {
204 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
205 async move {
206 scan_complete.await;
207 log::info!("worktree scan completed");
208 }
209 })
210 .collect::<Vec<_>>();
211
212 let fs = self.fs.clone();
213 let language_registry = self.language_registry.clone();
214 let embedding_provider = self.embedding_provider.clone();
215 let database_url = self.database_url.clone();
216
217 cx.spawn(|_, cx| async move {
218 futures::future::join_all(worktree_scans_complete).await;
219
220 // TODO: remove this after fixing the bug in scan_complete
221 cx.background()
222 .timer(std::time::Duration::from_secs(3))
223 .await;
224
225 let db = VectorDatabase::new(&database_url)?;
226
227 let worktrees = project.read_with(&cx, |project, cx| {
228 project
229 .worktrees(cx)
230 .map(|worktree| worktree.read(cx).snapshot())
231 .collect::<Vec<_>>()
232 });
233
234 let worktree_root_paths = worktrees
235 .iter()
236 .map(|worktree| worktree.abs_path().clone())
237 .collect::<Vec<_>>();
238
239 // Here we query the worktree ids, and yet we dont have them elsewhere
240 // We likely want to clean up these datastructures
241 let (db, worktree_hashes, worktree_ids) = cx
242 .background()
243 .spawn(async move {
244 let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
245 let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
246 for worktree_root_path in worktree_root_paths {
247 let worktree_id =
248 db.find_or_create_worktree(worktree_root_path.as_ref())?;
249 worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
250 hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
251 }
252 anyhow::Ok((db, hashes, worktree_ids))
253 })
254 .await?;
255
256 let (paths_tx, paths_rx) =
257 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
258 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
259 cx.background()
260 .spawn({
261 let fs = fs.clone();
262 async move {
263 for worktree in worktrees.into_iter() {
264 let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
265 let file_hashes = &worktree_hashes[&worktree_id];
266 for file in worktree.files(false, 0) {
267 let absolute_path = worktree.absolutize(&file.path);
268
269 if let Ok(language) = language_registry
270 .language_for_file(&absolute_path, None)
271 .await
272 {
273 if language.name().as_ref() != "Rust" {
274 continue;
275 }
276
277 if let Some(content) = fs.load(&absolute_path).await.log_err() {
278 log::info!("loaded file: {absolute_path:?}");
279
280 let path_buf = file.path.to_path_buf();
281 let already_stored = file_hashes
282 .get(&path_buf)
283 .map_or(false, |existing_hash| {
284 existing_hash.equals(&content)
285 });
286
287 if !already_stored {
288 log::info!(
289 "File Changed (Sending to Parse): {:?}",
290 &path_buf
291 );
292 paths_tx
293 .try_send((
294 worktree_id,
295 path_buf,
296 content,
297 language,
298 ))
299 .unwrap();
300 }
301 }
302 }
303 }
304 }
305 }
306 })
307 .detach();
308
309 let db_write_task = cx.background().spawn(
310 async move {
311 // Initialize Database, creates database and tables if not exists
312 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
313 db.insert_file(worktree_id, indexed_file).log_err();
314 }
315
316 // ALL OF THE BELOW IS FOR TESTING,
317 // This should be removed as we find and appropriate place for evaluate our search.
318
319 // let queries = vec![
320 // "compute embeddings for all of the symbols in the codebase, and write them to a database",
321 // "compute an outline view of all of the symbols in a buffer",
322 // "scan a directory on the file system and load all of its children into an in-memory snapshot",
323 // ];
324 // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
325
326 // let t2 = Instant::now();
327 // let documents = db.get_documents().unwrap();
328 // let files = db.get_files().unwrap();
329 // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
330
331 // let t1 = Instant::now();
332 // let mut bfs = BruteForceSearch::load(&db).unwrap();
333 // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
334 // for (idx, embed) in embeddings.into_iter().enumerate() {
335 // let t0 = Instant::now();
336 // println!("\nQuery: {:?}", queries[idx]);
337 // let results = bfs.top_k_search(&embed, 5).await;
338 // println!("Search Elapsed: {}", t0.elapsed().as_millis());
339 // for (id, distance) in results {
340 // println!("");
341 // println!(" distance: {:?}", distance);
342 // println!(" document: {:?}", documents[&id].name);
343 // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
344 // }
345
346 // }
347
348 anyhow::Ok(())
349 }
350 .log_err(),
351 );
352
353 cx.background()
354 .scoped(|scope| {
355 for _ in 0..cx.background().num_cpus() {
356 scope.spawn(async {
357 let mut parser = Parser::new();
358 let mut cursor = QueryCursor::new();
359 while let Ok((worktree_id, file_path, content, language)) =
360 paths_rx.recv().await
361 {
362 if let Some(indexed_file) = Self::index_file(
363 &mut cursor,
364 &mut parser,
365 embedding_provider.as_ref(),
366 language,
367 file_path,
368 content,
369 )
370 .await
371 .log_err()
372 {
373 indexed_files_tx
374 .try_send((worktree_id, indexed_file))
375 .unwrap();
376 }
377 }
378 });
379 }
380 })
381 .await;
382 drop(indexed_files_tx);
383
384 db_write_task.await;
385 anyhow::Ok(())
386 })
387 }
388
389 pub fn search(
390 &mut self,
391 phrase: String,
392 limit: usize,
393 cx: &mut ModelContext<Self>,
394 ) -> Task<Result<Vec<SearchResult>>> {
395 let embedding_provider = self.embedding_provider.clone();
396 let database_url = self.database_url.clone();
397 cx.background().spawn(async move {
398 let database = VectorDatabase::new(database_url.as_ref())?;
399
400 let phrase_embedding = embedding_provider
401 .embed_batch(vec![&phrase])
402 .await?
403 .into_iter()
404 .next()
405 .unwrap();
406
407 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
408 database.for_each_document(0, |id, embedding| {
409 let similarity = dot(&embedding.0, &phrase_embedding);
410 let ix = match results.binary_search_by(|(_, s)| {
411 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
412 }) {
413 Ok(ix) => ix,
414 Err(ix) => ix,
415 };
416 results.insert(ix, (id, similarity));
417 results.truncate(limit);
418 })?;
419
420 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
421 let documents = database.get_documents_by_ids(&ids)?;
422
423 anyhow::Ok(
424 documents
425 .into_iter()
426 .map(|(file_path, offset, name)| SearchResult {
427 name,
428 offset,
429 file_path,
430 })
431 .collect(),
432 )
433 })
434 }
435}
436
437impl Entity for VectorStore {
438 type Event = ();
439}
440
441fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
442 let len = vec_a.len();
443 assert_eq!(len, vec_b.len());
444
445 let mut result = 0.0;
446 unsafe {
447 matrixmultiply::sgemm(
448 1,
449 len,
450 1,
451 1.0,
452 vec_a.as_ptr(),
453 len as isize,
454 1,
455 vec_b.as_ptr(),
456 1,
457 len as isize,
458 0.0,
459 &mut result as *mut f32,
460 1,
461 1,
462 );
463 }
464 result
465}