1mod db;
2mod embedding;
3mod parsing;
4mod search;
5
6#[cfg(test)]
7mod vector_store_tests;
8
9use anyhow::{anyhow, Result};
10use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
12use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
13use language::{Language, LanguageRegistry};
14use parsing::Document;
15use project::{Fs, Project};
16use smol::channel;
17use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
18use tree_sitter::{Parser, QueryCursor};
19use util::{http::HttpClient, ResultExt, TryFutureExt};
20use workspace::{Workspace, WorkspaceCreated};
21
22actions!(semantic_search, [TestSearch]);
23
24pub fn init(
25 fs: Arc<dyn Fs>,
26 http_client: Arc<dyn HttpClient>,
27 language_registry: Arc<LanguageRegistry>,
28 cx: &mut AppContext,
29) {
30 let vector_store = cx.add_model(|cx| {
31 VectorStore::new(
32 fs,
33 VECTOR_DB_URL.to_string(),
34 Arc::new(OpenAIEmbeddings {
35 client: http_client,
36 }),
37 language_registry,
38 )
39 });
40
41 cx.subscribe_global::<WorkspaceCreated, _>({
42 let vector_store = vector_store.clone();
43 move |event, cx| {
44 let workspace = &event.0;
45 if let Some(workspace) = workspace.upgrade(cx) {
46 let project = workspace.read(cx).project().clone();
47 if project.read(cx).is_local() {
48 vector_store.update(cx, |store, cx| {
49 store.add_project(project, cx).detach();
50 });
51 }
52 }
53 }
54 })
55 .detach();
56
57 cx.add_action({
58 let vector_store = vector_store.clone();
59 move |workspace: &mut Workspace, _: &TestSearch, cx: &mut ViewContext<Workspace>| {
60 let t0 = std::time::Instant::now();
61 let task = vector_store.update(cx, |store, cx| {
62 store.search("compute embeddings for all of the symbols in the codebase and write them to a database".to_string(), 10, cx)
63 });
64
65 cx.spawn(|this, cx| async move {
66 let results = task.await?;
67 let duration = t0.elapsed();
68
69 println!("search took {:?}", duration);
70 println!("results {:?}", results);
71
72 anyhow::Ok(())
73 }).detach()
74 }
75 });
76}
77
78#[derive(Debug)]
79pub struct IndexedFile {
80 path: PathBuf,
81 sha1: FileSha1,
82 documents: Vec<Document>,
83}
84
85struct VectorStore {
86 fs: Arc<dyn Fs>,
87 database_url: Arc<str>,
88 embedding_provider: Arc<dyn EmbeddingProvider>,
89 language_registry: Arc<LanguageRegistry>,
90}
91
92#[derive(Debug)]
93pub struct SearchResult {
94 pub name: String,
95 pub offset: usize,
96 pub file_path: PathBuf,
97}
98
99impl VectorStore {
100 fn new(
101 fs: Arc<dyn Fs>,
102 database_url: String,
103 embedding_provider: Arc<dyn EmbeddingProvider>,
104 language_registry: Arc<LanguageRegistry>,
105 ) -> Self {
106 Self {
107 fs,
108 database_url: database_url.into(),
109 embedding_provider,
110 language_registry,
111 }
112 }
113
114 async fn index_file(
115 cursor: &mut QueryCursor,
116 parser: &mut Parser,
117 embedding_provider: &dyn EmbeddingProvider,
118 language: Arc<Language>,
119 file_path: PathBuf,
120 content: String,
121 ) -> Result<IndexedFile> {
122 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
123 let outline_config = grammar
124 .outline_config
125 .as_ref()
126 .ok_or_else(|| anyhow!("no outline query"))?;
127
128 parser.set_language(grammar.ts_language).unwrap();
129 let tree = parser
130 .parse(&content, None)
131 .ok_or_else(|| anyhow!("parsing failed"))?;
132
133 let mut documents = Vec::new();
134 let mut context_spans = Vec::new();
135 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
136 let mut item_range = None;
137 let mut name_range = None;
138 for capture in mat.captures {
139 if capture.index == outline_config.item_capture_ix {
140 item_range = Some(capture.node.byte_range());
141 } else if capture.index == outline_config.name_capture_ix {
142 name_range = Some(capture.node.byte_range());
143 }
144 }
145
146 if let Some((item_range, name_range)) = item_range.zip(name_range) {
147 if let Some((item, name)) =
148 content.get(item_range.clone()).zip(content.get(name_range))
149 {
150 context_spans.push(item);
151 documents.push(Document {
152 name: name.to_string(),
153 offset: item_range.start,
154 embedding: Vec::new(),
155 });
156 }
157 }
158 }
159
160 let embeddings = embedding_provider.embed_batch(context_spans).await?;
161 for (document, embedding) in documents.iter_mut().zip(embeddings) {
162 document.embedding = embedding;
163 }
164
165 let sha1 = FileSha1::from_str(content);
166
167 return Ok(IndexedFile {
168 path: file_path,
169 sha1,
170 documents,
171 });
172 }
173
174 fn add_project(
175 &mut self,
176 project: ModelHandle<Project>,
177 cx: &mut ModelContext<Self>,
178 ) -> Task<Result<()>> {
179 let worktree_scans_complete = project
180 .read(cx)
181 .worktrees(cx)
182 .map(|worktree| {
183 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
184 async move {
185 scan_complete.await;
186 log::info!("worktree scan completed");
187 }
188 })
189 .collect::<Vec<_>>();
190
191 let fs = self.fs.clone();
192 let language_registry = self.language_registry.clone();
193 let embedding_provider = self.embedding_provider.clone();
194 let database_url = self.database_url.clone();
195
196 cx.spawn(|_, cx| async move {
197 futures::future::join_all(worktree_scans_complete).await;
198
199 // TODO: remove this after fixing the bug in scan_complete
200 cx.background()
201 .timer(std::time::Duration::from_secs(3))
202 .await;
203
204 let db = VectorDatabase::new(&database_url)?;
205
206 let worktrees = project.read_with(&cx, |project, cx| {
207 project
208 .worktrees(cx)
209 .map(|worktree| worktree.read(cx).snapshot())
210 .collect::<Vec<_>>()
211 });
212
213 let worktree_root_paths = worktrees
214 .iter()
215 .map(|worktree| worktree.abs_path().clone())
216 .collect::<Vec<_>>();
217
218 // Here we query the worktree ids, and yet we dont have them elsewhere
219 // We likely want to clean up these datastructures
220 let (db, worktree_hashes, worktree_ids) = cx
221 .background()
222 .spawn(async move {
223 let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
224 let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
225 for worktree_root_path in worktree_root_paths {
226 let worktree_id =
227 db.find_or_create_worktree(worktree_root_path.as_ref())?;
228 worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
229 hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
230 }
231 anyhow::Ok((db, hashes, worktree_ids))
232 })
233 .await?;
234
235 let (paths_tx, paths_rx) =
236 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
237 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
238 cx.background()
239 .spawn({
240 let fs = fs.clone();
241 async move {
242 for worktree in worktrees.into_iter() {
243 let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
244 let file_hashes = &worktree_hashes[&worktree_id];
245 for file in worktree.files(false, 0) {
246 let absolute_path = worktree.absolutize(&file.path);
247
248 if let Ok(language) = language_registry
249 .language_for_file(&absolute_path, None)
250 .await
251 {
252 if language.name().as_ref() != "Rust" {
253 continue;
254 }
255
256 if let Some(content) = fs.load(&absolute_path).await.log_err() {
257 log::info!("loaded file: {absolute_path:?}");
258
259 let path_buf = file.path.to_path_buf();
260 let already_stored = file_hashes
261 .get(&path_buf)
262 .map_or(false, |existing_hash| {
263 existing_hash.equals(&content)
264 });
265
266 if !already_stored {
267 log::info!(
268 "File Changed (Sending to Parse): {:?}",
269 &path_buf
270 );
271 paths_tx
272 .try_send((
273 worktree_id,
274 path_buf,
275 content,
276 language,
277 ))
278 .unwrap();
279 }
280 }
281 }
282 }
283 }
284 }
285 })
286 .detach();
287
288 let db_write_task = cx.background().spawn(
289 async move {
290 // Initialize Database, creates database and tables if not exists
291 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
292 db.insert_file(worktree_id, indexed_file).log_err();
293 }
294
295 // ALL OF THE BELOW IS FOR TESTING,
296 // This should be removed as we find and appropriate place for evaluate our search.
297
298 // let queries = vec![
299 // "compute embeddings for all of the symbols in the codebase, and write them to a database",
300 // "compute an outline view of all of the symbols in a buffer",
301 // "scan a directory on the file system and load all of its children into an in-memory snapshot",
302 // ];
303 // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
304
305 // let t2 = Instant::now();
306 // let documents = db.get_documents().unwrap();
307 // let files = db.get_files().unwrap();
308 // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
309
310 // let t1 = Instant::now();
311 // let mut bfs = BruteForceSearch::load(&db).unwrap();
312 // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
313 // for (idx, embed) in embeddings.into_iter().enumerate() {
314 // let t0 = Instant::now();
315 // println!("\nQuery: {:?}", queries[idx]);
316 // let results = bfs.top_k_search(&embed, 5).await;
317 // println!("Search Elapsed: {}", t0.elapsed().as_millis());
318 // for (id, distance) in results {
319 // println!("");
320 // println!(" distance: {:?}", distance);
321 // println!(" document: {:?}", documents[&id].name);
322 // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
323 // }
324
325 // }
326
327 anyhow::Ok(())
328 }
329 .log_err(),
330 );
331
332 cx.background()
333 .scoped(|scope| {
334 for _ in 0..cx.background().num_cpus() {
335 scope.spawn(async {
336 let mut parser = Parser::new();
337 let mut cursor = QueryCursor::new();
338 while let Ok((worktree_id, file_path, content, language)) =
339 paths_rx.recv().await
340 {
341 if let Some(indexed_file) = Self::index_file(
342 &mut cursor,
343 &mut parser,
344 embedding_provider.as_ref(),
345 language,
346 file_path,
347 content,
348 )
349 .await
350 .log_err()
351 {
352 indexed_files_tx
353 .try_send((worktree_id, indexed_file))
354 .unwrap();
355 }
356 }
357 });
358 }
359 })
360 .await;
361 drop(indexed_files_tx);
362
363 db_write_task.await;
364 anyhow::Ok(())
365 })
366 }
367
368 pub fn search(
369 &mut self,
370 phrase: String,
371 limit: usize,
372 cx: &mut ModelContext<Self>,
373 ) -> Task<Result<Vec<SearchResult>>> {
374 let embedding_provider = self.embedding_provider.clone();
375 let database_url = self.database_url.clone();
376 cx.background().spawn(async move {
377 let database = VectorDatabase::new(database_url.as_ref())?;
378
379 let phrase_embedding = embedding_provider
380 .embed_batch(vec![&phrase])
381 .await?
382 .into_iter()
383 .next()
384 .unwrap();
385
386 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
387 database.for_each_document(0, |id, embedding| {
388 let similarity = dot(&embedding.0, &phrase_embedding);
389 let ix = match results.binary_search_by(|(_, s)| {
390 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
391 }) {
392 Ok(ix) => ix,
393 Err(ix) => ix,
394 };
395 results.insert(ix, (id, similarity));
396 results.truncate(limit);
397 })?;
398
399 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
400 let documents = database.get_documents_by_ids(&ids)?;
401
402 anyhow::Ok(
403 documents
404 .into_iter()
405 .map(|(file_path, offset, name)| SearchResult {
406 name,
407 offset,
408 file_path,
409 })
410 .collect(),
411 )
412 })
413 }
414}
415
416impl Entity for VectorStore {
417 type Event = ();
418}
419
420fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
421 let len = vec_a.len();
422 assert_eq!(len, vec_b.len());
423
424 let mut result = 0.0;
425 unsafe {
426 matrixmultiply::sgemm(
427 1,
428 len,
429 1,
430 1.0,
431 vec_a.as_ptr(),
432 len as isize,
433 1,
434 vec_b.as_ptr(),
435 1,
436 len as isize,
437 0.0,
438 &mut result as *mut f32,
439 1,
440 1,
441 );
442 }
443 result
444}