1use std::cmp::Reverse;
2use std::ops::Range;
3use std::sync::Arc;
4
5use language::BufferSnapshot;
6use util::ResultExt;
7
8use crate::models::LanguageModel;
9use crate::prompts::repository_context::PromptCodeSnippet;
10
11pub(crate) enum PromptFileType {
12 Text,
13 Code,
14}
15
16// TODO: Set this up to manage for defaults well
17pub struct PromptArguments {
18 pub model: Arc<dyn LanguageModel>,
19 pub user_prompt: Option<String>,
20 pub language_name: Option<String>,
21 pub project_name: Option<String>,
22 pub snippets: Vec<PromptCodeSnippet>,
23 pub reserved_tokens: usize,
24 pub buffer: Option<BufferSnapshot>,
25 pub selected_range: Option<Range<usize>>,
26}
27
28impl PromptArguments {
29 pub(crate) fn get_file_type(&self) -> PromptFileType {
30 if self
31 .language_name
32 .as_ref()
33 .map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
34 .unwrap_or(true)
35 {
36 PromptFileType::Code
37 } else {
38 PromptFileType::Text
39 }
40 }
41}
42
43pub trait PromptTemplate {
44 fn generate(
45 &self,
46 args: &PromptArguments,
47 max_token_length: Option<usize>,
48 ) -> anyhow::Result<(String, usize)>;
49}
50
51#[repr(i8)]
52#[derive(PartialEq, Eq)]
53pub enum PromptPriority {
54 /// Ignores truncation.
55 Mandatory,
56 /// Truncates based on priority.
57 Ordered { order: usize },
58}
59
60impl PartialOrd for PromptPriority {
61 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
62 Some(self.cmp(other))
63 }
64}
65
66impl Ord for PromptPriority {
67 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
68 match (self, other) {
69 (Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
70 (Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
71 (Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
72 (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
73 }
74 }
75}
76
77pub struct PromptChain {
78 args: PromptArguments,
79 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
80}
81
82impl PromptChain {
83 pub fn new(
84 args: PromptArguments,
85 templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
86 ) -> Self {
87 PromptChain { args, templates }
88 }
89
90 pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
91 // Argsort based on Prompt Priority
92 let separator = "\n";
93 let separator_tokens = self.args.model.count_tokens(separator)?;
94 let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
95 sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
96
97 let mut tokens_outstanding = if truncate {
98 Some(self.args.model.capacity()? - self.args.reserved_tokens)
99 } else {
100 None
101 };
102
103 let mut prompts = vec!["".to_string(); sorted_indices.len()];
104 for idx in sorted_indices {
105 let (_, template) = &self.templates[idx];
106
107 if let Some((template_prompt, prompt_token_count)) =
108 template.generate(&self.args, tokens_outstanding).log_err()
109 {
110 if template_prompt != "" {
111 prompts[idx] = template_prompt;
112
113 if let Some(remaining_tokens) = tokens_outstanding {
114 let new_tokens = prompt_token_count + separator_tokens;
115 tokens_outstanding = if remaining_tokens > new_tokens {
116 Some(remaining_tokens - new_tokens)
117 } else {
118 Some(0)
119 };
120 }
121 }
122 }
123 }
124
125 prompts.retain(|x| x != "");
126
127 let full_prompt = prompts.join(separator);
128 let total_token_count = self.args.model.count_tokens(&full_prompt)?;
129 anyhow::Ok((prompts.join(separator), total_token_count))
130 }
131}
132
133#[cfg(test)]
134pub(crate) mod tests {
135 use crate::models::TruncationDirection;
136 use crate::test::FakeLanguageModel;
137
138 use super::*;
139
140 #[test]
141 pub fn test_prompt_chain() {
142 struct TestPromptTemplate {}
143 impl PromptTemplate for TestPromptTemplate {
144 fn generate(
145 &self,
146 args: &PromptArguments,
147 max_token_length: Option<usize>,
148 ) -> anyhow::Result<(String, usize)> {
149 let mut content = "This is a test prompt template".to_string();
150
151 let mut token_count = args.model.count_tokens(&content)?;
152 if let Some(max_token_length) = max_token_length {
153 if token_count > max_token_length {
154 content = args.model.truncate(
155 &content,
156 max_token_length,
157 TruncationDirection::End,
158 )?;
159 token_count = max_token_length;
160 }
161 }
162
163 anyhow::Ok((content, token_count))
164 }
165 }
166
167 struct TestLowPriorityTemplate {}
168 impl PromptTemplate for TestLowPriorityTemplate {
169 fn generate(
170 &self,
171 args: &PromptArguments,
172 max_token_length: Option<usize>,
173 ) -> anyhow::Result<(String, usize)> {
174 let mut content = "This is a low priority test prompt template".to_string();
175
176 let mut token_count = args.model.count_tokens(&content)?;
177 if let Some(max_token_length) = max_token_length {
178 if token_count > max_token_length {
179 content = args.model.truncate(
180 &content,
181 max_token_length,
182 TruncationDirection::End,
183 )?;
184 token_count = max_token_length;
185 }
186 }
187
188 anyhow::Ok((content, token_count))
189 }
190 }
191
192 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
193 let args = PromptArguments {
194 model: model.clone(),
195 language_name: None,
196 project_name: None,
197 snippets: Vec::new(),
198 reserved_tokens: 0,
199 buffer: None,
200 selected_range: None,
201 user_prompt: None,
202 };
203
204 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
205 (
206 PromptPriority::Ordered { order: 0 },
207 Box::new(TestPromptTemplate {}),
208 ),
209 (
210 PromptPriority::Ordered { order: 1 },
211 Box::new(TestLowPriorityTemplate {}),
212 ),
213 ];
214 let chain = PromptChain::new(args, templates);
215
216 let (prompt, token_count) = chain.generate(false).unwrap();
217
218 assert_eq!(
219 prompt,
220 "This is a test prompt template\nThis is a low priority test prompt template"
221 .to_string()
222 );
223
224 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
225
226 // Testing with Truncation Off
227 // Should ignore capacity and return all prompts
228 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
229 let args = PromptArguments {
230 model: model.clone(),
231 language_name: None,
232 project_name: None,
233 snippets: Vec::new(),
234 reserved_tokens: 0,
235 buffer: None,
236 selected_range: None,
237 user_prompt: None,
238 };
239
240 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
241 (
242 PromptPriority::Ordered { order: 0 },
243 Box::new(TestPromptTemplate {}),
244 ),
245 (
246 PromptPriority::Ordered { order: 1 },
247 Box::new(TestLowPriorityTemplate {}),
248 ),
249 ];
250 let chain = PromptChain::new(args, templates);
251
252 let (prompt, token_count) = chain.generate(false).unwrap();
253
254 assert_eq!(
255 prompt,
256 "This is a test prompt template\nThis is a low priority test prompt template"
257 .to_string()
258 );
259
260 assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
261
262 // Testing with Truncation Off
263 // Should ignore capacity and return all prompts
264 let capacity = 20;
265 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
266 let args = PromptArguments {
267 model: model.clone(),
268 language_name: None,
269 project_name: None,
270 snippets: Vec::new(),
271 reserved_tokens: 0,
272 buffer: None,
273 selected_range: None,
274 user_prompt: None,
275 };
276
277 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
278 (
279 PromptPriority::Ordered { order: 0 },
280 Box::new(TestPromptTemplate {}),
281 ),
282 (
283 PromptPriority::Ordered { order: 1 },
284 Box::new(TestLowPriorityTemplate {}),
285 ),
286 (
287 PromptPriority::Ordered { order: 2 },
288 Box::new(TestLowPriorityTemplate {}),
289 ),
290 ];
291 let chain = PromptChain::new(args, templates);
292
293 let (prompt, token_count) = chain.generate(true).unwrap();
294
295 assert_eq!(prompt, "This is a test promp".to_string());
296 assert_eq!(token_count, capacity);
297
298 // Change Ordering of Prompts Based on Priority
299 let capacity = 120;
300 let reserved_tokens = 10;
301 let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
302 let args = PromptArguments {
303 model: model.clone(),
304 language_name: None,
305 project_name: None,
306 snippets: Vec::new(),
307 reserved_tokens,
308 buffer: None,
309 selected_range: None,
310 user_prompt: None,
311 };
312 let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
313 (
314 PromptPriority::Mandatory,
315 Box::new(TestLowPriorityTemplate {}),
316 ),
317 (
318 PromptPriority::Ordered { order: 0 },
319 Box::new(TestPromptTemplate {}),
320 ),
321 (
322 PromptPriority::Ordered { order: 1 },
323 Box::new(TestLowPriorityTemplate {}),
324 ),
325 ];
326 let chain = PromptChain::new(args, templates);
327
328 let (prompt, token_count) = chain.generate(true).unwrap();
329
330 assert_eq!(
331 prompt,
332 "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
333 .to_string()
334 );
335 assert_eq!(token_count, capacity - reserved_tokens);
336 }
337}