1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use ordered_float::OrderedFloat;
11use parking_lot::Mutex;
12use parse_duration::parse;
13use postage::watch;
14use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
15use rusqlite::ToSql;
16use serde::{Deserialize, Serialize};
17use std::env;
18use std::ops::Add;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tiktoken_rs::{cl100k_base, CoreBPE};
22use util::http::{HttpClient, Request};
23
24lazy_static! {
25 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
26 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
27}
28
29#[derive(Debug, PartialEq, Clone)]
30pub struct Embedding(pub Vec<f32>);
31
32// This is needed for semantic index functionality
33// Unfortunately it has to live wherever the "Embedding" struct is created.
34// Keeping this in here though, introduces a 'rusqlite' dependency into AI
35// which is less than ideal
36impl FromSql for Embedding {
37 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
38 let bytes = value.as_blob()?;
39 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
40 if embedding.is_err() {
41 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
42 }
43 Ok(Embedding(embedding.unwrap()))
44 }
45}
46
47impl ToSql for Embedding {
48 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
49 let bytes = bincode::serialize(&self.0)
50 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
51 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
52 }
53}
54impl From<Vec<f32>> for Embedding {
55 fn from(value: Vec<f32>) -> Self {
56 Embedding(value)
57 }
58}
59
60impl Embedding {
61 pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
62 let len = self.0.len();
63 assert_eq!(len, other.0.len());
64
65 let mut result = 0.0;
66 unsafe {
67 matrixmultiply::sgemm(
68 1,
69 len,
70 1,
71 1.0,
72 self.0.as_ptr(),
73 len as isize,
74 1,
75 other.0.as_ptr(),
76 1,
77 len as isize,
78 0.0,
79 &mut result as *mut f32,
80 1,
81 1,
82 );
83 }
84 OrderedFloat(result)
85 }
86}
87
88#[derive(Clone)]
89pub struct OpenAIEmbeddings {
90 pub client: Arc<dyn HttpClient>,
91 pub executor: Arc<Background>,
92 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
93 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
94}
95
96#[derive(Serialize)]
97struct OpenAIEmbeddingRequest<'a> {
98 model: &'static str,
99 input: Vec<&'a str>,
100}
101
102#[derive(Deserialize)]
103struct OpenAIEmbeddingResponse {
104 data: Vec<OpenAIEmbedding>,
105 usage: OpenAIEmbeddingUsage,
106}
107
108#[derive(Debug, Deserialize)]
109struct OpenAIEmbedding {
110 embedding: Vec<f32>,
111 index: usize,
112 object: String,
113}
114
115#[derive(Deserialize)]
116struct OpenAIEmbeddingUsage {
117 prompt_tokens: usize,
118 total_tokens: usize,
119}
120
121#[async_trait]
122pub trait EmbeddingProvider: Sync + Send {
123 fn is_authenticated(&self) -> bool;
124 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
125 fn max_tokens_per_batch(&self) -> usize;
126 fn truncate(&self, span: &str) -> (String, usize);
127 fn rate_limit_expiration(&self) -> Option<Instant>;
128}
129
130pub struct DummyEmbeddings {}
131
132#[async_trait]
133impl EmbeddingProvider for DummyEmbeddings {
134 fn is_authenticated(&self) -> bool {
135 true
136 }
137 fn rate_limit_expiration(&self) -> Option<Instant> {
138 None
139 }
140 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
141 // 1024 is the OpenAI Embeddings size for ada models.
142 // the model we will likely be starting with.
143 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
144 return Ok(vec![dummy_vec; spans.len()]);
145 }
146
147 fn max_tokens_per_batch(&self) -> usize {
148 OPENAI_INPUT_LIMIT
149 }
150
151 fn truncate(&self, span: &str) -> (String, usize) {
152 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
153 let token_count = tokens.len();
154 let output = if token_count > OPENAI_INPUT_LIMIT {
155 tokens.truncate(OPENAI_INPUT_LIMIT);
156 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
157 new_input.ok().unwrap_or_else(|| span.to_string())
158 } else {
159 span.to_string()
160 };
161
162 (output, tokens.len())
163 }
164}
165
166const OPENAI_INPUT_LIMIT: usize = 8190;
167
168impl OpenAIEmbeddings {
169 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
170 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
171 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
172
173 OpenAIEmbeddings {
174 client,
175 executor,
176 rate_limit_count_rx,
177 rate_limit_count_tx,
178 }
179 }
180
181 fn resolve_rate_limit(&self) {
182 let reset_time = *self.rate_limit_count_tx.lock().borrow();
183
184 if let Some(reset_time) = reset_time {
185 if Instant::now() >= reset_time {
186 *self.rate_limit_count_tx.lock().borrow_mut() = None
187 }
188 }
189
190 log::trace!(
191 "resolving reset time: {:?}",
192 *self.rate_limit_count_tx.lock().borrow()
193 );
194 }
195
196 fn update_reset_time(&self, reset_time: Instant) {
197 let original_time = *self.rate_limit_count_tx.lock().borrow();
198
199 let updated_time = if let Some(original_time) = original_time {
200 if reset_time < original_time {
201 Some(reset_time)
202 } else {
203 Some(original_time)
204 }
205 } else {
206 Some(reset_time)
207 };
208
209 log::trace!("updating rate limit time: {:?}", updated_time);
210
211 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
212 }
213 async fn send_request(
214 &self,
215 api_key: &str,
216 spans: Vec<&str>,
217 request_timeout: u64,
218 ) -> Result<Response<AsyncBody>> {
219 let request = Request::post("https://api.openai.com/v1/embeddings")
220 .redirect_policy(isahc::config::RedirectPolicy::Follow)
221 .timeout(Duration::from_secs(request_timeout))
222 .header("Content-Type", "application/json")
223 .header("Authorization", format!("Bearer {}", api_key))
224 .body(
225 serde_json::to_string(&OpenAIEmbeddingRequest {
226 input: spans.clone(),
227 model: "text-embedding-ada-002",
228 })
229 .unwrap()
230 .into(),
231 )?;
232
233 Ok(self.client.send(request).await?)
234 }
235}
236
237#[async_trait]
238impl EmbeddingProvider for OpenAIEmbeddings {
239 fn is_authenticated(&self) -> bool {
240 OPENAI_API_KEY.as_ref().is_some()
241 }
242 fn max_tokens_per_batch(&self) -> usize {
243 50000
244 }
245
246 fn rate_limit_expiration(&self) -> Option<Instant> {
247 *self.rate_limit_count_rx.borrow()
248 }
249 fn truncate(&self, span: &str) -> (String, usize) {
250 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
251 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
252 tokens.truncate(OPENAI_INPUT_LIMIT);
253 OPENAI_BPE_TOKENIZER
254 .decode(tokens.clone())
255 .ok()
256 .unwrap_or_else(|| span.to_string())
257 } else {
258 span.to_string()
259 };
260
261 (output, tokens.len())
262 }
263
264 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
265 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
266 const MAX_RETRIES: usize = 4;
267
268 let api_key = OPENAI_API_KEY
269 .as_ref()
270 .ok_or_else(|| anyhow!("no api key"))?;
271
272 let mut request_number = 0;
273 let mut rate_limiting = false;
274 let mut request_timeout: u64 = 15;
275 let mut response: Response<AsyncBody>;
276 while request_number < MAX_RETRIES {
277 response = self
278 .send_request(
279 api_key,
280 spans.iter().map(|x| &**x).collect(),
281 request_timeout,
282 )
283 .await?;
284
285 request_number += 1;
286
287 match response.status() {
288 StatusCode::REQUEST_TIMEOUT => {
289 request_timeout += 5;
290 }
291 StatusCode::OK => {
292 let mut body = String::new();
293 response.body_mut().read_to_string(&mut body).await?;
294 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
295
296 log::trace!(
297 "openai embedding completed. tokens: {:?}",
298 response.usage.total_tokens
299 );
300
301 // If we complete a request successfully that was previously rate_limited
302 // resolve the rate limit
303 if rate_limiting {
304 self.resolve_rate_limit()
305 }
306
307 return Ok(response
308 .data
309 .into_iter()
310 .map(|embedding| Embedding::from(embedding.embedding))
311 .collect());
312 }
313 StatusCode::TOO_MANY_REQUESTS => {
314 rate_limiting = true;
315 let mut body = String::new();
316 response.body_mut().read_to_string(&mut body).await?;
317
318 let delay_duration = {
319 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
320 if let Some(time_to_reset) =
321 response.headers().get("x-ratelimit-reset-tokens")
322 {
323 if let Ok(time_str) = time_to_reset.to_str() {
324 parse(time_str).unwrap_or(delay)
325 } else {
326 delay
327 }
328 } else {
329 delay
330 }
331 };
332
333 // If we've previously rate limited, increment the duration but not the count
334 let reset_time = Instant::now().add(delay_duration);
335 self.update_reset_time(reset_time);
336
337 log::trace!(
338 "openai rate limiting: waiting {:?} until lifted",
339 &delay_duration
340 );
341
342 self.executor.timer(delay_duration).await;
343 }
344 _ => {
345 let mut body = String::new();
346 response.body_mut().read_to_string(&mut body).await?;
347 return Err(anyhow!(
348 "open ai bad request: {:?} {:?}",
349 &response.status(),
350 body
351 ));
352 }
353 }
354 }
355 Err(anyhow!("openai max retries"))
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use rand::prelude::*;
363
364 #[gpui::test]
365 fn test_similarity(mut rng: StdRng) {
366 assert_eq!(
367 Embedding::from(vec![1., 0., 0., 0., 0.])
368 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
369 0.
370 );
371 assert_eq!(
372 Embedding::from(vec![2., 0., 0., 0., 0.])
373 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
374 6.
375 );
376
377 for _ in 0..100 {
378 let size = 1536;
379 let mut a = vec![0.; size];
380 let mut b = vec![0.; size];
381 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
382 *a = rng.gen();
383 *b = rng.gen();
384 }
385 let a = Embedding::from(a);
386 let b = Embedding::from(b);
387
388 assert_eq!(
389 round_to_decimals(a.similarity(&b), 1),
390 round_to_decimals(reference_dot(&a.0, &b.0), 1)
391 );
392 }
393
394 fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
395 let factor = (10.0 as f32).powi(decimal_places);
396 (n * factor).round() / factor
397 }
398
399 fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
400 OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
401 }
402 }
403}