embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::serde_json;
  5use isahc::prelude::Configurable;
  6use lazy_static::lazy_static;
  7use serde::{Deserialize, Serialize};
  8use std::sync::Arc;
  9use std::{env, time::Instant};
 10use util::http::{HttpClient, Request};
 11
 12lazy_static! {
 13    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 14}
 15
 16#[derive(Clone)]
 17pub struct OpenAIEmbeddings {
 18    pub client: Arc<dyn HttpClient>,
 19}
 20
 21#[derive(Serialize)]
 22struct OpenAIEmbeddingRequest<'a> {
 23    model: &'static str,
 24    input: Vec<&'a str>,
 25}
 26
 27#[derive(Deserialize)]
 28struct OpenAIEmbeddingResponse {
 29    data: Vec<OpenAIEmbedding>,
 30    usage: OpenAIEmbeddingUsage,
 31}
 32
 33#[derive(Debug, Deserialize)]
 34struct OpenAIEmbedding {
 35    embedding: Vec<f32>,
 36    index: usize,
 37    object: String,
 38}
 39
 40#[derive(Deserialize)]
 41struct OpenAIEmbeddingUsage {
 42    prompt_tokens: usize,
 43    total_tokens: usize,
 44}
 45
 46#[async_trait]
 47pub trait EmbeddingProvider: Sync + Send {
 48    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 49}
 50
 51pub struct DummyEmbeddings {}
 52
 53#[async_trait]
 54impl EmbeddingProvider for DummyEmbeddings {
 55    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 56        // 1024 is the OpenAI Embeddings size for ada models.
 57        // the model we will likely be starting with.
 58        let dummy_vec = vec![0.32 as f32; 1536];
 59        return Ok(vec![dummy_vec; spans.len()]);
 60    }
 61}
 62
 63// impl OpenAIEmbeddings {
 64//     async fn truncate(span: &str) -> String {
 65//         let bpe = cl100k_base().unwrap();
 66//         let mut tokens = bpe.encode_with_special_tokens(span);
 67//         if tokens.len() > 8192 {
 68//             tokens.truncate(8192);
 69//             let result = bpe.decode(tokens);
 70//             if result.is_ok() {
 71//                 return result.unwrap();
 72//             }
 73//         }
 74
 75//         return span.to_string();
 76//     }
 77// }
 78
 79#[async_trait]
 80impl EmbeddingProvider for OpenAIEmbeddings {
 81    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 82        // Truncate spans to 8192 if needed
 83        // let t0 = Instant::now();
 84        // let mut truncated_spans = vec![];
 85        // for span in spans {
 86        //     truncated_spans.push(Self::truncate(span));
 87        // }
 88        // let spans = futures::future::join_all(truncated_spans).await;
 89        // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs());
 90
 91        let api_key = OPENAI_API_KEY
 92            .as_ref()
 93            .ok_or_else(|| anyhow!("no api key"))?;
 94
 95        let request = Request::post("https://api.openai.com/v1/embeddings")
 96            .redirect_policy(isahc::config::RedirectPolicy::Follow)
 97            .header("Content-Type", "application/json")
 98            .header("Authorization", format!("Bearer {}", api_key))
 99            .body(
100                serde_json::to_string(&OpenAIEmbeddingRequest {
101                    input: spans,
102                    model: "text-embedding-ada-002",
103                })
104                .unwrap()
105                .into(),
106            )?;
107
108        let mut response = self.client.send(request).await?;
109        if !response.status().is_success() {
110            return Err(anyhow!("openai embedding failed {}", response.status()));
111        }
112
113        let mut body = String::new();
114        response.body_mut().read_to_string(&mut body).await?;
115        let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
116
117        log::info!(
118            "openai embedding completed. tokens: {:?}",
119            response.usage.total_tokens
120        );
121
122        Ok(response
123            .data
124            .into_iter()
125            .map(|embedding| embedding.embedding)
126            .collect())
127    }
128}