1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parking_lot::Mutex;
11use parse_duration::parse;
12use postage::watch;
13use serde::{Deserialize, Serialize};
14use std::env;
15use std::ops::Add;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tiktoken_rs::{cl100k_base, CoreBPE};
19use util::http::{HttpClient, Request};
20
21use crate::auth::{CredentialProvider, ProviderCredential};
22use crate::embedding::{Embedding, EmbeddingProvider};
23use crate::models::LanguageModel;
24use crate::providers::open_ai::OpenAILanguageModel;
25
26use crate::providers::open_ai::auth::OpenAICredentialProvider;
27
28lazy_static! {
29 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
30}
31
32#[derive(Clone)]
33pub struct OpenAIEmbeddingProvider {
34 model: OpenAILanguageModel,
35 credential_provider: OpenAICredentialProvider,
36 pub client: Arc<dyn HttpClient>,
37 pub executor: Arc<Background>,
38 rate_limit_count_rx: watch::Receiver<Option<Instant>>,
39 rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
40}
41
42#[derive(Serialize)]
43struct OpenAIEmbeddingRequest<'a> {
44 model: &'static str,
45 input: Vec<&'a str>,
46}
47
48#[derive(Deserialize)]
49struct OpenAIEmbeddingResponse {
50 data: Vec<OpenAIEmbedding>,
51 usage: OpenAIEmbeddingUsage,
52}
53
54#[derive(Debug, Deserialize)]
55struct OpenAIEmbedding {
56 embedding: Vec<f32>,
57 index: usize,
58 object: String,
59}
60
61#[derive(Deserialize)]
62struct OpenAIEmbeddingUsage {
63 prompt_tokens: usize,
64 total_tokens: usize,
65}
66
67impl OpenAIEmbeddingProvider {
68 pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
69 let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
70 let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
71
72 let model = OpenAILanguageModel::load("text-embedding-ada-002");
73
74 OpenAIEmbeddingProvider {
75 model,
76 credential_provider: OpenAICredentialProvider {},
77 client,
78 executor,
79 rate_limit_count_rx,
80 rate_limit_count_tx,
81 }
82 }
83
84 fn resolve_rate_limit(&self) {
85 let reset_time = *self.rate_limit_count_tx.lock().borrow();
86
87 if let Some(reset_time) = reset_time {
88 if Instant::now() >= reset_time {
89 *self.rate_limit_count_tx.lock().borrow_mut() = None
90 }
91 }
92
93 log::trace!(
94 "resolving reset time: {:?}",
95 *self.rate_limit_count_tx.lock().borrow()
96 );
97 }
98
99 fn update_reset_time(&self, reset_time: Instant) {
100 let original_time = *self.rate_limit_count_tx.lock().borrow();
101
102 let updated_time = if let Some(original_time) = original_time {
103 if reset_time < original_time {
104 Some(reset_time)
105 } else {
106 Some(original_time)
107 }
108 } else {
109 Some(reset_time)
110 };
111
112 log::trace!("updating rate limit time: {:?}", updated_time);
113
114 *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
115 }
116 async fn send_request(
117 &self,
118 api_key: &str,
119 spans: Vec<&str>,
120 request_timeout: u64,
121 ) -> Result<Response<AsyncBody>> {
122 let request = Request::post("https://api.openai.com/v1/embeddings")
123 .redirect_policy(isahc::config::RedirectPolicy::Follow)
124 .timeout(Duration::from_secs(request_timeout))
125 .header("Content-Type", "application/json")
126 .header("Authorization", format!("Bearer {}", api_key))
127 .body(
128 serde_json::to_string(&OpenAIEmbeddingRequest {
129 input: spans.clone(),
130 model: "text-embedding-ada-002",
131 })
132 .unwrap()
133 .into(),
134 )?;
135
136 Ok(self.client.send(request).await?)
137 }
138}
139
140#[async_trait]
141impl EmbeddingProvider for OpenAIEmbeddingProvider {
142 fn base_model(&self) -> Box<dyn LanguageModel> {
143 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
144 model
145 }
146
147 fn credential_provider(&self) -> Box<dyn CredentialProvider> {
148 let credential_provider: Box<dyn CredentialProvider> =
149 Box::new(self.credential_provider.clone());
150 credential_provider
151 }
152
153 fn max_tokens_per_batch(&self) -> usize {
154 50000
155 }
156
157 fn rate_limit_expiration(&self) -> Option<Instant> {
158 *self.rate_limit_count_rx.borrow()
159 }
160
161 async fn embed_batch(
162 &self,
163 spans: Vec<String>,
164 credential: ProviderCredential,
165 ) -> Result<Vec<Embedding>> {
166 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
167 const MAX_RETRIES: usize = 4;
168
169 let api_key = match credential {
170 ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
171 _ => Err(anyhow!("no api key provided")),
172 }?;
173
174 let mut request_number = 0;
175 let mut rate_limiting = false;
176 let mut request_timeout: u64 = 15;
177 let mut response: Response<AsyncBody>;
178 while request_number < MAX_RETRIES {
179 response = self
180 .send_request(
181 &api_key,
182 spans.iter().map(|x| &**x).collect(),
183 request_timeout,
184 )
185 .await?;
186
187 request_number += 1;
188
189 match response.status() {
190 StatusCode::REQUEST_TIMEOUT => {
191 request_timeout += 5;
192 }
193 StatusCode::OK => {
194 let mut body = String::new();
195 response.body_mut().read_to_string(&mut body).await?;
196 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
197
198 log::trace!(
199 "openai embedding completed. tokens: {:?}",
200 response.usage.total_tokens
201 );
202
203 // If we complete a request successfully that was previously rate_limited
204 // resolve the rate limit
205 if rate_limiting {
206 self.resolve_rate_limit()
207 }
208
209 return Ok(response
210 .data
211 .into_iter()
212 .map(|embedding| Embedding::from(embedding.embedding))
213 .collect());
214 }
215 StatusCode::TOO_MANY_REQUESTS => {
216 rate_limiting = true;
217 let mut body = String::new();
218 response.body_mut().read_to_string(&mut body).await?;
219
220 let delay_duration = {
221 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
222 if let Some(time_to_reset) =
223 response.headers().get("x-ratelimit-reset-tokens")
224 {
225 if let Ok(time_str) = time_to_reset.to_str() {
226 parse(time_str).unwrap_or(delay)
227 } else {
228 delay
229 }
230 } else {
231 delay
232 }
233 };
234
235 // If we've previously rate limited, increment the duration but not the count
236 let reset_time = Instant::now().add(delay_duration);
237 self.update_reset_time(reset_time);
238
239 log::trace!(
240 "openai rate limiting: waiting {:?} until lifted",
241 &delay_duration
242 );
243
244 self.executor.timer(delay_duration).await;
245 }
246 _ => {
247 let mut body = String::new();
248 response.body_mut().read_to_string(&mut body).await?;
249 return Err(anyhow!(
250 "open ai bad request: {:?} {:?}",
251 &response.status(),
252 body
253 ));
254 }
255 }
256 }
257 Err(anyhow!("openai max retries"))
258 }
259}