embedding.rs

  1use std::time::Instant;
  2
  3use anyhow::Result;
  4use async_trait::async_trait;
  5use gpui::AppContext;
  6use ordered_float::OrderedFloat;
  7use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
  8use rusqlite::ToSql;
  9
 10use crate::models::LanguageModel;
 11
 12#[derive(Debug, PartialEq, Clone)]
 13pub struct Embedding(pub Vec<f32>);
 14
 15// This is needed for semantic index functionality
 16// Unfortunately it has to live wherever the "Embedding" struct is created.
 17// Keeping this in here though, introduces a 'rusqlite' dependency into AI
 18// which is less than ideal
 19impl FromSql for Embedding {
 20    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 21        let bytes = value.as_blob()?;
 22        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 23        if embedding.is_err() {
 24            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 25        }
 26        Ok(Embedding(embedding.unwrap()))
 27    }
 28}
 29
 30impl ToSql for Embedding {
 31    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 32        let bytes = bincode::serialize(&self.0)
 33            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 34        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 35    }
 36}
 37impl From<Vec<f32>> for Embedding {
 38    fn from(value: Vec<f32>) -> Self {
 39        Embedding(value)
 40    }
 41}
 42
 43impl Embedding {
 44    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
 45        let len = self.0.len();
 46        assert_eq!(len, other.0.len());
 47
 48        let mut result = 0.0;
 49        unsafe {
 50            matrixmultiply::sgemm(
 51                1,
 52                len,
 53                1,
 54                1.0,
 55                self.0.as_ptr(),
 56                len as isize,
 57                1,
 58                other.0.as_ptr(),
 59                1,
 60                len as isize,
 61                0.0,
 62                &mut result as *mut f32,
 63                1,
 64                1,
 65            );
 66        }
 67        OrderedFloat(result)
 68    }
 69}
 70
 71#[async_trait]
 72pub trait EmbeddingProvider: Sync + Send {
 73    fn base_model(&self) -> Box<dyn LanguageModel>;
 74    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
 75    async fn embed_batch(
 76        &self,
 77        spans: Vec<String>,
 78        api_key: Option<String>,
 79    ) -> Result<Vec<Embedding>>;
 80    fn max_tokens_per_batch(&self) -> usize;
 81    fn rate_limit_expiration(&self) -> Option<Instant>;
 82}
 83
 84#[cfg(test)]
 85mod tests {
 86    use super::*;
 87    use rand::prelude::*;
 88
 89    #[gpui::test]
 90    fn test_similarity(mut rng: StdRng) {
 91        assert_eq!(
 92            Embedding::from(vec![1., 0., 0., 0., 0.])
 93                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
 94            0.
 95        );
 96        assert_eq!(
 97            Embedding::from(vec![2., 0., 0., 0., 0.])
 98                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
 99            6.
100        );
101
102        for _ in 0..100 {
103            let size = 1536;
104            let mut a = vec![0.; size];
105            let mut b = vec![0.; size];
106            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
107                *a = rng.gen();
108                *b = rng.gen();
109            }
110            let a = Embedding::from(a);
111            let b = Embedding::from(b);
112
113            assert_eq!(
114                round_to_decimals(a.similarity(&b), 1),
115                round_to_decimals(reference_dot(&a.0, &b.0), 1)
116            );
117        }
118
119        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
120            let factor = (10.0 as f32).powi(decimal_places);
121            (n * factor).round() / factor
122        }
123
124        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
125            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
126        }
127    }
128}