embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::{serde_json, AppContext};
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use ordered_float::OrderedFloat;
 11use parking_lot::Mutex;
 12use parse_duration::parse;
 13use postage::watch;
 14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 15use rusqlite::ToSql;
 16use serde::{Deserialize, Serialize};
 17use std::env;
 18use std::ops::Add;
 19use std::sync::Arc;
 20use std::time::{Duration, Instant};
 21use tiktoken_rs::{cl100k_base, CoreBPE};
 22use util::http::{HttpClient, Request};
 23use util::ResultExt;
 24
 25use crate::completion::OPENAI_API_URL;
 26
 27lazy_static! {
 28    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 29}
 30
 31#[derive(Debug, PartialEq, Clone)]
 32pub struct Embedding(pub Vec<f32>);
 33
 34// This is needed for semantic index functionality
 35// Unfortunately it has to live wherever the "Embedding" struct is created.
 36// Keeping this in here though, introduces a 'rusqlite' dependency into AI
 37// which is less than ideal
 38impl FromSql for Embedding {
 39    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 40        let bytes = value.as_blob()?;
 41        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 42        if embedding.is_err() {
 43            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 44        }
 45        Ok(Embedding(embedding.unwrap()))
 46    }
 47}
 48
 49impl ToSql for Embedding {
 50    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 51        let bytes = bincode::serialize(&self.0)
 52            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 53        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 54    }
 55}
 56impl From<Vec<f32>> for Embedding {
 57    fn from(value: Vec<f32>) -> Self {
 58        Embedding(value)
 59    }
 60}
 61
 62impl Embedding {
 63    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
 64        let len = self.0.len();
 65        assert_eq!(len, other.0.len());
 66
 67        let mut result = 0.0;
 68        unsafe {
 69            matrixmultiply::sgemm(
 70                1,
 71                len,
 72                1,
 73                1.0,
 74                self.0.as_ptr(),
 75                len as isize,
 76                1,
 77                other.0.as_ptr(),
 78                1,
 79                len as isize,
 80                0.0,
 81                &mut result as *mut f32,
 82                1,
 83                1,
 84            );
 85        }
 86        OrderedFloat(result)
 87    }
 88}
 89
 90#[derive(Clone)]
 91pub struct OpenAIEmbeddings {
 92    pub client: Arc<dyn HttpClient>,
 93    pub executor: Arc<Background>,
 94    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 95    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 96}
 97
 98#[derive(Serialize)]
 99struct OpenAIEmbeddingRequest<'a> {
100    model: &'static str,
101    input: Vec<&'a str>,
102}
103
104#[derive(Deserialize)]
105struct OpenAIEmbeddingResponse {
106    data: Vec<OpenAIEmbedding>,
107    usage: OpenAIEmbeddingUsage,
108}
109
110#[derive(Debug, Deserialize)]
111struct OpenAIEmbedding {
112    embedding: Vec<f32>,
113    index: usize,
114    object: String,
115}
116
117#[derive(Deserialize)]
118struct OpenAIEmbeddingUsage {
119    prompt_tokens: usize,
120    total_tokens: usize,
121}
122
123#[async_trait]
124pub trait EmbeddingProvider: Sync + Send {
125    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
126    async fn embed_batch(
127        &self,
128        spans: Vec<String>,
129        api_key: Option<String>,
130    ) -> Result<Vec<Embedding>>;
131    fn max_tokens_per_batch(&self) -> usize;
132    fn truncate(&self, span: &str) -> (String, usize);
133    fn rate_limit_expiration(&self) -> Option<Instant>;
134}
135
136pub struct DummyEmbeddings {}
137
138#[async_trait]
139impl EmbeddingProvider for DummyEmbeddings {
140    fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
141        Some("Dummy API KEY".to_string())
142    }
143    fn rate_limit_expiration(&self) -> Option<Instant> {
144        None
145    }
146    async fn embed_batch(
147        &self,
148        spans: Vec<String>,
149        _api_key: Option<String>,
150    ) -> Result<Vec<Embedding>> {
151        // 1024 is the OpenAI Embeddings size for ada models.
152        // the model we will likely be starting with.
153        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
154        return Ok(vec![dummy_vec; spans.len()]);
155    }
156
157    fn max_tokens_per_batch(&self) -> usize {
158        OPENAI_INPUT_LIMIT
159    }
160
161    fn truncate(&self, span: &str) -> (String, usize) {
162        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
163        let token_count = tokens.len();
164        let output = if token_count > OPENAI_INPUT_LIMIT {
165            tokens.truncate(OPENAI_INPUT_LIMIT);
166            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
167            new_input.ok().unwrap_or_else(|| span.to_string())
168        } else {
169            span.to_string()
170        };
171
172        (output, tokens.len())
173    }
174}
175
176const OPENAI_INPUT_LIMIT: usize = 8190;
177
178impl OpenAIEmbeddings {
179    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
180        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
181        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
182
183        OpenAIEmbeddings {
184            client,
185            executor,
186            rate_limit_count_rx,
187            rate_limit_count_tx,
188        }
189    }
190
191    fn resolve_rate_limit(&self) {
192        let reset_time = *self.rate_limit_count_tx.lock().borrow();
193
194        if let Some(reset_time) = reset_time {
195            if Instant::now() >= reset_time {
196                *self.rate_limit_count_tx.lock().borrow_mut() = None
197            }
198        }
199
200        log::trace!(
201            "resolving reset time: {:?}",
202            *self.rate_limit_count_tx.lock().borrow()
203        );
204    }
205
206    fn update_reset_time(&self, reset_time: Instant) {
207        let original_time = *self.rate_limit_count_tx.lock().borrow();
208
209        let updated_time = if let Some(original_time) = original_time {
210            if reset_time < original_time {
211                Some(reset_time)
212            } else {
213                Some(original_time)
214            }
215        } else {
216            Some(reset_time)
217        };
218
219        log::trace!("updating rate limit time: {:?}", updated_time);
220
221        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
222    }
223    async fn send_request(
224        &self,
225        api_key: &str,
226        spans: Vec<&str>,
227        request_timeout: u64,
228    ) -> Result<Response<AsyncBody>> {
229        let request = Request::post("https://api.openai.com/v1/embeddings")
230            .redirect_policy(isahc::config::RedirectPolicy::Follow)
231            .timeout(Duration::from_secs(request_timeout))
232            .header("Content-Type", "application/json")
233            .header("Authorization", format!("Bearer {}", api_key))
234            .body(
235                serde_json::to_string(&OpenAIEmbeddingRequest {
236                    input: spans.clone(),
237                    model: "text-embedding-ada-002",
238                })
239                .unwrap()
240                .into(),
241            )?;
242
243        Ok(self.client.send(request).await?)
244    }
245}
246
247#[async_trait]
248impl EmbeddingProvider for OpenAIEmbeddings {
249    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
250        if let Ok(api_key) = env::var("OPENAI_API_KEY") {
251            Some(api_key)
252        } else if let Some((_, api_key)) = cx
253            .platform()
254            .read_credentials(OPENAI_API_URL)
255            .log_err()
256            .flatten()
257        {
258            String::from_utf8(api_key).log_err()
259        } else {
260            None
261        }
262    }
263
264    fn max_tokens_per_batch(&self) -> usize {
265        50000
266    }
267
268    fn rate_limit_expiration(&self) -> Option<Instant> {
269        *self.rate_limit_count_rx.borrow()
270    }
271    fn truncate(&self, span: &str) -> (String, usize) {
272        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
273        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
274            tokens.truncate(OPENAI_INPUT_LIMIT);
275            OPENAI_BPE_TOKENIZER
276                .decode(tokens.clone())
277                .ok()
278                .unwrap_or_else(|| span.to_string())
279        } else {
280            span.to_string()
281        };
282
283        (output, tokens.len())
284    }
285
286    async fn embed_batch(
287        &self,
288        spans: Vec<String>,
289        api_key: Option<String>,
290    ) -> Result<Vec<Embedding>> {
291        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
292        const MAX_RETRIES: usize = 4;
293
294        let Some(api_key) = api_key else {
295            return Err(anyhow!("no open ai key provided"));
296        };
297
298        let mut request_number = 0;
299        let mut rate_limiting = false;
300        let mut request_timeout: u64 = 15;
301        let mut response: Response<AsyncBody>;
302        while request_number < MAX_RETRIES {
303            response = self
304                .send_request(
305                    &api_key,
306                    spans.iter().map(|x| &**x).collect(),
307                    request_timeout,
308                )
309                .await?;
310
311            request_number += 1;
312
313            match response.status() {
314                StatusCode::REQUEST_TIMEOUT => {
315                    request_timeout += 5;
316                }
317                StatusCode::OK => {
318                    let mut body = String::new();
319                    response.body_mut().read_to_string(&mut body).await?;
320                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
321
322                    log::trace!(
323                        "openai embedding completed. tokens: {:?}",
324                        response.usage.total_tokens
325                    );
326
327                    // If we complete a request successfully that was previously rate_limited
328                    // resolve the rate limit
329                    if rate_limiting {
330                        self.resolve_rate_limit()
331                    }
332
333                    return Ok(response
334                        .data
335                        .into_iter()
336                        .map(|embedding| Embedding::from(embedding.embedding))
337                        .collect());
338                }
339                StatusCode::TOO_MANY_REQUESTS => {
340                    rate_limiting = true;
341                    let mut body = String::new();
342                    response.body_mut().read_to_string(&mut body).await?;
343
344                    let delay_duration = {
345                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
346                        if let Some(time_to_reset) =
347                            response.headers().get("x-ratelimit-reset-tokens")
348                        {
349                            if let Ok(time_str) = time_to_reset.to_str() {
350                                parse(time_str).unwrap_or(delay)
351                            } else {
352                                delay
353                            }
354                        } else {
355                            delay
356                        }
357                    };
358
359                    // If we've previously rate limited, increment the duration but not the count
360                    let reset_time = Instant::now().add(delay_duration);
361                    self.update_reset_time(reset_time);
362
363                    log::trace!(
364                        "openai rate limiting: waiting {:?} until lifted",
365                        &delay_duration
366                    );
367
368                    self.executor.timer(delay_duration).await;
369                }
370                _ => {
371                    let mut body = String::new();
372                    response.body_mut().read_to_string(&mut body).await?;
373                    return Err(anyhow!(
374                        "open ai bad request: {:?} {:?}",
375                        &response.status(),
376                        body
377                    ));
378                }
379            }
380        }
381        Err(anyhow!("openai max retries"))
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use rand::prelude::*;
389
390    #[gpui::test]
391    fn test_similarity(mut rng: StdRng) {
392        assert_eq!(
393            Embedding::from(vec![1., 0., 0., 0., 0.])
394                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
395            0.
396        );
397        assert_eq!(
398            Embedding::from(vec![2., 0., 0., 0., 0.])
399                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
400            6.
401        );
402
403        for _ in 0..100 {
404            let size = 1536;
405            let mut a = vec![0.; size];
406            let mut b = vec![0.; size];
407            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
408                *a = rng.gen();
409                *b = rng.gen();
410            }
411            let a = Embedding::from(a);
412            let b = Embedding::from(b);
413
414            assert_eq!(
415                round_to_decimals(a.similarity(&b), 1),
416                round_to_decimals(reference_dot(&a.0, &b.0), 1)
417            );
418        }
419
420        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
421            let factor = (10.0 as f32).powi(decimal_places);
422            (n * factor).round() / factor
423        }
424
425        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
426            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
427        }
428    }
429}