1mod cloud;
2mod lmstudio;
3mod ollama;
4mod open_ai;
5
6pub use cloud::*;
7pub use lmstudio::*;
8pub use ollama::*;
9pub use open_ai::*;
10use sha2::{Digest, Sha256};
11
12use anyhow::Result;
13use futures::{FutureExt, future::BoxFuture};
14use serde::{Deserialize, Serialize};
15use std::{fmt, future};
16
17/// Trait for embedding providers. Texts in, vectors out.
18pub trait EmbeddingProvider: Sync + Send {
19 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
20 fn batch_size(&self) -> usize;
21}
22
23#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
24pub struct Embedding(Vec<f32>);
25
26impl Embedding {
27 pub fn new(mut embedding: Vec<f32>) -> Self {
28 let len = embedding.len();
29 let mut norm = 0f32;
30
31 for i in 0..len {
32 norm += embedding[i] * embedding[i];
33 }
34
35 norm = norm.sqrt();
36 for dimension in &mut embedding {
37 *dimension /= norm;
38 }
39
40 Self(embedding)
41 }
42
43 fn len(&self) -> usize {
44 self.0.len()
45 }
46
47 pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) {
48 debug_assert!(others.iter().all(|other| self.0.len() == other.0.len()));
49 others
50 .iter()
51 .enumerate()
52 .map(|(index, other)| {
53 let dot_product: f32 = self
54 .0
55 .iter()
56 .copied()
57 .zip(other.0.iter().copied())
58 .map(|(a, b)| a * b)
59 .sum();
60 (dot_product, index)
61 })
62 .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
63 .unwrap_or((0.0, 0))
64 }
65}
66
67impl fmt::Display for Embedding {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 let digits_to_display = 3;
70
71 // Start the Embedding display format
72 write!(f, "Embedding(sized: {}; values: [", self.len())?;
73
74 for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
75 // Lead with comma if not the first element
76 if index != 0 {
77 write!(f, ", ")?;
78 }
79 write!(f, "{:.3}", value)?;
80 }
81 if self.len() > digits_to_display {
82 write!(f, "...")?;
83 }
84 write!(f, "])")
85 }
86}
87
88#[derive(Debug)]
89pub struct TextToEmbed<'a> {
90 pub text: &'a str,
91 pub digest: [u8; 32],
92}
93
94impl<'a> TextToEmbed<'a> {
95 pub fn new(text: &'a str) -> Self {
96 let digest = Sha256::digest(text.as_bytes());
97 Self {
98 text,
99 digest: digest.into(),
100 }
101 }
102}
103
104pub struct FakeEmbeddingProvider;
105
106impl EmbeddingProvider for FakeEmbeddingProvider {
107 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
108 let embeddings = texts
109 .iter()
110 .map(|_text| {
111 let mut embedding = vec![0f32; 1536];
112 for i in 0..embedding.len() {
113 embedding[i] = i as f32;
114 }
115 Embedding::new(embedding)
116 })
117 .collect();
118 future::ready(Ok(embeddings)).boxed()
119 }
120
121 fn batch_size(&self) -> usize {
122 16
123 }
124}
125
126#[cfg(test)]
127mod test {
128 use super::*;
129
130 #[gpui::test]
131 fn test_normalize_embedding() {
132 let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
133 let value: f32 = 1.0 / 3.0_f32.sqrt();
134 assert_eq!(normalized, Embedding(vec![value; 3]));
135 }
136}