external_source_prompt.rs

  1#[derive(Clone, Debug, PartialEq, Eq)]
  2pub struct ExternalSourcePrompt(String);
  3
  4impl ExternalSourcePrompt {
  5    pub fn new(prompt: &str) -> Option<Self> {
  6        sanitize(prompt).map(Self)
  7    }
  8
  9    pub fn as_str(&self) -> &str {
 10        &self.0
 11    }
 12
 13    pub fn into_string(self) -> String {
 14        self.0
 15    }
 16}
 17
 18fn sanitize(prompt: &str) -> Option<String> {
 19    let mut sanitized_prompt = String::with_capacity(prompt.len());
 20    let mut consecutive_newline_count = 0;
 21    let mut characters = prompt.chars().peekable();
 22
 23    while let Some(character) = characters.next() {
 24        let character = if character == '\r' {
 25            if characters.peek() == Some(&'\n') {
 26                characters.next();
 27            }
 28            '\n'
 29        } else {
 30            character
 31        };
 32
 33        if is_bidi_control_character(character) || is_disallowed_control_character(character) {
 34            continue;
 35        }
 36
 37        if character == '\n' {
 38            consecutive_newline_count += 1;
 39            if consecutive_newline_count > 2 {
 40                continue;
 41            }
 42        } else {
 43            consecutive_newline_count = 0;
 44        }
 45
 46        sanitized_prompt.push(character);
 47    }
 48
 49    if sanitized_prompt.is_empty() {
 50        None
 51    } else {
 52        Some(sanitized_prompt)
 53    }
 54}
 55
 56fn is_disallowed_control_character(character: char) -> bool {
 57    character.is_control() && !matches!(character, '\n' | '\t')
 58}
 59
 60fn is_bidi_control_character(character: char) -> bool {
 61    matches!(
 62        character,
 63          '\u{061C}' // ALM
 64        | '\u{200E}' // LRM
 65        | '\u{200F}' // RLM
 66        | '\u{202A}'..='\u{202E}' // LRE, RLE, PDF, LRO, RLO
 67        | '\u{2066}'..='\u{2069}' // LRI, RLI, FSI, PDI
 68    )
 69}
 70
 71#[cfg(test)]
 72mod tests {
 73    use super::ExternalSourcePrompt;
 74
 75    #[test]
 76    fn keeps_normal_prompt_text() {
 77        let prompt = ExternalSourcePrompt::new("Write me a script\nThanks");
 78
 79        assert_eq!(
 80            prompt.as_ref().map(ExternalSourcePrompt::as_str),
 81            Some("Write me a script\nThanks")
 82        );
 83    }
 84
 85    #[test]
 86    fn keeps_multilingual_text() {
 87        let prompt =
 88            ExternalSourcePrompt::new("日本語の依頼です。\n中文提示也应该保留。\nemoji 👩‍💻");
 89
 90        assert_eq!(
 91            prompt.as_ref().map(ExternalSourcePrompt::as_str),
 92            Some("日本語の依頼です。\n中文提示也应该保留。\nemoji 👩‍💻")
 93        );
 94    }
 95
 96    #[test]
 97    fn collapses_newline_padding() {
 98        let prompt = ExternalSourcePrompt::new(
 99            "Review this prompt carefully.\n\nThis paragraph should stay separated.\n\n\n\n\n\n\nWrite me a script to do fizz buzz.",
100        );
101
102        assert_eq!(
103            prompt.as_ref().map(ExternalSourcePrompt::as_str),
104            Some(
105                "Review this prompt carefully.\n\nThis paragraph should stay separated.\n\nWrite me a script to do fizz buzz."
106            )
107        );
108    }
109
110    #[test]
111    fn normalizes_carriage_returns() {
112        let prompt = ExternalSourcePrompt::new("Line one\r\nLine two\rLine three");
113
114        assert_eq!(
115            prompt.as_ref().map(ExternalSourcePrompt::as_str),
116            Some("Line one\nLine two\nLine three")
117        );
118    }
119
120    #[test]
121    fn strips_bidi_control_characters() {
122        let prompt = ExternalSourcePrompt::new("abc\u{202E}def\u{202C}ghi");
123
124        assert_eq!(
125            prompt.as_ref().map(ExternalSourcePrompt::as_str),
126            Some("abcdefghi")
127        );
128    }
129
130    #[test]
131    fn strips_other_control_characters() {
132        let prompt = ExternalSourcePrompt::new("safe\u{0000}\u{001B}\u{007F}text");
133
134        assert_eq!(
135            prompt.as_ref().map(ExternalSourcePrompt::as_str),
136            Some("safetext")
137        );
138    }
139
140    #[test]
141    fn keeps_tabs() {
142        let prompt = ExternalSourcePrompt::new("keep\tindentation");
143
144        assert_eq!(
145            prompt.as_ref().map(ExternalSourcePrompt::as_str),
146            Some("keep\tindentation")
147        );
148    }
149
150    #[test]
151    fn drops_empty_prompt() {
152        assert_eq!(ExternalSourcePrompt::new(""), None);
153    }
154
155    #[test]
156    fn drops_prompt_with_only_removed_characters() {
157        assert_eq!(
158            ExternalSourcePrompt::new("\u{202E}\u{202C}\u{0000}\u{001B}"),
159            None
160        );
161    }
162}