diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index f16239dff0ca28781f36abfcdaab9fcc3873651d..93fc6aa99a27f18436bc564fbaa39a15d3be0b44 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -19,7 +19,7 @@ use settings::EditPredictionPromptFormat; use text::{Anchor, Bias}; use ui::SharedString; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; -use zeta_prompt::ZetaPromptInput; +use zeta_prompt::{ParsedOutput, ZetaPromptInput}; use std::{env, ops::Range, path::Path, sync::Arc, time::Instant}; use zeta_prompt::{ @@ -175,13 +175,12 @@ pub fn request_prediction_with_zeta( let request_id = EditPredictionId(request_id.into()); let output_text = zeta1::clean_zeta1_model_output(&response_text); + let parsed_output = output_text.map(|text| ParsedOutput { + new_editable_region: text, + range_in_excerpt: editable_range_in_excerpt, + }); - ( - request_id, - Some(editable_range_in_excerpt).zip(output_text), - None, - None, - ) + (request_id, parsed_output, None, None) } EditPredictionPromptFormat::Zeta2 => { let prompt = format_zeta_prompt(&prompt_input, zeta_version); @@ -271,20 +270,23 @@ pub fn request_prediction_with_zeta( let request_id = EditPredictionId(response.request_id.into()); let output_text = Some(response.output).filter(|s| !s.is_empty()); let model_version = response.model_version; + let parsed_output = ParsedOutput { + new_editable_region: output_text.unwrap_or_default(), + range_in_excerpt: response.editable_range, + }; - ( - request_id, - Some(response.editable_range).zip(output_text), - model_version, - usage, - ) + (request_id, Some(parsed_output), model_version, usage) }; let received_response_at = Instant::now(); log::trace!("Got edit prediction response"); - let Some((editable_range_in_excerpt, mut output_text)) = output else { + let Some(ParsedOutput { + new_editable_region: mut output_text, + range_in_excerpt: editable_range_in_excerpt, + }) = output + else { return Ok(((request_id, None), None)); }; diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 041c57c36e958df45dd000f48c33e00b05c751f3..94058efd92ca4a166ba4976819963ef5d3286f5d 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -6,7 +6,7 @@ use crate::{ }; use anyhow::{Context as _, Result}; use edit_prediction::example_spec::encode_cursor_in_patch; -use zeta_prompt::{CURSOR_MARKER, ZetaFormat, output_end_marker_for_format, resolve_cursor_region}; +use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -60,10 +60,13 @@ fn parse_zeta2_output( .as_ref() .context("prompt_inputs required")?; - let (context, editable_range, _, _) = resolve_cursor_region(prompt_inputs, format); - let old_text = context[editable_range].to_string(); + let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?; + let range_in_excerpt = parsed.range_in_excerpt; + + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + let old_text = excerpt[range_in_excerpt.clone()].to_string(); + let mut new_text = parsed.new_editable_region; - let mut new_text = actual_output.to_string(); let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { new_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); Some(offset) @@ -71,14 +74,8 @@ fn parse_zeta2_output( None }; - if let Some(marker) = output_end_marker_for_format(format) { - new_text = new_text - .strip_suffix(marker) - .unwrap_or(&new_text) - .to_string(); - } - - let mut old_text_normalized = old_text.clone(); + // Normalize trailing newlines for diff generation + let mut old_text_normalized = old_text; if !new_text.is_empty() && !new_text.ends_with('\n') { new_text.push('\n'); } @@ -86,22 +83,10 @@ fn parse_zeta2_output( old_text_normalized.push('\n'); } - let old_text_trimmed = old_text.trim_end_matches('\n'); - let excerpt = prompt_inputs.cursor_excerpt.as_ref(); - let (editable_region_offset, _) = excerpt - .match_indices(old_text_trimmed) - .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset_in_excerpt)) - .with_context(|| { - format!( - "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}", - old_text_trimmed, excerpt - ) - })?; - + let editable_region_offset = range_in_excerpt.start; let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - - // Use full context so cursor offset (relative to editable region start) aligns with diff content let editable_region_lines = old_text_normalized.lines().count() as u32; + let diff = language::unified_diff_with_context( &old_text_normalized, &new_text, diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 9469c056468ed91fe9c95aa5e5cd2edf3590b8bd..b7b67ed851419dcf0f125f46e5a17e7f9ac9aa92 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -470,12 +470,19 @@ pub fn encode_patch_as_output_for_format( } } +pub struct ParsedOutput { + /// Text that should replace the editable region + pub new_editable_region: String, + /// The byte range within `cursor_excerpt` that this replacement applies to + pub range_in_excerpt: Range, +} + /// Parse model output for the given zeta format pub fn parse_zeta2_model_output( output: &str, format: ZetaFormat, prompt_inputs: &ZetaPromptInput, -) -> Result<(Range, String)> { +) -> Result { let output = match output_end_marker_for_format(format) { Some(marker) => output.strip_suffix(marker).unwrap_or(output), None => output, @@ -509,7 +516,11 @@ pub fn parse_zeta2_model_output( let range_in_excerpt = range_in_context.start + context_start..range_in_context.end + context_start; - Ok((range_in_excerpt, output)) + + Ok(ParsedOutput { + new_editable_region: output, + range_in_excerpt, + }) } pub fn excerpt_range_for_format( @@ -4612,9 +4623,12 @@ mod tests { assert_eq!(cleaned, ""); } - fn apply_edit(excerpt: &str, range: &Range, new_text: &str) -> String { + fn apply_edit(excerpt: &str, parsed_output: &ParsedOutput) -> String { let mut result = excerpt.to_string(); - result.replace_range(range.clone(), new_text); + result.replace_range( + parsed_output.range_in_excerpt.clone(), + &parsed_output.new_editable_region, + ); result } @@ -4632,7 +4646,7 @@ mod tests { editable_start, ); - let (range, text) = parse_zeta2_model_output( + let output = parse_zeta2_model_output( "editable new\n>>>>>>> UPDATED\n", ZetaFormat::V0131GitMergeMarkersPrefix, &input, @@ -4640,7 +4654,7 @@ mod tests { .unwrap(); assert_eq!( - apply_edit(excerpt, &range, &text), + apply_edit(excerpt, &output), "before ctx\nctx start\neditable new\nctx end\nafter ctx\n" ); } @@ -4658,10 +4672,10 @@ mod tests { ); let format = ZetaFormat::V0131GitMergeMarkersPrefix; - let (range, text) = + let output = parse_zeta2_model_output("bbb\nccc\n>>>>>>> UPDATED\n", format, &input).unwrap(); - assert_eq!(apply_edit(excerpt, &range, &text), excerpt); + assert_eq!(apply_edit(excerpt, &output), excerpt); } #[test] @@ -4670,14 +4684,11 @@ mod tests { let input = make_input_with_context_range(excerpt, 0..excerpt.len(), 0..excerpt.len(), 0); let format = ZetaFormat::V0131GitMergeMarkersPrefix; - let (range1, text1) = + let output1 = parse_zeta2_model_output("new content\n>>>>>>> UPDATED\n", format, &input).unwrap(); - let (range2, text2) = parse_zeta2_model_output("new content\n", format, &input).unwrap(); + let output2 = parse_zeta2_model_output("new content\n", format, &input).unwrap(); - assert_eq!( - apply_edit(excerpt, &range1, &text1), - apply_edit(excerpt, &range2, &text2) - ); - assert_eq!(apply_edit(excerpt, &range1, &text1), "new content\n"); + assert_eq!(apply_edit(excerpt, &output1), apply_edit(excerpt, &output2)); + assert_eq!(apply_edit(excerpt, &output1), "new content\n"); } }