1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::serde_json;
5use isahc::prelude::Configurable;
6use lazy_static::lazy_static;
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::sync::Arc;
10use util::http::{HttpClient, Request};
11
12lazy_static! {
13 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
14}
15
16pub struct OpenAIEmbeddings {
17 pub client: Arc<dyn HttpClient>,
18}
19
20#[derive(Serialize)]
21struct OpenAIEmbeddingRequest<'a> {
22 model: &'static str,
23 input: Vec<&'a str>,
24}
25
26#[derive(Deserialize)]
27struct OpenAIEmbeddingResponse {
28 data: Vec<OpenAIEmbedding>,
29 usage: OpenAIEmbeddingUsage,
30}
31
32#[derive(Debug, Deserialize)]
33struct OpenAIEmbedding {
34 embedding: Vec<f32>,
35 index: usize,
36 object: String,
37}
38
39#[derive(Deserialize)]
40struct OpenAIEmbeddingUsage {
41 prompt_tokens: usize,
42 total_tokens: usize,
43}
44
45#[async_trait]
46pub trait EmbeddingProvider: Sync {
47 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
48}
49
50#[async_trait]
51impl EmbeddingProvider for OpenAIEmbeddings {
52 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
53 let api_key = OPENAI_API_KEY
54 .as_ref()
55 .ok_or_else(|| anyhow!("no api key"))?;
56
57 let request = Request::post("https://api.openai.com/v1/embeddings")
58 .redirect_policy(isahc::config::RedirectPolicy::Follow)
59 .header("Content-Type", "application/json")
60 .header("Authorization", format!("Bearer {}", api_key))
61 .body(
62 serde_json::to_string(&OpenAIEmbeddingRequest {
63 input: spans,
64 model: "text-embedding-ada-002",
65 })
66 .unwrap()
67 .into(),
68 )?;
69
70 let mut response = self.client.send(request).await?;
71 if !response.status().is_success() {
72 return Err(anyhow!("openai embedding failed {}", response.status()));
73 }
74
75 let mut body = String::new();
76 response.body_mut().read_to_string(&mut body).await?;
77 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
78
79 log::info!(
80 "openai embedding completed. tokens: {:?}",
81 response.usage.total_tokens
82 );
83
84 // do we need to re-order these based on the `index` field?
85 eprintln!(
86 "indices: {:?}",
87 response
88 .data
89 .iter()
90 .map(|embedding| embedding.index)
91 .collect::<Vec<_>>()
92 );
93
94 Ok(response
95 .data
96 .into_iter()
97 .map(|embedding| embedding.embedding)
98 .collect())
99 }
100}