embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parse_duration::parse;
 11use serde::{Deserialize, Serialize};
 12use std::env;
 13use std::sync::Arc;
 14use std::time::Duration;
 15use tiktoken_rs::{cl100k_base, CoreBPE};
 16use util::http::{HttpClient, Request};
 17
 18lazy_static! {
 19    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 20    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 21}
 22
 23#[derive(Clone)]
 24pub struct OpenAIEmbeddings {
 25    pub client: Arc<dyn HttpClient>,
 26    pub executor: Arc<Background>,
 27}
 28
 29#[derive(Serialize)]
 30struct OpenAIEmbeddingRequest<'a> {
 31    model: &'static str,
 32    input: Vec<&'a str>,
 33}
 34
 35#[derive(Deserialize)]
 36struct OpenAIEmbeddingResponse {
 37    data: Vec<OpenAIEmbedding>,
 38    usage: OpenAIEmbeddingUsage,
 39}
 40
 41#[derive(Debug, Deserialize)]
 42struct OpenAIEmbedding {
 43    embedding: Vec<f32>,
 44    index: usize,
 45    object: String,
 46}
 47
 48#[derive(Deserialize)]
 49struct OpenAIEmbeddingUsage {
 50    prompt_tokens: usize,
 51    total_tokens: usize,
 52}
 53
 54#[async_trait]
 55pub trait EmbeddingProvider: Sync + Send {
 56    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
 57    fn max_tokens_per_batch(&self) -> usize;
 58    fn truncate(&self, span: &str) -> (String, usize);
 59}
 60
 61pub struct DummyEmbeddings {}
 62
 63#[async_trait]
 64impl EmbeddingProvider for DummyEmbeddings {
 65    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
 66        // 1024 is the OpenAI Embeddings size for ada models.
 67        // the model we will likely be starting with.
 68        let dummy_vec = vec![0.32 as f32; 1536];
 69        return Ok(vec![dummy_vec; spans.len()]);
 70    }
 71
 72    fn max_tokens_per_batch(&self) -> usize {
 73        OPENAI_INPUT_LIMIT
 74    }
 75
 76    fn truncate(&self, span: &str) -> (String, usize) {
 77        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
 78        let token_count = tokens.len();
 79        let output = if token_count > OPENAI_INPUT_LIMIT {
 80            tokens.truncate(OPENAI_INPUT_LIMIT);
 81            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
 82            new_input.ok().unwrap_or_else(|| span.to_string())
 83        } else {
 84            span.to_string()
 85        };
 86
 87        (output, tokens.len())
 88    }
 89}
 90
 91const OPENAI_INPUT_LIMIT: usize = 8190;
 92
 93impl OpenAIEmbeddings {
 94    async fn send_request(
 95        &self,
 96        api_key: &str,
 97        spans: Vec<&str>,
 98        request_timeout: u64,
 99    ) -> Result<Response<AsyncBody>> {
100        let request = Request::post("https://api.openai.com/v1/embeddings")
101            .redirect_policy(isahc::config::RedirectPolicy::Follow)
102            .timeout(Duration::from_secs(request_timeout))
103            .header("Content-Type", "application/json")
104            .header("Authorization", format!("Bearer {}", api_key))
105            .body(
106                serde_json::to_string(&OpenAIEmbeddingRequest {
107                    input: spans.clone(),
108                    model: "text-embedding-ada-002",
109                })
110                .unwrap()
111                .into(),
112            )?;
113
114        Ok(self.client.send(request).await?)
115    }
116}
117
118#[async_trait]
119impl EmbeddingProvider for OpenAIEmbeddings {
120    fn max_tokens_per_batch(&self) -> usize {
121        50000
122    }
123
124    fn truncate(&self, span: &str) -> (String, usize) {
125        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
126        let token_count = tokens.len();
127        let output = if token_count > OPENAI_INPUT_LIMIT {
128            tokens.truncate(OPENAI_INPUT_LIMIT);
129            OPENAI_BPE_TOKENIZER
130                .decode(tokens)
131                .ok()
132                .unwrap_or_else(|| span.to_string())
133        } else {
134            span.to_string()
135        };
136
137        (output, token_count)
138    }
139
140    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
141        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
142        const MAX_RETRIES: usize = 4;
143
144        let api_key = OPENAI_API_KEY
145            .as_ref()
146            .ok_or_else(|| anyhow!("no api key"))?;
147
148        let mut request_number = 0;
149        let mut request_timeout: u64 = 10;
150        let mut response: Response<AsyncBody>;
151        while request_number < MAX_RETRIES {
152            response = self
153                .send_request(
154                    api_key,
155                    spans.iter().map(|x| &**x).collect(),
156                    request_timeout,
157                )
158                .await?;
159            request_number += 1;
160
161            match response.status() {
162                StatusCode::REQUEST_TIMEOUT => {
163                    request_timeout += 5;
164                }
165                StatusCode::OK => {
166                    let mut body = String::new();
167                    response.body_mut().read_to_string(&mut body).await?;
168                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
169
170                    log::trace!(
171                        "openai embedding completed. tokens: {:?}",
172                        response.usage.total_tokens
173                    );
174
175                    return Ok(response
176                        .data
177                        .into_iter()
178                        .map(|embedding| embedding.embedding)
179                        .collect());
180                }
181                StatusCode::TOO_MANY_REQUESTS => {
182                    let mut body = String::new();
183                    response.body_mut().read_to_string(&mut body).await?;
184
185                    let delay_duration = {
186                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
187                        if let Some(time_to_reset) =
188                            response.headers().get("x-ratelimit-reset-tokens")
189                        {
190                            if let Ok(time_str) = time_to_reset.to_str() {
191                                parse(time_str).unwrap_or(delay)
192                            } else {
193                                delay
194                            }
195                        } else {
196                            delay
197                        }
198                    };
199
200                    log::trace!(
201                        "openai rate limiting: waiting {:?} until lifted",
202                        &delay_duration
203                    );
204
205                    self.executor.timer(delay_duration).await;
206                }
207                _ => {
208                    let mut body = String::new();
209                    response.body_mut().read_to_string(&mut body).await?;
210                    return Err(anyhow!(
211                        "open ai bad request: {:?} {:?}",
212                        &response.status(),
213                        body
214                    ));
215                }
216            }
217        }
218        Err(anyhow!("openai max retries"))
219    }
220}