1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use parse_duration::parse;
11use serde::{Deserialize, Serialize};
12use std::env;
13use std::sync::Arc;
14use std::time::Duration;
15use tiktoken_rs::{cl100k_base, CoreBPE};
16use util::http::{HttpClient, Request};
17
18lazy_static! {
19 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
20 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
21}
22
23#[derive(Clone)]
24pub struct OpenAIEmbeddings {
25 pub client: Arc<dyn HttpClient>,
26 pub executor: Arc<Background>,
27}
28
29#[derive(Serialize)]
30struct OpenAIEmbeddingRequest<'a> {
31 model: &'static str,
32 input: Vec<&'a str>,
33}
34
35#[derive(Deserialize)]
36struct OpenAIEmbeddingResponse {
37 data: Vec<OpenAIEmbedding>,
38 usage: OpenAIEmbeddingUsage,
39}
40
41#[derive(Debug, Deserialize)]
42struct OpenAIEmbedding {
43 embedding: Vec<f32>,
44 index: usize,
45 object: String,
46}
47
48#[derive(Deserialize)]
49struct OpenAIEmbeddingUsage {
50 prompt_tokens: usize,
51 total_tokens: usize,
52}
53
54#[async_trait]
55pub trait EmbeddingProvider: Sync + Send {
56 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
57}
58
59pub struct DummyEmbeddings {}
60
61#[async_trait]
62impl EmbeddingProvider for DummyEmbeddings {
63 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
64 // 1024 is the OpenAI Embeddings size for ada models.
65 // the model we will likely be starting with.
66 let dummy_vec = vec![0.32 as f32; 1536];
67 return Ok(vec![dummy_vec; spans.len()]);
68 }
69}
70
71const OPENAI_INPUT_LIMIT: usize = 8190;
72
73impl OpenAIEmbeddings {
74 fn truncate(span: String) -> String {
75 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
76 if tokens.len() > OPENAI_INPUT_LIMIT {
77 tokens.truncate(OPENAI_INPUT_LIMIT);
78 let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
79 if result.is_ok() {
80 let transformed = result.unwrap();
81 return transformed;
82 }
83 }
84
85 span
86 }
87
88 async fn send_request(
89 &self,
90 api_key: &str,
91 spans: Vec<&str>,
92 request_timeout: u64,
93 ) -> Result<Response<AsyncBody>> {
94 let request = Request::post("https://api.openai.com/v1/embeddings")
95 .redirect_policy(isahc::config::RedirectPolicy::Follow)
96 .timeout(Duration::from_secs(request_timeout))
97 .header("Content-Type", "application/json")
98 .header("Authorization", format!("Bearer {}", api_key))
99 .body(
100 serde_json::to_string(&OpenAIEmbeddingRequest {
101 input: spans.clone(),
102 model: "text-embedding-ada-002",
103 })
104 .unwrap()
105 .into(),
106 )?;
107
108 Ok(self.client.send(request).await?)
109 }
110}
111
112#[async_trait]
113impl EmbeddingProvider for OpenAIEmbeddings {
114 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
115 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
116 const MAX_RETRIES: usize = 4;
117
118 let api_key = OPENAI_API_KEY
119 .as_ref()
120 .ok_or_else(|| anyhow!("no api key"))?;
121
122 let mut request_number = 0;
123 let mut request_timeout: u64 = 10;
124 let mut truncated = false;
125 let mut response: Response<AsyncBody>;
126 let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
127 while request_number < MAX_RETRIES {
128 response = self
129 .send_request(
130 api_key,
131 spans.iter().map(|x| &**x).collect(),
132 request_timeout,
133 )
134 .await?;
135 request_number += 1;
136
137 match response.status() {
138 StatusCode::REQUEST_TIMEOUT => {
139 request_timeout += 5;
140 }
141 StatusCode::OK => {
142 let mut body = String::new();
143 response.body_mut().read_to_string(&mut body).await?;
144 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
145
146 log::trace!(
147 "openai embedding completed. tokens: {:?}",
148 response.usage.total_tokens
149 );
150
151 return Ok(response
152 .data
153 .into_iter()
154 .map(|embedding| embedding.embedding)
155 .collect());
156 }
157 StatusCode::TOO_MANY_REQUESTS => {
158 let mut body = String::new();
159 response.body_mut().read_to_string(&mut body).await?;
160
161 let delay_duration = {
162 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
163 if let Some(time_to_reset) =
164 response.headers().get("x-ratelimit-reset-tokens")
165 {
166 if let Ok(time_str) = time_to_reset.to_str() {
167 parse(time_str).unwrap_or(delay)
168 } else {
169 delay
170 }
171 } else {
172 delay
173 }
174 };
175
176 log::trace!(
177 "openai rate limiting: waiting {:?} until lifted",
178 &delay_duration
179 );
180
181 self.executor.timer(delay_duration).await;
182 }
183 _ => {
184 // TODO: Move this to parsing step
185 // Only truncate if it hasnt been truncated before
186 if !truncated {
187 for span in spans.iter_mut() {
188 *span = Self::truncate(span.clone());
189 }
190 truncated = true;
191 } else {
192 // If failing once already truncated, log the error and break the loop
193 let mut body = String::new();
194 response.body_mut().read_to_string(&mut body).await?;
195 return Err(anyhow!(
196 "open ai bad request: {:?} {:?}",
197 &response.status(),
198 body
199 ));
200 }
201 }
202 }
203 }
204 Err(anyhow!("openai max retries"))
205 }
206}