embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parse_duration::parse;
 11use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 12use rusqlite::ToSql;
 13use serde::{Deserialize, Serialize};
 14use std::env;
 15use std::sync::Arc;
 16use std::time::Duration;
 17use tiktoken_rs::{cl100k_base, CoreBPE};
 18use util::http::{HttpClient, Request};
 19
 20lazy_static! {
 21    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 22    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 23}
 24
 25#[derive(Debug, PartialEq, Clone)]
 26pub struct Embedding(Vec<f32>);
 27
 28impl From<Vec<f32>> for Embedding {
 29    fn from(value: Vec<f32>) -> Self {
 30        Embedding(value)
 31    }
 32}
 33
 34impl Embedding {
 35    pub fn similarity(&self, other: &Self) -> f32 {
 36        let len = self.0.len();
 37        assert_eq!(len, other.0.len());
 38
 39        let mut result = 0.0;
 40        unsafe {
 41            matrixmultiply::sgemm(
 42                1,
 43                len,
 44                1,
 45                1.0,
 46                self.0.as_ptr(),
 47                len as isize,
 48                1,
 49                other.0.as_ptr(),
 50                1,
 51                len as isize,
 52                0.0,
 53                &mut result as *mut f32,
 54                1,
 55                1,
 56            );
 57        }
 58        result
 59    }
 60}
 61
 62impl FromSql for Embedding {
 63    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 64        let bytes = value.as_blob()?;
 65        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 66        if embedding.is_err() {
 67            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 68        }
 69        Ok(Embedding(embedding.unwrap()))
 70    }
 71}
 72
 73impl ToSql for Embedding {
 74    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 75        let bytes = bincode::serialize(&self.0)
 76            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 77        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 78    }
 79}
 80
 81#[derive(Clone)]
 82pub struct OpenAIEmbeddings {
 83    pub client: Arc<dyn HttpClient>,
 84    pub executor: Arc<Background>,
 85}
 86
 87#[derive(Serialize)]
 88struct OpenAIEmbeddingRequest<'a> {
 89    model: &'static str,
 90    input: Vec<&'a str>,
 91}
 92
 93#[derive(Deserialize)]
 94struct OpenAIEmbeddingResponse {
 95    data: Vec<OpenAIEmbedding>,
 96    usage: OpenAIEmbeddingUsage,
 97}
 98
 99#[derive(Debug, Deserialize)]
100struct OpenAIEmbedding {
101    embedding: Vec<f32>,
102    index: usize,
103    object: String,
104}
105
106#[derive(Deserialize)]
107struct OpenAIEmbeddingUsage {
108    prompt_tokens: usize,
109    total_tokens: usize,
110}
111
112#[async_trait]
113pub trait EmbeddingProvider: Sync + Send {
114    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
115    fn max_tokens_per_batch(&self) -> usize;
116    fn truncate(&self, span: &str) -> (String, usize);
117}
118
119pub struct DummyEmbeddings {}
120
121#[async_trait]
122impl EmbeddingProvider for DummyEmbeddings {
123    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
124        // 1024 is the OpenAI Embeddings size for ada models.
125        // the model we will likely be starting with.
126        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
127        return Ok(vec![dummy_vec; spans.len()]);
128    }
129
130    fn max_tokens_per_batch(&self) -> usize {
131        OPENAI_INPUT_LIMIT
132    }
133
134    fn truncate(&self, span: &str) -> (String, usize) {
135        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
136        let token_count = tokens.len();
137        let output = if token_count > OPENAI_INPUT_LIMIT {
138            tokens.truncate(OPENAI_INPUT_LIMIT);
139            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
140            new_input.ok().unwrap_or_else(|| span.to_string())
141        } else {
142            span.to_string()
143        };
144
145        (output, tokens.len())
146    }
147}
148
149const OPENAI_INPUT_LIMIT: usize = 8190;
150
151impl OpenAIEmbeddings {
152    async fn send_request(
153        &self,
154        api_key: &str,
155        spans: Vec<&str>,
156        request_timeout: u64,
157    ) -> Result<Response<AsyncBody>> {
158        let request = Request::post("https://api.openai.com/v1/embeddings")
159            .redirect_policy(isahc::config::RedirectPolicy::Follow)
160            .timeout(Duration::from_secs(request_timeout))
161            .header("Content-Type", "application/json")
162            .header("Authorization", format!("Bearer {}", api_key))
163            .body(
164                serde_json::to_string(&OpenAIEmbeddingRequest {
165                    input: spans.clone(),
166                    model: "text-embedding-ada-002",
167                })
168                .unwrap()
169                .into(),
170            )?;
171
172        Ok(self.client.send(request).await?)
173    }
174}
175
176#[async_trait]
177impl EmbeddingProvider for OpenAIEmbeddings {
178    fn max_tokens_per_batch(&self) -> usize {
179        50000
180    }
181
182    fn truncate(&self, span: &str) -> (String, usize) {
183        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
184        let token_count = tokens.len();
185        let output = if token_count > OPENAI_INPUT_LIMIT {
186            tokens.truncate(OPENAI_INPUT_LIMIT);
187            OPENAI_BPE_TOKENIZER
188                .decode(tokens)
189                .ok()
190                .unwrap_or_else(|| span.to_string())
191        } else {
192            span.to_string()
193        };
194
195        (output, token_count)
196    }
197
198    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
199        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
200        const MAX_RETRIES: usize = 4;
201
202        let api_key = OPENAI_API_KEY
203            .as_ref()
204            .ok_or_else(|| anyhow!("no api key"))?;
205
206        let mut request_number = 0;
207        let mut request_timeout: u64 = 10;
208        let mut response: Response<AsyncBody>;
209        while request_number < MAX_RETRIES {
210            response = self
211                .send_request(
212                    api_key,
213                    spans.iter().map(|x| &**x).collect(),
214                    request_timeout,
215                )
216                .await?;
217            request_number += 1;
218
219            match response.status() {
220                StatusCode::REQUEST_TIMEOUT => {
221                    request_timeout += 5;
222                }
223                StatusCode::OK => {
224                    let mut body = String::new();
225                    response.body_mut().read_to_string(&mut body).await?;
226                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
227
228                    log::trace!(
229                        "openai embedding completed. tokens: {:?}",
230                        response.usage.total_tokens
231                    );
232
233                    return Ok(response
234                        .data
235                        .into_iter()
236                        .map(|embedding| Embedding::from(embedding.embedding))
237                        .collect());
238                }
239                StatusCode::TOO_MANY_REQUESTS => {
240                    let mut body = String::new();
241                    response.body_mut().read_to_string(&mut body).await?;
242
243                    let delay_duration = {
244                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
245                        if let Some(time_to_reset) =
246                            response.headers().get("x-ratelimit-reset-tokens")
247                        {
248                            if let Ok(time_str) = time_to_reset.to_str() {
249                                parse(time_str).unwrap_or(delay)
250                            } else {
251                                delay
252                            }
253                        } else {
254                            delay
255                        }
256                    };
257
258                    log::trace!(
259                        "openai rate limiting: waiting {:?} until lifted",
260                        &delay_duration
261                    );
262
263                    self.executor.timer(delay_duration).await;
264                }
265                _ => {
266                    let mut body = String::new();
267                    response.body_mut().read_to_string(&mut body).await?;
268                    return Err(anyhow!(
269                        "open ai bad request: {:?} {:?}",
270                        &response.status(),
271                        body
272                    ));
273                }
274            }
275        }
276        Err(anyhow!("openai max retries"))
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use rand::prelude::*;
284
285    #[gpui::test]
286    fn test_similarity(mut rng: StdRng) {
287        assert_eq!(
288            Embedding::from(vec![1., 0., 0., 0., 0.])
289                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
290            0.
291        );
292        assert_eq!(
293            Embedding::from(vec![2., 0., 0., 0., 0.])
294                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
295            6.
296        );
297
298        for _ in 0..100 {
299            let size = 1536;
300            let mut a = vec![0.; size];
301            let mut b = vec![0.; size];
302            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
303                *a = rng.gen();
304                *b = rng.gen();
305            }
306            let a = Embedding::from(a);
307            let b = Embedding::from(b);
308
309            assert_eq!(
310                round_to_decimals(a.similarity(&b), 1),
311                round_to_decimals(reference_dot(&a.0, &b.0), 1)
312            );
313        }
314
315        fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
316            let factor = (10.0 as f32).powi(decimal_places);
317            (n * factor).round() / factor
318        }
319
320        fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
321            a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
322        }
323    }
324}