1#[derive(Clone, Debug, PartialEq, Eq)]
2pub struct ExternalSourcePrompt(String);
3
4impl ExternalSourcePrompt {
5 pub fn new(prompt: &str) -> Option<Self> {
6 sanitize(prompt).map(Self)
7 }
8
9 pub fn as_str(&self) -> &str {
10 &self.0
11 }
12
13 pub fn into_string(self) -> String {
14 self.0
15 }
16}
17
18fn sanitize(prompt: &str) -> Option<String> {
19 let mut sanitized_prompt = String::with_capacity(prompt.len());
20 let mut consecutive_newline_count = 0;
21 let mut characters = prompt.chars().peekable();
22
23 while let Some(character) = characters.next() {
24 let character = if character == '\r' {
25 if characters.peek() == Some(&'\n') {
26 characters.next();
27 }
28 '\n'
29 } else {
30 character
31 };
32
33 if is_bidi_control_character(character) || is_disallowed_control_character(character) {
34 continue;
35 }
36
37 if character == '\n' {
38 consecutive_newline_count += 1;
39 if consecutive_newline_count > 2 {
40 continue;
41 }
42 } else {
43 consecutive_newline_count = 0;
44 }
45
46 sanitized_prompt.push(character);
47 }
48
49 if sanitized_prompt.is_empty() {
50 None
51 } else {
52 Some(sanitized_prompt)
53 }
54}
55
56fn is_disallowed_control_character(character: char) -> bool {
57 character.is_control() && !matches!(character, '\n' | '\t')
58}
59
60fn is_bidi_control_character(character: char) -> bool {
61 matches!(
62 character,
63 '\u{061C}' // ALM
64 | '\u{200E}' // LRM
65 | '\u{200F}' // RLM
66 | '\u{202A}'..='\u{202E}' // LRE, RLE, PDF, LRO, RLO
67 | '\u{2066}'..='\u{2069}' // LRI, RLI, FSI, PDI
68 )
69}
70
71#[cfg(test)]
72mod tests {
73 use super::ExternalSourcePrompt;
74
75 #[test]
76 fn keeps_normal_prompt_text() {
77 let prompt = ExternalSourcePrompt::new("Write me a script\nThanks");
78
79 assert_eq!(
80 prompt.as_ref().map(ExternalSourcePrompt::as_str),
81 Some("Write me a script\nThanks")
82 );
83 }
84
85 #[test]
86 fn keeps_multilingual_text() {
87 let prompt =
88 ExternalSourcePrompt::new("日本語の依頼です。\n中文提示也应该保留。\nemoji 👩💻");
89
90 assert_eq!(
91 prompt.as_ref().map(ExternalSourcePrompt::as_str),
92 Some("日本語の依頼です。\n中文提示也应该保留。\nemoji 👩💻")
93 );
94 }
95
96 #[test]
97 fn collapses_newline_padding() {
98 let prompt = ExternalSourcePrompt::new(
99 "Review this prompt carefully.\n\nThis paragraph should stay separated.\n\n\n\n\n\n\nWrite me a script to do fizz buzz.",
100 );
101
102 assert_eq!(
103 prompt.as_ref().map(ExternalSourcePrompt::as_str),
104 Some(
105 "Review this prompt carefully.\n\nThis paragraph should stay separated.\n\nWrite me a script to do fizz buzz."
106 )
107 );
108 }
109
110 #[test]
111 fn normalizes_carriage_returns() {
112 let prompt = ExternalSourcePrompt::new("Line one\r\nLine two\rLine three");
113
114 assert_eq!(
115 prompt.as_ref().map(ExternalSourcePrompt::as_str),
116 Some("Line one\nLine two\nLine three")
117 );
118 }
119
120 #[test]
121 fn strips_bidi_control_characters() {
122 let prompt = ExternalSourcePrompt::new("abc\u{202E}def\u{202C}ghi");
123
124 assert_eq!(
125 prompt.as_ref().map(ExternalSourcePrompt::as_str),
126 Some("abcdefghi")
127 );
128 }
129
130 #[test]
131 fn strips_other_control_characters() {
132 let prompt = ExternalSourcePrompt::new("safe\u{0000}\u{001B}\u{007F}text");
133
134 assert_eq!(
135 prompt.as_ref().map(ExternalSourcePrompt::as_str),
136 Some("safetext")
137 );
138 }
139
140 #[test]
141 fn keeps_tabs() {
142 let prompt = ExternalSourcePrompt::new("keep\tindentation");
143
144 assert_eq!(
145 prompt.as_ref().map(ExternalSourcePrompt::as_str),
146 Some("keep\tindentation")
147 );
148 }
149
150 #[test]
151 fn drops_empty_prompt() {
152 assert_eq!(ExternalSourcePrompt::new(""), None);
153 }
154
155 #[test]
156 fn drops_prompt_with_only_removed_characters() {
157 assert_eq!(
158 ExternalSourcePrompt::new("\u{202E}\u{202C}\u{0000}\u{001B}"),
159 None
160 );
161 }
162}