Detailed changes
@@ -1,30 +1,9 @@
-use anyhow::{anyhow, Result};
+use anyhow::Result;
use async_trait::async_trait;
-use futures::AsyncReadExt;
-use gpui::executor::Background;
-use gpui::serde_json;
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
-use parking_lot::Mutex;
-use parse_duration::parse;
-use postage::watch;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
-use serde::{Deserialize, Serialize};
-use std::env;
-use std::ops::Add;
-use std::sync::Arc;
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-
-lazy_static! {
- static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
- static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
-}
+use std::time::Instant;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
@@ -85,39 +64,6 @@ impl Embedding {
}
}
-#[derive(Clone)]
-pub struct OpenAIEmbeddings {
- pub client: Arc<dyn HttpClient>,
- pub executor: Arc<Background>,
- rate_limit_count_rx: watch::Receiver<Option<Instant>>,
- rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAIEmbeddingRequest<'a> {
- model: &'static str,
- input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingResponse {
- data: Vec<OpenAIEmbedding>,
- usage: OpenAIEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAIEmbedding {
- embedding: Vec<f32>,
- index: usize,
- object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingUsage {
- prompt_tokens: usize,
- total_tokens: usize,
-}
-
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
fn is_authenticated(&self) -> bool;
@@ -127,235 +73,6 @@ pub trait EmbeddingProvider: Sync + Send {
fn rate_limit_expiration(&self) -> Option<Instant>;
}
-pub struct DummyEmbeddings {}
-
-#[async_trait]
-impl EmbeddingProvider for DummyEmbeddings {
- fn is_authenticated(&self) -> bool {
- true
- }
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
- // 1024 is the OpenAI Embeddings size for ada models.
- // the model we will likely be starting with.
- let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
- return Ok(vec![dummy_vec; spans.len()]);
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- OPENAI_INPUT_LIMIT
- }
-
- fn truncate(&self, span: &str) -> (String, usize) {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let token_count = tokens.len();
- let output = if token_count > OPENAI_INPUT_LIMIT {
- tokens.truncate(OPENAI_INPUT_LIMIT);
- let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
- new_input.ok().unwrap_or_else(|| span.to_string())
- } else {
- span.to_string()
- };
-
- (output, tokens.len())
- }
-}
-
-const OPENAI_INPUT_LIMIT: usize = 8190;
-
-impl OpenAIEmbeddings {
- pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
- let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
- let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
- OpenAIEmbeddings {
- client,
- executor,
- rate_limit_count_rx,
- rate_limit_count_tx,
- }
- }
-
- fn resolve_rate_limit(&self) {
- let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
- if let Some(reset_time) = reset_time {
- if Instant::now() >= reset_time {
- *self.rate_limit_count_tx.lock().borrow_mut() = None
- }
- }
-
- log::trace!(
- "resolving reset time: {:?}",
- *self.rate_limit_count_tx.lock().borrow()
- );
- }
-
- fn update_reset_time(&self, reset_time: Instant) {
- let original_time = *self.rate_limit_count_tx.lock().borrow();
-
- let updated_time = if let Some(original_time) = original_time {
- if reset_time < original_time {
- Some(reset_time)
- } else {
- Some(original_time)
- }
- } else {
- Some(reset_time)
- };
-
- log::trace!("updating rate limit time: {:?}", updated_time);
-
- *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
- }
- async fn send_request(
- &self,
- api_key: &str,
- spans: Vec<&str>,
- request_timeout: u64,
- ) -> Result<Response<AsyncBody>> {
- let request = Request::post("https://api.openai.com/v1/embeddings")
- .redirect_policy(isahc::config::RedirectPolicy::Follow)
- .timeout(Duration::from_secs(request_timeout))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(
- serde_json::to_string(&OpenAIEmbeddingRequest {
- input: spans.clone(),
- model: "text-embedding-ada-002",
- })
- .unwrap()
- .into(),
- )?;
-
- Ok(self.client.send(request).await?)
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddings {
- fn is_authenticated(&self) -> bool {
- OPENAI_API_KEY.as_ref().is_some()
- }
- fn max_tokens_per_batch(&self) -> usize {
- 50000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- *self.rate_limit_count_rx.borrow()
- }
- fn truncate(&self, span: &str) -> (String, usize) {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let output = if tokens.len() > OPENAI_INPUT_LIMIT {
- tokens.truncate(OPENAI_INPUT_LIMIT);
- OPENAI_BPE_TOKENIZER
- .decode(tokens.clone())
- .ok()
- .unwrap_or_else(|| span.to_string())
- } else {
- span.to_string()
- };
-
- (output, tokens.len())
- }
-
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
- const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
- const MAX_RETRIES: usize = 4;
-
- let api_key = OPENAI_API_KEY
- .as_ref()
- .ok_or_else(|| anyhow!("no api key"))?;
-
- let mut request_number = 0;
- let mut rate_limiting = false;
- let mut request_timeout: u64 = 15;
- let mut response: Response<AsyncBody>;
- while request_number < MAX_RETRIES {
- response = self
- .send_request(
- api_key,
- spans.iter().map(|x| &**x).collect(),
- request_timeout,
- )
- .await?;
-
- request_number += 1;
-
- match response.status() {
- StatusCode::REQUEST_TIMEOUT => {
- request_timeout += 5;
- }
- StatusCode::OK => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
-
- log::trace!(
- "openai embedding completed. tokens: {:?}",
- response.usage.total_tokens
- );
-
- // If we complete a request successfully that was previously rate_limited
- // resolve the rate limit
- if rate_limiting {
- self.resolve_rate_limit()
- }
-
- return Ok(response
- .data
- .into_iter()
- .map(|embedding| Embedding::from(embedding.embedding))
- .collect());
- }
- StatusCode::TOO_MANY_REQUESTS => {
- rate_limiting = true;
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let delay_duration = {
- let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
- if let Some(time_to_reset) =
- response.headers().get("x-ratelimit-reset-tokens")
- {
- if let Ok(time_str) = time_to_reset.to_str() {
- parse(time_str).unwrap_or(delay)
- } else {
- delay
- }
- } else {
- delay
- }
- };
-
- // If we've previously rate limited, increment the duration but not the count
- let reset_time = Instant::now().add(delay_duration);
- self.update_reset_time(reset_time);
-
- log::trace!(
- "openai rate limiting: waiting {:?} until lifted",
- &delay_duration
- );
-
- self.executor.timer(delay_duration).await;
- }
- _ => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(
- "open ai bad request: {:?} {:?}",
- &response.status(),
- body
- ));
- }
- }
- }
- Err(anyhow!("openai max retries"))
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -1,4 +1,10 @@
-use crate::completion::CompletionRequest;
+use std::time::Instant;
+
+use crate::{
+ completion::CompletionRequest,
+ embedding::{Embedding, EmbeddingProvider},
+};
+use async_trait::async_trait;
use serde::Serialize;
#[derive(Serialize)]
@@ -11,3 +17,32 @@ impl CompletionRequest for DummyCompletionRequest {
serde_json::to_string(self)
}
}
+
+pub struct DummyEmbeddingProvider {}
+
+#[async_trait]
+impl EmbeddingProvider for DummyEmbeddingProvider {
+ fn is_authenticated(&self) -> bool {
+ true
+ }
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
+ async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+ // 1024 is the OpenAI Embeddings size for ada models.
+ // the model we will likely be starting with.
+ let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
+ return Ok(vec![dummy_vec; spans.len()]);
+ }
+
+ fn max_tokens_per_batch(&self) -> usize {
+ 8190
+ }
+
+ fn truncate(&self, span: &str) -> (String, usize) {
+ let truncated = span.chars().collect::<Vec<char>>()[..8190]
+ .iter()
+ .collect::<String>();
+ (truncated, 8190)
+ }
+}
@@ -0,0 +1,252 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::executor::Background;
+use gpui::serde_json;
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::Mutex;
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+
+use crate::embedding::{Embedding, EmbeddingProvider};
+
+lazy_static! {
+ static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
+ static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+ pub client: Arc<dyn HttpClient>,
+ pub executor: Arc<Background>,
+ rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+ rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+ model: &'static str,
+ input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+ data: Vec<OpenAIEmbedding>,
+ usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+ embedding: Vec<f32>,
+ index: usize,
+ object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+ prompt_tokens: usize,
+ total_tokens: usize,
+}
+
+const OPENAI_INPUT_LIMIT: usize = 8190;
+
+impl OpenAIEmbeddingProvider {
+ pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+ let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+ let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+ OpenAIEmbeddingProvider {
+ client,
+ executor,
+ rate_limit_count_rx,
+ rate_limit_count_tx,
+ }
+ }
+
+ fn resolve_rate_limit(&self) {
+ let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+ if let Some(reset_time) = reset_time {
+ if Instant::now() >= reset_time {
+ *self.rate_limit_count_tx.lock().borrow_mut() = None
+ }
+ }
+
+ log::trace!(
+ "resolving reset time: {:?}",
+ *self.rate_limit_count_tx.lock().borrow()
+ );
+ }
+
+ fn update_reset_time(&self, reset_time: Instant) {
+ let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+ let updated_time = if let Some(original_time) = original_time {
+ if reset_time < original_time {
+ Some(reset_time)
+ } else {
+ Some(original_time)
+ }
+ } else {
+ Some(reset_time)
+ };
+
+ log::trace!("updating rate limit time: {:?}", updated_time);
+
+ *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+ }
+ async fn send_request(
+ &self,
+ api_key: &str,
+ spans: Vec<&str>,
+ request_timeout: u64,
+ ) -> Result<Response<AsyncBody>> {
+ let request = Request::post("https://api.openai.com/v1/embeddings")
+ .redirect_policy(isahc::config::RedirectPolicy::Follow)
+ .timeout(Duration::from_secs(request_timeout))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(
+ serde_json::to_string(&OpenAIEmbeddingRequest {
+ input: spans.clone(),
+ model: "text-embedding-ada-002",
+ })
+ .unwrap()
+ .into(),
+ )?;
+
+ Ok(self.client.send(request).await?)
+ }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+ fn is_authenticated(&self) -> bool {
+ OPENAI_API_KEY.as_ref().is_some()
+ }
+ fn max_tokens_per_batch(&self) -> usize {
+ 50000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ *self.rate_limit_count_rx.borrow()
+ }
+ fn truncate(&self, span: &str) -> (String, usize) {
+ let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+ let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+ tokens.truncate(OPENAI_INPUT_LIMIT);
+ OPENAI_BPE_TOKENIZER
+ .decode(tokens.clone())
+ .ok()
+ .unwrap_or_else(|| span.to_string())
+ } else {
+ span.to_string()
+ };
+
+ (output, tokens.len())
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+ const MAX_RETRIES: usize = 4;
+
+ let api_key = OPENAI_API_KEY
+ .as_ref()
+ .ok_or_else(|| anyhow!("no api key"))?;
+
+ let mut request_number = 0;
+ let mut rate_limiting = false;
+ let mut request_timeout: u64 = 15;
+ let mut response: Response<AsyncBody>;
+ while request_number < MAX_RETRIES {
+ response = self
+ .send_request(
+ api_key,
+ spans.iter().map(|x| &**x).collect(),
+ request_timeout,
+ )
+ .await?;
+
+ request_number += 1;
+
+ match response.status() {
+ StatusCode::REQUEST_TIMEOUT => {
+ request_timeout += 5;
+ }
+ StatusCode::OK => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+ log::trace!(
+ "openai embedding completed. tokens: {:?}",
+ response.usage.total_tokens
+ );
+
+ // If we complete a request successfully that was previously rate_limited
+ // resolve the rate limit
+ if rate_limiting {
+ self.resolve_rate_limit()
+ }
+
+ return Ok(response
+ .data
+ .into_iter()
+ .map(|embedding| Embedding::from(embedding.embedding))
+ .collect());
+ }
+ StatusCode::TOO_MANY_REQUESTS => {
+ rate_limiting = true;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let delay_duration = {
+ let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+ if let Some(time_to_reset) =
+ response.headers().get("x-ratelimit-reset-tokens")
+ {
+ if let Ok(time_str) = time_to_reset.to_str() {
+ parse(time_str).unwrap_or(delay)
+ } else {
+ delay
+ }
+ } else {
+ delay
+ }
+ };
+
+ // If we've previously rate limited, increment the duration but not the count
+ let reset_time = Instant::now().add(delay_duration);
+ self.update_reset_time(reset_time);
+
+ log::trace!(
+ "openai rate limiting: waiting {:?} until lifted",
+ &delay_duration
+ );
+
+ self.executor.timer(delay_duration).await;
+ }
+ _ => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "open ai bad request: {:?} {:?}",
+ &response.status(),
+ body
+ ));
+ }
+ }
+ }
+ Err(anyhow!("openai max retries"))
+ }
+}
@@ -1,4 +1,7 @@
pub mod completion;
+pub mod embedding;
pub mod model;
+
pub use completion::*;
+pub use embedding::*;
pub use model::OpenAILanguageModel;
@@ -7,7 +7,8 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@@ -88,7 +89,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@@ -4,7 +4,8 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
-use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::dummy::DummyEmbeddingProvider;
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, Task, TestAppContext};
@@ -280,7 +281,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -382,7 +383,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -466,7 +467,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -565,7 +566,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -639,7 +640,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -756,7 +757,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -909,7 +910,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1100,7 +1101,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1,4 +1,4 @@
-use ai::embedding::OpenAIEmbeddings;
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use client::{self, UserStore};
use gpui::{AsyncAppContext, ModelHandle, Task};
@@ -474,7 +474,7 @@ fn main() {
let semantic_index = SemanticIndex::new(
fs.clone(),
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
languages.clone(),
cx.clone(),
)