assistant edit tool: Support \r\n around markers (#26538)

Agus Zubiaga created

This should fix the tests on Windows

Release Notes:

- N/A

Change summary

crates/assistant_tools/src/edit_files_tool/edit_action.rs | 87 ++++++--
1 file changed, 60 insertions(+), 27 deletions(-)

Detailed changes

crates/assistant_tools/src/edit_files_tool/edit_action.rs 🔗

@@ -78,10 +78,10 @@ impl EditActionParser {
     pub fn parse_chunk(&mut self, input: &str) -> Vec<EditAction> {
         use State::*;
 
-        const FENCE: &[u8] = b"\n```";
-        const SEARCH_MARKER: &[u8] = b"<<<<<<< SEARCH\n";
-        const DIVIDER: &[u8] = b"=======\n";
-        const NL_DIVIDER: &[u8] = b"\n=======\n";
+        const FENCE: &[u8] = b"```";
+        const SEARCH_MARKER: &[u8] = b"<<<<<<< SEARCH";
+        const DIVIDER: &[u8] = b"=======";
+        const NL_DIVIDER: &[u8] = b"\n=======";
         const REPLACE_MARKER: &[u8] = b">>>>>>> REPLACE";
         const NL_REPLACE_MARKER: &[u8] = b"\n>>>>>>> REPLACE";
 
@@ -96,8 +96,8 @@ impl EditActionParser {
                 self.column += 1;
             }
 
-            match self.state {
-                Default => match match_marker(byte, FENCE, &mut self.marker_ix) {
+            match &self.state {
+                Default => match match_marker(byte, FENCE, false, &mut self.marker_ix) {
                     MarkerMatch::Complete => {
                         self.to_state(OpenFence);
                     }
@@ -105,12 +105,11 @@ impl EditActionParser {
                     MarkerMatch::None => {
                         if self.marker_ix > 0 {
                             self.marker_ix = 0;
+                        } else if self.pre_fence_line.ends_with(b"\n") {
                             self.pre_fence_line.clear();
                         }
 
-                        if byte != b'\n' {
-                            self.pre_fence_line.push(byte);
-                        }
+                        self.pre_fence_line.push(byte);
                     }
                 },
                 OpenFence => {
@@ -120,7 +119,7 @@ impl EditActionParser {
                     }
                 }
                 SearchMarker => {
-                    if self.expect_marker(byte, SEARCH_MARKER) {
+                    if self.expect_marker(byte, SEARCH_MARKER, true) {
                         self.to_state(SearchBlock);
                     }
                 }
@@ -129,6 +128,7 @@ impl EditActionParser {
                         byte,
                         DIVIDER,
                         NL_DIVIDER,
+                        true,
                         &mut self.marker_ix,
                         &mut self.old_bytes,
                     ) {
@@ -140,6 +140,7 @@ impl EditActionParser {
                         byte,
                         REPLACE_MARKER,
                         NL_REPLACE_MARKER,
+                        true,
                         &mut self.marker_ix,
                         &mut self.new_bytes,
                     ) {
@@ -147,10 +148,11 @@ impl EditActionParser {
                     }
                 }
                 CloseFence => {
-                    if self.expect_marker(byte, FENCE) {
+                    if self.expect_marker(byte, FENCE, false) {
                         if let Some(action) = self.action() {
                             actions.push(action);
                         }
+                        self.errors();
                         self.reset();
                     }
                 }
@@ -171,7 +173,17 @@ impl EditActionParser {
             return None;
         }
 
-        let file_path = String::from_utf8(std::mem::take(&mut self.pre_fence_line)).log_err()?;
+        let mut pre_fence_line = std::mem::take(&mut self.pre_fence_line);
+
+        if pre_fence_line.ends_with(b"\n") {
+            pre_fence_line.pop();
+
+            if pre_fence_line.ends_with(b"\r") {
+                pre_fence_line.pop();
+            }
+        }
+
+        let file_path = String::from_utf8(pre_fence_line).log_err()?;
         let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?;
 
         if self.old_bytes.is_empty() {
@@ -187,8 +199,8 @@ impl EditActionParser {
         }
     }
 
-    fn expect_marker(&mut self, byte: u8, marker: &'static [u8]) -> bool {
-        match match_marker(byte, marker, &mut self.marker_ix) {
+    fn expect_marker(&mut self, byte: u8, marker: &'static [u8], trailing_newline: bool) -> bool {
+        match match_marker(byte, marker, trailing_newline, &mut self.marker_ix) {
             MarkerMatch::Complete => true,
             MarkerMatch::Partial => false,
             MarkerMatch::None => {
@@ -230,14 +242,27 @@ enum MarkerMatch {
     Complete,
 }
 
-fn match_marker(byte: u8, marker: &[u8], marker_ix: &mut usize) -> MarkerMatch {
-    if byte == marker[*marker_ix] {
-        *marker_ix += 1;
-
-        if *marker_ix >= marker.len() {
+fn match_marker(
+    byte: u8,
+    marker: &[u8],
+    trailing_newline: bool,
+    marker_ix: &mut usize,
+) -> MarkerMatch {
+    if trailing_newline && *marker_ix >= marker.len() {
+        if byte == b'\n' {
+            MarkerMatch::Complete
+        } else if byte == b'\r' {
             MarkerMatch::Complete
         } else {
+            MarkerMatch::None
+        }
+    } else if byte == marker[*marker_ix] {
+        *marker_ix += 1;
+
+        if *marker_ix < marker.len() || trailing_newline {
             MarkerMatch::Partial
+        } else {
+            MarkerMatch::Complete
         }
     } else {
         MarkerMatch::None
@@ -248,6 +273,7 @@ fn collect_until_marker(
     byte: u8,
     marker: &[u8],
     nl_marker: &[u8],
+    trailing_newline: bool,
     marker_ix: &mut usize,
     buf: &mut Vec<u8>,
 ) -> bool {
@@ -258,7 +284,7 @@ fn collect_until_marker(
         nl_marker
     };
 
-    match match_marker(byte, marker, marker_ix) {
+    match match_marker(byte, marker, trailing_newline, marker_ix) {
         MarkerMatch::Complete => true,
         MarkerMatch::Partial => false,
         MarkerMatch::None => {
@@ -267,7 +293,7 @@ fn collect_until_marker(
                 *marker_ix = 0;
 
                 // The beginning of marker might match current byte
-                match match_marker(byte, marker, marker_ix) {
+                match match_marker(byte, marker, trailing_newline, marker_ix) {
                     MarkerMatch::Complete => return true,
                     MarkerMatch::Partial => return false,
                     MarkerMatch::None => { /* no match, keep collecting */ }
@@ -615,7 +641,7 @@ fn replacement() {}"#;
 
         // Check parser is in the correct state
         assert_eq!(parser.state, State::SearchBlock);
-        assert_eq!(parser.pre_fence_line, b"src/main.rs");
+        assert_eq!(parser.pre_fence_line, b"src/main.rs\n");
 
         // Continue parsing
         let actions2 = parser.parse_chunk("original code\n=======\n");
@@ -626,7 +652,7 @@ fn replacement() {}"#;
 
         // After complete parsing, state should reset
         assert_eq!(parser.state, State::Default);
-        assert!(parser.pre_fence_line.is_empty());
+        assert_eq!(parser.pre_fence_line, b"\n");
         assert!(parser.old_bytes.is_empty());
         assert!(parser.new_bytes.is_empty());
 
@@ -659,7 +685,7 @@ fn replacement() {}
         assert_eq!(
             error.kind,
             ParseErrorKind::ExpectedMarker {
-                expected: b"<<<<<<< SEARCH\n",
+                expected: b"<<<<<<< SEARCH",
                 found: b'W'
             }
         );
@@ -779,8 +805,15 @@ fn new_utils_func() {}
             }
         );
 
-        // Ensure we have no parsing errors
-        assert!(errors.is_empty(), "Parsing errors found: {:?}", errors);
+        // The system prompt includes some text that would produce errors
+        assert_eq!(
+            errors[0].to_string(),
+            "input:102:1: Expected marker \"<<<<<<< SEARCH\", found '3'"
+        );
+        assert_eq!(
+            errors[1].to_string(),
+            "input:109:0: Expected marker \"<<<<<<< SEARCH\", found '\\n'"
+        );
     }
 
     #[test]
@@ -800,7 +833,7 @@ fn replacement() {}
 
         assert_eq!(parser.errors().len(), 1);
         let error = &parser.errors()[0];
-        let expected_error = r#"input:3:9: Expected marker "<<<<<<< SEARCH\n", found 'W'"#;
+        let expected_error = r#"input:3:9: Expected marker "<<<<<<< SEARCH", found 'W'"#;
 
         assert_eq!(format!("{}", error), expected_error);
     }