embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parse_duration::parse;
 11use serde::{Deserialize, Serialize};
 12use std::env;
 13use std::sync::Arc;
 14use std::time::Duration;
 15use tiktoken_rs::{cl100k_base, CoreBPE};
 16use util::http::{HttpClient, Request};
 17
 18lazy_static! {
 19    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 20    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 21}
 22
 23#[derive(Clone)]
 24pub struct OpenAIEmbeddings {
 25    pub client: Arc<dyn HttpClient>,
 26    pub executor: Arc<Background>,
 27}
 28
 29#[derive(Serialize)]
 30struct OpenAIEmbeddingRequest<'a> {
 31    model: &'static str,
 32    input: Vec<&'a str>,
 33}
 34
 35#[derive(Deserialize)]
 36struct OpenAIEmbeddingResponse {
 37    data: Vec<OpenAIEmbedding>,
 38    usage: OpenAIEmbeddingUsage,
 39}
 40
 41#[derive(Debug, Deserialize)]
 42struct OpenAIEmbedding {
 43    embedding: Vec<f32>,
 44    index: usize,
 45    object: String,
 46}
 47
 48#[derive(Deserialize)]
 49struct OpenAIEmbeddingUsage {
 50    prompt_tokens: usize,
 51    total_tokens: usize,
 52}
 53
 54#[async_trait]
 55pub trait EmbeddingProvider: Sync + Send {
 56    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 57    fn count_tokens(&self, span: &str) -> usize;
 58    fn should_truncate(&self, span: &str) -> bool;
 59    // fn truncate(&self, span: &str) -> Result<&str>;
 60}
 61
 62pub struct DummyEmbeddings {}
 63
 64#[async_trait]
 65impl EmbeddingProvider for DummyEmbeddings {
 66    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 67        // 1024 is the OpenAI Embeddings size for ada models.
 68        // the model we will likely be starting with.
 69        let dummy_vec = vec![0.32 as f32; 1536];
 70        return Ok(vec![dummy_vec; spans.len()]);
 71    }
 72
 73    fn count_tokens(&self, span: &str) -> usize {
 74        // For Dummy Providers, we are going to use OpenAI tokenization for ease
 75        let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
 76        tokens.len()
 77    }
 78
 79    fn should_truncate(&self, span: &str) -> bool {
 80        self.count_tokens(span) > OPENAI_INPUT_LIMIT
 81
 82        // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
 83        // let Ok(output) = {
 84        //     if tokens.len() > OPENAI_INPUT_LIMIT {
 85        //         tokens.truncate(OPENAI_INPUT_LIMIT);
 86        //         OPENAI_BPE_TOKENIZER.decode(tokens)
 87        //     } else {
 88        //         Ok(span)
 89        //     }
 90        // };
 91    }
 92}
 93
 94const OPENAI_INPUT_LIMIT: usize = 8190;
 95
 96impl OpenAIEmbeddings {
 97    fn truncate(span: String) -> String {
 98        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
 99        if tokens.len() > OPENAI_INPUT_LIMIT {
100            tokens.truncate(OPENAI_INPUT_LIMIT);
101            let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
102            if result.is_ok() {
103                let transformed = result.unwrap();
104                return transformed;
105            }
106        }
107
108        span
109    }
110
111    async fn send_request(
112        &self,
113        api_key: &str,
114        spans: Vec<&str>,
115        request_timeout: u64,
116    ) -> Result<Response<AsyncBody>> {
117        let request = Request::post("https://api.openai.com/v1/embeddings")
118            .redirect_policy(isahc::config::RedirectPolicy::Follow)
119            .timeout(Duration::from_secs(request_timeout))
120            .header("Content-Type", "application/json")
121            .header("Authorization", format!("Bearer {}", api_key))
122            .body(
123                serde_json::to_string(&OpenAIEmbeddingRequest {
124                    input: spans.clone(),
125                    model: "text-embedding-ada-002",
126                })
127                .unwrap()
128                .into(),
129            )?;
130
131        Ok(self.client.send(request).await?)
132    }
133}
134
135#[async_trait]
136impl EmbeddingProvider for OpenAIEmbeddings {
137    fn count_tokens(&self, span: &str) -> usize {
138        // For Dummy Providers, we are going to use OpenAI tokenization for ease
139        let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
140        tokens.len()
141    }
142
143    fn should_truncate(&self, span: &str) -> bool {
144        self.count_tokens(span) > OPENAI_INPUT_LIMIT
145    }
146
147    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
148        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
149        const MAX_RETRIES: usize = 4;
150
151        let api_key = OPENAI_API_KEY
152            .as_ref()
153            .ok_or_else(|| anyhow!("no api key"))?;
154
155        let mut request_number = 0;
156        let mut request_timeout: u64 = 10;
157        let mut truncated = false;
158        let mut response: Response<AsyncBody>;
159        let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
160        while request_number < MAX_RETRIES {
161            response = self
162                .send_request(
163                    api_key,
164                    spans.iter().map(|x| &**x).collect(),
165                    request_timeout,
166                )
167                .await?;
168            request_number += 1;
169
170            match response.status() {
171                StatusCode::REQUEST_TIMEOUT => {
172                    request_timeout += 5;
173                }
174                StatusCode::OK => {
175                    let mut body = String::new();
176                    response.body_mut().read_to_string(&mut body).await?;
177                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
178
179                    log::trace!(
180                        "openai embedding completed. tokens: {:?}",
181                        response.usage.total_tokens
182                    );
183
184                    return Ok(response
185                        .data
186                        .into_iter()
187                        .map(|embedding| embedding.embedding)
188                        .collect());
189                }
190                StatusCode::TOO_MANY_REQUESTS => {
191                    let mut body = String::new();
192                    response.body_mut().read_to_string(&mut body).await?;
193
194                    let delay_duration = {
195                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
196                        if let Some(time_to_reset) =
197                            response.headers().get("x-ratelimit-reset-tokens")
198                        {
199                            if let Ok(time_str) = time_to_reset.to_str() {
200                                parse(time_str).unwrap_or(delay)
201                            } else {
202                                delay
203                            }
204                        } else {
205                            delay
206                        }
207                    };
208
209                    log::trace!(
210                        "openai rate limiting: waiting {:?} until lifted",
211                        &delay_duration
212                    );
213
214                    self.executor.timer(delay_duration).await;
215                }
216                _ => {
217                    // TODO: Move this to parsing step
218                    // Only truncate if it hasnt been truncated before
219                    if !truncated {
220                        for span in spans.iter_mut() {
221                            *span = Self::truncate(span.clone());
222                        }
223                        truncated = true;
224                    } else {
225                        // If failing once already truncated, log the error and break the loop
226                        let mut body = String::new();
227                        response.body_mut().read_to_string(&mut body).await?;
228                        return Err(anyhow!(
229                            "open ai bad request: {:?} {:?}",
230                            &response.status(),
231                            body
232                        ));
233                    }
234                }
235            }
236        }
237        Err(anyhow!("openai max retries"))
238    }
239}