@@ -800,16 +800,18 @@ fn new_utils_func() {}
old: "from flask import Flask".to_string(),
new: "import math\nfrom flask import Flask".to_string(),
}
+ .fix_lf(),
);
assert_eq!(
- actions[1],
- EditAction::Replace {
- file_path: PathBuf::from("mathweb/flask/app.py"),
- old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(),
- new: "".to_string(),
- }
- );
+ actions[1],
+ EditAction::Replace {
+ file_path: PathBuf::from("mathweb/flask/app.py"),
+ old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(),
+ new: "".to_string(),
+ }
+ .fix_lf()
+ );
assert_eq!(
actions[2],
@@ -818,6 +820,7 @@ fn new_utils_func() {}
old: " return str(factorial(n))".to_string(),
new: " return str(math.factorial(n))".to_string(),
}
+ .fix_lf(),
);
assert_eq!(
@@ -827,6 +830,7 @@ fn new_utils_func() {}
content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")"
.to_string(),
}
+ .fix_lf(),
);
assert_eq!(
@@ -836,6 +840,7 @@ fn new_utils_func() {}
old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(),
new: "from hello import hello".to_string(),
}
+ .fix_lf(),
);
// The system prompt includes some text that would produce errors
@@ -843,10 +848,39 @@ fn new_utils_func() {}
errors[0].to_string(),
"input:102:1: Expected marker \"<<<<<<< SEARCH\", found '3'"
);
+ #[cfg(not(windows))]
assert_eq!(
errors[1].to_string(),
"input:109:0: Expected marker \"<<<<<<< SEARCH\", found '\\n'"
);
+ #[cfg(windows)]
+ assert_eq!(
+ errors[1].to_string(),
+ "input:108:1: Expected marker \"<<<<<<< SEARCH\", found '\\r'"
+ );
+ }
+
+ impl EditAction {
+ fn fix_lf(self: EditAction) -> EditAction {
+ #[cfg(windows)]
+ match self {
+ EditAction::Replace {
+ file_path,
+ old,
+ new,
+ } => EditAction::Replace {
+ file_path: file_path.clone(),
+ old: old.replace("\n", "\r\n"),
+ new: new.replace("\n", "\r\n"),
+ },
+ EditAction::Write { file_path, content } => EditAction::Write {
+ file_path: file_path.clone(),
+ content: content.replace("\n", "\r\n"),
+ },
+ }
+ #[cfg(not(windows))]
+ self
+ }
}
#[test]