embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::serde_json;
  5use isahc::prelude::Configurable;
  6use lazy_static::lazy_static;
  7use serde::{Deserialize, Serialize};
  8use std::env;
  9use std::sync::Arc;
 10use util::http::{HttpClient, Request};
 11
 12lazy_static! {
 13    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 14}
 15
 16#[derive(Clone)]
 17pub struct OpenAIEmbeddings {
 18    pub client: Arc<dyn HttpClient>,
 19}
 20
 21#[derive(Serialize)]
 22struct OpenAIEmbeddingRequest<'a> {
 23    model: &'static str,
 24    input: Vec<&'a str>,
 25}
 26
 27#[derive(Deserialize)]
 28struct OpenAIEmbeddingResponse {
 29    data: Vec<OpenAIEmbedding>,
 30    usage: OpenAIEmbeddingUsage,
 31}
 32
 33#[derive(Debug, Deserialize)]
 34struct OpenAIEmbedding {
 35    embedding: Vec<f32>,
 36    index: usize,
 37    object: String,
 38}
 39
 40#[derive(Deserialize)]
 41struct OpenAIEmbeddingUsage {
 42    prompt_tokens: usize,
 43    total_tokens: usize,
 44}
 45
 46#[async_trait]
 47pub trait EmbeddingProvider: Sync + Send {
 48    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 49}
 50
 51pub struct DummyEmbeddings {}
 52
 53#[async_trait]
 54impl EmbeddingProvider for DummyEmbeddings {
 55    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 56        // 1024 is the OpenAI Embeddings size for ada models.
 57        // the model we will likely be starting with.
 58        let dummy_vec = vec![0.32 as f32; 1536];
 59        return Ok(vec![dummy_vec; spans.len()]);
 60    }
 61}
 62
 63#[async_trait]
 64impl EmbeddingProvider for OpenAIEmbeddings {
 65    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 66        let api_key = OPENAI_API_KEY
 67            .as_ref()
 68            .ok_or_else(|| anyhow!("no api key"))?;
 69
 70        let request = Request::post("https://api.openai.com/v1/embeddings")
 71            .redirect_policy(isahc::config::RedirectPolicy::Follow)
 72            .header("Content-Type", "application/json")
 73            .header("Authorization", format!("Bearer {}", api_key))
 74            .body(
 75                serde_json::to_string(&OpenAIEmbeddingRequest {
 76                    input: spans,
 77                    model: "text-embedding-ada-002",
 78                })
 79                .unwrap()
 80                .into(),
 81            )?;
 82
 83        let mut response = self.client.send(request).await?;
 84        if !response.status().is_success() {
 85            return Err(anyhow!("openai embedding failed {}", response.status()));
 86        }
 87
 88        let mut body = String::new();
 89        response.body_mut().read_to_string(&mut body).await?;
 90        let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
 91
 92        log::info!(
 93            "openai embedding completed. tokens: {:?}",
 94            response.usage.total_tokens
 95        );
 96
 97        Ok(response
 98            .data
 99            .into_iter()
100            .map(|embedding| embedding.embedding)
101            .collect())
102    }
103}