zeta2: Print average length of prompts and outputs (#42885)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/zeta_cli/src/evaluate.rs | 27 ++++++++++++++++++++-------
crates/zeta_cli/src/predict.rs  | 12 ++++++++----
2 files changed, 28 insertions(+), 11 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs 🔗

@@ -46,7 +46,7 @@ pub async fn run_evaluate(
     }
     let all_tasks = args.example_paths.into_iter().map(|path| {
         let app_state = app_state.clone();
-        let example = NamedExample::load(&path).unwrap();
+        let example = NamedExample::load(&path).expect("Failed to load example");
 
         cx.spawn(async move |cx| {
             let (project, zetas, _edited_buffers) = example
@@ -129,12 +129,15 @@ fn write_aggregated_scores(
         let aggregated_result = EvaluationResult {
             context: Scores::aggregate(successful.iter().map(|r| &r.context)),
             edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
+            prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
+            generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
+                / successful.len(),
         };
 
         writeln!(w, "\n{}", "-".repeat(80))?;
         writeln!(w, "\n## TOTAL SCORES")?;
         writeln!(w, "\n### Success Rate")?;
-        writeln!(w, "{}", aggregated_result)?;
+        writeln!(w, "{:#}", aggregated_result)?;
     }
 
     if successful.len() + failed_count > 1 {
@@ -241,6 +244,8 @@ fn write_eval_result(
 pub struct EvaluationResult {
     pub edit_prediction: Option<Scores>,
     pub context: Scores,
+    pub prompt_len: usize,
+    pub generated_len: usize,
 }
 
 #[derive(Default, Debug)]
@@ -362,15 +367,17 @@ impl EvaluationResult {
         writeln!(f, "### Scores\n")?;
         writeln!(
             f,
-            "                   TP     FP     FN     Precision   Recall     F1"
+            "                   Prompt  Generated  TP     FP     FN     Precision   Recall     F1"
         )?;
         writeln!(
             f,
-            "──────────────────────────────────────────────────────────────────"
+            "────────────────────────────────────────────────────────────────────────────────────"
         )?;
         writeln!(
             f,
-            "Context Retrieval  {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+            "Context Retrieval  {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+            "",
+            "",
             self.context.true_positives,
             self.context.false_positives,
             self.context.false_negatives,
@@ -381,7 +388,9 @@ impl EvaluationResult {
         if let Some(edit_prediction) = &self.edit_prediction {
             writeln!(
                 f,
-                "Edit Prediction    {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+                "Edit Prediction    {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+                self.prompt_len,
+                self.generated_len,
                 edit_prediction.true_positives,
                 edit_prediction.false_positives,
                 edit_prediction.false_negatives,
@@ -395,7 +404,11 @@ impl EvaluationResult {
 }
 
 pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
-    let mut eval_result = EvaluationResult::default();
+    let mut eval_result = EvaluationResult {
+        prompt_len: preds.prompt_len,
+        generated_len: preds.generated_len,
+        ..Default::default()
+    };
 
     let actual_context_lines: HashSet<_> = preds
         .excerpts

crates/zeta_cli/src/predict.rs 🔗

@@ -179,13 +179,12 @@ pub async fn zeta2_predict(
                     zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
                         let prediction_started_at = Instant::now();
                         start_time.get_or_insert(prediction_started_at);
-                        fs::write(
-                            example_run_dir.join("prediction_prompt.md"),
-                            &request.local_prompt.unwrap_or_default(),
-                        )?;
+                        let prompt = request.local_prompt.unwrap_or_default();
+                        fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
 
                         {
                             let mut result = result.lock().unwrap();
+                            result.prompt_len = prompt.chars().count();
 
                             for included_file in request.request.included_files {
                                 let insertions =
@@ -217,6 +216,7 @@ pub async fn zeta2_predict(
                         fs::write(example_run_dir.join("prediction_response.md"), &response)?;
 
                         let mut result = result.lock().unwrap();
+                        result.generated_len = response.chars().count();
 
                         if !use_expected_context {
                             result.planning_search_time =
@@ -411,6 +411,8 @@ pub struct PredictionDetails {
     pub prediction_time: Duration,
     pub total_time: Duration,
     pub run_example_dir: PathBuf,
+    pub prompt_len: usize,
+    pub generated_len: usize,
 }
 
 impl PredictionDetails {
@@ -424,6 +426,8 @@ impl PredictionDetails {
             prediction_time: Default::default(),
             total_time: Default::default(),
             run_example_dir,
+            prompt_len: 0,
+            generated_len: 0,
         }
     }