1mod db;
2mod embedding;
3
4use anyhow::{anyhow, Result};
5use db::VectorDatabase;
6use embedding::{EmbeddingProvider, OpenAIEmbeddings};
7use gpui::{AppContext, Entity, ModelContext, ModelHandle};
8use language::LanguageRegistry;
9use project::{Fs, Project};
10use smol::channel;
11use std::{path::PathBuf, sync::Arc, time::Instant};
12use tree_sitter::{Parser, QueryCursor};
13use util::{http::HttpClient, ResultExt};
14use workspace::WorkspaceCreated;
15
16pub fn init(
17 fs: Arc<dyn Fs>,
18 http_client: Arc<dyn HttpClient>,
19 language_registry: Arc<LanguageRegistry>,
20 cx: &mut AppContext,
21) {
22 let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
23
24 cx.subscribe_global::<WorkspaceCreated, _>({
25 let vector_store = vector_store.clone();
26 move |event, cx| {
27 let workspace = &event.0;
28 if let Some(workspace) = workspace.upgrade(cx) {
29 let project = workspace.read(cx).project().clone();
30 if project.read(cx).is_local() {
31 vector_store.update(cx, |store, cx| {
32 store.add_project(project, cx);
33 });
34 }
35 }
36 }
37 })
38 .detach();
39}
40
41#[derive(Debug, sqlx::FromRow)]
42struct Document {
43 offset: usize,
44 name: String,
45 embedding: Vec<f32>,
46}
47
48#[derive(Debug, sqlx::FromRow)]
49pub struct IndexedFile {
50 path: PathBuf,
51 sha1: String,
52 documents: Vec<Document>,
53}
54
55struct SearchResult {
56 path: PathBuf,
57 offset: usize,
58 name: String,
59 distance: f32,
60}
61
62struct VectorStore {
63 fs: Arc<dyn Fs>,
64 http_client: Arc<dyn HttpClient>,
65 language_registry: Arc<LanguageRegistry>,
66}
67
68impl VectorStore {
69 fn new(
70 fs: Arc<dyn Fs>,
71 http_client: Arc<dyn HttpClient>,
72 language_registry: Arc<LanguageRegistry>,
73 ) -> Self {
74 Self {
75 fs,
76 http_client,
77 language_registry,
78 }
79 }
80
81 async fn index_file(
82 cursor: &mut QueryCursor,
83 parser: &mut Parser,
84 embedding_provider: &dyn EmbeddingProvider,
85 fs: &Arc<dyn Fs>,
86 language_registry: &Arc<LanguageRegistry>,
87 file_path: PathBuf,
88 ) -> Result<IndexedFile> {
89 let language = language_registry
90 .language_for_file(&file_path, None)
91 .await?;
92
93 if language.name().as_ref() != "Rust" {
94 Err(anyhow!("unsupported language"))?;
95 }
96
97 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
98 let outline_config = grammar
99 .outline_config
100 .as_ref()
101 .ok_or_else(|| anyhow!("no outline query"))?;
102
103 let content = fs.load(&file_path).await?;
104 parser.set_language(grammar.ts_language).unwrap();
105 let tree = parser
106 .parse(&content, None)
107 .ok_or_else(|| anyhow!("parsing failed"))?;
108
109 let mut documents = Vec::new();
110 let mut context_spans = Vec::new();
111 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
112 let mut item_range = None;
113 let mut name_range = None;
114 for capture in mat.captures {
115 if capture.index == outline_config.item_capture_ix {
116 item_range = Some(capture.node.byte_range());
117 } else if capture.index == outline_config.name_capture_ix {
118 name_range = Some(capture.node.byte_range());
119 }
120 }
121
122 if let Some((item_range, name_range)) = item_range.zip(name_range) {
123 if let Some((item, name)) =
124 content.get(item_range.clone()).zip(content.get(name_range))
125 {
126 context_spans.push(item);
127 documents.push(Document {
128 name: name.to_string(),
129 offset: item_range.start,
130 embedding: Vec::new(),
131 });
132 }
133 }
134 }
135
136 let embeddings = embedding_provider.embed_batch(context_spans).await?;
137 for (document, embedding) in documents.iter_mut().zip(embeddings) {
138 document.embedding = embedding;
139 }
140
141 return Ok(IndexedFile {
142 path: file_path,
143 sha1: String::new(),
144 documents,
145 });
146 }
147
148 fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
149 let worktree_scans_complete = project
150 .read(cx)
151 .worktrees(cx)
152 .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
153 .collect::<Vec<_>>();
154
155 let fs = self.fs.clone();
156 let language_registry = self.language_registry.clone();
157 let client = self.http_client.clone();
158
159 cx.spawn(|_, cx| async move {
160 futures::future::join_all(worktree_scans_complete).await;
161
162 let worktrees = project.read_with(&cx, |project, cx| {
163 project
164 .worktrees(cx)
165 .map(|worktree| worktree.read(cx).snapshot())
166 .collect::<Vec<_>>()
167 });
168
169 let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
170 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
171 cx.background()
172 .spawn(async move {
173 for worktree in worktrees {
174 for file in worktree.files(false, 0) {
175 paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
176 }
177 }
178 })
179 .detach();
180
181 cx.background()
182 .spawn(async move {
183 // Initialize Database, creates database and tables if not exists
184 VectorDatabase::initialize_database().await.log_err();
185 while let Ok(indexed_file) = indexed_files_rx.recv().await {
186 VectorDatabase::insert_file(indexed_file).await.log_err();
187 }
188 })
189 .detach();
190
191 let provider = OpenAIEmbeddings { client };
192
193 let t0 = Instant::now();
194
195 cx.background()
196 .scoped(|scope| {
197 for _ in 0..cx.background().num_cpus() {
198 scope.spawn(async {
199 let mut parser = Parser::new();
200 let mut cursor = QueryCursor::new();
201 while let Ok(file_path) = paths_rx.recv().await {
202 if let Some(indexed_file) = Self::index_file(
203 &mut cursor,
204 &mut parser,
205 &provider,
206 &fs,
207 &language_registry,
208 file_path,
209 )
210 .await
211 .log_err()
212 {
213 indexed_files_tx.try_send(indexed_file).unwrap();
214 }
215 }
216 });
217 }
218 })
219 .await;
220
221 let duration = t0.elapsed();
222 log::info!("indexed project in {duration:?}");
223 })
224 .detach();
225 }
226}
227
228impl Entity for VectorStore {
229 type Event = ();
230}