Detailed changes
@@ -82,6 +82,10 @@ pub struct ExamplePrediction {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub provider: PredictionProvider,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cumulative_logprob: Option<f64>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub avg_logprob: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -166,6 +170,10 @@ pub struct ExampleScore {
pub inserted_tokens: usize,
#[serde(default)]
pub deleted_tokens: usize,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cumulative_logprob: Option<f64>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub avg_logprob: Option<f64>,
}
impl Example {
@@ -263,6 +263,8 @@ pub async fn run_prediction(
actual_cursor: None,
error: None,
provider,
+ cumulative_logprob: None,
+ avg_logprob: None,
});
step_progress.set_substatus("requesting prediction");
@@ -455,6 +457,8 @@ async fn predict_anthropic(
_ => PredictionProvider::TeacherNonBatching(backend),
}
},
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -572,6 +576,8 @@ async fn predict_openai(
_ => PredictionProvider::TeacherNonBatching(backend),
}
},
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -656,6 +662,8 @@ pub async fn predict_baseten(
actual_cursor,
error: None,
provider: PredictionProvider::Baseten(format),
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -426,6 +426,8 @@ pub async fn run_repair(
actual_cursor,
error: err,
provider: PredictionProvider::Repair,
+ cumulative_logprob: None,
+ avg_logprob: None,
});
Ok(())
@@ -78,6 +78,8 @@ pub async fn run_scoring(
has_isolated_whitespace_changes: false,
inserted_tokens: 0,
deleted_tokens: 0,
+ cumulative_logprob: None,
+ avg_logprob: None,
};
let cursor_path = example.spec.cursor_path.as_ref();
@@ -189,6 +191,8 @@ pub async fn run_scoring(
has_isolated_whitespace_changes,
inserted_tokens: token_changes.inserted_tokens,
deleted_tokens: token_changes.deleted_tokens,
+ cumulative_logprob: prediction.cumulative_logprob,
+ avg_logprob: prediction.avg_logprob,
});
}