1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parse_duration::parse;
11use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
12use rusqlite::ToSql;
13use serde::{Deserialize, Serialize};
14use std::env;
15use std::sync::Arc;
16use std::time::Duration;
17use tiktoken_rs::{cl100k_base, CoreBPE};
18use util::http::{HttpClient, Request};
19
20lazy_static! {
21 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
22 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
23}
24
25#[derive(Debug, PartialEq, Clone)]
26pub struct Embedding(Vec<f32>);
27
28impl From<Vec<f32>> for Embedding {
29 fn from(value: Vec<f32>) -> Self {
30 Embedding(value)
31 }
32}
33
34impl Embedding {
35 pub fn similarity(&self, other: &Self) -> f32 {
36 let len = self.0.len();
37 assert_eq!(len, other.0.len());
38
39 let mut result = 0.0;
40 unsafe {
41 matrixmultiply::sgemm(
42 1,
43 len,
44 1,
45 1.0,
46 self.0.as_ptr(),
47 len as isize,
48 1,
49 other.0.as_ptr(),
50 1,
51 len as isize,
52 0.0,
53 &mut result as *mut f32,
54 1,
55 1,
56 );
57 }
58 result
59 }
60}
61
62impl FromSql for Embedding {
63 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
64 let bytes = value.as_blob()?;
65 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
66 if embedding.is_err() {
67 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
68 }
69 Ok(Embedding(embedding.unwrap()))
70 }
71}
72
73impl ToSql for Embedding {
74 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
75 let bytes = bincode::serialize(&self.0)
76 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
77 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
78 }
79}
80
81#[derive(Clone)]
82pub struct OpenAIEmbeddings {
83 pub client: Arc<dyn HttpClient>,
84 pub executor: Arc<Background>,
85}
86
87#[derive(Serialize)]
88struct OpenAIEmbeddingRequest<'a> {
89 model: &'static str,
90 input: Vec<&'a str>,
91}
92
93#[derive(Deserialize)]
94struct OpenAIEmbeddingResponse {
95 data: Vec<OpenAIEmbedding>,
96 usage: OpenAIEmbeddingUsage,
97}
98
99#[derive(Debug, Deserialize)]
100struct OpenAIEmbedding {
101 embedding: Vec<f32>,
102 index: usize,
103 object: String,
104}
105
106#[derive(Deserialize)]
107struct OpenAIEmbeddingUsage {
108 prompt_tokens: usize,
109 total_tokens: usize,
110}
111
112#[async_trait]
113pub trait EmbeddingProvider: Sync + Send {
114 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
115 fn max_tokens_per_batch(&self) -> usize;
116 fn truncate(&self, span: &str) -> (String, usize);
117}
118
119pub struct DummyEmbeddings {}
120
121#[async_trait]
122impl EmbeddingProvider for DummyEmbeddings {
123 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
124 // 1024 is the OpenAI Embeddings size for ada models.
125 // the model we will likely be starting with.
126 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
127 return Ok(vec![dummy_vec; spans.len()]);
128 }
129
130 fn max_tokens_per_batch(&self) -> usize {
131 OPENAI_INPUT_LIMIT
132 }
133
134 fn truncate(&self, span: &str) -> (String, usize) {
135 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
136 let token_count = tokens.len();
137 let output = if token_count > OPENAI_INPUT_LIMIT {
138 tokens.truncate(OPENAI_INPUT_LIMIT);
139 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
140 new_input.ok().unwrap_or_else(|| span.to_string())
141 } else {
142 span.to_string()
143 };
144
145 (output, tokens.len())
146 }
147}
148
149const OPENAI_INPUT_LIMIT: usize = 8190;
150
151impl OpenAIEmbeddings {
152 async fn send_request(
153 &self,
154 api_key: &str,
155 spans: Vec<&str>,
156 request_timeout: u64,
157 ) -> Result<Response<AsyncBody>> {
158 let request = Request::post("https://api.openai.com/v1/embeddings")
159 .redirect_policy(isahc::config::RedirectPolicy::Follow)
160 .timeout(Duration::from_secs(request_timeout))
161 .header("Content-Type", "application/json")
162 .header("Authorization", format!("Bearer {}", api_key))
163 .body(
164 serde_json::to_string(&OpenAIEmbeddingRequest {
165 input: spans.clone(),
166 model: "text-embedding-ada-002",
167 })
168 .unwrap()
169 .into(),
170 )?;
171
172 Ok(self.client.send(request).await?)
173 }
174}
175
176#[async_trait]
177impl EmbeddingProvider for OpenAIEmbeddings {
178 fn max_tokens_per_batch(&self) -> usize {
179 50000
180 }
181
182 fn truncate(&self, span: &str) -> (String, usize) {
183 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
184 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
185 tokens.truncate(OPENAI_INPUT_LIMIT);
186 OPENAI_BPE_TOKENIZER
187 .decode(tokens.clone())
188 .ok()
189 .unwrap_or_else(|| span.to_string())
190 } else {
191 span.to_string()
192 };
193
194 (output, tokens.len())
195 }
196
197 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
198 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
199 const MAX_RETRIES: usize = 4;
200
201 let api_key = OPENAI_API_KEY
202 .as_ref()
203 .ok_or_else(|| anyhow!("no api key"))?;
204
205 let mut request_number = 0;
206 let mut request_timeout: u64 = 15;
207 let mut response: Response<AsyncBody>;
208 while request_number < MAX_RETRIES {
209 response = self
210 .send_request(
211 api_key,
212 spans.iter().map(|x| &**x).collect(),
213 request_timeout,
214 )
215 .await?;
216 request_number += 1;
217
218 match response.status() {
219 StatusCode::REQUEST_TIMEOUT => {
220 request_timeout += 5;
221 }
222 StatusCode::OK => {
223 let mut body = String::new();
224 response.body_mut().read_to_string(&mut body).await?;
225 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
226
227 log::trace!(
228 "openai embedding completed. tokens: {:?}",
229 response.usage.total_tokens
230 );
231
232 return Ok(response
233 .data
234 .into_iter()
235 .map(|embedding| Embedding::from(embedding.embedding))
236 .collect());
237 }
238 StatusCode::TOO_MANY_REQUESTS => {
239 let mut body = String::new();
240 response.body_mut().read_to_string(&mut body).await?;
241
242 let delay_duration = {
243 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
244 if let Some(time_to_reset) =
245 response.headers().get("x-ratelimit-reset-tokens")
246 {
247 if let Ok(time_str) = time_to_reset.to_str() {
248 parse(time_str).unwrap_or(delay)
249 } else {
250 delay
251 }
252 } else {
253 delay
254 }
255 };
256
257 log::trace!(
258 "openai rate limiting: waiting {:?} until lifted",
259 &delay_duration
260 );
261
262 self.executor.timer(delay_duration).await;
263 }
264 _ => {
265 let mut body = String::new();
266 response.body_mut().read_to_string(&mut body).await?;
267 return Err(anyhow!(
268 "open ai bad request: {:?} {:?}",
269 &response.status(),
270 body
271 ));
272 }
273 }
274 }
275 Err(anyhow!("openai max retries"))
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use rand::prelude::*;
283
284 #[gpui::test]
285 fn test_similarity(mut rng: StdRng) {
286 assert_eq!(
287 Embedding::from(vec![1., 0., 0., 0., 0.])
288 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
289 0.
290 );
291 assert_eq!(
292 Embedding::from(vec![2., 0., 0., 0., 0.])
293 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
294 6.
295 );
296
297 for _ in 0..100 {
298 let size = 1536;
299 let mut a = vec![0.; size];
300 let mut b = vec![0.; size];
301 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
302 *a = rng.gen();
303 *b = rng.gen();
304 }
305 let a = Embedding::from(a);
306 let b = Embedding::from(b);
307
308 assert_eq!(
309 round_to_decimals(a.similarity(&b), 1),
310 round_to_decimals(reference_dot(&a.0, &b.0), 1)
311 );
312 }
313
314 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
315 let factor = (10.0 as f32).powi(decimal_places);
316 (n * factor).round() / factor
317 }
318
319 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
320 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
321 }
322 }
323}