1mod lmstudio;
2mod ollama;
3mod open_ai;
4
5pub use lmstudio::*;
6pub use ollama::*;
7pub use open_ai::*;
8use sha2::{Digest, Sha256};
9
10use anyhow::Result;
11use futures::{FutureExt, future::BoxFuture};
12use serde::{Deserialize, Serialize};
13use std::{fmt, future};
14
15/// Trait for embedding providers. Texts in, vectors out.
16pub trait EmbeddingProvider: Sync + Send {
17 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
18 fn batch_size(&self) -> usize;
19}
20
21#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
22pub struct Embedding(Vec<f32>);
23
24impl Embedding {
25 pub fn new(mut embedding: Vec<f32>) -> Self {
26 let len = embedding.len();
27 let mut norm = 0f32;
28
29 for i in 0..len {
30 norm += embedding[i] * embedding[i];
31 }
32
33 norm = norm.sqrt();
34 for dimension in &mut embedding {
35 *dimension /= norm;
36 }
37
38 Self(embedding)
39 }
40
41 fn len(&self) -> usize {
42 self.0.len()
43 }
44
45 pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) {
46 debug_assert!(others.iter().all(|other| self.0.len() == other.0.len()));
47 others
48 .iter()
49 .enumerate()
50 .map(|(index, other)| {
51 let dot_product: f32 = self
52 .0
53 .iter()
54 .copied()
55 .zip(other.0.iter().copied())
56 .map(|(a, b)| a * b)
57 .sum();
58 (dot_product, index)
59 })
60 .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
61 .unwrap_or((0.0, 0))
62 }
63}
64
65impl fmt::Display for Embedding {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 let digits_to_display = 3;
68
69 // Start the Embedding display format
70 write!(f, "Embedding(sized: {}; values: [", self.len())?;
71
72 for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
73 // Lead with comma if not the first element
74 if index != 0 {
75 write!(f, ", ")?;
76 }
77 write!(f, "{:.3}", value)?;
78 }
79 if self.len() > digits_to_display {
80 write!(f, "...")?;
81 }
82 write!(f, "])")
83 }
84}
85
86#[derive(Debug)]
87pub struct TextToEmbed<'a> {
88 pub text: &'a str,
89 pub digest: [u8; 32],
90}
91
92impl<'a> TextToEmbed<'a> {
93 pub fn new(text: &'a str) -> Self {
94 let digest = Sha256::digest(text.as_bytes());
95 Self {
96 text,
97 digest: digest.into(),
98 }
99 }
100}
101
102pub struct FakeEmbeddingProvider;
103
104impl EmbeddingProvider for FakeEmbeddingProvider {
105 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
106 let embeddings = texts
107 .iter()
108 .map(|_text| {
109 let mut embedding = vec![0f32; 1536];
110 for i in 0..embedding.len() {
111 embedding[i] = i as f32;
112 }
113 Embedding::new(embedding)
114 })
115 .collect();
116 future::ready(Ok(embeddings)).boxed()
117 }
118
119 fn batch_size(&self) -> usize {
120 16
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use super::*;
127
128 #[gpui::test]
129 fn test_normalize_embedding() {
130 let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
131 let value: f32 = 1.0 / 3.0_f32.sqrt();
132 assert_eq!(normalized, Embedding(vec![value; 3]));
133 }
134}