embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::executor::Background;
  5use gpui::serde_json;
  6use isahc::http::StatusCode;
  7use isahc::prelude::Configurable;
  8use isahc::{AsyncBody, Response};
  9use lazy_static::lazy_static;
 10use serde::{Deserialize, Serialize};
 11use std::env;
 12use std::sync::Arc;
 13use std::time::Duration;
 14use tiktoken_rs::{cl100k_base, CoreBPE};
 15use util::http::{HttpClient, Request};
 16
 17lazy_static! {
 18    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 19    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 20}
 21
 22#[derive(Clone)]
 23pub struct OpenAIEmbeddings {
 24    pub client: Arc<dyn HttpClient>,
 25    pub executor: Arc<Background>,
 26}
 27
 28#[derive(Serialize)]
 29struct OpenAIEmbeddingRequest<'a> {
 30    model: &'static str,
 31    input: Vec<&'a str>,
 32}
 33
 34#[derive(Deserialize)]
 35struct OpenAIEmbeddingResponse {
 36    data: Vec<OpenAIEmbedding>,
 37    usage: OpenAIEmbeddingUsage,
 38}
 39
 40#[derive(Debug, Deserialize)]
 41struct OpenAIEmbedding {
 42    embedding: Vec<f32>,
 43    index: usize,
 44    object: String,
 45}
 46
 47#[derive(Deserialize)]
 48struct OpenAIEmbeddingUsage {
 49    prompt_tokens: usize,
 50    total_tokens: usize,
 51}
 52
 53#[async_trait]
 54pub trait EmbeddingProvider: Sync + Send {
 55    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 56}
 57
 58pub struct DummyEmbeddings {}
 59
 60#[async_trait]
 61impl EmbeddingProvider for DummyEmbeddings {
 62    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 63        // 1024 is the OpenAI Embeddings size for ada models.
 64        // the model we will likely be starting with.
 65        let dummy_vec = vec![0.32 as f32; 1536];
 66        return Ok(vec![dummy_vec; spans.len()]);
 67    }
 68}
 69
 70const OPENAI_INPUT_LIMIT: usize = 8190;
 71
 72impl OpenAIEmbeddings {
 73    fn truncate(span: String) -> String {
 74        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
 75        if tokens.len() > OPENAI_INPUT_LIMIT {
 76            tokens.truncate(OPENAI_INPUT_LIMIT);
 77            let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
 78            if result.is_ok() {
 79                let transformed = result.unwrap();
 80                return transformed;
 81            }
 82        }
 83
 84        span
 85    }
 86
 87    async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
 88        let request = Request::post("https://api.openai.com/v1/embeddings")
 89            .redirect_policy(isahc::config::RedirectPolicy::Follow)
 90            .timeout(Duration::from_secs(4))
 91            .header("Content-Type", "application/json")
 92            .header("Authorization", format!("Bearer {}", api_key))
 93            .body(
 94                serde_json::to_string(&OpenAIEmbeddingRequest {
 95                    input: spans.clone(),
 96                    model: "text-embedding-ada-002",
 97                })
 98                .unwrap()
 99                .into(),
100            )?;
101
102        Ok(self.client.send(request).await?)
103    }
104}
105
106#[async_trait]
107impl EmbeddingProvider for OpenAIEmbeddings {
108    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
109        const BACKOFF_SECONDS: [usize; 3] = [45, 75, 125];
110        const MAX_RETRIES: usize = 3;
111
112        let api_key = OPENAI_API_KEY
113            .as_ref()
114            .ok_or_else(|| anyhow!("no api key"))?;
115
116        let mut request_number = 0;
117        let mut truncated = false;
118        let mut response: Response<AsyncBody>;
119        let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
120        while request_number < MAX_RETRIES {
121            response = self
122                .send_request(api_key, spans.iter().map(|x| &**x).collect())
123                .await?;
124            request_number += 1;
125
126            if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
127                return Err(anyhow!(
128                    "openai max retries, error: {:?}",
129                    &response.status()
130                ));
131            }
132
133            match response.status() {
134                StatusCode::TOO_MANY_REQUESTS => {
135                    let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
136                    log::trace!(
137                        "open ai rate limiting, delaying request by {:?} seconds",
138                        delay.as_secs()
139                    );
140                    self.executor.timer(delay).await;
141                }
142                StatusCode::BAD_REQUEST => {
143                    // Only truncate if it hasnt been truncated before
144                    if !truncated {
145                        for span in spans.iter_mut() {
146                            *span = Self::truncate(span.clone());
147                        }
148                        truncated = true;
149                    } else {
150                        // If failing once already truncated, log the error and break the loop
151                        let mut body = String::new();
152                        response.body_mut().read_to_string(&mut body).await?;
153                        log::trace!("open ai bad request: {:?} {:?}", &response.status(), body);
154                        break;
155                    }
156                }
157                StatusCode::OK => {
158                    let mut body = String::new();
159                    response.body_mut().read_to_string(&mut body).await?;
160                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
161
162                    log::trace!(
163                        "openai embedding completed. tokens: {:?}",
164                        response.usage.total_tokens
165                    );
166                    return Ok(response
167                        .data
168                        .into_iter()
169                        .map(|embedding| embedding.embedding)
170                        .collect());
171                }
172                _ => {
173                    return Err(anyhow!("openai embedding failed {}", response.status()));
174                }
175            }
176        }
177
178        Err(anyhow!("openai embedding failed"))
179    }
180}