embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parking_lot::Mutex;
 11use parse_duration::parse;
 12use postage::watch;
 13use serde::{Deserialize, Serialize};
 14use std::env;
 15use std::ops::Add;
 16use std::sync::Arc;
 17use std::time::{Duration, Instant};
 18use tiktoken_rs::{cl100k_base, CoreBPE};
 19use util::http::{HttpClient, Request};
 20
 21use crate::embedding::{Embedding, EmbeddingProvider};
 22use crate::models::LanguageModel;
 23use crate::providers::open_ai::OpenAILanguageModel;
 24
 25lazy_static! {
 26    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 27    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 28}
 29
 30#[derive(Clone)]
 31pub struct OpenAIEmbeddingProvider {
 32    model: OpenAILanguageModel,
 33    pub client: Arc<dyn HttpClient>,
 34    pub executor: Arc<Background>,
 35    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
 36    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 37}
 38
 39#[derive(Serialize)]
 40struct OpenAIEmbeddingRequest<'a> {
 41    model: &'static str,
 42    input: Vec<&'a str>,
 43}
 44
 45#[derive(Deserialize)]
 46struct OpenAIEmbeddingResponse {
 47    data: Vec<OpenAIEmbedding>,
 48    usage: OpenAIEmbeddingUsage,
 49}
 50
 51#[derive(Debug, Deserialize)]
 52struct OpenAIEmbedding {
 53    embedding: Vec<f32>,
 54    index: usize,
 55    object: String,
 56}
 57
 58#[derive(Deserialize)]
 59struct OpenAIEmbeddingUsage {
 60    prompt_tokens: usize,
 61    total_tokens: usize,
 62}
 63
 64const OPENAI_INPUT_LIMIT: usize = 8190;
 65
 66impl OpenAIEmbeddingProvider {
 67    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
 68        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
 69        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 70
 71        let model = OpenAILanguageModel::load("text-embedding-ada-002");
 72
 73        OpenAIEmbeddingProvider {
 74            model,
 75            client,
 76            executor,
 77            rate_limit_count_rx,
 78            rate_limit_count_tx,
 79        }
 80    }
 81
 82    fn resolve_rate_limit(&self) {
 83        let reset_time = *self.rate_limit_count_tx.lock().borrow();
 84
 85        if let Some(reset_time) = reset_time {
 86            if Instant::now() >= reset_time {
 87                *self.rate_limit_count_tx.lock().borrow_mut() = None
 88            }
 89        }
 90
 91        log::trace!(
 92            "resolving reset time: {:?}",
 93            *self.rate_limit_count_tx.lock().borrow()
 94        );
 95    }
 96
 97    fn update_reset_time(&self, reset_time: Instant) {
 98        let original_time = *self.rate_limit_count_tx.lock().borrow();
 99
100        let updated_time = if let Some(original_time) = original_time {
101            if reset_time < original_time {
102                Some(reset_time)
103            } else {
104                Some(original_time)
105            }
106        } else {
107            Some(reset_time)
108        };
109
110        log::trace!("updating rate limit time: {:?}", updated_time);
111
112        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
113    }
114    async fn send_request(
115        &self,
116        api_key: &str,
117        spans: Vec<&str>,
118        request_timeout: u64,
119    ) -> Result<Response<AsyncBody>> {
120        let request = Request::post("https://api.openai.com/v1/embeddings")
121            .redirect_policy(isahc::config::RedirectPolicy::Follow)
122            .timeout(Duration::from_secs(request_timeout))
123            .header("Content-Type", "application/json")
124            .header("Authorization", format!("Bearer {}", api_key))
125            .body(
126                serde_json::to_string(&OpenAIEmbeddingRequest {
127                    input: spans.clone(),
128                    model: "text-embedding-ada-002",
129                })
130                .unwrap()
131                .into(),
132            )?;
133
134        Ok(self.client.send(request).await?)
135    }
136}
137
138#[async_trait]
139impl EmbeddingProvider for OpenAIEmbeddingProvider {
140    fn base_model(&self) -> Box<dyn LanguageModel> {
141        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
142        model
143    }
144    fn is_authenticated(&self) -> bool {
145        OPENAI_API_KEY.as_ref().is_some()
146    }
147    fn max_tokens_per_batch(&self) -> usize {
148        50000
149    }
150
151    fn rate_limit_expiration(&self) -> Option<Instant> {
152        *self.rate_limit_count_rx.borrow()
153    }
154    fn truncate(&self, span: &str) -> (String, usize) {
155        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
156        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
157            tokens.truncate(OPENAI_INPUT_LIMIT);
158            OPENAI_BPE_TOKENIZER
159                .decode(tokens.clone())
160                .ok()
161                .unwrap_or_else(|| span.to_string())
162        } else {
163            span.to_string()
164        };
165
166        (output, tokens.len())
167    }
168
169    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
170        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
171        const MAX_RETRIES: usize = 4;
172
173        let api_key = OPENAI_API_KEY
174            .as_ref()
175            .ok_or_else(|| anyhow!("no api key"))?;
176
177        let mut request_number = 0;
178        let mut rate_limiting = false;
179        let mut request_timeout: u64 = 15;
180        let mut response: Response<AsyncBody>;
181        while request_number < MAX_RETRIES {
182            response = self
183                .send_request(
184                    api_key,
185                    spans.iter().map(|x| &**x).collect(),
186                    request_timeout,
187                )
188                .await?;
189
190            request_number += 1;
191
192            match response.status() {
193                StatusCode::REQUEST_TIMEOUT => {
194                    request_timeout += 5;
195                }
196                StatusCode::OK => {
197                    let mut body = String::new();
198                    response.body_mut().read_to_string(&mut body).await?;
199                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
200
201                    log::trace!(
202                        "openai embedding completed. tokens: {:?}",
203                        response.usage.total_tokens
204                    );
205
206                    // If we complete a request successfully that was previously rate_limited
207                    // resolve the rate limit
208                    if rate_limiting {
209                        self.resolve_rate_limit()
210                    }
211
212                    return Ok(response
213                        .data
214                        .into_iter()
215                        .map(|embedding| Embedding::from(embedding.embedding))
216                        .collect());
217                }
218                StatusCode::TOO_MANY_REQUESTS => {
219                    rate_limiting = true;
220                    let mut body = String::new();
221                    response.body_mut().read_to_string(&mut body).await?;
222
223                    let delay_duration = {
224                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
225                        if let Some(time_to_reset) =
226                            response.headers().get("x-ratelimit-reset-tokens")
227                        {
228                            if let Ok(time_str) = time_to_reset.to_str() {
229                                parse(time_str).unwrap_or(delay)
230                            } else {
231                                delay
232                            }
233                        } else {
234                            delay
235                        }
236                    };
237
238                    // If we've previously rate limited, increment the duration but not the count
239                    let reset_time = Instant::now().add(delay_duration);
240                    self.update_reset_time(reset_time);
241
242                    log::trace!(
243                        "openai rate limiting: waiting {:?} until lifted",
244                        &delay_duration
245                    );
246
247                    self.executor.timer(delay_duration).await;
248                }
249                _ => {
250                    let mut body = String::new();
251                    response.body_mut().read_to_string(&mut body).await?;
252                    return Err(anyhow!(
253                        "open ai bad request: {:?} {:?}",
254                        &response.status(),
255                        body
256                    ));
257                }
258            }
259        }
260        Err(anyhow!("openai max retries"))
261    }
262}