@@ -56,7 +56,7 @@ pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
fn count_tokens(&self, span: &str) -> usize;
fn should_truncate(&self, span: &str) -> bool;
- // fn truncate(&self, span: &str) -> Result<&str>;
+ fn truncate(&self, span: &str) -> String;
}
pub struct DummyEmbeddings {}
@@ -78,36 +78,27 @@ impl EmbeddingProvider for DummyEmbeddings {
fn should_truncate(&self, span: &str) -> bool {
self.count_tokens(span) > OPENAI_INPUT_LIMIT
+ }
- // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- // let Ok(output) = {
- // if tokens.len() > OPENAI_INPUT_LIMIT {
- // tokens.truncate(OPENAI_INPUT_LIMIT);
- // OPENAI_BPE_TOKENIZER.decode(tokens)
- // } else {
- // Ok(span)
- // }
- // };
+ fn truncate(&self, span: &str) -> String {
+ let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+ let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+ tokens.truncate(OPENAI_INPUT_LIMIT);
+ OPENAI_BPE_TOKENIZER
+ .decode(tokens)
+ .ok()
+ .unwrap_or_else(|| span.to_string())
+ } else {
+ span.to_string()
+ };
+
+ output
}
}
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
- fn truncate(span: String) -> String {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
- if tokens.len() > OPENAI_INPUT_LIMIT {
- tokens.truncate(OPENAI_INPUT_LIMIT);
- let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
- if result.is_ok() {
- let transformed = result.unwrap();
- return transformed;
- }
- }
-
- span
- }
-
async fn send_request(
&self,
api_key: &str,
@@ -144,6 +135,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
self.count_tokens(span) > OPENAI_INPUT_LIMIT
}
+ fn truncate(&self, span: &str) -> String {
+ let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+ let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+ tokens.truncate(OPENAI_INPUT_LIMIT);
+ OPENAI_BPE_TOKENIZER
+ .decode(tokens)
+ .ok()
+ .unwrap_or_else(|| span.to_string())
+ } else {
+ span.to_string()
+ };
+
+ output
+ }
+
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@@ -214,23 +220,13 @@ impl EmbeddingProvider for OpenAIEmbeddings {
self.executor.timer(delay_duration).await;
}
_ => {
- // TODO: Move this to parsing step
- // Only truncate if it hasnt been truncated before
- if !truncated {
- for span in spans.iter_mut() {
- *span = Self::truncate(span.clone());
- }
- truncated = true;
- } else {
- // If failing once already truncated, log the error and break the loop
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(
- "open ai bad request: {:?} {:?}",
- &response.status(),
- body
- ));
- }
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "open ai bad request: {:?} {:?}",
+ &response.status(),
+ body
+ ));
}
}
}
@@ -73,6 +73,7 @@ impl CodeContextRetriever {
sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
+ let document_span = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
@@ -93,6 +94,7 @@ impl CodeContextRetriever {
sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
+ let document_span = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
@@ -182,6 +184,8 @@ impl CodeContextRetriever {
.replace("item", &document.content);
let token_count = self.embedding_provider.count_tokens(&document_content);
+ let document_content = self.embedding_provider.truncate(&document_content);
+
document.content = document_content;
document.token_count = token_count;
}
@@ -1232,6 +1232,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
false
}
+ fn truncate(&self, span: &str) -> String {
+ span.to_string()
+ }
+
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);