embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use parse_duration::parse;
 11use serde::{Deserialize, Serialize};
 12use std::env;
 13use std::sync::Arc;
 14use std::time::Duration;
 15use tiktoken_rs::{cl100k_base, CoreBPE};
 16use util::http::{HttpClient, Request};
 17
 18lazy_static! {
 19    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 20    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 21}
 22
 23#[derive(Clone)]
 24pub struct OpenAIEmbeddings {
 25    pub client: Arc<dyn HttpClient>,
 26    pub executor: Arc<Background>,
 27}
 28
 29#[derive(Serialize)]
 30struct OpenAIEmbeddingRequest<'a> {
 31    model: &'static str,
 32    input: Vec<&'a str>,
 33}
 34
 35#[derive(Deserialize)]
 36struct OpenAIEmbeddingResponse {
 37    data: Vec<OpenAIEmbedding>,
 38    usage: OpenAIEmbeddingUsage,
 39}
 40
 41#[derive(Debug, Deserialize)]
 42struct OpenAIEmbedding {
 43    embedding: Vec<f32>,
 44    index: usize,
 45    object: String,
 46}
 47
 48#[derive(Deserialize)]
 49struct OpenAIEmbeddingUsage {
 50    prompt_tokens: usize,
 51    total_tokens: usize,
 52}
 53
 54#[async_trait]
 55pub trait EmbeddingProvider: Sync + Send {
 56    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 57}
 58
 59pub struct DummyEmbeddings {}
 60
 61#[async_trait]
 62impl EmbeddingProvider for DummyEmbeddings {
 63    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 64        // 1024 is the OpenAI Embeddings size for ada models.
 65        // the model we will likely be starting with.
 66        let dummy_vec = vec![0.32 as f32; 1536];
 67        return Ok(vec![dummy_vec; spans.len()]);
 68    }
 69}
 70
 71const OPENAI_INPUT_LIMIT: usize = 8190;
 72
 73impl OpenAIEmbeddings {
 74    fn truncate(span: String) -> String {
 75        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
 76        if tokens.len() > OPENAI_INPUT_LIMIT {
 77            tokens.truncate(OPENAI_INPUT_LIMIT);
 78            let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
 79            if result.is_ok() {
 80                let transformed = result.unwrap();
 81                return transformed;
 82            }
 83        }
 84
 85        span
 86    }
 87
 88    async fn send_request(
 89        &self,
 90        api_key: &str,
 91        spans: Vec<&str>,
 92        request_timeout: u64,
 93    ) -> Result<Response<AsyncBody>> {
 94        let request = Request::post("https://api.openai.com/v1/embeddings")
 95            .redirect_policy(isahc::config::RedirectPolicy::Follow)
 96            .timeout(Duration::from_secs(request_timeout))
 97            .header("Content-Type", "application/json")
 98            .header("Authorization", format!("Bearer {}", api_key))
 99            .body(
100                serde_json::to_string(&OpenAIEmbeddingRequest {
101                    input: spans.clone(),
102                    model: "text-embedding-ada-002",
103                })
104                .unwrap()
105                .into(),
106            )?;
107
108        Ok(self.client.send(request).await?)
109    }
110}
111
112#[async_trait]
113impl EmbeddingProvider for OpenAIEmbeddings {
114    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
115        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
116        const MAX_RETRIES: usize = 4;
117
118        let api_key = OPENAI_API_KEY
119            .as_ref()
120            .ok_or_else(|| anyhow!("no api key"))?;
121
122        let mut request_number = 0;
123        let mut request_timeout: u64 = 10;
124        let mut truncated = false;
125        let mut response: Response<AsyncBody>;
126        let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
127        while request_number < MAX_RETRIES {
128            response = self
129                .send_request(
130                    api_key,
131                    spans.iter().map(|x| &**x).collect(),
132                    request_timeout,
133                )
134                .await?;
135            request_number += 1;
136
137            match response.status() {
138                StatusCode::REQUEST_TIMEOUT => {
139                    request_timeout += 5;
140                }
141                StatusCode::OK => {
142                    let mut body = String::new();
143                    response.body_mut().read_to_string(&mut body).await?;
144                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
145
146                    log::trace!(
147                        "openai embedding completed. tokens: {:?}",
148                        response.usage.total_tokens
149                    );
150
151                    return Ok(response
152                        .data
153                        .into_iter()
154                        .map(|embedding| embedding.embedding)
155                        .collect());
156                }
157                StatusCode::TOO_MANY_REQUESTS => {
158                    let mut body = String::new();
159                    response.body_mut().read_to_string(&mut body).await?;
160
161                    let delay_duration = {
162                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
163                        if let Some(time_to_reset) =
164                            response.headers().get("x-ratelimit-reset-tokens")
165                        {
166                            if let Ok(time_str) = time_to_reset.to_str() {
167                                parse(time_str).unwrap_or(delay)
168                            } else {
169                                delay
170                            }
171                        } else {
172                            delay
173                        }
174                    };
175
176                    log::trace!(
177                        "openai rate limiting: waiting {:?} until lifted",
178                        &delay_duration
179                    );
180
181                    self.executor.timer(delay_duration).await;
182                }
183                _ => {
184                    // TODO: Move this to parsing step
185                    // Only truncate if it hasnt been truncated before
186                    if !truncated {
187                        for span in spans.iter_mut() {
188                            *span = Self::truncate(span.clone());
189                        }
190                        truncated = true;
191                    } else {
192                        // If failing once already truncated, log the error and break the loop
193                        let mut body = String::new();
194                        response.body_mut().read_to_string(&mut body).await?;
195                        return Err(anyhow!(
196                            "open ai bad request: {:?} {:?}",
197                            &response.status(),
198                            body
199                        ));
200                    }
201                }
202            }
203        }
204        Err(anyhow!("openai max retries"))
205    }
206}