1use std::cmp::Reverse;
2use std::ops::Range;
3use std::sync::Arc;
4
5use language::BufferSnapshot;
6use util::ResultExt;
7
8use crate::models::LanguageModel;
9use crate::templates::repository_context::PromptCodeSnippet;
10
11pub(crate) enum PromptFileType {
12 Text,
13 Code,
14}
15
16// TODO: Set this up to manage for defaults well
17pub struct PromptArguments {
18 pub model: Arc<dyn LanguageModel>,
19 pub user_prompt: Option<String>,
20 pub language_name: Option<String>,
21 pub project_name: Option<String>,
22 pub snippets: Vec<PromptCodeSnippet>,
23 pub reserved_tokens: usize,
24 pub buffer: Option<BufferSnapshot>,
25 pub selected_range: Option<Range<usize>>,
26}
27
28impl PromptArguments {
29 pub(crate) fn get_file_type(&self) -> PromptFileType {
30 if self
31 .language_name
32 .as_ref()
33 .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
34 .unwrap_or(true)
35 {
36 PromptFileType::Code
37 } else {
38 PromptFileType::Text
39 }
40 }
41}
42
43pub trait PromptTemplate {
44 fn generate(
45 &self,
46 args: &PromptArguments,
47 max_token_length: Option<usize>,
48 ) -> anyhow::Result<(String, usize)>;
49}
50
51#[repr(i8)]
52#[derive(PartialEq, Eq, Ord)]
53pub enum PromptPriority {
54 Mandatory, // Ignores truncation
55 Ordered { order: usize }, // Truncates based on priority
56}
57
58impl PartialOrd for PromptPriority {
59 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
60 match (self, other) {
61 (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
62 (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
63 (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
64 (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
65 }
66 }
67}
68
69pub struct PromptChain {
70 args: PromptArguments,
71 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
72}
73
74impl PromptChain {
75 pub fn new(
76 args: PromptArguments,
77 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
78 ) -> Self {
79 PromptChain { args, templates }
80 }
81
82 pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
83 // Argsort based on Prompt Priority
84 let seperator = "\n";
85 let seperator_tokens = self.args.model.count_tokens(seperator)?;
86 let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
87 sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
88
89 // If Truncate
90 let mut tokens_outstanding = if truncate {
91 Some(self.args.model.capacity()? - self.args.reserved_tokens)
92 } else {
93 None
94 };
95
96 let mut prompts = vec!["".to_string(); sorted_indices.len()];
97 for idx in sorted_indices {
98 let (_, template) = &self.templates[idx];
99
100 if let Some((template_prompt, prompt_token_count)) =
101 template.generate(&self.args, tokens_outstanding).log_err()
102 {
103 if template_prompt != "" {
104 prompts[idx] = template_prompt;
105
106 if let Some(remaining_tokens) = tokens_outstanding {
107 let new_tokens = prompt_token_count + seperator_tokens;
108 tokens_outstanding = if remaining_tokens > new_tokens {
109 Some(remaining_tokens - new_tokens)
110 } else {
111 Some(0)
112 };
113 }
114 }
115 }
116 }
117
118 prompts.retain(|x| x != "");
119
120 let full_prompt = prompts.join(seperator);
121 let total_token_count = self.args.model.count_tokens(&full_prompt)?;
122 anyhow::Ok((prompts.join(seperator), total_token_count))
123 }
124}
125
126#[cfg(test)]
127pub(crate) mod tests {
128 use super::*;
129
130 #[test]
131 pub fn test_prompt_chain() {
132 struct TestPromptTemplate {}
133 impl PromptTemplate for TestPromptTemplate {
134 fn generate(
135 &self,
136 args: &PromptArguments,
137 max_token_length: Option<usize>,
138 ) -> anyhow::Result<(String, usize)> {
139 let mut content = "This is a test prompt template".to_string();
140
141 let mut token_count = args.model.count_tokens(&content)?;
142 if let Some(max_token_length) = max_token_length {
143 if token_count > max_token_length {
144 content = args.model.truncate(&content, max_token_length)?;
145 token_count = max_token_length;
146 }
147 }
148
149 anyhow::Ok((content, token_count))
150 }
151 }
152
153 struct TestLowPriorityTemplate {}
154 impl PromptTemplate for TestLowPriorityTemplate {
155 fn generate(
156 &self,
157 args: &PromptArguments,
158 max_token_length: Option<usize>,
159 ) -> anyhow::Result<(String, usize)> {
160 let mut content = "This is a low priority test prompt template".to_string();
161
162 let mut token_count = args.model.count_tokens(&content)?;
163 if let Some(max_token_length) = max_token_length {
164 if token_count > max_token_length {
165 content = args.model.truncate(&content, max_token_length)?;
166 token_count = max_token_length;
167 }
168 }
169
170 anyhow::Ok((content, token_count))
171 }
172 }
173
174 #[derive(Clone)]
175 struct DummyLanguageModel {
176 capacity: usize,
177 }
178
179 impl LanguageModel for DummyLanguageModel {
180 fn name(&self) -> String {
181 "dummy".to_string()
182 }
183 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
184 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
185 }
186 fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
187 anyhow::Ok(
188 content.chars().collect::<Vec<char>>()[..length]
189 .into_iter()
190 .collect::<String>(),
191 )
192 }
193 fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
194 anyhow::Ok(
195 content.chars().collect::<Vec<char>>()[length..]
196 .into_iter()
197 .collect::<String>(),
198 )
199 }
200 fn capacity(&self) -> anyhow::Result<usize> {
201 anyhow::Ok(self.capacity)
202 }
203 }
204
205 let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
206 let args = PromptArguments {
207 model: model.clone(),
208 language_name: None,
209 project_name: None,
210 snippets: Vec::new(),
211 reserved_tokens: 0,
212 buffer: None,
213 selected_range: None,
214 user_prompt: None,
215 };
216
217 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
218 (
219 PromptPriority::Ordered { order: 0 },
220 Box::new(TestPromptTemplate {}),
221 ),
222 (
223 PromptPriority::Ordered { order: 1 },
224 Box::new(TestLowPriorityTemplate {}),
225 ),
226 ];
227 let chain = PromptChain::new(args, templates);
228
229 let (prompt, token_count) = chain.generate(false).unwrap();
230
231 assert_eq!(
232 prompt,
233 "This is a test prompt template\nThis is a low priority test prompt template"
234 .to_string()
235 );
236
237 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
238
239 // Testing with Truncation Off
240 // Should ignore capacity and return all prompts
241 let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
242 let args = PromptArguments {
243 model: model.clone(),
244 language_name: None,
245 project_name: None,
246 snippets: Vec::new(),
247 reserved_tokens: 0,
248 buffer: None,
249 selected_range: None,
250 user_prompt: None,
251 };
252
253 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
254 (
255 PromptPriority::Ordered { order: 0 },
256 Box::new(TestPromptTemplate {}),
257 ),
258 (
259 PromptPriority::Ordered { order: 1 },
260 Box::new(TestLowPriorityTemplate {}),
261 ),
262 ];
263 let chain = PromptChain::new(args, templates);
264
265 let (prompt, token_count) = chain.generate(false).unwrap();
266
267 assert_eq!(
268 prompt,
269 "This is a test prompt template\nThis is a low priority test prompt template"
270 .to_string()
271 );
272
273 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
274
275 // Testing with Truncation Off
276 // Should ignore capacity and return all prompts
277 let capacity = 20;
278 let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
279 let args = PromptArguments {
280 model: model.clone(),
281 language_name: None,
282 project_name: None,
283 snippets: Vec::new(),
284 reserved_tokens: 0,
285 buffer: None,
286 selected_range: None,
287 user_prompt: None,
288 };
289
290 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
291 (
292 PromptPriority::Ordered { order: 0 },
293 Box::new(TestPromptTemplate {}),
294 ),
295 (
296 PromptPriority::Ordered { order: 1 },
297 Box::new(TestLowPriorityTemplate {}),
298 ),
299 (
300 PromptPriority::Ordered { order: 2 },
301 Box::new(TestLowPriorityTemplate {}),
302 ),
303 ];
304 let chain = PromptChain::new(args, templates);
305
306 let (prompt, token_count) = chain.generate(true).unwrap();
307
308 assert_eq!(prompt, "This is a test promp".to_string());
309 assert_eq!(token_count, capacity);
310
311 // Change Ordering of Prompts Based on Priority
312 let capacity = 120;
313 let reserved_tokens = 10;
314 let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
315 let args = PromptArguments {
316 model: model.clone(),
317 language_name: None,
318 project_name: None,
319 snippets: Vec::new(),
320 reserved_tokens,
321 buffer: None,
322 selected_range: None,
323 user_prompt: None,
324 };
325 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
326 (
327 PromptPriority::Mandatory,
328 Box::new(TestLowPriorityTemplate {}),
329 ),
330 (
331 PromptPriority::Ordered { order: 0 },
332 Box::new(TestPromptTemplate {}),
333 ),
334 (
335 PromptPriority::Ordered { order: 1 },
336 Box::new(TestLowPriorityTemplate {}),
337 ),
338 ];
339 let chain = PromptChain::new(args, templates);
340
341 let (prompt, token_count) = chain.generate(true).unwrap();
342
343 assert_eq!(
344 prompt,
345 "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
346 .to_string()
347 );
348 assert_eq!(token_count, capacity - reserved_tokens);
349 }
350}