embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use ordered_float::OrderedFloat;
 11use parking_lot::Mutex;
 12use parse_duration::parse;
 13use postage::watch;
 14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 15use rusqlite::ToSql;
 16use serde::{Deserialize, Serialize};
 17use std::env;
 18use std::ops::Add;
 19use std::sync::Arc;
 20use std::time::{Duration, Instant};
 21use tiktoken_rs::{cl100k_base, CoreBPE};
 22use util::http::{HttpClient, Request};
 23
 24lazy_static! {
 25    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 26    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 27}
 28
 29#[derive(Debug, PartialEq, Clone)]
 30pub struct Embedding(pub Vec<f32>);
 31
 32// This is needed for semantic index functionality
 33// Unfortunately it has to live wherever the "Embedding" struct is created.
 34// Keeping this in here though, introduces a 'rusqlite' dependency into AI
 35// which is less than ideal
 36impl FromSql for Embedding {
 37    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 38        let bytes = value.as_blob()?;
 39        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 40        if embedding.is_err() {
 41            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 42        }
 43        Ok(Embedding(embedding.unwrap()))
 44    }
 45}
 46
 47impl ToSql for Embedding {
 48    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 49        let bytes = bincode::serialize(&self.0)
 50            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 51        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 52    }
 53}
 54impl From<Vec<f32>> for Embedding {
 55    fn from(value: Vec<f32>) -> Self {
 56        Embedding(value)
 57    }
 58}
 59
 60impl Embedding {
 61    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
 62        let len = self.0.len();
 63        assert_eq!(len, other.0.len());
 64
 65        let mut result = 0.0;
 66        unsafe {
 67            matrixmultiply::sgemm(
 68                1,
 69                len,
 70                1,
 71                1.0,
 72                self.0.as_ptr(),
 73                len as isize,
 74                1,
 75                other.0.as_ptr(),
 76                1,
 77                len as isize,
 78                0.0,
 79                &mut result as *mut f32,
 80                1,
 81                1,
 82            );
 83        }
 84        OrderedFloat(result)
 85    }
 86}
 87
 88// impl FromSql for Embedding {
 89//     fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 90//         let bytes = value.as_blob()?;
 91//         let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 92//         if embedding.is_err() {
 93//             return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 94//         }
 95//         Ok(Embedding(embedding.unwrap()))
 96//     }
 97// }
 98
 99// impl ToSql for Embedding {
100//     fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
101//         let bytes = bincode::serialize(&self.0)
102//             .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
103//         Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
104//     }
105// }
106
107#[derive(Clone)]
108pub struct OpenAIEmbeddings {
109    pub client: Arc<dyn HttpClient>,
110    pub executor: Arc<Background>,
111    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
112    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
113}
114
115#[derive(Serialize)]
116struct OpenAIEmbeddingRequest<'a> {
117    model: &'static str,
118    input: Vec<&'a str>,
119}
120
121#[derive(Deserialize)]
122struct OpenAIEmbeddingResponse {
123    data: Vec<OpenAIEmbedding>,
124    usage: OpenAIEmbeddingUsage,
125}
126
127#[derive(Debug, Deserialize)]
128struct OpenAIEmbedding {
129    embedding: Vec<f32>,
130    index: usize,
131    object: String,
132}
133
134#[derive(Deserialize)]
135struct OpenAIEmbeddingUsage {
136    prompt_tokens: usize,
137    total_tokens: usize,
138}
139
140#[async_trait]
141pub trait EmbeddingProvider: Sync + Send {
142    fn is_authenticated(&self) -> bool;
143    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
144    fn max_tokens_per_batch(&self) -> usize;
145    fn truncate(&self, span: &str) -> (String, usize);
146    fn rate_limit_expiration(&self) -> Option<Instant>;
147}
148
149pub struct DummyEmbeddings {}
150
151#[async_trait]
152impl EmbeddingProvider for DummyEmbeddings {
153    fn is_authenticated(&self) -> bool {
154        true
155    }
156    fn rate_limit_expiration(&self) -> Option<Instant> {
157        None
158    }
159    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
160        // 1024 is the OpenAI Embeddings size for ada models.
161        // the model we will likely be starting with.
162        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
163        return Ok(vec![dummy_vec; spans.len()]);
164    }
165
166    fn max_tokens_per_batch(&self) -> usize {
167        OPENAI_INPUT_LIMIT
168    }
169
170    fn truncate(&self, span: &str) -> (String, usize) {
171        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
172        let token_count = tokens.len();
173        let output = if token_count > OPENAI_INPUT_LIMIT {
174            tokens.truncate(OPENAI_INPUT_LIMIT);
175            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
176            new_input.ok().unwrap_or_else(|| span.to_string())
177        } else {
178            span.to_string()
179        };
180
181        (output, tokens.len())
182    }
183}
184
185const OPENAI_INPUT_LIMIT: usize = 8190;
186
187impl OpenAIEmbeddings {
188    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
189        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
190        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
191
192        OpenAIEmbeddings {
193            client,
194            executor,
195            rate_limit_count_rx,
196            rate_limit_count_tx,
197        }
198    }
199
200    fn resolve_rate_limit(&self) {
201        let reset_time = *self.rate_limit_count_tx.lock().borrow();
202
203        if let Some(reset_time) = reset_time {
204            if Instant::now() >= reset_time {
205                *self.rate_limit_count_tx.lock().borrow_mut() = None
206            }
207        }
208
209        log::trace!(
210            "resolving reset time: {:?}",
211            *self.rate_limit_count_tx.lock().borrow()
212        );
213    }
214
215    fn update_reset_time(&self, reset_time: Instant) {
216        let original_time = *self.rate_limit_count_tx.lock().borrow();
217
218        let updated_time = if let Some(original_time) = original_time {
219            if reset_time < original_time {
220                Some(reset_time)
221            } else {
222                Some(original_time)
223            }
224        } else {
225            Some(reset_time)
226        };
227
228        log::trace!("updating rate limit time: {:?}", updated_time);
229
230        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
231    }
232    async fn send_request(
233        &self,
234        api_key: &str,
235        spans: Vec<&str>,
236        request_timeout: u64,
237    ) -> Result<Response<AsyncBody>> {
238        let request = Request::post("https://api.openai.com/v1/embeddings")
239            .redirect_policy(isahc::config::RedirectPolicy::Follow)
240            .timeout(Duration::from_secs(request_timeout))
241            .header("Content-Type", "application/json")
242            .header("Authorization", format!("Bearer {}", api_key))
243            .body(
244                serde_json::to_string(&OpenAIEmbeddingRequest {
245                    input: spans.clone(),
246                    model: "text-embedding-ada-002",
247                })
248                .unwrap()
249                .into(),
250            )?;
251
252        Ok(self.client.send(request).await?)
253    }
254}
255
256#[async_trait]
257impl EmbeddingProvider for OpenAIEmbeddings {
258    fn is_authenticated(&self) -> bool {
259        OPENAI_API_KEY.as_ref().is_some()
260    }
261    fn max_tokens_per_batch(&self) -> usize {
262        50000
263    }
264
265    fn rate_limit_expiration(&self) -> Option<Instant> {
266        *self.rate_limit_count_rx.borrow()
267    }
268    fn truncate(&self, span: &str) -> (String, usize) {
269        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
270        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
271            tokens.truncate(OPENAI_INPUT_LIMIT);
272            OPENAI_BPE_TOKENIZER
273                .decode(tokens.clone())
274                .ok()
275                .unwrap_or_else(|| span.to_string())
276        } else {
277            span.to_string()
278        };
279
280        (output, tokens.len())
281    }
282
283    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
284        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
285        const MAX_RETRIES: usize = 4;
286
287        let api_key = OPENAI_API_KEY
288            .as_ref()
289            .ok_or_else(|| anyhow!("no api key"))?;
290
291        let mut request_number = 0;
292        let mut rate_limiting = false;
293        let mut request_timeout: u64 = 30;
294        let mut response: Response<AsyncBody>;
295        while request_number < MAX_RETRIES {
296            response = self
297                .send_request(
298                    api_key,
299                    spans.iter().map(|x| &**x).collect(),
300                    request_timeout,
301                )
302                .await?;
303            request_number += 1;
304
305            match response.status() {
306                StatusCode::REQUEST_TIMEOUT => {
307                    request_timeout += 5;
308                }
309                StatusCode::OK => {
310                    let mut body = String::new();
311                    response.body_mut().read_to_string(&mut body).await?;
312                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
313
314                    log::trace!(
315                        "openai embedding completed. tokens: {:?}",
316                        response.usage.total_tokens
317                    );
318
319                    // If we complete a request successfully that was previously rate_limited
320                    // resolve the rate limit
321                    if rate_limiting {
322                        self.resolve_rate_limit()
323                    }
324
325                    return Ok(response
326                        .data
327                        .into_iter()
328                        .map(|embedding| Embedding::from(embedding.embedding))
329                        .collect());
330                }
331                StatusCode::TOO_MANY_REQUESTS => {
332                    rate_limiting = true;
333                    let mut body = String::new();
334                    response.body_mut().read_to_string(&mut body).await?;
335
336                    let delay_duration = {
337                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
338                        if let Some(time_to_reset) =
339                            response.headers().get("x-ratelimit-reset-tokens")
340                        {
341                            if let Ok(time_str) = time_to_reset.to_str() {
342                                parse(time_str).unwrap_or(delay)
343                            } else {
344                                delay
345                            }
346                        } else {
347                            delay
348                        }
349                    };
350
351                    // If we've previously rate limited, increment the duration but not the count
352                    let reset_time = Instant::now().add(delay_duration);
353                    self.update_reset_time(reset_time);
354
355                    log::trace!(
356                        "openai rate limiting: waiting {:?} until lifted",
357                        &delay_duration
358                    );
359
360                    self.executor.timer(delay_duration).await;
361                }
362                _ => {
363                    let mut body = String::new();
364                    response.body_mut().read_to_string(&mut body).await?;
365                    return Err(anyhow!(
366                        "open ai bad request: {:?} {:?}",
367                        &response.status(),
368                        body
369                    ));
370                }
371            }
372        }
373        Err(anyhow!("openai max retries"))
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use rand::prelude::*;
381
382    #[gpui::test]
383    fn test_similarity(mut rng: StdRng) {
384        assert_eq!(
385            Embedding::from(vec![1., 0., 0., 0., 0.])
386                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
387            0.
388        );
389        assert_eq!(
390            Embedding::from(vec![2., 0., 0., 0., 0.])
391                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
392            6.
393        );
394
395        for _ in 0..100 {
396            let size = 1536;
397            let mut a = vec![0.; size];
398            let mut b = vec![0.; size];
399            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
400                *a = rng.gen();
401                *b = rng.gen();
402            }
403            let a = Embedding::from(a);
404            let b = Embedding::from(b);
405
406            assert_eq!(
407                round_to_decimals(a.similarity(&b), 1),
408                round_to_decimals(reference_dot(&a.0, &b.0), 1)
409            );
410        }
411
412        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
413            let factor = (10.0 as f32).powi(decimal_places);
414            (n * factor).round() / factor
415        }
416
417        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
418            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
419        }
420    }
421}