embedding.rs

  1use std::time::Instant;
  2
  3use anyhow::Result;
  4use async_trait::async_trait;
  5use gpui::AppContext;
  6use ordered_float::OrderedFloat;
  7use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
  8use rusqlite::ToSql;
  9
 10use crate::auth::{CredentialProvider, ProviderCredential};
 11use crate::models::LanguageModel;
 12
 13#[derive(Debug, PartialEq, Clone)]
 14pub struct Embedding(pub Vec<f32>);
 15
 16// This is needed for semantic index functionality
 17// Unfortunately it has to live wherever the "Embedding" struct is created.
 18// Keeping this in here though, introduces a 'rusqlite' dependency into AI
 19// which is less than ideal
 20impl FromSql for Embedding {
 21    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 22        let bytes = value.as_blob()?;
 23        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 24        if embedding.is_err() {
 25            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 26        }
 27        Ok(Embedding(embedding.unwrap()))
 28    }
 29}
 30
 31impl ToSql for Embedding {
 32    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 33        let bytes = bincode::serialize(&self.0)
 34            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 35        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 36    }
 37}
 38impl From<Vec<f32>> for Embedding {
 39    fn from(value: Vec<f32>) -> Self {
 40        Embedding(value)
 41    }
 42}
 43
 44impl Embedding {
 45    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
 46        let len = self.0.len();
 47        assert_eq!(len, other.0.len());
 48
 49        let mut result = 0.0;
 50        unsafe {
 51            matrixmultiply::sgemm(
 52                1,
 53                len,
 54                1,
 55                1.0,
 56                self.0.as_ptr(),
 57                len as isize,
 58                1,
 59                other.0.as_ptr(),
 60                1,
 61                len as isize,
 62                0.0,
 63                &mut result as *mut f32,
 64                1,
 65                1,
 66            );
 67        }
 68        OrderedFloat(result)
 69    }
 70}
 71
 72#[async_trait]
 73pub trait EmbeddingProvider: Sync + Send {
 74    fn base_model(&self) -> Box<dyn LanguageModel>;
 75    fn credential_provider(&self) -> Box<dyn CredentialProvider>;
 76    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
 77        self.credential_provider().retrieve_credentials(cx)
 78    }
 79    async fn embed_batch(
 80        &self,
 81        spans: Vec<String>,
 82        credential: ProviderCredential,
 83    ) -> Result<Vec<Embedding>>;
 84    fn max_tokens_per_batch(&self) -> usize;
 85    fn rate_limit_expiration(&self) -> Option<Instant>;
 86}
 87
 88#[cfg(test)]
 89mod tests {
 90    use super::*;
 91    use rand::prelude::*;
 92
 93    #[gpui::test]
 94    fn test_similarity(mut rng: StdRng) {
 95        assert_eq!(
 96            Embedding::from(vec![1., 0., 0., 0., 0.])
 97                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
 98            0.
 99        );
100        assert_eq!(
101            Embedding::from(vec![2., 0., 0., 0., 0.])
102                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
103            6.
104        );
105
106        for _ in 0..100 {
107            let size = 1536;
108            let mut a = vec![0.; size];
109            let mut b = vec![0.; size];
110            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
111                *a = rng.gen();
112                *b = rng.gen();
113            }
114            let a = Embedding::from(a);
115            let b = Embedding::from(b);
116
117            assert_eq!(
118                round_to_decimals(a.similarity(&b), 1),
119                round_to_decimals(reference_dot(&a.0, &b.0), 1)
120            );
121        }
122
123        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
124            let factor = (10.0 as f32).powi(decimal_places);
125            (n * factor).round() / factor
126        }
127
128        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
129            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
130        }
131    }
132}