1use std::time::Instant;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use gpui::AppContext;
6use ordered_float::OrderedFloat;
7use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
8use rusqlite::ToSql;
9
10use crate::auth::{CredentialProvider, ProviderCredential};
11use crate::models::LanguageModel;
12
13#[derive(Debug, PartialEq, Clone)]
14pub struct Embedding(pub Vec<f32>);
15
16// This is needed for semantic index functionality
17// Unfortunately it has to live wherever the "Embedding" struct is created.
18// Keeping this in here though, introduces a 'rusqlite' dependency into AI
19// which is less than ideal
20impl FromSql for Embedding {
21 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
22 let bytes = value.as_blob()?;
23 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
24 if embedding.is_err() {
25 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
26 }
27 Ok(Embedding(embedding.unwrap()))
28 }
29}
30
31impl ToSql for Embedding {
32 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
33 let bytes = bincode::serialize(&self.0)
34 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
35 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
36 }
37}
38impl From<Vec<f32>> for Embedding {
39 fn from(value: Vec<f32>) -> Self {
40 Embedding(value)
41 }
42}
43
44impl Embedding {
45 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
46 let len = self.0.len();
47 assert_eq!(len, other.0.len());
48
49 let mut result = 0.0;
50 unsafe {
51 matrixmultiply::sgemm(
52 1,
53 len,
54 1,
55 1.0,
56 self.0.as_ptr(),
57 len as isize,
58 1,
59 other.0.as_ptr(),
60 1,
61 len as isize,
62 0.0,
63 &mut result as *mut f32,
64 1,
65 1,
66 );
67 }
68 OrderedFloat(result)
69 }
70}
71
72#[async_trait]
73pub trait EmbeddingProvider: Sync + Send {
74 fn base_model(&self) -> Box<dyn LanguageModel>;
75 fn credential_provider(&self) -> Box<dyn CredentialProvider>;
76 fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
77 self.credential_provider().retrieve_credentials(cx)
78 }
79 async fn embed_batch(
80 &self,
81 spans: Vec<String>,
82 credential: ProviderCredential,
83 ) -> Result<Vec<Embedding>>;
84 fn max_tokens_per_batch(&self) -> usize;
85 fn rate_limit_expiration(&self) -> Option<Instant>;
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use rand::prelude::*;
92
93 #[gpui::test]
94 fn test_similarity(mut rng: StdRng) {
95 assert_eq!(
96 Embedding::from(vec![1., 0., 0., 0., 0.])
97 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
98 0.
99 );
100 assert_eq!(
101 Embedding::from(vec![2., 0., 0., 0., 0.])
102 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
103 6.
104 );
105
106 for _ in 0..100 {
107 let size = 1536;
108 let mut a = vec![0.; size];
109 let mut b = vec![0.; size];
110 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
111 *a = rng.gen();
112 *b = rng.gen();
113 }
114 let a = Embedding::from(a);
115 let b = Embedding::from(b);
116
117 assert_eq!(
118 round_to_decimals(a.similarity(&b), 1),
119 round_to_decimals(reference_dot(&a.0, &b.0), 1)
120 );
121 }
122
123 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
124 let factor = (10.0 as f32).powi(decimal_places);
125 (n * factor).round() / factor
126 }
127
128 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
129 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
130 }
131 }
132}