1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parking_lot::Mutex;
11use parse_duration::parse;
12use postage::watch;
13use serde::{Deserialize, Serialize};
14use std::env;
15use std::ops::Add;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tiktoken_rs::{cl100k_base, CoreBPE};
19use util::http::{HttpClient, Request};
20
21use crate::embedding::{Embedding, EmbeddingProvider};
22use crate::models::LanguageModel;
23use crate::providers::open_ai::OpenAILanguageModel;
24
25lazy_static! {
26 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
27 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
28}
29
30#[derive(Clone)]
31pub struct OpenAIEmbeddingProvider {
32 model: OpenAILanguageModel,
33 pub client: Arc<dyn HttpClient>,
34 pub executor: Arc<Background>,
35 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
36 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
37}
38
39#[derive(Serialize)]
40struct OpenAIEmbeddingRequest<'a> {
41 model: &'static str,
42 input: Vec<&'a str>,
43}
44
45#[derive(Deserialize)]
46struct OpenAIEmbeddingResponse {
47 data: Vec<OpenAIEmbedding>,
48 usage: OpenAIEmbeddingUsage,
49}
50
51#[derive(Debug, Deserialize)]
52struct OpenAIEmbedding {
53 embedding: Vec<f32>,
54 index: usize,
55 object: String,
56}
57
58#[derive(Deserialize)]
59struct OpenAIEmbeddingUsage {
60 prompt_tokens: usize,
61 total_tokens: usize,
62}
63
64const OPENAI_INPUT_LIMIT: usize = 8190;
65
66impl OpenAIEmbeddingProvider {
67 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
68 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
69 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
70
71 let model = OpenAILanguageModel::load("text-embedding-ada-002");
72
73 OpenAIEmbeddingProvider {
74 model,
75 client,
76 executor,
77 rate_limit_count_rx,
78 rate_limit_count_tx,
79 }
80 }
81
82 fn resolve_rate_limit(&self) {
83 let reset_time = *self.rate_limit_count_tx.lock().borrow();
84
85 if let Some(reset_time) = reset_time {
86 if Instant::now() >= reset_time {
87 *self.rate_limit_count_tx.lock().borrow_mut() = None
88 }
89 }
90
91 log::trace!(
92 "resolving reset time: {:?}",
93 *self.rate_limit_count_tx.lock().borrow()
94 );
95 }
96
97 fn update_reset_time(&self, reset_time: Instant) {
98 let original_time = *self.rate_limit_count_tx.lock().borrow();
99
100 let updated_time = if let Some(original_time) = original_time {
101 if reset_time < original_time {
102 Some(reset_time)
103 } else {
104 Some(original_time)
105 }
106 } else {
107 Some(reset_time)
108 };
109
110 log::trace!("updating rate limit time: {:?}", updated_time);
111
112 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
113 }
114 async fn send_request(
115 &self,
116 api_key: &str,
117 spans: Vec<&str>,
118 request_timeout: u64,
119 ) -> Result<Response<AsyncBody>> {
120 let request = Request::post("https://api.openai.com/v1/embeddings")
121 .redirect_policy(isahc::config::RedirectPolicy::Follow)
122 .timeout(Duration::from_secs(request_timeout))
123 .header("Content-Type", "application/json")
124 .header("Authorization", format!("Bearer {}", api_key))
125 .body(
126 serde_json::to_string(&OpenAIEmbeddingRequest {
127 input: spans.clone(),
128 model: "text-embedding-ada-002",
129 })
130 .unwrap()
131 .into(),
132 )?;
133
134 Ok(self.client.send(request).await?)
135 }
136}
137
138#[async_trait]
139impl EmbeddingProvider for OpenAIEmbeddingProvider {
140 fn base_model(&self) -> Box<dyn LanguageModel> {
141 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
142 model
143 }
144 fn is_authenticated(&self) -> bool {
145 OPENAI_API_KEY.as_ref().is_some()
146 }
147 fn max_tokens_per_batch(&self) -> usize {
148 50000
149 }
150
151 fn rate_limit_expiration(&self) -> Option<Instant> {
152 *self.rate_limit_count_rx.borrow()
153 }
154 fn truncate(&self, span: &str) -> (String, usize) {
155 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
156 let output = if tokens.len() > OPENAI_INPUT_LIMIT {
157 tokens.truncate(OPENAI_INPUT_LIMIT);
158 OPENAI_BPE_TOKENIZER
159 .decode(tokens.clone())
160 .ok()
161 .unwrap_or_else(|| span.to_string())
162 } else {
163 span.to_string()
164 };
165
166 (output, tokens.len())
167 }
168
169 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
170 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
171 const MAX_RETRIES: usize = 4;
172
173 let api_key = OPENAI_API_KEY
174 .as_ref()
175 .ok_or_else(|| anyhow!("no api key"))?;
176
177 let mut request_number = 0;
178 let mut rate_limiting = false;
179 let mut request_timeout: u64 = 15;
180 let mut response: Response<AsyncBody>;
181 while request_number < MAX_RETRIES {
182 response = self
183 .send_request(
184 api_key,
185 spans.iter().map(|x| &**x).collect(),
186 request_timeout,
187 )
188 .await?;
189
190 request_number += 1;
191
192 match response.status() {
193 StatusCode::REQUEST_TIMEOUT => {
194 request_timeout += 5;
195 }
196 StatusCode::OK => {
197 let mut body = String::new();
198 response.body_mut().read_to_string(&mut body).await?;
199 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
200
201 log::trace!(
202 "openai embedding completed. tokens: {:?}",
203 response.usage.total_tokens
204 );
205
206 // If we complete a request successfully that was previously rate_limited
207 // resolve the rate limit
208 if rate_limiting {
209 self.resolve_rate_limit()
210 }
211
212 return Ok(response
213 .data
214 .into_iter()
215 .map(|embedding| Embedding::from(embedding.embedding))
216 .collect());
217 }
218 StatusCode::TOO_MANY_REQUESTS => {
219 rate_limiting = true;
220 let mut body = String::new();
221 response.body_mut().read_to_string(&mut body).await?;
222
223 let delay_duration = {
224 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
225 if let Some(time_to_reset) =
226 response.headers().get("x-ratelimit-reset-tokens")
227 {
228 if let Ok(time_str) = time_to_reset.to_str() {
229 parse(time_str).unwrap_or(delay)
230 } else {
231 delay
232 }
233 } else {
234 delay
235 }
236 };
237
238 // If we've previously rate limited, increment the duration but not the count
239 let reset_time = Instant::now().add(delay_duration);
240 self.update_reset_time(reset_time);
241
242 log::trace!(
243 "openai rate limiting: waiting {:?} until lifted",
244 &delay_duration
245 );
246
247 self.executor.timer(delay_duration).await;
248 }
249 _ => {
250 let mut body = String::new();
251 response.body_mut().read_to_string(&mut body).await?;
252 return Err(anyhow!(
253 "open ai bad request: {:?} {:?}",
254 &response.status(),
255 body
256 ));
257 }
258 }
259 }
260 Err(anyhow!("openai max retries"))
261 }
262}