1use std::cmp::Reverse;
2use std::ops::Range;
3use std::sync::Arc;
4
5use language::BufferSnapshot;
6use util::ResultExt;
7
8use crate::models::LanguageModel;
9use crate::prompts::repository_context::PromptCodeSnippet;
10
11pub(crate) enum PromptFileType {
12 Text,
13 Code,
14}
15
16// TODO: Set this up to manage for defaults well
17pub struct PromptArguments {
18 pub model: Arc<dyn LanguageModel>,
19 pub user_prompt: Option<String>,
20 pub language_name: Option<String>,
21 pub project_name: Option<String>,
22 pub snippets: Vec<PromptCodeSnippet>,
23 pub reserved_tokens: usize,
24 pub buffer: Option<BufferSnapshot>,
25 pub selected_range: Option<Range<usize>>,
26}
27
28impl PromptArguments {
29 pub(crate) fn get_file_type(&self) -> PromptFileType {
30 if self
31 .language_name
32 .as_ref()
33 .map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
34 .unwrap_or(true)
35 {
36 PromptFileType::Code
37 } else {
38 PromptFileType::Text
39 }
40 }
41}
42
43pub trait PromptTemplate {
44 fn generate(
45 &self,
46 args: &PromptArguments,
47 max_token_length: Option<usize>,
48 ) -> anyhow::Result<(String, usize)>;
49}
50
51#[repr(i8)]
52#[derive(PartialEq, Eq, Ord)]
53pub enum PromptPriority {
54 /// Ignores truncation.
55 Mandatory,
56 /// Truncates based on priority.
57 Ordered { order: usize },
58}
59
60impl PartialOrd for PromptPriority {
61 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
62 match (self, other) {
63 (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
64 (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
65 (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
66 (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
67 }
68 }
69}
70
71pub struct PromptChain {
72 args: PromptArguments,
73 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
74}
75
76impl PromptChain {
77 pub fn new(
78 args: PromptArguments,
79 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
80 ) -> Self {
81 PromptChain { args, templates }
82 }
83
84 pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
85 // Argsort based on Prompt Priority
86 let separator = "\n";
87 let separator_tokens = self.args.model.count_tokens(separator)?;
88 let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
89 sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
90
91 let mut tokens_outstanding = if truncate {
92 Some(self.args.model.capacity()? - self.args.reserved_tokens)
93 } else {
94 None
95 };
96
97 let mut prompts = vec!["".to_string(); sorted_indices.len()];
98 for idx in sorted_indices {
99 let (_, template) = &self.templates[idx];
100
101 if let Some((template_prompt, prompt_token_count)) =
102 template.generate(&self.args, tokens_outstanding).log_err()
103 {
104 if template_prompt != "" {
105 prompts[idx] = template_prompt;
106
107 if let Some(remaining_tokens) = tokens_outstanding {
108 let new_tokens = prompt_token_count + separator_tokens;
109 tokens_outstanding = if remaining_tokens > new_tokens {
110 Some(remaining_tokens - new_tokens)
111 } else {
112 Some(0)
113 };
114 }
115 }
116 }
117 }
118
119 prompts.retain(|x| x != "");
120
121 let full_prompt = prompts.join(separator);
122 let total_token_count = self.args.model.count_tokens(&full_prompt)?;
123 anyhow::Ok((prompts.join(separator), total_token_count))
124 }
125}
126
127#[cfg(test)]
128pub(crate) mod tests {
129 use crate::models::TruncationDirection;
130 use crate::test::FakeLanguageModel;
131
132 use super::*;
133
134 #[test]
135 pub fn test_prompt_chain() {
136 struct TestPromptTemplate {}
137 impl PromptTemplate for TestPromptTemplate {
138 fn generate(
139 &self,
140 args: &PromptArguments,
141 max_token_length: Option<usize>,
142 ) -> anyhow::Result<(String, usize)> {
143 let mut content = "This is a test prompt template".to_string();
144
145 let mut token_count = args.model.count_tokens(&content)?;
146 if let Some(max_token_length) = max_token_length {
147 if token_count > max_token_length {
148 content = args.model.truncate(
149 &content,
150 max_token_length,
151 TruncationDirection::End,
152 )?;
153 token_count = max_token_length;
154 }
155 }
156
157 anyhow::Ok((content, token_count))
158 }
159 }
160
161 struct TestLowPriorityTemplate {}
162 impl PromptTemplate for TestLowPriorityTemplate {
163 fn generate(
164 &self,
165 args: &PromptArguments,
166 max_token_length: Option<usize>,
167 ) -> anyhow::Result<(String, usize)> {
168 let mut content = "This is a low priority test prompt template".to_string();
169
170 let mut token_count = args.model.count_tokens(&content)?;
171 if let Some(max_token_length) = max_token_length {
172 if token_count > max_token_length {
173 content = args.model.truncate(
174 &content,
175 max_token_length,
176 TruncationDirection::End,
177 )?;
178 token_count = max_token_length;
179 }
180 }
181
182 anyhow::Ok((content, token_count))
183 }
184 }
185
186 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
187 let args = PromptArguments {
188 model: model.clone(),
189 language_name: None,
190 project_name: None,
191 snippets: Vec::new(),
192 reserved_tokens: 0,
193 buffer: None,
194 selected_range: None,
195 user_prompt: None,
196 };
197
198 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
199 (
200 PromptPriority::Ordered { order: 0 },
201 Box::new(TestPromptTemplate {}),
202 ),
203 (
204 PromptPriority::Ordered { order: 1 },
205 Box::new(TestLowPriorityTemplate {}),
206 ),
207 ];
208 let chain = PromptChain::new(args, templates);
209
210 let (prompt, token_count) = chain.generate(false).unwrap();
211
212 assert_eq!(
213 prompt,
214 "This is a test prompt template\nThis is a low priority test prompt template"
215 .to_string()
216 );
217
218 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
219
220 // Testing with Truncation Off
221 // Should ignore capacity and return all prompts
222 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
223 let args = PromptArguments {
224 model: model.clone(),
225 language_name: None,
226 project_name: None,
227 snippets: Vec::new(),
228 reserved_tokens: 0,
229 buffer: None,
230 selected_range: None,
231 user_prompt: None,
232 };
233
234 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
235 (
236 PromptPriority::Ordered { order: 0 },
237 Box::new(TestPromptTemplate {}),
238 ),
239 (
240 PromptPriority::Ordered { order: 1 },
241 Box::new(TestLowPriorityTemplate {}),
242 ),
243 ];
244 let chain = PromptChain::new(args, templates);
245
246 let (prompt, token_count) = chain.generate(false).unwrap();
247
248 assert_eq!(
249 prompt,
250 "This is a test prompt template\nThis is a low priority test prompt template"
251 .to_string()
252 );
253
254 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
255
256 // Testing with Truncation Off
257 // Should ignore capacity and return all prompts
258 let capacity = 20;
259 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
260 let args = PromptArguments {
261 model: model.clone(),
262 language_name: None,
263 project_name: None,
264 snippets: Vec::new(),
265 reserved_tokens: 0,
266 buffer: None,
267 selected_range: None,
268 user_prompt: None,
269 };
270
271 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
272 (
273 PromptPriority::Ordered { order: 0 },
274 Box::new(TestPromptTemplate {}),
275 ),
276 (
277 PromptPriority::Ordered { order: 1 },
278 Box::new(TestLowPriorityTemplate {}),
279 ),
280 (
281 PromptPriority::Ordered { order: 2 },
282 Box::new(TestLowPriorityTemplate {}),
283 ),
284 ];
285 let chain = PromptChain::new(args, templates);
286
287 let (prompt, token_count) = chain.generate(true).unwrap();
288
289 assert_eq!(prompt, "This is a test promp".to_string());
290 assert_eq!(token_count, capacity);
291
292 // Change Ordering of Prompts Based on Priority
293 let capacity = 120;
294 let reserved_tokens = 10;
295 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
296 let args = PromptArguments {
297 model: model.clone(),
298 language_name: None,
299 project_name: None,
300 snippets: Vec::new(),
301 reserved_tokens,
302 buffer: None,
303 selected_range: None,
304 user_prompt: None,
305 };
306 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
307 (
308 PromptPriority::Mandatory,
309 Box::new(TestLowPriorityTemplate {}),
310 ),
311 (
312 PromptPriority::Ordered { order: 0 },
313 Box::new(TestPromptTemplate {}),
314 ),
315 (
316 PromptPriority::Ordered { order: 1 },
317 Box::new(TestLowPriorityTemplate {}),
318 ),
319 ];
320 let chain = PromptChain::new(args, templates);
321
322 let (prompt, token_count) = chain.generate(true).unwrap();
323
324 assert_eq!(
325 prompt,
326 "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
327 .to_string()
328 );
329 assert_eq!(token_count, capacity - reserved_tokens);
330 }
331}