embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parking_lot::Mutex;
 11use parse_duration::parse;
 12use postage::watch;
 13use serde::{Deserialize, Serialize};
 14use std::env;
 15use std::ops::Add;
 16use std::sync::Arc;
 17use std::time::{Duration, Instant};
 18use tiktoken_rs::{cl100k_base, CoreBPE};
 19use util::http::{HttpClient, Request};
 20
 21use crate::auth::{CredentialProvider, ProviderCredential};
 22use crate::embedding::{Embedding, EmbeddingProvider};
 23use crate::models::LanguageModel;
 24use crate::providers::open_ai::OpenAILanguageModel;
 25
 26use crate::providers::open_ai::auth::OpenAICredentialProvider;
 27
 28lazy_static! {
 29    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 30}
 31
 32#[derive(Clone)]
 33pub struct OpenAIEmbeddingProvider {
 34    model: OpenAILanguageModel,
 35    credential_provider: OpenAICredentialProvider,
 36    pub client: Arc<dyn HttpClient>,
 37    pub executor: Arc<Background>,
 38    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 39    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 40}
 41
 42#[derive(Serialize)]
 43struct OpenAIEmbeddingRequest<'a> {
 44    model: &'static str,
 45    input: Vec<&'a str>,
 46}
 47
 48#[derive(Deserialize)]
 49struct OpenAIEmbeddingResponse {
 50    data: Vec<OpenAIEmbedding>,
 51    usage: OpenAIEmbeddingUsage,
 52}
 53
 54#[derive(Debug, Deserialize)]
 55struct OpenAIEmbedding {
 56    embedding: Vec<f32>,
 57    index: usize,
 58    object: String,
 59}
 60
 61#[derive(Deserialize)]
 62struct OpenAIEmbeddingUsage {
 63    prompt_tokens: usize,
 64    total_tokens: usize,
 65}
 66
 67impl OpenAIEmbeddingProvider {
 68    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
 69        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
 70        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 71
 72        let model = OpenAILanguageModel::load("text-embedding-ada-002");
 73
 74        OpenAIEmbeddingProvider {
 75            model,
 76            credential_provider: OpenAICredentialProvider {},
 77            client,
 78            executor,
 79            rate_limit_count_rx,
 80            rate_limit_count_tx,
 81        }
 82    }
 83
 84    fn resolve_rate_limit(&self) {
 85        let reset_time = *self.rate_limit_count_tx.lock().borrow();
 86
 87        if let Some(reset_time) = reset_time {
 88            if Instant::now() >= reset_time {
 89                *self.rate_limit_count_tx.lock().borrow_mut() = None
 90            }
 91        }
 92
 93        log::trace!(
 94            "resolving reset time: {:?}",
 95            *self.rate_limit_count_tx.lock().borrow()
 96        );
 97    }
 98
 99    fn update_reset_time(&self, reset_time: Instant) {
100        let original_time = *self.rate_limit_count_tx.lock().borrow();
101
102        let updated_time = if let Some(original_time) = original_time {
103            if reset_time < original_time {
104                Some(reset_time)
105            } else {
106                Some(original_time)
107            }
108        } else {
109            Some(reset_time)
110        };
111
112        log::trace!("updating rate limit time: {:?}", updated_time);
113
114        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
115    }
116    async fn send_request(
117        &self,
118        api_key: &str,
119        spans: Vec<&str>,
120        request_timeout: u64,
121    ) -> Result<Response<AsyncBody>> {
122        let request = Request::post("https://api.openai.com/v1/embeddings")
123            .redirect_policy(isahc::config::RedirectPolicy::Follow)
124            .timeout(Duration::from_secs(request_timeout))
125            .header("Content-Type", "application/json")
126            .header("Authorization", format!("Bearer {}", api_key))
127            .body(
128                serde_json::to_string(&OpenAIEmbeddingRequest {
129                    input: spans.clone(),
130                    model: "text-embedding-ada-002",
131                })
132                .unwrap()
133                .into(),
134            )?;
135
136        Ok(self.client.send(request).await?)
137    }
138}
139
140#[async_trait]
141impl EmbeddingProvider for OpenAIEmbeddingProvider {
142    fn base_model(&self) -> Box<dyn LanguageModel> {
143        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
144        model
145    }
146
147    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
148        let credential_provider: Box<dyn CredentialProvider> =
149            Box::new(self.credential_provider.clone());
150        credential_provider
151    }
152
153    fn max_tokens_per_batch(&self) -> usize {
154        50000
155    }
156
157    fn rate_limit_expiration(&self) -> Option<Instant> {
158        *self.rate_limit_count_rx.borrow()
159    }
160
161    async fn embed_batch(
162        &self,
163        spans: Vec<String>,
164        credential: ProviderCredential,
165    ) -> Result<Vec<Embedding>> {
166        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
167        const MAX_RETRIES: usize = 4;
168
169        let api_key = match credential {
170            ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
171            _ => Err(anyhow!("no api key provided")),
172        }?;
173
174        let mut request_number = 0;
175        let mut rate_limiting = false;
176        let mut request_timeout: u64 = 15;
177        let mut response: Response<AsyncBody>;
178        while request_number < MAX_RETRIES {
179            response = self
180                .send_request(
181                    &api_key,
182                    spans.iter().map(|x| &**x).collect(),
183                    request_timeout,
184                )
185                .await?;
186
187            request_number += 1;
188
189            match response.status() {
190                StatusCode::REQUEST_TIMEOUT => {
191                    request_timeout += 5;
192                }
193                StatusCode::OK => {
194                    let mut body = String::new();
195                    response.body_mut().read_to_string(&mut body).await?;
196                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
197
198                    log::trace!(
199                        "openai embedding completed. tokens: {:?}",
200                        response.usage.total_tokens
201                    );
202
203                    // If we complete a request successfully that was previously rate_limited
204                    // resolve the rate limit
205                    if rate_limiting {
206                        self.resolve_rate_limit()
207                    }
208
209                    return Ok(response
210                        .data
211                        .into_iter()
212                        .map(|embedding| Embedding::from(embedding.embedding))
213                        .collect());
214                }
215                StatusCode::TOO_MANY_REQUESTS => {
216                    rate_limiting = true;
217                    let mut body = String::new();
218                    response.body_mut().read_to_string(&mut body).await?;
219
220                    let delay_duration = {
221                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
222                        if let Some(time_to_reset) =
223                            response.headers().get("x-ratelimit-reset-tokens")
224                        {
225                            if let Ok(time_str) = time_to_reset.to_str() {
226                                parse(time_str).unwrap_or(delay)
227                            } else {
228                                delay
229                            }
230                        } else {
231                            delay
232                        }
233                    };
234
235                    // If we've previously rate limited, increment the duration but not the count
236                    let reset_time = Instant::now().add(delay_duration);
237                    self.update_reset_time(reset_time);
238
239                    log::trace!(
240                        "openai rate limiting: waiting {:?} until lifted",
241                        &delay_duration
242                    );
243
244                    self.executor.timer(delay_duration).await;
245                }
246                _ => {
247                    let mut body = String::new();
248                    response.body_mut().read_to_string(&mut body).await?;
249                    return Err(anyhow!(
250                        "open ai bad request: {:?} {:?}",
251                        &response.status(),
252                        body
253                    ));
254                }
255            }
256        }
257        Err(anyhow!("openai max retries"))
258    }
259}