embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::serde_json;
  5use isahc::http::StatusCode;
  6use isahc::prelude::Configurable;
  7use isahc::{AsyncBody, Response};
  8use lazy_static::lazy_static;
  9use serde::{Deserialize, Serialize};
 10use std::env;
 11use std::sync::Arc;
 12use std::time::Duration;
 13use tiktoken_rs::{cl100k_base, CoreBPE};
 14use util::http::{HttpClient, Request};
 15
 16lazy_static! {
 17    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 18    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 19}
 20
 21#[derive(Clone)]
 22pub struct OpenAIEmbeddings {
 23    pub client: Arc<dyn HttpClient>,
 24}
 25
 26#[derive(Serialize)]
 27struct OpenAIEmbeddingRequest<'a> {
 28    model: &'static str,
 29    input: Vec<&'a str>,
 30}
 31
 32#[derive(Deserialize)]
 33struct OpenAIEmbeddingResponse {
 34    data: Vec<OpenAIEmbedding>,
 35    usage: OpenAIEmbeddingUsage,
 36}
 37
 38#[derive(Debug, Deserialize)]
 39struct OpenAIEmbedding {
 40    embedding: Vec<f32>,
 41    index: usize,
 42    object: String,
 43}
 44
 45#[derive(Deserialize)]
 46struct OpenAIEmbeddingUsage {
 47    prompt_tokens: usize,
 48    total_tokens: usize,
 49}
 50
 51#[async_trait]
 52pub trait EmbeddingProvider: Sync + Send {
 53    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 54}
 55
 56pub struct DummyEmbeddings {}
 57
 58#[async_trait]
 59impl EmbeddingProvider for DummyEmbeddings {
 60    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 61        // 1024 is the OpenAI Embeddings size for ada models.
 62        // the model we will likely be starting with.
 63        let dummy_vec = vec![0.32 as f32; 1536];
 64        return Ok(vec![dummy_vec; spans.len()]);
 65    }
 66}
 67
 68impl OpenAIEmbeddings {
 69    async fn truncate(span: String) -> String {
 70        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
 71        if tokens.len() > 8190 {
 72            tokens.truncate(8190);
 73            let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
 74            if result.is_ok() {
 75                let transformed = result.unwrap();
 76                // assert_ne!(transformed, span);
 77                return transformed;
 78            }
 79        }
 80
 81        return span.to_string();
 82    }
 83
 84    async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
 85        let request = Request::post("https://api.openai.com/v1/embeddings")
 86            .redirect_policy(isahc::config::RedirectPolicy::Follow)
 87            .header("Content-Type", "application/json")
 88            .header("Authorization", format!("Bearer {}", api_key))
 89            .body(
 90                serde_json::to_string(&OpenAIEmbeddingRequest {
 91                    input: spans.clone(),
 92                    model: "text-embedding-ada-002",
 93                })
 94                .unwrap()
 95                .into(),
 96            )?;
 97
 98        Ok(self.client.send(request).await?)
 99    }
100}
101
102#[async_trait]
103impl EmbeddingProvider for OpenAIEmbeddings {
104    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
105        const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
106        const MAX_RETRIES: usize = 3;
107
108        let api_key = OPENAI_API_KEY
109            .as_ref()
110            .ok_or_else(|| anyhow!("no api key"))?;
111
112        let mut request_number = 0;
113        let mut response: Response<AsyncBody>;
114        let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
115        while request_number < MAX_RETRIES {
116            response = self
117                .send_request(api_key, spans.iter().map(|x| &**x).collect())
118                .await?;
119            request_number += 1;
120
121            if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
122                return Err(anyhow!(
123                    "openai max retries, error: {:?}",
124                    &response.status()
125                ));
126            }
127
128            match response.status() {
129                StatusCode::TOO_MANY_REQUESTS => {
130                    let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
131                    std::thread::sleep(delay);
132                }
133                StatusCode::BAD_REQUEST => {
134                    log::info!("BAD REQUEST: {:?}", &response.status());
135                    // Don't worry about delaying bad request, as we can assume
136                    // we haven't been rate limited yet.
137                    for span in spans.iter_mut() {
138                        *span = Self::truncate(span.to_string()).await;
139                    }
140                }
141                StatusCode::OK => {
142                    let mut body = String::new();
143                    response.body_mut().read_to_string(&mut body).await?;
144                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
145
146                    log::info!(
147                        "openai embedding completed. tokens: {:?}",
148                        response.usage.total_tokens
149                    );
150                    return Ok(response
151                        .data
152                        .into_iter()
153                        .map(|embedding| embedding.embedding)
154                        .collect());
155                }
156                _ => {
157                    return Err(anyhow!("openai embedding failed {}", response.status()));
158                }
159            }
160        }
161
162        Err(anyhow!("openai embedding failed"))
163    }
164}