@@ -31,6 +31,8 @@ pub struct EvaluateArguments {
cache: CacheMode,
#[clap(short, long, default_value_t = 1, alias = "repeat")]
repetitions: u16,
+ #[arg(long)]
+ skip_prediction: bool,
}
pub async fn run_evaluate(
@@ -66,6 +68,7 @@ pub async fn run_evaluate(
zeta,
args.prompt_format,
args.use_expected_context,
+ !args.skip_prediction,
args.cache,
cx,
)
@@ -118,9 +121,14 @@ fn write_aggregated_scores(
}
if successful.len() > 1 {
+ let mut edit_predictions = successful
+ .iter()
+ .filter_map(|r| r.edit_prediction.as_ref())
+ .peekable();
+ let has_edit_predictions = edit_predictions.peek().is_some();
let aggregated_result = EvaluationResult {
context: Scores::aggregate(successful.iter().map(|r| &r.context)),
- edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
+ edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
};
writeln!(w, "\n{}", "-".repeat(80))?;
@@ -149,6 +157,7 @@ pub async fn run_evaluate_one(
zeta: Entity<Zeta>,
prompt_format: PromptFormat,
use_expected_context: bool,
+ predict: bool,
cache_mode: CacheMode,
cx: &mut AsyncApp,
) -> Result<EvaluationResult> {
@@ -164,7 +173,7 @@ pub async fn run_evaluate_one(
)
.await?;
- let evaluation_result = evaluate(&example.example, &predict_result);
+ let evaluation_result = evaluate(&example.example, &predict_result, predict);
if repetition_ix.is_none() {
write_eval_result(
@@ -173,6 +182,7 @@ pub async fn run_evaluate_one(
&evaluation_result,
&mut std::io::stdout(),
std::io::stdout().is_terminal(),
+ predict,
)?;
}
@@ -185,6 +195,7 @@ pub async fn run_evaluate_one(
&evaluation_result,
&mut results_file,
false,
+ predict,
)
.log_err();
}
@@ -198,25 +209,29 @@ fn write_eval_result(
evaluation_result: &EvaluationResult,
out: &mut impl Write,
use_color: bool,
+ predict: bool,
) -> Result<()> {
- writeln!(
- out,
- "## Expected edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(
- &example.example.expected_patch,
- &predictions.diff,
- use_color
- )
- )?;
- writeln!(
- out,
- "## Actual edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(
- &predictions.diff,
- &example.example.expected_patch,
- use_color
- )
- )?;
+ if predict {
+ writeln!(
+ out,
+ "## Expected edit prediction:\n\n```diff\n{}\n```\n",
+ compare_diffs(
+ &example.example.expected_patch,
+ &predictions.diff,
+ use_color
+ )
+ )?;
+ writeln!(
+ out,
+ "## Actual edit prediction:\n\n```diff\n{}\n```\n",
+ compare_diffs(
+ &predictions.diff,
+ &example.example.expected_patch,
+ use_color
+ )
+ )?;
+ }
+
writeln!(out, "{:#}", evaluation_result)?;
anyhow::Ok(())
@@ -224,7 +239,7 @@ fn write_eval_result(
#[derive(Debug, Default)]
pub struct EvaluationResult {
- pub edit_prediction: Scores,
+ pub edit_prediction: Option<Scores>,
pub context: Scores,
}
@@ -328,13 +343,19 @@ impl EvaluationResult {
r#"
### Context Scores
{}
-
-### Edit Prediction Scores
-{}
"#,
self.context.to_markdown(),
- self.edit_prediction.to_markdown()
- )
+ )?;
+ if let Some(prediction) = &self.edit_prediction {
+ write!(
+ f,
+ r#"
+ ### Edit Prediction Scores
+ {}"#,
+ prediction.to_markdown()
+ )?;
+ }
+ Ok(())
}
fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -357,20 +378,23 @@ impl EvaluationResult {
self.context.recall() * 100.0,
self.context.f1_score() * 100.0
)?;
- writeln!(
- f,
- "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
- self.edit_prediction.true_positives,
- self.edit_prediction.false_positives,
- self.edit_prediction.false_negatives,
- self.edit_prediction.precision() * 100.0,
- self.edit_prediction.recall() * 100.0,
- self.edit_prediction.f1_score() * 100.0
- )
+ if let Some(edit_prediction) = &self.edit_prediction {
+ writeln!(
+ f,
+ "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+ edit_prediction.true_positives,
+ edit_prediction.false_positives,
+ edit_prediction.false_negatives,
+ edit_prediction.precision() * 100.0,
+ edit_prediction.recall() * 100.0,
+ edit_prediction.f1_score() * 100.0
+ )?;
+ }
+ Ok(())
}
}
-pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
+pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
let mut eval_result = EvaluationResult::default();
let actual_context_lines: HashSet<_> = preds
@@ -420,24 +444,27 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
eval_result.context.false_positives = false_positive_lines.len();
- // todo: alternatives for patches
- let expected_patch_lines = example
- .expected_patch
- .lines()
- .map(DiffLine::parse)
- .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
- .map(|line| line.to_string())
- .collect();
-
- let actual_patch_lines = preds
- .diff
- .lines()
- .map(DiffLine::parse)
- .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
- .map(|line| line.to_string())
- .collect();
+ if predict {
+ // todo: alternatives for patches
+ let expected_patch_lines = example
+ .expected_patch
+ .lines()
+ .map(DiffLine::parse)
+ .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+ .map(|line| line.to_string())
+ .collect();
+
+ let actual_patch_lines = preds
+ .diff
+ .lines()
+ .map(DiffLine::parse)
+ .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+ .map(|line| line.to_string())
+ .collect();
+
+ eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
+ }
- eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
eval_result
}