embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::future::BoxFuture;
  4use futures::AsyncReadExt;
  5use futures::FutureExt;
  6use gpui::AppContext;
  7use gpui::BackgroundExecutor;
  8use isahc::http::StatusCode;
  9use isahc::prelude::Configurable;
 10use isahc::{AsyncBody, Response};
 11use parking_lot::{Mutex, RwLock};
 12use parse_duration::parse;
 13use postage::watch;
 14use serde::{Deserialize, Serialize};
 15use serde_json;
 16use std::env;
 17use std::ops::Add;
 18use std::sync::{Arc, OnceLock};
 19use std::time::{Duration, Instant};
 20use tiktoken_rs::{cl100k_base, CoreBPE};
 21use util::http::{HttpClient, Request};
 22use util::ResultExt;
 23
 24use crate::auth::{CredentialProvider, ProviderCredential};
 25use crate::embedding::{Embedding, EmbeddingProvider};
 26use crate::models::LanguageModel;
 27use crate::providers::open_ai::OpenAiLanguageModel;
 28
 29use crate::providers::open_ai::OPEN_AI_API_URL;
 30
 31pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
 32    static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
 33    OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
 34}
 35
 36#[derive(Clone)]
 37pub struct OpenAiEmbeddingProvider {
 38    api_url: String,
 39    model: OpenAiLanguageModel,
 40    credential: Arc<RwLock<ProviderCredential>>,
 41    pub client: Arc<dyn HttpClient>,
 42    pub executor: BackgroundExecutor,
 43    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 44    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 45}
 46
 47#[derive(Serialize)]
 48struct OpenAiEmbeddingRequest<'a> {
 49    model: &'static str,
 50    input: Vec<&'a str>,
 51}
 52
 53#[derive(Deserialize)]
 54struct OpenAiEmbeddingResponse {
 55    data: Vec<OpenAiEmbedding>,
 56    usage: OpenAiEmbeddingUsage,
 57}
 58
 59#[derive(Debug, Deserialize)]
 60struct OpenAiEmbedding {
 61    embedding: Vec<f32>,
 62    index: usize,
 63    object: String,
 64}
 65
 66#[derive(Deserialize)]
 67struct OpenAiEmbeddingUsage {
 68    prompt_tokens: usize,
 69    total_tokens: usize,
 70}
 71
 72impl OpenAiEmbeddingProvider {
 73    pub async fn new(
 74        api_url: String,
 75        client: Arc<dyn HttpClient>,
 76        executor: BackgroundExecutor,
 77    ) -> Self {
 78        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
 79        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 80
 81        // Loading the model is expensive, so ensure this runs off the main thread.
 82        let model = executor
 83            .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
 84            .await;
 85        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
 86
 87        OpenAiEmbeddingProvider {
 88            api_url,
 89            model,
 90            credential,
 91            client,
 92            executor,
 93            rate_limit_count_rx,
 94            rate_limit_count_tx,
 95        }
 96    }
 97
 98    fn get_api_key(&self) -> Result<String> {
 99        match self.credential.read().clone() {
100            ProviderCredential::Credentials { api_key } => Ok(api_key),
101            _ => Err(anyhow!("api credentials not provided")),
102        }
103    }
104
105    fn resolve_rate_limit(&self) {
106        let reset_time = *self.rate_limit_count_tx.lock().borrow();
107
108        if let Some(reset_time) = reset_time {
109            if Instant::now() >= reset_time {
110                *self.rate_limit_count_tx.lock().borrow_mut() = None
111            }
112        }
113
114        log::trace!(
115            "resolving reset time: {:?}",
116            *self.rate_limit_count_tx.lock().borrow()
117        );
118    }
119
120    fn update_reset_time(&self, reset_time: Instant) {
121        let original_time = *self.rate_limit_count_tx.lock().borrow();
122
123        let updated_time = if let Some(original_time) = original_time {
124            if reset_time < original_time {
125                Some(reset_time)
126            } else {
127                Some(original_time)
128            }
129        } else {
130            Some(reset_time)
131        };
132
133        log::trace!("updating rate limit time: {:?}", updated_time);
134
135        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
136    }
137    async fn send_request(
138        &self,
139        api_url: &str,
140        api_key: &str,
141        spans: Vec<&str>,
142        request_timeout: u64,
143    ) -> Result<Response<AsyncBody>> {
144        let request = Request::post(format!("{api_url}/embeddings"))
145            .redirect_policy(isahc::config::RedirectPolicy::Follow)
146            .timeout(Duration::from_secs(request_timeout))
147            .header("Content-Type", "application/json")
148            .header("Authorization", format!("Bearer {}", api_key))
149            .body(
150                serde_json::to_string(&OpenAiEmbeddingRequest {
151                    input: spans.clone(),
152                    model: "text-embedding-ada-002",
153                })
154                .unwrap()
155                .into(),
156            )?;
157
158        Ok(self.client.send(request).await?)
159    }
160}
161
162impl CredentialProvider for OpenAiEmbeddingProvider {
163    fn has_credentials(&self) -> bool {
164        match *self.credential.read() {
165            ProviderCredential::Credentials { .. } => true,
166            _ => false,
167        }
168    }
169
170    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
171        let existing_credential = self.credential.read().clone();
172        let retrieved_credential = match existing_credential {
173            ProviderCredential::Credentials { .. } => {
174                return async move { existing_credential }.boxed()
175            }
176            _ => {
177                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
178                    async move { ProviderCredential::Credentials { api_key } }.boxed()
179                } else {
180                    let credentials = cx.read_credentials(OPEN_AI_API_URL);
181                    async move {
182                        if let Some(Some((_, api_key))) = credentials.await.log_err() {
183                            if let Some(api_key) = String::from_utf8(api_key).log_err() {
184                                ProviderCredential::Credentials { api_key }
185                            } else {
186                                ProviderCredential::NoCredentials
187                            }
188                        } else {
189                            ProviderCredential::NoCredentials
190                        }
191                    }
192                    .boxed()
193                }
194            }
195        };
196
197        async move {
198            let retrieved_credential = retrieved_credential.await;
199            *self.credential.write() = retrieved_credential.clone();
200            retrieved_credential
201        }
202        .boxed()
203    }
204
205    fn save_credentials(
206        &self,
207        cx: &mut AppContext,
208        credential: ProviderCredential,
209    ) -> BoxFuture<()> {
210        *self.credential.write() = credential.clone();
211        let credential = credential.clone();
212        let write_credentials = match credential {
213            ProviderCredential::Credentials { api_key } => {
214                Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
215            }
216            _ => None,
217        };
218
219        async move {
220            if let Some(write_credentials) = write_credentials {
221                write_credentials.await.log_err();
222            }
223        }
224        .boxed()
225    }
226
227    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
228        *self.credential.write() = ProviderCredential::NoCredentials;
229        let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
230        async move {
231            delete_credentials.await.log_err();
232        }
233        .boxed()
234    }
235}
236
237#[async_trait]
238impl EmbeddingProvider for OpenAiEmbeddingProvider {
239    fn base_model(&self) -> Box<dyn LanguageModel> {
240        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
241        model
242    }
243
244    fn max_tokens_per_batch(&self) -> usize {
245        50000
246    }
247
248    fn rate_limit_expiration(&self) -> Option<Instant> {
249        *self.rate_limit_count_rx.borrow()
250    }
251
252    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
253        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
254        const MAX_RETRIES: usize = 4;
255
256        let api_url = self.api_url.as_str();
257        let api_key = self.get_api_key()?;
258
259        let mut request_number = 0;
260        let mut rate_limiting = false;
261        let mut request_timeout: u64 = 15;
262        let mut response: Response<AsyncBody>;
263        while request_number < MAX_RETRIES {
264            response = self
265                .send_request(
266                    &api_url,
267                    &api_key,
268                    spans.iter().map(|x| &**x).collect(),
269                    request_timeout,
270                )
271                .await?;
272
273            request_number += 1;
274
275            match response.status() {
276                StatusCode::REQUEST_TIMEOUT => {
277                    request_timeout += 5;
278                }
279                StatusCode::OK => {
280                    let mut body = String::new();
281                    response.body_mut().read_to_string(&mut body).await?;
282                    let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
283
284                    log::trace!(
285                        "openai embedding completed. tokens: {:?}",
286                        response.usage.total_tokens
287                    );
288
289                    // If we complete a request successfully that was previously rate_limited
290                    // resolve the rate limit
291                    if rate_limiting {
292                        self.resolve_rate_limit()
293                    }
294
295                    return Ok(response
296                        .data
297                        .into_iter()
298                        .map(|embedding| Embedding::from(embedding.embedding))
299                        .collect());
300                }
301                StatusCode::TOO_MANY_REQUESTS => {
302                    rate_limiting = true;
303                    let mut body = String::new();
304                    response.body_mut().read_to_string(&mut body).await?;
305
306                    let delay_duration = {
307                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
308                        if let Some(time_to_reset) =
309                            response.headers().get("x-ratelimit-reset-tokens")
310                        {
311                            if let Ok(time_str) = time_to_reset.to_str() {
312                                parse(time_str).unwrap_or(delay)
313                            } else {
314                                delay
315                            }
316                        } else {
317                            delay
318                        }
319                    };
320
321                    // If we've previously rate limited, increment the duration but not the count
322                    let reset_time = Instant::now().add(delay_duration);
323                    self.update_reset_time(reset_time);
324
325                    log::trace!(
326                        "openai rate limiting: waiting {:?} until lifted",
327                        &delay_duration
328                    );
329
330                    self.executor.timer(delay_duration).await;
331                }
332                _ => {
333                    let mut body = String::new();
334                    response.body_mut().read_to_string(&mut body).await?;
335                    return Err(anyhow!(
336                        "open ai bad request: {:?} {:?}",
337                        &response.status(),
338                        body
339                    ));
340                }
341            }
342        }
343        Err(anyhow!("openai max retries"))
344    }
345}