1use anyhow::Result;
2use async_trait::async_trait;
3use ordered_float::OrderedFloat;
4use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
5use rusqlite::ToSql;
6use std::time::Instant;
7
8use crate::models::LanguageModel;
9
10#[derive(Debug, PartialEq, Clone)]
11pub struct Embedding(pub Vec<f32>);
12
13// This is needed for semantic index functionality
14// Unfortunately it has to live wherever the "Embedding" struct is created.
15// Keeping this in here though, introduces a 'rusqlite' dependency into AI
16// which is less than ideal
17impl FromSql for Embedding {
18 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
19 let bytes = value.as_blob()?;
20 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
21 if embedding.is_err() {
22 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
23 }
24 Ok(Embedding(embedding.unwrap()))
25 }
26}
27
28impl ToSql for Embedding {
29 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
30 let bytes = bincode::serialize(&self.0)
31 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
32 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
33 }
34}
35impl From<Vec<f32>> for Embedding {
36 fn from(value: Vec<f32>) -> Self {
37 Embedding(value)
38 }
39}
40
41impl Embedding {
42 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
43 let len = self.0.len();
44 assert_eq!(len, other.0.len());
45
46 let mut result = 0.0;
47 unsafe {
48 matrixmultiply::sgemm(
49 1,
50 len,
51 1,
52 1.0,
53 self.0.as_ptr(),
54 len as isize,
55 1,
56 other.0.as_ptr(),
57 1,
58 len as isize,
59 0.0,
60 &mut result as *mut f32,
61 1,
62 1,
63 );
64 }
65 OrderedFloat(result)
66 }
67}
68
69#[async_trait]
70pub trait EmbeddingProvider: Sync + Send {
71 fn base_model(&self) -> Box<dyn LanguageModel>;
72 fn is_authenticated(&self) -> bool;
73 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
74 fn max_tokens_per_batch(&self) -> usize;
75 fn rate_limit_expiration(&self) -> Option<Instant>;
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use rand::prelude::*;
82
83 #[gpui::test]
84 fn test_similarity(mut rng: StdRng) {
85 assert_eq!(
86 Embedding::from(vec![1., 0., 0., 0., 0.])
87 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
88 0.
89 );
90 assert_eq!(
91 Embedding::from(vec![2., 0., 0., 0., 0.])
92 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
93 6.
94 );
95
96 for _ in 0..100 {
97 let size = 1536;
98 let mut a = vec![0.; size];
99 let mut b = vec![0.; size];
100 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
101 *a = rng.gen();
102 *b = rng.gen();
103 }
104 let a = Embedding::from(a);
105 let b = Embedding::from(b);
106
107 assert_eq!(
108 round_to_decimals(a.similarity(&b), 1),
109 round_to_decimals(reference_dot(&a.0, &b.0), 1)
110 );
111 }
112
113 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
114 let factor = (10.0 as f32).powi(decimal_places);
115 (n * factor).round() / factor
116 }
117
118 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
119 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
120 }
121 }
122}