1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parking_lot::Mutex;
11use parse_duration::parse;
12use postage::watch;
13use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
14use rusqlite::ToSql;
15use serde::{Deserialize, Serialize};
16use std::env;
17use std::ops::Add;
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use tiktoken_rs::{cl100k_base, CoreBPE};
21use util::http::{HttpClient, Request};
22
23lazy_static! {
24 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
25 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
26}
27
28#[derive(Debug, PartialEq, Clone)]
29pub struct Embedding(Vec<f32>);
30
31impl From<Vec<f32>> for Embedding {
32 fn from(value: Vec<f32>) -> Self {
33 Embedding(value)
34 }
35}
36
37impl Embedding {
38 pub fn similarity(&self, other: &Self) -> f32 {
39 let len = self.0.len();
40 assert_eq!(len, other.0.len());
41
42 let mut result = 0.0;
43 unsafe {
44 matrixmultiply::sgemm(
45 1,
46 len,
47 1,
48 1.0,
49 self.0.as_ptr(),
50 len as isize,
51 1,
52 other.0.as_ptr(),
53 1,
54 len as isize,
55 0.0,
56 &mut result as *mut f32,
57 1,
58 1,
59 );
60 }
61 result
62 }
63}
64
65impl FromSql for Embedding {
66 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
67 let bytes = value.as_blob()?;
68 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
69 if embedding.is_err() {
70 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
71 }
72 Ok(Embedding(embedding.unwrap()))
73 }
74}
75
76impl ToSql for Embedding {
77 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
78 let bytes = bincode::serialize(&self.0)
79 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
80 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
81 }
82}
83
84#[derive(Clone)]
85pub struct OpenAIEmbeddings {
86 pub client: Arc<dyn HttpClient>,
87 pub executor: Arc<Background>,
88 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
89 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
90}
91
92#[derive(Serialize)]
93struct OpenAIEmbeddingRequest<'a> {
94 model: &'static str,
95 input: Vec<&'a str>,
96}
97
98#[derive(Deserialize)]
99struct OpenAIEmbeddingResponse {
100 data: Vec<OpenAIEmbedding>,
101 usage: OpenAIEmbeddingUsage,
102}
103
104#[derive(Debug, Deserialize)]
105struct OpenAIEmbedding {
106 embedding: Vec<f32>,
107 index: usize,
108 object: String,
109}
110
111#[derive(Deserialize)]
112struct OpenAIEmbeddingUsage {
113 prompt_tokens: usize,
114 total_tokens: usize,
115}
116
117#[async_trait]
118pub trait EmbeddingProvider: Sync + Send {
119 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
120 fn max_tokens_per_batch(&self) -> usize;
121 fn truncate(&self, span: &str) -> (String, usize);
122 fn rate_limit_expiration(&self) -> Option<Instant>;
123}
124
125pub struct DummyEmbeddings {}
126
127#[async_trait]
128impl EmbeddingProvider for DummyEmbeddings {
129 fn rate_limit_expiration(&self) -> Option<Instant> {
130 None
131 }
132 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
133 // 1024 is the OpenAI Embeddings size for ada models.
134 // the model we will likely be starting with.
135 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
136 return Ok(vec![dummy_vec; spans.len()]);
137 }
138
139 fn max_tokens_per_batch(&self) -> usize {
140 OPENAI_INPUT_LIMIT
141 }
142
143 fn truncate(&self, span: &str) -> (String, usize) {
144 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
145 let token_count = tokens.len();
146 let output = if token_count > OPENAI_INPUT_LIMIT {
147 tokens.truncate(OPENAI_INPUT_LIMIT);
148 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
149 new_input.ok().unwrap_or_else(|| span.to_string())
150 } else {
151 span.to_string()
152 };
153
154 (output, tokens.len())
155 }
156}
157
158const OPENAI_INPUT_LIMIT: usize = 8190;
159
160impl OpenAIEmbeddings {
161 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
162 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
163 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
164
165 OpenAIEmbeddings {
166 client,
167 executor,
168 rate_limit_count_rx,
169 rate_limit_count_tx,
170 }
171 }
172
173 fn resolve_rate_limit(&self) {
174 let reset_time = *self.rate_limit_count_tx.lock().borrow();
175
176 if let Some(reset_time) = reset_time {
177 if Instant::now() >= reset_time {
178 *self.rate_limit_count_tx.lock().borrow_mut() = None
179 }
180 }
181
182 log::trace!(
183 "resolving reset time: {:?}",
184 *self.rate_limit_count_tx.lock().borrow()
185 );
186 }
187
188 fn update_reset_time(&self, reset_time: Instant) {
189 let original_time = *self.rate_limit_count_tx.lock().borrow();
190
191 let updated_time = if let Some(original_time) = original_time {
192 if reset_time < original_time {
193 Some(reset_time)
194 } else {
195 Some(original_time)
196 }
197 } else {
198 Some(reset_time)
199 };
200
201 log::trace!("updating rate limit time: {:?}", updated_time);
202
203 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
204 }
205 async fn send_request(
206 &self,
207 api_key: &str,
208 spans: Vec<&str>,
209 request_timeout: u64,
210 ) -> Result<Response<AsyncBody>> {
211 let request = Request::post("https://api.openai.com/v1/embeddings")
212 .redirect_policy(isahc::config::RedirectPolicy::Follow)
213 .timeout(Duration::from_secs(request_timeout))
214 .header("Content-Type", "application/json")
215 .header("Authorization", format!("Bearer {}", api_key))
216 .body(
217 serde_json::to_string(&OpenAIEmbeddingRequest {
218 input: spans.clone(),
219 model: "text-embedding-ada-002",
220 })
221 .unwrap()
222 .into(),
223 )?;
224
225 Ok(self.client.send(request).await?)
226 }
227}
228
229#[async_trait]
230impl EmbeddingProvider for OpenAIEmbeddings {
231 fn max_tokens_per_batch(&self) -> usize {
232 50000
233 }
234
235 fn rate_limit_expiration(&self) -> Option<Instant> {
236 *self.rate_limit_count_rx.borrow()
237 }
238 fn truncate(&self, span: &str) -> (String, usize) {
239 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
240 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
241 tokens.truncate(OPENAI_INPUT_LIMIT);
242 OPENAI_BPE_TOKENIZER
243 .decode(tokens.clone())
244 .ok()
245 .unwrap_or_else(|| span.to_string())
246 } else {
247 span.to_string()
248 };
249
250 (output, tokens.len())
251 }
252
253 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
254 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
255 const MAX_RETRIES: usize = 4;
256
257 let api_key = OPENAI_API_KEY
258 .as_ref()
259 .ok_or_else(|| anyhow!("no api key"))?;
260
261 let mut request_number = 0;
262 let mut rate_limiting = false;
263 let mut request_timeout: u64 = 15;
264 let mut response: Response<AsyncBody>;
265 while request_number < MAX_RETRIES {
266 response = self
267 .send_request(
268 api_key,
269 spans.iter().map(|x| &**x).collect(),
270 request_timeout,
271 )
272 .await?;
273 request_number += 1;
274
275 match response.status() {
276 StatusCode::REQUEST_TIMEOUT => {
277 request_timeout += 5;
278 }
279 StatusCode::OK => {
280 let mut body = String::new();
281 response.body_mut().read_to_string(&mut body).await?;
282 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
283
284 log::trace!(
285 "openai embedding completed. tokens: {:?}",
286 response.usage.total_tokens
287 );
288
289 // If we complete a request successfully that was previously rate_limited
290 // resolve the rate limit
291 if rate_limiting {
292 self.resolve_rate_limit()
293 }
294
295 return Ok(response
296 .data
297 .into_iter()
298 .map(|embedding| Embedding::from(embedding.embedding))
299 .collect());
300 }
301 StatusCode::TOO_MANY_REQUESTS => {
302 rate_limiting = true;
303 let mut body = String::new();
304 response.body_mut().read_to_string(&mut body).await?;
305
306 let delay_duration = {
307 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
308 if let Some(time_to_reset) =
309 response.headers().get("x-ratelimit-reset-tokens")
310 {
311 if let Ok(time_str) = time_to_reset.to_str() {
312 parse(time_str).unwrap_or(delay)
313 } else {
314 delay
315 }
316 } else {
317 delay
318 }
319 };
320
321 // If we've previously rate limited, increment the duration but not the count
322 let reset_time = Instant::now().add(delay_duration);
323 self.update_reset_time(reset_time);
324
325 log::trace!(
326 "openai rate limiting: waiting {:?} until lifted",
327 &delay_duration
328 );
329
330 self.executor.timer(delay_duration).await;
331 }
332 _ => {
333 let mut body = String::new();
334 response.body_mut().read_to_string(&mut body).await?;
335 return Err(anyhow!(
336 "open ai bad request: {:?} {:?}",
337 &response.status(),
338 body
339 ));
340 }
341 }
342 }
343 Err(anyhow!("openai max retries"))
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use rand::prelude::*;
351
352 #[gpui::test]
353 fn test_similarity(mut rng: StdRng) {
354 assert_eq!(
355 Embedding::from(vec![1., 0., 0., 0., 0.])
356 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
357 0.
358 );
359 assert_eq!(
360 Embedding::from(vec![2., 0., 0., 0., 0.])
361 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
362 6.
363 );
364
365 for _ in 0..100 {
366 let size = 1536;
367 let mut a = vec![0.; size];
368 let mut b = vec![0.; size];
369 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
370 *a = rng.gen();
371 *b = rng.gen();
372 }
373 let a = Embedding::from(a);
374 let b = Embedding::from(b);
375
376 assert_eq!(
377 round_to_decimals(a.similarity(&b), 1),
378 round_to_decimals(reference_dot(&a.0, &b.0), 1)
379 );
380 }
381
382 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
383 let factor = (10.0 as f32).powi(decimal_places);
384 (n * factor).round() / factor
385 }
386
387 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
388 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
389 }
390 }
391}