From 0b57ab730332dbf0033d652b4b531b2898c88039 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:34:22 +0200 Subject: [PATCH] cleaned up truncate vs truncate start --- crates/ai/src/models.rs | 37 ++++++++++++++----------- crates/ai/src/templates/base.rs | 33 ++++++++++++++-------- crates/ai/src/templates/file_context.rs | 12 +++++--- crates/ai/src/templates/generate.rs | 6 +++- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index d0206cc41c526f171fef8521a120f8f4ff70aa74..afb8496156f6521eb3125b6a0ba6d703d5d0fe50 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -2,11 +2,20 @@ use anyhow::anyhow; use tiktoken_rs::CoreBPE; use util::ResultExt; +pub enum TruncationDirection { + Start, + End, +} + pub trait LanguageModel { fn name(&self) -> String; fn count_tokens(&self, content: &str) -> anyhow::Result; - fn truncate(&self, content: &str, length: usize) -> anyhow::Result; - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } @@ -36,23 +45,19 @@ impl LanguageModel for OpenAILanguageModel { Err(anyhow!("bpe for open ai model was not retrieved")) } } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - bpe.decode(tokens[..length].to_vec()) - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { if let Some(bpe) = &self.bpe { let tokens = bpe.encode_with_special_tokens(content); if tokens.len() > length { - bpe.decode(tokens[length..].to_vec()) + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } } else { bpe.decode(tokens) } diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index bda1d6c30e61a9e2fd3808fa45a34cbe041cf2b6..e5ac414bc1691b02090361aa19bd0c56ee1557f5 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -125,6 +125,8 @@ impl PromptChain { #[cfg(test)] pub(crate) mod tests { + use crate::models::TruncationDirection; + use super::*; #[test] @@ -141,7 +143,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::Start, + )?; token_count = max_token_length; } } @@ -162,7 +168,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::Start, + )?; token_count = max_token_length; } } @@ -183,19 +193,20 @@ pub(crate) mod tests { fn count_tokens(&self, content: &str) -> anyhow::Result { anyhow::Ok(content.chars().collect::>().len()) } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[..length] + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] .into_iter() .collect::(), - ) - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[length..] + TruncationDirection::Start => content.chars().collect::>()[length..] .into_iter() .collect::(), - ) + }) } fn capacity(&self) -> anyhow::Result { anyhow::Ok(self.capacity) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 1afd61192edc02b153abe8cd00836d67caa42f02..1517134abb97c05866c007c7072175bc2f7f6aca 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -3,6 +3,7 @@ use language::BufferSnapshot; use language::ToOffset; use crate::models::LanguageModel; +use crate::models::TruncationDirection; use crate::templates::base::PromptArguments; use crate::templates::base::PromptTemplate; use std::fmt::Write; @@ -70,8 +71,9 @@ fn retrieve_context( }; let truncated_start_window = - model.truncate_start(&start_window, start_goal_tokens)?; - let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; writeln!( prompt, "{truncated_start_window}{selected_window}{truncated_end_window}" @@ -89,7 +91,7 @@ fn retrieve_context( if let Some(max_token_count) = max_token_count { if model.count_tokens(&prompt)? > max_token_count { truncated = true; - prompt = model.truncate(&prompt, max_token_count)?; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; } } } @@ -148,7 +150,9 @@ impl PromptTemplate for FileContext { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; } let token_count = args.model.count_tokens(&prompt)?; diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index 1eeb197f932db0dc13963982e7e8bc983c338db7..c9541c6b44a076fdd87b491669b34616fec04e24 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; } let token_count = args.model.count_tokens(&prompt)?;