1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parse_duration::parse;
11use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
12use rusqlite::ToSql;
13use serde::{Deserialize, Serialize};
14use std::env;
15use std::sync::Arc;
16use std::time::Duration;
17use tiktoken_rs::{cl100k_base, CoreBPE};
18use util::http::{HttpClient, Request};
19
20lazy_static! {
21 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
22 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
23}
24
25#[derive(Debug, PartialEq, Clone)]
26pub struct Embedding(Vec<f32>);
27
28impl From<Vec<f32>> for Embedding {
29 fn from(value: Vec<f32>) -> Self {
30 Embedding(value)
31 }
32}
33
34impl Embedding {
35 pub fn similarity(&self, other: &Self) -> f32 {
36 let len = self.0.len();
37 assert_eq!(len, other.0.len());
38
39 let mut result = 0.0;
40 unsafe {
41 matrixmultiply::sgemm(
42 1,
43 len,
44 1,
45 1.0,
46 self.0.as_ptr(),
47 len as isize,
48 1,
49 other.0.as_ptr(),
50 1,
51 len as isize,
52 0.0,
53 &mut result as *mut f32,
54 1,
55 1,
56 );
57 }
58 result
59 }
60}
61
62impl FromSql for Embedding {
63 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
64 let bytes = value.as_blob()?;
65 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
66 if embedding.is_err() {
67 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
68 }
69 Ok(Embedding(embedding.unwrap()))
70 }
71}
72
73impl ToSql for Embedding {
74 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
75 let bytes = bincode::serialize(&self.0)
76 .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
77 Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
78 }
79}
80
81#[derive(Clone)]
82pub struct OpenAIEmbeddings {
83 pub client: Arc<dyn HttpClient>,
84 pub executor: Arc<Background>,
85}
86
87#[derive(Serialize)]
88struct OpenAIEmbeddingRequest<'a> {
89 model: &'static str,
90 input: Vec<&'a str>,
91}
92
93#[derive(Deserialize)]
94struct OpenAIEmbeddingResponse {
95 data: Vec<OpenAIEmbedding>,
96 usage: OpenAIEmbeddingUsage,
97}
98
99#[derive(Debug, Deserialize)]
100struct OpenAIEmbedding {
101 embedding: Vec<f32>,
102 index: usize,
103 object: String,
104}
105
106#[derive(Deserialize)]
107struct OpenAIEmbeddingUsage {
108 prompt_tokens: usize,
109 total_tokens: usize,
110}
111
112#[async_trait]
113pub trait EmbeddingProvider: Sync + Send {
114 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
115 fn max_tokens_per_batch(&self) -> usize;
116 fn truncate(&self, span: &str) -> (String, usize);
117}
118
119pub struct DummyEmbeddings {}
120
121#[async_trait]
122impl EmbeddingProvider for DummyEmbeddings {
123 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
124 // 1024 is the OpenAI Embeddings size for ada models.
125 // the model we will likely be starting with.
126 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
127 return Ok(vec![dummy_vec; spans.len()]);
128 }
129
130 fn max_tokens_per_batch(&self) -> usize {
131 OPENAI_INPUT_LIMIT
132 }
133
134 fn truncate(&self, span: &str) -> (String, usize) {
135 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
136 let token_count = tokens.len();
137 let output = if token_count > OPENAI_INPUT_LIMIT {
138 tokens.truncate(OPENAI_INPUT_LIMIT);
139 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
140 new_input.ok().unwrap_or_else(|| span.to_string())
141 } else {
142 span.to_string()
143 };
144
145 (output, tokens.len())
146 }
147}
148
149const OPENAI_INPUT_LIMIT: usize = 8190;
150
151impl OpenAIEmbeddings {
152 async fn send_request(
153 &self,
154 api_key: &str,
155 spans: Vec<&str>,
156 request_timeout: u64,
157 ) -> Result<Response<AsyncBody>> {
158 let request = Request::post("https://api.openai.com/v1/embeddings")
159 .redirect_policy(isahc::config::RedirectPolicy::Follow)
160 .timeout(Duration::from_secs(request_timeout))
161 .header("Content-Type", "application/json")
162 .header("Authorization", format!("Bearer {}", api_key))
163 .body(
164 serde_json::to_string(&OpenAIEmbeddingRequest {
165 input: spans.clone(),
166 model: "text-embedding-ada-002",
167 })
168 .unwrap()
169 .into(),
170 )?;
171
172 Ok(self.client.send(request).await?)
173 }
174}
175
176#[async_trait]
177impl EmbeddingProvider for OpenAIEmbeddings {
178 fn max_tokens_per_batch(&self) -> usize {
179 50000
180 }
181
182 fn truncate(&self, span: &str) -> (String, usize) {
183 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
184 let token_count = tokens.len();
185 let output = if token_count > OPENAI_INPUT_LIMIT {
186 tokens.truncate(OPENAI_INPUT_LIMIT);
187 OPENAI_BPE_TOKENIZER
188 .decode(tokens)
189 .ok()
190 .unwrap_or_else(|| span.to_string())
191 } else {
192 span.to_string()
193 };
194
195 (output, token_count)
196 }
197
198 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
199 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
200 const MAX_RETRIES: usize = 4;
201
202 let api_key = OPENAI_API_KEY
203 .as_ref()
204 .ok_or_else(|| anyhow!("no api key"))?;
205
206 let mut request_number = 0;
207 let mut request_timeout: u64 = 10;
208 let mut response: Response<AsyncBody>;
209 while request_number < MAX_RETRIES {
210 response = self
211 .send_request(
212 api_key,
213 spans.iter().map(|x| &**x).collect(),
214 request_timeout,
215 )
216 .await?;
217 request_number += 1;
218
219 match response.status() {
220 StatusCode::REQUEST_TIMEOUT => {
221 request_timeout += 5;
222 }
223 StatusCode::OK => {
224 let mut body = String::new();
225 response.body_mut().read_to_string(&mut body).await?;
226 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
227
228 log::trace!(
229 "openai embedding completed. tokens: {:?}",
230 response.usage.total_tokens
231 );
232
233 return Ok(response
234 .data
235 .into_iter()
236 .map(|embedding| Embedding::from(embedding.embedding))
237 .collect());
238 }
239 StatusCode::TOO_MANY_REQUESTS => {
240 let mut body = String::new();
241 response.body_mut().read_to_string(&mut body).await?;
242
243 let delay_duration = {
244 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
245 if let Some(time_to_reset) =
246 response.headers().get("x-ratelimit-reset-tokens")
247 {
248 if let Ok(time_str) = time_to_reset.to_str() {
249 parse(time_str).unwrap_or(delay)
250 } else {
251 delay
252 }
253 } else {
254 delay
255 }
256 };
257
258 log::trace!(
259 "openai rate limiting: waiting {:?} until lifted",
260 &delay_duration
261 );
262
263 self.executor.timer(delay_duration).await;
264 }
265 _ => {
266 let mut body = String::new();
267 response.body_mut().read_to_string(&mut body).await?;
268 return Err(anyhow!(
269 "open ai bad request: {:?} {:?}",
270 &response.status(),
271 body
272 ));
273 }
274 }
275 }
276 Err(anyhow!("openai max retries"))
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use rand::prelude::*;
284
285 #[gpui::test]
286 fn test_similarity(mut rng: StdRng) {
287 assert_eq!(
288 Embedding::from(vec![1., 0., 0., 0., 0.])
289 .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
290 0.
291 );
292 assert_eq!(
293 Embedding::from(vec![2., 0., 0., 0., 0.])
294 .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
295 6.
296 );
297
298 for _ in 0..100 {
299 let size = 1536;
300 let mut a = vec![0.; size];
301 let mut b = vec![0.; size];
302 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
303 *a = rng.gen();
304 *b = rng.gen();
305 }
306 let a = Embedding::from(a);
307 let b = Embedding::from(b);
308
309 assert_eq!(
310 round_to_decimals(a.similarity(&b), 1),
311 round_to_decimals(reference_dot(&a.0, &b.0), 1)
312 );
313 }
314
315 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
316 let factor = (10.0 as f32).powi(decimal_places);
317 (n * factor).round() / factor
318 }
319
320 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
321 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
322 }
323 }
324}