@@ -438,6 +438,13 @@ impl VectorDatabase {
.filter_map(|row| row.ok())
.collect::<Vec<(usize, Embedding)>>();
+ if deserialized_rows.len() == 0 {
+ return Ok(Vec::new());
+ }
+
+ // Get Length of Embeddings Returned
+ let embedding_len = deserialized_rows[0].1 .0.len();
+
let batch_n = 1000;
let mut batches = Vec::new();
let mut batch_ids = Vec::new();
@@ -449,7 +456,8 @@ impl VectorDatabase {
if batch_ids.len() == batch_n {
let embeddings = std::mem::take(&mut batch_embeddings);
let ids = std::mem::take(&mut batch_ids);
- let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings);
+ let array =
+ Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
match array {
Ok(array) => {
batches.push((ids, array));
@@ -460,8 +468,10 @@ impl VectorDatabase {
});
if batch_ids.len() > 0 {
- let array =
- Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
+ let array = Array2::from_shape_vec(
+ (batch_ids.len(), embedding_len),
+ batch_embeddings.clone(),
+ );
match array {
Ok(array) => {
batches.push((batch_ids.clone(), array));