1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use ordered_float::OrderedFloat;
11use parking_lot::Mutex;
12use parse_duration::parse;
13use postage::watch;
14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
15use rusqlite::ToSql;
16use serde::{Deserialize, Serialize};
17use std::env;
18use std::ops::Add;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tiktoken_rs::{cl100k_base, CoreBPE};
22use util::http::{HttpClient, Request};
23
24lazy_static! {
25 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
26 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
27}
28
29#[derive(Debug, PartialEq, Clone)]
30pub struct Embedding(Vec<f32>);
31
32impl From<Vec<f32>> for Embedding {
33 fn from(value: Vec<f32>) -> Self {
34 Embedding(value)
35 }
36}
37
38impl Embedding {
39 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
40 let len = self.0.len();
41 assert_eq!(len, other.0.len());
42
43 let mut result = 0.0;
44 unsafe {
45 matrixmultiply::sgemm(
46 1,
47 len,
48 1,
49 1.0,
50 self.0.as_ptr(),
51 len as isize,
52 1,
53 other.0.as_ptr(),
54 1,
55 len as isize,
56 0.0,
57 &mut result as *mut f32,
58 1,
59 1,
60 );
61 }
62 OrderedFloat(result)
63 }
64}
65
66impl FromSql for Embedding {
67 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
68 let bytes = value.as_blob()?;
69 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
70 if embedding.is_err() {
71 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
72 }
73 Ok(Embedding(embedding.unwrap()))
74 }
75}
76
77impl ToSql for Embedding {
78 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
79 let bytes = bincode::serialize(&self.0)
80 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
81 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
82 }
83}
84
85#[derive(Clone)]
86pub struct OpenAIEmbeddings {
87 pub client: Arc<dyn HttpClient>,
88 pub executor: Arc<Background>,
89 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
90 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
91}
92
93#[derive(Serialize)]
94struct OpenAIEmbeddingRequest<'a> {
95 model: &'static str,
96 input: Vec<&'a str>,
97}
98
99#[derive(Deserialize)]
100struct OpenAIEmbeddingResponse {
101 data: Vec<OpenAIEmbedding>,
102 usage: OpenAIEmbeddingUsage,
103}
104
105#[derive(Debug, Deserialize)]
106struct OpenAIEmbedding {
107 embedding: Vec<f32>,
108 index: usize,
109 object: String,
110}
111
112#[derive(Deserialize)]
113struct OpenAIEmbeddingUsage {
114 prompt_tokens: usize,
115 total_tokens: usize,
116}
117
118#[async_trait]
119pub trait EmbeddingProvider: Sync + Send {
120 fn is_authenticated(&self) -> bool;
121 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
122 fn max_tokens_per_batch(&self) -> usize;
123 fn truncate(&self, span: &str) -> (String, usize);
124 fn rate_limit_expiration(&self) -> Option<Instant>;
125}
126
127pub struct DummyEmbeddings {}
128
129#[async_trait]
130impl EmbeddingProvider for DummyEmbeddings {
131 fn is_authenticated(&self) -> bool {
132 true
133 }
134 fn rate_limit_expiration(&self) -> Option<Instant> {
135 None
136 }
137 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
138 // 1024 is the OpenAI Embeddings size for ada models.
139 // the model we will likely be starting with.
140 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
141 return Ok(vec![dummy_vec; spans.len()]);
142 }
143
144 fn max_tokens_per_batch(&self) -> usize {
145 OPENAI_INPUT_LIMIT
146 }
147
148 fn truncate(&self, span: &str) -> (String, usize) {
149 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
150 let token_count = tokens.len();
151 let output = if token_count > OPENAI_INPUT_LIMIT {
152 tokens.truncate(OPENAI_INPUT_LIMIT);
153 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
154 new_input.ok().unwrap_or_else(|| span.to_string())
155 } else {
156 span.to_string()
157 };
158
159 (output, tokens.len())
160 }
161}
162
163const OPENAI_INPUT_LIMIT: usize = 8190;
164
165impl OpenAIEmbeddings {
166 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
167 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
168 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
169
170 OpenAIEmbeddings {
171 client,
172 executor,
173 rate_limit_count_rx,
174 rate_limit_count_tx,
175 }
176 }
177
178 fn resolve_rate_limit(&self) {
179 let reset_time = *self.rate_limit_count_tx.lock().borrow();
180
181 if let Some(reset_time) = reset_time {
182 if Instant::now() >= reset_time {
183 *self.rate_limit_count_tx.lock().borrow_mut() = None
184 }
185 }
186
187 log::trace!(
188 "resolving reset time: {:?}",
189 *self.rate_limit_count_tx.lock().borrow()
190 );
191 }
192
193 fn update_reset_time(&self, reset_time: Instant) {
194 let original_time = *self.rate_limit_count_tx.lock().borrow();
195
196 let updated_time = if let Some(original_time) = original_time {
197 if reset_time < original_time {
198 Some(reset_time)
199 } else {
200 Some(original_time)
201 }
202 } else {
203 Some(reset_time)
204 };
205
206 log::trace!("updating rate limit time: {:?}", updated_time);
207
208 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
209 }
210 async fn send_request(
211 &self,
212 api_key: &str,
213 spans: Vec<&str>,
214 request_timeout: u64,
215 ) -> Result<Response<AsyncBody>> {
216 let request = Request::post("https://api.openai.com/v1/embeddings")
217 .redirect_policy(isahc::config::RedirectPolicy::Follow)
218 .timeout(Duration::from_secs(request_timeout))
219 .header("Content-Type", "application/json")
220 .header("Authorization", format!("Bearer {}", api_key))
221 .body(
222 serde_json::to_string(&OpenAIEmbeddingRequest {
223 input: spans.clone(),
224 model: "text-embedding-ada-002",
225 })
226 .unwrap()
227 .into(),
228 )?;
229
230 Ok(self.client.send(request).await?)
231 }
232}
233
234#[async_trait]
235impl EmbeddingProvider for OpenAIEmbeddings {
236 fn is_authenticated(&self) -> bool {
237 OPENAI_API_KEY.as_ref().is_some()
238 }
239 fn max_tokens_per_batch(&self) -> usize {
240 50000
241 }
242
243 fn rate_limit_expiration(&self) -> Option<Instant> {
244 *self.rate_limit_count_rx.borrow()
245 }
246 fn truncate(&self, span: &str) -> (String, usize) {
247 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
248 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
249 tokens.truncate(OPENAI_INPUT_LIMIT);
250 OPENAI_BPE_TOKENIZER
251 .decode(tokens.clone())
252 .ok()
253 .unwrap_or_else(|| span.to_string())
254 } else {
255 span.to_string()
256 };
257
258 (output, tokens.len())
259 }
260
261 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
262 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
263 const MAX_RETRIES: usize = 4;
264
265 let api_key = OPENAI_API_KEY
266 .as_ref()
267 .ok_or_else(|| anyhow!("no api key"))?;
268
269 let mut request_number = 0;
270 let mut rate_limiting = false;
271 let mut request_timeout: u64 = 15;
272 let mut response: Response<AsyncBody>;
273 while request_number < MAX_RETRIES {
274 response = self
275 .send_request(
276 api_key,
277 spans.iter().map(|x| &**x).collect(),
278 request_timeout,
279 )
280 .await?;
281 request_number += 1;
282
283 match response.status() {
284 StatusCode::REQUEST_TIMEOUT => {
285 request_timeout += 5;
286 }
287 StatusCode::OK => {
288 let mut body = String::new();
289 response.body_mut().read_to_string(&mut body).await?;
290 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
291
292 log::trace!(
293 "openai embedding completed. tokens: {:?}",
294 response.usage.total_tokens
295 );
296
297 // If we complete a request successfully that was previously rate_limited
298 // resolve the rate limit
299 if rate_limiting {
300 self.resolve_rate_limit()
301 }
302
303 return Ok(response
304 .data
305 .into_iter()
306 .map(|embedding| Embedding::from(embedding.embedding))
307 .collect());
308 }
309 StatusCode::TOO_MANY_REQUESTS => {
310 rate_limiting = true;
311 let mut body = String::new();
312 response.body_mut().read_to_string(&mut body).await?;
313
314 let delay_duration = {
315 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
316 if let Some(time_to_reset) =
317 response.headers().get("x-ratelimit-reset-tokens")
318 {
319 if let Ok(time_str) = time_to_reset.to_str() {
320 parse(time_str).unwrap_or(delay)
321 } else {
322 delay
323 }
324 } else {
325 delay
326 }
327 };
328
329 // If we've previously rate limited, increment the duration but not the count
330 let reset_time = Instant::now().add(delay_duration);
331 self.update_reset_time(reset_time);
332
333 log::trace!(
334 "openai rate limiting: waiting {:?} until lifted",
335 &delay_duration
336 );
337
338 self.executor.timer(delay_duration).await;
339 }
340 _ => {
341 let mut body = String::new();
342 response.body_mut().read_to_string(&mut body).await?;
343 return Err(anyhow!(
344 "open ai bad request: {:?} {:?}",
345 &response.status(),
346 body
347 ));
348 }
349 }
350 }
351 Err(anyhow!("openai max retries"))
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use rand::prelude::*;
359
360 #[gpui::test]
361 fn test_similarity(mut rng: StdRng) {
362 assert_eq!(
363 Embedding::from(vec![1., 0., 0., 0., 0.])
364 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
365 0.
366 );
367 assert_eq!(
368 Embedding::from(vec![2., 0., 0., 0., 0.])
369 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
370 6.
371 );
372
373 for _ in 0..100 {
374 let size = 1536;
375 let mut a = vec![0.; size];
376 let mut b = vec![0.; size];
377 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
378 *a = rng.gen();
379 *b = rng.gen();
380 }
381 let a = Embedding::from(a);
382 let b = Embedding::from(b);
383
384 assert_eq!(
385 round_to_decimals(a.similarity(&b), 1),
386 round_to_decimals(reference_dot(&a.0, &b.0), 1)
387 );
388 }
389
390 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
391 let factor = (10.0 as f32).powi(decimal_places);
392 (n * factor).round() / factor
393 }
394
395 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
396 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
397 }
398 }
399}