@@ -65,7 +65,8 @@ pub struct ExampleState {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrompt {
pub input: String,
- pub expected_output: String,
+ #[serde(default)]
+ pub expected_output: Option<String>,
pub rejected_output: Option<String>, // For DPO
#[serde(default)]
pub prefill: Option<String>,
@@ -43,7 +43,7 @@ pub async fn run_format_prompt(
let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
example.prompt = Some(ExamplePrompt {
input: prompt,
- expected_output: String::new(),
+ expected_output: None,
rejected_output: None,
prefill: None,
provider: args.provider,
@@ -61,7 +61,7 @@ pub async fn run_format_prompt(
TeacherMultiRegionPrompt::format_prompt(example, editable_range, context_range);
example.prompt = Some(ExamplePrompt {
input: prompt,
- expected_output: String::new(),
+ expected_output: None,
rejected_output: None,
prefill: None,
provider: args.provider,
@@ -85,8 +85,7 @@ pub async fn run_format_prompt(
zeta_format,
)
.ok()
- })
- .unwrap_or_default();
+ });
let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| {
zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
@@ -195,7 +195,7 @@ pub async fn run_prediction(
if matches!(provider, PredictionProvider::Zeta2(_)) {
updated_example.prompt.get_or_insert(ExamplePrompt {
input: prompt,
- expected_output: String::new(),
+ expected_output: None,
rejected_output: None,
provider,
prefill: None,
@@ -1674,7 +1674,7 @@ fn build_rejected_example(
example.spec.rejected_patch = Some(rejected_patch);
example.prompt = prompt.map(|prompt| ExamplePrompt {
input: prompt,
- expected_output: String::new(),
+ expected_output: None,
rejected_output: Some(output),
prefill: None,
provider: PredictionProvider::default(),