search.rs

 1use std::{cmp::Ordering, path::PathBuf};
 2
 3use async_trait::async_trait;
 4use ndarray::{Array1, Array2};
 5
 6use crate::db::{DocumentRecord, VectorDatabase};
 7use anyhow::Result;
 8
 9#[async_trait]
10pub trait VectorSearch {
11    // Given a query vector, and a limit to return
12    // Return a vector of id, distance tuples.
13    async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
14}
15
16pub struct BruteForceSearch {
17    document_ids: Vec<usize>,
18    candidate_array: ndarray::Array2<f32>,
19}
20
21impl BruteForceSearch {
22    pub fn load(db: &VectorDatabase) -> Result<Self> {
23        let documents = db.get_documents()?;
24        let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
25        let mut document_ids = vec![];
26        for i in documents.keys() {
27            document_ids.push(i.to_owned());
28        }
29
30        let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
31        for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
32            for (j, col) in row.iter_mut().enumerate() {
33                *col = embeddings[i].embedding.0[j];
34            }
35        }
36
37        return Ok(BruteForceSearch {
38            document_ids,
39            candidate_array,
40        });
41    }
42}
43
44#[async_trait]
45impl VectorSearch for BruteForceSearch {
46    async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
47        let target = Array1::from_vec(vec.to_owned());
48
49        let similarities = self.candidate_array.dot(&target);
50
51        let similarities = similarities.to_vec();
52
53        // construct a tuple vector from the floats, the tuple being (index,float)
54        let mut with_indices = similarities
55            .iter()
56            .copied()
57            .enumerate()
58            .map(|(index, value)| (self.document_ids[index], value))
59            .collect::<Vec<(usize, f32)>>();
60
61        // sort the tuple vector by float
62        with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
63        with_indices.truncate(limit);
64        with_indices
65    }
66}