1use std::time::Instant;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use gpui::AppContext;
6use ordered_float::OrderedFloat;
7use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
8use rusqlite::ToSql;
9
10use crate::models::LanguageModel;
11
12#[derive(Debug, PartialEq, Clone)]
13pub struct Embedding(pub Vec<f32>);
14
15// This is needed for semantic index functionality
16// Unfortunately it has to live wherever the "Embedding" struct is created.
17// Keeping this in here though, introduces a 'rusqlite' dependency into AI
18// which is less than ideal
19impl FromSql for Embedding {
20 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
21 let bytes = value.as_blob()?;
22 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
23 if embedding.is_err() {
24 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
25 }
26 Ok(Embedding(embedding.unwrap()))
27 }
28}
29
30impl ToSql for Embedding {
31 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
32 let bytes = bincode::serialize(&self.0)
33 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
34 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
35 }
36}
37impl From<Vec<f32>> for Embedding {
38 fn from(value: Vec<f32>) -> Self {
39 Embedding(value)
40 }
41}
42
43impl Embedding {
44 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
45 let len = self.0.len();
46 assert_eq!(len, other.0.len());
47
48 let mut result = 0.0;
49 unsafe {
50 matrixmultiply::sgemm(
51 1,
52 len,
53 1,
54 1.0,
55 self.0.as_ptr(),
56 len as isize,
57 1,
58 other.0.as_ptr(),
59 1,
60 len as isize,
61 0.0,
62 &mut result as *mut f32,
63 1,
64 1,
65 );
66 }
67 OrderedFloat(result)
68 }
69}
70
71#[async_trait]
72pub trait EmbeddingProvider: Sync + Send {
73 fn base_model(&self) -> Box<dyn LanguageModel>;
74 fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
75 async fn embed_batch(
76 &self,
77 spans: Vec<String>,
78 api_key: Option<String>,
79 ) -> Result<Vec<Embedding>>;
80 fn max_tokens_per_batch(&self) -> usize;
81 fn rate_limit_expiration(&self) -> Option<Instant>;
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use rand::prelude::*;
88
89 #[gpui::test]
90 fn test_similarity(mut rng: StdRng) {
91 assert_eq!(
92 Embedding::from(vec![1., 0., 0., 0., 0.])
93 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
94 0.
95 );
96 assert_eq!(
97 Embedding::from(vec![2., 0., 0., 0., 0.])
98 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
99 6.
100 );
101
102 for _ in 0..100 {
103 let size = 1536;
104 let mut a = vec![0.; size];
105 let mut b = vec![0.; size];
106 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
107 *a = rng.gen();
108 *b = rng.gen();
109 }
110 let a = Embedding::from(a);
111 let b = Embedding::from(b);
112
113 assert_eq!(
114 round_to_decimals(a.similarity(&b), 1),
115 round_to_decimals(reference_dot(&a.0, &b.0), 1)
116 );
117 }
118
119 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
120 let factor = (10.0 as f32).powi(decimal_places);
121 (n * factor).round() / factor
122 }
123
124 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
125 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
126 }
127 }
128}