1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parse_duration::parse;
11use serde::{Deserialize, Serialize};
12use std::env;
13use std::sync::Arc;
14use std::time::Duration;
15use tiktoken_rs::{cl100k_base, CoreBPE};
16use util::http::{HttpClient, Request};
17
18lazy_static! {
19 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
20 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
21}
22
23#[derive(Clone)]
24pub struct OpenAIEmbeddings {
25 pub client: Arc<dyn HttpClient>,
26 pub executor: Arc<Background>,
27}
28
29#[derive(Serialize)]
30struct OpenAIEmbeddingRequest<'a> {
31 model: &'static str,
32 input: Vec<&'a str>,
33}
34
35#[derive(Deserialize)]
36struct OpenAIEmbeddingResponse {
37 data: Vec<OpenAIEmbedding>,
38 usage: OpenAIEmbeddingUsage,
39}
40
41#[derive(Debug, Deserialize)]
42struct OpenAIEmbedding {
43 embedding: Vec<f32>,
44 index: usize,
45 object: String,
46}
47
48#[derive(Deserialize)]
49struct OpenAIEmbeddingUsage {
50 prompt_tokens: usize,
51 total_tokens: usize,
52}
53
54#[async_trait]
55pub trait EmbeddingProvider: Sync + Send {
56 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
57 fn max_tokens_per_batch(&self) -> usize;
58 fn truncate(&self, span: &str) -> (String, usize);
59}
60
61pub struct DummyEmbeddings {}
62
63#[async_trait]
64impl EmbeddingProvider for DummyEmbeddings {
65 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
66 // 1024 is the OpenAI Embeddings size for ada models.
67 // the model we will likely be starting with.
68 let dummy_vec = vec![0.32 as f32; 1536];
69 return Ok(vec![dummy_vec; spans.len()]);
70 }
71
72 fn max_tokens_per_batch(&self) -> usize {
73 OPENAI_INPUT_LIMIT
74 }
75
76 fn truncate(&self, span: &str) -> (String, usize) {
77 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
78 let token_count = tokens.len();
79 let output = if token_count > OPENAI_INPUT_LIMIT {
80 tokens.truncate(OPENAI_INPUT_LIMIT);
81 let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
82 new_input.ok().unwrap_or_else(|| span.to_string())
83 } else {
84 span.to_string()
85 };
86
87 (output, tokens.len())
88 }
89}
90
91const OPENAI_INPUT_LIMIT: usize = 8190;
92
93impl OpenAIEmbeddings {
94 async fn send_request(
95 &self,
96 api_key: &str,
97 spans: Vec<&str>,
98 request_timeout: u64,
99 ) -> Result<Response<AsyncBody>> {
100 let request = Request::post("https://api.openai.com/v1/embeddings")
101 .redirect_policy(isahc::config::RedirectPolicy::Follow)
102 .timeout(Duration::from_secs(request_timeout))
103 .header("Content-Type", "application/json")
104 .header("Authorization", format!("Bearer {}", api_key))
105 .body(
106 serde_json::to_string(&OpenAIEmbeddingRequest {
107 input: spans.clone(),
108 model: "text-embedding-ada-002",
109 })
110 .unwrap()
111 .into(),
112 )?;
113
114 Ok(self.client.send(request).await?)
115 }
116}
117
118#[async_trait]
119impl EmbeddingProvider for OpenAIEmbeddings {
120 fn max_tokens_per_batch(&self) -> usize {
121 50000
122 }
123
124 fn truncate(&self, span: &str) -> (String, usize) {
125 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
126 let token_count = tokens.len();
127 let output = if token_count > OPENAI_INPUT_LIMIT {
128 tokens.truncate(OPENAI_INPUT_LIMIT);
129 OPENAI_BPE_TOKENIZER
130 .decode(tokens)
131 .ok()
132 .unwrap_or_else(|| span.to_string())
133 } else {
134 span.to_string()
135 };
136
137 (output, token_count)
138 }
139
140 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
141 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
142 const MAX_RETRIES: usize = 4;
143
144 let api_key = OPENAI_API_KEY
145 .as_ref()
146 .ok_or_else(|| anyhow!("no api key"))?;
147
148 let mut request_number = 0;
149 let mut request_timeout: u64 = 10;
150 let mut response: Response<AsyncBody>;
151 while request_number < MAX_RETRIES {
152 response = self
153 .send_request(
154 api_key,
155 spans.iter().map(|x| &**x).collect(),
156 request_timeout,
157 )
158 .await?;
159 request_number += 1;
160
161 match response.status() {
162 StatusCode::REQUEST_TIMEOUT => {
163 request_timeout += 5;
164 }
165 StatusCode::OK => {
166 let mut body = String::new();
167 response.body_mut().read_to_string(&mut body).await?;
168 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
169
170 log::trace!(
171 "openai embedding completed. tokens: {:?}",
172 response.usage.total_tokens
173 );
174
175 return Ok(response
176 .data
177 .into_iter()
178 .map(|embedding| embedding.embedding)
179 .collect());
180 }
181 StatusCode::TOO_MANY_REQUESTS => {
182 let mut body = String::new();
183 response.body_mut().read_to_string(&mut body).await?;
184
185 let delay_duration = {
186 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
187 if let Some(time_to_reset) =
188 response.headers().get("x-ratelimit-reset-tokens")
189 {
190 if let Ok(time_str) = time_to_reset.to_str() {
191 parse(time_str).unwrap_or(delay)
192 } else {
193 delay
194 }
195 } else {
196 delay
197 }
198 };
199
200 log::trace!(
201 "openai rate limiting: waiting {:?} until lifted",
202 &delay_duration
203 );
204
205 self.executor.timer(delay_duration).await;
206 }
207 _ => {
208 let mut body = String::new();
209 response.body_mut().read_to_string(&mut body).await?;
210 return Err(anyhow!(
211 "open ai bad request: {:?} {:?}",
212 &response.status(),
213 body
214 ));
215 }
216 }
217 }
218 Err(anyhow!("openai max retries"))
219 }
220}