From 86ec0b1d9f92dd22cbabc6e45d65e14ceff99229 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 25 Sep 2023 13:44:19 -0400 Subject: [PATCH 1/5] implement new search strategy --- Cargo.lock | 121 ++++++++++++++++++- crates/semantic_index/Cargo.toml | 2 + crates/semantic_index/src/db.rs | 123 +++++++++++++------- crates/semantic_index/src/semantic_index.rs | 29 +++-- script/evaluate_semantic_index | 2 +- 5 files changed, 227 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 878604f3609fab58f724f62e16312a150522fc35..96cac48f7a610acddb8e7524ed2914ef130278a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -570,7 +570,7 @@ dependencies = [ "libc", "pin-project", "redox_syscall 0.2.16", - "xattr", + "xattr 0.2.3", ] [[package]] @@ -903,6 +903,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "blas-src" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb48fbaa7a0cb9d6d96c46bac6cedb16f13a10aebcef1c4e73515aaae8c9909d" +dependencies = [ + "openblas-src", +] + [[package]] name = "block" version = "0.1.6" @@ -1181,6 +1190,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2698f953def977c68f935bb0dfa959375ad4638570e969e2f1e9f433cbf1af6" +[[package]] +name = "cblas-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" +dependencies = [ + "libc", +] + [[package]] name = "cc" version = "1.0.83" @@ -4580,6 +4598,21 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "cblas-sys", + "libc", + "matrixmultiply", + "num-complex 0.4.4", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "ndk" version = "0.7.0" @@ -4706,7 +4739,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" dependencies = [ "num-bigint 0.2.6", - "num-complex", + "num-complex 0.2.4", "num-integer", "num-iter", "num-rational 0.2.4", @@ -4762,6 +4795,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -4948,6 +4990,32 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "openblas-build" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eba42c395477605f400a8d79ee0b756cfb82abe3eb5618e35fa70d3a36010a7f" +dependencies = [ + "anyhow", + "flate2", + "native-tls", + "tar", + "thiserror", + "ureq", + "walkdir", +] + +[[package]] +name = "openblas-src" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38e5d8af0b707ac2fe1574daa88b4157da73b0de3dc7c39fe3e2c0bb64070501" +dependencies = [ + "dirs 3.0.2", + "openblas-build", + "vcpkg", +] + [[package]] name = "openssl" version = "0.10.57" @@ -6422,6 +6490,18 @@ dependencies = [ "webpki 0.22.1", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.3" @@ -6740,6 +6820,7 @@ dependencies = [ "ai", "anyhow", "async-trait", + "blas-src", "client", "collections", "ctor", @@ -6751,6 +6832,7 @@ dependencies = [ "language", "lazy_static", "log", + "ndarray", "node_runtime", "ordered-float", "parking_lot 0.11.2", @@ -7629,6 +7711,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tar" +version = "0.4.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" +dependencies = [ + "filetime", + "libc", + "xattr 1.0.1", +] + [[package]] name = "target-lexicon" version = "0.12.11" @@ -8703,6 +8796,21 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "ureq" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +dependencies = [ + "base64 0.21.4", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls-native-certs", + "url", +] + [[package]] name = "url" version = "2.4.1" @@ -9754,6 +9862,15 @@ dependencies = [ "libc", ] +[[package]] +name = "xattr" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985" +dependencies = [ + "libc", +] + [[package]] name = "xmlparser" version = "0.13.5" diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index e38ae1f06db35ed96f7ac367c0529cd27e8cbd4f..7e68399b10c6747df075302856ab3c5e4d9c8380 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -39,6 +39,8 @@ rand.workspace = true schemars.workspace = true globset.workspace = true sha1 = "0.10.5" +ndarray = { version = "0.15.0", features = ["blas"] } +blas-src = { version = "0.8", features = ["openblas"] } [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 8280dc7d65c44bdbfd656c40db1e4b9e91cb9b7e..3558bf6b0a0335bf1d4bca0114d08edb9d835d7d 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,3 +1,5 @@ +extern crate blas_src; + use crate::{ parsing::{Span, SpanDigest}, SEMANTIC_INDEX_VERSION, @@ -7,6 +9,7 @@ use anyhow::{anyhow, Context, Result}; use collections::HashMap; use futures::channel::oneshot; use gpui::executor; +use ndarray::{Array1, Array2}; use ordered_float::OrderedFloat; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; @@ -19,10 +22,16 @@ use std::{ path::{Path, PathBuf}, rc::Rc, sync::Arc, - time::SystemTime, + time::{Instant, SystemTime}, }; use util::TryFutureExt; +pub fn argsort(data: &[T]) -> Vec { + let mut indices = (0..data.len()).collect::>(); + indices.sort_by_key(|&i| &data[i]); + indices +} + #[derive(Debug)] pub struct FileRecord { pub id: usize, @@ -409,23 +418,82 @@ impl VectorDatabase { limit: usize, file_ids: &[i64], ) -> impl Future)>>> { - let query_embedding = query_embedding.clone(); let file_ids = file_ids.to_vec(); + let query = query_embedding.clone().0; + let query = Array1::from_vec(query); self.transact(move |db| { - let mut results = Vec::<(i64, OrderedFloat)>::with_capacity(limit + 1); - Self::for_each_span(db, &file_ids, |id, embedding| { - let similarity = embedding.similarity(&query_embedding); - let ix = match results - .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s)) - { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; + let mut query_statement = db.prepare( + " + SELECT + id, embedding + FROM + spans + WHERE + file_id IN rarray(?) + ", + )?; + + let deserialized_rows = query_statement + .query_map(params![ids_to_sql(&file_ids)], |row| { + Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?)) + })? + .filter_map(|row| row.ok()) + .collect::>(); + + let batch_n = 250; + let mut batches = Vec::new(); + let mut batch_ids = Vec::new(); + let mut batch_embeddings: Vec = Vec::new(); + deserialized_rows.iter().for_each(|(id, embedding)| { + batch_ids.push(id); + batch_embeddings.extend(&embedding.0); + if batch_ids.len() == batch_n { + let array = + Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone()); + match array { + Ok(array) => { + batches.push((batch_ids.clone(), array)); + } + Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), + } + + batch_ids = Vec::new(); + batch_embeddings = Vec::new(); + } + }); + + if batch_ids.len() > 0 { + let array = + Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone()); + match array { + Ok(array) => { + batches.push((batch_ids.clone(), array)); + } + Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), + } + } + + let mut ids: Vec = Vec::new(); + let mut results = Vec::new(); + for (batch_ids, array) in batches { + let scores = array + .dot(&query.t()) + .to_vec() + .iter() + .map(|score| OrderedFloat(*score)) + .collect::>>(); + results.extend(scores); + ids.extend(batch_ids); + } - anyhow::Ok(results) + let sorted_idx = argsort(&results); + let mut sorted_results = Vec::new(); + let last_idx = limit.min(sorted_idx.len()); + for idx in &sorted_idx[0..last_idx] { + sorted_results.push((ids[*idx] as i64, results[*idx])) + } + + Ok(sorted_results) }) } @@ -468,31 +536,6 @@ impl VectorDatabase { }) } - fn for_each_span( - db: &rusqlite::Connection, - file_ids: &[i64], - mut f: impl FnMut(i64, Embedding), - ) -> Result<()> { - let mut query_statement = db.prepare( - " - SELECT - id, embedding - FROM - spans - WHERE - file_id IN rarray(?) - ", - )?; - - query_statement - .query_map(params![ids_to_sql(&file_ids)], |row| { - Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) - })? - .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| f(id, embedding)); - Ok(()) - } - pub fn spans_for_ids( &self, ids: &[i64], diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index fd41cb150024fd0ed5c2882f5835d76f8c152f41..59cf596e7f06eb136e2dbf5a2d1109a557883013 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -705,11 +705,13 @@ impl SemanticIndex { cx.spawn(|this, mut cx| async move { index.await?; + let t0 = Instant::now(); let query = embedding_provider .embed_batch(vec![query]) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; + log::trace!("Embedding Search Query: {:?}", t0.elapsed().as_millis()); let search_start = Instant::now(); let modified_buffer_results = this.update(&mut cx, |this, cx| { @@ -787,10 +789,15 @@ impl SemanticIndex { let batch_n = cx.background().num_cpus(); let ids_len = file_ids.clone().len(); - let batch_size = if ids_len <= batch_n { - ids_len - } else { - ids_len / batch_n + let minimum_batch_size = 50; + + let batch_size = { + let size = ids_len / batch_n; + if size < minimum_batch_size { + minimum_batch_size + } else { + size + } }; let mut batch_results = Vec::new(); @@ -813,17 +820,26 @@ impl SemanticIndex { let batch_results = futures::future::join_all(batch_results).await; let mut results = Vec::new(); + let mut min_similarity = None; for batch_result in batch_results { if batch_result.is_ok() { for (id, similarity) in batch_result.unwrap() { + if min_similarity.map_or_else(|| false, |min_sim| min_sim > similarity) { + continue; + } + let ix = match results .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s)) { Ok(ix) => ix, Err(ix) => ix, }; - results.insert(ix, (id, similarity)); - results.truncate(limit); + + if ix <= limit { + min_similarity = Some(similarity); + results.insert(ix, (id, similarity)); + results.truncate(limit); + } } } } @@ -856,7 +872,6 @@ impl SemanticIndex { })?; let buffers = futures::future::join_all(tasks).await; - Ok(buffers .into_iter() .zip(ranges) diff --git a/script/evaluate_semantic_index b/script/evaluate_semantic_index index e9a96a02b40d46c09dbf588361bfc98990321377..8dcb53c399dad42c7ceedb5d6181f78482843377 100755 --- a/script/evaluate_semantic_index +++ b/script/evaluate_semantic_index @@ -1,3 +1,3 @@ #!/bin/bash -cargo run -p semantic_index --example eval +RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release From ea278b5b12bec131bc4b587b71a45e4b74c11308 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 26 Sep 2023 09:53:49 -0400 Subject: [PATCH 2/5] ensure desc sort and cleanup unused imports --- crates/semantic_index/src/db.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 3558bf6b0a0335bf1d4bca0114d08edb9d835d7d..caa70a4cfaaa76cbfc6138f864a12f58acacea02 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -16,19 +16,19 @@ use rpc::proto::Timestamp; use rusqlite::params; use rusqlite::types::Value; use std::{ - cmp::Reverse, future::Future, ops::Range, path::{Path, PathBuf}, rc::Rc, sync::Arc, - time::{Instant, SystemTime}, + time::SystemTime, }; use util::TryFutureExt; pub fn argsort(data: &[T]) -> Vec { let mut indices = (0..data.len()).collect::>(); indices.sort_by_key(|&i| &data[i]); + indices.reverse(); indices } From e75f56a0f291c8a773fd5aef701033dcaeac7ddb Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 26 Sep 2023 12:39:22 -0400 Subject: [PATCH 3/5] move to system blas --- Cargo.lock | 1 + crates/semantic_index/Cargo.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 68c7a159d252df24f2ff7e66322d37080be5cb60..edb2272c1cad7cb7c07cd9a170804c6b9956f163 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6834,6 +6834,7 @@ dependencies = [ "log", "ndarray", "node_runtime", + "openblas-src", "ordered-float", "parking_lot 0.11.2", "picker", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 7e68399b10c6747df075302856ab3c5e4d9c8380..cd08aa0d63521b16066b5e4d8bfda06f60c88fef 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -41,6 +41,7 @@ globset.workspace = true sha1 = "0.10.5" ndarray = { version = "0.15.0", features = ["blas"] } blas-src = { version = "0.8", features = ["openblas"] } +openblas-src = { version = "0.10", features = ["cblas", "system"] } [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } From abefa2738b8e8258e278ac0981fb54ab8441cb3e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 27 Sep 2023 09:43:23 -0400 Subject: [PATCH 4/5] removed blas and increase batch size for vector search --- Cargo.lock | 97 +-------------------- crates/semantic_index/Cargo.toml | 4 +- crates/semantic_index/src/db.rs | 15 ++-- crates/semantic_index/src/semantic_index.rs | 2 +- 4 files changed, 9 insertions(+), 109 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index edb2272c1cad7cb7c07cd9a170804c6b9956f163..3342bf39b8f6f6b621ed97d5f06a9f1be43c2852 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -570,7 +570,7 @@ dependencies = [ "libc", "pin-project", "redox_syscall 0.2.16", - "xattr 0.2.3", + "xattr", ] [[package]] @@ -903,15 +903,6 @@ dependencies = [ "wyz", ] -[[package]] -name = "blas-src" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb48fbaa7a0cb9d6d96c46bac6cedb16f13a10aebcef1c4e73515aaae8c9909d" -dependencies = [ - "openblas-src", -] - [[package]] name = "block" version = "0.1.6" @@ -1190,15 +1181,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2698f953def977c68f935bb0dfa959375ad4638570e969e2f1e9f433cbf1af6" -[[package]] -name = "cblas-sys" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" -dependencies = [ - "libc", -] - [[package]] name = "cc" version = "1.0.83" @@ -4604,8 +4586,6 @@ version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" dependencies = [ - "cblas-sys", - "libc", "matrixmultiply", "num-complex 0.4.4", "num-integer", @@ -4990,32 +4970,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" -[[package]] -name = "openblas-build" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eba42c395477605f400a8d79ee0b756cfb82abe3eb5618e35fa70d3a36010a7f" -dependencies = [ - "anyhow", - "flate2", - "native-tls", - "tar", - "thiserror", - "ureq", - "walkdir", -] - -[[package]] -name = "openblas-src" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38e5d8af0b707ac2fe1574daa88b4157da73b0de3dc7c39fe3e2c0bb64070501" -dependencies = [ - "dirs 3.0.2", - "openblas-build", - "vcpkg", -] - [[package]] name = "openssl" version = "0.10.57" @@ -6490,18 +6444,6 @@ dependencies = [ "webpki 0.22.1", ] -[[package]] -name = "rustls-native-certs" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" -dependencies = [ - "openssl-probe", - "rustls-pemfile", - "schannel", - "security-framework", -] - [[package]] name = "rustls-pemfile" version = "1.0.3" @@ -6820,7 +6762,6 @@ dependencies = [ "ai", "anyhow", "async-trait", - "blas-src", "client", "collections", "ctor", @@ -6834,7 +6775,6 @@ dependencies = [ "log", "ndarray", "node_runtime", - "openblas-src", "ordered-float", "parking_lot 0.11.2", "picker", @@ -7712,17 +7652,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" -[[package]] -name = "tar" -version = "0.4.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" -dependencies = [ - "filetime", - "libc", - "xattr 1.0.1", -] - [[package]] name = "target-lexicon" version = "0.12.11" @@ -8797,21 +8726,6 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" -[[package]] -name = "ureq" -version = "2.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" -dependencies = [ - "base64 0.21.4", - "flate2", - "log", - "native-tls", - "once_cell", - "rustls-native-certs", - "url", -] - [[package]] name = "url" version = "2.4.1" @@ -9865,15 +9779,6 @@ dependencies = [ "libc", ] -[[package]] -name = "xattr" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985" -dependencies = [ - "libc", -] - [[package]] name = "xmlparser" version = "0.13.5" diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index cd08aa0d63521b16066b5e4d8bfda06f60c88fef..efda311633cad4b416adcea4b4f77c7a3e71c72d 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -39,9 +39,7 @@ rand.workspace = true schemars.workspace = true globset.workspace = true sha1 = "0.10.5" -ndarray = { version = "0.15.0", features = ["blas"] } -blas-src = { version = "0.8", features = ["openblas"] } -openblas-src = { version = "0.10", features = ["cblas", "system"] } +ndarray = { version = "0.15.0" } [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index caa70a4cfaaa76cbfc6138f864a12f58acacea02..18e38c6e4c1162eaa38fca0f954dcc1057920c96 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,5 +1,3 @@ -extern crate blas_src; - use crate::{ parsing::{Span, SpanDigest}, SEMANTIC_INDEX_VERSION, @@ -440,25 +438,24 @@ impl VectorDatabase { .filter_map(|row| row.ok()) .collect::>(); - let batch_n = 250; + let batch_n = 1000; let mut batches = Vec::new(); let mut batch_ids = Vec::new(); let mut batch_embeddings: Vec = Vec::new(); deserialized_rows.iter().for_each(|(id, embedding)| { batch_ids.push(id); batch_embeddings.extend(&embedding.0); + if batch_ids.len() == batch_n { - let array = - Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone()); + let embeddings = std::mem::take(&mut batch_embeddings); + let ids = std::mem::take(&mut batch_ids); + let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings); match array { Ok(array) => { - batches.push((batch_ids.clone(), array)); + batches.push((ids, array)); } Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), } - - batch_ids = Vec::new(); - batch_embeddings = Vec::new(); } }); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f42544af1c671505b101502e1355104c70c55f6b..ecdba4364315eb8f0a4ed2cf579fcc3149e56e67 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -711,7 +711,7 @@ impl SemanticIndex { .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; - log::trace!("Embedding Search Query: {:?}", t0.elapsed().as_millis()); + log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis()); let search_start = Instant::now(); let modified_buffer_results = this.update(&mut cx, |this, cx| { From 0e6fd645fd71bf77d1bdff28c30985ac23229aaf Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 27 Sep 2023 10:33:04 -0400 Subject: [PATCH 5/5] leverage embeddings len returned in construction matrix multiplication --- crates/semantic_index/src/db.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 18e38c6e4c1162eaa38fca0f954dcc1057920c96..63527cea1ccacadec4fa2410d936f1eac948eb47 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -438,6 +438,13 @@ impl VectorDatabase { .filter_map(|row| row.ok()) .collect::>(); + if deserialized_rows.len() == 0 { + return Ok(Vec::new()); + } + + // Get Length of Embeddings Returned + let embedding_len = deserialized_rows[0].1 .0.len(); + let batch_n = 1000; let mut batches = Vec::new(); let mut batch_ids = Vec::new(); @@ -449,7 +456,8 @@ impl VectorDatabase { if batch_ids.len() == batch_n { let embeddings = std::mem::take(&mut batch_embeddings); let ids = std::mem::take(&mut batch_ids); - let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings); + let array = + Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings); match array { Ok(array) => { batches.push((ids, array)); @@ -460,8 +468,10 @@ impl VectorDatabase { }); if batch_ids.len() > 0 { - let array = - Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone()); + let array = Array2::from_shape_vec( + (batch_ids.len(), embedding_len), + batch_embeddings.clone(), + ); match array { Ok(array) => { batches.push((batch_ids.clone(), array));