1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui2::Executor;
5use gpui2::{serde_json, AppContext};
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parking_lot::{Mutex, RwLock};
11use parse_duration::parse;
12use postage::watch;
13use serde::{Deserialize, Serialize};
14use std::env;
15use std::ops::Add;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tiktoken_rs::{cl100k_base, CoreBPE};
19use util::http::{HttpClient, Request};
20use util::ResultExt;
21
22use crate::auth::{CredentialProvider, ProviderCredential};
23use crate::embedding::{Embedding, EmbeddingProvider};
24use crate::models::LanguageModel;
25use crate::providers::open_ai::OpenAILanguageModel;
26
27use crate::providers::open_ai::OPENAI_API_URL;
28
29lazy_static! {
30 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
31}
32
33#[derive(Clone)]
34pub struct OpenAIEmbeddingProvider {
35 model: OpenAILanguageModel,
36 credential: Arc<RwLock<ProviderCredential>>,
37 pub client: Arc<dyn HttpClient>,
38 pub executor: Arc<Executor>,
39 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
40 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
41}
42
43#[derive(Serialize)]
44struct OpenAIEmbeddingRequest<'a> {
45 model: &'static str,
46 input: Vec<&'a str>,
47}
48
49#[derive(Deserialize)]
50struct OpenAIEmbeddingResponse {
51 data: Vec<OpenAIEmbedding>,
52 usage: OpenAIEmbeddingUsage,
53}
54
55#[derive(Debug, Deserialize)]
56struct OpenAIEmbedding {
57 embedding: Vec<f32>,
58 index: usize,
59 object: String,
60}
61
62#[derive(Deserialize)]
63struct OpenAIEmbeddingUsage {
64 prompt_tokens: usize,
65 total_tokens: usize,
66}
67
68impl OpenAIEmbeddingProvider {
69 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
70 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
71 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
72
73 let model = OpenAILanguageModel::load("text-embedding-ada-002");
74 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
75
76 OpenAIEmbeddingProvider {
77 model,
78 credential,
79 client,
80 executor,
81 rate_limit_count_rx,
82 rate_limit_count_tx,
83 }
84 }
85
86 fn get_api_key(&self) -> Result<String> {
87 match self.credential.read().clone() {
88 ProviderCredential::Credentials { api_key } => Ok(api_key),
89 _ => Err(anyhow!("api credentials not provided")),
90 }
91 }
92
93 fn resolve_rate_limit(&self) {
94 let reset_time = *self.rate_limit_count_tx.lock().borrow();
95
96 if let Some(reset_time) = reset_time {
97 if Instant::now() >= reset_time {
98 *self.rate_limit_count_tx.lock().borrow_mut() = None
99 }
100 }
101
102 log::trace!(
103 "resolving reset time: {:?}",
104 *self.rate_limit_count_tx.lock().borrow()
105 );
106 }
107
108 fn update_reset_time(&self, reset_time: Instant) {
109 let original_time = *self.rate_limit_count_tx.lock().borrow();
110
111 let updated_time = if let Some(original_time) = original_time {
112 if reset_time < original_time {
113 Some(reset_time)
114 } else {
115 Some(original_time)
116 }
117 } else {
118 Some(reset_time)
119 };
120
121 log::trace!("updating rate limit time: {:?}", updated_time);
122
123 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
124 }
125 async fn send_request(
126 &self,
127 api_key: &str,
128 spans: Vec<&str>,
129 request_timeout: u64,
130 ) -> Result<Response<AsyncBody>> {
131 let request = Request::post("https://api.openai.com/v1/embeddings")
132 .redirect_policy(isahc::config::RedirectPolicy::Follow)
133 .timeout(Duration::from_secs(request_timeout))
134 .header("Content-Type", "application/json")
135 .header("Authorization", format!("Bearer {}", api_key))
136 .body(
137 serde_json::to_string(&OpenAIEmbeddingRequest {
138 input: spans.clone(),
139 model: "text-embedding-ada-002",
140 })
141 .unwrap()
142 .into(),
143 )?;
144
145 Ok(self.client.send(request).await?)
146 }
147}
148
149#[async_trait]
150impl CredentialProvider for OpenAIEmbeddingProvider {
151 fn has_credentials(&self) -> bool {
152 match *self.credential.read() {
153 ProviderCredential::Credentials { .. } => true,
154 _ => false,
155 }
156 }
157 async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
158 let existing_credential = self.credential.read().clone();
159
160 let retrieved_credential = cx
161 .run_on_main(move |cx| match existing_credential {
162 ProviderCredential::Credentials { .. } => {
163 return existing_credential.clone();
164 }
165 _ => {
166 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
167 return ProviderCredential::Credentials { api_key };
168 }
169
170 if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
171 {
172 if let Some(api_key) = String::from_utf8(api_key).log_err() {
173 return ProviderCredential::Credentials { api_key };
174 } else {
175 return ProviderCredential::NoCredentials;
176 }
177 } else {
178 return ProviderCredential::NoCredentials;
179 }
180 }
181 })
182 .await;
183
184 *self.credential.write() = retrieved_credential.clone();
185 retrieved_credential
186 }
187
188 async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
189 *self.credential.write() = credential.clone();
190 let credential = credential.clone();
191 cx.run_on_main(move |cx| match credential {
192 ProviderCredential::Credentials { api_key } => {
193 cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
194 .log_err();
195 }
196 _ => {}
197 })
198 .await;
199 }
200 async fn delete_credentials(&self, cx: &mut AppContext) {
201 cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
202 .await;
203 *self.credential.write() = ProviderCredential::NoCredentials;
204 }
205}
206
207#[async_trait]
208impl EmbeddingProvider for OpenAIEmbeddingProvider {
209 fn base_model(&self) -> Box<dyn LanguageModel> {
210 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
211 model
212 }
213
214 fn max_tokens_per_batch(&self) -> usize {
215 50000
216 }
217
218 fn rate_limit_expiration(&self) -> Option<Instant> {
219 *self.rate_limit_count_rx.borrow()
220 }
221
222 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
223 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
224 const MAX_RETRIES: usize = 4;
225
226 let api_key = self.get_api_key()?;
227
228 let mut request_number = 0;
229 let mut rate_limiting = false;
230 let mut request_timeout: u64 = 15;
231 let mut response: Response<AsyncBody>;
232 while request_number < MAX_RETRIES {
233 response = self
234 .send_request(
235 &api_key,
236 spans.iter().map(|x| &**x).collect(),
237 request_timeout,
238 )
239 .await?;
240
241 request_number += 1;
242
243 match response.status() {
244 StatusCode::REQUEST_TIMEOUT => {
245 request_timeout += 5;
246 }
247 StatusCode::OK => {
248 let mut body = String::new();
249 response.body_mut().read_to_string(&mut body).await?;
250 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
251
252 log::trace!(
253 "openai embedding completed. tokens: {:?}",
254 response.usage.total_tokens
255 );
256
257 // If we complete a request successfully that was previously rate_limited
258 // resolve the rate limit
259 if rate_limiting {
260 self.resolve_rate_limit()
261 }
262
263 return Ok(response
264 .data
265 .into_iter()
266 .map(|embedding| Embedding::from(embedding.embedding))
267 .collect());
268 }
269 StatusCode::TOO_MANY_REQUESTS => {
270 rate_limiting = true;
271 let mut body = String::new();
272 response.body_mut().read_to_string(&mut body).await?;
273
274 let delay_duration = {
275 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
276 if let Some(time_to_reset) =
277 response.headers().get("x-ratelimit-reset-tokens")
278 {
279 if let Ok(time_str) = time_to_reset.to_str() {
280 parse(time_str).unwrap_or(delay)
281 } else {
282 delay
283 }
284 } else {
285 delay
286 }
287 };
288
289 // If we've previously rate limited, increment the duration but not the count
290 let reset_time = Instant::now().add(delay_duration);
291 self.update_reset_time(reset_time);
292
293 log::trace!(
294 "openai rate limiting: waiting {:?} until lifted",
295 &delay_duration
296 );
297
298 self.executor.timer(delay_duration).await;
299 }
300 _ => {
301 let mut body = String::new();
302 response.body_mut().read_to_string(&mut body).await?;
303 return Err(anyhow!(
304 "open ai bad request: {:?} {:?}",
305 &response.status(),
306 body
307 ));
308 }
309 }
310 }
311 Err(anyhow!("openai max retries"))
312 }
313}