1mod db;
2mod embedding;
3mod parsing;
4mod search;
5
6#[cfg(test)]
7mod vector_store_tests;
8
9use anyhow::{anyhow, Result};
10use db::{VectorDatabase, VECTOR_DB_URL};
11use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
12use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
13use language::LanguageRegistry;
14use parsing::Document;
15use project::{Fs, Project};
16use search::{BruteForceSearch, VectorSearch};
17use smol::channel;
18use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
19use tree_sitter::{Parser, QueryCursor};
20use util::{http::HttpClient, ResultExt, TryFutureExt};
21use workspace::WorkspaceCreated;
22
23pub fn init(
24 fs: Arc<dyn Fs>,
25 http_client: Arc<dyn HttpClient>,
26 language_registry: Arc<LanguageRegistry>,
27 cx: &mut AppContext,
28) {
29 let vector_store = cx.add_model(|cx| {
30 VectorStore::new(
31 fs,
32 VECTOR_DB_URL.to_string(),
33 Arc::new(OpenAIEmbeddings {
34 client: http_client,
35 }),
36 language_registry,
37 )
38 });
39
40 cx.subscribe_global::<WorkspaceCreated, _>({
41 let vector_store = vector_store.clone();
42 move |event, cx| {
43 let workspace = &event.0;
44 if let Some(workspace) = workspace.upgrade(cx) {
45 let project = workspace.read(cx).project().clone();
46 if project.read(cx).is_local() {
47 vector_store.update(cx, |store, cx| {
48 store.add_project(project, cx);
49 });
50 }
51 }
52 }
53 })
54 .detach();
55}
56
57#[derive(Debug)]
58pub struct IndexedFile {
59 path: PathBuf,
60 sha1: String,
61 documents: Vec<Document>,
62}
63
64// struct SearchResult {
65// path: PathBuf,
66// offset: usize,
67// name: String,
68// distance: f32,
69// }
70struct VectorStore {
71 fs: Arc<dyn Fs>,
72 database_url: Arc<str>,
73 embedding_provider: Arc<dyn EmbeddingProvider>,
74 language_registry: Arc<LanguageRegistry>,
75}
76
77pub struct SearchResult {
78 pub name: String,
79 pub offset: usize,
80 pub file_path: PathBuf,
81}
82
83impl VectorStore {
84 fn new(
85 fs: Arc<dyn Fs>,
86 database_url: String,
87 embedding_provider: Arc<dyn EmbeddingProvider>,
88 language_registry: Arc<LanguageRegistry>,
89 ) -> Self {
90 Self {
91 fs,
92 database_url: database_url.into(),
93 embedding_provider,
94 language_registry,
95 }
96 }
97
98 async fn index_file(
99 cursor: &mut QueryCursor,
100 parser: &mut Parser,
101 embedding_provider: &dyn EmbeddingProvider,
102 language_registry: &Arc<LanguageRegistry>,
103 file_path: PathBuf,
104 content: String,
105 ) -> Result<IndexedFile> {
106 dbg!(&file_path, &content);
107
108 let language = language_registry
109 .language_for_file(&file_path, None)
110 .await?;
111
112 if language.name().as_ref() != "Rust" {
113 Err(anyhow!("unsupported language"))?;
114 }
115
116 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
117 let outline_config = grammar
118 .outline_config
119 .as_ref()
120 .ok_or_else(|| anyhow!("no outline query"))?;
121
122 parser.set_language(grammar.ts_language).unwrap();
123 let tree = parser
124 .parse(&content, None)
125 .ok_or_else(|| anyhow!("parsing failed"))?;
126
127 let mut documents = Vec::new();
128 let mut context_spans = Vec::new();
129 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
130 let mut item_range = None;
131 let mut name_range = None;
132 for capture in mat.captures {
133 if capture.index == outline_config.item_capture_ix {
134 item_range = Some(capture.node.byte_range());
135 } else if capture.index == outline_config.name_capture_ix {
136 name_range = Some(capture.node.byte_range());
137 }
138 }
139
140 if let Some((item_range, name_range)) = item_range.zip(name_range) {
141 if let Some((item, name)) =
142 content.get(item_range.clone()).zip(content.get(name_range))
143 {
144 context_spans.push(item);
145 documents.push(Document {
146 name: name.to_string(),
147 offset: item_range.start,
148 embedding: Vec::new(),
149 });
150 }
151 }
152 }
153
154 let embeddings = embedding_provider.embed_batch(context_spans).await?;
155 for (document, embedding) in documents.iter_mut().zip(embeddings) {
156 document.embedding = embedding;
157 }
158
159 return Ok(IndexedFile {
160 path: file_path,
161 sha1: String::new(),
162 documents,
163 });
164 }
165
166 fn add_project(
167 &mut self,
168 project: ModelHandle<Project>,
169 cx: &mut ModelContext<Self>,
170 ) -> Task<Result<()>> {
171 let worktree_scans_complete = project
172 .read(cx)
173 .worktrees(cx)
174 .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
175 .collect::<Vec<_>>();
176
177 let fs = self.fs.clone();
178 let language_registry = self.language_registry.clone();
179 let embedding_provider = self.embedding_provider.clone();
180 let database_url = self.database_url.clone();
181
182 cx.spawn(|_, cx| async move {
183 futures::future::join_all(worktree_scans_complete).await;
184
185 let worktrees = project.read_with(&cx, |project, cx| {
186 project
187 .worktrees(cx)
188 .map(|worktree| worktree.read(cx).snapshot())
189 .collect::<Vec<_>>()
190 });
191
192 let db = VectorDatabase::new(&database_url)?;
193 let worktree_root_paths = worktrees
194 .iter()
195 .map(|worktree| worktree.abs_path().clone())
196 .collect::<Vec<_>>();
197 let (db, file_hashes) = cx
198 .background()
199 .spawn(async move {
200 let mut hashes = Vec::new();
201 for worktree_root_path in worktree_root_paths {
202 let worktree_id =
203 db.find_or_create_worktree(worktree_root_path.as_ref())?;
204 hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
205 }
206 anyhow::Ok((db, hashes))
207 })
208 .await?;
209
210 let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
211 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
212 cx.background()
213 .spawn({
214 let fs = fs.clone();
215 async move {
216 for worktree in worktrees.into_iter() {
217 for file in worktree.files(false, 0) {
218 let absolute_path = worktree.absolutize(&file.path);
219 dbg!(&absolute_path);
220 if let Some(content) = fs.load(&absolute_path).await.log_err() {
221 dbg!(&content);
222 paths_tx.try_send((0, absolute_path, content)).unwrap();
223 }
224 }
225 }
226 }
227 })
228 .detach();
229
230 let db_write_task = cx.background().spawn(
231 async move {
232 // Initialize Database, creates database and tables if not exists
233 while let Ok(indexed_file) = indexed_files_rx.recv().await {
234 db.insert_file(indexed_file).log_err();
235 }
236
237 // ALL OF THE BELOW IS FOR TESTING,
238 // This should be removed as we find and appropriate place for evaluate our search.
239
240 // let queries = vec![
241 // "compute embeddings for all of the symbols in the codebase, and write them to a database",
242 // "compute an outline view of all of the symbols in a buffer",
243 // "scan a directory on the file system and load all of its children into an in-memory snapshot",
244 // ];
245 // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
246
247 // let t2 = Instant::now();
248 // let documents = db.get_documents().unwrap();
249 // let files = db.get_files().unwrap();
250 // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
251
252 // let t1 = Instant::now();
253 // let mut bfs = BruteForceSearch::load(&db).unwrap();
254 // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
255 // for (idx, embed) in embeddings.into_iter().enumerate() {
256 // let t0 = Instant::now();
257 // println!("\nQuery: {:?}", queries[idx]);
258 // let results = bfs.top_k_search(&embed, 5).await;
259 // println!("Search Elapsed: {}", t0.elapsed().as_millis());
260 // for (id, distance) in results {
261 // println!("");
262 // println!(" distance: {:?}", distance);
263 // println!(" document: {:?}", documents[&id].name);
264 // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
265 // }
266
267 // }
268
269 anyhow::Ok(())
270 }
271 .log_err(),
272 );
273
274 let provider = DummyEmbeddings {};
275 // let provider = OpenAIEmbeddings { client };
276
277 cx.background()
278 .scoped(|scope| {
279 for _ in 0..cx.background().num_cpus() {
280 scope.spawn(async {
281 let mut parser = Parser::new();
282 let mut cursor = QueryCursor::new();
283 while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
284 {
285 if let Some(indexed_file) = Self::index_file(
286 &mut cursor,
287 &mut parser,
288 &provider,
289 &language_registry,
290 file_path,
291 content,
292 )
293 .await
294 .log_err()
295 {
296 indexed_files_tx.try_send(indexed_file).unwrap();
297 }
298 }
299 });
300 }
301 })
302 .await;
303 drop(indexed_files_tx);
304
305 db_write_task.await;
306 anyhow::Ok(())
307 })
308 }
309
310 pub fn search(
311 &mut self,
312 phrase: String,
313 limit: usize,
314 cx: &mut ModelContext<Self>,
315 ) -> Task<Result<Vec<SearchResult>>> {
316 let embedding_provider = self.embedding_provider.clone();
317 let database_url = self.database_url.clone();
318 cx.spawn(|this, cx| async move {
319 let database = VectorDatabase::new(database_url.as_ref())?;
320
321 // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
322 //
323 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
324
325 database.for_each_document(0, |id, embedding| {
326 dbg!(id, &embedding);
327
328 let similarity = dot(&embedding.0, &embedding.0);
329 let ix = match results.binary_search_by(|(_, s)| {
330 s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
331 }) {
332 Ok(ix) => ix,
333 Err(ix) => ix,
334 };
335
336 results.insert(ix, (id, similarity));
337 results.truncate(limit);
338 })?;
339
340 dbg!(&results);
341
342 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
343 // let documents = database.get_documents_by_ids(ids)?;
344
345 // let search_provider = cx
346 // .background()
347 // .spawn(async move { BruteForceSearch::load(&database) })
348 // .await?;
349
350 // let results = search_provider.top_k_search(&embedding, limit))
351
352 anyhow::Ok(vec![])
353 })
354 }
355}
356
357impl Entity for VectorStore {
358 type Event = ();
359}
360
361fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
362 let len = vec_a.len();
363 assert_eq!(len, vec_b.len());
364
365 let mut result = 0.0;
366 unsafe {
367 matrixmultiply::sgemm(
368 1,
369 len,
370 1,
371 1.0,
372 vec_a.as_ptr(),
373 len as isize,
374 1,
375 vec_b.as_ptr(),
376 1,
377 len as isize,
378 0.0,
379 &mut result as *mut f32,
380 1,
381 1,
382 );
383 }
384 result
385}