1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use ordered_float::OrderedFloat;
11use parking_lot::Mutex;
12use parse_duration::parse;
13use postage::watch;
14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
15use rusqlite::ToSql;
16use serde::{Deserialize, Serialize};
17use std::env;
18use std::ops::Add;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tiktoken_rs::{cl100k_base, CoreBPE};
22use util::http::{HttpClient, Request};
23
24lazy_static! {
25 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
26 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
27}
28
29#[derive(Debug, PartialEq, Clone)]
30pub struct Embedding(pub Vec<f32>);
31
32// This is needed for semantic index functionality
33// Unfortunately it has to live wherever the "Embedding" struct is created.
34// Keeping this in here though, introduces a 'rusqlite' dependency into AI
35// which is less than ideal
36impl FromSql for Embedding {
37 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
38 let bytes = value.as_blob()?;
39 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
40 if embedding.is_err() {
41 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
42 }
43 Ok(Embedding(embedding.unwrap()))
44 }
45}
46
47impl ToSql for Embedding {
48 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
49 let bytes = bincode::serialize(&self.0)
50 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
51 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
52 }
53}
54impl From<Vec<f32>> for Embedding {
55 fn from(value: Vec<f32>) -> Self {
56 Embedding(value)
57 }
58}
59
60impl Embedding {
61 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
62 let len = self.0.len();
63 assert_eq!(len, other.0.len());
64
65 let mut result = 0.0;
66 unsafe {
67 matrixmultiply::sgemm(
68 1,
69 len,
70 1,
71 1.0,
72 self.0.as_ptr(),
73 len as isize,
74 1,
75 other.0.as_ptr(),
76 1,
77 len as isize,
78 0.0,
79 &mut result as *mut f32,
80 1,
81 1,
82 );
83 }
84 OrderedFloat(result)
85 }
86}
87
88// impl FromSql for Embedding {
89// fn column_result(value: ValueRef) -> FromSqlResult<Self> {
90// let bytes = value.as_blob()?;
91// let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
92// if embedding.is_err() {
93// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
94// }
95// Ok(Embedding(embedding.unwrap()))
96// }
97// }
98
99// impl ToSql for Embedding {
100// fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
101// let bytes = bincode::serialize(&self.0)
102// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
103// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
104// }
105// }
106
107#[derive(Clone)]
108pub struct OpenAIEmbeddings {
109 pub client: Arc<dyn HttpClient>,
110 pub executor: Arc<Background>,
111 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
112 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
113}
114
115#[derive(Serialize)]
116struct OpenAIEmbeddingRequest<'a> {
117 model: &'static str,
118 input: Vec<&'a str>,
119}
120
121#[derive(Deserialize)]
122struct OpenAIEmbeddingResponse {
123 data: Vec<OpenAIEmbedding>,
124 usage: OpenAIEmbeddingUsage,
125}
126
127#[derive(Debug, Deserialize)]
128struct OpenAIEmbedding {
129 embedding: Vec<f32>,
130 index: usize,
131 object: String,
132}
133
134#[derive(Deserialize)]
135struct OpenAIEmbeddingUsage {
136 prompt_tokens: usize,
137 total_tokens: usize,
138}
139
140#[async_trait]
141pub trait EmbeddingProvider: Sync + Send {
142 fn is_authenticated(&self) -> bool;
143 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
144 fn max_tokens_per_batch(&self) -> usize;
145 fn truncate(&self, span: &str) -> (String, usize);
146 fn rate_limit_expiration(&self) -> Option<Instant>;
147}
148
149pub struct DummyEmbeddings {}
150
151#[async_trait]
152impl EmbeddingProvider for DummyEmbeddings {
153 fn is_authenticated(&self) -> bool {
154 true
155 }
156 fn rate_limit_expiration(&self) -> Option<Instant> {
157 None
158 }
159 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
160 // 1024 is the OpenAI Embeddings size for ada models.
161 // the model we will likely be starting with.
162 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
163 return Ok(vec![dummy_vec; spans.len()]);
164 }
165
166 fn max_tokens_per_batch(&self) -> usize {
167 OPENAI_INPUT_LIMIT
168 }
169
170 fn truncate(&self, span: &str) -> (String, usize) {
171 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
172 let token_count = tokens.len();
173 let output = if token_count > OPENAI_INPUT_LIMIT {
174 tokens.truncate(OPENAI_INPUT_LIMIT);
175 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
176 new_input.ok().unwrap_or_else(|| span.to_string())
177 } else {
178 span.to_string()
179 };
180
181 (output, tokens.len())
182 }
183}
184
185const OPENAI_INPUT_LIMIT: usize = 8190;
186
187impl OpenAIEmbeddings {
188 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
189 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
190 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
191
192 OpenAIEmbeddings {
193 client,
194 executor,
195 rate_limit_count_rx,
196 rate_limit_count_tx,
197 }
198 }
199
200 fn resolve_rate_limit(&self) {
201 let reset_time = *self.rate_limit_count_tx.lock().borrow();
202
203 if let Some(reset_time) = reset_time {
204 if Instant::now() >= reset_time {
205 *self.rate_limit_count_tx.lock().borrow_mut() = None
206 }
207 }
208
209 log::trace!(
210 "resolving reset time: {:?}",
211 *self.rate_limit_count_tx.lock().borrow()
212 );
213 }
214
215 fn update_reset_time(&self, reset_time: Instant) {
216 let original_time = *self.rate_limit_count_tx.lock().borrow();
217
218 let updated_time = if let Some(original_time) = original_time {
219 if reset_time < original_time {
220 Some(reset_time)
221 } else {
222 Some(original_time)
223 }
224 } else {
225 Some(reset_time)
226 };
227
228 log::trace!("updating rate limit time: {:?}", updated_time);
229
230 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
231 }
232 async fn send_request(
233 &self,
234 api_key: &str,
235 spans: Vec<&str>,
236 request_timeout: u64,
237 ) -> Result<Response<AsyncBody>> {
238 let request = Request::post("https://api.openai.com/v1/embeddings")
239 .redirect_policy(isahc::config::RedirectPolicy::Follow)
240 .timeout(Duration::from_secs(request_timeout))
241 .header("Content-Type", "application/json")
242 .header("Authorization", format!("Bearer {}", api_key))
243 .body(
244 serde_json::to_string(&OpenAIEmbeddingRequest {
245 input: spans.clone(),
246 model: "text-embedding-ada-002",
247 })
248 .unwrap()
249 .into(),
250 )?;
251
252 Ok(self.client.send(request).await?)
253 }
254}
255
256#[async_trait]
257impl EmbeddingProvider for OpenAIEmbeddings {
258 fn is_authenticated(&self) -> bool {
259 OPENAI_API_KEY.as_ref().is_some()
260 }
261 fn max_tokens_per_batch(&self) -> usize {
262 50000
263 }
264
265 fn rate_limit_expiration(&self) -> Option<Instant> {
266 *self.rate_limit_count_rx.borrow()
267 }
268 fn truncate(&self, span: &str) -> (String, usize) {
269 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
270 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
271 tokens.truncate(OPENAI_INPUT_LIMIT);
272 OPENAI_BPE_TOKENIZER
273 .decode(tokens.clone())
274 .ok()
275 .unwrap_or_else(|| span.to_string())
276 } else {
277 span.to_string()
278 };
279
280 (output, tokens.len())
281 }
282
283 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
284 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
285 const MAX_RETRIES: usize = 4;
286
287 let api_key = OPENAI_API_KEY
288 .as_ref()
289 .ok_or_else(|| anyhow!("no api key"))?;
290
291 let mut request_number = 0;
292 let mut rate_limiting = false;
293 let mut request_timeout: u64 = 30;
294 let mut response: Response<AsyncBody>;
295 while request_number < MAX_RETRIES {
296 response = self
297 .send_request(
298 api_key,
299 spans.iter().map(|x| &**x).collect(),
300 request_timeout,
301 )
302 .await?;
303 request_number += 1;
304
305 match response.status() {
306 StatusCode::REQUEST_TIMEOUT => {
307 request_timeout += 5;
308 }
309 StatusCode::OK => {
310 let mut body = String::new();
311 response.body_mut().read_to_string(&mut body).await?;
312 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
313
314 log::trace!(
315 "openai embedding completed. tokens: {:?}",
316 response.usage.total_tokens
317 );
318
319 // If we complete a request successfully that was previously rate_limited
320 // resolve the rate limit
321 if rate_limiting {
322 self.resolve_rate_limit()
323 }
324
325 return Ok(response
326 .data
327 .into_iter()
328 .map(|embedding| Embedding::from(embedding.embedding))
329 .collect());
330 }
331 StatusCode::TOO_MANY_REQUESTS => {
332 rate_limiting = true;
333 let mut body = String::new();
334 response.body_mut().read_to_string(&mut body).await?;
335
336 let delay_duration = {
337 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
338 if let Some(time_to_reset) =
339 response.headers().get("x-ratelimit-reset-tokens")
340 {
341 if let Ok(time_str) = time_to_reset.to_str() {
342 parse(time_str).unwrap_or(delay)
343 } else {
344 delay
345 }
346 } else {
347 delay
348 }
349 };
350
351 // If we've previously rate limited, increment the duration but not the count
352 let reset_time = Instant::now().add(delay_duration);
353 self.update_reset_time(reset_time);
354
355 log::trace!(
356 "openai rate limiting: waiting {:?} until lifted",
357 &delay_duration
358 );
359
360 self.executor.timer(delay_duration).await;
361 }
362 _ => {
363 let mut body = String::new();
364 response.body_mut().read_to_string(&mut body).await?;
365 return Err(anyhow!(
366 "open ai bad request: {:?} {:?}",
367 &response.status(),
368 body
369 ));
370 }
371 }
372 }
373 Err(anyhow!("openai max retries"))
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use rand::prelude::*;
381
382 #[gpui::test]
383 fn test_similarity(mut rng: StdRng) {
384 assert_eq!(
385 Embedding::from(vec![1., 0., 0., 0., 0.])
386 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
387 0.
388 );
389 assert_eq!(
390 Embedding::from(vec![2., 0., 0., 0., 0.])
391 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
392 6.
393 );
394
395 for _ in 0..100 {
396 let size = 1536;
397 let mut a = vec![0.; size];
398 let mut b = vec![0.; size];
399 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
400 *a = rng.gen();
401 *b = rng.gen();
402 }
403 let a = Embedding::from(a);
404 let b = Embedding::from(b);
405
406 assert_eq!(
407 round_to_decimals(a.similarity(&b), 1),
408 round_to_decimals(reference_dot(&a.0, &b.0), 1)
409 );
410 }
411
412 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
413 let factor = (10.0 as f32).powi(decimal_places);
414 (n * factor).round() / factor
415 }
416
417 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
418 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
419 }
420 }
421}