1mod db;
2mod embedding;
3mod parsing;
4mod search;
5
6#[cfg(test)]
7mod vector_store_tests;
8
9use anyhow::{anyhow, Result};
10use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
12use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
13use language::{Language, LanguageRegistry};
14use parsing::Document;
15use project::{Fs, Project};
16use smol::channel;
17use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
18use tree_sitter::{Parser, QueryCursor};
19use util::{http::HttpClient, ResultExt, TryFutureExt};
20use workspace::WorkspaceCreated;
21
22pub fn init(
23 fs: Arc<dyn Fs>,
24 http_client: Arc<dyn HttpClient>,
25 language_registry: Arc<LanguageRegistry>,
26 cx: &mut AppContext,
27) {
28 let vector_store = cx.add_model(|cx| {
29 VectorStore::new(
30 fs,
31 VECTOR_DB_URL.to_string(),
32 Arc::new(OpenAIEmbeddings {
33 client: http_client,
34 }),
35 language_registry,
36 )
37 });
38
39 cx.subscribe_global::<WorkspaceCreated, _>({
40 let vector_store = vector_store.clone();
41 move |event, cx| {
42 let workspace = &event.0;
43 if let Some(workspace) = workspace.upgrade(cx) {
44 let project = workspace.read(cx).project().clone();
45 if project.read(cx).is_local() {
46 vector_store.update(cx, |store, cx| {
47 store.add_project(project, cx).detach();
48 });
49 }
50 }
51 }
52 })
53 .detach();
54}
55
56#[derive(Debug)]
57pub struct IndexedFile {
58 path: PathBuf,
59 sha1: FileSha1,
60 documents: Vec<Document>,
61}
62
63struct VectorStore {
64 fs: Arc<dyn Fs>,
65 database_url: Arc<str>,
66 embedding_provider: Arc<dyn EmbeddingProvider>,
67 language_registry: Arc<LanguageRegistry>,
68}
69
70pub struct SearchResult {
71 pub name: String,
72 pub offset: usize,
73 pub file_path: PathBuf,
74}
75
76impl VectorStore {
77 fn new(
78 fs: Arc<dyn Fs>,
79 database_url: String,
80 embedding_provider: Arc<dyn EmbeddingProvider>,
81 language_registry: Arc<LanguageRegistry>,
82 ) -> Self {
83 Self {
84 fs,
85 database_url: database_url.into(),
86 embedding_provider,
87 language_registry,
88 }
89 }
90
91 async fn index_file(
92 cursor: &mut QueryCursor,
93 parser: &mut Parser,
94 embedding_provider: &dyn EmbeddingProvider,
95 language: Arc<Language>,
96 file_path: PathBuf,
97 content: String,
98 ) -> Result<IndexedFile> {
99 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
100 let outline_config = grammar
101 .outline_config
102 .as_ref()
103 .ok_or_else(|| anyhow!("no outline query"))?;
104
105 parser.set_language(grammar.ts_language).unwrap();
106 let tree = parser
107 .parse(&content, None)
108 .ok_or_else(|| anyhow!("parsing failed"))?;
109
110 let mut documents = Vec::new();
111 let mut context_spans = Vec::new();
112 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
113 let mut item_range = None;
114 let mut name_range = None;
115 for capture in mat.captures {
116 if capture.index == outline_config.item_capture_ix {
117 item_range = Some(capture.node.byte_range());
118 } else if capture.index == outline_config.name_capture_ix {
119 name_range = Some(capture.node.byte_range());
120 }
121 }
122
123 if let Some((item_range, name_range)) = item_range.zip(name_range) {
124 if let Some((item, name)) =
125 content.get(item_range.clone()).zip(content.get(name_range))
126 {
127 context_spans.push(item);
128 documents.push(Document {
129 name: name.to_string(),
130 offset: item_range.start,
131 embedding: Vec::new(),
132 });
133 }
134 }
135 }
136
137 let embeddings = embedding_provider.embed_batch(context_spans).await?;
138 for (document, embedding) in documents.iter_mut().zip(embeddings) {
139 document.embedding = embedding;
140 }
141
142 let sha1 = FileSha1::from_str(content);
143
144 return Ok(IndexedFile {
145 path: file_path,
146 sha1,
147 documents,
148 });
149 }
150
151 fn add_project(
152 &mut self,
153 project: ModelHandle<Project>,
154 cx: &mut ModelContext<Self>,
155 ) -> Task<Result<()>> {
156 let worktree_scans_complete = project
157 .read(cx)
158 .worktrees(cx)
159 .map(|worktree| {
160 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
161 async move {
162 scan_complete.await;
163 log::info!("worktree scan completed");
164 }
165 })
166 .collect::<Vec<_>>();
167
168 let fs = self.fs.clone();
169 let language_registry = self.language_registry.clone();
170 let embedding_provider = self.embedding_provider.clone();
171 let database_url = self.database_url.clone();
172
173 cx.spawn(|_, cx| async move {
174 futures::future::join_all(worktree_scans_complete).await;
175
176 // TODO: remove this after fixing the bug in scan_complete
177 cx.background()
178 .timer(std::time::Duration::from_secs(3))
179 .await;
180
181 let db = VectorDatabase::new(&database_url)?;
182
183 let worktrees = project.read_with(&cx, |project, cx| {
184 project
185 .worktrees(cx)
186 .map(|worktree| worktree.read(cx).snapshot())
187 .collect::<Vec<_>>()
188 });
189
190 let worktree_root_paths = worktrees
191 .iter()
192 .map(|worktree| worktree.abs_path().clone())
193 .collect::<Vec<_>>();
194
195 // Here we query the worktree ids, and yet we dont have them elsewhere
196 // We likely want to clean up these datastructures
197 let (db, worktree_hashes, worktree_ids) = cx
198 .background()
199 .spawn(async move {
200 let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
201 let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
202 for worktree_root_path in worktree_root_paths {
203 let worktree_id =
204 db.find_or_create_worktree(worktree_root_path.as_ref())?;
205 worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
206 hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
207 }
208 anyhow::Ok((db, hashes, worktree_ids))
209 })
210 .await?;
211
212 let (paths_tx, paths_rx) =
213 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
214 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
215 cx.background()
216 .spawn({
217 let fs = fs.clone();
218 async move {
219 for worktree in worktrees.into_iter() {
220 let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
221 let file_hashes = &worktree_hashes[&worktree_id];
222 for file in worktree.files(false, 0) {
223 let absolute_path = worktree.absolutize(&file.path);
224
225 if let Ok(language) = language_registry
226 .language_for_file(&absolute_path, None)
227 .await
228 {
229 if language.name().as_ref() != "Rust" {
230 continue;
231 }
232
233 if let Some(content) = fs.load(&absolute_path).await.log_err() {
234 log::info!("loaded file: {absolute_path:?}");
235
236 let path_buf = file.path.to_path_buf();
237 let already_stored = file_hashes
238 .get(&path_buf)
239 .map_or(false, |existing_hash| {
240 existing_hash.equals(&content)
241 });
242
243 if !already_stored {
244 log::info!(
245 "File Changed (Sending to Parse): {:?}",
246 &path_buf
247 );
248 paths_tx
249 .try_send((
250 worktree_id,
251 path_buf,
252 content,
253 language,
254 ))
255 .unwrap();
256 }
257 }
258 }
259 }
260 }
261 }
262 })
263 .detach();
264
265 let db_write_task = cx.background().spawn(
266 async move {
267 // Initialize Database, creates database and tables if not exists
268 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
269 db.insert_file(worktree_id, indexed_file).log_err();
270 }
271
272 // ALL OF THE BELOW IS FOR TESTING,
273 // This should be removed as we find and appropriate place for evaluate our search.
274
275 // let queries = vec![
276 // "compute embeddings for all of the symbols in the codebase, and write them to a database",
277 // "compute an outline view of all of the symbols in a buffer",
278 // "scan a directory on the file system and load all of its children into an in-memory snapshot",
279 // ];
280 // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
281
282 // let t2 = Instant::now();
283 // let documents = db.get_documents().unwrap();
284 // let files = db.get_files().unwrap();
285 // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
286
287 // let t1 = Instant::now();
288 // let mut bfs = BruteForceSearch::load(&db).unwrap();
289 // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
290 // for (idx, embed) in embeddings.into_iter().enumerate() {
291 // let t0 = Instant::now();
292 // println!("\nQuery: {:?}", queries[idx]);
293 // let results = bfs.top_k_search(&embed, 5).await;
294 // println!("Search Elapsed: {}", t0.elapsed().as_millis());
295 // for (id, distance) in results {
296 // println!("");
297 // println!(" distance: {:?}", distance);
298 // println!(" document: {:?}", documents[&id].name);
299 // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
300 // }
301
302 // }
303
304 anyhow::Ok(())
305 }
306 .log_err(),
307 );
308
309 cx.background()
310 .scoped(|scope| {
311 for _ in 0..cx.background().num_cpus() {
312 scope.spawn(async {
313 let mut parser = Parser::new();
314 let mut cursor = QueryCursor::new();
315 while let Ok((worktree_id, file_path, content, language)) =
316 paths_rx.recv().await
317 {
318 if let Some(indexed_file) = Self::index_file(
319 &mut cursor,
320 &mut parser,
321 embedding_provider.as_ref(),
322 language,
323 file_path,
324 content,
325 )
326 .await
327 .log_err()
328 {
329 indexed_files_tx
330 .try_send((worktree_id, indexed_file))
331 .unwrap();
332 }
333 }
334 });
335 }
336 })
337 .await;
338 drop(indexed_files_tx);
339
340 db_write_task.await;
341 anyhow::Ok(())
342 })
343 }
344
345 pub fn search(
346 &mut self,
347 phrase: String,
348 limit: usize,
349 cx: &mut ModelContext<Self>,
350 ) -> Task<Result<Vec<SearchResult>>> {
351 let embedding_provider = self.embedding_provider.clone();
352 let database_url = self.database_url.clone();
353 cx.background().spawn(async move {
354 let database = VectorDatabase::new(database_url.as_ref())?;
355
356 let phrase_embedding = embedding_provider
357 .embed_batch(vec![&phrase])
358 .await?
359 .into_iter()
360 .next()
361 .unwrap();
362
363 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
364 database.for_each_document(0, |id, embedding| {
365 let similarity = dot(&embedding.0, &phrase_embedding);
366 let ix = match results.binary_search_by(|(_, s)| {
367 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
368 }) {
369 Ok(ix) => ix,
370 Err(ix) => ix,
371 };
372 results.insert(ix, (id, similarity));
373 results.truncate(limit);
374 })?;
375
376 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
377 let documents = database.get_documents_by_ids(&ids)?;
378
379 anyhow::Ok(
380 documents
381 .into_iter()
382 .map(|(file_path, offset, name)| SearchResult {
383 name,
384 offset,
385 file_path,
386 })
387 .collect(),
388 )
389 })
390 }
391}
392
393impl Entity for VectorStore {
394 type Event = ();
395}
396
397fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
398 let len = vec_a.len();
399 assert_eq!(len, vec_b.len());
400
401 let mut result = 0.0;
402 unsafe {
403 matrixmultiply::sgemm(
404 1,
405 len,
406 1,
407 1.0,
408 vec_a.as_ptr(),
409 len as isize,
410 1,
411 vec_b.as_ptr(),
412 1,
413 len as isize,
414 0.0,
415 &mut result as *mut f32,
416 1,
417 1,
418 );
419 }
420 result
421}