1use std::{cmp::Ordering, path::PathBuf};
2
3use async_trait::async_trait;
4use ndarray::{Array1, Array2};
5
6use crate::db::{DocumentRecord, VectorDatabase};
7use anyhow::Result;
8
9#[async_trait]
10pub trait VectorSearch {
11 // Given a query vector, and a limit to return
12 // Return a vector of id, distance tuples.
13 async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
14}
15
16pub struct BruteForceSearch {
17 document_ids: Vec<usize>,
18 candidate_array: ndarray::Array2<f32>,
19}
20
21impl BruteForceSearch {
22 pub fn load(db: &VectorDatabase) -> Result<Self> {
23 let documents = db.get_documents()?;
24 let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
25 let mut document_ids = vec![];
26 for i in documents.keys() {
27 document_ids.push(i.to_owned());
28 }
29
30 let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
31 for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
32 for (j, col) in row.iter_mut().enumerate() {
33 *col = embeddings[i].embedding.0[j];
34 }
35 }
36
37 return Ok(BruteForceSearch {
38 document_ids,
39 candidate_array,
40 });
41 }
42}
43
44#[async_trait]
45impl VectorSearch for BruteForceSearch {
46 async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
47 let target = Array1::from_vec(vec.to_owned());
48
49 let similarities = self.candidate_array.dot(&target);
50
51 let similarities = similarities.to_vec();
52
53 // construct a tuple vector from the floats, the tuple being (index,float)
54 let mut with_indices = similarities
55 .iter()
56 .copied()
57 .enumerate()
58 .map(|(index, value)| (self.document_ids[index], value))
59 .collect::<Vec<(usize, f32)>>();
60
61 // sort the tuple vector by float
62 with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
63 with_indices.truncate(limit);
64 with_indices
65 }
66}