Detailed changes
@@ -330,6 +330,7 @@ impl EditPredictionDelegate for CodestralEditPredictionDelegate {
Some(EditPrediction::Local {
id: None,
edits,
+ cursor_position: None,
edit_preview: Some(current_completion.edit_preview.clone()),
})
}
@@ -177,6 +177,7 @@ impl EditPredictionDelegate for CopilotEditPredictionDelegate {
Some(EditPrediction::Local {
id: None,
edits,
+ cursor_position: None,
edit_preview: Some(edit_preview.clone()),
})
}
@@ -1433,6 +1433,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let prediction = EditPrediction {
edits,
+ cursor_position: None,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
@@ -1,3 +1,4 @@
+use crate::udiff::DiffLine;
use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc};
@@ -491,6 +492,123 @@ impl ExampleSpec {
self.cursor_position = result;
}
+
+ /// Returns all of the possible expected patches for this example, each with an optional
+ /// cursor offset.
+ ///
+ /// The cursor offset is an offset within the new text (after applying the patch), relative
+ /// to the start of the hunk.
+ ///
+ /// In the serialized representation of this example, the cursor position is represented
+ /// using a comment line in the diff, beginning with `#`, and containing a `[CURSOR_POSITION]`
+ /// marker with the same format as the [`Self::cursor_excerpt`].
+ pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option<usize>)> {
+ self.expected_patches
+ .iter()
+ .map(|patch| {
+ let mut clean_patch = String::new();
+ let mut cursor_offset: Option<usize> = None;
+ let mut line_start_offset = 0usize;
+ let mut prev_line_start_offset = 0usize;
+
+ for line in patch.lines() {
+ let diff_line = DiffLine::parse(line);
+
+ match &diff_line {
+ DiffLine::Garbage(content)
+ if content.starts_with('#')
+ && content.contains(CURSOR_POSITION_MARKER) =>
+ {
+ let caret_column = if let Some(caret_pos) = content.find('^') {
+ caret_pos
+ } else if let Some(_) = content.find('<') {
+ 0
+ } else {
+ continue;
+ };
+ let cursor_column = caret_column.saturating_sub('#'.len_utf8());
+ cursor_offset = Some(prev_line_start_offset + cursor_column);
+ }
+ _ => {
+ if !clean_patch.is_empty() {
+ clean_patch.push('\n');
+ }
+ clean_patch.push_str(line);
+
+ match diff_line {
+ DiffLine::Addition(content) | DiffLine::Context(content) => {
+ prev_line_start_offset = line_start_offset;
+ line_start_offset += content.len() + 1;
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ if patch.ends_with('\n') && !clean_patch.is_empty() {
+ clean_patch.push('\n');
+ }
+
+ (clean_patch, cursor_offset)
+ })
+ .collect()
+ }
+
+ pub fn set_expected_patches_with_cursor_positions(
+ &mut self,
+ patches: Vec<(String, Option<usize>)>,
+ ) {
+ self.expected_patches = patches
+ .into_iter()
+ .map(|(patch, cursor_offset)| {
+ let Some(cursor_offset) = cursor_offset else {
+ return patch;
+ };
+
+ let mut result = String::new();
+ let mut line_start_offset = 0usize;
+
+ for line in patch.lines() {
+ if !result.is_empty() {
+ result.push('\n');
+ }
+ result.push_str(line);
+
+ match DiffLine::parse(line) {
+ DiffLine::Addition(content) => {
+ let line_end_offset = line_start_offset + content.len();
+
+ if cursor_offset >= line_start_offset
+ && cursor_offset <= line_end_offset
+ {
+ let cursor_column = cursor_offset - line_start_offset;
+
+ result.push('\n');
+ result.push('#');
+ for _ in 0..cursor_column {
+ result.push(' ');
+ }
+ write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
+ }
+
+ line_start_offset = line_end_offset + 1;
+ }
+ DiffLine::Context(content) => {
+ line_start_offset += content.len() + 1;
+ }
+ _ => {}
+ }
+ }
+
+ if patch.ends_with('\n') {
+ result.push('\n');
+ }
+
+ result
+ })
+ .collect();
+ }
}
#[cfg(test)]
@@ -707,4 +825,76 @@ mod tests {
(expected_excerpt.to_string(), expected_offset)
);
}
+
+ #[test]
+ fn test_expected_patches_with_cursor_positions() {
+ let mut spec = ExampleSpec {
+ name: String::new(),
+ repository_url: String::new(),
+ revision: String::new(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: Path::new("test.rs").into(),
+ cursor_position: String::new(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ rejected_patch: None,
+ captured_prompt_input: None,
+ telemetry: None,
+ human_feedback: Vec::new(),
+ rating: None,
+ };
+
+ let new_content = indoc! {r#"
+ // prints a greeting
+ fn main() {
+ println!("hello, {}", );
+ let x = 42;
+ }
+ "#};
+ let cursor_offset = new_content.find(");").unwrap();
+
+ let clean_patch = indoc! {r#"
+ --- a/test.rs
+ +++ b/test.rs
+ @@ -1,3 +1,4 @@
+ +// prints a greeting
+ fn main() {
+ - println!("hi");
+ + println!("hello, {}", );
+ let x = 42;
+ }
+ "#}
+ .to_string();
+
+ let encoded_patch = indoc! {r#"
+ --- a/test.rs
+ +++ b/test.rs
+ @@ -1,3 +1,4 @@
+ +// prints a greeting
+ fn main() {
+ - println!("hi");
+ + println!("hello, {}", );
+ # ^[CURSOR_POSITION]
+ let x = 42;
+ }
+ "#}
+ .to_string();
+
+ spec.set_expected_patches_with_cursor_positions(vec![(
+ clean_patch.clone(),
+ Some(cursor_offset),
+ )]);
+ assert_eq!(spec.expected_patches, vec![encoded_patch]);
+
+ let results = spec.expected_patches_with_cursor_positions();
+ assert_eq!(results, vec![(clean_patch.clone(), Some(cursor_offset))]);
+
+ spec.set_expected_patches_with_cursor_positions(vec![(clean_patch.clone(), None)]);
+ assert_eq!(spec.expected_patches, vec![clean_patch.clone()]);
+
+ let results = spec.expected_patches_with_cursor_positions();
+ assert_eq!(results, vec![(clean_patch, None)]);
+ }
}
@@ -206,6 +206,7 @@ impl Mercury {
&buffer,
&old_snapshot,
edits.into(),
+ None,
buffer_snapshotted_at,
response_received_at,
inputs,
@@ -5,7 +5,7 @@ use std::{
};
use cloud_llm_client::EditPredictionRejectReason;
-use edit_prediction_types::interpolate_edits;
+use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
use zeta_prompt::ZetaPromptInput;
@@ -37,6 +37,7 @@ impl EditPredictionResult {
edited_buffer: &Entity<Buffer>,
edited_buffer_snapshot: &BufferSnapshot,
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+ cursor_position: Option<PredictedCursorPosition>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
inputs: ZetaPromptInput,
@@ -71,6 +72,7 @@ impl EditPredictionResult {
prediction: Ok(EditPrediction {
id,
edits,
+ cursor_position,
snapshot,
edit_preview,
inputs,
@@ -86,6 +88,7 @@ impl EditPredictionResult {
pub struct EditPrediction {
pub id: EditPredictionId,
pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+ pub cursor_position: Option<PredictedCursorPosition>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
pub buffer: Entity<Buffer>,
@@ -143,6 +146,7 @@ mod tests {
let prediction = EditPrediction {
id: EditPredictionId("prediction-1".into()),
edits,
+ cursor_position: None,
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
@@ -291,6 +291,7 @@ impl SweepAi {
&buffer,
&old_snapshot,
edits.into(),
+ None,
buffer_snapshotted_at,
response_received_at,
inputs,
@@ -190,38 +190,6 @@ pub async fn refresh_worktree_entries(
Ok(())
}
-/// Extract the diff for a specific file from a multi-file diff.
-/// Returns an error if the file is not found in the diff.
-pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
- let mut result = String::new();
- let mut in_target_file = false;
- let mut found_file = false;
-
- for line in full_diff.lines() {
- if line.starts_with("diff --git") {
- if in_target_file {
- break;
- }
- in_target_file = line.contains(&format!("a/{}", file_path))
- || line.contains(&format!("b/{}", file_path));
- if in_target_file {
- found_file = true;
- }
- }
-
- if in_target_file {
- result.push_str(line);
- result.push('\n');
- }
- }
-
- if !found_file {
- anyhow::bail!("File '{}' not found in diff", file_path);
- }
-
- Ok(result)
-}
-
pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
if prefix.is_empty() {
return Cow::Borrowed(diff);
@@ -319,9 +287,20 @@ fn disambiguate_by_line_number(
}
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
+ apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text)
+}
+
+/// Applies a diff to a string and returns the result along with the offset where
+/// the first hunk's context matched in the original text. This offset can be used
+/// to adjust cursor positions that are relative to the hunk's content.
+pub fn apply_diff_to_string_with_hunk_offset(
+ diff_str: &str,
+ text: &str,
+) -> Result<(String, Option<usize>)> {
let mut diff = DiffParser::new(diff_str);
let mut text = text.to_string();
+ let mut first_hunk_offset = None;
while let Some(event) = diff.next().context("Failed to parse diff")? {
match event {
@@ -342,6 +321,10 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
})
.ok_or_else(|| anyhow!("couldn't resolve hunk"))?;
+ if first_hunk_offset.is_none() {
+ first_hunk_offset = Some(hunk_offset);
+ }
+
for edit in hunk.edits.iter().rev() {
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
text.replace_range(range, &edit.text);
@@ -351,7 +334,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
}
}
- Ok(text)
+ Ok((text, first_hunk_offset))
}
/// Returns the individual edits that would be applied by a diff to the given content.
@@ -1457,63 +1440,6 @@ mod tests {
FakeFs::new(cx.background_executor.clone())
}
- #[test]
- fn test_extract_file_diff() {
- let multi_file_diff = indoc! {r#"
- diff --git a/file1.txt b/file1.txt
- index 1234567..abcdefg 100644
- --- a/file1.txt
- +++ b/file1.txt
- @@ -1,3 +1,4 @@
- line1
- +added line
- line2
- line3
- diff --git a/file2.txt b/file2.txt
- index 2345678..bcdefgh 100644
- --- a/file2.txt
- +++ b/file2.txt
- @@ -1,2 +1,2 @@
- -old line
- +new line
- unchanged
- "#};
-
- let file1_diff = extract_file_diff(multi_file_diff, "file1.txt").unwrap();
- assert_eq!(
- file1_diff,
- indoc! {r#"
- diff --git a/file1.txt b/file1.txt
- index 1234567..abcdefg 100644
- --- a/file1.txt
- +++ b/file1.txt
- @@ -1,3 +1,4 @@
- line1
- +added line
- line2
- line3
- "#}
- );
-
- let file2_diff = extract_file_diff(multi_file_diff, "file2.txt").unwrap();
- assert_eq!(
- file2_diff,
- indoc! {r#"
- diff --git a/file2.txt b/file2.txt
- index 2345678..bcdefgh 100644
- --- a/file2.txt
- +++ b/file2.txt
- @@ -1,2 +1,2 @@
- -old line
- +new line
- unchanged
- "#}
- );
-
- let result = extract_file_diff(multi_file_diff, "nonexistent.txt");
- assert!(result.is_err());
- }
-
#[test]
fn test_edits_for_diff() {
let content = indoc! {"
@@ -223,6 +223,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
+ cursor_position: None,
edit_preview: Some(prediction.edit_preview.clone()),
})
})
@@ -274,6 +274,7 @@ fn process_completion_response(
&buffer,
&snapshot,
edits,
+ None,
buffer_snapshotted_at,
received_response_at,
inputs,
@@ -8,8 +8,9 @@ use crate::{
use anyhow::{Result, anyhow};
use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
+use edit_prediction_types::PredictedCursorPosition;
use gpui::{App, Task, prelude::*};
-use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
+use language::{OffsetRangeExt as _, ToOffset as _, ToPoint, text_diff};
use release_channel::AppVersion;
use std::env;
@@ -145,9 +146,10 @@ pub fn request_prediction_with_zeta2(
.ok();
}
- if output_text.contains(CURSOR_MARKER) {
- log::trace!("Stripping out {CURSOR_MARKER} from response");
- output_text = output_text.replace(CURSOR_MARKER, "");
+ let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
+ if let Some(offset) = cursor_offset_in_output {
+ log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
+ output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
}
if zeta_version == ZetaVersion::V0120GitMergeMarkers {
@@ -170,12 +172,22 @@ pub fn request_prediction_with_zeta2(
}
let edits = compute_edits(
- old_text,
+ old_text.clone(),
&output_text,
editable_offset_range.start,
&snapshot,
);
+ let cursor_position = cursor_offset_in_output.map(|cursor_offset| {
+ compute_predicted_cursor_position(
+ &old_text,
+ &output_text,
+ cursor_offset,
+ editable_offset_range.start,
+ &snapshot,
+ )
+ });
+
anyhow::Ok((
Some((
request_id,
@@ -184,6 +196,7 @@ pub fn request_prediction_with_zeta2(
buffer,
snapshot.clone(),
edits,
+ cursor_position,
received_response_at,
)),
)),
@@ -199,8 +212,14 @@ pub fn request_prediction_with_zeta2(
return Ok(None);
};
- let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
- prediction
+ let Some((
+ inputs,
+ edited_buffer,
+ edited_buffer_snapshot,
+ edits,
+ cursor_position,
+ received_response_at,
+ )) = prediction
else {
return Ok(Some(EditPredictionResult {
id,
@@ -214,6 +233,7 @@ pub fn request_prediction_with_zeta2(
&edited_buffer,
&edited_buffer_snapshot,
edits.into(),
+ cursor_position,
buffer_snapshotted_at,
received_response_at,
inputs,
@@ -224,6 +244,65 @@ pub fn request_prediction_with_zeta2(
})
}
+/// Computes a `PredictedCursorPosition` from a cursor offset in the output text.
+///
+/// The cursor offset is relative to `new_text`. We need to determine if the cursor
+/// falls inside an edit's inserted text or in unchanged text:
+/// - If inside an edit: anchor = start of edit range, offset = position within insertion
+/// - If in unchanged text: anchor = corresponding position in old buffer, offset = 0
+fn compute_predicted_cursor_position(
+ old_text: &str,
+ new_text: &str,
+ cursor_offset_in_new: usize,
+ editable_region_start: usize,
+ snapshot: &language::BufferSnapshot,
+) -> PredictedCursorPosition {
+ let diffs = text_diff(old_text, new_text);
+
+ // Track position in both old and new text as we walk through diffs
+ let mut old_pos = 0usize;
+ let mut new_pos = 0usize;
+
+ for (old_range, new_text_chunk) in &diffs {
+ // Text before this diff is unchanged
+ let unchanged_len = old_range.start - old_pos;
+ let unchanged_end_in_new = new_pos + unchanged_len;
+
+ if cursor_offset_in_new < unchanged_end_in_new {
+ // Cursor is in unchanged text before this diff
+ let offset_in_unchanged = cursor_offset_in_new - new_pos;
+ let buffer_offset = editable_region_start + old_pos + offset_in_unchanged;
+ return PredictedCursorPosition::at_anchor(snapshot.anchor_after(buffer_offset));
+ }
+
+ // Move past the unchanged portion in new_text coordinates
+ new_pos = unchanged_end_in_new;
+
+ // Check if cursor is within this edit's new text
+ let edit_new_text_end = new_pos + new_text_chunk.len();
+ if cursor_offset_in_new < edit_new_text_end {
+ // Cursor is inside this edit's inserted text.
+ // Use anchor_before (left bias) so the anchor stays at the insertion point
+ // rather than moving past the inserted text.
+ let offset_within_insertion = cursor_offset_in_new - new_pos;
+ let buffer_offset = editable_region_start + old_range.start;
+ return PredictedCursorPosition::new(
+ snapshot.anchor_before(buffer_offset),
+ offset_within_insertion,
+ );
+ }
+
+ // Move past this edit
+ old_pos = old_range.end;
+ new_pos = edit_new_text_end;
+ }
+
+ // Cursor is in unchanged text after all diffs
+ let offset_in_unchanged = cursor_offset_in_new - new_pos;
+ let buffer_offset = (editable_region_start + old_pos + offset_in_unchanged).min(snapshot.len());
+ PredictedCursorPosition::at_anchor(snapshot.anchor_after(buffer_offset))
+}
+
pub fn zeta2_prompt_input(
snapshot: &language::BufferSnapshot,
related_files: Vec<zeta_prompt::RelatedFile>,
@@ -15,10 +15,12 @@ pub async fn run_distill(example: &mut Example) -> Result<()> {
let expected_patches = predictions
.into_iter()
- .filter_map(|p| p.actual_patch.clone())
+ .filter_map(|p| Some((p.actual_patch.clone()?, p.actual_cursor_offset)))
.collect();
- example.spec.expected_patches = expected_patches;
+ example
+ .spec
+ .set_expected_patches_with_cursor_positions(expected_patches);
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
@@ -86,6 +86,8 @@ pub struct ExamplePrediction {
#[serde(deserialize_with = "deserialize_null_as_empty_string")]
pub actual_output: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
+ pub actual_cursor_offset: Option<usize>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub provider: PredictionProvider,
}
@@ -110,6 +112,10 @@ pub struct ExampleScore {
pub exact_lines_fn: usize,
#[serde(default)]
pub reversal_ratio: f32,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cursor_distance: Option<usize>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cursor_exact_match: Option<bool>,
}
impl Example {
@@ -6,7 +6,7 @@ use crate::{
retrieve_context::run_context_retrieval,
};
use anyhow::{Context as _, Result, anyhow};
-use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
+use edit_prediction::{cursor_excerpt::editable_and_context_ranges_for_cursor_position, udiff};
use gpui::{AppContext, AsyncApp};
use language::{Buffer, OffsetRangeExt, Point};
use similar::DiffableStr;
@@ -61,18 +61,10 @@ pub async fn run_format_prompt(
let context_range = context_range.to_offset(&snapshot);
let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
- let expected_output = example
- .spec
- .expected_patches
- .first()
- .cloned()
- .unwrap_or_default();
- let rejected_output = example.spec.rejected_patch.clone();
-
example.prompt = Some(ExamplePrompt {
input: prompt,
- expected_output,
- rejected_output,
+ expected_output: String::new(),
+ rejected_output: None,
provider: args.provider,
});
}
@@ -102,21 +94,19 @@ pub async fn run_format_prompt(
related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
};
let prompt = format_zeta_prompt(&input, version);
- let expected_output = zeta2_output_for_patch(
- &input,
- &example
- .spec
- .expected_patches
- .first()
- .context("expected patches is empty")?
- .clone(),
- version,
- )?;
+ let (expected_patch, expected_cursor_offset) = example
+ .spec
+ .expected_patches_with_cursor_positions()
+ .into_iter()
+ .next()
+ .context("expected patches is empty")?;
+ let expected_output =
+ zeta2_output_for_patch(&input, &expected_patch, expected_cursor_offset, version)?;
let rejected_output = example
.spec
.rejected_patch
.as_ref()
- .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok());
+ .and_then(|patch| zeta2_output_for_patch(&input, patch, None, version).ok());
example.prompt = Some(ExamplePrompt {
input: prompt,
@@ -135,6 +125,7 @@ pub async fn run_format_prompt(
pub fn zeta2_output_for_patch(
input: &zeta_prompt::ZetaPromptInput,
patch: &str,
+ cursor_offset: Option<usize>,
version: ZetaVersion,
) -> Result<String> {
let mut old_editable_region =
@@ -144,13 +135,24 @@ pub fn zeta2_output_for_patch(
old_editable_region.push('\n');
}
- let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
- .with_context(|| {
- format!(
- "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
- patch, old_editable_region
- )
- })?;
+ let (mut result, first_hunk_offset) =
+ udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
+ || {
+ format!(
+ "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
+ patch, old_editable_region
+ )
+ },
+ )?;
+
+ if let Some(cursor_offset) = cursor_offset {
+ // The cursor_offset is relative to the start of the hunk's new text (context + additions).
+ // We need to add where the hunk context matched in the editable region to compute
+ // the actual cursor position in the result.
+ let hunk_start = first_hunk_offset.unwrap_or(0);
+ let offset = (hunk_start + cursor_offset).min(result.len());
+ result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
+ }
if version == ZetaVersion::V0120GitMergeMarkers {
if !result.ends_with('\n') {
@@ -191,24 +193,28 @@ impl TeacherPrompt {
prompt
}
- pub fn parse(example: &Example, response: &str) -> Result<String> {
+ pub fn parse(example: &Example, response: &str) -> Result<(String, Option<usize>)> {
// Extract updated (new) editable region from the model response.
// The model may include editable region markers in its output, so we need to strip them.
let new_editable_region = extract_last_codeblock(response);
// Check if the model indicated no edits are needed
if new_editable_region.trim() == Self::NO_EDITS {
- return Ok(String::new());
+ return Ok((String::new(), None));
}
- let mut new_editable_region = Self::extract_editable_region(&new_editable_region)?;
+ let new_editable_region = Self::extract_editable_region(&new_editable_region)?;
+ let cursor_offset = new_editable_region.find(Self::USER_CURSOR_MARKER);
+ let mut new_editable_region = new_editable_region.replace(Self::USER_CURSOR_MARKER, "");
let old_editable_region = Self::extract_editable_region(
&example
.prompt
.as_ref()
.context("example prompt missing")?
.input,
- )?;
+ )?
+ .replace(Self::USER_CURSOR_MARKER, "");
+
let prompt_inputs = example
.prompt_inputs
.as_ref()
@@ -230,11 +236,14 @@ impl TeacherPrompt {
.matches('\n')
.count();
- let diff = language::unified_diff_with_offsets(
+ // Use full context so cursor offset (relative to editable region start) aligns with diff content
+ let editable_region_lines = old_editable_region.lines().count() as u32;
+ let diff = language::unified_diff_with_context(
&old_editable_region,
&new_editable_region,
editable_region_start_line as u32,
editable_region_start_line as u32,
+ editable_region_lines,
);
let diff = indoc::formatdoc! {"
@@ -245,7 +254,7 @@ impl TeacherPrompt {
diff = diff,
};
- Ok(diff)
+ Ok((diff, cursor_offset))
}
fn format_edit_history(edit_history: &str) -> String {
@@ -328,7 +337,7 @@ impl TeacherPrompt {
result
}
- fn extract_editable_region(text: &str) -> Result<String> {
+ pub fn extract_editable_region(text: &str) -> Result<String> {
let start = text
.rfind(Self::EDITABLE_REGION_START)
.map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
@@ -339,9 +348,7 @@ impl TeacherPrompt {
}
let region = &text[start..end];
- let region = region.strip_suffix('\n').unwrap_or(region);
-
- Ok(region.replace(Self::USER_CURSOR_MARKER, ""))
+ Ok(region.strip_suffix('\n').unwrap_or(region).to_string())
}
fn is_udiff_content_line(s: &str) -> bool {
@@ -571,22 +578,4 @@ mod tests {
let codeblock = extract_last_codeblock(response);
assert_eq!(codeblock.trim(), TeacherPrompt::NO_EDITS);
}
-
- #[test]
- fn test_extract_editable_region_strips_cursor_marker() {
- let text = indoc::indoc! {"
- <|editable_region_start|>
- one
- <|user_cursor|>two three
-
- <|editable_region_end|>
- "};
- let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
- assert_eq!(
- parsed,
- indoc::indoc! {"
- one
- two three"}
- );
- }
}
@@ -19,14 +19,14 @@ pub fn run_parse_output(example: &mut Example) -> Result<()> {
.enumerate()
.filter(|(_, p)| !p.actual_output.is_empty())
.map(|(ix, prediction)| {
- let actual_patch =
- parse_prediction_output(example, &prediction.actual_output, provider);
- actual_patch.map(|patch| (ix, patch))
+ let result = parse_prediction_output(example, &prediction.actual_output, provider);
+ result.map(|(patch, cursor_offset)| (ix, patch, cursor_offset))
})
.collect::<Result<Vec<_>>>()?;
- for (ix, actual_patch) in parsed_patches {
+ for (ix, actual_patch, actual_cursor_offset) in parsed_patches {
example.predictions[ix].actual_patch = Some(actual_patch);
+ example.predictions[ix].actual_cursor_offset = actual_cursor_offset;
example.predictions[ix].provider = provider;
}
@@ -37,7 +37,7 @@ pub fn parse_prediction_output(
example: &Example,
actual_output: &str,
provider: PredictionProvider,
-) -> Result<String> {
+) -> Result<(String, Option<usize>)> {
match provider {
PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
TeacherPrompt::parse(example, actual_output)
@@ -83,7 +83,7 @@ fn parse_zeta2_output(
example: &Example,
actual_output: &str,
version: ZetaVersion,
-) -> Result<String> {
+) -> Result<(String, Option<usize>)> {
let prompt = &example.prompt.as_ref().context("prompt required")?.input;
let prompt_inputs = example
.prompt_inputs
@@ -92,7 +92,13 @@ fn parse_zeta2_output(
let old_text = extract_zeta2_current_region(prompt, version)?;
- let mut new_text = actual_output.replace(CURSOR_MARKER, "");
+ let mut new_text = actual_output.to_string();
+ let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
+ new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
+ Some(offset)
+ } else {
+ None
+ };
if version == ZetaVersion::V0120GitMergeMarkers {
if let Some(stripped) =
@@ -126,11 +132,14 @@ fn parse_zeta2_output(
.matches('\n')
.count();
- let diff = language::unified_diff_with_offsets(
+ // Use full context so cursor offset (relative to editable region start) aligns with diff content
+ let editable_region_lines = old_text_normalized.lines().count() as u32;
+ let diff = language::unified_diff_with_context(
&old_text_normalized,
&new_text,
editable_region_start_line as u32,
editable_region_start_line as u32,
+ editable_region_lines,
);
let formatted_diff = format!(
@@ -138,7 +147,7 @@ fn parse_zeta2_output(
path = example.spec.cursor_path.to_string_lossy(),
);
- Ok(formatted_diff)
+ Ok((formatted_diff, cursor_offset))
}
#[cfg(test)]
@@ -201,6 +201,7 @@ pub async fn run_prediction(
.push(ExamplePrediction {
actual_patch: None,
actual_output: String::new(),
+ actual_cursor_offset: None,
error: None,
provider,
});
@@ -322,11 +323,12 @@ async fn predict_anthropic(
.collect::<Vec<String>>()
.join("\n");
- let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+ let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch: Some(actual_patch),
actual_output,
+ actual_cursor_offset,
error: None,
provider: if batched {
PredictionProvider::Teacher(backend)
@@ -394,11 +396,12 @@ async fn predict_openai(
.collect::<Vec<String>>()
.join("\n");
- let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+ let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch: Some(actual_patch),
actual_output,
+ actual_cursor_offset,
error: None,
provider: if batched {
PredictionProvider::Teacher(backend)
@@ -125,11 +125,12 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
/// Parse the repair response into a prediction.
fn parse_repair_response(example: &Example, response_text: &str) -> Result<ExamplePrediction> {
- let actual_patch = TeacherPrompt::parse(example, response_text)?;
+ let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, response_text)?;
Ok(ExamplePrediction {
actual_patch: Some(actual_patch),
actual_output: response_text.to_string(),
+ actual_cursor_offset,
error: None,
provider: PredictionProvider::Repair,
})
@@ -362,6 +363,7 @@ pub async fn run_repair(
example.predictions.push(ExamplePrediction {
actual_patch: None,
actual_output: response_text.clone(),
+ actual_cursor_offset: None,
error: Some(format!("Failed to parse repair response: {}", e)),
provider: PredictionProvider::Repair,
});
@@ -1,6 +1,7 @@
use crate::{
- PredictArgs,
+ PredictArgs, PredictionProvider,
example::{Example, ExampleScore},
+ format_prompt::TeacherPrompt,
headless::EpAppState,
metrics,
parse_output::parse_prediction_output,
@@ -9,7 +10,7 @@ use crate::{
reversal_tracking,
};
use anyhow::Context as _;
-use edit_prediction::udiff::apply_diff_to_string;
+use edit_prediction::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
use gpui::AsyncApp;
use serde::Serialize;
use std::fs::File;
@@ -34,16 +35,36 @@ pub async fn run_scoring(
.as_ref()
.context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
.content;
- let expected_texts: Vec<String> = example
- .spec
- .expected_patches
+ let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
+
+ let expected_texts: Vec<String> = expected_patches_with_cursors
.iter()
- .map(|patch| {
+ .map(|(patch, _)| {
apply_diff_to_string(patch, original_text)
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
})
.collect::<Result<Vec<_>, _>>()?;
+ // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
+ // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
+ // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
+ // region to find where the hunk matched, then compute the expected cursor position.
+ let old_editable_region = if let Some(p) = example.prompt.as_ref() {
+ if matches!(
+ p.provider,
+ PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
+ ) {
+ Some(
+ TeacherPrompt::extract_editable_region(&p.input)?
+ .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
+ )
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
let zero_scores = ExampleScore {
delta_chr_f: 0.0,
braces_disbalance: 0,
@@ -51,6 +72,8 @@ pub async fn run_scoring(
exact_lines_fp: 0,
exact_lines_fn: 0,
reversal_ratio: 0.0,
+ cursor_distance: None,
+ cursor_exact_match: None,
};
let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
@@ -60,7 +83,9 @@ pub async fn run_scoring(
let mut scores = vec![];
for prediction in &example.predictions {
let actual_patch = prediction.actual_patch.clone().or_else(|| {
- parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
+ parse_prediction_output(example, &prediction.actual_output, prediction.provider)
+ .ok()
+ .map(|(patch, _)| patch)
});
let Some(actual_patch) = actual_patch else {
@@ -75,10 +100,42 @@ pub async fn run_scoring(
continue;
}
};
- let best_delta_chr_f = expected_texts
- .iter()
- .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
- .fold(0.0, f32::max);
+
+ let mut best_delta_chr_f = 0.0f32;
+ let mut best_expected_cursor: Option<usize> = None;
+ let mut best_patch_idx: Option<usize> = None;
+
+ for (idx, expected) in expected_texts.iter().enumerate() {
+ let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32;
+ if delta_chr_f > best_delta_chr_f {
+ best_delta_chr_f = delta_chr_f;
+ best_patch_idx = Some(idx);
+ }
+ }
+
+ if let Some(idx) = best_patch_idx {
+ // Get the raw cursor offset from the expected patch (relative to hunk new text)
+ let expected_cursor_in_patch = expected_patches_with_cursors
+ .get(idx)
+ .and_then(|(_, cursor)| *cursor);
+
+ // For Teacher prompts, we need to apply the patch to the editable region
+ // to find where the hunk matched, then compute the actual cursor position
+ if let (Some(editable_region), Some(cursor_in_patch)) =
+ (&old_editable_region, expected_cursor_in_patch)
+ {
+ let (patch, _) = &expected_patches_with_cursors[idx];
+ if let Ok((_, hunk_offset)) =
+ apply_diff_to_string_with_hunk_offset(patch, editable_region)
+ {
+ let hunk_start = hunk_offset.unwrap_or(0);
+ best_expected_cursor = Some(hunk_start + cursor_in_patch);
+ }
+ } else {
+ // For non-Teacher prompts or if we can't compute, use raw offset
+ best_expected_cursor = expected_cursor_in_patch;
+ }
+ }
let disbalance_before = metrics::braces_disbalance(&original_text);
let disbalance_after = metrics::braces_disbalance(&actual_text);
@@ -95,11 +152,9 @@ pub async fn run_scoring(
}
// Compute exact lines match against best matching expected patch
- let best_exact_lines = example
- .spec
- .expected_patches
+ let best_exact_lines = expected_patches_with_cursors
.iter()
- .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
+ .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
.max_by_key(|m| m.true_positives)
.unwrap_or_default();
@@ -110,6 +165,10 @@ pub async fn run_scoring(
cursor_path,
);
+ // Compute cursor position metrics
+ let (cursor_distance, cursor_exact_match) =
+ compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor_offset);
+
scores.push(ExampleScore {
delta_chr_f: best_delta_chr_f,
braces_disbalance,
@@ -117,6 +176,8 @@ pub async fn run_scoring(
exact_lines_fp: best_exact_lines.false_positives,
exact_lines_fn: best_exact_lines.false_negatives,
reversal_ratio,
+ cursor_distance,
+ cursor_exact_match,
});
}
@@ -124,16 +185,37 @@ pub async fn run_scoring(
Ok(())
}
+fn compute_cursor_metrics(
+ expected_cursor: Option<usize>,
+ actual_cursor: Option<usize>,
+) -> (Option<usize>, Option<bool>) {
+ match (expected_cursor, actual_cursor) {
+ (Some(expected), Some(actual)) => {
+ let distance = expected.abs_diff(actual);
+ let exact_match = expected == actual;
+ (Some(distance), Some(exact_match))
+ }
+ (None, None) => {
+ // Neither has cursor position - skip cursor scoring
+ (None, None)
+ }
+ (Some(_), None) | (None, Some(_)) => {
+ // Only one has cursor position - count as miss
+ (None, Some(false))
+ }
+ }
+}
+
pub fn print_report(examples: &[Example]) {
use crate::metrics::ClassificationMetrics;
- const LINE_WIDTH: usize = 82;
+ const LINE_WIDTH: usize = 94;
let separator = "─".repeat(LINE_WIDTH);
println!("{}", separator);
println!(
- "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7}",
- "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf"
+ "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6}",
+ "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor"
);
println!("{}", separator);
@@ -146,6 +228,10 @@ pub fn print_report(examples: &[Example]) {
let mut qa_reverts_total: usize = 0;
let mut qa_confidence_sum: u64 = 0;
let mut qa_confidence_count: usize = 0;
+ let mut cursor_exact_matches: usize = 0;
+ let mut cursor_total: usize = 0;
+ let mut cursor_distance_sum: usize = 0;
+ let mut cursor_distance_count: usize = 0;
for example in examples {
for (score_idx, score) in example.score.iter().enumerate() {
@@ -166,15 +252,24 @@ pub fn print_report(examples: &[Example]) {
.map(|v| format!("{}", v))
.unwrap_or("-".to_string());
+ // Format cursor metric
+ let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
+ (Some(true), _) => "✓".to_string(),
+ (Some(false), Some(dist)) => format!("±{}", dist),
+ (Some(false), None) => "✗".to_string(),
+ (None, _) => "-".to_string(),
+ };
+
println!(
- "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7}",
+ "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
truncate_name(&example.spec.name, 40),
score.delta_chr_f,
score.braces_disbalance,
exact_lines.f1() * 100.0,
score.reversal_ratio * 100.0,
qa_reverts_str,
- qa_conf_str
+ qa_conf_str,
+ cursor_str
);
all_delta_chr_f_scores.push(score.delta_chr_f);
@@ -198,6 +293,18 @@ pub fn print_report(examples: &[Example]) {
qa_confidence_count += 1;
}
}
+
+ // Accumulate cursor metrics
+ if let Some(exact_match) = score.cursor_exact_match {
+ cursor_total += 1;
+ if exact_match {
+ cursor_exact_matches += 1;
+ }
+ }
+ if let Some(dist) = score.cursor_distance {
+ cursor_distance_sum += dist;
+ cursor_distance_count += 1;
+ }
}
}
@@ -226,18 +333,43 @@ pub fn print_report(examples: &[Example]) {
} else {
"-".to_string()
};
+ let cursor_str = if cursor_total > 0 {
+ format!(
+ "{:.0}%",
+ cursor_exact_matches as f32 / cursor_total as f32 * 100.0
+ )
+ } else {
+ "-".to_string()
+ };
+ let avg_cursor_distance = if cursor_distance_count > 0 {
+ Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
+ } else {
+ None
+ };
println!(
- "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7}",
+ "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
"TOTAL / AVERAGE",
avg_delta_chr_f,
braces_disbalance_avg,
total_exact_lines.f1() * 100.0,
avg_reversal_ratio * 100.0,
qa_reverts_str,
- qa_conf_str
+ qa_conf_str,
+ cursor_str
);
println!("{}", separator);
+
+ // Print additional cursor metrics if available
+ if let Some(avg_dist) = avg_cursor_distance {
+ println!(
+ "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
+ cursor_exact_matches,
+ cursor_total,
+ cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
+ avg_dist
+ );
+ }
}
println!("\n");
@@ -267,6 +399,12 @@ pub struct SummaryJson {
pub qa_avg_reverts_edits: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub qa_avg_confidence: Option<f32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub cursor_exact_match_rate: Option<f32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub cursor_avg_distance: Option<f32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub cursor_total_evaluated: Option<usize>,
}
pub fn compute_summary(examples: &[Example]) -> SummaryJson {
@@ -281,6 +419,10 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
let mut qa_reverts_total: usize = 0;
let mut qa_confidence_sum: u64 = 0;
let mut qa_confidence_count: usize = 0;
+ let mut cursor_exact_matches: usize = 0;
+ let mut cursor_total: usize = 0;
+ let mut cursor_distance_sum: usize = 0;
+ let mut cursor_distance_count: usize = 0;
for example in examples {
for (score_idx, score) in example.score.iter().enumerate() {
@@ -305,6 +447,18 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
qa_confidence_count += 1;
}
}
+
+ // Accumulate cursor metrics
+ if let Some(exact_match) = score.cursor_exact_match {
+ cursor_total += 1;
+ if exact_match {
+ cursor_exact_matches += 1;
+ }
+ }
+ if let Some(dist) = score.cursor_distance {
+ cursor_distance_sum += dist;
+ cursor_distance_count += 1;
+ }
}
}
@@ -338,6 +492,24 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
None
};
+ let cursor_exact_match_rate = if cursor_total > 0 {
+ Some(cursor_exact_matches as f32 / cursor_total as f32)
+ } else {
+ None
+ };
+
+ let cursor_avg_distance = if cursor_distance_count > 0 {
+ Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
+ } else {
+ None
+ };
+
+ let cursor_total_evaluated = if cursor_total > 0 {
+ Some(cursor_total)
+ } else {
+ None
+ };
+
SummaryJson {
total_examples: total_scores,
avg_delta_chr_f,
@@ -351,6 +523,9 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
avg_reversal_ratio,
qa_avg_reverts_edits,
qa_avg_confidence,
+ cursor_exact_match_rate,
+ cursor_avg_distance,
+ cursor_total_evaluated,
}
}
@@ -4,6 +4,34 @@ use client::EditPredictionUsage;
use gpui::{App, Context, Entity, SharedString};
use language::{Anchor, Buffer, OffsetRangeExt};
+/// Represents a predicted cursor position after an edit is applied.
+///
+/// Since the cursor may be positioned inside newly inserted text that doesn't
+/// exist in the original buffer, we store an anchor (which points to a position
+/// in the original buffer, typically the start of an edit) plus an offset into
+/// the inserted text.
+#[derive(Clone, Debug)]
+pub struct PredictedCursorPosition {
+ /// An anchor in the original buffer. If the cursor is inside an edit,
+ /// this points to the start of that edit's range.
+ pub anchor: language::Anchor,
+ /// Offset from the anchor into the new text. If the cursor is inside
+ /// inserted text, this is the offset within that insertion. If the cursor
+ /// is outside any edit, this is 0.
+ pub offset: usize,
+}
+
+impl PredictedCursorPosition {
+ pub fn new(anchor: language::Anchor, offset: usize) -> Self {
+ Self { anchor, offset }
+ }
+
+ /// Creates a predicted cursor position at an exact anchor location (offset = 0).
+ pub fn at_anchor(anchor: language::Anchor) -> Self {
+ Self { anchor, offset: 0 }
+ }
+}
+
/// The display mode used when showing an edit prediction to the user.
/// Used for metrics tracking.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -29,6 +57,7 @@ pub enum EditPrediction {
Local {
id: Option<SharedString>,
edits: Vec<(Range<language::Anchor>, Arc<str>)>,
+ cursor_position: Option<PredictedCursorPosition>,
edit_preview: Option<language::EditPreview>,
},
/// Jump to a different file from the one that requested the prediction
@@ -1,4 +1,4 @@
-use edit_prediction_types::EditPredictionDelegate;
+use edit_prediction_types::{EditPredictionDelegate, PredictedCursorPosition};
use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
use indoc::indoc;
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
@@ -32,6 +32,88 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
cx.assert_editor_state("let absolute_zero_celsius = -273.15ˇ;")
}
+#[gpui::test]
+async fn test_edit_prediction_cursor_position_inside_insertion(cx: &mut gpui::TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+ // Buffer: "fn foo() {}" - we'll insert text and position cursor inside the insertion
+ cx.set_state("fn foo() ˇ{}");
+
+ // Insert "bar()" at offset 9, with cursor at offset 2 within the insertion (after "ba")
+ // This tests the case where cursor is inside newly inserted text
+ propose_edits_with_cursor_position_in_insertion(
+ &provider,
+ vec![(9..9, "bar()")],
+ 9, // anchor at the insertion point
+ 2, // offset 2 within "bar()" puts cursor after "ba"
+ &mut cx,
+ );
+ cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ assert_editor_active_edit_completion(&mut cx, |_, edits| {
+ assert_eq!(edits.len(), 1);
+ assert_eq!(edits[0].1.as_ref(), "bar()");
+ });
+
+ accept_completion(&mut cx);
+
+ // Cursor should be inside the inserted text at "baˇr()"
+ cx.assert_editor_state("fn foo() baˇr(){}");
+}
+
+#[gpui::test]
+async fn test_edit_prediction_cursor_position_outside_edit(cx: &mut gpui::TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+ // Buffer: "let x = ;" with cursor before semicolon - we'll insert "42" and position cursor elsewhere
+ cx.set_state("let x = ˇ;");
+
+ // Insert "42" at offset 8, but set cursor_position to offset 4 (the 'x')
+ // This tests that cursor moves to the predicted position, not the end of the edit
+ propose_edits_with_cursor_position(
+ &provider,
+ vec![(8..8, "42")],
+ Some(4), // cursor at offset 4 (the 'x'), NOT at the edit location
+ &mut cx,
+ );
+ cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ assert_editor_active_edit_completion(&mut cx, |_, edits| {
+ assert_eq!(edits.len(), 1);
+ assert_eq!(edits[0].1.as_ref(), "42");
+ });
+
+ accept_completion(&mut cx);
+
+ // Cursor should be at offset 4 (the 'x'), not at the end of the inserted "42"
+ cx.assert_editor_state("let ˇx = 42;");
+}
+
+#[gpui::test]
+async fn test_edit_prediction_cursor_position_fallback(cx: &mut gpui::TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+ cx.set_state("let x = ˇ;");
+
+ // Propose an edit without a cursor position - should fall back to end of edit
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ accept_completion(&mut cx);
+
+ // Cursor should be at the end of the inserted text (default behavior)
+ cx.assert_editor_state("let x = 42ˇ;")
+}
+
#[gpui::test]
async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
@@ -374,8 +456,50 @@ fn propose_edits<T: ToOffset>(
provider: &Entity<FakeEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
cx: &mut EditorTestContext,
+) {
+ propose_edits_with_cursor_position(provider, edits, None, cx);
+}
+
+fn propose_edits_with_cursor_position<T: ToOffset>(
+ provider: &Entity<FakeEditPredictionDelegate>,
+ edits: Vec<(Range<T>, &str)>,
+ cursor_offset: Option<usize>,
+ cx: &mut EditorTestContext,
+) {
+ let snapshot = cx.buffer_snapshot();
+ let cursor_position = cursor_offset
+ .map(|offset| PredictedCursorPosition::at_anchor(snapshot.anchor_after(offset)));
+ let edits = edits.into_iter().map(|(range, text)| {
+ let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end);
+ (range, text.into())
+ });
+
+ cx.update(|_, cx| {
+ provider.update(cx, |provider, _| {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
+ id: None,
+ edits: edits.collect(),
+ cursor_position,
+ edit_preview: None,
+ }))
+ })
+ });
+}
+
+fn propose_edits_with_cursor_position_in_insertion<T: ToOffset>(
+ provider: &Entity<FakeEditPredictionDelegate>,
+ edits: Vec<(Range<T>, &str)>,
+ anchor_offset: usize,
+ offset_within_insertion: usize,
+ cx: &mut EditorTestContext,
) {
let snapshot = cx.buffer_snapshot();
+ // Use anchor_before (left bias) so the anchor stays at the insertion point
+ // rather than moving past the inserted text
+ let cursor_position = Some(PredictedCursorPosition::new(
+ snapshot.anchor_before(anchor_offset),
+ offset_within_insertion,
+ ));
let edits = edits.into_iter().map(|(range, text)| {
let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end);
(range, text.into())
@@ -386,6 +510,7 @@ fn propose_edits<T: ToOffset>(
provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: edits.collect(),
+ cursor_position,
edit_preview: None,
}))
})
@@ -417,6 +542,7 @@ fn propose_edits_non_zed<T: ToOffset>(
provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: edits.collect(),
+ cursor_position: None,
edit_preview: None,
}))
})
@@ -634,6 +634,10 @@ pub(crate) enum EditDisplayMode {
enum EditPrediction {
Edit {
edits: Vec<(Range<Anchor>, Arc<str>)>,
+ /// Predicted cursor position as (anchor, offset_from_anchor).
+ /// The anchor is in multibuffer coordinates; after applying edits,
+ /// resolve the anchor and add the offset to get the final cursor position.
+ cursor_position: Option<(Anchor, usize)>,
edit_preview: Option<EditPreview>,
display_mode: EditDisplayMode,
snapshot: BufferSnapshot,
@@ -7872,7 +7876,11 @@ impl Editor {
.detach_and_log_err(cx);
}
}
- EditPrediction::Edit { edits, .. } => {
+ EditPrediction::Edit {
+ edits,
+ cursor_position,
+ ..
+ } => {
self.report_edit_prediction_event(
active_edit_prediction.completion_id.clone(),
true,
@@ -7886,15 +7894,32 @@ impl Editor {
}
let transaction_id_prev = self.buffer.read(cx).last_transaction_id(cx);
- let snapshot = self.buffer.read(cx).snapshot(cx);
- let last_edit_end = edits.last().unwrap().0.end.bias_right(&snapshot);
+
+ // Compute fallback cursor position BEFORE applying the edit,
+ // so the anchor tracks through the edit correctly
+ let fallback_cursor_target = {
+ let snapshot = self.buffer.read(cx).snapshot(cx);
+ edits.last().unwrap().0.end.bias_right(&snapshot)
+ };
self.buffer.update(cx, |buffer, cx| {
buffer.edit(edits.iter().cloned(), None, cx)
});
+ // Resolve cursor position after the edit is applied
+ let cursor_target = if let Some((anchor, offset)) = cursor_position {
+ // The anchor tracks through the edit, then we add the offset
+ let snapshot = self.buffer.read(cx).snapshot(cx);
+ let base_offset = anchor.to_offset(&snapshot).0;
+ let target_offset =
+ MultiBufferOffset((base_offset + offset).min(snapshot.len().0));
+ snapshot.anchor_after(target_offset)
+ } else {
+ fallback_cursor_target
+ };
+
self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
- s.select_anchor_ranges([last_edit_end..last_edit_end]);
+ s.select_anchor_ranges([cursor_target..cursor_target]);
});
let selections = self.selections.disjoint_anchors_arc();
@@ -8358,12 +8383,14 @@ impl Editor {
let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
- let (completion_id, edits, edit_preview) = match edit_prediction {
+ let (completion_id, edits, predicted_cursor_position, edit_preview) = match edit_prediction
+ {
edit_prediction_types::EditPrediction::Local {
id,
edits,
+ cursor_position,
edit_preview,
- } => (id, edits, edit_preview),
+ } => (id, edits, cursor_position, edit_preview),
edit_prediction_types::EditPrediction::Jump {
id,
snapshot,
@@ -8397,6 +8424,11 @@ impl Editor {
return None;
}
+ let cursor_position = predicted_cursor_position.and_then(|predicted| {
+ let anchor = multibuffer.anchor_in_excerpt(excerpt_id, predicted.anchor)?;
+ Some((anchor, predicted.offset))
+ });
+
let first_edit_start = edits.first().unwrap().0.start;
let first_edit_start_point = first_edit_start.to_point(&multibuffer);
let edit_start_row = first_edit_start_point.row.saturating_sub(2);
@@ -8491,6 +8523,7 @@ impl Editor {
EditPrediction::Edit {
edits,
+ cursor_position,
edit_preview,
display_mode,
snapshot,
@@ -9179,6 +9212,7 @@ impl Editor {
edit_preview,
display_mode: EditDisplayMode::DiffPopover,
snapshot,
+ ..
} => self.render_edit_prediction_diff_popover(
text_bounds,
content_origin,
@@ -10129,7 +10163,7 @@ impl Editor {
edits,
edit_preview,
snapshot,
- display_mode: _,
+ ..
} => {
let first_edit_row = edits.first()?.0.start.text_anchor.to_point(snapshot).row;
@@ -9017,6 +9017,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: vec![(edit_position..edit_position, "X".into())],
+ cursor_position: None,
edit_preview: None,
}))
})
@@ -67,7 +67,7 @@ use task::RunnableTag;
pub use task_context::{ContextLocation, ContextProvider, RunnableRange};
pub use text_diff::{
DiffOptions, apply_diff_patch, line_diff, text_diff, text_diff_with_options, unified_diff,
- unified_diff_with_offsets, word_diff_ranges,
+ unified_diff_with_context, unified_diff_with_offsets, word_diff_ranges,
};
use theme::SyntaxTheme;
pub use toolchain::{
@@ -22,12 +22,25 @@ pub fn unified_diff_with_offsets(
new_text: &str,
old_start_line: u32,
new_start_line: u32,
+) -> String {
+ unified_diff_with_context(old_text, new_text, old_start_line, new_start_line, 3)
+}
+
+/// Computes a diff between two strings, returning a unified diff string with
+/// hunk headers adjusted to reflect the given starting line numbers (zero-indexed),
+/// and a configurable number of context lines around changes.
+pub fn unified_diff_with_context(
+ old_text: &str,
+ new_text: &str,
+ old_start_line: u32,
+ new_start_line: u32,
+ context_lines: u32,
) -> String {
let input = InternedInput::new(old_text, new_text);
diff(
Algorithm::Histogram,
&input,
- OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line),
+ OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines),
)
}
@@ -45,13 +58,19 @@ struct OffsetUnifiedDiffBuilder<'a> {
old_line_offset: u32,
new_line_offset: u32,
+ context_lines: u32,
buffer: String,
dst: String,
}
impl<'a> OffsetUnifiedDiffBuilder<'a> {
- fn new(input: &'a InternedInput<&'a str>, old_line_offset: u32, new_line_offset: u32) -> Self {
+ fn new(
+ input: &'a InternedInput<&'a str>,
+ old_line_offset: u32,
+ new_line_offset: u32,
+ context_lines: u32,
+ ) -> Self {
Self {
before_hunk_start: 0,
after_hunk_start: 0,
@@ -59,6 +78,7 @@ impl<'a> OffsetUnifiedDiffBuilder<'a> {
after_hunk_len: 0,
old_line_offset,
new_line_offset,
+ context_lines,
buffer: String::with_capacity(8),
dst: String::new(),
interner: &input.interner,
@@ -79,7 +99,7 @@ impl<'a> OffsetUnifiedDiffBuilder<'a> {
return;
}
- let end = (self.pos + 3).min(self.before.len() as u32);
+ let end = (self.pos + self.context_lines).min(self.before.len() as u32);
self.update_pos(end, end);
writeln!(
@@ -110,13 +130,13 @@ impl Sink for OffsetUnifiedDiffBuilder<'_> {
type Out = String;
fn process_change(&mut self, before: Range<u32>, after: Range<u32>) {
- if before.start - self.pos > 6 {
+ if before.start - self.pos > self.context_lines * 2 {
self.flush();
}
if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
- self.pos = before.start.saturating_sub(3);
+ self.pos = before.start.saturating_sub(self.context_lines);
self.before_hunk_start = self.pos;
- self.after_hunk_start = after.start.saturating_sub(3);
+ self.after_hunk_start = after.start.saturating_sub(self.context_lines);
}
self.update_pos(before.start, before.end);
self.before_hunk_len += before.end - before.start;
@@ -467,4 +487,29 @@ mod tests {
format!("@@ -100,3 +105,3 @@\n{}", expected_diff_body)
);
}
+
+ #[test]
+ fn test_unified_diff_with_context() {
+ // Test that full context includes all lines from the start
+ let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n";
+ let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n";
+
+ // With default 3 lines of context, the diff starts at line 3
+ let diff_default = unified_diff_with_offsets(old_text, new_text, 0, 0);
+ assert_eq!(
+ diff_default,
+ "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+ );
+
+ // With full context (8 lines), the diff starts at line 1
+ let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8);
+ assert_eq!(
+ diff_full_context,
+ "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+ );
+
+ // With 0 context, only the changed line is shown
+ let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0);
+ assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n");
+ }
}
@@ -100,6 +100,7 @@ fn completion_from_diff(
EditPrediction::Local {
id: None,
edits,
+ cursor_position: None,
edit_preview: None,
}
}
@@ -24,6 +24,7 @@ extend-exclude = [
# Editor and file finder rely on partial typing and custom in-string syntax.
"crates/file_finder/src/file_finder_tests.rs",
"crates/editor/src/editor_tests.rs",
+ "crates/editor/src/edit_prediction_tests.rs",
# There are some names in the test data that are incorrectly flagged as typos.
"crates/git/test_data/blame_incremental_complex",
"crates/git/test_data/golden/blame_incremental_complex.json",