@@ -2,36 +2,30 @@ use crate::{
PredictionProvider,
example::{ActualCursor, Example},
format_prompt::TeacherPrompt,
+ repair,
};
use anyhow::{Context as _, Result};
use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
pub fn run_parse_output(example: &mut Example) -> Result<()> {
- let provider = example
- .prompt
- .as_ref()
- .context("prompt required (run format-prompt first)")?
- .provider;
example
.prompt_inputs
.as_ref()
.context("prompt_inputs required")?;
- let parsed_patches: Vec<_> = example
+ let to_parse: Vec<_> = example
.predictions
.iter()
.enumerate()
.filter(|(_, p)| !p.actual_output.is_empty())
- .map(|(ix, prediction)| {
- let result = parse_prediction_output(example, &prediction.actual_output, provider);
- result.map(|(patch, cursor_offset)| (ix, patch, cursor_offset))
- })
- .collect::<Result<Vec<_>>>()?;
+ .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
+ .collect();
- for (ix, actual_patch, actual_cursor) in parsed_patches {
+ for (ix, actual_output, provider) in to_parse {
+ let (actual_patch, actual_cursor) =
+ parse_prediction_output(example, &actual_output, provider)?;
example.predictions[ix].actual_patch = Some(actual_patch);
example.predictions[ix].actual_cursor = actual_cursor;
- example.predictions[ix].provider = provider;
}
Ok(())
@@ -47,6 +41,7 @@ pub fn parse_prediction_output(
TeacherPrompt::parse(example, actual_output)
}
PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
+ PredictionProvider::Repair => repair::parse(example, actual_output),
_ => anyhow::bail!(
"parse-output only supports Teacher and Zeta2 providers, got {:?}",
provider
@@ -9,7 +9,7 @@
use crate::{
BatchProvider, PredictionProvider,
anthropic_client::AnthropicClient,
- example::{Example, ExamplePrediction},
+ example::{ActualCursor, Example, ExamplePrediction},
format_prompt::{TeacherPrompt, extract_cursor_excerpt_from_example, extract_last_codeblock},
openai_client::OpenAiClient,
parse_output::run_parse_output,
@@ -233,6 +233,25 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
false
}
+/// Parse repair model output into a patch and optional cursor.
+///
+/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
+/// and delegates normal output to `TeacherPrompt::parse`.
+pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
+ let last_codeblock = extract_last_codeblock(actual_output);
+ if last_codeblock.trim() == KEEP_PREVIOUS {
+ let original = example
+ .predictions
+ .first()
+ .context("no original prediction to keep")?;
+ let patch = original.actual_patch.clone().unwrap_or_default();
+ let cursor = original.actual_cursor.clone();
+ return Ok((patch, cursor));
+ }
+
+ TeacherPrompt::parse(example, actual_output)
+}
+
/// Check if an example already has a successful repair prediction.
fn has_successful_repair(example: &Example) -> bool {
example
@@ -354,37 +373,22 @@ pub async fn run_repair(
}
};
- let last_codeblock = extract_last_codeblock(&response);
- if last_codeblock.trim() == KEEP_PREVIOUS {
- let original = example
- .predictions
- .first()
- .context("no original prediction to keep")?;
- example.predictions.push(ExamplePrediction {
- actual_patch: original.actual_patch.clone(),
- actual_output: response,
- actual_cursor: original.actual_cursor.clone(),
- error: None,
- provider: PredictionProvider::Repair,
- });
- } else {
- let parse_result = TeacherPrompt::parse(example, &response);
- let err = parse_result
- .as_ref()
- .err()
- .map(|e| format!("Failed to parse repair response: {}", e));
-
- let (actual_patch, actual_cursor) = parse_result.ok().unzip();
- let actual_cursor = actual_cursor.flatten();
-
- example.predictions.push(ExamplePrediction {
- actual_patch,
- actual_output: response,
- actual_cursor,
- error: err,
- provider: PredictionProvider::Repair,
- });
- }
+ let parse_result = parse(example, &response);
+ let err = parse_result
+ .as_ref()
+ .err()
+ .map(|e| format!("Failed to parse repair response: {}", e));
+
+ let (actual_patch, actual_cursor) = parse_result.ok().unzip();
+ let actual_cursor = actual_cursor.flatten();
+
+ example.predictions.push(ExamplePrediction {
+ actual_patch,
+ actual_output: response,
+ actual_cursor,
+ error: err,
+ provider: PredictionProvider::Repair,
+ });
Ok(())
}