@@ -6,6 +6,7 @@ pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
@@ -47,6 +48,18 @@ impl LanguageModel for OpenAILanguageModel {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ bpe.decode(tokens[length..].to_vec())
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
@@ -190,6 +190,13 @@ pub(crate) mod tests {
.collect::<String>(),
)
}
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ anyhow::Ok(
+ content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ )
+ }
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
@@ -1,9 +1,103 @@
use anyhow::anyhow;
+use language::BufferSnapshot;
use language::ToOffset;
+use crate::models::LanguageModel;
use crate::templates::base::PromptArguments;
use crate::templates::base::PromptTemplate;
use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+ buffer: &BufferSnapshot,
+ selected_range: &Option<Range<usize>>,
+ model: Arc<dyn LanguageModel>,
+ max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+ let mut prompt = String::new();
+ let mut truncated = false;
+ if let Some(selected_range) = selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+ let mut selected_window = String::new();
+ if start == end {
+ write!(selected_window, "<|START|>").unwrap();
+ } else {
+ write!(selected_window, "<|START|").unwrap();
+ }
+
+ write!(
+ selected_window,
+ "{}",
+ buffer.text_for_range(start..end).collect::<String>()
+ )
+ .unwrap();
+
+ if start != end {
+ write!(selected_window, "|END|>").unwrap();
+ }
+
+ let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+ if let Some(max_token_count) = max_token_count {
+ let selected_tokens = model.count_tokens(&selected_window)?;
+ if selected_tokens > max_token_count {
+ return Err(anyhow!(
+ "selected range is greater than model context window, truncation not possible"
+ ));
+ };
+
+ let mut remaining_tokens = max_token_count - selected_tokens;
+ let start_window_tokens = model.count_tokens(&start_window)?;
+ let end_window_tokens = model.count_tokens(&end_window)?;
+ let outside_tokens = start_window_tokens + end_window_tokens;
+ if outside_tokens > remaining_tokens {
+ let (start_goal_tokens, end_goal_tokens) =
+ if start_window_tokens < end_window_tokens {
+ let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+ remaining_tokens -= start_goal_tokens;
+ let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ } else {
+ let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+ remaining_tokens -= end_goal_tokens;
+ let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ };
+
+ let truncated_start_window =
+ model.truncate_start(&start_window, start_goal_tokens)?;
+ let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+ writeln!(
+ prompt,
+ "{truncated_start_window}{selected_window}{truncated_end_window}"
+ )
+ .unwrap();
+ truncated = true;
+ } else {
+ writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+ }
+ } else {
+ // If we dont have a selected range, include entire file.
+ writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+ // Dumb truncation strategy
+ if let Some(max_token_count) = max_token_count {
+ if model.count_tokens(&prompt)? > max_token_count {
+ truncated = true;
+ prompt = model.truncate(&prompt, max_token_count)?;
+ }
+ }
+ }
+ }
+
+ let token_count = model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count, truncated))
+}
pub struct FileContext {}
@@ -28,53 +122,24 @@ impl PromptTemplate for FileContext {
.clone()
.unwrap_or("".to_string())
.to_lowercase();
- writeln!(prompt, "```{language_name}").unwrap();
+
+ let (context, _, truncated) = retrieve_context(
+ buffer,
+ &args.selected_range,
+ args.model.clone(),
+ max_token_length,
+ )?;
+ writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
- writeln!(
- prompt,
- "{}",
- buffer.text_for_range(0..start).collect::<String>()
- )
- .unwrap();
-
- if start == end {
- write!(prompt, "<|START|>").unwrap();
- } else {
- write!(prompt, "<|START|").unwrap();
- }
-
- write!(
- prompt,
- "{}",
- buffer.text_for_range(start..end).collect::<String>()
- )
- .unwrap();
- if start != end {
- write!(prompt, "|END|>").unwrap();
- }
-
- write!(
- prompt,
- "{}",
- buffer.text_for_range(end..buffer.len()).collect::<String>()
- )
- .unwrap();
-
- writeln!(prompt, "```").unwrap();
-
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
- } else {
- // If we dont have a selected range, include entire file.
- writeln!(prompt, "{}", &buffer.text()).unwrap();
- writeln!(prompt, "```").unwrap();
}
// Really dumb truncation strategy