embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::{serde_json, ViewContext};
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use ordered_float::OrderedFloat;
 11use parking_lot::Mutex;
 12use parse_duration::parse;
 13use postage::watch;
 14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 15use rusqlite::ToSql;
 16use serde::{Deserialize, Serialize};
 17use std::env;
 18use std::ops::Add;
 19use std::sync::Arc;
 20use std::time::{Duration, Instant};
 21use tiktoken_rs::{cl100k_base, CoreBPE};
 22use util::http::{HttpClient, Request};
 23use util::ResultExt;
 24
 25use crate::completion::OPENAI_API_URL;
 26
 27lazy_static! {
 28    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 29}
 30
 31#[derive(Debug, PartialEq, Clone)]
 32pub struct Embedding(pub Vec<f32>);
 33
 34// This is needed for semantic index functionality
 35// Unfortunately it has to live wherever the "Embedding" struct is created.
 36// Keeping this in here though, introduces a 'rusqlite' dependency into AI
 37// which is less than ideal
 38impl FromSql for Embedding {
 39    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 40        let bytes = value.as_blob()?;
 41        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 42        if embedding.is_err() {
 43            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 44        }
 45        Ok(Embedding(embedding.unwrap()))
 46    }
 47}
 48
 49impl ToSql for Embedding {
 50    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 51        let bytes = bincode::serialize(&self.0)
 52            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 53        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 54    }
 55}
 56impl From<Vec<f32>> for Embedding {
 57    fn from(value: Vec<f32>) -> Self {
 58        Embedding(value)
 59    }
 60}
 61
 62impl Embedding {
 63    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
 64        let len = self.0.len();
 65        assert_eq!(len, other.0.len());
 66
 67        let mut result = 0.0;
 68        unsafe {
 69            matrixmultiply::sgemm(
 70                1,
 71                len,
 72                1,
 73                1.0,
 74                self.0.as_ptr(),
 75                len as isize,
 76                1,
 77                other.0.as_ptr(),
 78                1,
 79                len as isize,
 80                0.0,
 81                &mut result as *mut f32,
 82                1,
 83                1,
 84            );
 85        }
 86        OrderedFloat(result)
 87    }
 88}
 89
 90#[derive(Clone)]
 91pub struct OpenAIEmbeddings {
 92    pub api_key: Option<String>,
 93    pub client: Arc<dyn HttpClient>,
 94    pub executor: Arc<Background>,
 95    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 96    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 97}
 98
 99#[derive(Serialize)]
100struct OpenAIEmbeddingRequest<'a> {
101    model: &'static str,
102    input: Vec<&'a str>,
103}
104
105#[derive(Deserialize)]
106struct OpenAIEmbeddingResponse {
107    data: Vec<OpenAIEmbedding>,
108    usage: OpenAIEmbeddingUsage,
109}
110
111#[derive(Debug, Deserialize)]
112struct OpenAIEmbedding {
113    embedding: Vec<f32>,
114    index: usize,
115    object: String,
116}
117
118#[derive(Deserialize)]
119struct OpenAIEmbeddingUsage {
120    prompt_tokens: usize,
121    total_tokens: usize,
122}
123
124#[async_trait]
125pub trait EmbeddingProvider: Sync + Send {
126    fn is_authenticated(&self) -> bool;
127    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
128    fn max_tokens_per_batch(&self) -> usize;
129    fn truncate(&self, span: &str) -> (String, usize);
130    fn rate_limit_expiration(&self) -> Option<Instant>;
131}
132
133pub struct DummyEmbeddings {}
134
135#[async_trait]
136impl EmbeddingProvider for DummyEmbeddings {
137    fn is_authenticated(&self) -> bool {
138        true
139    }
140    fn rate_limit_expiration(&self) -> Option<Instant> {
141        None
142    }
143    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
144        // 1024 is the OpenAI Embeddings size for ada models.
145        // the model we will likely be starting with.
146        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
147        return Ok(vec![dummy_vec; spans.len()]);
148    }
149
150    fn max_tokens_per_batch(&self) -> usize {
151        OPENAI_INPUT_LIMIT
152    }
153
154    fn truncate(&self, span: &str) -> (String, usize) {
155        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
156        let token_count = tokens.len();
157        let output = if token_count > OPENAI_INPUT_LIMIT {
158            tokens.truncate(OPENAI_INPUT_LIMIT);
159            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
160            new_input.ok().unwrap_or_else(|| span.to_string())
161        } else {
162            span.to_string()
163        };
164
165        (output, tokens.len())
166    }
167}
168
169const OPENAI_INPUT_LIMIT: usize = 8190;
170
171impl OpenAIEmbeddings {
172    pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
173        if self.api_key.is_none() {
174            let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
175                Some(api_key)
176            } else if let Some((_, api_key)) = cx
177                .platform()
178                .read_credentials(OPENAI_API_URL)
179                .log_err()
180                .flatten()
181            {
182                String::from_utf8(api_key).log_err()
183            } else {
184                None
185            };
186
187            if let Some(api_key) = api_key {
188                self.api_key = Some(api_key);
189            }
190        }
191    }
192    pub fn new(
193        api_key: Option<String>,
194        client: Arc<dyn HttpClient>,
195        executor: Arc<Background>,
196    ) -> Self {
197        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
198        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
199
200        OpenAIEmbeddings {
201            api_key,
202            client,
203            executor,
204            rate_limit_count_rx,
205            rate_limit_count_tx,
206        }
207    }
208
209    fn resolve_rate_limit(&self) {
210        let reset_time = *self.rate_limit_count_tx.lock().borrow();
211
212        if let Some(reset_time) = reset_time {
213            if Instant::now() >= reset_time {
214                *self.rate_limit_count_tx.lock().borrow_mut() = None
215            }
216        }
217
218        log::trace!(
219            "resolving reset time: {:?}",
220            *self.rate_limit_count_tx.lock().borrow()
221        );
222    }
223
224    fn update_reset_time(&self, reset_time: Instant) {
225        let original_time = *self.rate_limit_count_tx.lock().borrow();
226
227        let updated_time = if let Some(original_time) = original_time {
228            if reset_time < original_time {
229                Some(reset_time)
230            } else {
231                Some(original_time)
232            }
233        } else {
234            Some(reset_time)
235        };
236
237        log::trace!("updating rate limit time: {:?}", updated_time);
238
239        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
240    }
241    async fn send_request(
242        &self,
243        api_key: &str,
244        spans: Vec<&str>,
245        request_timeout: u64,
246    ) -> Result<Response<AsyncBody>> {
247        let request = Request::post("https://api.openai.com/v1/embeddings")
248            .redirect_policy(isahc::config::RedirectPolicy::Follow)
249            .timeout(Duration::from_secs(request_timeout))
250            .header("Content-Type", "application/json")
251            .header("Authorization", format!("Bearer {}", api_key))
252            .body(
253                serde_json::to_string(&OpenAIEmbeddingRequest {
254                    input: spans.clone(),
255                    model: "text-embedding-ada-002",
256                })
257                .unwrap()
258                .into(),
259            )?;
260
261        Ok(self.client.send(request).await?)
262    }
263}
264
265#[async_trait]
266impl EmbeddingProvider for OpenAIEmbeddings {
267    fn is_authenticated(&self) -> bool {
268        self.api_key.is_some()
269    }
270
271    fn max_tokens_per_batch(&self) -> usize {
272        50000
273    }
274
275    fn rate_limit_expiration(&self) -> Option<Instant> {
276        *self.rate_limit_count_rx.borrow()
277    }
278    fn truncate(&self, span: &str) -> (String, usize) {
279        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
280        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
281            tokens.truncate(OPENAI_INPUT_LIMIT);
282            OPENAI_BPE_TOKENIZER
283                .decode(tokens.clone())
284                .ok()
285                .unwrap_or_else(|| span.to_string())
286        } else {
287            span.to_string()
288        };
289
290        (output, tokens.len())
291    }
292
293    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
294        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
295        const MAX_RETRIES: usize = 4;
296
297        let Some(api_key) = self.api_key.clone() else {
298            return Err(anyhow!("no open ai key provided"));
299        };
300
301        let mut request_number = 0;
302        let mut rate_limiting = false;
303        let mut request_timeout: u64 = 15;
304        let mut response: Response<AsyncBody>;
305        while request_number < MAX_RETRIES {
306            response = self
307                .send_request(
308                    &api_key,
309                    spans.iter().map(|x| &**x).collect(),
310                    request_timeout,
311                )
312                .await?;
313
314            request_number += 1;
315
316            match response.status() {
317                StatusCode::REQUEST_TIMEOUT => {
318                    request_timeout += 5;
319                }
320                StatusCode::OK => {
321                    let mut body = String::new();
322                    response.body_mut().read_to_string(&mut body).await?;
323                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
324
325                    log::trace!(
326                        "openai embedding completed. tokens: {:?}",
327                        response.usage.total_tokens
328                    );
329
330                    // If we complete a request successfully that was previously rate_limited
331                    // resolve the rate limit
332                    if rate_limiting {
333                        self.resolve_rate_limit()
334                    }
335
336                    return Ok(response
337                        .data
338                        .into_iter()
339                        .map(|embedding| Embedding::from(embedding.embedding))
340                        .collect());
341                }
342                StatusCode::TOO_MANY_REQUESTS => {
343                    rate_limiting = true;
344                    let mut body = String::new();
345                    response.body_mut().read_to_string(&mut body).await?;
346
347                    let delay_duration = {
348                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
349                        if let Some(time_to_reset) =
350                            response.headers().get("x-ratelimit-reset-tokens")
351                        {
352                            if let Ok(time_str) = time_to_reset.to_str() {
353                                parse(time_str).unwrap_or(delay)
354                            } else {
355                                delay
356                            }
357                        } else {
358                            delay
359                        }
360                    };
361
362                    // If we've previously rate limited, increment the duration but not the count
363                    let reset_time = Instant::now().add(delay_duration);
364                    self.update_reset_time(reset_time);
365
366                    log::trace!(
367                        "openai rate limiting: waiting {:?} until lifted",
368                        &delay_duration
369                    );
370
371                    self.executor.timer(delay_duration).await;
372                }
373                _ => {
374                    let mut body = String::new();
375                    response.body_mut().read_to_string(&mut body).await?;
376                    return Err(anyhow!(
377                        "open ai bad request: {:?} {:?}",
378                        &response.status(),
379                        body
380                    ));
381                }
382            }
383        }
384        Err(anyhow!("openai max retries"))
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use rand::prelude::*;
392
393    #[gpui::test]
394    fn test_similarity(mut rng: StdRng) {
395        assert_eq!(
396            Embedding::from(vec![1., 0., 0., 0., 0.])
397                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
398            0.
399        );
400        assert_eq!(
401            Embedding::from(vec![2., 0., 0., 0., 0.])
402                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
403            6.
404        );
405
406        for _ in 0..100 {
407            let size = 1536;
408            let mut a = vec![0.; size];
409            let mut b = vec![0.; size];
410            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
411                *a = rng.gen();
412                *b = rng.gen();
413            }
414            let a = Embedding::from(a);
415            let b = Embedding::from(b);
416
417            assert_eq!(
418                round_to_decimals(a.similarity(&b), 1),
419                round_to_decimals(reference_dot(&a.0, &b.0), 1)
420            );
421        }
422
423        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
424            let factor = (10.0 as f32).powi(decimal_places);
425            (n * factor).round() / factor
426        }
427
428        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
429            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
430        }
431    }
432}