1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::future::BoxFuture;
4use futures::AsyncReadExt;
5use futures::FutureExt;
6use gpui::AppContext;
7use gpui::BackgroundExecutor;
8use isahc::http::StatusCode;
9use isahc::prelude::Configurable;
10use isahc::{AsyncBody, Response};
11use parking_lot::{Mutex, RwLock};
12use parse_duration::parse;
13use postage::watch;
14use serde::{Deserialize, Serialize};
15use serde_json;
16use std::env;
17use std::ops::Add;
18use std::sync::{Arc, OnceLock};
19use std::time::{Duration, Instant};
20use tiktoken_rs::{cl100k_base, CoreBPE};
21use util::http::{HttpClient, Request};
22use util::ResultExt;
23
24use crate::auth::{CredentialProvider, ProviderCredential};
25use crate::embedding::{Embedding, EmbeddingProvider};
26use crate::models::LanguageModel;
27use crate::providers::open_ai::OpenAiLanguageModel;
28
29use crate::providers::open_ai::OPEN_AI_API_URL;
30
31pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
32 static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
33 OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
34}
35
36#[derive(Clone)]
37pub struct OpenAiEmbeddingProvider {
38 api_url: String,
39 model: OpenAiLanguageModel,
40 credential: Arc<RwLock<ProviderCredential>>,
41 pub client: Arc<dyn HttpClient>,
42 pub executor: BackgroundExecutor,
43 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
44 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
45}
46
47#[derive(Serialize)]
48struct OpenAiEmbeddingRequest<'a> {
49 model: &'static str,
50 input: Vec<&'a str>,
51}
52
53#[derive(Deserialize)]
54struct OpenAiEmbeddingResponse {
55 data: Vec<OpenAiEmbedding>,
56 usage: OpenAiEmbeddingUsage,
57}
58
59#[derive(Debug, Deserialize)]
60struct OpenAiEmbedding {
61 embedding: Vec<f32>,
62 index: usize,
63 object: String,
64}
65
66#[derive(Deserialize)]
67struct OpenAiEmbeddingUsage {
68 prompt_tokens: usize,
69 total_tokens: usize,
70}
71
72impl OpenAiEmbeddingProvider {
73 pub async fn new(
74 api_url: String,
75 client: Arc<dyn HttpClient>,
76 executor: BackgroundExecutor,
77 ) -> Self {
78 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
79 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
80
81 // Loading the model is expensive, so ensure this runs off the main thread.
82 let model = executor
83 .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
84 .await;
85 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
86
87 OpenAiEmbeddingProvider {
88 api_url,
89 model,
90 credential,
91 client,
92 executor,
93 rate_limit_count_rx,
94 rate_limit_count_tx,
95 }
96 }
97
98 fn get_api_key(&self) -> Result<String> {
99 match self.credential.read().clone() {
100 ProviderCredential::Credentials { api_key } => Ok(api_key),
101 _ => Err(anyhow!("api credentials not provided")),
102 }
103 }
104
105 fn resolve_rate_limit(&self) {
106 let reset_time = *self.rate_limit_count_tx.lock().borrow();
107
108 if let Some(reset_time) = reset_time {
109 if Instant::now() >= reset_time {
110 *self.rate_limit_count_tx.lock().borrow_mut() = None
111 }
112 }
113
114 log::trace!(
115 "resolving reset time: {:?}",
116 *self.rate_limit_count_tx.lock().borrow()
117 );
118 }
119
120 fn update_reset_time(&self, reset_time: Instant) {
121 let original_time = *self.rate_limit_count_tx.lock().borrow();
122
123 let updated_time = if let Some(original_time) = original_time {
124 if reset_time < original_time {
125 Some(reset_time)
126 } else {
127 Some(original_time)
128 }
129 } else {
130 Some(reset_time)
131 };
132
133 log::trace!("updating rate limit time: {:?}", updated_time);
134
135 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
136 }
137 async fn send_request(
138 &self,
139 api_url: &str,
140 api_key: &str,
141 spans: Vec<&str>,
142 request_timeout: u64,
143 ) -> Result<Response<AsyncBody>> {
144 let request = Request::post(format!("{api_url}/embeddings"))
145 .redirect_policy(isahc::config::RedirectPolicy::Follow)
146 .timeout(Duration::from_secs(request_timeout))
147 .header("Content-Type", "application/json")
148 .header("Authorization", format!("Bearer {}", api_key))
149 .body(
150 serde_json::to_string(&OpenAiEmbeddingRequest {
151 input: spans.clone(),
152 model: "text-embedding-ada-002",
153 })
154 .unwrap()
155 .into(),
156 )?;
157
158 Ok(self.client.send(request).await?)
159 }
160}
161
162impl CredentialProvider for OpenAiEmbeddingProvider {
163 fn has_credentials(&self) -> bool {
164 match *self.credential.read() {
165 ProviderCredential::Credentials { .. } => true,
166 _ => false,
167 }
168 }
169
170 fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
171 let existing_credential = self.credential.read().clone();
172 let retrieved_credential = match existing_credential {
173 ProviderCredential::Credentials { .. } => {
174 return async move { existing_credential }.boxed()
175 }
176 _ => {
177 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
178 async move { ProviderCredential::Credentials { api_key } }.boxed()
179 } else {
180 let credentials = cx.read_credentials(OPEN_AI_API_URL);
181 async move {
182 if let Some(Some((_, api_key))) = credentials.await.log_err() {
183 if let Some(api_key) = String::from_utf8(api_key).log_err() {
184 ProviderCredential::Credentials { api_key }
185 } else {
186 ProviderCredential::NoCredentials
187 }
188 } else {
189 ProviderCredential::NoCredentials
190 }
191 }
192 .boxed()
193 }
194 }
195 };
196
197 async move {
198 let retrieved_credential = retrieved_credential.await;
199 *self.credential.write() = retrieved_credential.clone();
200 retrieved_credential
201 }
202 .boxed()
203 }
204
205 fn save_credentials(
206 &self,
207 cx: &mut AppContext,
208 credential: ProviderCredential,
209 ) -> BoxFuture<()> {
210 *self.credential.write() = credential.clone();
211 let credential = credential.clone();
212 let write_credentials = match credential {
213 ProviderCredential::Credentials { api_key } => {
214 Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
215 }
216 _ => None,
217 };
218
219 async move {
220 if let Some(write_credentials) = write_credentials {
221 write_credentials.await.log_err();
222 }
223 }
224 .boxed()
225 }
226
227 fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
228 *self.credential.write() = ProviderCredential::NoCredentials;
229 let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
230 async move {
231 delete_credentials.await.log_err();
232 }
233 .boxed()
234 }
235}
236
237#[async_trait]
238impl EmbeddingProvider for OpenAiEmbeddingProvider {
239 fn base_model(&self) -> Box<dyn LanguageModel> {
240 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
241 model
242 }
243
244 fn max_tokens_per_batch(&self) -> usize {
245 50000
246 }
247
248 fn rate_limit_expiration(&self) -> Option<Instant> {
249 *self.rate_limit_count_rx.borrow()
250 }
251
252 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
253 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
254 const MAX_RETRIES: usize = 4;
255
256 let api_url = self.api_url.as_str();
257 let api_key = self.get_api_key()?;
258
259 let mut request_number = 0;
260 let mut rate_limiting = false;
261 let mut request_timeout: u64 = 15;
262 let mut response: Response<AsyncBody>;
263 while request_number < MAX_RETRIES {
264 response = self
265 .send_request(
266 &api_url,
267 &api_key,
268 spans.iter().map(|x| &**x).collect(),
269 request_timeout,
270 )
271 .await?;
272
273 request_number += 1;
274
275 match response.status() {
276 StatusCode::REQUEST_TIMEOUT => {
277 request_timeout += 5;
278 }
279 StatusCode::OK => {
280 let mut body = String::new();
281 response.body_mut().read_to_string(&mut body).await?;
282 let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
283
284 log::trace!(
285 "openai embedding completed. tokens: {:?}",
286 response.usage.total_tokens
287 );
288
289 // If we complete a request successfully that was previously rate_limited
290 // resolve the rate limit
291 if rate_limiting {
292 self.resolve_rate_limit()
293 }
294
295 return Ok(response
296 .data
297 .into_iter()
298 .map(|embedding| Embedding::from(embedding.embedding))
299 .collect());
300 }
301 StatusCode::TOO_MANY_REQUESTS => {
302 rate_limiting = true;
303 let mut body = String::new();
304 response.body_mut().read_to_string(&mut body).await?;
305
306 let delay_duration = {
307 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
308 if let Some(time_to_reset) =
309 response.headers().get("x-ratelimit-reset-tokens")
310 {
311 if let Ok(time_str) = time_to_reset.to_str() {
312 parse(time_str).unwrap_or(delay)
313 } else {
314 delay
315 }
316 } else {
317 delay
318 }
319 };
320
321 // If we've previously rate limited, increment the duration but not the count
322 let reset_time = Instant::now().add(delay_duration);
323 self.update_reset_time(reset_time);
324
325 log::trace!(
326 "openai rate limiting: waiting {:?} until lifted",
327 &delay_duration
328 );
329
330 self.executor.timer(delay_duration).await;
331 }
332 _ => {
333 let mut body = String::new();
334 response.body_mut().read_to_string(&mut body).await?;
335 return Err(anyhow!(
336 "open ai bad request: {:?} {:?}",
337 &response.status(),
338 body
339 ));
340 }
341 }
342 }
343 Err(anyhow!("openai max retries"))
344 }
345}