diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 6a7c6232d08b15fccacdd80a446432e453a80e20..d9d9c2243d81640a55133843669514d551f64902 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -586,10 +586,11 @@ impl EditPredictionStore { pub fn edit_history_for_project( &self, project: &Entity, + cx: &App, ) -> Vec> { self.projects .get(&project.entity_id()) - .map(|project_state| project_state.events.iter().cloned().collect()) + .map(|project_state| project_state.events(cx)) .unwrap_or_default() } diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 8586e6caaea1fdc9c865ddba8894f680d766b4a9..9706e2b9ecd03f6e8ba592210722725f420643d3 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -228,13 +228,16 @@ pub fn zeta2_prompt_input( } #[cfg(feature = "cli-support")] -pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> String { - eprintln!("{}", patch); - eprintln!("---------------------"); - eprintln!("{}", input.cursor_excerpt); - crate::udiff::apply_diff_to_string( - patch, - &input.cursor_excerpt[input.editable_range_in_excerpt.clone()], - ) - .unwrap() +pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result { + let text = &input.cursor_excerpt; + let editable_region = input.editable_range_in_excerpt.clone(); + let old_prefix = &text[..editable_region.start]; + let old_suffix = &text[editable_region.end..]; + + let new = crate::udiff::apply_diff_to_string(patch, text)?; + if !new.starts_with(old_prefix) || !new.ends_with(old_suffix) { + anyhow::bail!("Patch shouldn't affect text outside of editable region"); + } + + Ok(new[editable_region.start..new.len() - old_suffix.len()].to_string()) } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index f8fd9b2023a84abcf59bcb5ba54d2d228a0c6484..c778b708b701492b0cc85a0030a1e9d090ce0724 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -44,7 +44,7 @@ pub async fn run_format_prompt( let state = example.state.as_ref().context("state must be set")?; let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?; let project = state.project.clone(); - let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| { + let (_, input) = ep_store.update(&mut cx, |ep_store, cx| { anyhow::Ok(zeta2_prompt_input( &snapshot, example @@ -53,7 +53,7 @@ pub async fn run_format_prompt( .context("context must be set")? .files .clone(), - ep_store.edit_history_for_project(&project), + ep_store.edit_history_for_project(&project, cx), example.cursor_path.clone(), example .buffer @@ -63,7 +63,7 @@ pub async fn run_format_prompt( )) })??; let prompt = format_zeta_prompt(&input); - let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone()); + let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone())?; example.prompt = Some(ExamplePrompt { input: prompt, expected_output,