Detailed changes
@@ -2,11 +2,20 @@ use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
+pub enum TruncationDirection {
+ Start,
+ End,
+}
+
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
@@ -36,23 +45,19 @@ impl LanguageModel for OpenAILanguageModel {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- bpe.decode(tokens[..length].to_vec())
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
- bpe.decode(tokens[length..].to_vec())
+ match direction {
+ TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+ TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+ }
} else {
bpe.decode(tokens)
}
@@ -125,6 +125,8 @@ impl PromptChain {
#[cfg(test)]
pub(crate) mod tests {
+ use crate::models::TruncationDirection;
+
use super::*;
#[test]
@@ -141,7 +143,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
- content = args.model.truncate(&content, max_token_length)?;
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::Start,
+ )?;
token_count = max_token_length;
}
}
@@ -162,7 +168,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
- content = args.model.truncate(&content, max_token_length)?;
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::Start,
+ )?;
token_count = max_token_length;
}
}
@@ -183,19 +193,20 @@ pub(crate) mod tests {
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
- anyhow::Ok(
- content.chars().collect::<Vec<char>>()[..length]
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ anyhow::Ok(match direction {
+ TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
- )
- }
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
- anyhow::Ok(
- content.chars().collect::<Vec<char>>()[length..]
+ TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
- )
+ })
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
@@ -3,6 +3,7 @@ use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
+use crate::models::TruncationDirection;
use crate::templates::base::PromptArguments;
use crate::templates::base::PromptTemplate;
use std::fmt::Write;
@@ -70,8 +71,9 @@ fn retrieve_context(
};
let truncated_start_window =
- model.truncate_start(&start_window, start_goal_tokens)?;
- let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+ model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+ let truncated_end_window =
+ model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
@@ -89,7 +91,7 @@ fn retrieve_context(
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
- prompt = model.truncate(&prompt, max_token_count)?;
+ prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
@@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(&prompt, max_tokens)?;
+ prompt = args
+ .model
+ .truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
@@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(&prompt, max_tokens)?;
+ prompt = args.model.truncate(
+ &prompt,
+ max_tokens,
+ crate::models::TruncationDirection::End,
+ )?;
}
let token_count = args.model.count_tokens(&prompt)?;