diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 1da966ea9c5b2f3cf7b866bc82839de9d70e9fa6..2a1b49007bd19e721a6d95ebddda3758c86aaaef 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -251,7 +251,10 @@ impl TeacherPrompt { } } - if response.trim().ends_with(Self::NO_EDITS) { + if response + .trim_end_matches(&[' ', '\n', '`']) + .ends_with(Self::NO_EDITS) + { return Ok(no_edits); } @@ -886,4 +889,42 @@ mod tests { let result = extract_last_codeblock(text).unwrap(); assert_eq!(result, "content here\n"); } + + #[test] + fn test_parse_no_edits_response_with_trailing_backticks() { + let response = "NO_EDITS```"; + + let parsed = TeacherPrompt::parse( + &Example { + spec: edit_prediction::example_spec::ExampleSpec { + name: "test".to_string(), + repository_url: "https://github.com/zed-industries/zed.git".to_string(), + revision: "HEAD".to_string(), + tags: Vec::new(), + reasoning: None, + uncommitted_diff: String::new(), + cursor_path: std::sync::Arc::from(std::path::Path::new("src/main.rs")), + cursor_position: "0:0".to_string(), + edit_history: String::new(), + expected_patches: Vec::new(), + rejected_patch: None, + telemetry: None, + human_feedback: Vec::new(), + rating: None, + }, + prompt_inputs: None, + prompt: None, + predictions: Vec::new(), + score: Vec::new(), + qa: Vec::new(), + zed_version: None, + state: None, + }, + response, + ) + .unwrap(); + + assert!(parsed.0.is_empty()); + assert!(parsed.1.is_none()); + } } diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index a0c4242748c9ad83c3b0fbe9e70a4b132ac75c4d..58da6c47e91491cc785804c7f4c2aab30887a741 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -10,7 +10,7 @@ use crate::{ BatchProvider, PredictionProvider, anthropic_client::AnthropicClient, example::{ActualCursor, Example, ExamplePrediction}, - format_prompt::{TeacherPrompt, extract_last_codeblock}, + format_prompt::TeacherPrompt, metrics::count_patch_token_changes, openai_client::OpenAiClient, parse_output::run_parse_output, @@ -227,10 +227,7 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool { /// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction, /// and delegates normal output to `TeacherPrompt::parse`. pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option)> { - let last_codeblock = - extract_last_codeblock(actual_output).unwrap_or_else(|| actual_output.to_string()); - - if last_codeblock.contains(KEEP_PREVIOUS) { + if actual_output.contains(KEEP_PREVIOUS) { let original = example .predictions .first() @@ -456,3 +453,71 @@ pub async fn sync_batches(args: &RepairArgs) -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{PredictionProvider, TeacherBackend}; + use edit_prediction::example_spec::ExampleSpec; + use std::{path::Path, sync::Arc}; + + fn example_with_previous_prediction() -> Example { + Example { + spec: ExampleSpec { + name: "example".to_string(), + repository_url: "https://github.com/zed-industries/zed.git".to_string(), + revision: "HEAD".to_string(), + tags: Vec::new(), + reasoning: None, + uncommitted_diff: String::new(), + cursor_path: Arc::from(Path::new("src/main.rs")), + cursor_position: "0:0".to_string(), + edit_history: String::new(), + expected_patches: Vec::new(), + rejected_patch: None, + telemetry: None, + human_feedback: Vec::new(), + rating: None, + }, + prompt_inputs: None, + prompt: None, + predictions: vec![ExamplePrediction { + actual_patch: Some("previous patch".to_string()), + actual_output: String::new(), + actual_cursor: Some(ActualCursor { + path: "src/main.rs".to_string(), + row: 1, + column: 2, + offset: 3, + editable_region_offset: Some(4), + }), + error: None, + provider: PredictionProvider::Teacher(TeacherBackend::Sonnet45), + cumulative_logprob: None, + avg_logprob: None, + }], + score: Vec::new(), + qa: Vec::new(), + zed_version: None, + state: None, + } + } + + #[test] + fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() { + let example = example_with_previous_prediction(); + let actual_output = indoc::indoc! {" + After reviewing the feedback, the previous prediction is still correct. + Use `KEEP_PREVIOUS`. + + ``` + unrelated trailing code block + ``` + "}; + + let (patch, cursor) = parse(&example, actual_output).unwrap(); + + assert_eq!(patch, "previous patch"); + assert_eq!(cursor.unwrap().offset, 3); + } +}