1use std::time::Instant;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use ordered_float::OrderedFloat;
6use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
7use rusqlite::ToSql;
8
9use crate::auth::CredentialProvider;
10use crate::models::LanguageModel;
11
12#[derive(Debug, PartialEq, Clone)]
13pub struct Embedding(pub Vec<f32>);
14
15// This is needed for semantic index functionality
16// Unfortunately it has to live wherever the "Embedding" struct is created.
17// Keeping this in here though, introduces a 'rusqlite' dependency into AI
18// which is less than ideal
19impl FromSql for Embedding {
20 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
21 let bytes = value.as_blob()?;
22 let embedding =
23 bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?;
24 Ok(Embedding(embedding))
25 }
26}
27
28impl ToSql for Embedding {
29 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
30 let bytes = bincode::serialize(&self.0)
31 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
32 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
33 }
34}
35impl From<Vec<f32>> for Embedding {
36 fn from(value: Vec<f32>) -> Self {
37 Embedding(value)
38 }
39}
40
41impl Embedding {
42 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
43 let len = self.0.len();
44 assert_eq!(len, other.0.len());
45
46 let mut result = 0.0;
47 unsafe {
48 matrixmultiply::sgemm(
49 1,
50 len,
51 1,
52 1.0,
53 self.0.as_ptr(),
54 len as isize,
55 1,
56 other.0.as_ptr(),
57 1,
58 len as isize,
59 0.0,
60 &mut result as *mut f32,
61 1,
62 1,
63 );
64 }
65 OrderedFloat(result)
66 }
67}
68
69#[async_trait]
70pub trait EmbeddingProvider: CredentialProvider {
71 fn base_model(&self) -> Box<dyn LanguageModel>;
72 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
73 fn max_tokens_per_batch(&self) -> usize;
74 fn rate_limit_expiration(&self) -> Option<Instant>;
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use rand::prelude::*;
81
82 #[gpui::test]
83 fn test_similarity(mut rng: StdRng) {
84 assert_eq!(
85 Embedding::from(vec![1., 0., 0., 0., 0.])
86 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
87 0.
88 );
89 assert_eq!(
90 Embedding::from(vec![2., 0., 0., 0., 0.])
91 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
92 6.
93 );
94
95 for _ in 0..100 {
96 let size = 1536;
97 let mut a = vec![0.; size];
98 let mut b = vec![0.; size];
99 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
100 *a = rng.gen();
101 *b = rng.gen();
102 }
103 let a = Embedding::from(a);
104 let b = Embedding::from(b);
105
106 assert_eq!(
107 round_to_decimals(a.similarity(&b), 1),
108 round_to_decimals(reference_dot(&a.0, &b.0), 1)
109 );
110 }
111
112 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
113 let factor = 10.0_f32.powi(decimal_places);
114 (n * factor).round() / factor
115 }
116
117 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
118 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
119 }
120 }
121}