1mod cloud;
2mod ollama;
3mod open_ai;
4
5pub use cloud::*;
6pub use ollama::*;
7pub use open_ai::*;
8use sha2::{Digest, Sha256};
9
10use anyhow::Result;
11use futures::{future::BoxFuture, FutureExt};
12use serde::{Deserialize, Serialize};
13use std::{fmt, future};
14
15/// Trait for embedding providers. Texts in, vectors out.
16pub trait EmbeddingProvider: Sync + Send {
17 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
18 fn batch_size(&self) -> usize;
19}
20
21#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
22pub struct Embedding(Vec<f32>);
23
24impl Embedding {
25 pub fn new(mut embedding: Vec<f32>) -> Self {
26 let len = embedding.len();
27 let mut norm = 0f32;
28
29 for i in 0..len {
30 norm += embedding[i] * embedding[i];
31 }
32
33 norm = norm.sqrt();
34 for dimension in &mut embedding {
35 *dimension /= norm;
36 }
37
38 Self(embedding)
39 }
40
41 fn len(&self) -> usize {
42 self.0.len()
43 }
44
45 pub fn similarity(self, other: &Embedding) -> f32 {
46 debug_assert_eq!(self.0.len(), other.0.len());
47 self.0
48 .iter()
49 .copied()
50 .zip(other.0.iter().copied())
51 .map(|(a, b)| a * b)
52 .sum()
53 }
54}
55
56impl fmt::Display for Embedding {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 let digits_to_display = 3;
59
60 // Start the Embedding display format
61 write!(f, "Embedding(sized: {}; values: [", self.len())?;
62
63 for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
64 // Lead with comma if not the first element
65 if index != 0 {
66 write!(f, ", ")?;
67 }
68 write!(f, "{:.3}", value)?;
69 }
70 if self.len() > digits_to_display {
71 write!(f, "...")?;
72 }
73 write!(f, "])")
74 }
75}
76
77#[derive(Debug)]
78pub struct TextToEmbed<'a> {
79 pub text: &'a str,
80 pub digest: [u8; 32],
81}
82
83impl<'a> TextToEmbed<'a> {
84 pub fn new(text: &'a str) -> Self {
85 let digest = Sha256::digest(text.as_bytes());
86 Self {
87 text,
88 digest: digest.into(),
89 }
90 }
91}
92
93pub struct FakeEmbeddingProvider;
94
95impl EmbeddingProvider for FakeEmbeddingProvider {
96 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
97 let embeddings = texts
98 .iter()
99 .map(|_text| {
100 let mut embedding = vec![0f32; 1536];
101 for i in 0..embedding.len() {
102 embedding[i] = i as f32;
103 }
104 Embedding::new(embedding)
105 })
106 .collect();
107 future::ready(Ok(embeddings)).boxed()
108 }
109
110 fn batch_size(&self) -> usize {
111 16
112 }
113}
114
115#[cfg(test)]
116mod test {
117 use super::*;
118
119 #[gpui::test]
120 fn test_normalize_embedding() {
121 let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
122 let value: f32 = 1.0 / 3.0_f32.sqrt();
123 assert_eq!(normalized, Embedding(vec![value; 3]));
124 }
125}