1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::AppContext;
5use gpui::BackgroundExecutor;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parking_lot::{Mutex, RwLock};
11use parse_duration::parse;
12use postage::watch;
13use serde::{Deserialize, Serialize};
14use serde_json;
15use std::env;
16use std::ops::Add;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tiktoken_rs::{cl100k_base, CoreBPE};
20use util::http::{HttpClient, Request};
21use util::ResultExt;
22
23use crate::auth::{CredentialProvider, ProviderCredential};
24use crate::embedding::{Embedding, EmbeddingProvider};
25use crate::models::LanguageModel;
26use crate::providers::open_ai::OpenAILanguageModel;
27
28use crate::providers::open_ai::OPENAI_API_URL;
29
30lazy_static! {
31 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
32}
33
34#[derive(Clone)]
35pub struct OpenAIEmbeddingProvider {
36 model: OpenAILanguageModel,
37 credential: Arc<RwLock<ProviderCredential>>,
38 pub client: Arc<dyn HttpClient>,
39 pub executor: BackgroundExecutor,
40 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
41 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
42}
43
44#[derive(Serialize)]
45struct OpenAIEmbeddingRequest<'a> {
46 model: &'static str,
47 input: Vec<&'a str>,
48}
49
50#[derive(Deserialize)]
51struct OpenAIEmbeddingResponse {
52 data: Vec<OpenAIEmbedding>,
53 usage: OpenAIEmbeddingUsage,
54}
55
56#[derive(Debug, Deserialize)]
57struct OpenAIEmbedding {
58 embedding: Vec<f32>,
59 index: usize,
60 object: String,
61}
62
63#[derive(Deserialize)]
64struct OpenAIEmbeddingUsage {
65 prompt_tokens: usize,
66 total_tokens: usize,
67}
68
69impl OpenAIEmbeddingProvider {
70 pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
71 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
72 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
73
74 let model = OpenAILanguageModel::load("text-embedding-ada-002");
75 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
76
77 OpenAIEmbeddingProvider {
78 model,
79 credential,
80 client,
81 executor,
82 rate_limit_count_rx,
83 rate_limit_count_tx,
84 }
85 }
86
87 fn get_api_key(&self) -> Result<String> {
88 match self.credential.read().clone() {
89 ProviderCredential::Credentials { api_key } => Ok(api_key),
90 _ => Err(anyhow!("api credentials not provided")),
91 }
92 }
93
94 fn resolve_rate_limit(&self) {
95 let reset_time = *self.rate_limit_count_tx.lock().borrow();
96
97 if let Some(reset_time) = reset_time {
98 if Instant::now() >= reset_time {
99 *self.rate_limit_count_tx.lock().borrow_mut() = None
100 }
101 }
102
103 log::trace!(
104 "resolving reset time: {:?}",
105 *self.rate_limit_count_tx.lock().borrow()
106 );
107 }
108
109 fn update_reset_time(&self, reset_time: Instant) {
110 let original_time = *self.rate_limit_count_tx.lock().borrow();
111
112 let updated_time = if let Some(original_time) = original_time {
113 if reset_time < original_time {
114 Some(reset_time)
115 } else {
116 Some(original_time)
117 }
118 } else {
119 Some(reset_time)
120 };
121
122 log::trace!("updating rate limit time: {:?}", updated_time);
123
124 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
125 }
126 async fn send_request(
127 &self,
128 api_key: &str,
129 spans: Vec<&str>,
130 request_timeout: u64,
131 ) -> Result<Response<AsyncBody>> {
132 let request = Request::post("https://api.openai.com/v1/embeddings")
133 .redirect_policy(isahc::config::RedirectPolicy::Follow)
134 .timeout(Duration::from_secs(request_timeout))
135 .header("Content-Type", "application/json")
136 .header("Authorization", format!("Bearer {}", api_key))
137 .body(
138 serde_json::to_string(&OpenAIEmbeddingRequest {
139 input: spans.clone(),
140 model: "text-embedding-ada-002",
141 })
142 .unwrap()
143 .into(),
144 )?;
145
146 Ok(self.client.send(request).await?)
147 }
148}
149
150impl CredentialProvider for OpenAIEmbeddingProvider {
151 fn has_credentials(&self) -> bool {
152 match *self.credential.read() {
153 ProviderCredential::Credentials { .. } => true,
154 _ => false,
155 }
156 }
157 fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
158 let existing_credential = self.credential.read().clone();
159
160 let retrieved_credential = match existing_credential {
161 ProviderCredential::Credentials { .. } => existing_credential.clone(),
162 _ => {
163 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
164 ProviderCredential::Credentials { api_key }
165 } else if let Some(Some((_, api_key))) =
166 cx.read_credentials(OPENAI_API_URL).log_err()
167 {
168 if let Some(api_key) = String::from_utf8(api_key).log_err() {
169 ProviderCredential::Credentials { api_key }
170 } else {
171 ProviderCredential::NoCredentials
172 }
173 } else {
174 ProviderCredential::NoCredentials
175 }
176 }
177 };
178
179 *self.credential.write() = retrieved_credential.clone();
180 retrieved_credential
181 }
182
183 fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
184 *self.credential.write() = credential.clone();
185 match credential {
186 ProviderCredential::Credentials { api_key } => {
187 cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
188 .log_err();
189 }
190 _ => {}
191 }
192 }
193
194 fn delete_credentials(&self, cx: &mut AppContext) {
195 cx.delete_credentials(OPENAI_API_URL).log_err();
196 *self.credential.write() = ProviderCredential::NoCredentials;
197 }
198}
199
200#[async_trait]
201impl EmbeddingProvider for OpenAIEmbeddingProvider {
202 fn base_model(&self) -> Box<dyn LanguageModel> {
203 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
204 model
205 }
206
207 fn max_tokens_per_batch(&self) -> usize {
208 50000
209 }
210
211 fn rate_limit_expiration(&self) -> Option<Instant> {
212 *self.rate_limit_count_rx.borrow()
213 }
214
215 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
216 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
217 const MAX_RETRIES: usize = 4;
218
219 let api_key = self.get_api_key()?;
220
221 let mut request_number = 0;
222 let mut rate_limiting = false;
223 let mut request_timeout: u64 = 15;
224 let mut response: Response<AsyncBody>;
225 while request_number < MAX_RETRIES {
226 response = self
227 .send_request(
228 &api_key,
229 spans.iter().map(|x| &**x).collect(),
230 request_timeout,
231 )
232 .await?;
233
234 request_number += 1;
235
236 match response.status() {
237 StatusCode::REQUEST_TIMEOUT => {
238 request_timeout += 5;
239 }
240 StatusCode::OK => {
241 let mut body = String::new();
242 response.body_mut().read_to_string(&mut body).await?;
243 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
244
245 log::trace!(
246 "openai embedding completed. tokens: {:?}",
247 response.usage.total_tokens
248 );
249
250 // If we complete a request successfully that was previously rate_limited
251 // resolve the rate limit
252 if rate_limiting {
253 self.resolve_rate_limit()
254 }
255
256 return Ok(response
257 .data
258 .into_iter()
259 .map(|embedding| Embedding::from(embedding.embedding))
260 .collect());
261 }
262 StatusCode::TOO_MANY_REQUESTS => {
263 rate_limiting = true;
264 let mut body = String::new();
265 response.body_mut().read_to_string(&mut body).await?;
266
267 let delay_duration = {
268 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
269 if let Some(time_to_reset) =
270 response.headers().get("x-ratelimit-reset-tokens")
271 {
272 if let Ok(time_str) = time_to_reset.to_str() {
273 parse(time_str).unwrap_or(delay)
274 } else {
275 delay
276 }
277 } else {
278 delay
279 }
280 };
281
282 // If we've previously rate limited, increment the duration but not the count
283 let reset_time = Instant::now().add(delay_duration);
284 self.update_reset_time(reset_time);
285
286 log::trace!(
287 "openai rate limiting: waiting {:?} until lifted",
288 &delay_duration
289 );
290
291 self.executor.timer(delay_duration).await;
292 }
293 _ => {
294 let mut body = String::new();
295 response.body_mut().read_to_string(&mut body).await?;
296 return Err(anyhow!(
297 "open ai bad request: {:?} {:?}",
298 &response.status(),
299 body
300 ));
301 }
302 }
303 }
304 Err(anyhow!("openai max retries"))
305 }
306}