1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::serde_json;
5use isahc::prelude::Configurable;
6use lazy_static::lazy_static;
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::sync::Arc;
10use util::http::{HttpClient, Request};
11
12lazy_static! {
13 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
14}
15
16#[derive(Clone)]
17pub struct OpenAIEmbeddings {
18 pub client: Arc<dyn HttpClient>,
19}
20
21#[derive(Serialize)]
22struct OpenAIEmbeddingRequest<'a> {
23 model: &'static str,
24 input: Vec<&'a str>,
25}
26
27#[derive(Deserialize)]
28struct OpenAIEmbeddingResponse {
29 data: Vec<OpenAIEmbedding>,
30 usage: OpenAIEmbeddingUsage,
31}
32
33#[derive(Debug, Deserialize)]
34struct OpenAIEmbedding {
35 embedding: Vec<f32>,
36 index: usize,
37 object: String,
38}
39
40#[derive(Deserialize)]
41struct OpenAIEmbeddingUsage {
42 prompt_tokens: usize,
43 total_tokens: usize,
44}
45
46#[async_trait]
47pub trait EmbeddingProvider: Sync + Send {
48 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
49}
50
51pub struct DummyEmbeddings {}
52
53#[async_trait]
54impl EmbeddingProvider for DummyEmbeddings {
55 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
56 // 1024 is the OpenAI Embeddings size for ada models.
57 // the model we will likely be starting with.
58 let dummy_vec = vec![0.32 as f32; 1536];
59 return Ok(vec![dummy_vec; spans.len()]);
60 }
61}
62
63#[async_trait]
64impl EmbeddingProvider for OpenAIEmbeddings {
65 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
66 let api_key = OPENAI_API_KEY
67 .as_ref()
68 .ok_or_else(|| anyhow!("no api key"))?;
69
70 let request = Request::post("https://api.openai.com/v1/embeddings")
71 .redirect_policy(isahc::config::RedirectPolicy::Follow)
72 .header("Content-Type", "application/json")
73 .header("Authorization", format!("Bearer {}", api_key))
74 .body(
75 serde_json::to_string(&OpenAIEmbeddingRequest {
76 input: spans,
77 model: "text-embedding-ada-002",
78 })
79 .unwrap()
80 .into(),
81 )?;
82
83 let mut response = self.client.send(request).await?;
84 if !response.status().is_success() {
85 return Err(anyhow!("openai embedding failed {}", response.status()));
86 }
87
88 let mut body = String::new();
89 response.body_mut().read_to_string(&mut body).await?;
90 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
91
92 log::info!(
93 "openai embedding completed. tokens: {:?}",
94 response.usage.total_tokens
95 );
96
97 Ok(response
98 .data
99 .into_iter()
100 .map(|embedding| embedding.embedding)
101 .collect())
102 }
103}