1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
2use anyhow::{anyhow, Context, Result};
3use project::Fs;
4use rpc::proto::Timestamp;
5use rusqlite::{
6 params,
7 types::{FromSql, FromSqlResult, ValueRef},
8};
9use std::{
10 cmp::Ordering,
11 collections::HashMap,
12 ops::Range,
13 path::{Path, PathBuf},
14 rc::Rc,
15 sync::Arc,
16 time::SystemTime,
17};
18
19#[derive(Debug)]
20pub struct FileRecord {
21 pub id: usize,
22 pub relative_path: String,
23 pub mtime: Timestamp,
24}
25
26#[derive(Debug)]
27struct Embedding(pub Vec<f32>);
28
29impl FromSql for Embedding {
30 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
31 let bytes = value.as_blob()?;
32 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
33 if embedding.is_err() {
34 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
35 }
36 return Ok(Embedding(embedding.unwrap()));
37 }
38}
39
40pub struct VectorDatabase {
41 db: rusqlite::Connection,
42}
43
44impl VectorDatabase {
45 pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
46 if let Some(db_directory) = path.parent() {
47 fs.create_dir(db_directory).await?;
48 }
49
50 let this = Self {
51 db: rusqlite::Connection::open(path.as_path())?,
52 };
53 this.initialize_database()?;
54 Ok(this)
55 }
56
57 fn get_existing_version(&self) -> Result<i64> {
58 let mut version_query = self
59 .db
60 .prepare("SELECT version from semantic_index_config")?;
61 version_query
62 .query_row([], |row| Ok(row.get::<_, i64>(0)?))
63 .map_err(|err| anyhow!("version query failed: {err}"))
64 }
65
66 fn initialize_database(&self) -> Result<()> {
67 rusqlite::vtab::array::load_module(&self.db)?;
68
69 if self
70 .get_existing_version()
71 .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
72 {
73 return Ok(());
74 }
75
76 self.db
77 .execute(
78 "
79 DROP TABLE IF EXISTS documents;
80 DROP TABLE IF EXISTS files;
81 DROP TABLE IF EXISTS worktrees;
82 DROP TABLE IF EXISTS semantic_index_config;
83 ",
84 [],
85 )
86 .context("failed to drop tables")?;
87
88 // Initialize Vector Databasing Tables
89 self.db.execute(
90 "CREATE TABLE semantic_index_config (
91 version INTEGER NOT NULL
92 )",
93 [],
94 )?;
95
96 self.db.execute(
97 "INSERT INTO semantic_index_config (version) VALUES (?1)",
98 params![SEMANTIC_INDEX_VERSION],
99 )?;
100
101 self.db.execute(
102 "CREATE TABLE worktrees (
103 id INTEGER PRIMARY KEY AUTOINCREMENT,
104 absolute_path VARCHAR NOT NULL
105 );
106 CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
107 ",
108 [],
109 )?;
110
111 self.db.execute(
112 "CREATE TABLE files (
113 id INTEGER PRIMARY KEY AUTOINCREMENT,
114 worktree_id INTEGER NOT NULL,
115 relative_path VARCHAR NOT NULL,
116 mtime_seconds INTEGER NOT NULL,
117 mtime_nanos INTEGER NOT NULL,
118 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
119 )",
120 [],
121 )?;
122
123 self.db.execute(
124 "CREATE TABLE documents (
125 id INTEGER PRIMARY KEY AUTOINCREMENT,
126 file_id INTEGER NOT NULL,
127 start_byte INTEGER NOT NULL,
128 end_byte INTEGER NOT NULL,
129 name VARCHAR NOT NULL,
130 embedding BLOB NOT NULL,
131 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
132 )",
133 [],
134 )?;
135
136 Ok(())
137 }
138
139 pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
140 self.db.execute(
141 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
142 params![worktree_id, delete_path.to_str()],
143 )?;
144 Ok(())
145 }
146
147 pub fn insert_file(
148 &self,
149 worktree_id: i64,
150 path: PathBuf,
151 mtime: SystemTime,
152 documents: Vec<Document>,
153 ) -> Result<()> {
154 // Write to files table, and return generated id.
155 self.db.execute(
156 "
157 DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
158 ",
159 params![worktree_id, path.to_str()],
160 )?;
161 let mtime = Timestamp::from(mtime);
162 self.db.execute(
163 "
164 INSERT INTO files
165 (worktree_id, relative_path, mtime_seconds, mtime_nanos)
166 VALUES
167 (?1, ?2, $3, $4);
168 ",
169 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
170 )?;
171
172 let file_id = self.db.last_insert_rowid();
173
174 // Currently inserting at approximately 3400 documents a second
175 // I imagine we can speed this up with a bulk insert of some kind.
176 for document in documents {
177 let embedding_blob = bincode::serialize(&document.embedding)?;
178
179 self.db.execute(
180 "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
181 params![
182 file_id,
183 document.range.start.to_string(),
184 document.range.end.to_string(),
185 document.name,
186 embedding_blob
187 ],
188 )?;
189 }
190
191 Ok(())
192 }
193
194 pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
195 // Check that the absolute path doesnt exist
196 let mut worktree_query = self
197 .db
198 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
199
200 let worktree_id = worktree_query
201 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
202 Ok(row.get::<_, i64>(0)?)
203 })
204 .map_err(|err| anyhow!(err));
205
206 if worktree_id.is_ok() {
207 return worktree_id;
208 }
209
210 // If worktree_id is Err, insert new worktree
211 self.db.execute(
212 "
213 INSERT into worktrees (absolute_path) VALUES (?1)
214 ",
215 params![worktree_root_path.to_string_lossy()],
216 )?;
217 Ok(self.db.last_insert_rowid())
218 }
219
220 pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
221 let mut statement = self.db.prepare(
222 "
223 SELECT relative_path, mtime_seconds, mtime_nanos
224 FROM files
225 WHERE worktree_id = ?1
226 ORDER BY relative_path",
227 )?;
228 let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
229 for row in statement.query_map(params![worktree_id], |row| {
230 Ok((
231 row.get::<_, String>(0)?.into(),
232 Timestamp {
233 seconds: row.get(1)?,
234 nanos: row.get(2)?,
235 }
236 .into(),
237 ))
238 })? {
239 let row = row?;
240 result.insert(row.0, row.1);
241 }
242 Ok(result)
243 }
244
245 pub fn top_k_search(
246 &self,
247 worktree_ids: &[i64],
248 query_embedding: &Vec<f32>,
249 limit: usize,
250 ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
251 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
252 self.for_each_document(&worktree_ids, |id, embedding| {
253 let similarity = dot(&embedding, &query_embedding);
254 let ix = match results
255 .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
256 {
257 Ok(ix) => ix,
258 Err(ix) => ix,
259 };
260 results.insert(ix, (id, similarity));
261 results.truncate(limit);
262 })?;
263
264 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
265 self.get_documents_by_ids(&ids)
266 }
267
268 fn for_each_document(
269 &self,
270 worktree_ids: &[i64],
271 mut f: impl FnMut(i64, Vec<f32>),
272 ) -> Result<()> {
273 let mut query_statement = self.db.prepare(
274 "
275 SELECT
276 documents.id, documents.embedding
277 FROM
278 documents, files
279 WHERE
280 documents.file_id = files.id AND
281 files.worktree_id IN rarray(?)
282 ",
283 )?;
284
285 query_statement
286 .query_map(params![ids_to_sql(worktree_ids)], |row| {
287 Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
288 })?
289 .filter_map(|row| row.ok())
290 .for_each(|(id, embedding)| f(id, embedding.0));
291 Ok(())
292 }
293
294 fn get_documents_by_ids(
295 &self,
296 ids: &[i64],
297 ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
298 let mut statement = self.db.prepare(
299 "
300 SELECT
301 documents.id,
302 files.worktree_id,
303 files.relative_path,
304 documents.start_byte,
305 documents.end_byte, documents.name
306 FROM
307 documents, files
308 WHERE
309 documents.file_id = files.id AND
310 documents.id in rarray(?)
311 ",
312 )?;
313
314 let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
315 Ok((
316 row.get::<_, i64>(0)?,
317 row.get::<_, i64>(1)?,
318 row.get::<_, String>(2)?.into(),
319 row.get(3)?..row.get(4)?,
320 row.get(5)?,
321 ))
322 })?;
323
324 let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
325 for row in result_iter {
326 let (id, worktree_id, path, range, name) = row?;
327 values_by_id.insert(id, (worktree_id, path, range, name));
328 }
329
330 let mut results = Vec::with_capacity(ids.len());
331 for id in ids {
332 let value = values_by_id
333 .remove(id)
334 .ok_or(anyhow!("missing document id {}", id))?;
335 results.push(value);
336 }
337
338 Ok(results)
339 }
340}
341
342fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
343 Rc::new(
344 ids.iter()
345 .copied()
346 .map(|v| rusqlite::types::Value::from(v))
347 .collect::<Vec<_>>(),
348 )
349}
350
351pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
352 let len = vec_a.len();
353 assert_eq!(len, vec_b.len());
354
355 let mut result = 0.0;
356 unsafe {
357 matrixmultiply::sgemm(
358 1,
359 len,
360 1,
361 1.0,
362 vec_a.as_ptr(),
363 len as isize,
364 1,
365 vec_b.as_ptr(),
366 1,
367 len as isize,
368 0.0,
369 &mut result as *mut f32,
370 1,
371 1,
372 );
373 }
374 result
375}