embedding.rs

  1use anyhow::Result;
  2use async_trait::async_trait;
  3use ordered_float::OrderedFloat;
  4use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
  5use rusqlite::ToSql;
  6use std::time::Instant;
  7
  8use crate::models::LanguageModel;
  9
 10#[derive(Debug, PartialEq, Clone)]
 11pub struct Embedding(pub Vec<f32>);
 12
 13// This is needed for semantic index functionality
 14// Unfortunately it has to live wherever the "Embedding" struct is created.
 15// Keeping this in here though, introduces a 'rusqlite' dependency into AI
 16// which is less than ideal
 17impl FromSql for Embedding {
 18    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 19        let bytes = value.as_blob()?;
 20        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 21        if embedding.is_err() {
 22            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 23        }
 24        Ok(Embedding(embedding.unwrap()))
 25    }
 26}
 27
 28impl ToSql for Embedding {
 29    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 30        let bytes = bincode::serialize(&self.0)
 31            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 32        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 33    }
 34}
 35impl From<Vec<f32>> for Embedding {
 36    fn from(value: Vec<f32>) -> Self {
 37        Embedding(value)
 38    }
 39}
 40
 41impl Embedding {
 42    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
 43        let len = self.0.len();
 44        assert_eq!(len, other.0.len());
 45
 46        let mut result = 0.0;
 47        unsafe {
 48            matrixmultiply::sgemm(
 49                1,
 50                len,
 51                1,
 52                1.0,
 53                self.0.as_ptr(),
 54                len as isize,
 55                1,
 56                other.0.as_ptr(),
 57                1,
 58                len as isize,
 59                0.0,
 60                &mut result as *mut f32,
 61                1,
 62                1,
 63            );
 64        }
 65        OrderedFloat(result)
 66    }
 67}
 68
 69#[async_trait]
 70pub trait EmbeddingProvider: Sync + Send {
 71    fn base_model(&self) -> Box<dyn LanguageModel>;
 72    fn is_authenticated(&self) -> bool;
 73    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
 74    fn max_tokens_per_batch(&self) -> usize;
 75    fn rate_limit_expiration(&self) -> Option<Instant>;
 76}
 77
 78#[cfg(test)]
 79mod tests {
 80    use super::*;
 81    use rand::prelude::*;
 82
 83    #[gpui::test]
 84    fn test_similarity(mut rng: StdRng) {
 85        assert_eq!(
 86            Embedding::from(vec![1., 0., 0., 0., 0.])
 87                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
 88            0.
 89        );
 90        assert_eq!(
 91            Embedding::from(vec![2., 0., 0., 0., 0.])
 92                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
 93            6.
 94        );
 95
 96        for _ in 0..100 {
 97            let size = 1536;
 98            let mut a = vec![0.; size];
 99            let mut b = vec![0.; size];
100            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
101                *a = rng.gen();
102                *b = rng.gen();
103            }
104            let a = Embedding::from(a);
105            let b = Embedding::from(b);
106
107            assert_eq!(
108                round_to_decimals(a.similarity(&b), 1),
109                round_to_decimals(reference_dot(&a.0, &b.0), 1)
110            );
111        }
112
113        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
114            let factor = (10.0 as f32).powi(decimal_places);
115            (n * factor).round() / factor
116        }
117
118        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
119            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
120        }
121    }
122}