1mod db;
2mod embedding;
3mod search;
4
5use anyhow::{anyhow, Result};
6use db::VectorDatabase;
7use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
8use gpui::{AppContext, Entity, ModelContext, ModelHandle};
9use language::LanguageRegistry;
10use project::{Fs, Project};
11use smol::channel;
12use std::{path::PathBuf, sync::Arc, time::Instant};
13use tree_sitter::{Parser, QueryCursor};
14use util::{http::HttpClient, ResultExt};
15use workspace::WorkspaceCreated;
16
17pub fn init(
18 fs: Arc<dyn Fs>,
19 http_client: Arc<dyn HttpClient>,
20 language_registry: Arc<LanguageRegistry>,
21 cx: &mut AppContext,
22) {
23 let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
24
25 cx.subscribe_global::<WorkspaceCreated, _>({
26 let vector_store = vector_store.clone();
27 move |event, cx| {
28 let workspace = &event.0;
29 if let Some(workspace) = workspace.upgrade(cx) {
30 let project = workspace.read(cx).project().clone();
31 if project.read(cx).is_local() {
32 vector_store.update(cx, |store, cx| {
33 store.add_project(project, cx);
34 });
35 }
36 }
37 }
38 })
39 .detach();
40}
41
42#[derive(Debug)]
43pub struct Document {
44 pub offset: usize,
45 pub name: String,
46 pub embedding: Vec<f32>,
47}
48
49#[derive(Debug)]
50pub struct IndexedFile {
51 path: PathBuf,
52 sha1: String,
53 documents: Vec<Document>,
54}
55
56struct SearchResult {
57 path: PathBuf,
58 offset: usize,
59 name: String,
60 distance: f32,
61}
62
63struct VectorStore {
64 fs: Arc<dyn Fs>,
65 http_client: Arc<dyn HttpClient>,
66 language_registry: Arc<LanguageRegistry>,
67}
68
69impl VectorStore {
70 fn new(
71 fs: Arc<dyn Fs>,
72 http_client: Arc<dyn HttpClient>,
73 language_registry: Arc<LanguageRegistry>,
74 ) -> Self {
75 Self {
76 fs,
77 http_client,
78 language_registry,
79 }
80 }
81
82 async fn index_file(
83 cursor: &mut QueryCursor,
84 parser: &mut Parser,
85 embedding_provider: &dyn EmbeddingProvider,
86 fs: &Arc<dyn Fs>,
87 language_registry: &Arc<LanguageRegistry>,
88 file_path: PathBuf,
89 ) -> Result<IndexedFile> {
90 let language = language_registry
91 .language_for_file(&file_path, None)
92 .await?;
93
94 if language.name().as_ref() != "Rust" {
95 Err(anyhow!("unsupported language"))?;
96 }
97
98 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
99 let outline_config = grammar
100 .outline_config
101 .as_ref()
102 .ok_or_else(|| anyhow!("no outline query"))?;
103
104 let content = fs.load(&file_path).await?;
105 parser.set_language(grammar.ts_language).unwrap();
106 let tree = parser
107 .parse(&content, None)
108 .ok_or_else(|| anyhow!("parsing failed"))?;
109
110 let mut documents = Vec::new();
111 let mut context_spans = Vec::new();
112 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
113 let mut item_range = None;
114 let mut name_range = None;
115 for capture in mat.captures {
116 if capture.index == outline_config.item_capture_ix {
117 item_range = Some(capture.node.byte_range());
118 } else if capture.index == outline_config.name_capture_ix {
119 name_range = Some(capture.node.byte_range());
120 }
121 }
122
123 if let Some((item_range, name_range)) = item_range.zip(name_range) {
124 if let Some((item, name)) =
125 content.get(item_range.clone()).zip(content.get(name_range))
126 {
127 context_spans.push(item);
128 documents.push(Document {
129 name: name.to_string(),
130 offset: item_range.start,
131 embedding: Vec::new(),
132 });
133 }
134 }
135 }
136
137 let embeddings = embedding_provider.embed_batch(context_spans).await?;
138 for (document, embedding) in documents.iter_mut().zip(embeddings) {
139 document.embedding = embedding;
140 }
141
142 return Ok(IndexedFile {
143 path: file_path,
144 sha1: String::new(),
145 documents,
146 });
147 }
148
149 fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
150 let worktree_scans_complete = project
151 .read(cx)
152 .worktrees(cx)
153 .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
154 .collect::<Vec<_>>();
155
156 let fs = self.fs.clone();
157 let language_registry = self.language_registry.clone();
158 let client = self.http_client.clone();
159
160 cx.spawn(|_, cx| async move {
161 futures::future::join_all(worktree_scans_complete).await;
162
163 let worktrees = project.read_with(&cx, |project, cx| {
164 project
165 .worktrees(cx)
166 .map(|worktree| worktree.read(cx).snapshot())
167 .collect::<Vec<_>>()
168 });
169
170 let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
171 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
172 cx.background()
173 .spawn(async move {
174 for worktree in worktrees {
175 for file in worktree.files(false, 0) {
176 paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
177 }
178 }
179 })
180 .detach();
181
182 cx.background()
183 .spawn(async move {
184 // Initialize Database, creates database and tables if not exists
185 VectorDatabase::initialize_database().await.log_err();
186 while let Ok(indexed_file) = indexed_files_rx.recv().await {
187 VectorDatabase::insert_file(indexed_file).await.log_err();
188 }
189
190 anyhow::Ok(())
191 })
192 .detach();
193
194 let provider = DummyEmbeddings {};
195
196 cx.background()
197 .scoped(|scope| {
198 for _ in 0..cx.background().num_cpus() {
199 scope.spawn(async {
200 let mut parser = Parser::new();
201 let mut cursor = QueryCursor::new();
202 while let Ok(file_path) = paths_rx.recv().await {
203 if let Some(indexed_file) = Self::index_file(
204 &mut cursor,
205 &mut parser,
206 &provider,
207 &fs,
208 &language_registry,
209 file_path,
210 )
211 .await
212 .log_err()
213 {
214 indexed_files_tx.try_send(indexed_file).unwrap();
215 }
216 }
217 });
218 }
219 })
220 .await;
221 })
222 .detach();
223 }
224}
225
226impl Entity for VectorStore {
227 type Event = ();
228}