1use std::cmp::Reverse;
2use std::ops::Range;
3use std::sync::Arc;
4
5use language::BufferSnapshot;
6use util::ResultExt;
7
8use crate::models::LanguageModel;
9use crate::prompts::repository_context::PromptCodeSnippet;
10
11pub(crate) enum PromptFileType {
12 Text,
13 Code,
14}
15
16// TODO: Set this up to manage for defaults well
17pub struct PromptArguments {
18 pub model: Arc<dyn LanguageModel>,
19 pub user_prompt: Option<String>,
20 pub language_name: Option<String>,
21 pub project_name: Option<String>,
22 pub snippets: Vec<PromptCodeSnippet>,
23 pub reserved_tokens: usize,
24 pub buffer: Option<BufferSnapshot>,
25 pub selected_range: Option<Range<usize>>,
26}
27
28impl PromptArguments {
29 pub(crate) fn get_file_type(&self) -> PromptFileType {
30 if self
31 .language_name
32 .as_ref()
33 .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
34 .unwrap_or(true)
35 {
36 PromptFileType::Code
37 } else {
38 PromptFileType::Text
39 }
40 }
41}
42
43pub trait PromptTemplate {
44 fn generate(
45 &self,
46 args: &PromptArguments,
47 max_token_length: Option<usize>,
48 ) -> anyhow::Result<(String, usize)>;
49}
50
51#[repr(i8)]
52#[derive(PartialEq, Eq, Ord)]
53pub enum PromptPriority {
54 Mandatory, // Ignores truncation
55 Ordered { order: usize }, // Truncates based on priority
56}
57
58impl PartialOrd for PromptPriority {
59 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
60 match (self, other) {
61 (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
62 (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
63 (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
64 (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
65 }
66 }
67}
68
69pub struct PromptChain {
70 args: PromptArguments,
71 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
72}
73
74impl PromptChain {
75 pub fn new(
76 args: PromptArguments,
77 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
78 ) -> Self {
79 PromptChain { args, templates }
80 }
81
82 pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
83 // Argsort based on Prompt Priority
84 let separator = "\n";
85 let separator_tokens = self.args.model.count_tokens(separator)?;
86 let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
87 sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
88
89 // If Truncate
90 let mut tokens_outstanding = if truncate {
91 Some(self.args.model.capacity()? - self.args.reserved_tokens)
92 } else {
93 None
94 };
95
96 let mut prompts = vec!["".to_string(); sorted_indices.len()];
97 for idx in sorted_indices {
98 let (_, template) = &self.templates[idx];
99
100 if let Some((template_prompt, prompt_token_count)) =
101 template.generate(&self.args, tokens_outstanding).log_err()
102 {
103 if template_prompt != "" {
104 prompts[idx] = template_prompt;
105
106 if let Some(remaining_tokens) = tokens_outstanding {
107 let new_tokens = prompt_token_count + separator_tokens;
108 tokens_outstanding = if remaining_tokens > new_tokens {
109 Some(remaining_tokens - new_tokens)
110 } else {
111 Some(0)
112 };
113 }
114 }
115 }
116 }
117
118 prompts.retain(|x| x != "");
119
120 let full_prompt = prompts.join(separator);
121 let total_token_count = self.args.model.count_tokens(&full_prompt)?;
122 anyhow::Ok((prompts.join(separator), total_token_count))
123 }
124}
125
126#[cfg(test)]
127pub(crate) mod tests {
128 use crate::models::TruncationDirection;
129 use crate::test::FakeLanguageModel;
130
131 use super::*;
132
133 #[test]
134 pub fn test_prompt_chain() {
135 struct TestPromptTemplate {}
136 impl PromptTemplate for TestPromptTemplate {
137 fn generate(
138 &self,
139 args: &PromptArguments,
140 max_token_length: Option<usize>,
141 ) -> anyhow::Result<(String, usize)> {
142 let mut content = "This is a test prompt template".to_string();
143
144 let mut token_count = args.model.count_tokens(&content)?;
145 if let Some(max_token_length) = max_token_length {
146 if token_count > max_token_length {
147 content = args.model.truncate(
148 &content,
149 max_token_length,
150 TruncationDirection::End,
151 )?;
152 token_count = max_token_length;
153 }
154 }
155
156 anyhow::Ok((content, token_count))
157 }
158 }
159
160 struct TestLowPriorityTemplate {}
161 impl PromptTemplate for TestLowPriorityTemplate {
162 fn generate(
163 &self,
164 args: &PromptArguments,
165 max_token_length: Option<usize>,
166 ) -> anyhow::Result<(String, usize)> {
167 let mut content = "This is a low priority test prompt template".to_string();
168
169 let mut token_count = args.model.count_tokens(&content)?;
170 if let Some(max_token_length) = max_token_length {
171 if token_count > max_token_length {
172 content = args.model.truncate(
173 &content,
174 max_token_length,
175 TruncationDirection::End,
176 )?;
177 token_count = max_token_length;
178 }
179 }
180
181 anyhow::Ok((content, token_count))
182 }
183 }
184
185 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
186 let args = PromptArguments {
187 model: model.clone(),
188 language_name: None,
189 project_name: None,
190 snippets: Vec::new(),
191 reserved_tokens: 0,
192 buffer: None,
193 selected_range: None,
194 user_prompt: None,
195 };
196
197 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
198 (
199 PromptPriority::Ordered { order: 0 },
200 Box::new(TestPromptTemplate {}),
201 ),
202 (
203 PromptPriority::Ordered { order: 1 },
204 Box::new(TestLowPriorityTemplate {}),
205 ),
206 ];
207 let chain = PromptChain::new(args, templates);
208
209 let (prompt, token_count) = chain.generate(false).unwrap();
210
211 assert_eq!(
212 prompt,
213 "This is a test prompt template\nThis is a low priority test prompt template"
214 .to_string()
215 );
216
217 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
218
219 // Testing with Truncation Off
220 // Should ignore capacity and return all prompts
221 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
222 let args = PromptArguments {
223 model: model.clone(),
224 language_name: None,
225 project_name: None,
226 snippets: Vec::new(),
227 reserved_tokens: 0,
228 buffer: None,
229 selected_range: None,
230 user_prompt: None,
231 };
232
233 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
234 (
235 PromptPriority::Ordered { order: 0 },
236 Box::new(TestPromptTemplate {}),
237 ),
238 (
239 PromptPriority::Ordered { order: 1 },
240 Box::new(TestLowPriorityTemplate {}),
241 ),
242 ];
243 let chain = PromptChain::new(args, templates);
244
245 let (prompt, token_count) = chain.generate(false).unwrap();
246
247 assert_eq!(
248 prompt,
249 "This is a test prompt template\nThis is a low priority test prompt template"
250 .to_string()
251 );
252
253 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
254
255 // Testing with Truncation Off
256 // Should ignore capacity and return all prompts
257 let capacity = 20;
258 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
259 let args = PromptArguments {
260 model: model.clone(),
261 language_name: None,
262 project_name: None,
263 snippets: Vec::new(),
264 reserved_tokens: 0,
265 buffer: None,
266 selected_range: None,
267 user_prompt: None,
268 };
269
270 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
271 (
272 PromptPriority::Ordered { order: 0 },
273 Box::new(TestPromptTemplate {}),
274 ),
275 (
276 PromptPriority::Ordered { order: 1 },
277 Box::new(TestLowPriorityTemplate {}),
278 ),
279 (
280 PromptPriority::Ordered { order: 2 },
281 Box::new(TestLowPriorityTemplate {}),
282 ),
283 ];
284 let chain = PromptChain::new(args, templates);
285
286 let (prompt, token_count) = chain.generate(true).unwrap();
287
288 assert_eq!(prompt, "This is a test promp".to_string());
289 assert_eq!(token_count, capacity);
290
291 // Change Ordering of Prompts Based on Priority
292 let capacity = 120;
293 let reserved_tokens = 10;
294 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
295 let args = PromptArguments {
296 model: model.clone(),
297 language_name: None,
298 project_name: None,
299 snippets: Vec::new(),
300 reserved_tokens,
301 buffer: None,
302 selected_range: None,
303 user_prompt: None,
304 };
305 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
306 (
307 PromptPriority::Mandatory,
308 Box::new(TestLowPriorityTemplate {}),
309 ),
310 (
311 PromptPriority::Ordered { order: 0 },
312 Box::new(TestPromptTemplate {}),
313 ),
314 (
315 PromptPriority::Ordered { order: 1 },
316 Box::new(TestLowPriorityTemplate {}),
317 ),
318 ];
319 let chain = PromptChain::new(args, templates);
320
321 let (prompt, token_count) = chain.generate(true).unwrap();
322
323 assert_eq!(
324 prompt,
325 "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
326 .to_string()
327 );
328 assert_eq!(token_count, capacity - reserved_tokens);
329 }
330}