embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parking_lot::Mutex;
 11use parse_duration::parse;
 12use postage::watch;
 13use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 14use rusqlite::ToSql;
 15use serde::{Deserialize, Serialize};
 16use std::env;
 17use std::ops::Add;
 18use std::sync::Arc;
 19use std::time::{Duration, Instant};
 20use tiktoken_rs::{cl100k_base, CoreBPE};
 21use util::http::{HttpClient, Request};
 22
 23lazy_static! {
 24    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 25    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 26}
 27
 28#[derive(Debug, PartialEq, Clone)]
 29pub struct Embedding(Vec<f32>);
 30
 31impl From<Vec<f32>> for Embedding {
 32    fn from(value: Vec<f32>) -> Self {
 33        Embedding(value)
 34    }
 35}
 36
 37impl Embedding {
 38    pub fn similarity(&self, other: &Self) -> f32 {
 39        let len = self.0.len();
 40        assert_eq!(len, other.0.len());
 41
 42        let mut result = 0.0;
 43        unsafe {
 44            matrixmultiply::sgemm(
 45                1,
 46                len,
 47                1,
 48                1.0,
 49                self.0.as_ptr(),
 50                len as isize,
 51                1,
 52                other.0.as_ptr(),
 53                1,
 54                len as isize,
 55                0.0,
 56                &mut result as *mut f32,
 57                1,
 58                1,
 59            );
 60        }
 61        result
 62    }
 63}
 64
 65impl FromSql for Embedding {
 66    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 67        let bytes = value.as_blob()?;
 68        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 69        if embedding.is_err() {
 70            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 71        }
 72        Ok(Embedding(embedding.unwrap()))
 73    }
 74}
 75
 76impl ToSql for Embedding {
 77    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 78        let bytes = bincode::serialize(&self.0)
 79            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
 80        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
 81    }
 82}
 83
 84#[derive(Clone)]
 85pub struct OpenAIEmbeddings {
 86    pub client: Arc<dyn HttpClient>,
 87    pub executor: Arc<Background>,
 88    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 89    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 90}
 91
 92#[derive(Serialize)]
 93struct OpenAIEmbeddingRequest<'a> {
 94    model: &'static str,
 95    input: Vec<&'a str>,
 96}
 97
 98#[derive(Deserialize)]
 99struct OpenAIEmbeddingResponse {
100    data: Vec<OpenAIEmbedding>,
101    usage: OpenAIEmbeddingUsage,
102}
103
104#[derive(Debug, Deserialize)]
105struct OpenAIEmbedding {
106    embedding: Vec<f32>,
107    index: usize,
108    object: String,
109}
110
111#[derive(Deserialize)]
112struct OpenAIEmbeddingUsage {
113    prompt_tokens: usize,
114    total_tokens: usize,
115}
116
117#[async_trait]
118pub trait EmbeddingProvider: Sync + Send {
119    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
120    fn max_tokens_per_batch(&self) -> usize;
121    fn truncate(&self, span: &str) -> (String, usize);
122    fn rate_limit_expiration(&self) -> Option<Instant>;
123}
124
125pub struct DummyEmbeddings {}
126
127#[async_trait]
128impl EmbeddingProvider for DummyEmbeddings {
129    fn rate_limit_expiration(&self) -> Option<Instant> {
130        None
131    }
132    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
133        // 1024 is the OpenAI Embeddings size for ada models.
134        // the model we will likely be starting with.
135        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
136        return Ok(vec![dummy_vec; spans.len()]);
137    }
138
139    fn max_tokens_per_batch(&self) -> usize {
140        OPENAI_INPUT_LIMIT
141    }
142
143    fn truncate(&self, span: &str) -> (String, usize) {
144        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
145        let token_count = tokens.len();
146        let output = if token_count > OPENAI_INPUT_LIMIT {
147            tokens.truncate(OPENAI_INPUT_LIMIT);
148            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
149            new_input.ok().unwrap_or_else(|| span.to_string())
150        } else {
151            span.to_string()
152        };
153
154        (output, tokens.len())
155    }
156}
157
158const OPENAI_INPUT_LIMIT: usize = 8190;
159
160impl OpenAIEmbeddings {
161    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
162        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
163        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
164
165        OpenAIEmbeddings {
166            client,
167            executor,
168            rate_limit_count_rx,
169            rate_limit_count_tx,
170        }
171    }
172
173    fn resolve_rate_limit(&self) {
174        let reset_time = *self.rate_limit_count_tx.lock().borrow();
175
176        if let Some(reset_time) = reset_time {
177            if Instant::now() >= reset_time {
178                *self.rate_limit_count_tx.lock().borrow_mut() = None
179            }
180        }
181
182        log::trace!(
183            "resolving reset time: {:?}",
184            *self.rate_limit_count_tx.lock().borrow()
185        );
186    }
187
188    fn update_reset_time(&self, reset_time: Instant) {
189        let original_time = *self.rate_limit_count_tx.lock().borrow();
190
191        let updated_time = if let Some(original_time) = original_time {
192            if reset_time < original_time {
193                Some(reset_time)
194            } else {
195                Some(original_time)
196            }
197        } else {
198            Some(reset_time)
199        };
200
201        log::trace!("updating rate limit time: {:?}", updated_time);
202
203        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
204    }
205    async fn send_request(
206        &self,
207        api_key: &str,
208        spans: Vec<&str>,
209        request_timeout: u64,
210    ) -> Result<Response<AsyncBody>> {
211        let request = Request::post("https://api.openai.com/v1/embeddings")
212            .redirect_policy(isahc::config::RedirectPolicy::Follow)
213            .timeout(Duration::from_secs(request_timeout))
214            .header("Content-Type", "application/json")
215            .header("Authorization", format!("Bearer {}", api_key))
216            .body(
217                serde_json::to_string(&OpenAIEmbeddingRequest {
218                    input: spans.clone(),
219                    model: "text-embedding-ada-002",
220                })
221                .unwrap()
222                .into(),
223            )?;
224
225        Ok(self.client.send(request).await?)
226    }
227}
228
229#[async_trait]
230impl EmbeddingProvider for OpenAIEmbeddings {
231    fn max_tokens_per_batch(&self) -> usize {
232        50000
233    }
234
235    fn rate_limit_expiration(&self) -> Option<Instant> {
236        *self.rate_limit_count_rx.borrow()
237    }
238    fn truncate(&self, span: &str) -> (String, usize) {
239        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
240        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
241            tokens.truncate(OPENAI_INPUT_LIMIT);
242            OPENAI_BPE_TOKENIZER
243                .decode(tokens.clone())
244                .ok()
245                .unwrap_or_else(|| span.to_string())
246        } else {
247            span.to_string()
248        };
249
250        (output, tokens.len())
251    }
252
253    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
254        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
255        const MAX_RETRIES: usize = 4;
256
257        let api_key = OPENAI_API_KEY
258            .as_ref()
259            .ok_or_else(|| anyhow!("no api key"))?;
260
261        let mut request_number = 0;
262        let mut rate_limiting = false;
263        let mut request_timeout: u64 = 15;
264        let mut response: Response<AsyncBody>;
265        while request_number < MAX_RETRIES {
266            response = self
267                .send_request(
268                    api_key,
269                    spans.iter().map(|x| &**x).collect(),
270                    request_timeout,
271                )
272                .await?;
273            request_number += 1;
274
275            match response.status() {
276                StatusCode::REQUEST_TIMEOUT => {
277                    request_timeout += 5;
278                }
279                StatusCode::OK => {
280                    let mut body = String::new();
281                    response.body_mut().read_to_string(&mut body).await?;
282                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
283
284                    log::trace!(
285                        "openai embedding completed. tokens: {:?}",
286                        response.usage.total_tokens
287                    );
288
289                    // If we complete a request successfully that was previously rate_limited
290                    // resolve the rate limit
291                    if rate_limiting {
292                        self.resolve_rate_limit()
293                    }
294
295                    return Ok(response
296                        .data
297                        .into_iter()
298                        .map(|embedding| Embedding::from(embedding.embedding))
299                        .collect());
300                }
301                StatusCode::TOO_MANY_REQUESTS => {
302                    rate_limiting = true;
303                    let mut body = String::new();
304                    response.body_mut().read_to_string(&mut body).await?;
305
306                    let delay_duration = {
307                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
308                        if let Some(time_to_reset) =
309                            response.headers().get("x-ratelimit-reset-tokens")
310                        {
311                            if let Ok(time_str) = time_to_reset.to_str() {
312                                parse(time_str).unwrap_or(delay)
313                            } else {
314                                delay
315                            }
316                        } else {
317                            delay
318                        }
319                    };
320
321                    // If we've previously rate limited, increment the duration but not the count
322                    let reset_time = Instant::now().add(delay_duration);
323                    self.update_reset_time(reset_time);
324
325                    log::trace!(
326                        "openai rate limiting: waiting {:?} until lifted",
327                        &delay_duration
328                    );
329
330                    self.executor.timer(delay_duration).await;
331                }
332                _ => {
333                    let mut body = String::new();
334                    response.body_mut().read_to_string(&mut body).await?;
335                    return Err(anyhow!(
336                        "open ai bad request: {:?} {:?}",
337                        &response.status(),
338                        body
339                    ));
340                }
341            }
342        }
343        Err(anyhow!("openai max retries"))
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use rand::prelude::*;
351
352    #[gpui::test]
353    fn test_similarity(mut rng: StdRng) {
354        assert_eq!(
355            Embedding::from(vec![1., 0., 0., 0., 0.])
356                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
357            0.
358        );
359        assert_eq!(
360            Embedding::from(vec![2., 0., 0., 0., 0.])
361                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
362            6.
363        );
364
365        for _ in 0..100 {
366            let size = 1536;
367            let mut a = vec![0.; size];
368            let mut b = vec![0.; size];
369            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
370                *a = rng.gen();
371                *b = rng.gen();
372            }
373            let a = Embedding::from(a);
374            let b = Embedding::from(b);
375
376            assert_eq!(
377                round_to_decimals(a.similarity(&b), 1),
378                round_to_decimals(reference_dot(&a.0, &b.0), 1)
379            );
380        }
381
382        fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
383            let factor = (10.0 as f32).powi(decimal_places);
384            (n * factor).round() / factor
385        }
386
387        fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
388            a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
389        }
390    }
391}