1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::future::BoxFuture;
4use futures::AsyncReadExt;
5use futures::FutureExt;
6use gpui::AppContext;
7use gpui::BackgroundExecutor;
8use isahc::http::StatusCode;
9use isahc::prelude::Configurable;
10use isahc::{AsyncBody, Response};
11use lazy_static::lazy_static;
12use parking_lot::{Mutex, RwLock};
13use parse_duration::parse;
14use postage::watch;
15use serde::{Deserialize, Serialize};
16use serde_json;
17use std::env;
18use std::ops::Add;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tiktoken_rs::{cl100k_base, CoreBPE};
22use util::http::{HttpClient, Request};
23use util::ResultExt;
24
25use crate::auth::{CredentialProvider, ProviderCredential};
26use crate::embedding::{Embedding, EmbeddingProvider};
27use crate::models::LanguageModel;
28use crate::providers::open_ai::OpenAiLanguageModel;
29
30use crate::providers::open_ai::OPEN_AI_API_URL;
31
32lazy_static! {
33 pub(crate) static ref OPEN_AI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
34}
35
36#[derive(Clone)]
37pub struct OpenAiEmbeddingProvider {
38 model: OpenAiLanguageModel,
39 credential: Arc<RwLock<ProviderCredential>>,
40 pub client: Arc<dyn HttpClient>,
41 pub executor: BackgroundExecutor,
42 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
43 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
44}
45
46#[derive(Serialize)]
47struct OpenAiEmbeddingRequest<'a> {
48 model: &'static str,
49 input: Vec<&'a str>,
50}
51
52#[derive(Deserialize)]
53struct OpenAiEmbeddingResponse {
54 data: Vec<OpenAiEmbedding>,
55 usage: OpenAiEmbeddingUsage,
56}
57
58#[derive(Debug, Deserialize)]
59struct OpenAiEmbedding {
60 embedding: Vec<f32>,
61 index: usize,
62 object: String,
63}
64
65#[derive(Deserialize)]
66struct OpenAiEmbeddingUsage {
67 prompt_tokens: usize,
68 total_tokens: usize,
69}
70
71impl OpenAiEmbeddingProvider {
72 pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
73 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
74 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
75
76 // Loading the model is expensive, so ensure this runs off the main thread.
77 let model = executor
78 .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
79 .await;
80 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
81
82 OpenAiEmbeddingProvider {
83 model,
84 credential,
85 client,
86 executor,
87 rate_limit_count_rx,
88 rate_limit_count_tx,
89 }
90 }
91
92 fn get_api_key(&self) -> Result<String> {
93 match self.credential.read().clone() {
94 ProviderCredential::Credentials { api_key } => Ok(api_key),
95 _ => Err(anyhow!("api credentials not provided")),
96 }
97 }
98
99 fn resolve_rate_limit(&self) {
100 let reset_time = *self.rate_limit_count_tx.lock().borrow();
101
102 if let Some(reset_time) = reset_time {
103 if Instant::now() >= reset_time {
104 *self.rate_limit_count_tx.lock().borrow_mut() = None
105 }
106 }
107
108 log::trace!(
109 "resolving reset time: {:?}",
110 *self.rate_limit_count_tx.lock().borrow()
111 );
112 }
113
114 fn update_reset_time(&self, reset_time: Instant) {
115 let original_time = *self.rate_limit_count_tx.lock().borrow();
116
117 let updated_time = if let Some(original_time) = original_time {
118 if reset_time < original_time {
119 Some(reset_time)
120 } else {
121 Some(original_time)
122 }
123 } else {
124 Some(reset_time)
125 };
126
127 log::trace!("updating rate limit time: {:?}", updated_time);
128
129 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
130 }
131 async fn send_request(
132 &self,
133 api_key: &str,
134 spans: Vec<&str>,
135 request_timeout: u64,
136 ) -> Result<Response<AsyncBody>> {
137 let request = Request::post("https://api.openai.com/v1/embeddings")
138 .redirect_policy(isahc::config::RedirectPolicy::Follow)
139 .timeout(Duration::from_secs(request_timeout))
140 .header("Content-Type", "application/json")
141 .header("Authorization", format!("Bearer {}", api_key))
142 .body(
143 serde_json::to_string(&OpenAiEmbeddingRequest {
144 input: spans.clone(),
145 model: "text-embedding-ada-002",
146 })
147 .unwrap()
148 .into(),
149 )?;
150
151 Ok(self.client.send(request).await?)
152 }
153}
154
155impl CredentialProvider for OpenAiEmbeddingProvider {
156 fn has_credentials(&self) -> bool {
157 match *self.credential.read() {
158 ProviderCredential::Credentials { .. } => true,
159 _ => false,
160 }
161 }
162
163 fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
164 let existing_credential = self.credential.read().clone();
165 let retrieved_credential = match existing_credential {
166 ProviderCredential::Credentials { .. } => {
167 return async move { existing_credential }.boxed()
168 }
169 _ => {
170 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
171 async move { ProviderCredential::Credentials { api_key } }.boxed()
172 } else {
173 let credentials = cx.read_credentials(OPEN_AI_API_URL);
174 async move {
175 if let Some(Some((_, api_key))) = credentials.await.log_err() {
176 if let Some(api_key) = String::from_utf8(api_key).log_err() {
177 ProviderCredential::Credentials { api_key }
178 } else {
179 ProviderCredential::NoCredentials
180 }
181 } else {
182 ProviderCredential::NoCredentials
183 }
184 }
185 .boxed()
186 }
187 }
188 };
189
190 async move {
191 let retrieved_credential = retrieved_credential.await;
192 *self.credential.write() = retrieved_credential.clone();
193 retrieved_credential
194 }
195 .boxed()
196 }
197
198 fn save_credentials(
199 &self,
200 cx: &mut AppContext,
201 credential: ProviderCredential,
202 ) -> BoxFuture<()> {
203 *self.credential.write() = credential.clone();
204 let credential = credential.clone();
205 let write_credentials = match credential {
206 ProviderCredential::Credentials { api_key } => {
207 Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
208 }
209 _ => None,
210 };
211
212 async move {
213 if let Some(write_credentials) = write_credentials {
214 write_credentials.await.log_err();
215 }
216 }
217 .boxed()
218 }
219
220 fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
221 *self.credential.write() = ProviderCredential::NoCredentials;
222 let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
223 async move {
224 delete_credentials.await.log_err();
225 }
226 .boxed()
227 }
228}
229
230#[async_trait]
231impl EmbeddingProvider for OpenAiEmbeddingProvider {
232 fn base_model(&self) -> Box<dyn LanguageModel> {
233 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
234 model
235 }
236
237 fn max_tokens_per_batch(&self) -> usize {
238 50000
239 }
240
241 fn rate_limit_expiration(&self) -> Option<Instant> {
242 *self.rate_limit_count_rx.borrow()
243 }
244
245 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
246 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
247 const MAX_RETRIES: usize = 4;
248
249 let api_key = self.get_api_key()?;
250
251 let mut request_number = 0;
252 let mut rate_limiting = false;
253 let mut request_timeout: u64 = 15;
254 let mut response: Response<AsyncBody>;
255 while request_number < MAX_RETRIES {
256 response = self
257 .send_request(
258 &api_key,
259 spans.iter().map(|x| &**x).collect(),
260 request_timeout,
261 )
262 .await?;
263
264 request_number += 1;
265
266 match response.status() {
267 StatusCode::REQUEST_TIMEOUT => {
268 request_timeout += 5;
269 }
270 StatusCode::OK => {
271 let mut body = String::new();
272 response.body_mut().read_to_string(&mut body).await?;
273 let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
274
275 log::trace!(
276 "openai embedding completed. tokens: {:?}",
277 response.usage.total_tokens
278 );
279
280 // If we complete a request successfully that was previously rate_limited
281 // resolve the rate limit
282 if rate_limiting {
283 self.resolve_rate_limit()
284 }
285
286 return Ok(response
287 .data
288 .into_iter()
289 .map(|embedding| Embedding::from(embedding.embedding))
290 .collect());
291 }
292 StatusCode::TOO_MANY_REQUESTS => {
293 rate_limiting = true;
294 let mut body = String::new();
295 response.body_mut().read_to_string(&mut body).await?;
296
297 let delay_duration = {
298 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
299 if let Some(time_to_reset) =
300 response.headers().get("x-ratelimit-reset-tokens")
301 {
302 if let Ok(time_str) = time_to_reset.to_str() {
303 parse(time_str).unwrap_or(delay)
304 } else {
305 delay
306 }
307 } else {
308 delay
309 }
310 };
311
312 // If we've previously rate limited, increment the duration but not the count
313 let reset_time = Instant::now().add(delay_duration);
314 self.update_reset_time(reset_time);
315
316 log::trace!(
317 "openai rate limiting: waiting {:?} until lifted",
318 &delay_duration
319 );
320
321 self.executor.timer(delay_duration).await;
322 }
323 _ => {
324 let mut body = String::new();
325 response.body_mut().read_to_string(&mut body).await?;
326 return Err(anyhow!(
327 "open ai bad request: {:?} {:?}",
328 &response.status(),
329 body
330 ));
331 }
332 }
333 }
334 Err(anyhow!("openai max retries"))
335 }
336}