diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 2afcc87ff5dc49072b558fffc4f22da1a34909e9..923e1833c2115953a27044d198497db256287907 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use std::cmp::Reverse; use std::ops::Range; use std::sync::Arc; @@ -96,36 +95,15 @@ impl PromptChain { let mut prompts = vec!["".to_string(); sorted_indices.len()]; for idx in sorted_indices { - let (priority, template) = &self.templates[idx]; - - // If PromptPriority is marked as mandatory, we ignore the tokens outstanding - // However, if a prompt is generated in excess of the available tokens, - // we raise an error outlining that a mandatory prompt has exceeded the available - // balance - let template_tokens = if let Some(template_tokens) = tokens_outstanding { - match priority { - &PromptPriority::Mandatory => None, - _ => Some(template_tokens), - } - } else { - None - }; + let (_, template) = &self.templates[idx]; if let Some((template_prompt, prompt_token_count)) = - template.generate(&self.args, template_tokens).log_err() + template.generate(&self.args, tokens_outstanding).log_err() { if template_prompt != "" { prompts[idx] = template_prompt; if let Some(remaining_tokens) = tokens_outstanding { - if prompt_token_count > remaining_tokens - && priority == &PromptPriority::Mandatory - { - return Err(anyhow!( - "mandatory template added in excess of model capacity" - )); - } - let new_tokens = prompt_token_count + seperator_tokens; tokens_outstanding = if remaining_tokens > new_tokens { Some(remaining_tokens - new_tokens) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index e28f9ccdedb293817c22f54e0b2a12f17a40ac9f..00fe99dd7ffc257339caa8c5198e532b967bee40 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use language::ToOffset; use crate::templates::base::PromptArguments; @@ -13,12 +12,6 @@ impl PromptTemplate for FileContext { args: &PromptArguments, max_token_length: Option, ) -> anyhow::Result<(String, usize)> { - if max_token_length.is_some() { - return Err(anyhow!( - "no truncation strategy established for file_context template" - )); - } - let mut prompt = String::new(); // Add Initial Preamble @@ -84,6 +77,11 @@ impl PromptTemplate for FileContext { } } + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + let token_count = args.model.count_tokens(&prompt)?; anyhow::Ok((prompt, token_count)) } diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index 053398e873828a22e35e90e5a12b7137c83b2de0..34d874cc4128ccae034a0ecf3beace159bbec1ac 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -18,12 +18,6 @@ impl PromptTemplate for GenerateInlineContent { args: &PromptArguments, max_token_length: Option, ) -> anyhow::Result<(String, usize)> { - if max_token_length.is_some() { - return Err(anyhow!( - "no truncation strategy established for generating inline content template" - )); - } - let Some(user_prompt) = &args.user_prompt else { return Err(anyhow!("user prompt not provided")); }; @@ -88,6 +82,11 @@ impl PromptTemplate for GenerateInlineContent { _ => {} } + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + let token_count = args.model.count_tokens(&prompt)?; anyhow::Ok((prompt, token_count))