1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::{serde_json, ViewContext};
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use ordered_float::OrderedFloat;
11use parking_lot::Mutex;
12use parse_duration::parse;
13use postage::watch;
14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
15use rusqlite::ToSql;
16use serde::{Deserialize, Serialize};
17use std::env;
18use std::ops::Add;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tiktoken_rs::{cl100k_base, CoreBPE};
22use util::http::{HttpClient, Request};
23use util::ResultExt;
24
25use crate::completion::OPENAI_API_URL;
26
27lazy_static! {
28 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
29}
30
31#[derive(Debug, PartialEq, Clone)]
32pub struct Embedding(pub Vec<f32>);
33
34// This is needed for semantic index functionality
35// Unfortunately it has to live wherever the "Embedding" struct is created.
36// Keeping this in here though, introduces a 'rusqlite' dependency into AI
37// which is less than ideal
38impl FromSql for Embedding {
39 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
40 let bytes = value.as_blob()?;
41 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
42 if embedding.is_err() {
43 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
44 }
45 Ok(Embedding(embedding.unwrap()))
46 }
47}
48
49impl ToSql for Embedding {
50 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
51 let bytes = bincode::serialize(&self.0)
52 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
53 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
54 }
55}
56impl From<Vec<f32>> for Embedding {
57 fn from(value: Vec<f32>) -> Self {
58 Embedding(value)
59 }
60}
61
62impl Embedding {
63 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
64 let len = self.0.len();
65 assert_eq!(len, other.0.len());
66
67 let mut result = 0.0;
68 unsafe {
69 matrixmultiply::sgemm(
70 1,
71 len,
72 1,
73 1.0,
74 self.0.as_ptr(),
75 len as isize,
76 1,
77 other.0.as_ptr(),
78 1,
79 len as isize,
80 0.0,
81 &mut result as *mut f32,
82 1,
83 1,
84 );
85 }
86 OrderedFloat(result)
87 }
88}
89
90#[derive(Clone)]
91pub struct OpenAIEmbeddings {
92 pub api_key: Option<String>,
93 pub client: Arc<dyn HttpClient>,
94 pub executor: Arc<Background>,
95 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
96 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
97}
98
99#[derive(Serialize)]
100struct OpenAIEmbeddingRequest<'a> {
101 model: &'static str,
102 input: Vec<&'a str>,
103}
104
105#[derive(Deserialize)]
106struct OpenAIEmbeddingResponse {
107 data: Vec<OpenAIEmbedding>,
108 usage: OpenAIEmbeddingUsage,
109}
110
111#[derive(Debug, Deserialize)]
112struct OpenAIEmbedding {
113 embedding: Vec<f32>,
114 index: usize,
115 object: String,
116}
117
118#[derive(Deserialize)]
119struct OpenAIEmbeddingUsage {
120 prompt_tokens: usize,
121 total_tokens: usize,
122}
123
124#[async_trait]
125pub trait EmbeddingProvider: Sync + Send {
126 fn is_authenticated(&self) -> bool;
127 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
128 fn max_tokens_per_batch(&self) -> usize;
129 fn truncate(&self, span: &str) -> (String, usize);
130 fn rate_limit_expiration(&self) -> Option<Instant>;
131}
132
133pub struct DummyEmbeddings {}
134
135#[async_trait]
136impl EmbeddingProvider for DummyEmbeddings {
137 fn is_authenticated(&self) -> bool {
138 true
139 }
140 fn rate_limit_expiration(&self) -> Option<Instant> {
141 None
142 }
143 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
144 // 1024 is the OpenAI Embeddings size for ada models.
145 // the model we will likely be starting with.
146 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
147 return Ok(vec![dummy_vec; spans.len()]);
148 }
149
150 fn max_tokens_per_batch(&self) -> usize {
151 OPENAI_INPUT_LIMIT
152 }
153
154 fn truncate(&self, span: &str) -> (String, usize) {
155 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
156 let token_count = tokens.len();
157 let output = if token_count > OPENAI_INPUT_LIMIT {
158 tokens.truncate(OPENAI_INPUT_LIMIT);
159 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
160 new_input.ok().unwrap_or_else(|| span.to_string())
161 } else {
162 span.to_string()
163 };
164
165 (output, tokens.len())
166 }
167}
168
169const OPENAI_INPUT_LIMIT: usize = 8190;
170
171impl OpenAIEmbeddings {
172 pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
173 if self.api_key.is_none() {
174 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
175 Some(api_key)
176 } else if let Some((_, api_key)) = cx
177 .platform()
178 .read_credentials(OPENAI_API_URL)
179 .log_err()
180 .flatten()
181 {
182 String::from_utf8(api_key).log_err()
183 } else {
184 None
185 };
186
187 if let Some(api_key) = api_key {
188 self.api_key = Some(api_key);
189 }
190 }
191 }
192 pub fn new(
193 api_key: Option<String>,
194 client: Arc<dyn HttpClient>,
195 executor: Arc<Background>,
196 ) -> Self {
197 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
198 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
199
200 OpenAIEmbeddings {
201 api_key,
202 client,
203 executor,
204 rate_limit_count_rx,
205 rate_limit_count_tx,
206 }
207 }
208
209 fn resolve_rate_limit(&self) {
210 let reset_time = *self.rate_limit_count_tx.lock().borrow();
211
212 if let Some(reset_time) = reset_time {
213 if Instant::now() >= reset_time {
214 *self.rate_limit_count_tx.lock().borrow_mut() = None
215 }
216 }
217
218 log::trace!(
219 "resolving reset time: {:?}",
220 *self.rate_limit_count_tx.lock().borrow()
221 );
222 }
223
224 fn update_reset_time(&self, reset_time: Instant) {
225 let original_time = *self.rate_limit_count_tx.lock().borrow();
226
227 let updated_time = if let Some(original_time) = original_time {
228 if reset_time < original_time {
229 Some(reset_time)
230 } else {
231 Some(original_time)
232 }
233 } else {
234 Some(reset_time)
235 };
236
237 log::trace!("updating rate limit time: {:?}", updated_time);
238
239 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
240 }
241 async fn send_request(
242 &self,
243 api_key: &str,
244 spans: Vec<&str>,
245 request_timeout: u64,
246 ) -> Result<Response<AsyncBody>> {
247 let request = Request::post("https://api.openai.com/v1/embeddings")
248 .redirect_policy(isahc::config::RedirectPolicy::Follow)
249 .timeout(Duration::from_secs(request_timeout))
250 .header("Content-Type", "application/json")
251 .header("Authorization", format!("Bearer {}", api_key))
252 .body(
253 serde_json::to_string(&OpenAIEmbeddingRequest {
254 input: spans.clone(),
255 model: "text-embedding-ada-002",
256 })
257 .unwrap()
258 .into(),
259 )?;
260
261 Ok(self.client.send(request).await?)
262 }
263}
264
265#[async_trait]
266impl EmbeddingProvider for OpenAIEmbeddings {
267 fn is_authenticated(&self) -> bool {
268 self.api_key.is_some()
269 }
270
271 fn max_tokens_per_batch(&self) -> usize {
272 50000
273 }
274
275 fn rate_limit_expiration(&self) -> Option<Instant> {
276 *self.rate_limit_count_rx.borrow()
277 }
278 fn truncate(&self, span: &str) -> (String, usize) {
279 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
280 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
281 tokens.truncate(OPENAI_INPUT_LIMIT);
282 OPENAI_BPE_TOKENIZER
283 .decode(tokens.clone())
284 .ok()
285 .unwrap_or_else(|| span.to_string())
286 } else {
287 span.to_string()
288 };
289
290 (output, tokens.len())
291 }
292
293 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
294 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
295 const MAX_RETRIES: usize = 4;
296
297 let Some(api_key) = self.api_key.clone() else {
298 return Err(anyhow!("no open ai key provided"));
299 };
300
301 let mut request_number = 0;
302 let mut rate_limiting = false;
303 let mut request_timeout: u64 = 15;
304 let mut response: Response<AsyncBody>;
305 while request_number < MAX_RETRIES {
306 response = self
307 .send_request(
308 &api_key,
309 spans.iter().map(|x| &**x).collect(),
310 request_timeout,
311 )
312 .await?;
313
314 request_number += 1;
315
316 match response.status() {
317 StatusCode::REQUEST_TIMEOUT => {
318 request_timeout += 5;
319 }
320 StatusCode::OK => {
321 let mut body = String::new();
322 response.body_mut().read_to_string(&mut body).await?;
323 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
324
325 log::trace!(
326 "openai embedding completed. tokens: {:?}",
327 response.usage.total_tokens
328 );
329
330 // If we complete a request successfully that was previously rate_limited
331 // resolve the rate limit
332 if rate_limiting {
333 self.resolve_rate_limit()
334 }
335
336 return Ok(response
337 .data
338 .into_iter()
339 .map(|embedding| Embedding::from(embedding.embedding))
340 .collect());
341 }
342 StatusCode::TOO_MANY_REQUESTS => {
343 rate_limiting = true;
344 let mut body = String::new();
345 response.body_mut().read_to_string(&mut body).await?;
346
347 let delay_duration = {
348 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
349 if let Some(time_to_reset) =
350 response.headers().get("x-ratelimit-reset-tokens")
351 {
352 if let Ok(time_str) = time_to_reset.to_str() {
353 parse(time_str).unwrap_or(delay)
354 } else {
355 delay
356 }
357 } else {
358 delay
359 }
360 };
361
362 // If we've previously rate limited, increment the duration but not the count
363 let reset_time = Instant::now().add(delay_duration);
364 self.update_reset_time(reset_time);
365
366 log::trace!(
367 "openai rate limiting: waiting {:?} until lifted",
368 &delay_duration
369 );
370
371 self.executor.timer(delay_duration).await;
372 }
373 _ => {
374 let mut body = String::new();
375 response.body_mut().read_to_string(&mut body).await?;
376 return Err(anyhow!(
377 "open ai bad request: {:?} {:?}",
378 &response.status(),
379 body
380 ));
381 }
382 }
383 }
384 Err(anyhow!("openai max retries"))
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use rand::prelude::*;
392
393 #[gpui::test]
394 fn test_similarity(mut rng: StdRng) {
395 assert_eq!(
396 Embedding::from(vec![1., 0., 0., 0., 0.])
397 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
398 0.
399 );
400 assert_eq!(
401 Embedding::from(vec![2., 0., 0., 0., 0.])
402 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
403 6.
404 );
405
406 for _ in 0..100 {
407 let size = 1536;
408 let mut a = vec![0.; size];
409 let mut b = vec![0.; size];
410 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
411 *a = rng.gen();
412 *b = rng.gen();
413 }
414 let a = Embedding::from(a);
415 let b = Embedding::from(b);
416
417 assert_eq!(
418 round_to_decimals(a.similarity(&b), 1),
419 round_to_decimals(reference_dot(&a.0, &b.0), 1)
420 );
421 }
422
423 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
424 let factor = (10.0 as f32).powi(decimal_places);
425 (n * factor).round() / factor
426 }
427
428 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
429 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
430 }
431 }
432}