embedding.rs

  1mod cloud;
  2mod ollama;
  3mod open_ai;
  4
  5pub use cloud::*;
  6pub use ollama::*;
  7pub use open_ai::*;
  8use sha2::{Digest, Sha256};
  9
 10use anyhow::Result;
 11use futures::{future::BoxFuture, FutureExt};
 12use serde::{Deserialize, Serialize};
 13use std::{fmt, future};
 14
 15#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
 16pub struct Embedding(Vec<f32>);
 17
 18impl Embedding {
 19    pub fn new(mut embedding: Vec<f32>) -> Self {
 20        let len = embedding.len();
 21        let mut norm = 0f32;
 22
 23        for i in 0..len {
 24            norm += embedding[i] * embedding[i];
 25        }
 26
 27        norm = norm.sqrt();
 28        for dimension in &mut embedding {
 29            *dimension /= norm;
 30        }
 31
 32        Self(embedding)
 33    }
 34
 35    fn len(&self) -> usize {
 36        self.0.len()
 37    }
 38
 39    pub fn similarity(self, other: &Embedding) -> f32 {
 40        debug_assert_eq!(self.0.len(), other.0.len());
 41        self.0
 42            .iter()
 43            .copied()
 44            .zip(other.0.iter().copied())
 45            .map(|(a, b)| a * b)
 46            .sum()
 47    }
 48}
 49
 50impl fmt::Display for Embedding {
 51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 52        let digits_to_display = 3;
 53
 54        // Start the Embedding display format
 55        write!(f, "Embedding(sized: {}; values: [", self.len())?;
 56
 57        for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
 58            // Lead with comma if not the first element
 59            if index != 0 {
 60                write!(f, ", ")?;
 61            }
 62            write!(f, "{:.3}", value)?;
 63        }
 64        if self.len() > digits_to_display {
 65            write!(f, "...")?;
 66        }
 67        write!(f, "])")
 68    }
 69}
 70
 71/// Trait for embedding providers. Texts in, vectors out.
 72pub trait EmbeddingProvider: Sync + Send {
 73    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
 74    fn batch_size(&self) -> usize;
 75}
 76
 77#[derive(Debug)]
 78pub struct TextToEmbed<'a> {
 79    pub text: &'a str,
 80    pub digest: [u8; 32],
 81}
 82
 83impl<'a> TextToEmbed<'a> {
 84    pub fn new(text: &'a str) -> Self {
 85        let digest = Sha256::digest(text.as_bytes());
 86        Self {
 87            text,
 88            digest: digest.into(),
 89        }
 90    }
 91}
 92
 93pub struct FakeEmbeddingProvider;
 94
 95impl EmbeddingProvider for FakeEmbeddingProvider {
 96    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
 97        let embeddings = texts
 98            .iter()
 99            .map(|_text| {
100                let mut embedding = vec![0f32; 1536];
101                for i in 0..embedding.len() {
102                    embedding[i] = i as f32;
103                }
104                Embedding::new(embedding)
105            })
106            .collect();
107        future::ready(Ok(embeddings)).boxed()
108    }
109
110    fn batch_size(&self) -> usize {
111        16
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use super::*;
118
119    #[gpui::test]
120    fn test_normalize_embedding() {
121        let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
122        let value: f32 = 1.0 / 3.0_f32.sqrt();
123        assert_eq!(normalized, Embedding(vec![value; 3]));
124    }
125}