tag.rs

  1pub const SCRIPT_START_TAG: &str = "<eval type=\"lua\">";
  2pub const SCRIPT_END_TAG: &str = "</eval>";
  3
  4const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes();
  5const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes();
  6
  7/// Parses a script tag in an assistant message as it is being streamed.
  8pub struct ScriptTagParser {
  9    state: State,
 10    buffer: Vec<u8>,
 11    tag_match_ix: usize,
 12}
 13
 14enum State {
 15    Unstarted,
 16    Streaming,
 17    Ended,
 18}
 19
 20#[derive(Debug, PartialEq)]
 21pub struct ChunkOutput {
 22    /// The chunk with script tags removed.
 23    pub content: String,
 24    /// The full script tag content. `None` until closed.
 25    pub script_source: Option<String>,
 26}
 27
 28impl ScriptTagParser {
 29    /// Create a new script tag parser.
 30    pub fn new() -> Self {
 31        Self {
 32            state: State::Unstarted,
 33            buffer: Vec::new(),
 34            tag_match_ix: 0,
 35        }
 36    }
 37
 38    /// Returns true if the parser has found a script tag.
 39    pub fn found_script(&self) -> bool {
 40        match self.state {
 41            State::Unstarted => false,
 42            State::Streaming | State::Ended => true,
 43        }
 44    }
 45
 46    /// Process a new chunk of input, splitting it into surrounding content and script source.
 47    pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput {
 48        let mut content = Vec::with_capacity(input.len());
 49
 50        for byte in input.bytes() {
 51            match self.state {
 52                State::Unstarted => {
 53                    if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) {
 54                        self.state = State::Streaming;
 55                        self.buffer = Vec::with_capacity(1024);
 56                        self.tag_match_ix = 0;
 57                    }
 58                }
 59                State::Streaming => {
 60                    if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) {
 61                        self.state = State::Ended;
 62                    }
 63                }
 64                State::Ended => content.push(byte),
 65            }
 66        }
 67
 68        let content = unsafe { String::from_utf8_unchecked(content) };
 69
 70        let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() {
 71            let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) };
 72
 73            Some(source)
 74        } else {
 75            None
 76        };
 77
 78        ChunkOutput {
 79            content,
 80            script_source,
 81        }
 82    }
 83}
 84
 85fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec<u8>) -> bool {
 86    // this can't be a method because it'd require a mutable borrow on both self and self.buffer
 87
 88    if match_tag_byte(byte, tag, tag_match_ix) {
 89        *tag_match_ix >= tag.len()
 90    } else {
 91        if *tag_match_ix > 0 {
 92            // push the partially matched tag to the buffer
 93            buffer.extend_from_slice(&tag[..*tag_match_ix]);
 94            *tag_match_ix = 0;
 95
 96            // the tag might start to match again
 97            if match_tag_byte(byte, tag, tag_match_ix) {
 98                return *tag_match_ix >= tag.len();
 99            }
100        }
101
102        buffer.push(byte);
103
104        false
105    }
106}
107
108fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool {
109    if byte == tag[*tag_match_ix] {
110        *tag_match_ix += 1;
111        true
112    } else {
113        false
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_parse_complete_tag() {
123        let mut parser = ScriptTagParser::new();
124        let input = "<eval type=\"lua\">print(\"Hello, World!\")</eval>";
125        let result = parser.parse_chunk(input);
126        assert_eq!(result.content, "");
127        assert_eq!(
128            result.script_source,
129            Some("print(\"Hello, World!\")".to_string())
130        );
131    }
132
133    #[test]
134    fn test_no_tag() {
135        let mut parser = ScriptTagParser::new();
136        let input = "No tags here, just plain text";
137        let result = parser.parse_chunk(input);
138        assert_eq!(result.content, "No tags here, just plain text");
139        assert_eq!(result.script_source, None);
140    }
141
142    #[test]
143    fn test_partial_end_tag() {
144        let mut parser = ScriptTagParser::new();
145
146        // Start the tag
147        let result = parser.parse_chunk("<eval type=\"lua\">let x = '</e");
148        assert_eq!(result.content, "");
149        assert_eq!(result.script_source, None);
150
151        // Finish with the rest
152        let result = parser.parse_chunk("val' + 'not the end';</eval>");
153        assert_eq!(result.content, "");
154        assert_eq!(
155            result.script_source,
156            Some("let x = '</eval' + 'not the end';".to_string())
157        );
158    }
159
160    #[test]
161    fn test_text_before_and_after_tag() {
162        let mut parser = ScriptTagParser::new();
163        let input = "Before tag <eval type=\"lua\">print(\"Hello\")</eval> After tag";
164        let result = parser.parse_chunk(input);
165        assert_eq!(result.content, "Before tag  After tag");
166        assert_eq!(result.script_source, Some("print(\"Hello\")".to_string()));
167    }
168
169    #[test]
170    fn test_multiple_chunks_with_surrounding_text() {
171        let mut parser = ScriptTagParser::new();
172
173        // First chunk with text before
174        let result = parser.parse_chunk("Before script <eval type=\"lua\">local x = 10");
175        assert_eq!(result.content, "Before script ");
176        assert_eq!(result.script_source, None);
177
178        // Second chunk with script content
179        let result = parser.parse_chunk("\nlocal y = 20");
180        assert_eq!(result.content, "");
181        assert_eq!(result.script_source, None);
182
183        // Last chunk with text after
184        let result = parser.parse_chunk("\nprint(x + y)</eval> After script");
185        assert_eq!(result.content, " After script");
186        assert_eq!(
187            result.script_source,
188            Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string())
189        );
190
191        let result = parser.parse_chunk(" there's more text");
192        assert_eq!(result.content, " there's more text");
193        assert_eq!(result.script_source, None);
194    }
195
196    #[test]
197    fn test_partial_start_tag_matching() {
198        let mut parser = ScriptTagParser::new();
199
200        // partial match of start tag...
201        let result = parser.parse_chunk("<ev");
202        assert_eq!(result.content, "");
203
204        // ...that's abandandoned when the < of a real tag is encountered
205        let result = parser.parse_chunk("<eval type=\"lua\">script content</eval>");
206        // ...so it gets pushed to content
207        assert_eq!(result.content, "<ev");
208        // ...and the real tag is parsed correctly
209        assert_eq!(result.script_source, Some("script content".to_string()));
210    }
211
212    #[test]
213    fn test_random_chunked_parsing() {
214        use rand::rngs::StdRng;
215        use rand::{Rng, SeedableRng};
216        use std::time::{SystemTime, UNIX_EPOCH};
217
218        let test_inputs = [
219            "Before <eval type=\"lua\">print(\"Hello\")</eval> After",
220            "No tags here at all",
221            "<eval type=\"lua\">local x = 10\nlocal y = 20\nprint(x + y)</eval>",
222            "Text <eval type=\"lua\">if true then\nprint(\"nested </e\")\nend</eval> more",
223        ];
224
225        let seed = SystemTime::now()
226            .duration_since(UNIX_EPOCH)
227            .unwrap()
228            .as_secs();
229
230        eprintln!("Using random seed: {}", seed);
231        let mut rng = StdRng::seed_from_u64(seed);
232
233        for test_input in &test_inputs {
234            let mut reference_parser = ScriptTagParser::new();
235            let expected = reference_parser.parse_chunk(test_input);
236
237            let mut chunked_parser = ScriptTagParser::new();
238            let mut remaining = test_input.as_bytes();
239            let mut actual_content = String::new();
240            let mut actual_script = None;
241
242            while !remaining.is_empty() {
243                let chunk_size = rng.gen_range(1..=remaining.len().min(5));
244                let (chunk, rest) = remaining.split_at(chunk_size);
245                remaining = rest;
246
247                let chunk_str = std::str::from_utf8(chunk).unwrap();
248                let result = chunked_parser.parse_chunk(chunk_str);
249
250                actual_content.push_str(&result.content);
251                if result.script_source.is_some() {
252                    actual_script = result.script_source;
253                }
254            }
255
256            assert_eq!(actual_content, expected.content);
257            assert_eq!(actual_script, expected.script_source);
258        }
259    }
260}