search.rs

 1use std::cmp::Ordering;
 2
 3use async_trait::async_trait;
 4use ndarray::{Array1, Array2};
 5
 6use crate::db::{DocumentRecord, VectorDatabase};
 7use anyhow::Result;
 8
 9#[async_trait]
10pub trait VectorSearch {
11    // Given a query vector, and a limit to return
12    // Return a vector of id, distance tuples.
13    async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
14}
15
16pub struct BruteForceSearch {
17    document_ids: Vec<usize>,
18    candidate_array: ndarray::Array2<f32>,
19}
20
21impl BruteForceSearch {
22    pub fn load() -> Result<Self> {
23        let db = VectorDatabase {};
24        let documents = db.get_documents()?;
25        let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
26        let mut document_ids = vec![];
27        for i in documents.keys() {
28            document_ids.push(i.to_owned());
29        }
30
31        let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
32        for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
33            for (j, col) in row.iter_mut().enumerate() {
34                *col = embeddings[i].embedding.0[j];
35            }
36        }
37
38        return Ok(BruteForceSearch {
39            document_ids,
40            candidate_array,
41        });
42    }
43}
44
45#[async_trait]
46impl VectorSearch for BruteForceSearch {
47    async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
48        let target = Array1::from_vec(vec.to_owned());
49
50        let distances = self.candidate_array.dot(&target);
51
52        let distances = distances.to_vec();
53
54        // construct a tuple vector from the floats, the tuple being (index,float)
55        let mut with_indices = distances
56            .clone()
57            .into_iter()
58            .enumerate()
59            .map(|(index, value)| (index, value))
60            .collect::<Vec<(usize, f32)>>();
61
62        // sort the tuple vector by float
63        with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
64            (true, true) => Ordering::Equal,
65            (true, false) => Ordering::Greater,
66            (false, true) => Ordering::Less,
67            (false, false) => a.1.partial_cmp(&b.1).unwrap(),
68        });
69
70        // extract the sorted indices from the sorted tuple vector
71        let stored_indices = with_indices
72            .into_iter()
73            .map(|(index, value)| index)
74            .collect::<Vec<usize>>();
75
76        let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
77
78        let mut results = vec![];
79        for idx in sorted_indices[0..limit].to_vec() {
80            results.push((self.document_ids[idx], 1.0 - distances[idx]));
81        }
82
83        return results;
84    }
85}