base.rs

  1use std::cmp::Reverse;
  2use std::ops::Range;
  3use std::sync::Arc;
  4
  5use language::BufferSnapshot;
  6use util::ResultExt;
  7
  8use crate::models::LanguageModel;
  9use crate::prompts::repository_context::PromptCodeSnippet;
 10
 11pub(crate) enum PromptFileType {
 12    Text,
 13    Code,
 14}
 15
 16// TODO: Set this up to manage for defaults well
 17pub struct PromptArguments {
 18    pub model: Arc<dyn LanguageModel>,
 19    pub user_prompt: Option<String>,
 20    pub language_name: Option<String>,
 21    pub project_name: Option<String>,
 22    pub snippets: Vec<PromptCodeSnippet>,
 23    pub reserved_tokens: usize,
 24    pub buffer: Option<BufferSnapshot>,
 25    pub selected_range: Option<Range<usize>>,
 26}
 27
 28impl PromptArguments {
 29    pub(crate) fn get_file_type(&self) -> PromptFileType {
 30        if self
 31            .language_name
 32            .as_ref()
 33            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
 34            .unwrap_or(true)
 35        {
 36            PromptFileType::Code
 37        } else {
 38            PromptFileType::Text
 39        }
 40    }
 41}
 42
 43pub trait PromptTemplate {
 44    fn generate(
 45        &self,
 46        args: &PromptArguments,
 47        max_token_length: Option<usize>,
 48    ) -> anyhow::Result<(String, usize)>;
 49}
 50
 51#[repr(i8)]
 52#[derive(PartialEq, Eq, Ord)]
 53pub enum PromptPriority {
 54    Mandatory,                // Ignores truncation
 55    Ordered { order: usize }, // Truncates based on priority
 56}
 57
 58impl PartialOrd for PromptPriority {
 59    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
 60        match (self, other) {
 61            (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
 62            (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
 63            (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
 64            (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
 65        }
 66    }
 67}
 68
 69pub struct PromptChain {
 70    args: PromptArguments,
 71    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 72}
 73
 74impl PromptChain {
 75    pub fn new(
 76        args: PromptArguments,
 77        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
 78    ) -> Self {
 79        PromptChain { args, templates }
 80    }
 81
 82    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
 83        // Argsort based on Prompt Priority
 84        let seperator = "\n";
 85        let seperator_tokens = self.args.model.count_tokens(seperator)?;
 86        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
 87        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
 88
 89        // If Truncate
 90        let mut tokens_outstanding = if truncate {
 91            Some(self.args.model.capacity()? - self.args.reserved_tokens)
 92        } else {
 93            None
 94        };
 95
 96        let mut prompts = vec!["".to_string(); sorted_indices.len()];
 97        for idx in sorted_indices {
 98            let (_, template) = &self.templates[idx];
 99
100            if let Some((template_prompt, prompt_token_count)) =
101                template.generate(&self.args, tokens_outstanding).log_err()
102            {
103                if template_prompt != "" {
104                    prompts[idx] = template_prompt;
105
106                    if let Some(remaining_tokens) = tokens_outstanding {
107                        let new_tokens = prompt_token_count + seperator_tokens;
108                        tokens_outstanding = if remaining_tokens > new_tokens {
109                            Some(remaining_tokens - new_tokens)
110                        } else {
111                            Some(0)
112                        };
113                    }
114                }
115            }
116        }
117
118        prompts.retain(|x| x != "");
119
120        let full_prompt = prompts.join(seperator);
121        let total_token_count = self.args.model.count_tokens(&full_prompt)?;
122        anyhow::Ok((prompts.join(seperator), total_token_count))
123    }
124}
125
126#[cfg(test)]
127pub(crate) mod tests {
128    use crate::models::TruncationDirection;
129
130    use super::*;
131
132    #[test]
133    pub fn test_prompt_chain() {
134        struct TestPromptTemplate {}
135        impl PromptTemplate for TestPromptTemplate {
136            fn generate(
137                &self,
138                args: &PromptArguments,
139                max_token_length: Option<usize>,
140            ) -> anyhow::Result<(String, usize)> {
141                let mut content = "This is a test prompt template".to_string();
142
143                let mut token_count = args.model.count_tokens(&content)?;
144                if let Some(max_token_length) = max_token_length {
145                    if token_count > max_token_length {
146                        content = args.model.truncate(
147                            &content,
148                            max_token_length,
149                            TruncationDirection::Start,
150                        )?;
151                        token_count = max_token_length;
152                    }
153                }
154
155                anyhow::Ok((content, token_count))
156            }
157        }
158
159        struct TestLowPriorityTemplate {}
160        impl PromptTemplate for TestLowPriorityTemplate {
161            fn generate(
162                &self,
163                args: &PromptArguments,
164                max_token_length: Option<usize>,
165            ) -> anyhow::Result<(String, usize)> {
166                let mut content = "This is a low priority test prompt template".to_string();
167
168                let mut token_count = args.model.count_tokens(&content)?;
169                if let Some(max_token_length) = max_token_length {
170                    if token_count > max_token_length {
171                        content = args.model.truncate(
172                            &content,
173                            max_token_length,
174                            TruncationDirection::Start,
175                        )?;
176                        token_count = max_token_length;
177                    }
178                }
179
180                anyhow::Ok((content, token_count))
181            }
182        }
183
184        #[derive(Clone)]
185        struct DummyLanguageModel {
186            capacity: usize,
187        }
188
189        impl LanguageModel for DummyLanguageModel {
190            fn name(&self) -> String {
191                "dummy".to_string()
192            }
193            fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
194                anyhow::Ok(content.chars().collect::<Vec<char>>().len())
195            }
196            fn truncate(
197                &self,
198                content: &str,
199                length: usize,
200                direction: TruncationDirection,
201            ) -> anyhow::Result<String> {
202                anyhow::Ok(match direction {
203                    TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
204                        .into_iter()
205                        .collect::<String>(),
206                    TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
207                        .into_iter()
208                        .collect::<String>(),
209                })
210            }
211            fn capacity(&self) -> anyhow::Result<usize> {
212                anyhow::Ok(self.capacity)
213            }
214        }
215
216        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
217        let args = PromptArguments {
218            model: model.clone(),
219            language_name: None,
220            project_name: None,
221            snippets: Vec::new(),
222            reserved_tokens: 0,
223            buffer: None,
224            selected_range: None,
225            user_prompt: None,
226        };
227
228        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
229            (
230                PromptPriority::Ordered { order: 0 },
231                Box::new(TestPromptTemplate {}),
232            ),
233            (
234                PromptPriority::Ordered { order: 1 },
235                Box::new(TestLowPriorityTemplate {}),
236            ),
237        ];
238        let chain = PromptChain::new(args, templates);
239
240        let (prompt, token_count) = chain.generate(false).unwrap();
241
242        assert_eq!(
243            prompt,
244            "This is a test prompt template\nThis is a low priority test prompt template"
245                .to_string()
246        );
247
248        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
249
250        // Testing with Truncation Off
251        // Should ignore capacity and return all prompts
252        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
253        let args = PromptArguments {
254            model: model.clone(),
255            language_name: None,
256            project_name: None,
257            snippets: Vec::new(),
258            reserved_tokens: 0,
259            buffer: None,
260            selected_range: None,
261            user_prompt: None,
262        };
263
264        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
265            (
266                PromptPriority::Ordered { order: 0 },
267                Box::new(TestPromptTemplate {}),
268            ),
269            (
270                PromptPriority::Ordered { order: 1 },
271                Box::new(TestLowPriorityTemplate {}),
272            ),
273        ];
274        let chain = PromptChain::new(args, templates);
275
276        let (prompt, token_count) = chain.generate(false).unwrap();
277
278        assert_eq!(
279            prompt,
280            "This is a test prompt template\nThis is a low priority test prompt template"
281                .to_string()
282        );
283
284        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
285
286        // Testing with Truncation Off
287        // Should ignore capacity and return all prompts
288        let capacity = 20;
289        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
290        let args = PromptArguments {
291            model: model.clone(),
292            language_name: None,
293            project_name: None,
294            snippets: Vec::new(),
295            reserved_tokens: 0,
296            buffer: None,
297            selected_range: None,
298            user_prompt: None,
299        };
300
301        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
302            (
303                PromptPriority::Ordered { order: 0 },
304                Box::new(TestPromptTemplate {}),
305            ),
306            (
307                PromptPriority::Ordered { order: 1 },
308                Box::new(TestLowPriorityTemplate {}),
309            ),
310            (
311                PromptPriority::Ordered { order: 2 },
312                Box::new(TestLowPriorityTemplate {}),
313            ),
314        ];
315        let chain = PromptChain::new(args, templates);
316
317        let (prompt, token_count) = chain.generate(true).unwrap();
318
319        assert_eq!(prompt, "This is a test promp".to_string());
320        assert_eq!(token_count, capacity);
321
322        // Change Ordering of Prompts Based on Priority
323        let capacity = 120;
324        let reserved_tokens = 10;
325        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
326        let args = PromptArguments {
327            model: model.clone(),
328            language_name: None,
329            project_name: None,
330            snippets: Vec::new(),
331            reserved_tokens,
332            buffer: None,
333            selected_range: None,
334            user_prompt: None,
335        };
336        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
337            (
338                PromptPriority::Mandatory,
339                Box::new(TestLowPriorityTemplate {}),
340            ),
341            (
342                PromptPriority::Ordered { order: 0 },
343                Box::new(TestPromptTemplate {}),
344            ),
345            (
346                PromptPriority::Ordered { order: 1 },
347                Box::new(TestLowPriorityTemplate {}),
348            ),
349        ];
350        let chain = PromptChain::new(args, templates);
351
352        let (prompt, token_count) = chain.generate(true).unwrap();
353
354        assert_eq!(
355            prompt,
356            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
357                .to_string()
358        );
359        assert_eq!(token_count, capacity - reserved_tokens);
360    }
361}