diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs index d6343871e8054fc54062f3d3f7f5210374b36812..21b255d5f99dc00e5264ffe901cced1352515fa1 100644 --- a/crates/edit_prediction_cli/src/distill.rs +++ b/crates/edit_prediction_cli/src/distill.rs @@ -6,7 +6,7 @@ use crate::example::Example; pub async fn run_distill(example: &mut Example) -> Result<()> { let predictions = mem::take(&mut example.predictions) .into_iter() - .map(|p| p.actual_patch) + .filter_map(|p| p.actual_patch) .collect(); example.spec.expected_patches = predictions; diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index b7c98c4035e1aaf14f3de484ab3233849a65a2b5..b06ac57c54909d690d5aa65b99760586248e2bf9 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -73,7 +73,8 @@ pub struct ExamplePrompt { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ExamplePrediction { - pub actual_patch: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub actual_patch: Option, pub actual_output: String, pub provider: PredictionProvider, } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 042ce6b1135f9e919dc8f8c6fbd7bdc66f660d05..c839be804fe9599f1b7a2b077218041ce58e238a 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -6,6 +6,7 @@ mod git; mod headless; mod load_project; mod metrics; +mod parse_output; mod paths; mod predict; mod progress; @@ -130,6 +131,9 @@ enum Command { FormatPrompt(FormatPromptArgs), /// Runs edit prediction Predict(PredictArgs), + /// Parse model outputs (actual_output) into unified diffs (actual_patch). + /// Requires format-prompt to have been run first. Uses provider from prompt. + ParseOutput, /// Computes a score based on actual and expected patches Score(PredictArgs), /// Prepares a distillation dataset by copying expected outputs to @@ -159,6 +163,7 @@ impl Display for Command { Command::Predict(args) => { write!(f, "predict --provider={}", args.provider) } + Command::ParseOutput => write!(f, "parse-output"), Command::Score(args) => { write!(f, "score --provider={}", args.provider) } @@ -601,6 +606,9 @@ fn main() { ) .await?; } + Command::ParseOutput => { + parse_output::run_parse_output(example)?; + } Command::Distill => { run_distill(example).await?; } diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs new file mode 100644 index 0000000000000000000000000000000000000000..06e8e2dadd61c4e0df136acd14ff03d65ebe2bda --- /dev/null +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -0,0 +1,234 @@ +use crate::{PredictionProvider, example::Example, format_prompt::TeacherPrompt}; +use anyhow::{Context as _, Result}; +use zeta_prompt::{CURSOR_MARKER, ZetaVersion}; + +pub fn run_parse_output(example: &mut Example) -> Result<()> { + let provider = example + .prompt + .as_ref() + .context("prompt required (run format-prompt first)")? + .provider; + example + .prompt_inputs + .as_ref() + .context("prompt_inputs required")?; + + let parsed_patches: Vec<_> = example + .predictions + .iter() + .enumerate() + .filter(|(_, p)| !p.actual_output.is_empty()) + .map(|(ix, prediction)| { + let actual_patch = + parse_prediction_output(example, &prediction.actual_output, provider); + actual_patch.map(|patch| (ix, patch)) + }) + .collect::>>()?; + + for (ix, actual_patch) in parsed_patches { + example.predictions[ix].actual_patch = Some(actual_patch); + example.predictions[ix].provider = provider; + } + + Ok(()) +} + +pub fn parse_prediction_output( + example: &Example, + actual_output: &str, + provider: PredictionProvider, +) -> Result { + match provider { + PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => { + TeacherPrompt::parse(example, actual_output) + } + PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version), + _ => anyhow::bail!( + "parse-output only supports Teacher and Zeta2 providers, got {:?}", + provider + ), + } +} + +fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result { + let (current_marker, end_marker) = match version { + ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), + ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => { + ("<|fim_middle|>current\n", "<|fim_suffix|>") + } + ZetaVersion::V0120GitMergeMarkers => ( + zeta_prompt::v0120_git_merge_markers::START_MARKER, + zeta_prompt::v0120_git_merge_markers::SEPARATOR, + ), + }; + + let start = prompt.find(current_marker).with_context(|| { + format!( + "missing current marker '{}' in prompt", + current_marker.trim() + ) + })? + current_marker.len(); + + let end = prompt[start..] + .find(end_marker) + .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))? + + start; + + let region = &prompt[start..end]; + let region = region.strip_suffix('\n').unwrap_or(region); + Ok(region.replace(CURSOR_MARKER, "")) +} + +fn parse_zeta2_output( + example: &Example, + actual_output: &str, + version: ZetaVersion, +) -> Result { + let prompt = &example.prompt.as_ref().context("prompt required")?.input; + let prompt_inputs = example + .prompt_inputs + .as_ref() + .context("prompt_inputs required")?; + + let old_text = extract_zeta2_current_region(prompt, version)?; + + let mut new_text = actual_output.replace(CURSOR_MARKER, ""); + + if version == ZetaVersion::V0120GitMergeMarkers { + if let Some(stripped) = + new_text.strip_suffix(zeta_prompt::v0120_git_merge_markers::END_MARKER) + { + new_text = stripped.to_string(); + } + } + + let mut old_text_normalized = old_text.clone(); + if !new_text.is_empty() && !new_text.ends_with('\n') { + new_text.push('\n'); + } + if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') { + old_text_normalized.push('\n'); + } + + let old_text_trimmed = old_text.trim_end_matches('\n'); + let (editable_region_offset, _) = prompt_inputs + .content + .match_indices(old_text_trimmed) + .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset)) + .with_context(|| { + format!( + "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}", + old_text_trimmed, &prompt_inputs.content + ) + })?; + + let editable_region_start_line = prompt_inputs.content[..editable_region_offset] + .matches('\n') + .count(); + + let diff = language::unified_diff_with_offsets( + &old_text_normalized, + &new_text, + editable_region_start_line as u32, + editable_region_start_line as u32, + ); + + let formatted_diff = format!( + "--- a/{path}\n+++ b/{path}\n{diff}", + path = example.spec.cursor_path.to_string_lossy(), + ); + + Ok(formatted_diff) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_zeta2_current_region_v0113() { + let prompt = indoc::indoc! {" + <|file_sep|>src/main.rs + <|fim_prefix|> + fn main() { + <|fim_middle|>current + println!(\"hello\"); + <|fim_suffix|> + } + <|fim_middle|>updated + "}; + + let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap(); + assert_eq!(region, "println!(\"hello\");"); + } + + #[test] + fn test_extract_zeta2_current_region_v0112() { + let prompt = indoc::indoc! {" + <|file_sep|>src/main.rs + <|fim_prefix|> + fn main() { + <|fim_suffix|> + } + <|fim_middle|>current + println!(\"hello\"); + <|fim_middle|>updated + "}; + + let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap(); + assert_eq!(region, "println!(\"hello\");"); + } + + #[test] + fn test_extract_zeta2_current_region_with_cursor_marker() { + let prompt = indoc::indoc! {" + <|file_sep|>src/main.rs + <|fim_prefix|> + fn main() { + <|fim_middle|>current + print<|user_cursor|>ln!(\"hello\"); + <|fim_suffix|> + } + <|fim_middle|>updated + "}; + + let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap(); + assert_eq!(region, "println!(\"hello\");"); + } + + #[test] + fn test_extract_zeta2_current_region_v0120_git_merge_markers() { + let prompt = indoc::indoc! {" + <|file_sep|>src/main.rs + <|fim_prefix|> + fn main() { + <|fim_suffix|> + } + <|fim_middle|><<<<<<< CURRENT + println!(\"hello\"); + ======= + "}; + + let region = + extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap(); + assert_eq!(region, "println!(\"hello\");"); + } + + #[test] + fn test_extract_zeta2_current_region_v0120_with_cursor_marker() { + let prompt = indoc::indoc! {" + <|file_sep|>src/main.rs + <|fim_prefix|> + fn main() { + <|fim_suffix|> + } + <|fim_middle|><<<<<<< CURRENT + print<|user_cursor|>ln!(\"hello\"); + ======= + "}; + + let region = + extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap(); + assert_eq!(region, "println!(\"hello\");"); + } +} diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 9749675a08ddf59e221204a89603a67e5ea329ec..f0bc99c9dd855c39fccb3068f20150aef407e6b8 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -186,7 +186,7 @@ pub async fn run_prediction( .unwrap() .predictions .push(ExamplePrediction { - actual_patch: String::new(), + actual_patch: None, actual_output: String::new(), provider, }); @@ -204,16 +204,14 @@ pub async fn run_prediction( }) .await?; - let actual_patch = prediction - .and_then(|prediction| { - let prediction = prediction.prediction.ok()?; - prediction - .edit_preview - .as_unified_diff(prediction.snapshot.file(), &prediction.edits) - }) - .unwrap_or_default(); + let actual_patch = prediction.and_then(|prediction| { + let prediction = prediction.prediction.ok()?; + prediction + .edit_preview + .as_unified_diff(prediction.snapshot.file(), &prediction.edits) + }); - let has_prediction = !actual_patch.is_empty(); + let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty()); updated_example .lock() @@ -293,7 +291,7 @@ async fn predict_anthropic( let actual_patch = TeacherPrompt::parse(&example, &actual_output)?; let prediction = ExamplePrediction { - actual_patch, + actual_patch: Some(actual_patch), actual_output, provider: if batched { PredictionProvider::Teacher(version) diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index f467f140b4975163e80ff28c8de4b12807edc034..3f7e0167ee667996fd9487e578f63680f7ca5803 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -3,6 +3,7 @@ use crate::{ example::{Example, ExampleScore}, headless::EpAppState, metrics, + parse_output::parse_prediction_output, predict::run_prediction, progress::{ExampleProgress, Step}, }; @@ -37,7 +38,27 @@ pub async fn run_scoring( progress.set_substatus("computing metrics"); let mut scores = vec![]; for prediction in &example.predictions { - let actual_text = match apply_diff_to_string(&prediction.actual_patch, original_text) { + let actual_patch = match &prediction.actual_patch { + Some(patch) => patch.clone(), + None => { + if prediction.actual_output.is_empty() { + scores.push(ExampleScore { delta_chr_f: 0.0 }); + continue; + } + match parse_prediction_output( + example, + &prediction.actual_output, + prediction.provider, + ) { + Ok(patch) => patch, + Err(_) => { + scores.push(ExampleScore { delta_chr_f: 0.0 }); + continue; + } + } + } + }; + let actual_text = match apply_diff_to_string(&actual_patch, original_text) { Ok(text) => text, Err(_) => { scores.push(ExampleScore { delta_chr_f: 0.0 });