embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::future::BoxFuture;
  4use futures::AsyncReadExt;
  5use futures::FutureExt;
  6use gpui::AppContext;
  7use gpui::BackgroundExecutor;
  8use isahc::http::StatusCode;
  9use isahc::prelude::Configurable;
 10use isahc::{AsyncBody, Response};
 11use lazy_static::lazy_static;
 12use parking_lot::{Mutex, RwLock};
 13use parse_duration::parse;
 14use postage::watch;
 15use serde::{Deserialize, Serialize};
 16use serde_json;
 17use std::env;
 18use std::ops::Add;
 19use std::sync::Arc;
 20use std::time::{Duration, Instant};
 21use tiktoken_rs::{cl100k_base, CoreBPE};
 22use util::http::{HttpClient, Request};
 23use util::ResultExt;
 24
 25use crate::auth::{CredentialProvider, ProviderCredential};
 26use crate::embedding::{Embedding, EmbeddingProvider};
 27use crate::models::LanguageModel;
 28use crate::providers::open_ai::OpenAiLanguageModel;
 29
 30use crate::providers::open_ai::OPEN_AI_API_URL;
 31
 32lazy_static! {
 33    pub(crate) static ref OPEN_AI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 34}
 35
 36#[derive(Clone)]
 37pub struct OpenAiEmbeddingProvider {
 38    model: OpenAiLanguageModel,
 39    credential: Arc<RwLock<ProviderCredential>>,
 40    pub client: Arc<dyn HttpClient>,
 41    pub executor: BackgroundExecutor,
 42    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 43    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 44}
 45
 46#[derive(Serialize)]
 47struct OpenAiEmbeddingRequest<'a> {
 48    model: &'static str,
 49    input: Vec<&'a str>,
 50}
 51
 52#[derive(Deserialize)]
 53struct OpenAiEmbeddingResponse {
 54    data: Vec<OpenAiEmbedding>,
 55    usage: OpenAiEmbeddingUsage,
 56}
 57
 58#[derive(Debug, Deserialize)]
 59struct OpenAiEmbedding {
 60    embedding: Vec<f32>,
 61    index: usize,
 62    object: String,
 63}
 64
 65#[derive(Deserialize)]
 66struct OpenAiEmbeddingUsage {
 67    prompt_tokens: usize,
 68    total_tokens: usize,
 69}
 70
 71impl OpenAiEmbeddingProvider {
 72    pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
 73        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
 74        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 75
 76        // Loading the model is expensive, so ensure this runs off the main thread.
 77        let model = executor
 78            .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
 79            .await;
 80        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
 81
 82        OpenAiEmbeddingProvider {
 83            model,
 84            credential,
 85            client,
 86            executor,
 87            rate_limit_count_rx,
 88            rate_limit_count_tx,
 89        }
 90    }
 91
 92    fn get_api_key(&self) -> Result<String> {
 93        match self.credential.read().clone() {
 94            ProviderCredential::Credentials { api_key } => Ok(api_key),
 95            _ => Err(anyhow!("api credentials not provided")),
 96        }
 97    }
 98
 99    fn resolve_rate_limit(&self) {
100        let reset_time = *self.rate_limit_count_tx.lock().borrow();
101
102        if let Some(reset_time) = reset_time {
103            if Instant::now() >= reset_time {
104                *self.rate_limit_count_tx.lock().borrow_mut() = None
105            }
106        }
107
108        log::trace!(
109            "resolving reset time: {:?}",
110            *self.rate_limit_count_tx.lock().borrow()
111        );
112    }
113
114    fn update_reset_time(&self, reset_time: Instant) {
115        let original_time = *self.rate_limit_count_tx.lock().borrow();
116
117        let updated_time = if let Some(original_time) = original_time {
118            if reset_time < original_time {
119                Some(reset_time)
120            } else {
121                Some(original_time)
122            }
123        } else {
124            Some(reset_time)
125        };
126
127        log::trace!("updating rate limit time: {:?}", updated_time);
128
129        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
130    }
131    async fn send_request(
132        &self,
133        api_key: &str,
134        spans: Vec<&str>,
135        request_timeout: u64,
136    ) -> Result<Response<AsyncBody>> {
137        let request = Request::post("https://api.openai.com/v1/embeddings")
138            .redirect_policy(isahc::config::RedirectPolicy::Follow)
139            .timeout(Duration::from_secs(request_timeout))
140            .header("Content-Type", "application/json")
141            .header("Authorization", format!("Bearer {}", api_key))
142            .body(
143                serde_json::to_string(&OpenAiEmbeddingRequest {
144                    input: spans.clone(),
145                    model: "text-embedding-ada-002",
146                })
147                .unwrap()
148                .into(),
149            )?;
150
151        Ok(self.client.send(request).await?)
152    }
153}
154
155impl CredentialProvider for OpenAiEmbeddingProvider {
156    fn has_credentials(&self) -> bool {
157        match *self.credential.read() {
158            ProviderCredential::Credentials { .. } => true,
159            _ => false,
160        }
161    }
162
163    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
164        let existing_credential = self.credential.read().clone();
165        let retrieved_credential = match existing_credential {
166            ProviderCredential::Credentials { .. } => {
167                return async move { existing_credential }.boxed()
168            }
169            _ => {
170                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
171                    async move { ProviderCredential::Credentials { api_key } }.boxed()
172                } else {
173                    let credentials = cx.read_credentials(OPEN_AI_API_URL);
174                    async move {
175                        if let Some(Some((_, api_key))) = credentials.await.log_err() {
176                            if let Some(api_key) = String::from_utf8(api_key).log_err() {
177                                ProviderCredential::Credentials { api_key }
178                            } else {
179                                ProviderCredential::NoCredentials
180                            }
181                        } else {
182                            ProviderCredential::NoCredentials
183                        }
184                    }
185                    .boxed()
186                }
187            }
188        };
189
190        async move {
191            let retrieved_credential = retrieved_credential.await;
192            *self.credential.write() = retrieved_credential.clone();
193            retrieved_credential
194        }
195        .boxed()
196    }
197
198    fn save_credentials(
199        &self,
200        cx: &mut AppContext,
201        credential: ProviderCredential,
202    ) -> BoxFuture<()> {
203        *self.credential.write() = credential.clone();
204        let credential = credential.clone();
205        let write_credentials = match credential {
206            ProviderCredential::Credentials { api_key } => {
207                Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
208            }
209            _ => None,
210        };
211
212        async move {
213            if let Some(write_credentials) = write_credentials {
214                write_credentials.await.log_err();
215            }
216        }
217        .boxed()
218    }
219
220    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
221        *self.credential.write() = ProviderCredential::NoCredentials;
222        let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
223        async move {
224            delete_credentials.await.log_err();
225        }
226        .boxed()
227    }
228}
229
230#[async_trait]
231impl EmbeddingProvider for OpenAiEmbeddingProvider {
232    fn base_model(&self) -> Box<dyn LanguageModel> {
233        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
234        model
235    }
236
237    fn max_tokens_per_batch(&self) -> usize {
238        50000
239    }
240
241    fn rate_limit_expiration(&self) -> Option<Instant> {
242        *self.rate_limit_count_rx.borrow()
243    }
244
245    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
246        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
247        const MAX_RETRIES: usize = 4;
248
249        let api_key = self.get_api_key()?;
250
251        let mut request_number = 0;
252        let mut rate_limiting = false;
253        let mut request_timeout: u64 = 15;
254        let mut response: Response<AsyncBody>;
255        while request_number < MAX_RETRIES {
256            response = self
257                .send_request(
258                    &api_key,
259                    spans.iter().map(|x| &**x).collect(),
260                    request_timeout,
261                )
262                .await?;
263
264            request_number += 1;
265
266            match response.status() {
267                StatusCode::REQUEST_TIMEOUT => {
268                    request_timeout += 5;
269                }
270                StatusCode::OK => {
271                    let mut body = String::new();
272                    response.body_mut().read_to_string(&mut body).await?;
273                    let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
274
275                    log::trace!(
276                        "openai embedding completed. tokens: {:?}",
277                        response.usage.total_tokens
278                    );
279
280                    // If we complete a request successfully that was previously rate_limited
281                    // resolve the rate limit
282                    if rate_limiting {
283                        self.resolve_rate_limit()
284                    }
285
286                    return Ok(response
287                        .data
288                        .into_iter()
289                        .map(|embedding| Embedding::from(embedding.embedding))
290                        .collect());
291                }
292                StatusCode::TOO_MANY_REQUESTS => {
293                    rate_limiting = true;
294                    let mut body = String::new();
295                    response.body_mut().read_to_string(&mut body).await?;
296
297                    let delay_duration = {
298                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
299                        if let Some(time_to_reset) =
300                            response.headers().get("x-ratelimit-reset-tokens")
301                        {
302                            if let Ok(time_str) = time_to_reset.to_str() {
303                                parse(time_str).unwrap_or(delay)
304                            } else {
305                                delay
306                            }
307                        } else {
308                            delay
309                        }
310                    };
311
312                    // If we've previously rate limited, increment the duration but not the count
313                    let reset_time = Instant::now().add(delay_duration);
314                    self.update_reset_time(reset_time);
315
316                    log::trace!(
317                        "openai rate limiting: waiting {:?} until lifted",
318                        &delay_duration
319                    );
320
321                    self.executor.timer(delay_duration).await;
322                }
323                _ => {
324                    let mut body = String::new();
325                    response.body_mut().read_to_string(&mut body).await?;
326                    return Err(anyhow!(
327                        "open ai bad request: {:?} {:?}",
328                        &response.status(),
329                        body
330                    ));
331                }
332            }
333        }
334        Err(anyhow!("openai max retries"))
335    }
336}