Tolerate edits ending with `</edits>` instead of `</new_text>` (#31453)

Antonio Scandurra created

Release Notes:

- Improve reliability of the agent when a model outputs malformed edits.

Change summary

crates/assistant_tools/src/edit_agent/edit_parser.rs | 58 ++++++++++---
1 file changed, 43 insertions(+), 15 deletions(-)

Detailed changes

crates/assistant_tools/src/edit_agent/edit_parser.rs 🔗

@@ -2,12 +2,12 @@ use derive_more::{Add, AddAssign};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use smallvec::SmallVec;
-use std::{cmp, mem, ops::Range};
+use std::{mem, ops::Range};
 
 const OLD_TEXT_END_TAG: &str = "</old_text>";
 const NEW_TEXT_END_TAG: &str = "</new_text>";
-const END_TAG_LEN: usize = OLD_TEXT_END_TAG.len();
-const _: () = debug_assert!(OLD_TEXT_END_TAG.len() == NEW_TEXT_END_TAG.len());
+const EDITS_END_TAG: &str = "</edits>";
+const END_TAGS: [&str; 3] = [OLD_TEXT_END_TAG, NEW_TEXT_END_TAG, EDITS_END_TAG];
 
 #[derive(Debug)]
 pub enum EditParserEvent {
@@ -115,8 +115,9 @@ impl EditParser {
                         self.state = EditParserState::Pending;
                         edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
                     } else {
-                        let mut end_prefixes = (1..END_TAG_LEN)
-                            .flat_map(|i| [&NEW_TEXT_END_TAG[..i], &OLD_TEXT_END_TAG[..i]])
+                        let mut end_prefixes = END_TAGS
+                            .iter()
+                            .flat_map(|tag| (1..tag.len()).map(move |i| &tag[..i]))
                             .chain(["\n"]);
                         if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
                             edit_events.push(EditParserEvent::NewTextChunk {
@@ -133,16 +134,11 @@ impl EditParser {
     }
 
     fn find_end_tag(&self) -> Option<Range<usize>> {
-        let old_text_end_tag_ix = self.buffer.find(OLD_TEXT_END_TAG);
-        let new_text_end_tag_ix = self.buffer.find(NEW_TEXT_END_TAG);
-        let start_ix = if let Some((old_text_ix, new_text_ix)) =
-            old_text_end_tag_ix.zip(new_text_end_tag_ix)
-        {
-            cmp::min(old_text_ix, new_text_ix)
-        } else {
-            old_text_end_tag_ix.or(new_text_end_tag_ix)?
-        };
-        Some(start_ix..start_ix + END_TAG_LEN)
+        let (tag, start_ix) = END_TAGS
+            .iter()
+            .flat_map(|tag| Some((tag, self.buffer.find(tag)?)))
+            .min_by_key(|(_, ix)| *ix)?;
+        Some(start_ix..start_ix + tag.len())
     }
 
     pub fn finish(self) -> EditParserMetrics {
@@ -373,6 +369,35 @@ mod tests {
                 mismatched_tags: 4
             }
         );
+
+        let mut parser = EditParser::new();
+        assert_eq!(
+            parse_random_chunks(
+                // Reduced from an actual Opus 4 output
+                indoc! {"
+                    <edits>
+                    <old_text>
+                    Lorem
+                    </old_text>
+                    <new_text>
+                    LOREM
+                    </edits>
+                "},
+                &mut parser,
+                &mut rng
+            ),
+            vec![Edit {
+                old_text: "Lorem".to_string(),
+                new_text: "LOREM".to_string(),
+            },]
+        );
+        assert_eq!(
+            parser.finish(),
+            EditParserMetrics {
+                tags: 2,
+                mismatched_tags: 1
+            }
+        );
     }
 
     #[derive(Default, Debug, PartialEq, Eq)]
@@ -407,6 +432,9 @@ mod tests {
             }
             last_ix = chunk_ix;
         }
+
+        assert_eq!(pending_edit, Edit::default(), "unfinished edit");
+
         edits
     }
 }