embedding.rs

  1use anyhow::{anyhow, Result};
  2use async_trait::async_trait;
  3use futures::AsyncReadExt;
  4use gpui::serde_json;
  5use isahc::prelude::Configurable;
  6use lazy_static::lazy_static;
  7use serde::{Deserialize, Serialize};
  8use std::env;
  9use std::sync::Arc;
 10use util::http::{HttpClient, Request};
 11
 12lazy_static! {
 13    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 14}
 15
 16pub struct OpenAIEmbeddings {
 17    pub client: Arc<dyn HttpClient>,
 18}
 19
 20#[derive(Serialize)]
 21struct OpenAIEmbeddingRequest<'a> {
 22    model: &'static str,
 23    input: Vec<&'a str>,
 24}
 25
 26#[derive(Deserialize)]
 27struct OpenAIEmbeddingResponse {
 28    data: Vec<OpenAIEmbedding>,
 29    usage: OpenAIEmbeddingUsage,
 30}
 31
 32#[derive(Debug, Deserialize)]
 33struct OpenAIEmbedding {
 34    embedding: Vec<f32>,
 35    index: usize,
 36    object: String,
 37}
 38
 39#[derive(Deserialize)]
 40struct OpenAIEmbeddingUsage {
 41    prompt_tokens: usize,
 42    total_tokens: usize,
 43}
 44
 45#[async_trait]
 46pub trait EmbeddingProvider: Sync {
 47    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 48}
 49
 50#[async_trait]
 51impl EmbeddingProvider for OpenAIEmbeddings {
 52    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
 53        let api_key = OPENAI_API_KEY
 54            .as_ref()
 55            .ok_or_else(|| anyhow!("no api key"))?;
 56
 57        let request = Request::post("https://api.openai.com/v1/embeddings")
 58            .redirect_policy(isahc::config::RedirectPolicy::Follow)
 59            .header("Content-Type", "application/json")
 60            .header("Authorization", format!("Bearer {}", api_key))
 61            .body(
 62                serde_json::to_string(&OpenAIEmbeddingRequest {
 63                    input: spans,
 64                    model: "text-embedding-ada-002",
 65                })
 66                .unwrap()
 67                .into(),
 68            )?;
 69
 70        let mut response = self.client.send(request).await?;
 71        if !response.status().is_success() {
 72            return Err(anyhow!("openai embedding failed {}", response.status()));
 73        }
 74
 75        let mut body = String::new();
 76        response.body_mut().read_to_string(&mut body).await?;
 77        let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
 78
 79        log::info!(
 80            "openai embedding completed. tokens: {:?}",
 81            response.usage.total_tokens
 82        );
 83
 84        // do we need to re-order these based on the `index` field?
 85        eprintln!(
 86            "indices: {:?}",
 87            response
 88                .data
 89                .iter()
 90                .map(|embedding| embedding.index)
 91                .collect::<Vec<_>>()
 92        );
 93
 94        Ok(response
 95            .data
 96            .into_iter()
 97            .map(|embedding| embedding.embedding)
 98            .collect())
 99    }
100}