1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parse_duration::parse;
11use serde::{Deserialize, Serialize};
12use std::env;
13use std::sync::Arc;
14use std::time::Duration;
15use tiktoken_rs::{cl100k_base, CoreBPE};
16use util::http::{HttpClient, Request};
17
18lazy_static! {
19 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
20 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
21}
22
23#[derive(Clone)]
24pub struct OpenAIEmbeddings {
25 pub client: Arc<dyn HttpClient>,
26 pub executor: Arc<Background>,
27}
28
29#[derive(Serialize)]
30struct OpenAIEmbeddingRequest<'a> {
31 model: &'static str,
32 input: Vec<&'a str>,
33}
34
35#[derive(Deserialize)]
36struct OpenAIEmbeddingResponse {
37 data: Vec<OpenAIEmbedding>,
38 usage: OpenAIEmbeddingUsage,
39}
40
41#[derive(Debug, Deserialize)]
42struct OpenAIEmbedding {
43 embedding: Vec<f32>,
44 index: usize,
45 object: String,
46}
47
48#[derive(Deserialize)]
49struct OpenAIEmbeddingUsage {
50 prompt_tokens: usize,
51 total_tokens: usize,
52}
53
54#[async_trait]
55pub trait EmbeddingProvider: Sync + Send {
56 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
57 fn count_tokens(&self, span: &str) -> usize;
58 fn should_truncate(&self, span: &str) -> bool;
59 // fn truncate(&self, span: &str) -> Result<&str>;
60}
61
62pub struct DummyEmbeddings {}
63
64#[async_trait]
65impl EmbeddingProvider for DummyEmbeddings {
66 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
67 // 1024 is the OpenAI Embeddings size for ada models.
68 // the model we will likely be starting with.
69 let dummy_vec = vec![0.32 as f32; 1536];
70 return Ok(vec![dummy_vec; spans.len()]);
71 }
72
73 fn count_tokens(&self, span: &str) -> usize {
74 // For Dummy Providers, we are going to use OpenAI tokenization for ease
75 let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
76 tokens.len()
77 }
78
79 fn should_truncate(&self, span: &str) -> bool {
80 self.count_tokens(span) > OPENAI_INPUT_LIMIT
81
82 // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
83 // let Ok(output) = {
84 // if tokens.len() > OPENAI_INPUT_LIMIT {
85 // tokens.truncate(OPENAI_INPUT_LIMIT);
86 // OPENAI_BPE_TOKENIZER.decode(tokens)
87 // } else {
88 // Ok(span)
89 // }
90 // };
91 }
92}
93
94const OPENAI_INPUT_LIMIT: usize = 8190;
95
96impl OpenAIEmbeddings {
97 fn truncate(span: String) -> String {
98 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
99 if tokens.len() > OPENAI_INPUT_LIMIT {
100 tokens.truncate(OPENAI_INPUT_LIMIT);
101 let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
102 if result.is_ok() {
103 let transformed = result.unwrap();
104 return transformed;
105 }
106 }
107
108 span
109 }
110
111 async fn send_request(
112 &self,
113 api_key: &str,
114 spans: Vec<&str>,
115 request_timeout: u64,
116 ) -> Result<Response<AsyncBody>> {
117 let request = Request::post("https://api.openai.com/v1/embeddings")
118 .redirect_policy(isahc::config::RedirectPolicy::Follow)
119 .timeout(Duration::from_secs(request_timeout))
120 .header("Content-Type", "application/json")
121 .header("Authorization", format!("Bearer {}", api_key))
122 .body(
123 serde_json::to_string(&OpenAIEmbeddingRequest {
124 input: spans.clone(),
125 model: "text-embedding-ada-002",
126 })
127 .unwrap()
128 .into(),
129 )?;
130
131 Ok(self.client.send(request).await?)
132 }
133}
134
135#[async_trait]
136impl EmbeddingProvider for OpenAIEmbeddings {
137 fn count_tokens(&self, span: &str) -> usize {
138 // For Dummy Providers, we are going to use OpenAI tokenization for ease
139 let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
140 tokens.len()
141 }
142
143 fn should_truncate(&self, span: &str) -> bool {
144 self.count_tokens(span) > OPENAI_INPUT_LIMIT
145 }
146
147 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
148 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
149 const MAX_RETRIES: usize = 4;
150
151 let api_key = OPENAI_API_KEY
152 .as_ref()
153 .ok_or_else(|| anyhow!("no api key"))?;
154
155 let mut request_number = 0;
156 let mut request_timeout: u64 = 10;
157 let mut truncated = false;
158 let mut response: Response<AsyncBody>;
159 let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
160 while request_number < MAX_RETRIES {
161 response = self
162 .send_request(
163 api_key,
164 spans.iter().map(|x| &**x).collect(),
165 request_timeout,
166 )
167 .await?;
168 request_number += 1;
169
170 match response.status() {
171 StatusCode::REQUEST_TIMEOUT => {
172 request_timeout += 5;
173 }
174 StatusCode::OK => {
175 let mut body = String::new();
176 response.body_mut().read_to_string(&mut body).await?;
177 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
178
179 log::trace!(
180 "openai embedding completed. tokens: {:?}",
181 response.usage.total_tokens
182 );
183
184 return Ok(response
185 .data
186 .into_iter()
187 .map(|embedding| embedding.embedding)
188 .collect());
189 }
190 StatusCode::TOO_MANY_REQUESTS => {
191 let mut body = String::new();
192 response.body_mut().read_to_string(&mut body).await?;
193
194 let delay_duration = {
195 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
196 if let Some(time_to_reset) =
197 response.headers().get("x-ratelimit-reset-tokens")
198 {
199 if let Ok(time_str) = time_to_reset.to_str() {
200 parse(time_str).unwrap_or(delay)
201 } else {
202 delay
203 }
204 } else {
205 delay
206 }
207 };
208
209 log::trace!(
210 "openai rate limiting: waiting {:?} until lifted",
211 &delay_duration
212 );
213
214 self.executor.timer(delay_duration).await;
215 }
216 _ => {
217 // TODO: Move this to parsing step
218 // Only truncate if it hasnt been truncated before
219 if !truncated {
220 for span in spans.iter_mut() {
221 *span = Self::truncate(span.clone());
222 }
223 truncated = true;
224 } else {
225 // If failing once already truncated, log the error and break the loop
226 let mut body = String::new();
227 response.body_mut().read_to_string(&mut body).await?;
228 return Err(anyhow!(
229 "open ai bad request: {:?} {:?}",
230 &response.status(),
231 body
232 ));
233 }
234 }
235 }
236 }
237 Err(anyhow!("openai max retries"))
238 }
239}