embedding.rs

  1mod lmstudio;
  2mod ollama;
  3mod open_ai;
  4
  5pub use lmstudio::*;
  6pub use ollama::*;
  7pub use open_ai::*;
  8use sha2::{Digest, Sha256};
  9
 10use anyhow::Result;
 11use futures::{FutureExt, future::BoxFuture};
 12use serde::{Deserialize, Serialize};
 13use std::{fmt, future};
 14
 15/// Trait for embedding providers. Texts in, vectors out.
 16pub trait EmbeddingProvider: Sync + Send {
 17    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
 18    fn batch_size(&self) -> usize;
 19}
 20
 21#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
 22pub struct Embedding(Vec<f32>);
 23
 24impl Embedding {
 25    pub fn new(mut embedding: Vec<f32>) -> Self {
 26        let len = embedding.len();
 27        let mut norm = 0f32;
 28
 29        for i in 0..len {
 30            norm += embedding[i] * embedding[i];
 31        }
 32
 33        norm = norm.sqrt();
 34        for dimension in &mut embedding {
 35            *dimension /= norm;
 36        }
 37
 38        Self(embedding)
 39    }
 40
 41    fn len(&self) -> usize {
 42        self.0.len()
 43    }
 44
 45    pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) {
 46        debug_assert!(others.iter().all(|other| self.0.len() == other.0.len()));
 47        others
 48            .iter()
 49            .enumerate()
 50            .map(|(index, other)| {
 51                let dot_product: f32 = self
 52                    .0
 53                    .iter()
 54                    .copied()
 55                    .zip(other.0.iter().copied())
 56                    .map(|(a, b)| a * b)
 57                    .sum();
 58                (dot_product, index)
 59            })
 60            .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
 61            .unwrap_or((0.0, 0))
 62    }
 63}
 64
 65impl fmt::Display for Embedding {
 66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 67        let digits_to_display = 3;
 68
 69        // Start the Embedding display format
 70        write!(f, "Embedding(sized: {}; values: [", self.len())?;
 71
 72        for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
 73            // Lead with comma if not the first element
 74            if index != 0 {
 75                write!(f, ", ")?;
 76            }
 77            write!(f, "{:.3}", value)?;
 78        }
 79        if self.len() > digits_to_display {
 80            write!(f, "...")?;
 81        }
 82        write!(f, "])")
 83    }
 84}
 85
 86#[derive(Debug)]
 87pub struct TextToEmbed<'a> {
 88    pub text: &'a str,
 89    pub digest: [u8; 32],
 90}
 91
 92impl<'a> TextToEmbed<'a> {
 93    pub fn new(text: &'a str) -> Self {
 94        let digest = Sha256::digest(text.as_bytes());
 95        Self {
 96            text,
 97            digest: digest.into(),
 98        }
 99    }
100}
101
102pub struct FakeEmbeddingProvider;
103
104impl EmbeddingProvider for FakeEmbeddingProvider {
105    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
106        let embeddings = texts
107            .iter()
108            .map(|_text| {
109                let mut embedding = vec![0f32; 1536];
110                for i in 0..embedding.len() {
111                    embedding[i] = i as f32;
112                }
113                Embedding::new(embedding)
114            })
115            .collect();
116        future::ready(Ok(embeddings)).boxed()
117    }
118
119    fn batch_size(&self) -> usize {
120        16
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use super::*;
127
128    #[gpui::test]
129    fn test_normalize_embedding() {
130        let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
131        let value: f32 = 1.0 / 3.0_f32.sqrt();
132        assert_eq!(normalized, Embedding(vec![value; 3]));
133    }
134}