leverage embeddings len returned in construction matrix multiplication

KCaverly created

Change summary

crates/semantic_index/src/db.rs | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)

Detailed changes

crates/semantic_index/src/db.rs 🔗

@@ -438,6 +438,13 @@ impl VectorDatabase {
                 .filter_map(|row| row.ok())
                 .collect::<Vec<(usize, Embedding)>>();
 
+            if deserialized_rows.len() == 0 {
+                return Ok(Vec::new());
+            }
+
+            // Get Length of Embeddings Returned
+            let embedding_len = deserialized_rows[0].1 .0.len();
+
             let batch_n = 1000;
             let mut batches = Vec::new();
             let mut batch_ids = Vec::new();
@@ -449,7 +456,8 @@ impl VectorDatabase {
                 if batch_ids.len() == batch_n {
                     let embeddings = std::mem::take(&mut batch_embeddings);
                     let ids = std::mem::take(&mut batch_ids);
-                    let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings);
+                    let array =
+                        Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
                     match array {
                         Ok(array) => {
                             batches.push((ids, array));
@@ -460,8 +468,10 @@ impl VectorDatabase {
             });
 
             if batch_ids.len() > 0 {
-                let array =
-                    Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
+                let array = Array2::from_shape_vec(
+                    (batch_ids.len(), embedding_len),
+                    batch_embeddings.clone(),
+                );
                 match array {
                     Ok(array) => {
                         batches.push((batch_ids.clone(), array));