embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use ordered_float::OrderedFloat;
 11use parking_lot::Mutex;
 12use parse_duration::parse;
 13use postage::watch;
 14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 15use rusqlite::ToSql;
 16use serde::{Deserialize, Serialize};
 17use std::env;
 18use std::ops::Add;
 19use std::sync::Arc;
 20use std::time::{Duration, Instant};
 21use tiktoken_rs::{cl100k_base, CoreBPE};
 22use util::http::{HttpClient, Request};
 23
 24lazy_static! {
 25    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 26    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 27}
 28
 29#[derive(Debug, PartialEq, Clone)]
 30pub struct Embedding(Vec<f32>);
 31
 32impl From<Vec<f32>> for Embedding {
 33    fn from(value: Vec<f32>) -> Self {
 34        Embedding(value)
 35    }
 36}
 37
 38impl Embedding {
 39    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
 40        let len = self.0.len();
 41        assert_eq!(len, other.0.len());
 42
 43        let mut result = 0.0;
 44        unsafe {
 45            matrixmultiply::sgemm(
 46                1,
 47                len,
 48                1,
 49                1.0,
 50                self.0.as_ptr(),
 51                len as isize,
 52                1,
 53                other.0.as_ptr(),
 54                1,
 55                len as isize,
 56                0.0,
 57                &mut result as *mut f32,
 58                1,
 59                1,
 60            );
 61        }
 62        OrderedFloat(result)
 63    }
 64}
 65
 66impl FromSql for Embedding {
 67    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 68        let bytes = value.as_blob()?;
 69        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 70        if embedding.is_err() {
 71            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 72        }
 73        Ok(Embedding(embedding.unwrap()))
 74    }
 75}
 76
 77impl ToSql for Embedding {
 78    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 79        let bytes = bincode::serialize(&self.0)
 80            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 81        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 82    }
 83}
 84
 85#[derive(Clone)]
 86pub struct OpenAIEmbeddings {
 87    pub client: Arc<dyn HttpClient>,
 88    pub executor: Arc<Background>,
 89    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 90    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 91}
 92
 93#[derive(Serialize)]
 94struct OpenAIEmbeddingRequest<'a> {
 95    model: &'static str,
 96    input: Vec<&'a str>,
 97}
 98
 99#[derive(Deserialize)]
100struct OpenAIEmbeddingResponse {
101    data: Vec<OpenAIEmbedding>,
102    usage: OpenAIEmbeddingUsage,
103}
104
105#[derive(Debug, Deserialize)]
106struct OpenAIEmbedding {
107    embedding: Vec<f32>,
108    index: usize,
109    object: String,
110}
111
112#[derive(Deserialize)]
113struct OpenAIEmbeddingUsage {
114    prompt_tokens: usize,
115    total_tokens: usize,
116}
117
118#[async_trait]
119pub trait EmbeddingProvider: Sync + Send {
120    fn is_authenticated(&self) -> bool;
121    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
122    fn max_tokens_per_batch(&self) -> usize;
123    fn truncate(&self, span: &str) -> (String, usize);
124    fn rate_limit_expiration(&self) -> Option<Instant>;
125}
126
127pub struct DummyEmbeddings {}
128
129#[async_trait]
130impl EmbeddingProvider for DummyEmbeddings {
131    fn is_authenticated(&self) -> bool {
132        true
133    }
134    fn rate_limit_expiration(&self) -> Option<Instant> {
135        None
136    }
137    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
138        // 1024 is the OpenAI Embeddings size for ada models.
139        // the model we will likely be starting with.
140        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
141        return Ok(vec![dummy_vec; spans.len()]);
142    }
143
144    fn max_tokens_per_batch(&self) -> usize {
145        OPENAI_INPUT_LIMIT
146    }
147
148    fn truncate(&self, span: &str) -> (String, usize) {
149        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
150        let token_count = tokens.len();
151        let output = if token_count > OPENAI_INPUT_LIMIT {
152            tokens.truncate(OPENAI_INPUT_LIMIT);
153            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
154            new_input.ok().unwrap_or_else(|| span.to_string())
155        } else {
156            span.to_string()
157        };
158
159        (output, tokens.len())
160    }
161}
162
163const OPENAI_INPUT_LIMIT: usize = 8190;
164
165impl OpenAIEmbeddings {
166    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
167        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
168        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
169
170        OpenAIEmbeddings {
171            client,
172            executor,
173            rate_limit_count_rx,
174            rate_limit_count_tx,
175        }
176    }
177
178    fn resolve_rate_limit(&self) {
179        let reset_time = *self.rate_limit_count_tx.lock().borrow();
180
181        if let Some(reset_time) = reset_time {
182            if Instant::now() >= reset_time {
183                *self.rate_limit_count_tx.lock().borrow_mut() = None
184            }
185        }
186
187        log::trace!(
188            "resolving reset time: {:?}",
189            *self.rate_limit_count_tx.lock().borrow()
190        );
191    }
192
193    fn update_reset_time(&self, reset_time: Instant) {
194        let original_time = *self.rate_limit_count_tx.lock().borrow();
195
196        let updated_time = if let Some(original_time) = original_time {
197            if reset_time < original_time {
198                Some(reset_time)
199            } else {
200                Some(original_time)
201            }
202        } else {
203            Some(reset_time)
204        };
205
206        log::trace!("updating rate limit time: {:?}", updated_time);
207
208        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
209    }
210    async fn send_request(
211        &self,
212        api_key: &str,
213        spans: Vec<&str>,
214        request_timeout: u64,
215    ) -> Result<Response<AsyncBody>> {
216        let request = Request::post("https://api.openai.com/v1/embeddings")
217            .redirect_policy(isahc::config::RedirectPolicy::Follow)
218            .timeout(Duration::from_secs(request_timeout))
219            .header("Content-Type", "application/json")
220            .header("Authorization", format!("Bearer {}", api_key))
221            .body(
222                serde_json::to_string(&OpenAIEmbeddingRequest {
223                    input: spans.clone(),
224                    model: "text-embedding-ada-002",
225                })
226                .unwrap()
227                .into(),
228            )?;
229
230        Ok(self.client.send(request).await?)
231    }
232}
233
234#[async_trait]
235impl EmbeddingProvider for OpenAIEmbeddings {
236    fn is_authenticated(&self) -> bool {
237        OPENAI_API_KEY.as_ref().is_some()
238    }
239    fn max_tokens_per_batch(&self) -> usize {
240        50000
241    }
242
243    fn rate_limit_expiration(&self) -> Option<Instant> {
244        *self.rate_limit_count_rx.borrow()
245    }
246    fn truncate(&self, span: &str) -> (String, usize) {
247        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
248        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
249            tokens.truncate(OPENAI_INPUT_LIMIT);
250            OPENAI_BPE_TOKENIZER
251                .decode(tokens.clone())
252                .ok()
253                .unwrap_or_else(|| span.to_string())
254        } else {
255            span.to_string()
256        };
257
258        (output, tokens.len())
259    }
260
261    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
262        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
263        const MAX_RETRIES: usize = 4;
264
265        let api_key = OPENAI_API_KEY
266            .as_ref()
267            .ok_or_else(|| anyhow!("no api key"))?;
268
269        let mut request_number = 0;
270        let mut rate_limiting = false;
271        let mut request_timeout: u64 = 15;
272        let mut response: Response<AsyncBody>;
273        while request_number < MAX_RETRIES {
274            response = self
275                .send_request(
276                    api_key,
277                    spans.iter().map(|x| &**x).collect(),
278                    request_timeout,
279                )
280                .await?;
281            request_number += 1;
282
283            match response.status() {
284                StatusCode::REQUEST_TIMEOUT => {
285                    request_timeout += 5;
286                }
287                StatusCode::OK => {
288                    let mut body = String::new();
289                    response.body_mut().read_to_string(&mut body).await?;
290                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
291
292                    log::trace!(
293                        "openai embedding completed. tokens: {:?}",
294                        response.usage.total_tokens
295                    );
296
297                    // If we complete a request successfully that was previously rate_limited
298                    // resolve the rate limit
299                    if rate_limiting {
300                        self.resolve_rate_limit()
301                    }
302
303                    return Ok(response
304                        .data
305                        .into_iter()
306                        .map(|embedding| Embedding::from(embedding.embedding))
307                        .collect());
308                }
309                StatusCode::TOO_MANY_REQUESTS => {
310                    rate_limiting = true;
311                    let mut body = String::new();
312                    response.body_mut().read_to_string(&mut body).await?;
313
314                    let delay_duration = {
315                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
316                        if let Some(time_to_reset) =
317                            response.headers().get("x-ratelimit-reset-tokens")
318                        {
319                            if let Ok(time_str) = time_to_reset.to_str() {
320                                parse(time_str).unwrap_or(delay)
321                            } else {
322                                delay
323                            }
324                        } else {
325                            delay
326                        }
327                    };
328
329                    // If we've previously rate limited, increment the duration but not the count
330                    let reset_time = Instant::now().add(delay_duration);
331                    self.update_reset_time(reset_time);
332
333                    log::trace!(
334                        "openai rate limiting: waiting {:?} until lifted",
335                        &delay_duration
336                    );
337
338                    self.executor.timer(delay_duration).await;
339                }
340                _ => {
341                    let mut body = String::new();
342                    response.body_mut().read_to_string(&mut body).await?;
343                    return Err(anyhow!(
344                        "open ai bad request: {:?} {:?}",
345                        &response.status(),
346                        body
347                    ));
348                }
349            }
350        }
351        Err(anyhow!("openai max retries"))
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use rand::prelude::*;
359
360    #[gpui::test]
361    fn test_similarity(mut rng: StdRng) {
362        assert_eq!(
363            Embedding::from(vec![1., 0., 0., 0., 0.])
364                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
365            0.
366        );
367        assert_eq!(
368            Embedding::from(vec![2., 0., 0., 0., 0.])
369                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
370            6.
371        );
372
373        for _ in 0..100 {
374            let size = 1536;
375            let mut a = vec![0.; size];
376            let mut b = vec![0.; size];
377            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
378                *a = rng.gen();
379                *b = rng.gen();
380            }
381            let a = Embedding::from(a);
382            let b = Embedding::from(b);
383
384            assert_eq!(
385                round_to_decimals(a.similarity(&b), 1),
386                round_to_decimals(reference_dot(&a.0, &b.0), 1)
387            );
388        }
389
390        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
391            let factor = (10.0 as f32).powi(decimal_places);
392            (n * factor).round() / factor
393        }
394
395        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
396            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
397        }
398    }
399}