1use anyhow::Result;
2use async_trait::async_trait;
3use ordered_float::OrderedFloat;
4use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
5use rusqlite::ToSql;
6use std::time::Instant;
7
8use crate::models::LanguageModel;
9
10#[derive(Debug, PartialEq, Clone)]
11pub struct Embedding(pub Vec<f32>);
12
13// This is needed for semantic index functionality
14// Unfortunately it has to live wherever the "Embedding" struct is created.
15// Keeping this in here though, introduces a 'rusqlite' dependency into AI
16// which is less than ideal
17impl FromSql for Embedding {
18 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
19 let bytes = value.as_blob()?;
20 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
21 if embedding.is_err() {
22 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
23 }
24 Ok(Embedding(embedding.unwrap()))
25 }
26}
27
28impl ToSql for Embedding {
29 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
30 let bytes = bincode::serialize(&self.0)
31 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
32 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
33 }
34}
35impl From<Vec<f32>> for Embedding {
36 fn from(value: Vec<f32>) -> Self {
37 Embedding(value)
38 }
39}
40
41impl Embedding {
42 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
43 let len = self.0.len();
44 assert_eq!(len, other.0.len());
45
46 let mut result = 0.0;
47 unsafe {
48 matrixmultiply::sgemm(
49 1,
50 len,
51 1,
52 1.0,
53 self.0.as_ptr(),
54 len as isize,
55 1,
56 other.0.as_ptr(),
57 1,
58 len as isize,
59 0.0,
60 &mut result as *mut f32,
61 1,
62 1,
63 );
64 }
65 OrderedFloat(result)
66 }
67}
68
69#[async_trait]
70pub trait EmbeddingProvider: Sync + Send {
71 fn base_model(&self) -> Box<dyn LanguageModel>;
72 fn is_authenticated(&self) -> bool;
73 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
74 fn max_tokens_per_batch(&self) -> usize;
75 fn truncate(&self, span: &str) -> (String, usize);
76 fn rate_limit_expiration(&self) -> Option<Instant>;
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use rand::prelude::*;
83
84 #[gpui::test]
85 fn test_similarity(mut rng: StdRng) {
86 assert_eq!(
87 Embedding::from(vec![1., 0., 0., 0., 0.])
88 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
89 0.
90 );
91 assert_eq!(
92 Embedding::from(vec![2., 0., 0., 0., 0.])
93 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
94 6.
95 );
96
97 for _ in 0..100 {
98 let size = 1536;
99 let mut a = vec![0.; size];
100 let mut b = vec![0.; size];
101 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
102 *a = rng.gen();
103 *b = rng.gen();
104 }
105 let a = Embedding::from(a);
106 let b = Embedding::from(b);
107
108 assert_eq!(
109 round_to_decimals(a.similarity(&b), 1),
110 round_to_decimals(reference_dot(&a.0, &b.0), 1)
111 );
112 }
113
114 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
115 let factor = (10.0 as f32).powi(decimal_places);
116 (n * factor).round() / factor
117 }
118
119 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
120 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
121 }
122 }
123}