progress on smarter truncation strategy for file context

KCaverly created

Change summary

crates/ai/src/models.rs                 |  13 ++
crates/ai/src/templates/base.rs         |   7 +
crates/ai/src/templates/file_context.rs | 139 +++++++++++++++++++-------
crates/assistant/src/prompts.rs         |   2 
4 files changed, 124 insertions(+), 37 deletions(-)

Detailed changes

crates/ai/src/models.rs 🔗

@@ -6,6 +6,7 @@ pub trait LanguageModel {
     fn name(&self) -> String;
     fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
     fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
+    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
     fn capacity(&self) -> anyhow::Result<usize>;
 }
 
@@ -47,6 +48,18 @@ impl LanguageModel for OpenAILanguageModel {
             Err(anyhow!("bpe for open ai model was not retrieved"))
         }
     }
+    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+        if let Some(bpe) = &self.bpe {
+            let tokens = bpe.encode_with_special_tokens(content);
+            if tokens.len() > length {
+                bpe.decode(tokens[length..].to_vec())
+            } else {
+                bpe.decode(tokens)
+            }
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
     fn capacity(&self) -> anyhow::Result<usize> {
         anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
     }

crates/ai/src/templates/base.rs 🔗

@@ -190,6 +190,13 @@ pub(crate) mod tests {
                         .collect::<String>(),
                 )
             }
+            fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+                anyhow::Ok(
+                    content.chars().collect::<Vec<char>>()[length..]
+                        .into_iter()
+                        .collect::<String>(),
+                )
+            }
             fn capacity(&self) -> anyhow::Result<usize> {
                 anyhow::Ok(self.capacity)
             }

crates/ai/src/templates/file_context.rs 🔗

@@ -1,9 +1,103 @@
 use anyhow::anyhow;
+use language::BufferSnapshot;
 use language::ToOffset;
 
+use crate::models::LanguageModel;
 use crate::templates::base::PromptArguments;
 use crate::templates::base::PromptTemplate;
 use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+    buffer: &BufferSnapshot,
+    selected_range: &Option<Range<usize>>,
+    model: Arc<dyn LanguageModel>,
+    max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+    let mut prompt = String::new();
+    let mut truncated = false;
+    if let Some(selected_range) = selected_range {
+        let start = selected_range.start.to_offset(buffer);
+        let end = selected_range.end.to_offset(buffer);
+
+        let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+        let mut selected_window = String::new();
+        if start == end {
+            write!(selected_window, "<|START|>").unwrap();
+        } else {
+            write!(selected_window, "<|START|").unwrap();
+        }
+
+        write!(
+            selected_window,
+            "{}",
+            buffer.text_for_range(start..end).collect::<String>()
+        )
+        .unwrap();
+
+        if start != end {
+            write!(selected_window, "|END|>").unwrap();
+        }
+
+        let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+        if let Some(max_token_count) = max_token_count {
+            let selected_tokens = model.count_tokens(&selected_window)?;
+            if selected_tokens > max_token_count {
+                return Err(anyhow!(
+                    "selected range is greater than model context window, truncation not possible"
+                ));
+            };
+
+            let mut remaining_tokens = max_token_count - selected_tokens;
+            let start_window_tokens = model.count_tokens(&start_window)?;
+            let end_window_tokens = model.count_tokens(&end_window)?;
+            let outside_tokens = start_window_tokens + end_window_tokens;
+            if outside_tokens > remaining_tokens {
+                let (start_goal_tokens, end_goal_tokens) =
+                    if start_window_tokens < end_window_tokens {
+                        let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+                        remaining_tokens -= start_goal_tokens;
+                        let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+                        (start_goal_tokens, end_goal_tokens)
+                    } else {
+                        let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+                        remaining_tokens -= end_goal_tokens;
+                        let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+                        (start_goal_tokens, end_goal_tokens)
+                    };
+
+                let truncated_start_window =
+                    model.truncate_start(&start_window, start_goal_tokens)?;
+                let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+                writeln!(
+                    prompt,
+                    "{truncated_start_window}{selected_window}{truncated_end_window}"
+                )
+                .unwrap();
+                truncated = true;
+            } else {
+                writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+            }
+        } else {
+            // If we dont have a selected range, include entire file.
+            writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+            // Dumb truncation strategy
+            if let Some(max_token_count) = max_token_count {
+                if model.count_tokens(&prompt)? > max_token_count {
+                    truncated = true;
+                    prompt = model.truncate(&prompt, max_token_count)?;
+                }
+            }
+        }
+    }
+
+    let token_count = model.count_tokens(&prompt)?;
+    anyhow::Ok((prompt, token_count, truncated))
+}
 
 pub struct FileContext {}
 
@@ -28,53 +122,24 @@ impl PromptTemplate for FileContext {
                 .clone()
                 .unwrap_or("".to_string())
                 .to_lowercase();
-            writeln!(prompt, "```{language_name}").unwrap();
+
+            let (context, _, truncated) = retrieve_context(
+                buffer,
+                &args.selected_range,
+                args.model.clone(),
+                max_token_length,
+            )?;
+            writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
 
             if let Some(selected_range) = &args.selected_range {
                 let start = selected_range.start.to_offset(buffer);
                 let end = selected_range.end.to_offset(buffer);
 
-                writeln!(
-                    prompt,
-                    "{}",
-                    buffer.text_for_range(0..start).collect::<String>()
-                )
-                .unwrap();
-
-                if start == end {
-                    write!(prompt, "<|START|>").unwrap();
-                } else {
-                    write!(prompt, "<|START|").unwrap();
-                }
-
-                write!(
-                    prompt,
-                    "{}",
-                    buffer.text_for_range(start..end).collect::<String>()
-                )
-                .unwrap();
-                if start != end {
-                    write!(prompt, "|END|>").unwrap();
-                }
-
-                write!(
-                    prompt,
-                    "{}",
-                    buffer.text_for_range(end..buffer.len()).collect::<String>()
-                )
-                .unwrap();
-
-                writeln!(prompt, "```").unwrap();
-
                 if start == end {
                     writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
                 } else {
                     writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
                 }
-            } else {
-                // If we dont have a selected range, include entire file.
-                writeln!(prompt, "{}", &buffer.text()).unwrap();
-                writeln!(prompt, "```").unwrap();
             }
 
             // Really dumb truncation strategy

crates/assistant/src/prompts.rs 🔗

@@ -166,6 +166,8 @@ pub fn generate_content_prompt(
     let chain = PromptChain::new(args, templates);
     let (prompt, _) = chain.generate(true)?;
 
+    println!("PROMPT: {:?}", &prompt);
+
     anyhow::Ok(prompt)
 }