cleaned up truncate vs truncate start

KCaverly created

Change summary

crates/ai/src/models.rs                 | 37 +++++++++++++++-----------
crates/ai/src/templates/base.rs         | 33 ++++++++++++++++--------
crates/ai/src/templates/file_context.rs | 12 +++++--
crates/ai/src/templates/generate.rs     |  6 +++
4 files changed, 56 insertions(+), 32 deletions(-)

Detailed changes

crates/ai/src/models.rs 🔗

@@ -2,11 +2,20 @@ use anyhow::anyhow;
 use tiktoken_rs::CoreBPE;
 use util::ResultExt;
 
+pub enum TruncationDirection {
+    Start,
+    End,
+}
+
 pub trait LanguageModel {
     fn name(&self) -> String;
     fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
-    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
-    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String>;
     fn capacity(&self) -> anyhow::Result<usize>;
 }
 
@@ -36,23 +45,19 @@ impl LanguageModel for OpenAILanguageModel {
             Err(anyhow!("bpe for open ai model was not retrieved"))
         }
     }
-    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
-        if let Some(bpe) = &self.bpe {
-            let tokens = bpe.encode_with_special_tokens(content);
-            if tokens.len() > length {
-                bpe.decode(tokens[..length].to_vec())
-            } else {
-                bpe.decode(tokens)
-            }
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String> {
         if let Some(bpe) = &self.bpe {
             let tokens = bpe.encode_with_special_tokens(content);
             if tokens.len() > length {
-                bpe.decode(tokens[length..].to_vec())
+                match direction {
+                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+                }
             } else {
                 bpe.decode(tokens)
             }

crates/ai/src/templates/base.rs 🔗

@@ -125,6 +125,8 @@ impl PromptChain {
 
 #[cfg(test)]
 pub(crate) mod tests {
+    use crate::models::TruncationDirection;
+
     use super::*;
 
     #[test]
@@ -141,7 +143,11 @@ pub(crate) mod tests {
                 let mut token_count = args.model.count_tokens(&content)?;
                 if let Some(max_token_length) = max_token_length {
                     if token_count > max_token_length {
-                        content = args.model.truncate(&content, max_token_length)?;
+                        content = args.model.truncate(
+                            &content,
+                            max_token_length,
+                            TruncationDirection::Start,
+                        )?;
                         token_count = max_token_length;
                     }
                 }
@@ -162,7 +168,11 @@ pub(crate) mod tests {
                 let mut token_count = args.model.count_tokens(&content)?;
                 if let Some(max_token_length) = max_token_length {
                     if token_count > max_token_length {
-                        content = args.model.truncate(&content, max_token_length)?;
+                        content = args.model.truncate(
+                            &content,
+                            max_token_length,
+                            TruncationDirection::Start,
+                        )?;
                         token_count = max_token_length;
                     }
                 }
@@ -183,19 +193,20 @@ pub(crate) mod tests {
             fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
                 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
             }
-            fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
-                anyhow::Ok(
-                    content.chars().collect::<Vec<char>>()[..length]
+            fn truncate(
+                &self,
+                content: &str,
+                length: usize,
+                direction: TruncationDirection,
+            ) -> anyhow::Result<String> {
+                anyhow::Ok(match direction {
+                    TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
                         .into_iter()
                         .collect::<String>(),
-                )
-            }
-            fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
-                anyhow::Ok(
-                    content.chars().collect::<Vec<char>>()[length..]
+                    TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
                         .into_iter()
                         .collect::<String>(),
-                )
+                })
             }
             fn capacity(&self) -> anyhow::Result<usize> {
                 anyhow::Ok(self.capacity)

crates/ai/src/templates/file_context.rs 🔗

@@ -3,6 +3,7 @@ use language::BufferSnapshot;
 use language::ToOffset;
 
 use crate::models::LanguageModel;
+use crate::models::TruncationDirection;
 use crate::templates::base::PromptArguments;
 use crate::templates::base::PromptTemplate;
 use std::fmt::Write;
@@ -70,8 +71,9 @@ fn retrieve_context(
                     };
 
                 let truncated_start_window =
-                    model.truncate_start(&start_window, start_goal_tokens)?;
-                let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+                    model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+                let truncated_end_window =
+                    model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
                 writeln!(
                     prompt,
                     "{truncated_start_window}{selected_window}{truncated_end_window}"
@@ -89,7 +91,7 @@ fn retrieve_context(
             if let Some(max_token_count) = max_token_count {
                 if model.count_tokens(&prompt)? > max_token_count {
                     truncated = true;
-                    prompt = model.truncate(&prompt, max_token_count)?;
+                    prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
                 }
             }
         }
@@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
 
             // Really dumb truncation strategy
             if let Some(max_tokens) = max_token_length {
-                prompt = args.model.truncate(&prompt, max_tokens)?;
+                prompt = args
+                    .model
+                    .truncate(&prompt, max_tokens, TruncationDirection::End)?;
             }
 
             let token_count = args.model.count_tokens(&prompt)?;

crates/ai/src/templates/generate.rs 🔗

@@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
 
         // Really dumb truncation strategy
         if let Some(max_tokens) = max_token_length {
-            prompt = args.model.truncate(&prompt, max_tokens)?;
+            prompt = args.model.truncate(
+                &prompt,
+                max_tokens,
+                crate::models::TruncationDirection::End,
+            )?;
         }
 
         let token_count = args.model.count_tokens(&prompt)?;