base.rs

  1use std::cmp::Reverse;
  2use std::ops::Range;
  3use std::sync::Arc;
  4
  5use language::BufferSnapshot;
  6use util::ResultExt;
  7
  8use crate::models::LanguageModel;
  9use crate::prompts::repository_context::PromptCodeSnippet;
 10
 11pub(crate) enum PromptFileType {
 12    Text,
 13    Code,
 14}
 15
 16// TODO: Set this up to manage for defaults well
 17pub struct PromptArguments {
 18    pub model: Arc<dyn LanguageModel>,
 19    pub user_prompt: Option<String>,
 20    pub language_name: Option<String>,
 21    pub project_name: Option<String>,
 22    pub snippets: Vec<PromptCodeSnippet>,
 23    pub reserved_tokens: usize,
 24    pub buffer: Option<BufferSnapshot>,
 25    pub selected_range: Option<Range<usize>>,
 26}
 27
 28impl PromptArguments {
 29    pub(crate) fn get_file_type(&self) -> PromptFileType {
 30        if self
 31            .language_name
 32            .as_ref()
 33            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
 34            .unwrap_or(true)
 35        {
 36            PromptFileType::Code
 37        } else {
 38            PromptFileType::Text
 39        }
 40    }
 41}
 42
 43pub trait PromptTemplate {
 44    fn generate(
 45        &self,
 46        args: &PromptArguments,
 47        max_token_length: Option<usize>,
 48    ) -> anyhow::Result<(String, usize)>;
 49}
 50
 51#[repr(i8)]
 52#[derive(PartialEq, Eq, Ord)]
 53pub enum PromptPriority {
 54    Mandatory,                // Ignores truncation
 55    Ordered { order: usize }, // Truncates based on priority
 56}
 57
 58impl PartialOrd for PromptPriority {
 59    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
 60        match (self, other) {
 61            (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
 62            (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
 63            (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
 64            (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
 65        }
 66    }
 67}
 68
 69pub struct PromptChain {
 70    args: PromptArguments,
 71    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 72}
 73
 74impl PromptChain {
 75    pub fn new(
 76        args: PromptArguments,
 77        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 78    ) -> Self {
 79        PromptChain { args, templates }
 80    }
 81
 82    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
 83        // Argsort based on Prompt Priority
 84        let separator = "\n";
 85        let separator_tokens = self.args.model.count_tokens(separator)?;
 86        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
 87        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
 88
 89        // If Truncate
 90        let mut tokens_outstanding = if truncate {
 91            Some(self.args.model.capacity()? - self.args.reserved_tokens)
 92        } else {
 93            None
 94        };
 95
 96        let mut prompts = vec!["".to_string(); sorted_indices.len()];
 97        for idx in sorted_indices {
 98            let (_, template) = &self.templates[idx];
 99
100            if let Some((template_prompt, prompt_token_count)) =
101                template.generate(&self.args, tokens_outstanding).log_err()
102            {
103                if template_prompt != "" {
104                    prompts[idx] = template_prompt;
105
106                    if let Some(remaining_tokens) = tokens_outstanding {
107                        let new_tokens = prompt_token_count + separator_tokens;
108                        tokens_outstanding = if remaining_tokens > new_tokens {
109                            Some(remaining_tokens - new_tokens)
110                        } else {
111                            Some(0)
112                        };
113                    }
114                }
115            }
116        }
117
118        prompts.retain(|x| x != "");
119
120        let full_prompt = prompts.join(separator);
121        let total_token_count = self.args.model.count_tokens(&full_prompt)?;
122        anyhow::Ok((prompts.join(separator), total_token_count))
123    }
124}
125
126#[cfg(test)]
127pub(crate) mod tests {
128    use crate::models::TruncationDirection;
129    use crate::test::FakeLanguageModel;
130
131    use super::*;
132
133    #[test]
134    pub fn test_prompt_chain() {
135        struct TestPromptTemplate {}
136        impl PromptTemplate for TestPromptTemplate {
137            fn generate(
138                &self,
139                args: &PromptArguments,
140                max_token_length: Option<usize>,
141            ) -> anyhow::Result<(String, usize)> {
142                let mut content = "This is a test prompt template".to_string();
143
144                let mut token_count = args.model.count_tokens(&content)?;
145                if let Some(max_token_length) = max_token_length {
146                    if token_count > max_token_length {
147                        content = args.model.truncate(
148                            &content,
149                            max_token_length,
150                            TruncationDirection::End,
151                        )?;
152                        token_count = max_token_length;
153                    }
154                }
155
156                anyhow::Ok((content, token_count))
157            }
158        }
159
160        struct TestLowPriorityTemplate {}
161        impl PromptTemplate for TestLowPriorityTemplate {
162            fn generate(
163                &self,
164                args: &PromptArguments,
165                max_token_length: Option<usize>,
166            ) -> anyhow::Result<(String, usize)> {
167                let mut content = "This is a low priority test prompt template".to_string();
168
169                let mut token_count = args.model.count_tokens(&content)?;
170                if let Some(max_token_length) = max_token_length {
171                    if token_count > max_token_length {
172                        content = args.model.truncate(
173                            &content,
174                            max_token_length,
175                            TruncationDirection::End,
176                        )?;
177                        token_count = max_token_length;
178                    }
179                }
180
181                anyhow::Ok((content, token_count))
182            }
183        }
184
185        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
186        let args = PromptArguments {
187            model: model.clone(),
188            language_name: None,
189            project_name: None,
190            snippets: Vec::new(),
191            reserved_tokens: 0,
192            buffer: None,
193            selected_range: None,
194            user_prompt: None,
195        };
196
197        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
198            (
199                PromptPriority::Ordered { order: 0 },
200                Box::new(TestPromptTemplate {}),
201            ),
202            (
203                PromptPriority::Ordered { order: 1 },
204                Box::new(TestLowPriorityTemplate {}),
205            ),
206        ];
207        let chain = PromptChain::new(args, templates);
208
209        let (prompt, token_count) = chain.generate(false).unwrap();
210
211        assert_eq!(
212            prompt,
213            "This is a test prompt template\nThis is a low priority test prompt template"
214                .to_string()
215        );
216
217        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
218
219        // Testing with Truncation Off
220        // Should ignore capacity and return all prompts
221        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
222        let args = PromptArguments {
223            model: model.clone(),
224            language_name: None,
225            project_name: None,
226            snippets: Vec::new(),
227            reserved_tokens: 0,
228            buffer: None,
229            selected_range: None,
230            user_prompt: None,
231        };
232
233        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
234            (
235                PromptPriority::Ordered { order: 0 },
236                Box::new(TestPromptTemplate {}),
237            ),
238            (
239                PromptPriority::Ordered { order: 1 },
240                Box::new(TestLowPriorityTemplate {}),
241            ),
242        ];
243        let chain = PromptChain::new(args, templates);
244
245        let (prompt, token_count) = chain.generate(false).unwrap();
246
247        assert_eq!(
248            prompt,
249            "This is a test prompt template\nThis is a low priority test prompt template"
250                .to_string()
251        );
252
253        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
254
255        // Testing with Truncation Off
256        // Should ignore capacity and return all prompts
257        let capacity = 20;
258        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
259        let args = PromptArguments {
260            model: model.clone(),
261            language_name: None,
262            project_name: None,
263            snippets: Vec::new(),
264            reserved_tokens: 0,
265            buffer: None,
266            selected_range: None,
267            user_prompt: None,
268        };
269
270        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
271            (
272                PromptPriority::Ordered { order: 0 },
273                Box::new(TestPromptTemplate {}),
274            ),
275            (
276                PromptPriority::Ordered { order: 1 },
277                Box::new(TestLowPriorityTemplate {}),
278            ),
279            (
280                PromptPriority::Ordered { order: 2 },
281                Box::new(TestLowPriorityTemplate {}),
282            ),
283        ];
284        let chain = PromptChain::new(args, templates);
285
286        let (prompt, token_count) = chain.generate(true).unwrap();
287
288        assert_eq!(prompt, "This is a test promp".to_string());
289        assert_eq!(token_count, capacity);
290
291        // Change Ordering of Prompts Based on Priority
292        let capacity = 120;
293        let reserved_tokens = 10;
294        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
295        let args = PromptArguments {
296            model: model.clone(),
297            language_name: None,
298            project_name: None,
299            snippets: Vec::new(),
300            reserved_tokens,
301            buffer: None,
302            selected_range: None,
303            user_prompt: None,
304        };
305        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
306            (
307                PromptPriority::Mandatory,
308                Box::new(TestLowPriorityTemplate {}),
309            ),
310            (
311                PromptPriority::Ordered { order: 0 },
312                Box::new(TestPromptTemplate {}),
313            ),
314            (
315                PromptPriority::Ordered { order: 1 },
316                Box::new(TestLowPriorityTemplate {}),
317            ),
318        ];
319        let chain = PromptChain::new(args, templates);
320
321        let (prompt, token_count) = chain.generate(true).unwrap();
322
323        assert_eq!(
324            prompt,
325            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
326                .to_string()
327        );
328        assert_eq!(token_count, capacity - reserved_tokens);
329    }
330}