base.rs

  1use std::cmp::Reverse;
  2use std::ops::Range;
  3use std::sync::Arc;
  4
  5use language::BufferSnapshot;
  6use util::ResultExt;
  7
  8use crate::models::LanguageModel;
  9use crate::prompts::repository_context::PromptCodeSnippet;
 10
 11pub(crate) enum PromptFileType {
 12    Text,
 13    Code,
 14}
 15
 16// TODO: Set this up to manage for defaults well
 17pub struct PromptArguments {
 18    pub model: Arc<dyn LanguageModel>,
 19    pub user_prompt: Option<String>,
 20    pub language_name: Option<String>,
 21    pub project_name: Option<String>,
 22    pub snippets: Vec<PromptCodeSnippet>,
 23    pub reserved_tokens: usize,
 24    pub buffer: Option<BufferSnapshot>,
 25    pub selected_range: Option<Range<usize>>,
 26}
 27
 28impl PromptArguments {
 29    pub(crate) fn get_file_type(&self) -> PromptFileType {
 30        if self
 31            .language_name
 32            .as_ref()
 33            .map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
 34            .unwrap_or(true)
 35        {
 36            PromptFileType::Code
 37        } else {
 38            PromptFileType::Text
 39        }
 40    }
 41}
 42
 43pub trait PromptTemplate {
 44    fn generate(
 45        &self,
 46        args: &PromptArguments,
 47        max_token_length: Option<usize>,
 48    ) -> anyhow::Result<(String, usize)>;
 49}
 50
 51#[repr(i8)]
 52#[derive(PartialEq, Eq, Ord)]
 53pub enum PromptPriority {
 54    /// Ignores truncation.
 55    Mandatory,
 56    /// Truncates based on priority.
 57    Ordered { order: usize },
 58}
 59
 60impl PartialOrd for PromptPriority {
 61    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
 62        match (self, other) {
 63            (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
 64            (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
 65            (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
 66            (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
 67        }
 68    }
 69}
 70
 71pub struct PromptChain {
 72    args: PromptArguments,
 73    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 74}
 75
 76impl PromptChain {
 77    pub fn new(
 78        args: PromptArguments,
 79        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 80    ) -> Self {
 81        PromptChain { args, templates }
 82    }
 83
 84    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
 85        // Argsort based on Prompt Priority
 86        let separator = "\n";
 87        let separator_tokens = self.args.model.count_tokens(separator)?;
 88        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
 89        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
 90
 91        let mut tokens_outstanding = if truncate {
 92            Some(self.args.model.capacity()? - self.args.reserved_tokens)
 93        } else {
 94            None
 95        };
 96
 97        let mut prompts = vec!["".to_string(); sorted_indices.len()];
 98        for idx in sorted_indices {
 99            let (_, template) = &self.templates[idx];
100
101            if let Some((template_prompt, prompt_token_count)) =
102                template.generate(&self.args, tokens_outstanding).log_err()
103            {
104                if template_prompt != "" {
105                    prompts[idx] = template_prompt;
106
107                    if let Some(remaining_tokens) = tokens_outstanding {
108                        let new_tokens = prompt_token_count + separator_tokens;
109                        tokens_outstanding = if remaining_tokens > new_tokens {
110                            Some(remaining_tokens - new_tokens)
111                        } else {
112                            Some(0)
113                        };
114                    }
115                }
116            }
117        }
118
119        prompts.retain(|x| x != "");
120
121        let full_prompt = prompts.join(separator);
122        let total_token_count = self.args.model.count_tokens(&full_prompt)?;
123        anyhow::Ok((prompts.join(separator), total_token_count))
124    }
125}
126
127#[cfg(test)]
128pub(crate) mod tests {
129    use crate::models::TruncationDirection;
130    use crate::test::FakeLanguageModel;
131
132    use super::*;
133
134    #[test]
135    pub fn test_prompt_chain() {
136        struct TestPromptTemplate {}
137        impl PromptTemplate for TestPromptTemplate {
138            fn generate(
139                &self,
140                args: &PromptArguments,
141                max_token_length: Option<usize>,
142            ) -> anyhow::Result<(String, usize)> {
143                let mut content = "This is a test prompt template".to_string();
144
145                let mut token_count = args.model.count_tokens(&content)?;
146                if let Some(max_token_length) = max_token_length {
147                    if token_count > max_token_length {
148                        content = args.model.truncate(
149                            &content,
150                            max_token_length,
151                            TruncationDirection::End,
152                        )?;
153                        token_count = max_token_length;
154                    }
155                }
156
157                anyhow::Ok((content, token_count))
158            }
159        }
160
161        struct TestLowPriorityTemplate {}
162        impl PromptTemplate for TestLowPriorityTemplate {
163            fn generate(
164                &self,
165                args: &PromptArguments,
166                max_token_length: Option<usize>,
167            ) -> anyhow::Result<(String, usize)> {
168                let mut content = "This is a low priority test prompt template".to_string();
169
170                let mut token_count = args.model.count_tokens(&content)?;
171                if let Some(max_token_length) = max_token_length {
172                    if token_count > max_token_length {
173                        content = args.model.truncate(
174                            &content,
175                            max_token_length,
176                            TruncationDirection::End,
177                        )?;
178                        token_count = max_token_length;
179                    }
180                }
181
182                anyhow::Ok((content, token_count))
183            }
184        }
185
186        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
187        let args = PromptArguments {
188            model: model.clone(),
189            language_name: None,
190            project_name: None,
191            snippets: Vec::new(),
192            reserved_tokens: 0,
193            buffer: None,
194            selected_range: None,
195            user_prompt: None,
196        };
197
198        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
199            (
200                PromptPriority::Ordered { order: 0 },
201                Box::new(TestPromptTemplate {}),
202            ),
203            (
204                PromptPriority::Ordered { order: 1 },
205                Box::new(TestLowPriorityTemplate {}),
206            ),
207        ];
208        let chain = PromptChain::new(args, templates);
209
210        let (prompt, token_count) = chain.generate(false).unwrap();
211
212        assert_eq!(
213            prompt,
214            "This is a test prompt template\nThis is a low priority test prompt template"
215                .to_string()
216        );
217
218        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
219
220        // Testing with Truncation Off
221        // Should ignore capacity and return all prompts
222        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
223        let args = PromptArguments {
224            model: model.clone(),
225            language_name: None,
226            project_name: None,
227            snippets: Vec::new(),
228            reserved_tokens: 0,
229            buffer: None,
230            selected_range: None,
231            user_prompt: None,
232        };
233
234        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
235            (
236                PromptPriority::Ordered { order: 0 },
237                Box::new(TestPromptTemplate {}),
238            ),
239            (
240                PromptPriority::Ordered { order: 1 },
241                Box::new(TestLowPriorityTemplate {}),
242            ),
243        ];
244        let chain = PromptChain::new(args, templates);
245
246        let (prompt, token_count) = chain.generate(false).unwrap();
247
248        assert_eq!(
249            prompt,
250            "This is a test prompt template\nThis is a low priority test prompt template"
251                .to_string()
252        );
253
254        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
255
256        // Testing with Truncation Off
257        // Should ignore capacity and return all prompts
258        let capacity = 20;
259        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
260        let args = PromptArguments {
261            model: model.clone(),
262            language_name: None,
263            project_name: None,
264            snippets: Vec::new(),
265            reserved_tokens: 0,
266            buffer: None,
267            selected_range: None,
268            user_prompt: None,
269        };
270
271        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
272            (
273                PromptPriority::Ordered { order: 0 },
274                Box::new(TestPromptTemplate {}),
275            ),
276            (
277                PromptPriority::Ordered { order: 1 },
278                Box::new(TestLowPriorityTemplate {}),
279            ),
280            (
281                PromptPriority::Ordered { order: 2 },
282                Box::new(TestLowPriorityTemplate {}),
283            ),
284        ];
285        let chain = PromptChain::new(args, templates);
286
287        let (prompt, token_count) = chain.generate(true).unwrap();
288
289        assert_eq!(prompt, "This is a test promp".to_string());
290        assert_eq!(token_count, capacity);
291
292        // Change Ordering of Prompts Based on Priority
293        let capacity = 120;
294        let reserved_tokens = 10;
295        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
296        let args = PromptArguments {
297            model: model.clone(),
298            language_name: None,
299            project_name: None,
300            snippets: Vec::new(),
301            reserved_tokens,
302            buffer: None,
303            selected_range: None,
304            user_prompt: None,
305        };
306        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
307            (
308                PromptPriority::Mandatory,
309                Box::new(TestLowPriorityTemplate {}),
310            ),
311            (
312                PromptPriority::Ordered { order: 0 },
313                Box::new(TestPromptTemplate {}),
314            ),
315            (
316                PromptPriority::Ordered { order: 1 },
317                Box::new(TestLowPriorityTemplate {}),
318            ),
319        ];
320        let chain = PromptChain::new(args, templates);
321
322        let (prompt, token_count) = chain.generate(true).unwrap();
323
324        assert_eq!(
325            prompt,
326            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
327                .to_string()
328        );
329        assert_eq!(token_count, capacity - reserved_tokens);
330    }
331}