1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::AsyncReadExt;
4use gpui::serde_json;
5use isahc::http::StatusCode;
6use isahc::prelude::Configurable;
7use isahc::{AsyncBody, Response};
8use lazy_static::lazy_static;
9use serde::{Deserialize, Serialize};
10use std::env;
11use std::sync::Arc;
12use std::time::Duration;
13use tiktoken_rs::{cl100k_base, CoreBPE};
14use util::http::{HttpClient, Request};
15
16lazy_static! {
17 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
18 static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
19}
20
21#[derive(Clone)]
22pub struct OpenAIEmbeddings {
23 pub client: Arc<dyn HttpClient>,
24}
25
26#[derive(Serialize)]
27struct OpenAIEmbeddingRequest<'a> {
28 model: &'static str,
29 input: Vec<&'a str>,
30}
31
32#[derive(Deserialize)]
33struct OpenAIEmbeddingResponse {
34 data: Vec<OpenAIEmbedding>,
35 usage: OpenAIEmbeddingUsage,
36}
37
38#[derive(Debug, Deserialize)]
39struct OpenAIEmbedding {
40 embedding: Vec<f32>,
41 index: usize,
42 object: String,
43}
44
45#[derive(Deserialize)]
46struct OpenAIEmbeddingUsage {
47 prompt_tokens: usize,
48 total_tokens: usize,
49}
50
51#[async_trait]
52pub trait EmbeddingProvider: Sync + Send {
53 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
54}
55
56pub struct DummyEmbeddings {}
57
58#[async_trait]
59impl EmbeddingProvider for DummyEmbeddings {
60 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
61 // 1024 is the OpenAI Embeddings size for ada models.
62 // the model we will likely be starting with.
63 let dummy_vec = vec![0.32 as f32; 1536];
64 return Ok(vec![dummy_vec; spans.len()]);
65 }
66}
67
68impl OpenAIEmbeddings {
69 async fn truncate(span: String) -> String {
70 let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
71 if tokens.len() > 8190 {
72 tokens.truncate(8190);
73 let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
74 if result.is_ok() {
75 let transformed = result.unwrap();
76 // assert_ne!(transformed, span);
77 return transformed;
78 }
79 }
80
81 return span.to_string();
82 }
83
84 async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
85 let request = Request::post("https://api.openai.com/v1/embeddings")
86 .redirect_policy(isahc::config::RedirectPolicy::Follow)
87 .header("Content-Type", "application/json")
88 .header("Authorization", format!("Bearer {}", api_key))
89 .body(
90 serde_json::to_string(&OpenAIEmbeddingRequest {
91 input: spans.clone(),
92 model: "text-embedding-ada-002",
93 })
94 .unwrap()
95 .into(),
96 )?;
97
98 Ok(self.client.send(request).await?)
99 }
100}
101
102#[async_trait]
103impl EmbeddingProvider for OpenAIEmbeddings {
104 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
105 const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
106 const MAX_RETRIES: usize = 3;
107
108 let api_key = OPENAI_API_KEY
109 .as_ref()
110 .ok_or_else(|| anyhow!("no api key"))?;
111
112 let mut request_number = 0;
113 let mut response: Response<AsyncBody>;
114 let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
115 while request_number < MAX_RETRIES {
116 response = self
117 .send_request(api_key, spans.iter().map(|x| &**x).collect())
118 .await?;
119 request_number += 1;
120
121 if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
122 return Err(anyhow!(
123 "openai max retries, error: {:?}",
124 &response.status()
125 ));
126 }
127
128 match response.status() {
129 StatusCode::TOO_MANY_REQUESTS => {
130 let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
131 std::thread::sleep(delay);
132 }
133 StatusCode::BAD_REQUEST => {
134 log::info!("BAD REQUEST: {:?}", &response.status());
135 // Don't worry about delaying bad request, as we can assume
136 // we haven't been rate limited yet.
137 for span in spans.iter_mut() {
138 *span = Self::truncate(span.to_string()).await;
139 }
140 }
141 StatusCode::OK => {
142 let mut body = String::new();
143 response.body_mut().read_to_string(&mut body).await?;
144 let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
145
146 log::info!(
147 "openai embedding completed. tokens: {:?}",
148 response.usage.total_tokens
149 );
150 return Ok(response
151 .data
152 .into_iter()
153 .map(|embedding| embedding.embedding)
154 .collect());
155 }
156 _ => {
157 return Err(anyhow!("openai embedding failed {}", response.status()));
158 }
159 }
160 }
161
162 Err(anyhow!("openai embedding failed"))
163 }
164}