embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::AppContext;
  5use gpui::BackgroundExecutor;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parking_lot::{Mutex, RwLock};
 11use parse_duration::parse;
 12use postage::watch;
 13use serde::{Deserialize, Serialize};
 14use serde_json;
 15use std::env;
 16use std::ops::Add;
 17use std::sync::Arc;
 18use std::time::{Duration, Instant};
 19use tiktoken_rs::{cl100k_base, CoreBPE};
 20use util::http::{HttpClient, Request};
 21use util::ResultExt;
 22
 23use crate::auth::{CredentialProvider, ProviderCredential};
 24use crate::embedding::{Embedding, EmbeddingProvider};
 25use crate::models::LanguageModel;
 26use crate::providers::open_ai::OpenAILanguageModel;
 27
 28use crate::providers::open_ai::OPENAI_API_URL;
 29
 30lazy_static! {
 31    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 32}
 33
 34#[derive(Clone)]
 35pub struct OpenAIEmbeddingProvider {
 36    model: OpenAILanguageModel,
 37    credential: Arc<RwLock<ProviderCredential>>,
 38    pub client: Arc<dyn HttpClient>,
 39    pub executor: BackgroundExecutor,
 40    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 41    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 42}
 43
 44#[derive(Serialize)]
 45struct OpenAIEmbeddingRequest<'a> {
 46    model: &'static str,
 47    input: Vec<&'a str>,
 48}
 49
 50#[derive(Deserialize)]
 51struct OpenAIEmbeddingResponse {
 52    data: Vec<OpenAIEmbedding>,
 53    usage: OpenAIEmbeddingUsage,
 54}
 55
 56#[derive(Debug, Deserialize)]
 57struct OpenAIEmbedding {
 58    embedding: Vec<f32>,
 59    index: usize,
 60    object: String,
 61}
 62
 63#[derive(Deserialize)]
 64struct OpenAIEmbeddingUsage {
 65    prompt_tokens: usize,
 66    total_tokens: usize,
 67}
 68
 69impl OpenAIEmbeddingProvider {
 70    pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
 71        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
 72        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 73
 74        let model = OpenAILanguageModel::load("text-embedding-ada-002");
 75        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
 76
 77        OpenAIEmbeddingProvider {
 78            model,
 79            credential,
 80            client,
 81            executor,
 82            rate_limit_count_rx,
 83            rate_limit_count_tx,
 84        }
 85    }
 86
 87    fn get_api_key(&self) -> Result<String> {
 88        match self.credential.read().clone() {
 89            ProviderCredential::Credentials { api_key } => Ok(api_key),
 90            _ => Err(anyhow!("api credentials not provided")),
 91        }
 92    }
 93
 94    fn resolve_rate_limit(&self) {
 95        let reset_time = *self.rate_limit_count_tx.lock().borrow();
 96
 97        if let Some(reset_time) = reset_time {
 98            if Instant::now() >= reset_time {
 99                *self.rate_limit_count_tx.lock().borrow_mut() = None
100            }
101        }
102
103        log::trace!(
104            "resolving reset time: {:?}",
105            *self.rate_limit_count_tx.lock().borrow()
106        );
107    }
108
109    fn update_reset_time(&self, reset_time: Instant) {
110        let original_time = *self.rate_limit_count_tx.lock().borrow();
111
112        let updated_time = if let Some(original_time) = original_time {
113            if reset_time < original_time {
114                Some(reset_time)
115            } else {
116                Some(original_time)
117            }
118        } else {
119            Some(reset_time)
120        };
121
122        log::trace!("updating rate limit time: {:?}", updated_time);
123
124        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
125    }
126    async fn send_request(
127        &self,
128        api_key: &str,
129        spans: Vec<&str>,
130        request_timeout: u64,
131    ) -> Result<Response<AsyncBody>> {
132        let request = Request::post("https://api.openai.com/v1/embeddings")
133            .redirect_policy(isahc::config::RedirectPolicy::Follow)
134            .timeout(Duration::from_secs(request_timeout))
135            .header("Content-Type", "application/json")
136            .header("Authorization", format!("Bearer {}", api_key))
137            .body(
138                serde_json::to_string(&OpenAIEmbeddingRequest {
139                    input: spans.clone(),
140                    model: "text-embedding-ada-002",
141                })
142                .unwrap()
143                .into(),
144            )?;
145
146        Ok(self.client.send(request).await?)
147    }
148}
149
150impl CredentialProvider for OpenAIEmbeddingProvider {
151    fn has_credentials(&self) -> bool {
152        match *self.credential.read() {
153            ProviderCredential::Credentials { .. } => true,
154            _ => false,
155        }
156    }
157    fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
158        let existing_credential = self.credential.read().clone();
159
160        let retrieved_credential = match existing_credential {
161            ProviderCredential::Credentials { .. } => existing_credential.clone(),
162            _ => {
163                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
164                    ProviderCredential::Credentials { api_key }
165                } else if let Some(Some((_, api_key))) =
166                    cx.read_credentials(OPENAI_API_URL).log_err()
167                {
168                    if let Some(api_key) = String::from_utf8(api_key).log_err() {
169                        ProviderCredential::Credentials { api_key }
170                    } else {
171                        ProviderCredential::NoCredentials
172                    }
173                } else {
174                    ProviderCredential::NoCredentials
175                }
176            }
177        };
178
179        *self.credential.write() = retrieved_credential.clone();
180        retrieved_credential
181    }
182
183    fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
184        *self.credential.write() = credential.clone();
185        match credential {
186            ProviderCredential::Credentials { api_key } => {
187                cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
188                    .log_err();
189            }
190            _ => {}
191        }
192    }
193
194    fn delete_credentials(&self, cx: &mut AppContext) {
195        cx.delete_credentials(OPENAI_API_URL).log_err();
196        *self.credential.write() = ProviderCredential::NoCredentials;
197    }
198}
199
200#[async_trait]
201impl EmbeddingProvider for OpenAIEmbeddingProvider {
202    fn base_model(&self) -> Box<dyn LanguageModel> {
203        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
204        model
205    }
206
207    fn max_tokens_per_batch(&self) -> usize {
208        50000
209    }
210
211    fn rate_limit_expiration(&self) -> Option<Instant> {
212        *self.rate_limit_count_rx.borrow()
213    }
214
215    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
216        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
217        const MAX_RETRIES: usize = 4;
218
219        let api_key = self.get_api_key()?;
220
221        let mut request_number = 0;
222        let mut rate_limiting = false;
223        let mut request_timeout: u64 = 15;
224        let mut response: Response<AsyncBody>;
225        while request_number < MAX_RETRIES {
226            response = self
227                .send_request(
228                    &api_key,
229                    spans.iter().map(|x| &**x).collect(),
230                    request_timeout,
231                )
232                .await?;
233
234            request_number += 1;
235
236            match response.status() {
237                StatusCode::REQUEST_TIMEOUT => {
238                    request_timeout += 5;
239                }
240                StatusCode::OK => {
241                    let mut body = String::new();
242                    response.body_mut().read_to_string(&mut body).await?;
243                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
244
245                    log::trace!(
246                        "openai embedding completed. tokens: {:?}",
247                        response.usage.total_tokens
248                    );
249
250                    // If we complete a request successfully that was previously rate_limited
251                    // resolve the rate limit
252                    if rate_limiting {
253                        self.resolve_rate_limit()
254                    }
255
256                    return Ok(response
257                        .data
258                        .into_iter()
259                        .map(|embedding| Embedding::from(embedding.embedding))
260                        .collect());
261                }
262                StatusCode::TOO_MANY_REQUESTS => {
263                    rate_limiting = true;
264                    let mut body = String::new();
265                    response.body_mut().read_to_string(&mut body).await?;
266
267                    let delay_duration = {
268                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
269                        if let Some(time_to_reset) =
270                            response.headers().get("x-ratelimit-reset-tokens")
271                        {
272                            if let Ok(time_str) = time_to_reset.to_str() {
273                                parse(time_str).unwrap_or(delay)
274                            } else {
275                                delay
276                            }
277                        } else {
278                            delay
279                        }
280                    };
281
282                    // If we've previously rate limited, increment the duration but not the count
283                    let reset_time = Instant::now().add(delay_duration);
284                    self.update_reset_time(reset_time);
285
286                    log::trace!(
287                        "openai rate limiting: waiting {:?} until lifted",
288                        &delay_duration
289                    );
290
291                    self.executor.timer(delay_duration).await;
292                }
293                _ => {
294                    let mut body = String::new();
295                    response.body_mut().read_to_string(&mut body).await?;
296                    return Err(anyhow!(
297                        "open ai bad request: {:?} {:?}",
298                        &response.status(),
299                        body
300                    ));
301                }
302            }
303        }
304        Err(anyhow!("openai max retries"))
305    }
306}