base.rs

  1use std::cmp::Reverse;
  2use std::ops::Range;
  3use std::sync::Arc;
  4
  5use language::BufferSnapshot;
  6use util::ResultExt;
  7
  8use crate::models::LanguageModel;
  9use crate::prompts::repository_context::PromptCodeSnippet;
 10
 11pub(crate) enum PromptFileType {
 12    Text,
 13    Code,
 14}
 15
 16// TODO: Set this up to manage for defaults well
 17pub struct PromptArguments {
 18    pub model: Arc<dyn LanguageModel>,
 19    pub user_prompt: Option<String>,
 20    pub language_name: Option<String>,
 21    pub project_name: Option<String>,
 22    pub snippets: Vec<PromptCodeSnippet>,
 23    pub reserved_tokens: usize,
 24    pub buffer: Option<BufferSnapshot>,
 25    pub selected_range: Option<Range<usize>>,
 26}
 27
 28impl PromptArguments {
 29    pub(crate) fn get_file_type(&self) -> PromptFileType {
 30        if self
 31            .language_name
 32            .as_ref()
 33            .map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
 34            .unwrap_or(true)
 35        {
 36            PromptFileType::Code
 37        } else {
 38            PromptFileType::Text
 39        }
 40    }
 41}
 42
 43pub trait PromptTemplate {
 44    fn generate(
 45        &self,
 46        args: &PromptArguments,
 47        max_token_length: Option<usize>,
 48    ) -> anyhow::Result<(String, usize)>;
 49}
 50
 51#[repr(i8)]
 52#[derive(PartialEq, Eq)]
 53pub enum PromptPriority {
 54    /// Ignores truncation.
 55    Mandatory,
 56    /// Truncates based on priority.
 57    Ordered { order: usize },
 58}
 59
 60impl PartialOrd for PromptPriority {
 61    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
 62        Some(self.cmp(other))
 63    }
 64}
 65
 66impl Ord for PromptPriority {
 67    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
 68        match (self, other) {
 69            (Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
 70            (Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
 71            (Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
 72            (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
 73        }
 74    }
 75}
 76
 77pub struct PromptChain {
 78    args: PromptArguments,
 79    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 80}
 81
 82impl PromptChain {
 83    pub fn new(
 84        args: PromptArguments,
 85        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 86    ) -> Self {
 87        PromptChain { args, templates }
 88    }
 89
 90    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
 91        // Argsort based on Prompt Priority
 92        let separator = "\n";
 93        let separator_tokens = self.args.model.count_tokens(separator)?;
 94        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
 95        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
 96
 97        let mut tokens_outstanding = if truncate {
 98            Some(self.args.model.capacity()? - self.args.reserved_tokens)
 99        } else {
100            None
101        };
102
103        let mut prompts = vec!["".to_string(); sorted_indices.len()];
104        for idx in sorted_indices {
105            let (_, template) = &self.templates[idx];
106
107            if let Some((template_prompt, prompt_token_count)) =
108                template.generate(&self.args, tokens_outstanding).log_err()
109            {
110                if template_prompt != "" {
111                    prompts[idx] = template_prompt;
112
113                    if let Some(remaining_tokens) = tokens_outstanding {
114                        let new_tokens = prompt_token_count + separator_tokens;
115                        tokens_outstanding = if remaining_tokens > new_tokens {
116                            Some(remaining_tokens - new_tokens)
117                        } else {
118                            Some(0)
119                        };
120                    }
121                }
122            }
123        }
124
125        prompts.retain(|x| x != "");
126
127        let full_prompt = prompts.join(separator);
128        let total_token_count = self.args.model.count_tokens(&full_prompt)?;
129        anyhow::Ok((prompts.join(separator), total_token_count))
130    }
131}
132
133#[cfg(test)]
134pub(crate) mod tests {
135    use crate::models::TruncationDirection;
136    use crate::test::FakeLanguageModel;
137
138    use super::*;
139
140    #[test]
141    pub fn test_prompt_chain() {
142        struct TestPromptTemplate {}
143        impl PromptTemplate for TestPromptTemplate {
144            fn generate(
145                &self,
146                args: &PromptArguments,
147                max_token_length: Option<usize>,
148            ) -> anyhow::Result<(String, usize)> {
149                let mut content = "This is a test prompt template".to_string();
150
151                let mut token_count = args.model.count_tokens(&content)?;
152                if let Some(max_token_length) = max_token_length {
153                    if token_count > max_token_length {
154                        content = args.model.truncate(
155                            &content,
156                            max_token_length,
157                            TruncationDirection::End,
158                        )?;
159                        token_count = max_token_length;
160                    }
161                }
162
163                anyhow::Ok((content, token_count))
164            }
165        }
166
167        struct TestLowPriorityTemplate {}
168        impl PromptTemplate for TestLowPriorityTemplate {
169            fn generate(
170                &self,
171                args: &PromptArguments,
172                max_token_length: Option<usize>,
173            ) -> anyhow::Result<(String, usize)> {
174                let mut content = "This is a low priority test prompt template".to_string();
175
176                let mut token_count = args.model.count_tokens(&content)?;
177                if let Some(max_token_length) = max_token_length {
178                    if token_count > max_token_length {
179                        content = args.model.truncate(
180                            &content,
181                            max_token_length,
182                            TruncationDirection::End,
183                        )?;
184                        token_count = max_token_length;
185                    }
186                }
187
188                anyhow::Ok((content, token_count))
189            }
190        }
191
192        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
193        let args = PromptArguments {
194            model: model.clone(),
195            language_name: None,
196            project_name: None,
197            snippets: Vec::new(),
198            reserved_tokens: 0,
199            buffer: None,
200            selected_range: None,
201            user_prompt: None,
202        };
203
204        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
205            (
206                PromptPriority::Ordered { order: 0 },
207                Box::new(TestPromptTemplate {}),
208            ),
209            (
210                PromptPriority::Ordered { order: 1 },
211                Box::new(TestLowPriorityTemplate {}),
212            ),
213        ];
214        let chain = PromptChain::new(args, templates);
215
216        let (prompt, token_count) = chain.generate(false).unwrap();
217
218        assert_eq!(
219            prompt,
220            "This is a test prompt template\nThis is a low priority test prompt template"
221                .to_string()
222        );
223
224        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
225
226        // Testing with Truncation Off
227        // Should ignore capacity and return all prompts
228        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
229        let args = PromptArguments {
230            model: model.clone(),
231            language_name: None,
232            project_name: None,
233            snippets: Vec::new(),
234            reserved_tokens: 0,
235            buffer: None,
236            selected_range: None,
237            user_prompt: None,
238        };
239
240        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
241            (
242                PromptPriority::Ordered { order: 0 },
243                Box::new(TestPromptTemplate {}),
244            ),
245            (
246                PromptPriority::Ordered { order: 1 },
247                Box::new(TestLowPriorityTemplate {}),
248            ),
249        ];
250        let chain = PromptChain::new(args, templates);
251
252        let (prompt, token_count) = chain.generate(false).unwrap();
253
254        assert_eq!(
255            prompt,
256            "This is a test prompt template\nThis is a low priority test prompt template"
257                .to_string()
258        );
259
260        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
261
262        // Testing with Truncation Off
263        // Should ignore capacity and return all prompts
264        let capacity = 20;
265        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
266        let args = PromptArguments {
267            model: model.clone(),
268            language_name: None,
269            project_name: None,
270            snippets: Vec::new(),
271            reserved_tokens: 0,
272            buffer: None,
273            selected_range: None,
274            user_prompt: None,
275        };
276
277        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
278            (
279                PromptPriority::Ordered { order: 0 },
280                Box::new(TestPromptTemplate {}),
281            ),
282            (
283                PromptPriority::Ordered { order: 1 },
284                Box::new(TestLowPriorityTemplate {}),
285            ),
286            (
287                PromptPriority::Ordered { order: 2 },
288                Box::new(TestLowPriorityTemplate {}),
289            ),
290        ];
291        let chain = PromptChain::new(args, templates);
292
293        let (prompt, token_count) = chain.generate(true).unwrap();
294
295        assert_eq!(prompt, "This is a test promp".to_string());
296        assert_eq!(token_count, capacity);
297
298        // Change Ordering of Prompts Based on Priority
299        let capacity = 120;
300        let reserved_tokens = 10;
301        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
302        let args = PromptArguments {
303            model: model.clone(),
304            language_name: None,
305            project_name: None,
306            snippets: Vec::new(),
307            reserved_tokens,
308            buffer: None,
309            selected_range: None,
310            user_prompt: None,
311        };
312        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
313            (
314                PromptPriority::Mandatory,
315                Box::new(TestLowPriorityTemplate {}),
316            ),
317            (
318                PromptPriority::Ordered { order: 0 },
319                Box::new(TestPromptTemplate {}),
320            ),
321            (
322                PromptPriority::Ordered { order: 1 },
323                Box::new(TestLowPriorityTemplate {}),
324            ),
325        ];
326        let chain = PromptChain::new(args, templates);
327
328        let (prompt, token_count) = chain.generate(true).unwrap();
329
330        assert_eq!(
331            prompt,
332            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
333                .to_string()
334        );
335        assert_eq!(token_count, capacity - reserved_tokens);
336    }
337}