From b64256552659d7cf54ff23e43ea2f11311b6e500 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Sat, 31 Jan 2026 09:16:32 -0700 Subject: [PATCH] Allow zeta2 to predict next cursor position along with edits (#47916) * [x] capture and store teacher model's predicted cursor position * [x] provide cursor position to student during distillation * [x] eval cursor positions * [x] parse and apply cursor position predictions at runtime Release Notes: - N/A --- crates/codestral/src/codestral.rs | 1 + .../src/copilot_edit_prediction_delegate.rs | 1 + .../src/edit_prediction_tests.rs | 1 + crates/edit_prediction/src/example_spec.rs | 190 +++++++++++++++ crates/edit_prediction/src/mercury.rs | 1 + crates/edit_prediction/src/prediction.rs | 6 +- crates/edit_prediction/src/sweep_ai.rs | 1 + crates/edit_prediction/src/udiff.rs | 106 ++------- .../src/zed_edit_prediction_delegate.rs | 1 + crates/edit_prediction/src/zeta1.rs | 1 + crates/edit_prediction/src/zeta2.rs | 93 +++++++- crates/edit_prediction_cli/src/distill.rs | 6 +- crates/edit_prediction_cli/src/example.rs | 6 + .../edit_prediction_cli/src/format_prompt.rs | 103 ++++---- .../edit_prediction_cli/src/parse_output.rs | 27 ++- crates/edit_prediction_cli/src/predict.rs | 7 +- crates/edit_prediction_cli/src/repair.rs | 4 +- crates/edit_prediction_cli/src/score.rs | 219 ++++++++++++++++-- .../src/edit_prediction_types.rs | 29 +++ crates/editor/src/edit_prediction_tests.rs | 128 +++++++++- crates/editor/src/editor.rs | 48 +++- crates/editor/src/editor_tests.rs | 1 + crates/language/src/language.rs | 2 +- crates/language/src/text_diff.rs | 57 ++++- .../supermaven_edit_prediction_delegate.rs | 1 + typos.toml | 1 + 26 files changed, 835 insertions(+), 206 deletions(-) diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index afec79bef7f6d5f523b1ad2d110982e0a1dd467a..d5fc8d53b7e24b64598bc9b717fa043080f99680 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -330,6 +330,7 @@ impl EditPredictionDelegate for CodestralEditPredictionDelegate { Some(EditPrediction::Local { id: None, edits, + cursor_position: None, edit_preview: Some(current_completion.edit_preview.clone()), }) } diff --git a/crates/copilot/src/copilot_edit_prediction_delegate.rs b/crates/copilot/src/copilot_edit_prediction_delegate.rs index ffd4414a49066175cc58ad1c59dacb8d31a94bff..a8e8e231baa94397b72205ee90a86e80dad8ca80 100644 --- a/crates/copilot/src/copilot_edit_prediction_delegate.rs +++ b/crates/copilot/src/copilot_edit_prediction_delegate.rs @@ -177,6 +177,7 @@ impl EditPredictionDelegate for CopilotEditPredictionDelegate { Some(EditPrediction::Local { id: None, edits, + cursor_position: None, edit_preview: Some(edit_preview.clone()), }) } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 14e6efa903d0cf074477564ebd6d1c6cfd8bdf30..990373b5f9dedd5dd4c5c07e09a86d28a57c8135 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1433,6 +1433,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let prediction = EditPrediction { edits, + cursor_position: None, edit_preview, buffer: buffer.clone(), snapshot: cx.read(|cx| buffer.read(cx).snapshot()), diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index d3ba165e51048c14d6e6d9bcc857557bf352c81a..f326ab749b31bb32895ea5f5d9e92e4b4f4853ec 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -1,3 +1,4 @@ +use crate::udiff::DiffLine; use anyhow::{Context as _, Result}; use serde::{Deserialize, Serialize}; use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc}; @@ -491,6 +492,123 @@ impl ExampleSpec { self.cursor_position = result; } + + /// Returns all of the possible expected patches for this example, each with an optional + /// cursor offset. + /// + /// The cursor offset is an offset within the new text (after applying the patch), relative + /// to the start of the hunk. + /// + /// In the serialized representation of this example, the cursor position is represented + /// using a comment line in the diff, beginning with `#`, and containing a `[CURSOR_POSITION]` + /// marker with the same format as the [`Self::cursor_excerpt`]. + pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option)> { + self.expected_patches + .iter() + .map(|patch| { + let mut clean_patch = String::new(); + let mut cursor_offset: Option = None; + let mut line_start_offset = 0usize; + let mut prev_line_start_offset = 0usize; + + for line in patch.lines() { + let diff_line = DiffLine::parse(line); + + match &diff_line { + DiffLine::Garbage(content) + if content.starts_with('#') + && content.contains(CURSOR_POSITION_MARKER) => + { + let caret_column = if let Some(caret_pos) = content.find('^') { + caret_pos + } else if let Some(_) = content.find('<') { + 0 + } else { + continue; + }; + let cursor_column = caret_column.saturating_sub('#'.len_utf8()); + cursor_offset = Some(prev_line_start_offset + cursor_column); + } + _ => { + if !clean_patch.is_empty() { + clean_patch.push('\n'); + } + clean_patch.push_str(line); + + match diff_line { + DiffLine::Addition(content) | DiffLine::Context(content) => { + prev_line_start_offset = line_start_offset; + line_start_offset += content.len() + 1; + } + _ => {} + } + } + } + } + + if patch.ends_with('\n') && !clean_patch.is_empty() { + clean_patch.push('\n'); + } + + (clean_patch, cursor_offset) + }) + .collect() + } + + pub fn set_expected_patches_with_cursor_positions( + &mut self, + patches: Vec<(String, Option)>, + ) { + self.expected_patches = patches + .into_iter() + .map(|(patch, cursor_offset)| { + let Some(cursor_offset) = cursor_offset else { + return patch; + }; + + let mut result = String::new(); + let mut line_start_offset = 0usize; + + for line in patch.lines() { + if !result.is_empty() { + result.push('\n'); + } + result.push_str(line); + + match DiffLine::parse(line) { + DiffLine::Addition(content) => { + let line_end_offset = line_start_offset + content.len(); + + if cursor_offset >= line_start_offset + && cursor_offset <= line_end_offset + { + let cursor_column = cursor_offset - line_start_offset; + + result.push('\n'); + result.push('#'); + for _ in 0..cursor_column { + result.push(' '); + } + write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap(); + } + + line_start_offset = line_end_offset + 1; + } + DiffLine::Context(content) => { + line_start_offset += content.len() + 1; + } + _ => {} + } + } + + if patch.ends_with('\n') { + result.push('\n'); + } + + result + }) + .collect(); + } } #[cfg(test)] @@ -707,4 +825,76 @@ mod tests { (expected_excerpt.to_string(), expected_offset) ); } + + #[test] + fn test_expected_patches_with_cursor_positions() { + let mut spec = ExampleSpec { + name: String::new(), + repository_url: String::new(), + revision: String::new(), + tags: Vec::new(), + reasoning: None, + uncommitted_diff: String::new(), + cursor_path: Path::new("test.rs").into(), + cursor_position: String::new(), + edit_history: String::new(), + expected_patches: Vec::new(), + rejected_patch: None, + captured_prompt_input: None, + telemetry: None, + human_feedback: Vec::new(), + rating: None, + }; + + let new_content = indoc! {r#" + // prints a greeting + fn main() { + println!("hello, {}", ); + let x = 42; + } + "#}; + let cursor_offset = new_content.find(");").unwrap(); + + let clean_patch = indoc! {r#" + --- a/test.rs + +++ b/test.rs + @@ -1,3 +1,4 @@ + +// prints a greeting + fn main() { + - println!("hi"); + + println!("hello, {}", ); + let x = 42; + } + "#} + .to_string(); + + let encoded_patch = indoc! {r#" + --- a/test.rs + +++ b/test.rs + @@ -1,3 +1,4 @@ + +// prints a greeting + fn main() { + - println!("hi"); + + println!("hello, {}", ); + # ^[CURSOR_POSITION] + let x = 42; + } + "#} + .to_string(); + + spec.set_expected_patches_with_cursor_positions(vec![( + clean_patch.clone(), + Some(cursor_offset), + )]); + assert_eq!(spec.expected_patches, vec![encoded_patch]); + + let results = spec.expected_patches_with_cursor_positions(); + assert_eq!(results, vec![(clean_patch.clone(), Some(cursor_offset))]); + + spec.set_expected_patches_with_cursor_positions(vec![(clean_patch.clone(), None)]); + assert_eq!(spec.expected_patches, vec![clean_patch.clone()]); + + let results = spec.expected_patches_with_cursor_positions(); + assert_eq!(results, vec![(clean_patch, None)]); + } } diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 3905939f05358fba18379ea17f8407d26f3b3a2a..4396f5ac880bcea15d44bbf987ac50f16a746615 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -206,6 +206,7 @@ impl Mercury { &buffer, &old_snapshot, edits.into(), + None, buffer_snapshotted_at, response_received_at, inputs, diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 9035d326c2295671ee382e3b905508b577929b9c..8d4a40d8b9ddf7a2ed8a68773da83a9498c4d516 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -5,7 +5,7 @@ use std::{ }; use cloud_llm_client::EditPredictionRejectReason; -use edit_prediction_types::interpolate_edits; +use edit_prediction_types::{PredictedCursorPosition, interpolate_edits}; use gpui::{AsyncApp, Entity, SharedString}; use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot}; use zeta_prompt::ZetaPromptInput; @@ -37,6 +37,7 @@ impl EditPredictionResult { edited_buffer: &Entity, edited_buffer_snapshot: &BufferSnapshot, edits: Arc<[(Range, Arc)]>, + cursor_position: Option, buffer_snapshotted_at: Instant, response_received_at: Instant, inputs: ZetaPromptInput, @@ -71,6 +72,7 @@ impl EditPredictionResult { prediction: Ok(EditPrediction { id, edits, + cursor_position, snapshot, edit_preview, inputs, @@ -86,6 +88,7 @@ impl EditPredictionResult { pub struct EditPrediction { pub id: EditPredictionId, pub edits: Arc<[(Range, Arc)]>, + pub cursor_position: Option, pub snapshot: BufferSnapshot, pub edit_preview: EditPreview, pub buffer: Entity, @@ -143,6 +146,7 @@ mod tests { let prediction = EditPrediction { id: EditPredictionId("prediction-1".into()), edits, + cursor_position: None, snapshot: cx.read(|cx| buffer.read(cx).snapshot()), buffer: buffer.clone(), edit_preview, diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index 15dccee48bea85da2d8e72c0860f4e9f5f4de77f..d781d175a247a0ee7c92565cb9becc7446d34df0 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -291,6 +291,7 @@ impl SweepAi { &buffer, &old_snapshot, edits.into(), + None, buffer_snapshotted_at, response_received_at, inputs, diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs index e914d4c95f349aee07f32e21caf3c04c318af4d2..f0d55b6899a47e6366e8fef0a7e0d6faaa63c32a 100644 --- a/crates/edit_prediction/src/udiff.rs +++ b/crates/edit_prediction/src/udiff.rs @@ -190,38 +190,6 @@ pub async fn refresh_worktree_entries( Ok(()) } -/// Extract the diff for a specific file from a multi-file diff. -/// Returns an error if the file is not found in the diff. -pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result { - let mut result = String::new(); - let mut in_target_file = false; - let mut found_file = false; - - for line in full_diff.lines() { - if line.starts_with("diff --git") { - if in_target_file { - break; - } - in_target_file = line.contains(&format!("a/{}", file_path)) - || line.contains(&format!("b/{}", file_path)); - if in_target_file { - found_file = true; - } - } - - if in_target_file { - result.push_str(line); - result.push('\n'); - } - } - - if !found_file { - anyhow::bail!("File '{}' not found in diff", file_path); - } - - Ok(result) -} - pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> { if prefix.is_empty() { return Cow::Borrowed(diff); @@ -319,9 +287,20 @@ fn disambiguate_by_line_number( } pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { + apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text) +} + +/// Applies a diff to a string and returns the result along with the offset where +/// the first hunk's context matched in the original text. This offset can be used +/// to adjust cursor positions that are relative to the hunk's content. +pub fn apply_diff_to_string_with_hunk_offset( + diff_str: &str, + text: &str, +) -> Result<(String, Option)> { let mut diff = DiffParser::new(diff_str); let mut text = text.to_string(); + let mut first_hunk_offset = None; while let Some(event) = diff.next().context("Failed to parse diff")? { match event { @@ -342,6 +321,10 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { }) .ok_or_else(|| anyhow!("couldn't resolve hunk"))?; + if first_hunk_offset.is_none() { + first_hunk_offset = Some(hunk_offset); + } + for edit in hunk.edits.iter().rev() { let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end); text.replace_range(range, &edit.text); @@ -351,7 +334,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { } } - Ok(text) + Ok((text, first_hunk_offset)) } /// Returns the individual edits that would be applied by a diff to the given content. @@ -1457,63 +1440,6 @@ mod tests { FakeFs::new(cx.background_executor.clone()) } - #[test] - fn test_extract_file_diff() { - let multi_file_diff = indoc! {r#" - diff --git a/file1.txt b/file1.txt - index 1234567..abcdefg 100644 - --- a/file1.txt - +++ b/file1.txt - @@ -1,3 +1,4 @@ - line1 - +added line - line2 - line3 - diff --git a/file2.txt b/file2.txt - index 2345678..bcdefgh 100644 - --- a/file2.txt - +++ b/file2.txt - @@ -1,2 +1,2 @@ - -old line - +new line - unchanged - "#}; - - let file1_diff = extract_file_diff(multi_file_diff, "file1.txt").unwrap(); - assert_eq!( - file1_diff, - indoc! {r#" - diff --git a/file1.txt b/file1.txt - index 1234567..abcdefg 100644 - --- a/file1.txt - +++ b/file1.txt - @@ -1,3 +1,4 @@ - line1 - +added line - line2 - line3 - "#} - ); - - let file2_diff = extract_file_diff(multi_file_diff, "file2.txt").unwrap(); - assert_eq!( - file2_diff, - indoc! {r#" - diff --git a/file2.txt b/file2.txt - index 2345678..bcdefgh 100644 - --- a/file2.txt - +++ b/file2.txt - @@ -1,2 +1,2 @@ - -old line - +new line - unchanged - "#} - ); - - let result = extract_file_diff(multi_file_diff, "nonexistent.txt"); - assert!(result.is_err()); - } - #[test] fn test_edits_for_diff() { let content = indoc! {" diff --git a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs index cdfe230d4bf84e9e5c891efa0a1f38a07839e37f..b2a7f34c73b37eabee51d91f1bed2b6735936239 100644 --- a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -223,6 +223,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { Some(edit_prediction_types::EditPrediction::Local { id: Some(prediction.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), + cursor_position: None, edit_preview: Some(prediction.edit_preview.clone()), }) }) diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index c82a949d669f03f4e811c6ef8a3479a8bc85c4c3..6785189a2e1a3e8dd235f903eaee3b7d95df262b 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -274,6 +274,7 @@ fn process_completion_response( &buffer, &snapshot, edits, + None, buffer_snapshotted_at, received_response_at, inputs, diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 0e776f1ecc5f862a5676505900de2b3eef5029fc..8889f9f4fcf59009b5bcdbd6087cf00522bc9e61 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -8,8 +8,9 @@ use crate::{ use anyhow::{Result, anyhow}; use cloud_llm_client::predict_edits_v3::RawCompletionRequest; use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason}; +use edit_prediction_types::PredictedCursorPosition; use gpui::{App, Task, prelude::*}; -use language::{OffsetRangeExt as _, ToOffset as _, ToPoint}; +use language::{OffsetRangeExt as _, ToOffset as _, ToPoint, text_diff}; use release_channel::AppVersion; use std::env; @@ -145,9 +146,10 @@ pub fn request_prediction_with_zeta2( .ok(); } - if output_text.contains(CURSOR_MARKER) { - log::trace!("Stripping out {CURSOR_MARKER} from response"); - output_text = output_text.replace(CURSOR_MARKER, ""); + let cursor_offset_in_output = output_text.find(CURSOR_MARKER); + if let Some(offset) = cursor_offset_in_output { + log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); + output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); } if zeta_version == ZetaVersion::V0120GitMergeMarkers { @@ -170,12 +172,22 @@ pub fn request_prediction_with_zeta2( } let edits = compute_edits( - old_text, + old_text.clone(), &output_text, editable_offset_range.start, &snapshot, ); + let cursor_position = cursor_offset_in_output.map(|cursor_offset| { + compute_predicted_cursor_position( + &old_text, + &output_text, + cursor_offset, + editable_offset_range.start, + &snapshot, + ) + }); + anyhow::Ok(( Some(( request_id, @@ -184,6 +196,7 @@ pub fn request_prediction_with_zeta2( buffer, snapshot.clone(), edits, + cursor_position, received_response_at, )), )), @@ -199,8 +212,14 @@ pub fn request_prediction_with_zeta2( return Ok(None); }; - let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) = - prediction + let Some(( + inputs, + edited_buffer, + edited_buffer_snapshot, + edits, + cursor_position, + received_response_at, + )) = prediction else { return Ok(Some(EditPredictionResult { id, @@ -214,6 +233,7 @@ pub fn request_prediction_with_zeta2( &edited_buffer, &edited_buffer_snapshot, edits.into(), + cursor_position, buffer_snapshotted_at, received_response_at, inputs, @@ -224,6 +244,65 @@ pub fn request_prediction_with_zeta2( }) } +/// Computes a `PredictedCursorPosition` from a cursor offset in the output text. +/// +/// The cursor offset is relative to `new_text`. We need to determine if the cursor +/// falls inside an edit's inserted text or in unchanged text: +/// - If inside an edit: anchor = start of edit range, offset = position within insertion +/// - If in unchanged text: anchor = corresponding position in old buffer, offset = 0 +fn compute_predicted_cursor_position( + old_text: &str, + new_text: &str, + cursor_offset_in_new: usize, + editable_region_start: usize, + snapshot: &language::BufferSnapshot, +) -> PredictedCursorPosition { + let diffs = text_diff(old_text, new_text); + + // Track position in both old and new text as we walk through diffs + let mut old_pos = 0usize; + let mut new_pos = 0usize; + + for (old_range, new_text_chunk) in &diffs { + // Text before this diff is unchanged + let unchanged_len = old_range.start - old_pos; + let unchanged_end_in_new = new_pos + unchanged_len; + + if cursor_offset_in_new < unchanged_end_in_new { + // Cursor is in unchanged text before this diff + let offset_in_unchanged = cursor_offset_in_new - new_pos; + let buffer_offset = editable_region_start + old_pos + offset_in_unchanged; + return PredictedCursorPosition::at_anchor(snapshot.anchor_after(buffer_offset)); + } + + // Move past the unchanged portion in new_text coordinates + new_pos = unchanged_end_in_new; + + // Check if cursor is within this edit's new text + let edit_new_text_end = new_pos + new_text_chunk.len(); + if cursor_offset_in_new < edit_new_text_end { + // Cursor is inside this edit's inserted text. + // Use anchor_before (left bias) so the anchor stays at the insertion point + // rather than moving past the inserted text. + let offset_within_insertion = cursor_offset_in_new - new_pos; + let buffer_offset = editable_region_start + old_range.start; + return PredictedCursorPosition::new( + snapshot.anchor_before(buffer_offset), + offset_within_insertion, + ); + } + + // Move past this edit + old_pos = old_range.end; + new_pos = edit_new_text_end; + } + + // Cursor is in unchanged text after all diffs + let offset_in_unchanged = cursor_offset_in_new - new_pos; + let buffer_offset = (editable_region_start + old_pos + offset_in_unchanged).min(snapshot.len()); + PredictedCursorPosition::at_anchor(snapshot.anchor_after(buffer_offset)) +} + pub fn zeta2_prompt_input( snapshot: &language::BufferSnapshot, related_files: Vec, diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs index bed15c347dc619e772350469a19500ebc18a6da2..b1dbb0b9fcf1639d3a9da00502e97a0778393327 100644 --- a/crates/edit_prediction_cli/src/distill.rs +++ b/crates/edit_prediction_cli/src/distill.rs @@ -15,10 +15,12 @@ pub async fn run_distill(example: &mut Example) -> Result<()> { let expected_patches = predictions .into_iter() - .filter_map(|p| p.actual_patch.clone()) + .filter_map(|p| Some((p.actual_patch.clone()?, p.actual_cursor_offset))) .collect(); - example.spec.expected_patches = expected_patches; + example + .spec + .set_expected_patches_with_cursor_positions(expected_patches); example.prompt = None; example.predictions = Vec::new(); example.score = Vec::new(); diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 0c8905b6291b15eaf2dcf3b80573fcc8966e79a5..a491b48e242a1a648ed4535a1397122ad9674183 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -86,6 +86,8 @@ pub struct ExamplePrediction { #[serde(deserialize_with = "deserialize_null_as_empty_string")] pub actual_output: String, #[serde(default, skip_serializing_if = "Option::is_none")] + pub actual_cursor_offset: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub error: Option, pub provider: PredictionProvider, } @@ -110,6 +112,10 @@ pub struct ExampleScore { pub exact_lines_fn: usize, #[serde(default)] pub reversal_ratio: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cursor_distance: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cursor_exact_match: Option, } impl Example { diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index ed248070ccc960513324ebdec1dc68c7cc2042be..2e13045511c361919a4e1ae9dada17323d109c2b 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -6,7 +6,7 @@ use crate::{ retrieve_context::run_context_retrieval, }; use anyhow::{Context as _, Result, anyhow}; -use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position; +use edit_prediction::{cursor_excerpt::editable_and_context_ranges_for_cursor_position, udiff}; use gpui::{AppContext, AsyncApp}; use language::{Buffer, OffsetRangeExt, Point}; use similar::DiffableStr; @@ -61,18 +61,10 @@ pub async fn run_format_prompt( let context_range = context_range.to_offset(&snapshot); let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range); - let expected_output = example - .spec - .expected_patches - .first() - .cloned() - .unwrap_or_default(); - let rejected_output = example.spec.rejected_patch.clone(); - example.prompt = Some(ExamplePrompt { input: prompt, - expected_output, - rejected_output, + expected_output: String::new(), + rejected_output: None, provider: args.provider, }); } @@ -102,21 +94,19 @@ pub async fn run_format_prompt( related_files: prompt_inputs.related_files.clone().unwrap_or_default(), }; let prompt = format_zeta_prompt(&input, version); - let expected_output = zeta2_output_for_patch( - &input, - &example - .spec - .expected_patches - .first() - .context("expected patches is empty")? - .clone(), - version, - )?; + let (expected_patch, expected_cursor_offset) = example + .spec + .expected_patches_with_cursor_positions() + .into_iter() + .next() + .context("expected patches is empty")?; + let expected_output = + zeta2_output_for_patch(&input, &expected_patch, expected_cursor_offset, version)?; let rejected_output = example .spec .rejected_patch .as_ref() - .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok()); + .and_then(|patch| zeta2_output_for_patch(&input, patch, None, version).ok()); example.prompt = Some(ExamplePrompt { input: prompt, @@ -135,6 +125,7 @@ pub async fn run_format_prompt( pub fn zeta2_output_for_patch( input: &zeta_prompt::ZetaPromptInput, patch: &str, + cursor_offset: Option, version: ZetaVersion, ) -> Result { let mut old_editable_region = @@ -144,13 +135,24 @@ pub fn zeta2_output_for_patch( old_editable_region.push('\n'); } - let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region) - .with_context(|| { - format!( - "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```", - patch, old_editable_region - ) - })?; + let (mut result, first_hunk_offset) = + udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context( + || { + format!( + "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```", + patch, old_editable_region + ) + }, + )?; + + if let Some(cursor_offset) = cursor_offset { + // The cursor_offset is relative to the start of the hunk's new text (context + additions). + // We need to add where the hunk context matched in the editable region to compute + // the actual cursor position in the result. + let hunk_start = first_hunk_offset.unwrap_or(0); + let offset = (hunk_start + cursor_offset).min(result.len()); + result.insert_str(offset, zeta_prompt::CURSOR_MARKER); + } if version == ZetaVersion::V0120GitMergeMarkers { if !result.ends_with('\n') { @@ -191,24 +193,28 @@ impl TeacherPrompt { prompt } - pub fn parse(example: &Example, response: &str) -> Result { + pub fn parse(example: &Example, response: &str) -> Result<(String, Option)> { // Extract updated (new) editable region from the model response. // The model may include editable region markers in its output, so we need to strip them. let new_editable_region = extract_last_codeblock(response); // Check if the model indicated no edits are needed if new_editable_region.trim() == Self::NO_EDITS { - return Ok(String::new()); + return Ok((String::new(), None)); } - let mut new_editable_region = Self::extract_editable_region(&new_editable_region)?; + let new_editable_region = Self::extract_editable_region(&new_editable_region)?; + let cursor_offset = new_editable_region.find(Self::USER_CURSOR_MARKER); + let mut new_editable_region = new_editable_region.replace(Self::USER_CURSOR_MARKER, ""); let old_editable_region = Self::extract_editable_region( &example .prompt .as_ref() .context("example prompt missing")? .input, - )?; + )? + .replace(Self::USER_CURSOR_MARKER, ""); + let prompt_inputs = example .prompt_inputs .as_ref() @@ -230,11 +236,14 @@ impl TeacherPrompt { .matches('\n') .count(); - let diff = language::unified_diff_with_offsets( + // Use full context so cursor offset (relative to editable region start) aligns with diff content + let editable_region_lines = old_editable_region.lines().count() as u32; + let diff = language::unified_diff_with_context( &old_editable_region, &new_editable_region, editable_region_start_line as u32, editable_region_start_line as u32, + editable_region_lines, ); let diff = indoc::formatdoc! {" @@ -245,7 +254,7 @@ impl TeacherPrompt { diff = diff, }; - Ok(diff) + Ok((diff, cursor_offset)) } fn format_edit_history(edit_history: &str) -> String { @@ -328,7 +337,7 @@ impl TeacherPrompt { result } - fn extract_editable_region(text: &str) -> Result { + pub fn extract_editable_region(text: &str) -> Result { let start = text .rfind(Self::EDITABLE_REGION_START) .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len()); @@ -339,9 +348,7 @@ impl TeacherPrompt { } let region = &text[start..end]; - let region = region.strip_suffix('\n').unwrap_or(region); - - Ok(region.replace(Self::USER_CURSOR_MARKER, "")) + Ok(region.strip_suffix('\n').unwrap_or(region).to_string()) } fn is_udiff_content_line(s: &str) -> bool { @@ -571,22 +578,4 @@ mod tests { let codeblock = extract_last_codeblock(response); assert_eq!(codeblock.trim(), TeacherPrompt::NO_EDITS); } - - #[test] - fn test_extract_editable_region_strips_cursor_marker() { - let text = indoc::indoc! {" - <|editable_region_start|> - one - <|user_cursor|>two three - - <|editable_region_end|> - "}; - let parsed = TeacherPrompt::extract_editable_region(text).unwrap(); - assert_eq!( - parsed, - indoc::indoc! {" - one - two three"} - ); - } } diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 06e8e2dadd61c4e0df136acd14ff03d65ebe2bda..5a3a49870ad3e0d3ace594b6613dae9e70df19c2 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -19,14 +19,14 @@ pub fn run_parse_output(example: &mut Example) -> Result<()> { .enumerate() .filter(|(_, p)| !p.actual_output.is_empty()) .map(|(ix, prediction)| { - let actual_patch = - parse_prediction_output(example, &prediction.actual_output, provider); - actual_patch.map(|patch| (ix, patch)) + let result = parse_prediction_output(example, &prediction.actual_output, provider); + result.map(|(patch, cursor_offset)| (ix, patch, cursor_offset)) }) .collect::>>()?; - for (ix, actual_patch) in parsed_patches { + for (ix, actual_patch, actual_cursor_offset) in parsed_patches { example.predictions[ix].actual_patch = Some(actual_patch); + example.predictions[ix].actual_cursor_offset = actual_cursor_offset; example.predictions[ix].provider = provider; } @@ -37,7 +37,7 @@ pub fn parse_prediction_output( example: &Example, actual_output: &str, provider: PredictionProvider, -) -> Result { +) -> Result<(String, Option)> { match provider { PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => { TeacherPrompt::parse(example, actual_output) @@ -83,7 +83,7 @@ fn parse_zeta2_output( example: &Example, actual_output: &str, version: ZetaVersion, -) -> Result { +) -> Result<(String, Option)> { let prompt = &example.prompt.as_ref().context("prompt required")?.input; let prompt_inputs = example .prompt_inputs @@ -92,7 +92,13 @@ fn parse_zeta2_output( let old_text = extract_zeta2_current_region(prompt, version)?; - let mut new_text = actual_output.replace(CURSOR_MARKER, ""); + let mut new_text = actual_output.to_string(); + let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { + new_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); + Some(offset) + } else { + None + }; if version == ZetaVersion::V0120GitMergeMarkers { if let Some(stripped) = @@ -126,11 +132,14 @@ fn parse_zeta2_output( .matches('\n') .count(); - let diff = language::unified_diff_with_offsets( + // Use full context so cursor offset (relative to editable region start) aligns with diff content + let editable_region_lines = old_text_normalized.lines().count() as u32; + let diff = language::unified_diff_with_context( &old_text_normalized, &new_text, editable_region_start_line as u32, editable_region_start_line as u32, + editable_region_lines, ); let formatted_diff = format!( @@ -138,7 +147,7 @@ fn parse_zeta2_output( path = example.spec.cursor_path.to_string_lossy(), ); - Ok(formatted_diff) + Ok((formatted_diff, cursor_offset)) } #[cfg(test)] diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 192f8dc8f2dea95b756c7d699b91c388cced3cdf..69258c964ea109d27a1ede2a950a49a1066c0cfe 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -201,6 +201,7 @@ pub async fn run_prediction( .push(ExamplePrediction { actual_patch: None, actual_output: String::new(), + actual_cursor_offset: None, error: None, provider, }); @@ -322,11 +323,12 @@ async fn predict_anthropic( .collect::>() .join("\n"); - let actual_patch = TeacherPrompt::parse(example, &actual_output)?; + let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?; let prediction = ExamplePrediction { actual_patch: Some(actual_patch), actual_output, + actual_cursor_offset, error: None, provider: if batched { PredictionProvider::Teacher(backend) @@ -394,11 +396,12 @@ async fn predict_openai( .collect::>() .join("\n"); - let actual_patch = TeacherPrompt::parse(example, &actual_output)?; + let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?; let prediction = ExamplePrediction { actual_patch: Some(actual_patch), actual_output, + actual_cursor_offset, error: None, provider: if batched { PredictionProvider::Teacher(backend) diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index 78d7232209ef6268fce943bff34e3b08274a02e8..23b89133ae6183ee9444be7c1da21668351f8d2d 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -125,11 +125,12 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool { /// Parse the repair response into a prediction. fn parse_repair_response(example: &Example, response_text: &str) -> Result { - let actual_patch = TeacherPrompt::parse(example, response_text)?; + let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, response_text)?; Ok(ExamplePrediction { actual_patch: Some(actual_patch), actual_output: response_text.to_string(), + actual_cursor_offset, error: None, provider: PredictionProvider::Repair, }) @@ -362,6 +363,7 @@ pub async fn run_repair( example.predictions.push(ExamplePrediction { actual_patch: None, actual_output: response_text.clone(), + actual_cursor_offset: None, error: Some(format!("Failed to parse repair response: {}", e)), provider: PredictionProvider::Repair, }); diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 486e433ca0e9a69712023c418c06f331c758ec02..772f518b157a717f87ca9d5b704104fd4dde7181 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -1,6 +1,7 @@ use crate::{ - PredictArgs, + PredictArgs, PredictionProvider, example::{Example, ExampleScore}, + format_prompt::TeacherPrompt, headless::EpAppState, metrics, parse_output::parse_prediction_output, @@ -9,7 +10,7 @@ use crate::{ reversal_tracking, }; use anyhow::Context as _; -use edit_prediction::udiff::apply_diff_to_string; +use edit_prediction::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset}; use gpui::AsyncApp; use serde::Serialize; use std::fs::File; @@ -34,16 +35,36 @@ pub async fn run_scoring( .as_ref() .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")? .content; - let expected_texts: Vec = example - .spec - .expected_patches + let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions(); + + let expected_texts: Vec = expected_patches_with_cursors .iter() - .map(|patch| { + .map(|(patch, _)| { apply_diff_to_string(patch, original_text) .with_context(|| format!("Expected patch did not apply for {}", example.spec.name)) }) .collect::, _>>()?; + // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets. + // The actual_cursor_offset from Teacher is relative to the editable region, while the expected + // cursor from the patch is relative to the hunk. We need to apply the patch to the editable + // region to find where the hunk matched, then compute the expected cursor position. + let old_editable_region = if let Some(p) = example.prompt.as_ref() { + if matches!( + p.provider, + PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) + ) { + Some( + TeacherPrompt::extract_editable_region(&p.input)? + .replace(TeacherPrompt::USER_CURSOR_MARKER, ""), + ) + } else { + None + } + } else { + None + }; + let zero_scores = ExampleScore { delta_chr_f: 0.0, braces_disbalance: 0, @@ -51,6 +72,8 @@ pub async fn run_scoring( exact_lines_fp: 0, exact_lines_fn: 0, reversal_ratio: 0.0, + cursor_distance: None, + cursor_exact_match: None, }; let prompt_inputs = example.prompt_inputs.as_ref().unwrap(); @@ -60,7 +83,9 @@ pub async fn run_scoring( let mut scores = vec![]; for prediction in &example.predictions { let actual_patch = prediction.actual_patch.clone().or_else(|| { - parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok() + parse_prediction_output(example, &prediction.actual_output, prediction.provider) + .ok() + .map(|(patch, _)| patch) }); let Some(actual_patch) = actual_patch else { @@ -75,10 +100,42 @@ pub async fn run_scoring( continue; } }; - let best_delta_chr_f = expected_texts - .iter() - .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32) - .fold(0.0, f32::max); + + let mut best_delta_chr_f = 0.0f32; + let mut best_expected_cursor: Option = None; + let mut best_patch_idx: Option = None; + + for (idx, expected) in expected_texts.iter().enumerate() { + let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32; + if delta_chr_f > best_delta_chr_f { + best_delta_chr_f = delta_chr_f; + best_patch_idx = Some(idx); + } + } + + if let Some(idx) = best_patch_idx { + // Get the raw cursor offset from the expected patch (relative to hunk new text) + let expected_cursor_in_patch = expected_patches_with_cursors + .get(idx) + .and_then(|(_, cursor)| *cursor); + + // For Teacher prompts, we need to apply the patch to the editable region + // to find where the hunk matched, then compute the actual cursor position + if let (Some(editable_region), Some(cursor_in_patch)) = + (&old_editable_region, expected_cursor_in_patch) + { + let (patch, _) = &expected_patches_with_cursors[idx]; + if let Ok((_, hunk_offset)) = + apply_diff_to_string_with_hunk_offset(patch, editable_region) + { + let hunk_start = hunk_offset.unwrap_or(0); + best_expected_cursor = Some(hunk_start + cursor_in_patch); + } + } else { + // For non-Teacher prompts or if we can't compute, use raw offset + best_expected_cursor = expected_cursor_in_patch; + } + } let disbalance_before = metrics::braces_disbalance(&original_text); let disbalance_after = metrics::braces_disbalance(&actual_text); @@ -95,11 +152,9 @@ pub async fn run_scoring( } // Compute exact lines match against best matching expected patch - let best_exact_lines = example - .spec - .expected_patches + let best_exact_lines = expected_patches_with_cursors .iter() - .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch)) + .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch)) .max_by_key(|m| m.true_positives) .unwrap_or_default(); @@ -110,6 +165,10 @@ pub async fn run_scoring( cursor_path, ); + // Compute cursor position metrics + let (cursor_distance, cursor_exact_match) = + compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor_offset); + scores.push(ExampleScore { delta_chr_f: best_delta_chr_f, braces_disbalance, @@ -117,6 +176,8 @@ pub async fn run_scoring( exact_lines_fp: best_exact_lines.false_positives, exact_lines_fn: best_exact_lines.false_negatives, reversal_ratio, + cursor_distance, + cursor_exact_match, }); } @@ -124,16 +185,37 @@ pub async fn run_scoring( Ok(()) } +fn compute_cursor_metrics( + expected_cursor: Option, + actual_cursor: Option, +) -> (Option, Option) { + match (expected_cursor, actual_cursor) { + (Some(expected), Some(actual)) => { + let distance = expected.abs_diff(actual); + let exact_match = expected == actual; + (Some(distance), Some(exact_match)) + } + (None, None) => { + // Neither has cursor position - skip cursor scoring + (None, None) + } + (Some(_), None) | (None, Some(_)) => { + // Only one has cursor position - count as miss + (None, Some(false)) + } + } +} + pub fn print_report(examples: &[Example]) { use crate::metrics::ClassificationMetrics; - const LINE_WIDTH: usize = 82; + const LINE_WIDTH: usize = 94; let separator = "─".repeat(LINE_WIDTH); println!("{}", separator); println!( - "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7}", - "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf" + "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6}", + "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor" ); println!("{}", separator); @@ -146,6 +228,10 @@ pub fn print_report(examples: &[Example]) { let mut qa_reverts_total: usize = 0; let mut qa_confidence_sum: u64 = 0; let mut qa_confidence_count: usize = 0; + let mut cursor_exact_matches: usize = 0; + let mut cursor_total: usize = 0; + let mut cursor_distance_sum: usize = 0; + let mut cursor_distance_count: usize = 0; for example in examples { for (score_idx, score) in example.score.iter().enumerate() { @@ -166,15 +252,24 @@ pub fn print_report(examples: &[Example]) { .map(|v| format!("{}", v)) .unwrap_or("-".to_string()); + // Format cursor metric + let cursor_str = match (score.cursor_exact_match, score.cursor_distance) { + (Some(true), _) => "✓".to_string(), + (Some(false), Some(dist)) => format!("±{}", dist), + (Some(false), None) => "✗".to_string(), + (None, _) => "-".to_string(), + }; + println!( - "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7}", + "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}", truncate_name(&example.spec.name, 40), score.delta_chr_f, score.braces_disbalance, exact_lines.f1() * 100.0, score.reversal_ratio * 100.0, qa_reverts_str, - qa_conf_str + qa_conf_str, + cursor_str ); all_delta_chr_f_scores.push(score.delta_chr_f); @@ -198,6 +293,18 @@ pub fn print_report(examples: &[Example]) { qa_confidence_count += 1; } } + + // Accumulate cursor metrics + if let Some(exact_match) = score.cursor_exact_match { + cursor_total += 1; + if exact_match { + cursor_exact_matches += 1; + } + } + if let Some(dist) = score.cursor_distance { + cursor_distance_sum += dist; + cursor_distance_count += 1; + } } } @@ -226,18 +333,43 @@ pub fn print_report(examples: &[Example]) { } else { "-".to_string() }; + let cursor_str = if cursor_total > 0 { + format!( + "{:.0}%", + cursor_exact_matches as f32 / cursor_total as f32 * 100.0 + ) + } else { + "-".to_string() + }; + let avg_cursor_distance = if cursor_distance_count > 0 { + Some(cursor_distance_sum as f32 / cursor_distance_count as f32) + } else { + None + }; println!( - "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7}", + "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}", "TOTAL / AVERAGE", avg_delta_chr_f, braces_disbalance_avg, total_exact_lines.f1() * 100.0, avg_reversal_ratio * 100.0, qa_reverts_str, - qa_conf_str + qa_conf_str, + cursor_str ); println!("{}", separator); + + // Print additional cursor metrics if available + if let Some(avg_dist) = avg_cursor_distance { + println!( + "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes", + cursor_exact_matches, + cursor_total, + cursor_exact_matches as f32 / cursor_total as f32 * 100.0, + avg_dist + ); + } } println!("\n"); @@ -267,6 +399,12 @@ pub struct SummaryJson { pub qa_avg_reverts_edits: Option, #[serde(skip_serializing_if = "Option::is_none")] pub qa_avg_confidence: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor_exact_match_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor_avg_distance: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor_total_evaluated: Option, } pub fn compute_summary(examples: &[Example]) -> SummaryJson { @@ -281,6 +419,10 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut qa_reverts_total: usize = 0; let mut qa_confidence_sum: u64 = 0; let mut qa_confidence_count: usize = 0; + let mut cursor_exact_matches: usize = 0; + let mut cursor_total: usize = 0; + let mut cursor_distance_sum: usize = 0; + let mut cursor_distance_count: usize = 0; for example in examples { for (score_idx, score) in example.score.iter().enumerate() { @@ -305,6 +447,18 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { qa_confidence_count += 1; } } + + // Accumulate cursor metrics + if let Some(exact_match) = score.cursor_exact_match { + cursor_total += 1; + if exact_match { + cursor_exact_matches += 1; + } + } + if let Some(dist) = score.cursor_distance { + cursor_distance_sum += dist; + cursor_distance_count += 1; + } } } @@ -338,6 +492,24 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { None }; + let cursor_exact_match_rate = if cursor_total > 0 { + Some(cursor_exact_matches as f32 / cursor_total as f32) + } else { + None + }; + + let cursor_avg_distance = if cursor_distance_count > 0 { + Some(cursor_distance_sum as f32 / cursor_distance_count as f32) + } else { + None + }; + + let cursor_total_evaluated = if cursor_total > 0 { + Some(cursor_total) + } else { + None + }; + SummaryJson { total_examples: total_scores, avg_delta_chr_f, @@ -351,6 +523,9 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { avg_reversal_ratio, qa_avg_reverts_edits, qa_avg_confidence, + cursor_exact_match_rate, + cursor_avg_distance, + cursor_total_evaluated, } } diff --git a/crates/edit_prediction_types/src/edit_prediction_types.rs b/crates/edit_prediction_types/src/edit_prediction_types.rs index 9cb191f7b99da7229bd686abeb46fdef3f2274be..a077e43ff7850c0c4c5f0fe460664d6b642a4a14 100644 --- a/crates/edit_prediction_types/src/edit_prediction_types.rs +++ b/crates/edit_prediction_types/src/edit_prediction_types.rs @@ -4,6 +4,34 @@ use client::EditPredictionUsage; use gpui::{App, Context, Entity, SharedString}; use language::{Anchor, Buffer, OffsetRangeExt}; +/// Represents a predicted cursor position after an edit is applied. +/// +/// Since the cursor may be positioned inside newly inserted text that doesn't +/// exist in the original buffer, we store an anchor (which points to a position +/// in the original buffer, typically the start of an edit) plus an offset into +/// the inserted text. +#[derive(Clone, Debug)] +pub struct PredictedCursorPosition { + /// An anchor in the original buffer. If the cursor is inside an edit, + /// this points to the start of that edit's range. + pub anchor: language::Anchor, + /// Offset from the anchor into the new text. If the cursor is inside + /// inserted text, this is the offset within that insertion. If the cursor + /// is outside any edit, this is 0. + pub offset: usize, +} + +impl PredictedCursorPosition { + pub fn new(anchor: language::Anchor, offset: usize) -> Self { + Self { anchor, offset } + } + + /// Creates a predicted cursor position at an exact anchor location (offset = 0). + pub fn at_anchor(anchor: language::Anchor) -> Self { + Self { anchor, offset: 0 } + } +} + /// The display mode used when showing an edit prediction to the user. /// Used for metrics tracking. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -29,6 +57,7 @@ pub enum EditPrediction { Local { id: Option, edits: Vec<(Range, Arc)>, + cursor_position: Option, edit_preview: Option, }, /// Jump to a different file from the one that requested the prediction diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index b5931cde42a4e2c0e21b2d1f68558879de9750b4..45cae0ef956e8fc05aeef84099cddedff343f4e9 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -1,4 +1,4 @@ -use edit_prediction_types::EditPredictionDelegate; +use edit_prediction_types::{EditPredictionDelegate, PredictedCursorPosition}; use gpui::{Entity, KeyBinding, Modifiers, prelude::*}; use indoc::indoc; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; @@ -32,6 +32,88 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { cx.assert_editor_state("let absolute_zero_celsius = -273.15ˇ;") } +#[gpui::test] +async fn test_edit_prediction_cursor_position_inside_insertion(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); + assign_editor_completion_provider(provider.clone(), &mut cx); + // Buffer: "fn foo() {}" - we'll insert text and position cursor inside the insertion + cx.set_state("fn foo() ˇ{}"); + + // Insert "bar()" at offset 9, with cursor at offset 2 within the insertion (after "ba") + // This tests the case where cursor is inside newly inserted text + propose_edits_with_cursor_position_in_insertion( + &provider, + vec![(9..9, "bar()")], + 9, // anchor at the insertion point + 2, // offset 2 within "bar()" puts cursor after "ba" + &mut cx, + ); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + + assert_editor_active_edit_completion(&mut cx, |_, edits| { + assert_eq!(edits.len(), 1); + assert_eq!(edits[0].1.as_ref(), "bar()"); + }); + + accept_completion(&mut cx); + + // Cursor should be inside the inserted text at "baˇr()" + cx.assert_editor_state("fn foo() baˇr(){}"); +} + +#[gpui::test] +async fn test_edit_prediction_cursor_position_outside_edit(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); + assign_editor_completion_provider(provider.clone(), &mut cx); + // Buffer: "let x = ;" with cursor before semicolon - we'll insert "42" and position cursor elsewhere + cx.set_state("let x = ˇ;"); + + // Insert "42" at offset 8, but set cursor_position to offset 4 (the 'x') + // This tests that cursor moves to the predicted position, not the end of the edit + propose_edits_with_cursor_position( + &provider, + vec![(8..8, "42")], + Some(4), // cursor at offset 4 (the 'x'), NOT at the edit location + &mut cx, + ); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + + assert_editor_active_edit_completion(&mut cx, |_, edits| { + assert_eq!(edits.len(), 1); + assert_eq!(edits[0].1.as_ref(), "42"); + }); + + accept_completion(&mut cx); + + // Cursor should be at offset 4 (the 'x'), not at the end of the inserted "42" + cx.assert_editor_state("let ˇx = 42;"); +} + +#[gpui::test] +async fn test_edit_prediction_cursor_position_fallback(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); + assign_editor_completion_provider(provider.clone(), &mut cx); + cx.set_state("let x = ˇ;"); + + // Propose an edit without a cursor position - should fall back to end of edit + propose_edits(&provider, vec![(8..8, "42")], &mut cx); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + + accept_completion(&mut cx); + + // Cursor should be at the end of the inserted text (default behavior) + cx.assert_editor_state("let x = 42ˇ;") +} + #[gpui::test] async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); @@ -374,8 +456,50 @@ fn propose_edits( provider: &Entity, edits: Vec<(Range, &str)>, cx: &mut EditorTestContext, +) { + propose_edits_with_cursor_position(provider, edits, None, cx); +} + +fn propose_edits_with_cursor_position( + provider: &Entity, + edits: Vec<(Range, &str)>, + cursor_offset: Option, + cx: &mut EditorTestContext, +) { + let snapshot = cx.buffer_snapshot(); + let cursor_position = cursor_offset + .map(|offset| PredictedCursorPosition::at_anchor(snapshot.anchor_after(offset))); + let edits = edits.into_iter().map(|(range, text)| { + let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end); + (range, text.into()) + }); + + cx.update(|_, cx| { + provider.update(cx, |provider, _| { + provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { + id: None, + edits: edits.collect(), + cursor_position, + edit_preview: None, + })) + }) + }); +} + +fn propose_edits_with_cursor_position_in_insertion( + provider: &Entity, + edits: Vec<(Range, &str)>, + anchor_offset: usize, + offset_within_insertion: usize, + cx: &mut EditorTestContext, ) { let snapshot = cx.buffer_snapshot(); + // Use anchor_before (left bias) so the anchor stays at the insertion point + // rather than moving past the inserted text + let cursor_position = Some(PredictedCursorPosition::new( + snapshot.anchor_before(anchor_offset), + offset_within_insertion, + )); let edits = edits.into_iter().map(|(range, text)| { let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end); (range, text.into()) @@ -386,6 +510,7 @@ fn propose_edits( provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: edits.collect(), + cursor_position, edit_preview: None, })) }) @@ -417,6 +542,7 @@ fn propose_edits_non_zed( provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: edits.collect(), + cursor_position: None, edit_preview: None, })) }) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index e4b3d1a6bb37f0080ba470c575fca2f82e8a6e31..aae80501b4fd57ab7036167fdd99358fe36c1edc 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -634,6 +634,10 @@ pub(crate) enum EditDisplayMode { enum EditPrediction { Edit { edits: Vec<(Range, Arc)>, + /// Predicted cursor position as (anchor, offset_from_anchor). + /// The anchor is in multibuffer coordinates; after applying edits, + /// resolve the anchor and add the offset to get the final cursor position. + cursor_position: Option<(Anchor, usize)>, edit_preview: Option, display_mode: EditDisplayMode, snapshot: BufferSnapshot, @@ -7872,7 +7876,11 @@ impl Editor { .detach_and_log_err(cx); } } - EditPrediction::Edit { edits, .. } => { + EditPrediction::Edit { + edits, + cursor_position, + .. + } => { self.report_edit_prediction_event( active_edit_prediction.completion_id.clone(), true, @@ -7886,15 +7894,32 @@ impl Editor { } let transaction_id_prev = self.buffer.read(cx).last_transaction_id(cx); - let snapshot = self.buffer.read(cx).snapshot(cx); - let last_edit_end = edits.last().unwrap().0.end.bias_right(&snapshot); + + // Compute fallback cursor position BEFORE applying the edit, + // so the anchor tracks through the edit correctly + let fallback_cursor_target = { + let snapshot = self.buffer.read(cx).snapshot(cx); + edits.last().unwrap().0.end.bias_right(&snapshot) + }; self.buffer.update(cx, |buffer, cx| { buffer.edit(edits.iter().cloned(), None, cx) }); + // Resolve cursor position after the edit is applied + let cursor_target = if let Some((anchor, offset)) = cursor_position { + // The anchor tracks through the edit, then we add the offset + let snapshot = self.buffer.read(cx).snapshot(cx); + let base_offset = anchor.to_offset(&snapshot).0; + let target_offset = + MultiBufferOffset((base_offset + offset).min(snapshot.len().0)); + snapshot.anchor_after(target_offset) + } else { + fallback_cursor_target + }; + self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_anchor_ranges([last_edit_end..last_edit_end]); + s.select_anchor_ranges([cursor_target..cursor_target]); }); let selections = self.selections.disjoint_anchors_arc(); @@ -8358,12 +8383,14 @@ impl Editor { let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?; - let (completion_id, edits, edit_preview) = match edit_prediction { + let (completion_id, edits, predicted_cursor_position, edit_preview) = match edit_prediction + { edit_prediction_types::EditPrediction::Local { id, edits, + cursor_position, edit_preview, - } => (id, edits, edit_preview), + } => (id, edits, cursor_position, edit_preview), edit_prediction_types::EditPrediction::Jump { id, snapshot, @@ -8397,6 +8424,11 @@ impl Editor { return None; } + let cursor_position = predicted_cursor_position.and_then(|predicted| { + let anchor = multibuffer.anchor_in_excerpt(excerpt_id, predicted.anchor)?; + Some((anchor, predicted.offset)) + }); + let first_edit_start = edits.first().unwrap().0.start; let first_edit_start_point = first_edit_start.to_point(&multibuffer); let edit_start_row = first_edit_start_point.row.saturating_sub(2); @@ -8491,6 +8523,7 @@ impl Editor { EditPrediction::Edit { edits, + cursor_position, edit_preview, display_mode, snapshot, @@ -9179,6 +9212,7 @@ impl Editor { edit_preview, display_mode: EditDisplayMode::DiffPopover, snapshot, + .. } => self.render_edit_prediction_diff_popover( text_bounds, content_origin, @@ -10129,7 +10163,7 @@ impl Editor { edits, edit_preview, snapshot, - display_mode: _, + .. } => { let first_edit_row = edits.first()?.0.start.text_anchor.to_point(snapshot).row; diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 6a5242d30d53a4408b126a5cb9c35521bf7203f1..d4335524ab4f9f7a41a48426533e9fc4bbacf1f5 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -9017,6 +9017,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: vec![(edit_position..edit_position, "X".into())], + cursor_position: None, edit_preview: None, })) }) diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index a294f1b5ae81a0d1b59ae1d685ab0d1f8fd67b5a..bfdffabf31142ca297608a62d2692288b82e696d 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -67,7 +67,7 @@ use task::RunnableTag; pub use task_context::{ContextLocation, ContextProvider, RunnableRange}; pub use text_diff::{ DiffOptions, apply_diff_patch, line_diff, text_diff, text_diff_with_options, unified_diff, - unified_diff_with_offsets, word_diff_ranges, + unified_diff_with_context, unified_diff_with_offsets, word_diff_ranges, }; use theme::SyntaxTheme; pub use toolchain::{ diff --git a/crates/language/src/text_diff.rs b/crates/language/src/text_diff.rs index 774fae2cb832397b07aaa2fbcedef22c119f8bf3..96108fc33ae03f2d83e50e2df9433854a608cac5 100644 --- a/crates/language/src/text_diff.rs +++ b/crates/language/src/text_diff.rs @@ -22,12 +22,25 @@ pub fn unified_diff_with_offsets( new_text: &str, old_start_line: u32, new_start_line: u32, +) -> String { + unified_diff_with_context(old_text, new_text, old_start_line, new_start_line, 3) +} + +/// Computes a diff between two strings, returning a unified diff string with +/// hunk headers adjusted to reflect the given starting line numbers (zero-indexed), +/// and a configurable number of context lines around changes. +pub fn unified_diff_with_context( + old_text: &str, + new_text: &str, + old_start_line: u32, + new_start_line: u32, + context_lines: u32, ) -> String { let input = InternedInput::new(old_text, new_text); diff( Algorithm::Histogram, &input, - OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line), + OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines), ) } @@ -45,13 +58,19 @@ struct OffsetUnifiedDiffBuilder<'a> { old_line_offset: u32, new_line_offset: u32, + context_lines: u32, buffer: String, dst: String, } impl<'a> OffsetUnifiedDiffBuilder<'a> { - fn new(input: &'a InternedInput<&'a str>, old_line_offset: u32, new_line_offset: u32) -> Self { + fn new( + input: &'a InternedInput<&'a str>, + old_line_offset: u32, + new_line_offset: u32, + context_lines: u32, + ) -> Self { Self { before_hunk_start: 0, after_hunk_start: 0, @@ -59,6 +78,7 @@ impl<'a> OffsetUnifiedDiffBuilder<'a> { after_hunk_len: 0, old_line_offset, new_line_offset, + context_lines, buffer: String::with_capacity(8), dst: String::new(), interner: &input.interner, @@ -79,7 +99,7 @@ impl<'a> OffsetUnifiedDiffBuilder<'a> { return; } - let end = (self.pos + 3).min(self.before.len() as u32); + let end = (self.pos + self.context_lines).min(self.before.len() as u32); self.update_pos(end, end); writeln!( @@ -110,13 +130,13 @@ impl Sink for OffsetUnifiedDiffBuilder<'_> { type Out = String; fn process_change(&mut self, before: Range, after: Range) { - if before.start - self.pos > 6 { + if before.start - self.pos > self.context_lines * 2 { self.flush(); } if self.before_hunk_len == 0 && self.after_hunk_len == 0 { - self.pos = before.start.saturating_sub(3); + self.pos = before.start.saturating_sub(self.context_lines); self.before_hunk_start = self.pos; - self.after_hunk_start = after.start.saturating_sub(3); + self.after_hunk_start = after.start.saturating_sub(self.context_lines); } self.update_pos(before.start, before.end); self.before_hunk_len += before.end - before.start; @@ -467,4 +487,29 @@ mod tests { format!("@@ -100,3 +105,3 @@\n{}", expected_diff_body) ); } + + #[test] + fn test_unified_diff_with_context() { + // Test that full context includes all lines from the start + let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n"; + let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n"; + + // With default 3 lines of context, the diff starts at line 3 + let diff_default = unified_diff_with_offsets(old_text, new_text, 0, 0); + assert_eq!( + diff_default, + "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + // With full context (8 lines), the diff starts at line 1 + let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8); + assert_eq!( + diff_full_context, + "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + // With 0 context, only the changed line is shown + let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0); + assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n"); + } } diff --git a/crates/supermaven/src/supermaven_edit_prediction_delegate.rs b/crates/supermaven/src/supermaven_edit_prediction_delegate.rs index 9563a0aa99f1760b5af214be28f25dbf1734c371..4c3216f2e283dd3cc0d1f54a4dfd39d510e72e08 100644 --- a/crates/supermaven/src/supermaven_edit_prediction_delegate.rs +++ b/crates/supermaven/src/supermaven_edit_prediction_delegate.rs @@ -100,6 +100,7 @@ fn completion_from_diff( EditPrediction::Local { id: None, edits, + cursor_position: None, edit_preview: None, } } diff --git a/typos.toml b/typos.toml index b8dc55b8e53a066933a7c8b70bf521b663c16cba..7ce5d047e6113dc9b22755dcdfb2d0c3f016db12 100644 --- a/typos.toml +++ b/typos.toml @@ -24,6 +24,7 @@ extend-exclude = [ # Editor and file finder rely on partial typing and custom in-string syntax. "crates/file_finder/src/file_finder_tests.rs", "crates/editor/src/editor_tests.rs", + "crates/editor/src/edit_prediction_tests.rs", # There are some names in the test data that are incorrectly flagged as typos. "crates/git/test_data/blame_incremental_complex", "crates/git/test_data/golden/blame_incremental_complex.json",