assistant edit tool: Do not include `\r` in old/new str (#26542)

Agus Zubiaga created

#26538 fixed part of the issue, but it would keep trailing carriage
returns in the old/new strings. The model is unlikely to produce those,
but we might as well support them.

Release Notes:

- N/A

Change summary

crates/assistant_tools/src/edit_files_tool/edit_action.rs | 62 ++++++--
1 file changed, 47 insertions(+), 15 deletions(-)

Detailed changes

crates/assistant_tools/src/edit_files_tool/edit_action.rs 🔗

@@ -177,10 +177,7 @@ impl EditActionParser {
 
         if pre_fence_line.ends_with(b"\n") {
             pre_fence_line.pop();
-
-            if pre_fence_line.ends_with(b"\r") {
-                pre_fence_line.pop();
-            }
+            pop_carriage_return(&mut pre_fence_line);
         }
 
         let file_path = String::from_utf8(pre_fence_line).log_err()?;
@@ -252,7 +249,7 @@ fn match_marker(
         if byte == b'\n' {
             MarkerMatch::Complete
         } else if byte == b'\r' {
-            MarkerMatch::Complete
+            MarkerMatch::Partial
         } else {
             MarkerMatch::None
         }
@@ -285,7 +282,10 @@ fn collect_until_marker(
     };
 
     match match_marker(byte, marker, trailing_newline, marker_ix) {
-        MarkerMatch::Complete => true,
+        MarkerMatch::Complete => {
+            pop_carriage_return(buf);
+            true
+        }
         MarkerMatch::Partial => false,
         MarkerMatch::None => {
             if *marker_ix > 0 {
@@ -307,6 +307,12 @@ fn collect_until_marker(
     }
 }
 
+fn pop_carriage_return(buf: &mut Vec<u8>) {
+    if buf.ends_with(b"\r") {
+        buf.pop();
+    }
+}
+
 #[derive(Debug, PartialEq, Eq)]
 pub struct ParseError {
     line: usize,
@@ -373,6 +379,7 @@ fn replacement() {}
                 new: "fn replacement() {}".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -399,6 +406,7 @@ fn replacement() {}
                 new: "fn replacement() {}".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -429,6 +437,7 @@ This change makes the function better.
                 new: "fn replacement() {}".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -474,6 +483,7 @@ fn new_util() -> bool { true }
                 new: "fn new_util() -> bool { true }".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -514,6 +524,7 @@ fn replacement() {
                 new: "fn replacement() {\n    println!(\"This is the replacement function\");\n    let x = 100;\n    if x > 50 {\n        println!(\"Large number\");\n    } else {\n        println!(\"Small number\");\n    }\n}".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -543,6 +554,7 @@ fn new_function() {
                     .to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -559,8 +571,7 @@ fn this_will_be_deleted() {
 "#;
 
         let mut parser = EditActionParser::new();
-        let actions = parser.parse_chunk(input);
-
+        let actions = parser.parse_chunk(&input);
         assert_eq!(actions.len(), 1);
         assert_eq!(
             actions[0],
@@ -571,6 +582,21 @@ fn this_will_be_deleted() {
                 new: "".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
+
+        let actions = parser.parse_chunk(&input.replace("\n", "\r\n"));
+        assert_eq!(actions.len(), 1);
+        assert_eq!(
+            actions[0],
+            EditAction::Replace {
+                file_path: "src/main.rs".to_string(),
+                old:
+                    "fn this_will_be_deleted() {\r\n    println!(\"Deleting this function\");\r\n}"
+                        .to_string(),
+                new: "".to_string(),
+            }
+        );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -616,10 +642,12 @@ fn replacement() {}"#;
         let mut parser = EditActionParser::new();
         let actions1 = parser.parse_chunk(input_part1);
         assert_eq!(actions1.len(), 0);
+        assert_eq!(parser.errors().len(), 0);
 
         let actions2 = parser.parse_chunk(input_part2);
         // No actions should be complete yet
         assert_eq!(actions2.len(), 0);
+        assert_eq!(parser.errors().len(), 0);
 
         let actions3 = parser.parse_chunk(input_part3);
         // The third chunk should complete the action
@@ -632,6 +660,7 @@ fn replacement() {}"#;
                 new: "fn replacement() {}".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -642,11 +671,13 @@ fn replacement() {}"#;
         // Check parser is in the correct state
         assert_eq!(parser.state, State::SearchBlock);
         assert_eq!(parser.pre_fence_line, b"src/main.rs\n");
+        assert_eq!(parser.errors().len(), 0);
 
         // Continue parsing
         let actions2 = parser.parse_chunk("original code\n=======\n");
         assert_eq!(parser.state, State::ReplaceBlock);
         assert_eq!(parser.old_bytes, b"original code");
+        assert_eq!(parser.errors().len(), 0);
 
         let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n");
 
@@ -659,6 +690,7 @@ fn replacement() {}"#;
         assert_eq!(actions1.len(), 0);
         assert_eq!(actions2.len(), 0);
         assert_eq!(actions3.len(), 1);
+        assert_eq!(parser.errors().len(), 0);
     }
 
     #[test]
@@ -680,14 +712,9 @@ fn replacement() {}
         assert_eq!(parser.errors().len(), 1);
         let error = &parser.errors()[0];
 
-        assert_eq!(error.line, 3);
-        assert_eq!(error.column, 9);
         assert_eq!(
-            error.kind,
-            ParseErrorKind::ExpectedMarker {
-                expected: b"<<<<<<< SEARCH",
-                found: b'W'
-            }
+            error.to_string(),
+            "input:3:9: Expected marker \"<<<<<<< SEARCH\", found 'W'"
         );
     }
 
@@ -725,6 +752,11 @@ fn new_utils_func() {}
                 new: "fn new_utils_func() {}".to_string(),
             }
         );
+        assert_eq!(parser.errors().len(), 1);
+        assert_eq!(
+            parser.errors()[0].to_string(),
+            "input:8:1: Expected marker \"```\", found '<'".to_string()
+        );
 
         // The parser should continue after an error
         assert_eq!(parser.state, State::Default);