1pub const SCRIPT_START_TAG: &str = "<eval type=\"lua\">";
2pub const SCRIPT_END_TAG: &str = "</eval>";
3
4const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes();
5const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes();
6
7/// Parses a script tag in an assistant message as it is being streamed.
8pub struct ScriptTagParser {
9 state: State,
10 buffer: Vec<u8>,
11 tag_match_ix: usize,
12}
13
14enum State {
15 Unstarted,
16 Streaming,
17 Ended,
18}
19
20#[derive(Debug, PartialEq)]
21pub struct ChunkOutput {
22 /// The chunk with script tags removed.
23 pub content: String,
24 /// The full script tag content. `None` until closed.
25 pub script_source: Option<String>,
26}
27
28impl ScriptTagParser {
29 /// Create a new script tag parser.
30 pub fn new() -> Self {
31 Self {
32 state: State::Unstarted,
33 buffer: Vec::new(),
34 tag_match_ix: 0,
35 }
36 }
37
38 /// Returns true if the parser has found a script tag.
39 pub fn found_script(&self) -> bool {
40 match self.state {
41 State::Unstarted => false,
42 State::Streaming | State::Ended => true,
43 }
44 }
45
46 /// Process a new chunk of input, splitting it into surrounding content and script source.
47 pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput {
48 let mut content = Vec::with_capacity(input.len());
49
50 for byte in input.bytes() {
51 match self.state {
52 State::Unstarted => {
53 if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) {
54 self.state = State::Streaming;
55 self.buffer = Vec::with_capacity(1024);
56 self.tag_match_ix = 0;
57 }
58 }
59 State::Streaming => {
60 if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) {
61 self.state = State::Ended;
62 }
63 }
64 State::Ended => content.push(byte),
65 }
66 }
67
68 let content = unsafe { String::from_utf8_unchecked(content) };
69
70 let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() {
71 let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) };
72
73 Some(source)
74 } else {
75 None
76 };
77
78 ChunkOutput {
79 content,
80 script_source,
81 }
82 }
83}
84
85fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec<u8>) -> bool {
86 // this can't be a method because it'd require a mutable borrow on both self and self.buffer
87
88 if match_tag_byte(byte, tag, tag_match_ix) {
89 *tag_match_ix >= tag.len()
90 } else {
91 if *tag_match_ix > 0 {
92 // push the partially matched tag to the buffer
93 buffer.extend_from_slice(&tag[..*tag_match_ix]);
94 *tag_match_ix = 0;
95
96 // the tag might start to match again
97 if match_tag_byte(byte, tag, tag_match_ix) {
98 return *tag_match_ix >= tag.len();
99 }
100 }
101
102 buffer.push(byte);
103
104 false
105 }
106}
107
108fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool {
109 if byte == tag[*tag_match_ix] {
110 *tag_match_ix += 1;
111 true
112 } else {
113 false
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_parse_complete_tag() {
123 let mut parser = ScriptTagParser::new();
124 let input = "<eval type=\"lua\">print(\"Hello, World!\")</eval>";
125 let result = parser.parse_chunk(input);
126 assert_eq!(result.content, "");
127 assert_eq!(
128 result.script_source,
129 Some("print(\"Hello, World!\")".to_string())
130 );
131 }
132
133 #[test]
134 fn test_no_tag() {
135 let mut parser = ScriptTagParser::new();
136 let input = "No tags here, just plain text";
137 let result = parser.parse_chunk(input);
138 assert_eq!(result.content, "No tags here, just plain text");
139 assert_eq!(result.script_source, None);
140 }
141
142 #[test]
143 fn test_partial_end_tag() {
144 let mut parser = ScriptTagParser::new();
145
146 // Start the tag
147 let result = parser.parse_chunk("<eval type=\"lua\">let x = '</e");
148 assert_eq!(result.content, "");
149 assert_eq!(result.script_source, None);
150
151 // Finish with the rest
152 let result = parser.parse_chunk("val' + 'not the end';</eval>");
153 assert_eq!(result.content, "");
154 assert_eq!(
155 result.script_source,
156 Some("let x = '</eval' + 'not the end';".to_string())
157 );
158 }
159
160 #[test]
161 fn test_text_before_and_after_tag() {
162 let mut parser = ScriptTagParser::new();
163 let input = "Before tag <eval type=\"lua\">print(\"Hello\")</eval> After tag";
164 let result = parser.parse_chunk(input);
165 assert_eq!(result.content, "Before tag After tag");
166 assert_eq!(result.script_source, Some("print(\"Hello\")".to_string()));
167 }
168
169 #[test]
170 fn test_multiple_chunks_with_surrounding_text() {
171 let mut parser = ScriptTagParser::new();
172
173 // First chunk with text before
174 let result = parser.parse_chunk("Before script <eval type=\"lua\">local x = 10");
175 assert_eq!(result.content, "Before script ");
176 assert_eq!(result.script_source, None);
177
178 // Second chunk with script content
179 let result = parser.parse_chunk("\nlocal y = 20");
180 assert_eq!(result.content, "");
181 assert_eq!(result.script_source, None);
182
183 // Last chunk with text after
184 let result = parser.parse_chunk("\nprint(x + y)</eval> After script");
185 assert_eq!(result.content, " After script");
186 assert_eq!(
187 result.script_source,
188 Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string())
189 );
190
191 let result = parser.parse_chunk(" there's more text");
192 assert_eq!(result.content, " there's more text");
193 assert_eq!(result.script_source, None);
194 }
195
196 #[test]
197 fn test_partial_start_tag_matching() {
198 let mut parser = ScriptTagParser::new();
199
200 // partial match of start tag...
201 let result = parser.parse_chunk("<ev");
202 assert_eq!(result.content, "");
203
204 // ...that's abandandoned when the < of a real tag is encountered
205 let result = parser.parse_chunk("<eval type=\"lua\">script content</eval>");
206 // ...so it gets pushed to content
207 assert_eq!(result.content, "<ev");
208 // ...and the real tag is parsed correctly
209 assert_eq!(result.script_source, Some("script content".to_string()));
210 }
211
212 #[test]
213 fn test_random_chunked_parsing() {
214 use rand::rngs::StdRng;
215 use rand::{Rng, SeedableRng};
216 use std::time::{SystemTime, UNIX_EPOCH};
217
218 let test_inputs = [
219 "Before <eval type=\"lua\">print(\"Hello\")</eval> After",
220 "No tags here at all",
221 "<eval type=\"lua\">local x = 10\nlocal y = 20\nprint(x + y)</eval>",
222 "Text <eval type=\"lua\">if true then\nprint(\"nested </e\")\nend</eval> more",
223 ];
224
225 let seed = SystemTime::now()
226 .duration_since(UNIX_EPOCH)
227 .unwrap()
228 .as_secs();
229
230 eprintln!("Using random seed: {}", seed);
231 let mut rng = StdRng::seed_from_u64(seed);
232
233 for test_input in &test_inputs {
234 let mut reference_parser = ScriptTagParser::new();
235 let expected = reference_parser.parse_chunk(test_input);
236
237 let mut chunked_parser = ScriptTagParser::new();
238 let mut remaining = test_input.as_bytes();
239 let mut actual_content = String::new();
240 let mut actual_script = None;
241
242 while !remaining.is_empty() {
243 let chunk_size = rng.gen_range(1..=remaining.len().min(5));
244 let (chunk, rest) = remaining.split_at(chunk_size);
245 remaining = rest;
246
247 let chunk_str = std::str::from_utf8(chunk).unwrap();
248 let result = chunked_parser.parse_chunk(chunk_str);
249
250 actual_content.push_str(&result.content);
251 if result.script_source.is_some() {
252 actual_script = result.script_source;
253 }
254 }
255
256 assert_eq!(actual_content, expected.content);
257 assert_eq!(actual_script, expected.script_source);
258 }
259 }
260}