embedding.rs

  1mod cloud;
  2mod lmstudio;
  3mod ollama;
  4mod open_ai;
  5
  6pub use cloud::*;
  7pub use lmstudio::*;
  8pub use ollama::*;
  9pub use open_ai::*;
 10use sha2::{Digest, Sha256};
 11
 12use anyhow::Result;
 13use futures::{FutureExt, future::BoxFuture};
 14use serde::{Deserialize, Serialize};
 15use std::{fmt, future};
 16
 17/// Trait for embedding providers. Texts in, vectors out.
 18pub trait EmbeddingProvider: Sync + Send {
 19    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
 20    fn batch_size(&self) -> usize;
 21}
 22
 23#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
 24pub struct Embedding(Vec<f32>);
 25
 26impl Embedding {
 27    pub fn new(mut embedding: Vec<f32>) -> Self {
 28        let len = embedding.len();
 29        let mut norm = 0f32;
 30
 31        for i in 0..len {
 32            norm += embedding[i] * embedding[i];
 33        }
 34
 35        norm = norm.sqrt();
 36        for dimension in &mut embedding {
 37            *dimension /= norm;
 38        }
 39
 40        Self(embedding)
 41    }
 42
 43    fn len(&self) -> usize {
 44        self.0.len()
 45    }
 46
 47    pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) {
 48        debug_assert!(others.iter().all(|other| self.0.len() == other.0.len()));
 49        others
 50            .iter()
 51            .enumerate()
 52            .map(|(index, other)| {
 53                let dot_product: f32 = self
 54                    .0
 55                    .iter()
 56                    .copied()
 57                    .zip(other.0.iter().copied())
 58                    .map(|(a, b)| a * b)
 59                    .sum();
 60                (dot_product, index)
 61            })
 62            .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
 63            .unwrap_or((0.0, 0))
 64    }
 65}
 66
 67impl fmt::Display for Embedding {
 68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 69        let digits_to_display = 3;
 70
 71        // Start the Embedding display format
 72        write!(f, "Embedding(sized: {}; values: [", self.len())?;
 73
 74        for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
 75            // Lead with comma if not the first element
 76            if index != 0 {
 77                write!(f, ", ")?;
 78            }
 79            write!(f, "{:.3}", value)?;
 80        }
 81        if self.len() > digits_to_display {
 82            write!(f, "...")?;
 83        }
 84        write!(f, "])")
 85    }
 86}
 87
 88#[derive(Debug)]
 89pub struct TextToEmbed<'a> {
 90    pub text: &'a str,
 91    pub digest: [u8; 32],
 92}
 93
 94impl<'a> TextToEmbed<'a> {
 95    pub fn new(text: &'a str) -> Self {
 96        let digest = Sha256::digest(text.as_bytes());
 97        Self {
 98            text,
 99            digest: digest.into(),
100        }
101    }
102}
103
104pub struct FakeEmbeddingProvider;
105
106impl EmbeddingProvider for FakeEmbeddingProvider {
107    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
108        let embeddings = texts
109            .iter()
110            .map(|_text| {
111                let mut embedding = vec![0f32; 1536];
112                for i in 0..embedding.len() {
113                    embedding[i] = i as f32;
114                }
115                Embedding::new(embedding)
116            })
117            .collect();
118        future::ready(Ok(embeddings)).boxed()
119    }
120
121    fn batch_size(&self) -> usize {
122        16
123    }
124}
125
126#[cfg(test)]
127mod test {
128    use super::*;
129
130    #[gpui::test]
131    fn test_normalize_embedding() {
132        let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
133        let value: f32 = 1.0 / 3.0_f32.sqrt();
134        assert_eq!(normalized, Embedding(vec![value; 3]));
135    }
136}