Detailed changes
@@ -6,7 +6,7 @@ use crate::example::Example;
pub async fn run_distill(example: &mut Example) -> Result<()> {
let predictions = mem::take(&mut example.predictions)
.into_iter()
- .map(|p| p.actual_patch)
+ .filter_map(|p| p.actual_patch)
.collect();
example.spec.expected_patches = predictions;
@@ -73,7 +73,8 @@ pub struct ExamplePrompt {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrediction {
- pub actual_patch: String,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub actual_patch: Option<String>,
pub actual_output: String,
pub provider: PredictionProvider,
}
@@ -6,6 +6,7 @@ mod git;
mod headless;
mod load_project;
mod metrics;
+mod parse_output;
mod paths;
mod predict;
mod progress;
@@ -130,6 +131,9 @@ enum Command {
FormatPrompt(FormatPromptArgs),
/// Runs edit prediction
Predict(PredictArgs),
+ /// Parse model outputs (actual_output) into unified diffs (actual_patch).
+ /// Requires format-prompt to have been run first. Uses provider from prompt.
+ ParseOutput,
/// Computes a score based on actual and expected patches
Score(PredictArgs),
/// Prepares a distillation dataset by copying expected outputs to
@@ -159,6 +163,7 @@ impl Display for Command {
Command::Predict(args) => {
write!(f, "predict --provider={}", args.provider)
}
+ Command::ParseOutput => write!(f, "parse-output"),
Command::Score(args) => {
write!(f, "score --provider={}", args.provider)
}
@@ -601,6 +606,9 @@ fn main() {
)
.await?;
}
+ Command::ParseOutput => {
+ parse_output::run_parse_output(example)?;
+ }
Command::Distill => {
run_distill(example).await?;
}
@@ -0,0 +1,234 @@
+use crate::{PredictionProvider, example::Example, format_prompt::TeacherPrompt};
+use anyhow::{Context as _, Result};
+use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
+
+pub fn run_parse_output(example: &mut Example) -> Result<()> {
+ let provider = example
+ .prompt
+ .as_ref()
+ .context("prompt required (run format-prompt first)")?
+ .provider;
+ example
+ .prompt_inputs
+ .as_ref()
+ .context("prompt_inputs required")?;
+
+ let parsed_patches: Vec<_> = example
+ .predictions
+ .iter()
+ .enumerate()
+ .filter(|(_, p)| !p.actual_output.is_empty())
+ .map(|(ix, prediction)| {
+ let actual_patch =
+ parse_prediction_output(example, &prediction.actual_output, provider);
+ actual_patch.map(|patch| (ix, patch))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ for (ix, actual_patch) in parsed_patches {
+ example.predictions[ix].actual_patch = Some(actual_patch);
+ example.predictions[ix].provider = provider;
+ }
+
+ Ok(())
+}
+
+pub fn parse_prediction_output(
+ example: &Example,
+ actual_output: &str,
+ provider: PredictionProvider,
+) -> Result<String> {
+ match provider {
+ PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
+ TeacherPrompt::parse(example, actual_output)
+ }
+ PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
+ _ => anyhow::bail!(
+ "parse-output only supports Teacher and Zeta2 providers, got {:?}",
+ provider
+ ),
+ }
+}
+
+fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
+ let (current_marker, end_marker) = match version {
+ ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
+ ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
+ ("<|fim_middle|>current\n", "<|fim_suffix|>")
+ }
+ ZetaVersion::V0120GitMergeMarkers => (
+ zeta_prompt::v0120_git_merge_markers::START_MARKER,
+ zeta_prompt::v0120_git_merge_markers::SEPARATOR,
+ ),
+ };
+
+ let start = prompt.find(current_marker).with_context(|| {
+ format!(
+ "missing current marker '{}' in prompt",
+ current_marker.trim()
+ )
+ })? + current_marker.len();
+
+ let end = prompt[start..]
+ .find(end_marker)
+ .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
+ + start;
+
+ let region = &prompt[start..end];
+ let region = region.strip_suffix('\n').unwrap_or(region);
+ Ok(region.replace(CURSOR_MARKER, ""))
+}
+
+fn parse_zeta2_output(
+ example: &Example,
+ actual_output: &str,
+ version: ZetaVersion,
+) -> Result<String> {
+ let prompt = &example.prompt.as_ref().context("prompt required")?.input;
+ let prompt_inputs = example
+ .prompt_inputs
+ .as_ref()
+ .context("prompt_inputs required")?;
+
+ let old_text = extract_zeta2_current_region(prompt, version)?;
+
+ let mut new_text = actual_output.replace(CURSOR_MARKER, "");
+
+ if version == ZetaVersion::V0120GitMergeMarkers {
+ if let Some(stripped) =
+ new_text.strip_suffix(zeta_prompt::v0120_git_merge_markers::END_MARKER)
+ {
+ new_text = stripped.to_string();
+ }
+ }
+
+ let mut old_text_normalized = old_text.clone();
+ if !new_text.is_empty() && !new_text.ends_with('\n') {
+ new_text.push('\n');
+ }
+ if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
+ old_text_normalized.push('\n');
+ }
+
+ let old_text_trimmed = old_text.trim_end_matches('\n');
+ let (editable_region_offset, _) = prompt_inputs
+ .content
+ .match_indices(old_text_trimmed)
+ .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
+ .with_context(|| {
+ format!(
+ "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
+ old_text_trimmed, &prompt_inputs.content
+ )
+ })?;
+
+ let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
+ .matches('\n')
+ .count();
+
+ let diff = language::unified_diff_with_offsets(
+ &old_text_normalized,
+ &new_text,
+ editable_region_start_line as u32,
+ editable_region_start_line as u32,
+ );
+
+ let formatted_diff = format!(
+ "--- a/{path}\n+++ b/{path}\n{diff}",
+ path = example.spec.cursor_path.to_string_lossy(),
+ );
+
+ Ok(formatted_diff)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_extract_zeta2_current_region_v0113() {
+ let prompt = indoc::indoc! {"
+ <|file_sep|>src/main.rs
+ <|fim_prefix|>
+ fn main() {
+ <|fim_middle|>current
+ println!(\"hello\");
+ <|fim_suffix|>
+ }
+ <|fim_middle|>updated
+ "};
+
+ let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
+ assert_eq!(region, "println!(\"hello\");");
+ }
+
+ #[test]
+ fn test_extract_zeta2_current_region_v0112() {
+ let prompt = indoc::indoc! {"
+ <|file_sep|>src/main.rs
+ <|fim_prefix|>
+ fn main() {
+ <|fim_suffix|>
+ }
+ <|fim_middle|>current
+ println!(\"hello\");
+ <|fim_middle|>updated
+ "};
+
+ let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
+ assert_eq!(region, "println!(\"hello\");");
+ }
+
+ #[test]
+ fn test_extract_zeta2_current_region_with_cursor_marker() {
+ let prompt = indoc::indoc! {"
+ <|file_sep|>src/main.rs
+ <|fim_prefix|>
+ fn main() {
+ <|fim_middle|>current
+ print<|user_cursor|>ln!(\"hello\");
+ <|fim_suffix|>
+ }
+ <|fim_middle|>updated
+ "};
+
+ let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
+ assert_eq!(region, "println!(\"hello\");");
+ }
+
+ #[test]
+ fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
+ let prompt = indoc::indoc! {"
+ <|file_sep|>src/main.rs
+ <|fim_prefix|>
+ fn main() {
+ <|fim_suffix|>
+ }
+ <|fim_middle|><<<<<<< CURRENT
+ println!(\"hello\");
+ =======
+ "};
+
+ let region =
+ extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
+ assert_eq!(region, "println!(\"hello\");");
+ }
+
+ #[test]
+ fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
+ let prompt = indoc::indoc! {"
+ <|file_sep|>src/main.rs
+ <|fim_prefix|>
+ fn main() {
+ <|fim_suffix|>
+ }
+ <|fim_middle|><<<<<<< CURRENT
+ print<|user_cursor|>ln!(\"hello\");
+ =======
+ "};
+
+ let region =
+ extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
+ assert_eq!(region, "println!(\"hello\");");
+ }
+}
@@ -186,7 +186,7 @@ pub async fn run_prediction(
.unwrap()
.predictions
.push(ExamplePrediction {
- actual_patch: String::new(),
+ actual_patch: None,
actual_output: String::new(),
provider,
});
@@ -204,16 +204,14 @@ pub async fn run_prediction(
})
.await?;
- let actual_patch = prediction
- .and_then(|prediction| {
- let prediction = prediction.prediction.ok()?;
- prediction
- .edit_preview
- .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
- })
- .unwrap_or_default();
+ let actual_patch = prediction.and_then(|prediction| {
+ let prediction = prediction.prediction.ok()?;
+ prediction
+ .edit_preview
+ .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
+ });
- let has_prediction = !actual_patch.is_empty();
+ let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
updated_example
.lock()
@@ -293,7 +291,7 @@ async fn predict_anthropic(
let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
let prediction = ExamplePrediction {
- actual_patch,
+ actual_patch: Some(actual_patch),
actual_output,
provider: if batched {
PredictionProvider::Teacher(version)
@@ -3,6 +3,7 @@ use crate::{
example::{Example, ExampleScore},
headless::EpAppState,
metrics,
+ parse_output::parse_prediction_output,
predict::run_prediction,
progress::{ExampleProgress, Step},
};
@@ -37,7 +38,27 @@ pub async fn run_scoring(
progress.set_substatus("computing metrics");
let mut scores = vec![];
for prediction in &example.predictions {
- let actual_text = match apply_diff_to_string(&prediction.actual_patch, original_text) {
+ let actual_patch = match &prediction.actual_patch {
+ Some(patch) => patch.clone(),
+ None => {
+ if prediction.actual_output.is_empty() {
+ scores.push(ExampleScore { delta_chr_f: 0.0 });
+ continue;
+ }
+ match parse_prediction_output(
+ example,
+ &prediction.actual_output,
+ prediction.provider,
+ ) {
+ Ok(patch) => patch,
+ Err(_) => {
+ scores.push(ExampleScore { delta_chr_f: 0.0 });
+ continue;
+ }
+ }
+ }
+ };
+ let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
Ok(text) => text,
Err(_) => {
scores.push(ExampleScore { delta_chr_f: 0.0 });