embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui2::Executor;
  5use gpui2::{serde_json, AppContext};
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parking_lot::{Mutex, RwLock};
 11use parse_duration::parse;
 12use postage::watch;
 13use serde::{Deserialize, Serialize};
 14use std::env;
 15use std::ops::Add;
 16use std::sync::Arc;
 17use std::time::{Duration, Instant};
 18use tiktoken_rs::{cl100k_base, CoreBPE};
 19use util::http::{HttpClient, Request};
 20use util::ResultExt;
 21
 22use crate::auth::{CredentialProvider, ProviderCredential};
 23use crate::embedding::{Embedding, EmbeddingProvider};
 24use crate::models::LanguageModel;
 25use crate::providers::open_ai::OpenAILanguageModel;
 26
 27use crate::providers::open_ai::OPENAI_API_URL;
 28
 29lazy_static! {
 30    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 31}
 32
 33#[derive(Clone)]
 34pub struct OpenAIEmbeddingProvider {
 35    model: OpenAILanguageModel,
 36    credential: Arc<RwLock<ProviderCredential>>,
 37    pub client: Arc<dyn HttpClient>,
 38    pub executor: Arc<Executor>,
 39    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 40    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 41}
 42
 43#[derive(Serialize)]
 44struct OpenAIEmbeddingRequest<'a> {
 45    model: &'static str,
 46    input: Vec<&'a str>,
 47}
 48
 49#[derive(Deserialize)]
 50struct OpenAIEmbeddingResponse {
 51    data: Vec<OpenAIEmbedding>,
 52    usage: OpenAIEmbeddingUsage,
 53}
 54
 55#[derive(Debug, Deserialize)]
 56struct OpenAIEmbedding {
 57    embedding: Vec<f32>,
 58    index: usize,
 59    object: String,
 60}
 61
 62#[derive(Deserialize)]
 63struct OpenAIEmbeddingUsage {
 64    prompt_tokens: usize,
 65    total_tokens: usize,
 66}
 67
 68impl OpenAIEmbeddingProvider {
 69    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
 70        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
 71        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 72
 73        let model = OpenAILanguageModel::load("text-embedding-ada-002");
 74        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
 75
 76        OpenAIEmbeddingProvider {
 77            model,
 78            credential,
 79            client,
 80            executor,
 81            rate_limit_count_rx,
 82            rate_limit_count_tx,
 83        }
 84    }
 85
 86    fn get_api_key(&self) -> Result<String> {
 87        match self.credential.read().clone() {
 88            ProviderCredential::Credentials { api_key } => Ok(api_key),
 89            _ => Err(anyhow!("api credentials not provided")),
 90        }
 91    }
 92
 93    fn resolve_rate_limit(&self) {
 94        let reset_time = *self.rate_limit_count_tx.lock().borrow();
 95
 96        if let Some(reset_time) = reset_time {
 97            if Instant::now() >= reset_time {
 98                *self.rate_limit_count_tx.lock().borrow_mut() = None
 99            }
100        }
101
102        log::trace!(
103            "resolving reset time: {:?}",
104            *self.rate_limit_count_tx.lock().borrow()
105        );
106    }
107
108    fn update_reset_time(&self, reset_time: Instant) {
109        let original_time = *self.rate_limit_count_tx.lock().borrow();
110
111        let updated_time = if let Some(original_time) = original_time {
112            if reset_time < original_time {
113                Some(reset_time)
114            } else {
115                Some(original_time)
116            }
117        } else {
118            Some(reset_time)
119        };
120
121        log::trace!("updating rate limit time: {:?}", updated_time);
122
123        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
124    }
125    async fn send_request(
126        &self,
127        api_key: &str,
128        spans: Vec<&str>,
129        request_timeout: u64,
130    ) -> Result<Response<AsyncBody>> {
131        let request = Request::post("https://api.openai.com/v1/embeddings")
132            .redirect_policy(isahc::config::RedirectPolicy::Follow)
133            .timeout(Duration::from_secs(request_timeout))
134            .header("Content-Type", "application/json")
135            .header("Authorization", format!("Bearer {}", api_key))
136            .body(
137                serde_json::to_string(&OpenAIEmbeddingRequest {
138                    input: spans.clone(),
139                    model: "text-embedding-ada-002",
140                })
141                .unwrap()
142                .into(),
143            )?;
144
145        Ok(self.client.send(request).await?)
146    }
147}
148
149#[async_trait]
150impl CredentialProvider for OpenAIEmbeddingProvider {
151    fn has_credentials(&self) -> bool {
152        match *self.credential.read() {
153            ProviderCredential::Credentials { .. } => true,
154            _ => false,
155        }
156    }
157    async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
158        let existing_credential = self.credential.read().clone();
159
160        let retrieved_credential = cx
161            .run_on_main(move |cx| match existing_credential {
162                ProviderCredential::Credentials { .. } => {
163                    return existing_credential.clone();
164                }
165                _ => {
166                    if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
167                        return ProviderCredential::Credentials { api_key };
168                    }
169
170                    if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
171                    {
172                        if let Some(api_key) = String::from_utf8(api_key).log_err() {
173                            return ProviderCredential::Credentials { api_key };
174                        } else {
175                            return ProviderCredential::NoCredentials;
176                        }
177                    } else {
178                        return ProviderCredential::NoCredentials;
179                    }
180                }
181            })
182            .await;
183
184        *self.credential.write() = retrieved_credential.clone();
185        retrieved_credential
186    }
187
188    async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
189        *self.credential.write() = credential.clone();
190        let credential = credential.clone();
191        cx.run_on_main(move |cx| match credential {
192            ProviderCredential::Credentials { api_key } => {
193                cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
194                    .log_err();
195            }
196            _ => {}
197        })
198        .await;
199    }
200    async fn delete_credentials(&self, cx: &mut AppContext) {
201        cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
202            .await;
203        *self.credential.write() = ProviderCredential::NoCredentials;
204    }
205}
206
207#[async_trait]
208impl EmbeddingProvider for OpenAIEmbeddingProvider {
209    fn base_model(&self) -> Box<dyn LanguageModel> {
210        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
211        model
212    }
213
214    fn max_tokens_per_batch(&self) -> usize {
215        50000
216    }
217
218    fn rate_limit_expiration(&self) -> Option<Instant> {
219        *self.rate_limit_count_rx.borrow()
220    }
221
222    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
223        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
224        const MAX_RETRIES: usize = 4;
225
226        let api_key = self.get_api_key()?;
227
228        let mut request_number = 0;
229        let mut rate_limiting = false;
230        let mut request_timeout: u64 = 15;
231        let mut response: Response<AsyncBody>;
232        while request_number < MAX_RETRIES {
233            response = self
234                .send_request(
235                    &api_key,
236                    spans.iter().map(|x| &**x).collect(),
237                    request_timeout,
238                )
239                .await?;
240
241            request_number += 1;
242
243            match response.status() {
244                StatusCode::REQUEST_TIMEOUT => {
245                    request_timeout += 5;
246                }
247                StatusCode::OK => {
248                    let mut body = String::new();
249                    response.body_mut().read_to_string(&mut body).await?;
250                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
251
252                    log::trace!(
253                        "openai embedding completed. tokens: {:?}",
254                        response.usage.total_tokens
255                    );
256
257                    // If we complete a request successfully that was previously rate_limited
258                    // resolve the rate limit
259                    if rate_limiting {
260                        self.resolve_rate_limit()
261                    }
262
263                    return Ok(response
264                        .data
265                        .into_iter()
266                        .map(|embedding| Embedding::from(embedding.embedding))
267                        .collect());
268                }
269                StatusCode::TOO_MANY_REQUESTS => {
270                    rate_limiting = true;
271                    let mut body = String::new();
272                    response.body_mut().read_to_string(&mut body).await?;
273
274                    let delay_duration = {
275                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
276                        if let Some(time_to_reset) =
277                            response.headers().get("x-ratelimit-reset-tokens")
278                        {
279                            if let Ok(time_str) = time_to_reset.to_str() {
280                                parse(time_str).unwrap_or(delay)
281                            } else {
282                                delay
283                            }
284                        } else {
285                            delay
286                        }
287                    };
288
289                    // If we've previously rate limited, increment the duration but not the count
290                    let reset_time = Instant::now().add(delay_duration);
291                    self.update_reset_time(reset_time);
292
293                    log::trace!(
294                        "openai rate limiting: waiting {:?} until lifted",
295                        &delay_duration
296                    );
297
298                    self.executor.timer(delay_duration).await;
299                }
300                _ => {
301                    let mut body = String::new();
302                    response.body_mut().read_to_string(&mut body).await?;
303                    return Err(anyhow!(
304                        "open ai bad request: {:?} {:?}",
305                        &response.status(),
306                        body
307                    ));
308                }
309            }
310        }
311        Err(anyhow!("openai max retries"))
312    }
313}