1mod db;
2mod embedding;
3mod search;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
12use language::{Language, LanguageRegistry};
13use project::{Fs, Project};
14use smol::channel;
15use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
16use tree_sitter::{Parser, QueryCursor};
17use util::{http::HttpClient, ResultExt, TryFutureExt};
18use workspace::{Workspace, WorkspaceCreated};
19
20actions!(semantic_search, [TestSearch]);
21
22#[derive(Debug)]
23pub struct Document {
24 pub offset: usize,
25 pub name: String,
26 pub embedding: Vec<f32>,
27}
28
29pub fn init(
30 fs: Arc<dyn Fs>,
31 http_client: Arc<dyn HttpClient>,
32 language_registry: Arc<LanguageRegistry>,
33 cx: &mut AppContext,
34) {
35 let vector_store = cx.add_model(|cx| {
36 VectorStore::new(
37 fs,
38 VECTOR_DB_URL.to_string(),
39 Arc::new(OpenAIEmbeddings {
40 client: http_client,
41 }),
42 language_registry,
43 )
44 });
45
46 cx.subscribe_global::<WorkspaceCreated, _>({
47 let vector_store = vector_store.clone();
48 move |event, cx| {
49 let workspace = &event.0;
50 if let Some(workspace) = workspace.upgrade(cx) {
51 let project = workspace.read(cx).project().clone();
52 if project.read(cx).is_local() {
53 vector_store.update(cx, |store, cx| {
54 store.add_project(project, cx).detach();
55 });
56 }
57 }
58 }
59 })
60 .detach();
61
62 cx.add_action({
63 let vector_store = vector_store.clone();
64 move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
65 let t0 = std::time::Instant::now();
66 let task = vector_store.update(cx, |store, cx| {
67 store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
68 });
69
70 cx.spawn(|this, cx| async move {
71 let results = task.await?;
72 let duration = t0.elapsed();
73
74 println!("search took {:?}", duration);
75 println!("results {:?}", results);
76
77 anyhow::Ok(())
78 }).detach()
79 }
80 });
81}
82
83#[derive(Debug)]
84pub struct IndexedFile {
85 path: PathBuf,
86 sha1: FileSha1,
87 documents: Vec<Document>,
88}
89
90struct VectorStore {
91 fs: Arc<dyn Fs>,
92 database_url: Arc<str>,
93 embedding_provider: Arc<dyn EmbeddingProvider>,
94 language_registry: Arc<LanguageRegistry>,
95}
96
97#[derive(Debug)]
98pub struct SearchResult {
99 pub name: String,
100 pub offset: usize,
101 pub file_path: PathBuf,
102}
103
104impl VectorStore {
105 fn new(
106 fs: Arc<dyn Fs>,
107 database_url: String,
108 embedding_provider: Arc<dyn EmbeddingProvider>,
109 language_registry: Arc<LanguageRegistry>,
110 ) -> Self {
111 Self {
112 fs,
113 database_url: database_url.into(),
114 embedding_provider,
115 language_registry,
116 }
117 }
118
119 async fn index_file(
120 cursor: &mut QueryCursor,
121 parser: &mut Parser,
122 embedding_provider: &dyn EmbeddingProvider,
123 language: Arc<Language>,
124 file_path: PathBuf,
125 content: String,
126 ) -> Result<IndexedFile> {
127 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
128 let outline_config = grammar
129 .outline_config
130 .as_ref()
131 .ok_or_else(|| anyhow!("no outline query"))?;
132
133 parser.set_language(grammar.ts_language).unwrap();
134 let tree = parser
135 .parse(&content, None)
136 .ok_or_else(|| anyhow!("parsing failed"))?;
137
138 let mut documents = Vec::new();
139 let mut context_spans = Vec::new();
140 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
141 let mut item_range = None;
142 let mut name_range = None;
143 for capture in mat.captures {
144 if capture.index == outline_config.item_capture_ix {
145 item_range = Some(capture.node.byte_range());
146 } else if capture.index == outline_config.name_capture_ix {
147 name_range = Some(capture.node.byte_range());
148 }
149 }
150
151 if let Some((item_range, name_range)) = item_range.zip(name_range) {
152 if let Some((item, name)) =
153 content.get(item_range.clone()).zip(content.get(name_range))
154 {
155 context_spans.push(item);
156 documents.push(Document {
157 name: name.to_string(),
158 offset: item_range.start,
159 embedding: Vec::new(),
160 });
161 }
162 }
163 }
164
165 let embeddings = embedding_provider.embed_batch(context_spans).await?;
166 for (document, embedding) in documents.iter_mut().zip(embeddings) {
167 document.embedding = embedding;
168 }
169
170 let sha1 = FileSha1::from_str(content);
171
172 return Ok(IndexedFile {
173 path: file_path,
174 sha1,
175 documents,
176 });
177 }
178
179 fn add_project(
180 &mut self,
181 project: ModelHandle<Project>,
182 cx: &mut ModelContext<Self>,
183 ) -> Task<Result<()>> {
184 let worktree_scans_complete = project
185 .read(cx)
186 .worktrees(cx)
187 .map(|worktree| {
188 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
189 async move {
190 scan_complete.await;
191 log::info!("worktree scan completed");
192 }
193 })
194 .collect::<Vec<_>>();
195
196 let fs = self.fs.clone();
197 let language_registry = self.language_registry.clone();
198 let embedding_provider = self.embedding_provider.clone();
199 let database_url = self.database_url.clone();
200
201 cx.spawn(|_, cx| async move {
202 futures::future::join_all(worktree_scans_complete).await;
203
204 // TODO: remove this after fixing the bug in scan_complete
205 cx.background()
206 .timer(std::time::Duration::from_secs(3))
207 .await;
208
209 let db = VectorDatabase::new(&database_url)?;
210
211 let worktrees = project.read_with(&cx, |project, cx| {
212 project
213 .worktrees(cx)
214 .map(|worktree| worktree.read(cx).snapshot())
215 .collect::<Vec<_>>()
216 });
217
218 let worktree_root_paths = worktrees
219 .iter()
220 .map(|worktree| worktree.abs_path().clone())
221 .collect::<Vec<_>>();
222
223 // Here we query the worktree ids, and yet we dont have them elsewhere
224 // We likely want to clean up these datastructures
225 let (db, worktree_hashes, worktree_ids) = cx
226 .background()
227 .spawn(async move {
228 let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
229 let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
230 for worktree_root_path in worktree_root_paths {
231 let worktree_id =
232 db.find_or_create_worktree(worktree_root_path.as_ref())?;
233 worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
234 hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
235 }
236 anyhow::Ok((db, hashes, worktree_ids))
237 })
238 .await?;
239
240 let (paths_tx, paths_rx) =
241 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
242 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
243 cx.background()
244 .spawn({
245 let fs = fs.clone();
246 async move {
247 for worktree in worktrees.into_iter() {
248 let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
249 let file_hashes = &worktree_hashes[&worktree_id];
250 for file in worktree.files(false, 0) {
251 let absolute_path = worktree.absolutize(&file.path);
252
253 if let Ok(language) = language_registry
254 .language_for_file(&absolute_path, None)
255 .await
256 {
257 if language.name().as_ref() != "Rust" {
258 continue;
259 }
260
261 if let Some(content) = fs.load(&absolute_path).await.log_err() {
262 log::info!("loaded file: {absolute_path:?}");
263
264 let path_buf = file.path.to_path_buf();
265 let already_stored = file_hashes
266 .get(&path_buf)
267 .map_or(false, |existing_hash| {
268 existing_hash.equals(&content)
269 });
270
271 if !already_stored {
272 log::info!(
273 "File Changed (Sending to Parse): {:?}",
274 &path_buf
275 );
276 paths_tx
277 .try_send((
278 worktree_id,
279 path_buf,
280 content,
281 language,
282 ))
283 .unwrap();
284 }
285 }
286 }
287 }
288 }
289 }
290 })
291 .detach();
292
293 let db_write_task = cx.background().spawn(
294 async move {
295 // Initialize Database, creates database and tables if not exists
296 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
297 db.insert_file(worktree_id, indexed_file).log_err();
298 }
299
300 // ALL OF THE BELOW IS FOR TESTING,
301 // This should be removed as we find and appropriate place for evaluate our search.
302
303 // let queries = vec![
304 // "compute embeddings for all of the symbols in the codebase, and write them to a database",
305 // "compute an outline view of all of the symbols in a buffer",
306 // "scan a directory on the file system and load all of its children into an in-memory snapshot",
307 // ];
308 // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
309
310 // let t2 = Instant::now();
311 // let documents = db.get_documents().unwrap();
312 // let files = db.get_files().unwrap();
313 // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
314
315 // let t1 = Instant::now();
316 // let mut bfs = BruteForceSearch::load(&db).unwrap();
317 // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
318 // for (idx, embed) in embeddings.into_iter().enumerate() {
319 // let t0 = Instant::now();
320 // println!("\nQuery: {:?}", queries[idx]);
321 // let results = bfs.top_k_search(&embed, 5).await;
322 // println!("Search Elapsed: {}", t0.elapsed().as_millis());
323 // for (id, distance) in results {
324 // println!("");
325 // println!(" distance: {:?}", distance);
326 // println!(" document: {:?}", documents[&id].name);
327 // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
328 // }
329
330 // }
331
332 anyhow::Ok(())
333 }
334 .log_err(),
335 );
336
337 cx.background()
338 .scoped(|scope| {
339 for _ in 0..cx.background().num_cpus() {
340 scope.spawn(async {
341 let mut parser = Parser::new();
342 let mut cursor = QueryCursor::new();
343 while let Ok((worktree_id, file_path, content, language)) =
344 paths_rx.recv().await
345 {
346 if let Some(indexed_file) = Self::index_file(
347 &mut cursor,
348 &mut parser,
349 embedding_provider.as_ref(),
350 language,
351 file_path,
352 content,
353 )
354 .await
355 .log_err()
356 {
357 indexed_files_tx
358 .try_send((worktree_id, indexed_file))
359 .unwrap();
360 }
361 }
362 });
363 }
364 })
365 .await;
366 drop(indexed_files_tx);
367
368 db_write_task.await;
369 anyhow::Ok(())
370 })
371 }
372
373 pub fn search(
374 &mut self,
375 phrase: String,
376 limit: usize,
377 cx: &mut ModelContext<Self>,
378 ) -> Task<Result<Vec<SearchResult>>> {
379 let embedding_provider = self.embedding_provider.clone();
380 let database_url = self.database_url.clone();
381 cx.background().spawn(async move {
382 let database = VectorDatabase::new(database_url.as_ref())?;
383
384 let phrase_embedding = embedding_provider
385 .embed_batch(vec![&phrase])
386 .await?
387 .into_iter()
388 .next()
389 .unwrap();
390
391 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
392 database.for_each_document(0, |id, embedding| {
393 let similarity = dot(&embedding.0, &phrase_embedding);
394 let ix = match results.binary_search_by(|(_, s)| {
395 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
396 }) {
397 Ok(ix) => ix,
398 Err(ix) => ix,
399 };
400 results.insert(ix, (id, similarity));
401 results.truncate(limit);
402 })?;
403
404 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
405 let documents = database.get_documents_by_ids(&ids)?;
406
407 anyhow::Ok(
408 documents
409 .into_iter()
410 .map(|(file_path, offset, name)| SearchResult {
411 name,
412 offset,
413 file_path,
414 })
415 .collect(),
416 )
417 })
418 }
419}
420
421impl Entity for VectorStore {
422 type Event = ();
423}
424
425fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
426 let len = vec_a.len();
427 assert_eq!(len, vec_b.len());
428
429 let mut result = 0.0;
430 unsafe {
431 matrixmultiply::sgemm(
432 1,
433 len,
434 1,
435 1.0,
436 vec_a.as_ptr(),
437 len as isize,
438 1,
439 vec_b.as_ptr(),
440 1,
441 len as isize,
442 0.0,
443 &mut result as *mut f32,
444 1,
445 1,
446 );
447 }
448 result
449}