diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 495ca26f97af5f2c2c1dc50ea339881853d9ebbc..196f4f96d99b64aed2ff3ae2d7a9897295a60b29 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -82,6 +82,10 @@ pub struct ExamplePrediction { #[serde(default, skip_serializing_if = "Option::is_none")] pub error: Option, pub provider: PredictionProvider, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cumulative_logprob: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub avg_logprob: Option, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -166,6 +170,10 @@ pub struct ExampleScore { pub inserted_tokens: usize, #[serde(default)] pub deleted_tokens: usize, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cumulative_logprob: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub avg_logprob: Option, } impl Example { diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 9f70861b5ef7298141441ec09606fa77e341cbfd..df797b0abaa4933e73e40b746797ffb5581d7f79 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -263,6 +263,8 @@ pub async fn run_prediction( actual_cursor: None, error: None, provider, + cumulative_logprob: None, + avg_logprob: None, }); step_progress.set_substatus("requesting prediction"); @@ -455,6 +457,8 @@ async fn predict_anthropic( _ => PredictionProvider::TeacherNonBatching(backend), } }, + cumulative_logprob: None, + avg_logprob: None, }; example.predictions.push(prediction); @@ -572,6 +576,8 @@ async fn predict_openai( _ => PredictionProvider::TeacherNonBatching(backend), } }, + cumulative_logprob: None, + avg_logprob: None, }; example.predictions.push(prediction); @@ -656,6 +662,8 @@ pub async fn predict_baseten( actual_cursor, error: None, provider: PredictionProvider::Baseten(format), + cumulative_logprob: None, + avg_logprob: None, }; example.predictions.push(prediction); diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index 9d891314bc62a44e730b584cea3423df665dc381..a0c4242748c9ad83c3b0fbe9e70a4b132ac75c4d 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -426,6 +426,8 @@ pub async fn run_repair( actual_cursor, error: err, provider: PredictionProvider::Repair, + cumulative_logprob: None, + avg_logprob: None, }); Ok(()) diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index b6f745114f6dd2a091b95b724ee53869a04a8c4e..d75cf55e85b198bc28469e83d8f9209a8a59a83f 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -78,6 +78,8 @@ pub async fn run_scoring( has_isolated_whitespace_changes: false, inserted_tokens: 0, deleted_tokens: 0, + cumulative_logprob: None, + avg_logprob: None, }; let cursor_path = example.spec.cursor_path.as_ref(); @@ -189,6 +191,8 @@ pub async fn run_scoring( has_isolated_whitespace_changes, inserted_tokens: token_changes.inserted_tokens, deleted_tokens: token_changes.deleted_tokens, + cumulative_logprob: prediction.cumulative_logprob, + avg_logprob: prediction.avg_logprob, }); }