1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::serde_json;
5use isahc::prelude::Configurable;
6use lazy_static::lazy_static;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use std::{env, time::Instant};
10use util::http::{HttpClient, Request};
11
12lazy_static! {
13 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
14}
15
16#[derive(Clone)]
17pub struct OpenAIEmbeddings {
18 pub client: Arc<dyn HttpClient>,
19}
20
21#[derive(Serialize)]
22struct OpenAIEmbeddingRequest<'a> {
23 model: &'static str,
24 input: Vec<&'a str>,
25}
26
27#[derive(Deserialize)]
28struct OpenAIEmbeddingResponse {
29 data: Vec<OpenAIEmbedding>,
30 usage: OpenAIEmbeddingUsage,
31}
32
33#[derive(Debug, Deserialize)]
34struct OpenAIEmbedding {
35 embedding: Vec<f32>,
36 index: usize,
37 object: String,
38}
39
40#[derive(Deserialize)]
41struct OpenAIEmbeddingUsage {
42 prompt_tokens: usize,
43 total_tokens: usize,
44}
45
46#[async_trait]
47pub trait EmbeddingProvider: Sync + Send {
48 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
49}
50
51pub struct DummyEmbeddings {}
52
53#[async_trait]
54impl EmbeddingProvider for DummyEmbeddings {
55 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
56 // 1024 is the OpenAI Embeddings size for ada models.
57 // the model we will likely be starting with.
58 let dummy_vec = vec![0.32 as f32; 1536];
59 return Ok(vec![dummy_vec; spans.len()]);
60 }
61}
62
63// impl OpenAIEmbeddings {
64// async fn truncate(span: &str) -> String {
65// let bpe = cl100k_base().unwrap();
66// let mut tokens = bpe.encode_with_special_tokens(span);
67// if tokens.len() > 8192 {
68// tokens.truncate(8192);
69// let result = bpe.decode(tokens);
70// if result.is_ok() {
71// return result.unwrap();
72// }
73// }
74
75// return span.to_string();
76// }
77// }
78
79#[async_trait]
80impl EmbeddingProvider for OpenAIEmbeddings {
81 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
82 // Truncate spans to 8192 if needed
83 // let t0 = Instant::now();
84 // let mut truncated_spans = vec![];
85 // for span in spans {
86 // truncated_spans.push(Self::truncate(span));
87 // }
88 // let spans = futures::future::join_all(truncated_spans).await;
89 // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs());
90
91 let api_key = OPENAI_API_KEY
92 .as_ref()
93 .ok_or_else(|| anyhow!("no api key"))?;
94
95 let request = Request::post("https://api.openai.com/v1/embeddings")
96 .redirect_policy(isahc::config::RedirectPolicy::Follow)
97 .header("Content-Type", "application/json")
98 .header("Authorization", format!("Bearer {}", api_key))
99 .body(
100 serde_json::to_string(&OpenAIEmbeddingRequest {
101 input: spans,
102 model: "text-embedding-ada-002",
103 })
104 .unwrap()
105 .into(),
106 )?;
107
108 let mut response = self.client.send(request).await?;
109 if !response.status().is_success() {
110 return Err(anyhow!("openai embedding failed {}", response.status()));
111 }
112
113 let mut body = String::new();
114 response.body_mut().read_to_string(&mut body).await?;
115 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
116
117 log::info!(
118 "openai embedding completed. tokens: {:?}",
119 response.usage.total_tokens
120 );
121
122 Ok(response
123 .data
124 .into_iter()
125 .map(|embedding| embedding.embedding)
126 .collect())
127 }
128}