1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::executor::Background;
5use gpui::serde_json;
6use isahc::http::StatusCode;
7use isahc::prelude::Configurable;
8use isahc::{AsyncBody, Response};
9use lazy_static::lazy_static;
10use serde::{Deserialize, Serialize};
11use std::env;
12use std::sync::Arc;
13use std::time::Duration;
14use tiktoken_rs::{cl100k_base, CoreBPE};
15use util::http::{HttpClient, Request};
16
17lazy_static! {
18 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
19 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
20}
21
22#[derive(Clone)]
23pub struct OpenAIEmbeddings {
24 pub client: Arc<dyn HttpClient>,
25 pub executor: Arc<Background>,
26}
27
28#[derive(Serialize)]
29struct OpenAIEmbeddingRequest<'a> {
30 model: &'static str,
31 input: Vec<&'a str>,
32}
33
34#[derive(Deserialize)]
35struct OpenAIEmbeddingResponse {
36 data: Vec<OpenAIEmbedding>,
37 usage: OpenAIEmbeddingUsage,
38}
39
40#[derive(Debug, Deserialize)]
41struct OpenAIEmbedding {
42 embedding: Vec<f32>,
43 index: usize,
44 object: String,
45}
46
47#[derive(Deserialize)]
48struct OpenAIEmbeddingUsage {
49 prompt_tokens: usize,
50 total_tokens: usize,
51}
52
53#[async_trait]
54pub trait EmbeddingProvider: Sync + Send {
55 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
56}
57
58pub struct DummyEmbeddings {}
59
60#[async_trait]
61impl EmbeddingProvider for DummyEmbeddings {
62 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
63 // 1024 is the OpenAI Embeddings size for ada models.
64 // the model we will likely be starting with.
65 let dummy_vec = vec![0.32 as f32; 1536];
66 return Ok(vec![dummy_vec; spans.len()]);
67 }
68}
69
70const OPENAI_INPUT_LIMIT: usize = 8190;
71
72impl OpenAIEmbeddings {
73 fn truncate(span: String) -> String {
74 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
75 if tokens.len() > OPENAI_INPUT_LIMIT {
76 tokens.truncate(OPENAI_INPUT_LIMIT);
77 let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
78 if result.is_ok() {
79 let transformed = result.unwrap();
80 return transformed;
81 }
82 }
83
84 span
85 }
86
87 async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
88 let request = Request::post("https://api.openai.com/v1/embeddings")
89 .redirect_policy(isahc::config::RedirectPolicy::Follow)
90 .timeout(Duration::from_secs(4))
91 .header("Content-Type", "application/json")
92 .header("Authorization", format!("Bearer {}", api_key))
93 .body(
94 serde_json::to_string(&OpenAIEmbeddingRequest {
95 input: spans.clone(),
96 model: "text-embedding-ada-002",
97 })
98 .unwrap()
99 .into(),
100 )?;
101
102 Ok(self.client.send(request).await?)
103 }
104}
105
106#[async_trait]
107impl EmbeddingProvider for OpenAIEmbeddings {
108 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
109 const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
110 const MAX_RETRIES: usize = 4;
111
112 let api_key = OPENAI_API_KEY
113 .as_ref()
114 .ok_or_else(|| anyhow!("no api key"))?;
115
116 let mut request_number = 0;
117 let mut truncated = false;
118 let mut response: Response<AsyncBody>;
119 let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
120 while request_number < MAX_RETRIES {
121 response = self
122 .send_request(api_key, spans.iter().map(|x| &**x).collect())
123 .await?;
124 request_number += 1;
125
126 if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
127 return Err(anyhow!(
128 "openai max retries, error: {:?}",
129 &response.status()
130 ));
131 }
132
133 match response.status() {
134 StatusCode::TOO_MANY_REQUESTS => {
135 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
136 log::trace!(
137 "open ai rate limiting, delaying request by {:?} seconds",
138 delay.as_secs()
139 );
140 self.executor.timer(delay).await;
141 }
142 StatusCode::BAD_REQUEST => {
143 // Only truncate if it hasnt been truncated before
144 if !truncated {
145 for span in spans.iter_mut() {
146 *span = Self::truncate(span.clone());
147 }
148 truncated = true;
149 } else {
150 // If failing once already truncated, log the error and break the loop
151 let mut body = String::new();
152 response.body_mut().read_to_string(&mut body).await?;
153 log::trace!("open ai bad request: {:?} {:?}", &response.status(), body);
154 break;
155 }
156 }
157 StatusCode::OK => {
158 let mut body = String::new();
159 response.body_mut().read_to_string(&mut body).await?;
160 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
161
162 log::trace!(
163 "openai embedding completed. tokens: {:?}",
164 response.usage.total_tokens
165 );
166 return Ok(response
167 .data
168 .into_iter()
169 .map(|embedding| embedding.embedding)
170 .collect());
171 }
172 _ => {
173 return Err(anyhow!("openai embedding failed {}", response.status()));
174 }
175 }
176 }
177
178 Err(anyhow!("openai embedding failed"))
179 }
180}