1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::{serde_json, AppContext};
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use ordered_float::OrderedFloat;
11use parking_lot::Mutex;
12use parse_duration::parse;
13use postage::watch;
14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
15use rusqlite::ToSql;
16use serde::{Deserialize, Serialize};
17use std::env;
18use std::ops::Add;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tiktoken_rs::{cl100k_base, CoreBPE};
22use util::http::{HttpClient, Request};
23use util::ResultExt;
24
25use crate::completion::OPENAI_API_URL;
26
27lazy_static! {
28 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
29}
30
31#[derive(Debug, PartialEq, Clone)]
32pub struct Embedding(pub Vec<f32>);
33
34// This is needed for semantic index functionality
35// Unfortunately it has to live wherever the "Embedding" struct is created.
36// Keeping this in here though, introduces a 'rusqlite' dependency into AI
37// which is less than ideal
38impl FromSql for Embedding {
39 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
40 let bytes = value.as_blob()?;
41 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
42 if embedding.is_err() {
43 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
44 }
45 Ok(Embedding(embedding.unwrap()))
46 }
47}
48
49impl ToSql for Embedding {
50 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
51 let bytes = bincode::serialize(&self.0)
52 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
53 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
54 }
55}
56impl From<Vec<f32>> for Embedding {
57 fn from(value: Vec<f32>) -> Self {
58 Embedding(value)
59 }
60}
61
62impl Embedding {
63 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
64 let len = self.0.len();
65 assert_eq!(len, other.0.len());
66
67 let mut result = 0.0;
68 unsafe {
69 matrixmultiply::sgemm(
70 1,
71 len,
72 1,
73 1.0,
74 self.0.as_ptr(),
75 len as isize,
76 1,
77 other.0.as_ptr(),
78 1,
79 len as isize,
80 0.0,
81 &mut result as *mut f32,
82 1,
83 1,
84 );
85 }
86 OrderedFloat(result)
87 }
88}
89
90#[derive(Clone)]
91pub struct OpenAIEmbeddings {
92 pub client: Arc<dyn HttpClient>,
93 pub executor: Arc<Background>,
94 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
95 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
96}
97
98#[derive(Serialize)]
99struct OpenAIEmbeddingRequest<'a> {
100 model: &'static str,
101 input: Vec<&'a str>,
102}
103
104#[derive(Deserialize)]
105struct OpenAIEmbeddingResponse {
106 data: Vec<OpenAIEmbedding>,
107 usage: OpenAIEmbeddingUsage,
108}
109
110#[derive(Debug, Deserialize)]
111struct OpenAIEmbedding {
112 embedding: Vec<f32>,
113 index: usize,
114 object: String,
115}
116
117#[derive(Deserialize)]
118struct OpenAIEmbeddingUsage {
119 prompt_tokens: usize,
120 total_tokens: usize,
121}
122
123#[async_trait]
124pub trait EmbeddingProvider: Sync + Send {
125 fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
126 async fn embed_batch(
127 &self,
128 spans: Vec<String>,
129 api_key: Option<String>,
130 ) -> Result<Vec<Embedding>>;
131 fn max_tokens_per_batch(&self) -> usize;
132 fn truncate(&self, span: &str) -> (String, usize);
133 fn rate_limit_expiration(&self) -> Option<Instant>;
134}
135
136pub struct DummyEmbeddings {}
137
138#[async_trait]
139impl EmbeddingProvider for DummyEmbeddings {
140 fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
141 Some("Dummy API KEY".to_string())
142 }
143 fn rate_limit_expiration(&self) -> Option<Instant> {
144 None
145 }
146 async fn embed_batch(
147 &self,
148 spans: Vec<String>,
149 _api_key: Option<String>,
150 ) -> Result<Vec<Embedding>> {
151 // 1024 is the OpenAI Embeddings size for ada models.
152 // the model we will likely be starting with.
153 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
154 return Ok(vec![dummy_vec; spans.len()]);
155 }
156
157 fn max_tokens_per_batch(&self) -> usize {
158 OPENAI_INPUT_LIMIT
159 }
160
161 fn truncate(&self, span: &str) -> (String, usize) {
162 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
163 let token_count = tokens.len();
164 let output = if token_count > OPENAI_INPUT_LIMIT {
165 tokens.truncate(OPENAI_INPUT_LIMIT);
166 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
167 new_input.ok().unwrap_or_else(|| span.to_string())
168 } else {
169 span.to_string()
170 };
171
172 (output, tokens.len())
173 }
174}
175
176const OPENAI_INPUT_LIMIT: usize = 8190;
177
178impl OpenAIEmbeddings {
179 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
180 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
181 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
182
183 OpenAIEmbeddings {
184 client,
185 executor,
186 rate_limit_count_rx,
187 rate_limit_count_tx,
188 }
189 }
190
191 fn resolve_rate_limit(&self) {
192 let reset_time = *self.rate_limit_count_tx.lock().borrow();
193
194 if let Some(reset_time) = reset_time {
195 if Instant::now() >= reset_time {
196 *self.rate_limit_count_tx.lock().borrow_mut() = None
197 }
198 }
199
200 log::trace!(
201 "resolving reset time: {:?}",
202 *self.rate_limit_count_tx.lock().borrow()
203 );
204 }
205
206 fn update_reset_time(&self, reset_time: Instant) {
207 let original_time = *self.rate_limit_count_tx.lock().borrow();
208
209 let updated_time = if let Some(original_time) = original_time {
210 if reset_time < original_time {
211 Some(reset_time)
212 } else {
213 Some(original_time)
214 }
215 } else {
216 Some(reset_time)
217 };
218
219 log::trace!("updating rate limit time: {:?}", updated_time);
220
221 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
222 }
223 async fn send_request(
224 &self,
225 api_key: &str,
226 spans: Vec<&str>,
227 request_timeout: u64,
228 ) -> Result<Response<AsyncBody>> {
229 let request = Request::post("https://api.openai.com/v1/embeddings")
230 .redirect_policy(isahc::config::RedirectPolicy::Follow)
231 .timeout(Duration::from_secs(request_timeout))
232 .header("Content-Type", "application/json")
233 .header("Authorization", format!("Bearer {}", api_key))
234 .body(
235 serde_json::to_string(&OpenAIEmbeddingRequest {
236 input: spans.clone(),
237 model: "text-embedding-ada-002",
238 })
239 .unwrap()
240 .into(),
241 )?;
242
243 Ok(self.client.send(request).await?)
244 }
245}
246
247#[async_trait]
248impl EmbeddingProvider for OpenAIEmbeddings {
249 fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
250 if let Ok(api_key) = env::var("OPENAI_API_KEY") {
251 Some(api_key)
252 } else if let Some((_, api_key)) = cx
253 .platform()
254 .read_credentials(OPENAI_API_URL)
255 .log_err()
256 .flatten()
257 {
258 String::from_utf8(api_key).log_err()
259 } else {
260 None
261 }
262 }
263
264 fn max_tokens_per_batch(&self) -> usize {
265 50000
266 }
267
268 fn rate_limit_expiration(&self) -> Option<Instant> {
269 *self.rate_limit_count_rx.borrow()
270 }
271 fn truncate(&self, span: &str) -> (String, usize) {
272 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
273 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
274 tokens.truncate(OPENAI_INPUT_LIMIT);
275 OPENAI_BPE_TOKENIZER
276 .decode(tokens.clone())
277 .ok()
278 .unwrap_or_else(|| span.to_string())
279 } else {
280 span.to_string()
281 };
282
283 (output, tokens.len())
284 }
285
286 async fn embed_batch(
287 &self,
288 spans: Vec<String>,
289 api_key: Option<String>,
290 ) -> Result<Vec<Embedding>> {
291 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
292 const MAX_RETRIES: usize = 4;
293
294 let Some(api_key) = api_key else {
295 return Err(anyhow!("no open ai key provided"));
296 };
297
298 let mut request_number = 0;
299 let mut rate_limiting = false;
300 let mut request_timeout: u64 = 15;
301 let mut response: Response<AsyncBody>;
302 while request_number < MAX_RETRIES {
303 response = self
304 .send_request(
305 &api_key,
306 spans.iter().map(|x| &**x).collect(),
307 request_timeout,
308 )
309 .await?;
310
311 request_number += 1;
312
313 match response.status() {
314 StatusCode::REQUEST_TIMEOUT => {
315 request_timeout += 5;
316 }
317 StatusCode::OK => {
318 let mut body = String::new();
319 response.body_mut().read_to_string(&mut body).await?;
320 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
321
322 log::trace!(
323 "openai embedding completed. tokens: {:?}",
324 response.usage.total_tokens
325 );
326
327 // If we complete a request successfully that was previously rate_limited
328 // resolve the rate limit
329 if rate_limiting {
330 self.resolve_rate_limit()
331 }
332
333 return Ok(response
334 .data
335 .into_iter()
336 .map(|embedding| Embedding::from(embedding.embedding))
337 .collect());
338 }
339 StatusCode::TOO_MANY_REQUESTS => {
340 rate_limiting = true;
341 let mut body = String::new();
342 response.body_mut().read_to_string(&mut body).await?;
343
344 let delay_duration = {
345 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
346 if let Some(time_to_reset) =
347 response.headers().get("x-ratelimit-reset-tokens")
348 {
349 if let Ok(time_str) = time_to_reset.to_str() {
350 parse(time_str).unwrap_or(delay)
351 } else {
352 delay
353 }
354 } else {
355 delay
356 }
357 };
358
359 // If we've previously rate limited, increment the duration but not the count
360 let reset_time = Instant::now().add(delay_duration);
361 self.update_reset_time(reset_time);
362
363 log::trace!(
364 "openai rate limiting: waiting {:?} until lifted",
365 &delay_duration
366 );
367
368 self.executor.timer(delay_duration).await;
369 }
370 _ => {
371 let mut body = String::new();
372 response.body_mut().read_to_string(&mut body).await?;
373 return Err(anyhow!(
374 "open ai bad request: {:?} {:?}",
375 &response.status(),
376 body
377 ));
378 }
379 }
380 }
381 Err(anyhow!("openai max retries"))
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use rand::prelude::*;
389
390 #[gpui::test]
391 fn test_similarity(mut rng: StdRng) {
392 assert_eq!(
393 Embedding::from(vec![1., 0., 0., 0., 0.])
394 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
395 0.
396 );
397 assert_eq!(
398 Embedding::from(vec![2., 0., 0., 0., 0.])
399 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
400 6.
401 );
402
403 for _ in 0..100 {
404 let size = 1536;
405 let mut a = vec![0.; size];
406 let mut b = vec![0.; size];
407 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
408 *a = rng.gen();
409 *b = rng.gen();
410 }
411 let a = Embedding::from(a);
412 let b = Embedding::from(b);
413
414 assert_eq!(
415 round_to_decimals(a.similarity(&b), 1),
416 round_to_decimals(reference_dot(&a.0, &b.0), 1)
417 );
418 }
419
420 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
421 let factor = (10.0 as f32).powi(decimal_places);
422 (n * factor).round() / factor
423 }
424
425 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
426 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
427 }
428 }
429}