1use std::cmp::Ordering;
2
3use async_trait::async_trait;
4use ndarray::{Array1, Array2};
5
6use crate::db::{DocumentRecord, VectorDatabase};
7use anyhow::Result;
8
9#[async_trait]
10pub trait VectorSearch {
11 // Given a query vector, and a limit to return
12 // Return a vector of id, distance tuples.
13 async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
14}
15
16pub struct BruteForceSearch {
17 document_ids: Vec<usize>,
18 candidate_array: ndarray::Array2<f32>,
19}
20
21impl BruteForceSearch {
22 pub fn load() -> Result<Self> {
23 let db = VectorDatabase {};
24 let documents = db.get_documents()?;
25 let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
26 let mut document_ids = vec![];
27 for i in documents.keys() {
28 document_ids.push(i.to_owned());
29 }
30
31 let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
32 for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
33 for (j, col) in row.iter_mut().enumerate() {
34 *col = embeddings[i].embedding.0[j];
35 }
36 }
37
38 return Ok(BruteForceSearch {
39 document_ids,
40 candidate_array,
41 });
42 }
43}
44
45#[async_trait]
46impl VectorSearch for BruteForceSearch {
47 async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
48 let target = Array1::from_vec(vec.to_owned());
49
50 let distances = self.candidate_array.dot(&target);
51
52 let distances = distances.to_vec();
53
54 // construct a tuple vector from the floats, the tuple being (index,float)
55 let mut with_indices = distances
56 .clone()
57 .into_iter()
58 .enumerate()
59 .map(|(index, value)| (index, value))
60 .collect::<Vec<(usize, f32)>>();
61
62 // sort the tuple vector by float
63 with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
64 (true, true) => Ordering::Equal,
65 (true, false) => Ordering::Greater,
66 (false, true) => Ordering::Less,
67 (false, false) => a.1.partial_cmp(&b.1).unwrap(),
68 });
69
70 // extract the sorted indices from the sorted tuple vector
71 let stored_indices = with_indices
72 .into_iter()
73 .map(|(index, value)| index)
74 .collect::<Vec<usize>>();
75
76 let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
77
78 let mut results = vec![];
79 for idx in sorted_indices[0..limit].to_vec() {
80 results.push((self.document_ids[idx], 1.0 - distances[idx]));
81 }
82
83 return results;
84 }
85}