diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index fe57464c7f9aa33334fcb7b719ad65a297761db6..1f692eff2c062cf703e72117c6fd39c7a4e1efbb 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -53,7 +53,6 @@ use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use thiserror::Error; use util::{RangeExt as _, ResultExt as _}; -use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; pub mod cursor_excerpt; pub mod example_spec; @@ -2470,49 +2469,6 @@ impl EditPredictionStore { .await } - fn handle_api_response( - this: &WeakEntity, - response: Result<(T, Option)>, - cx: &mut gpui::AsyncApp, - ) -> Result { - match response { - Ok((data, usage)) => { - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - }) - .ok(); - } - Ok(data) - } - Err(err) => { - if err.is::() { - cx.update(|cx| { - this.update(cx, |this, _cx| { - this.update_required = true; - }) - .ok(); - - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button("Update Zed", "https://zed.dev/releases") - }) - }, - ); - }); - } - Err(err) - } - } - } - async fn send_api_request( build: impl Fn(http_client::http::request::Builder) -> Result>, client: Arc, diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 355e10a743f6b778e67989a0b65b93318bfd007c..f16239dff0ca28781f36abfcdaab9fcc3873651d 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -1,24 +1,30 @@ -use crate::cursor_excerpt::compute_excerpt_ranges; -use crate::prediction::EditPredictionResult; use crate::{ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent, + ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges, + prediction::EditPredictionResult, }; use anyhow::Result; -use cloud_llm_client::predict_edits_v3::RawCompletionRequest; -use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason}; +use cloud_llm_client::{ + AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest, +}; use edit_prediction_types::PredictedCursorPosition; -use gpui::{App, AppContext as _, Task, prelude::*}; -use language::language_settings::all_language_settings; -use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff}; +use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*}; +use language::{ + Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings, + text_diff, +}; use release_channel::AppVersion; use settings::EditPredictionPromptFormat; use text::{Anchor, Bias}; +use ui::SharedString; +use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; +use zeta_prompt::ZetaPromptInput; use std::{env, ops::Range, path::Path, sync::Arc, time::Instant}; use zeta_prompt::{ - CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill, - output_with_context_for_format, prompt_input_contains_special_tokens, + CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, + prompt_input_contains_special_tokens, zeta1::{self, EDITABLE_REGION_END_MARKER}, }; @@ -86,6 +92,17 @@ pub fn request_prediction_with_zeta( .map(|organization| organization.id.clone()); let app_version = AppVersion::global(cx); + struct Prediction { + prompt_input: ZetaPromptInput, + buffer: Entity, + snapshot: BufferSnapshot, + edits: Vec<(Range, Arc)>, + cursor_position: Option, + received_response_at: Instant, + editable_range_in_buffer: Range, + model_version: Option, + } + let request_task = cx.background_spawn({ async move { let zeta_version = raw_config @@ -94,7 +111,6 @@ pub fn request_prediction_with_zeta( .unwrap_or(ZetaFormat::default()); let cursor_offset = position.to_offset(&snapshot); - let editable_range_in_excerpt: Range; let (full_context_offset_range, prompt_input) = zeta2_prompt_input( &snapshot, related_files, @@ -108,7 +124,7 @@ pub fn request_prediction_with_zeta( ); if prompt_input_contains_special_tokens(&prompt_input, zeta_version) { - return Ok((None, None)); + return Err(anyhow::anyhow!("prompt contains special tokens")); } if let Some(debug_tx) = &debug_tx { @@ -126,19 +142,19 @@ pub fn request_prediction_with_zeta( log::trace!("Sending edit prediction request"); - let (request_id, output_text, model_version, usage) = + let (request_id, output, model_version, usage) = if let Some(custom_settings) = &custom_server_settings { let max_tokens = custom_settings.max_output_tokens * 4; match custom_settings.prompt_format { EditPredictionPromptFormat::Zeta => { let ranges = &prompt_input.excerpt_ranges; + let editable_range_in_excerpt = ranges.editable_350.clone(); let prompt = zeta1::format_zeta1_from_input( &prompt_input, - ranges.editable_350.clone(), + editable_range_in_excerpt.clone(), ranges.editable_350_context_150.clone(), ); - editable_range_in_excerpt = ranges.editable_350.clone(); let stop_tokens = vec![ EDITABLE_REGION_END_MARKER.to_string(), format!("{EDITABLE_REGION_END_MARKER}\n"), @@ -160,19 +176,18 @@ pub fn request_prediction_with_zeta( let request_id = EditPredictionId(request_id.into()); let output_text = zeta1::clean_zeta1_model_output(&response_text); - (request_id, output_text, None, None) + ( + request_id, + Some(editable_range_in_excerpt).zip(output_text), + None, + None, + ) } EditPredictionPromptFormat::Zeta2 => { let prompt = format_zeta_prompt(&prompt_input, zeta_version); let prefill = get_prefill(&prompt_input, zeta_version); let prompt = format!("{prompt}{prefill}"); - editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format( - zeta_version, - &prompt_input.excerpt_ranges, - ) - .0; - let (response_text, request_id) = send_custom_server_request( provider, custom_settings, @@ -189,7 +204,11 @@ pub fn request_prediction_with_zeta( None } else { let output = format!("{prefill}{response_text}"); - Some(clean_zeta2_model_output(&output, zeta_version).to_string()) + Some(parse_zeta2_model_output( + &output, + zeta_version, + &prompt_input, + )?) }; (request_id, output_text, None, None) @@ -213,12 +232,6 @@ pub fn request_prediction_with_zeta( environment, }; - editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format( - config.format, - &prompt_input.excerpt_ranges, - ) - .1; - let (mut response, usage) = EditPredictionStore::send_raw_llm_request( request, client, @@ -230,13 +243,19 @@ pub fn request_prediction_with_zeta( .await?; let request_id = EditPredictionId(response.id.clone().into()); - let output_text = response.choices.pop().map(|choice| { + let output = if let Some(choice) = response.choices.pop() { let response = &choice.text; let output = format!("{prefill}{response}"); - clean_zeta2_model_output(&output, config.format).to_string() - }); + Some(parse_zeta2_model_output( + &output, + config.format, + &prompt_input, + )?) + } else { + None + }; - (request_id, output_text, None, usage) + (request_id, output, None, usage) } else { // Use V3 endpoint - server handles model/version selection and suffix stripping let (response, usage) = EditPredictionStore::send_v3_request( @@ -250,23 +269,23 @@ pub fn request_prediction_with_zeta( .await?; let request_id = EditPredictionId(response.request_id.into()); - let output_text = if response.output.is_empty() { - None - } else { - Some(response.output) - }; - editable_range_in_excerpt = response.editable_range; + let output_text = Some(response.output).filter(|s| !s.is_empty()); let model_version = response.model_version; - (request_id, output_text, model_version, usage) + ( + request_id, + Some(response.editable_range).zip(output_text), + model_version, + usage, + ) }; let received_response_at = Instant::now(); log::trace!("Got edit prediction response"); - let Some(mut output_text) = output_text else { - return Ok((Some((request_id, None, model_version)), usage)); + let Some((editable_range_in_excerpt, mut output_text)) = output else { + return Ok(((request_id, None), None)); }; let editable_range_in_buffer = editable_range_in_excerpt.start @@ -277,17 +296,6 @@ pub fn request_prediction_with_zeta( .text_for_range(editable_range_in_buffer.clone()) .collect::(); - // For the hashline format, the model may return <|set|>/<|insert|> - // edit commands instead of a full replacement. Apply them against - // the original editable region to produce the full replacement text. - // This must happen before cursor marker stripping because the cursor - // marker is embedded inside edit command content. - if let Some(rewritten_output) = - output_with_context_for_format(zeta_version, &old_text, &output_text)? - { - output_text = rewritten_output; - } - // Client-side cursor marker processing (applies to both raw and v3 responses) let cursor_offset_in_output = output_text.find(CURSOR_MARKER); if let Some(offset) = cursor_offset_in_output { @@ -323,40 +331,37 @@ pub fn request_prediction_with_zeta( ); anyhow::Ok(( - Some(( + ( request_id, - Some(( + Some(Prediction { prompt_input, buffer, - snapshot.clone(), + snapshot: snapshot.clone(), edits, cursor_position, received_response_at, editable_range_in_buffer, - )), - model_version, - )), + model_version, + }), + ), usage, )) } }); cx.spawn(async move |this, cx| { - let Some((id, prediction, model_version)) = - EditPredictionStore::handle_api_response(&this, request_task.await, cx)? - else { - return Ok(None); - }; + let (id, prediction) = handle_api_response(&this, request_task.await, cx)?; - let Some(( - inputs, - edited_buffer, - edited_buffer_snapshot, + let Some(Prediction { + prompt_input: inputs, + buffer: edited_buffer, + snapshot: edited_buffer_snapshot, edits, cursor_position, received_response_at, editable_range_in_buffer, - )) = prediction + model_version, + }) = prediction else { return Ok(Some(EditPredictionResult { id, @@ -423,6 +428,49 @@ pub fn request_prediction_with_zeta( }) } +fn handle_api_response( + this: &WeakEntity, + response: Result<(T, Option)>, + cx: &mut gpui::AsyncApp, +) -> Result { + match response { + Ok((data, usage)) => { + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } + Ok(data) + } + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |this, _cx| { + this.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button("Update Zed", "https://zed.dev/releases") + }) + }, + ); + }); + } + Err(err) + } + } +} + pub fn zeta2_prompt_input( snapshot: &language::BufferSnapshot, related_files: Vec, diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 2c066b8b32b3eaab54ad6e3b3bcb0796ff27f950..041c57c36e958df45dd000f48c33e00b05c751f3 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -6,11 +6,7 @@ use crate::{ }; use anyhow::{Context as _, Result}; use edit_prediction::example_spec::encode_cursor_in_patch; -use zeta_prompt::{ - CURSOR_MARKER, ZetaFormat, clean_extracted_region_for_format, - current_region_markers_for_format, output_end_marker_for_format, - output_with_context_for_format, -}; +use zeta_prompt::{CURSOR_MARKER, ZetaFormat, output_end_marker_for_format, resolve_cursor_region}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -54,43 +50,20 @@ pub fn parse_prediction_output( } } -fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result { - let (current_marker, end_marker) = current_region_markers_for_format(format); - - let start = prompt.find(current_marker).with_context(|| { - format!( - "missing current marker '{}' in prompt", - current_marker.trim() - ) - })? + current_marker.len(); - - let end = prompt[start..] - .find(end_marker) - .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))? - + start; - - let region = &prompt[start..end]; - let region = region.replace(CURSOR_MARKER, ""); - Ok(clean_extracted_region_for_format(format, ®ion)) -} - fn parse_zeta2_output( example: &Example, actual_output: &str, format: ZetaFormat, ) -> Result<(String, Option)> { - let prompt = &example.prompt.as_ref().context("prompt required")?.input; let prompt_inputs = example .prompt_inputs .as_ref() .context("prompt_inputs required")?; - let old_text = extract_zeta2_current_region(prompt, format)?; + let (context, editable_range, _, _) = resolve_cursor_region(prompt_inputs, format); + let old_text = context[editable_range].to_string(); let mut new_text = actual_output.to_string(); - if let Some(transformed) = output_with_context_for_format(format, &old_text, &new_text)? { - new_text = transformed; - } let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { new_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); Some(offset) @@ -157,95 +130,3 @@ fn parse_zeta2_output( Ok((formatted_diff, actual_cursor)) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_extract_zeta2_current_region_v0113() { - let prompt = indoc::indoc! {" - <|file_sep|>src/main.rs - <|fim_prefix|> - fn main() { - <|fim_middle|>current - println!(\"hello\"); - <|fim_suffix|> - } - <|fim_middle|>updated - "}; - - let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap(); - assert_eq!(region, "println!(\"hello\");\n"); - } - - #[test] - fn test_extract_zeta2_current_region_v0112() { - let prompt = indoc::indoc! {" - <|file_sep|>src/main.rs - <|fim_prefix|> - fn main() { - <|fim_suffix|> - } - <|fim_middle|>current - println!(\"hello\"); - <|fim_middle|>updated - "}; - - let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap(); - assert_eq!(region, "println!(\"hello\");\n"); - } - - #[test] - fn test_extract_zeta2_current_region_with_cursor_marker() { - let prompt = indoc::indoc! {" - <|file_sep|>src/main.rs - <|fim_prefix|> - fn main() { - <|fim_middle|>current - print<|user_cursor|>ln!(\"hello\"); - <|fim_suffix|> - } - <|fim_middle|>updated - "}; - - let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap(); - assert_eq!(region, "println!(\"hello\");\n"); - } - - #[test] - fn test_extract_zeta2_current_region_v0120_git_merge_markers() { - let prompt = indoc::indoc! {" - <|file_sep|>src/main.rs - <|fim_prefix|> - fn main() { - <|fim_suffix|> - } - <|fim_middle|><<<<<<< CURRENT - println!(\"hello\"); - ======= - "}; - - let region = - extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap(); - assert_eq!(region, "println!(\"hello\");\n"); - } - - #[test] - fn test_extract_zeta2_current_region_v0120_with_cursor_marker() { - let prompt = indoc::indoc! {" - <|file_sep|>src/main.rs - <|fim_prefix|> - fn main() { - <|fim_suffix|> - } - <|fim_middle|><<<<<<< CURRENT - print<|user_cursor|>ln!(\"hello\"); - ======= - "}; - - let region = - extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap(); - assert_eq!(region, "println!(\"hello\");\n"); - } -} diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index d6313cc1f4d8dc5c9675c17b007e69d3c546ee92..52cda41ac07c52711bd381b8bebe9d8a172d0d09 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{Result, anyhow}; use serde::{Deserialize, Serialize}; use std::fmt::Write; use std::ops::Range; @@ -89,6 +89,7 @@ pub enum ZetaFormat { V0211Prefill, V0211SeedCoder, v0226Hashline, + V0304VariableEdit, V0304SeedNoEdits, } @@ -216,6 +217,7 @@ pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] ZetaFormat::V0211Prefill => v0211_prefill::special_tokens(), ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(), ZetaFormat::v0226Hashline => hashline::special_tokens(), + ZetaFormat::V0304VariableEdit => v0304_variable_edit::special_tokens(), ZetaFormat::V0304SeedNoEdits => seed_coder::special_tokens(), } } @@ -242,6 +244,13 @@ pub fn excerpt_ranges_for_format( ranges.editable_350.clone(), ranges.editable_350_context_150.clone(), ), + ZetaFormat::V0304VariableEdit => { + let context = ranges + .context_8192 + .clone() + .unwrap_or_else(|| ranges.editable_350_context_150.clone()); + (context.clone(), context) + } } } @@ -302,6 +311,9 @@ pub fn write_cursor_excerpt_section_for_format( editable_range, cursor_offset, ), + ZetaFormat::V0304VariableEdit => { + v0304_variable_edit::write_cursor_excerpt_section(prompt, path, context, cursor_offset) + } } } @@ -418,7 +430,8 @@ pub fn get_prefill_for_format( | ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211SeedCoder | ZetaFormat::v0226Hashline - | ZetaFormat::V0304SeedNoEdits => String::new(), + | ZetaFormat::V0304VariableEdit => String::new(), + ZetaFormat::V0304SeedNoEdits => String::new(), } } @@ -431,32 +444,8 @@ pub fn output_end_marker_for_format(format: ZetaFormat) -> Option<&'static str> ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion - | ZetaFormat::v0226Hashline => None, - } -} - -pub fn current_region_markers_for_format(format: ZetaFormat) -> (&'static str, &'static str) { - match format { - ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), - ZetaFormat::V0113Ordered - | ZetaFormat::V0114180EditableRegion - | ZetaFormat::v0226Hashline => ("<|fim_middle|>current\n", "<|fim_suffix|>"), - ZetaFormat::V0120GitMergeMarkers - | ZetaFormat::V0131GitMergeMarkersPrefix - | ZetaFormat::V0211Prefill => ( - v0120_git_merge_markers::START_MARKER, - v0120_git_merge_markers::SEPARATOR, - ), - ZetaFormat::V0211SeedCoder | ZetaFormat::V0304SeedNoEdits => { - (seed_coder::START_MARKER, seed_coder::SEPARATOR) - } - } -} - -pub fn clean_extracted_region_for_format(format: ZetaFormat, region: &str) -> String { - match format { - ZetaFormat::v0226Hashline => hashline::strip_hashline_prefixes(region), - _ => region.to_string(), + | ZetaFormat::v0226Hashline + | ZetaFormat::V0304VariableEdit => None, } } @@ -470,43 +459,52 @@ pub fn encode_patch_as_output_for_format( ZetaFormat::v0226Hashline => { hashline::patch_to_edit_commands(old_editable_region, patch, cursor_offset).map(Some) } + ZetaFormat::V0304VariableEdit => v0304_variable_edit::patch_to_variable_edit_output( + old_editable_region, + patch, + cursor_offset, + ) + .map(Some), ZetaFormat::V0304SeedNoEdits => Ok(seed_coder::no_edits(patch)), _ => Ok(None), } } -pub fn output_with_context_for_format( - format: ZetaFormat, - old_editable_region: &str, +/// Parse model output for the given zeta format +pub fn parse_zeta2_model_output( output: &str, -) -> Result> { + format: ZetaFormat, + prompt_inputs: &ZetaPromptInput, +) -> Result<(Range, String)> { + let output = match output_end_marker_for_format(format) { + Some(marker) => output.strip_suffix(marker).unwrap_or(output), + None => output, + }; + + let (context, editable_range, _, _) = resolve_cursor_region(prompt_inputs, format); + let old_editable_region = &context[editable_range.clone()]; + match format { - ZetaFormat::v0226Hashline => { + ZetaFormat::v0226Hashline => Ok(( + editable_range, if hashline::output_has_edit_commands(output) { - Ok(Some(hashline::apply_edit_commands( - old_editable_region, - output, - ))) + hashline::apply_edit_commands(old_editable_region, output) } else { - Ok(None) - } + output.to_string() + }, + )), + ZetaFormat::V0304VariableEdit => { + v0304_variable_edit::apply_variable_edit(old_editable_region, output) } - ZetaFormat::V0304SeedNoEdits => { + ZetaFormat::V0304SeedNoEdits => Ok(( + editable_range, if output.starts_with(seed_coder::NO_EDITS) { - Ok(Some(old_editable_region.to_owned())) + old_editable_region.to_string() } else { - Ok(None) - } - } - _ => Ok(None), - } -} - -/// Post-processes model output for the given zeta format by stripping format-specific suffixes. -pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str { - match output_end_marker_for_format(format) { - Some(marker) => output.strip_suffix(marker).unwrap_or(output), - None => output, + output.to_string() + }, + )), + _ => Ok((editable_range, output.to_string())), } } @@ -2565,6 +2563,1009 @@ pub mod seed_coder { } } +pub mod v0304_variable_edit { + //! A prompt format with no fixed editable region. The entire context is shown + //! to the model, and it chooses which text to replace by outputting surrounding + //! context lines with `<|fim_middle|>` and `<|fim_suffix|>` delimiting the new + //! text. + //! + //! Example prompt: + //! + //! <|file_sep|>path/to/file.py + //! zero + //! one + //! two + //! three<|user_cursor|> + //! four + //! five + //! <|fim_prefix|> + // + //! Expected output (model generates): + //! + //! two + //! <|fim_middle|> + //! THREE + //! <|fim_suffix|> + //! four + //! + //! The output means: find "two\n...\nfour" in the context, and replace + //! everything between "two\n" and "four" with "THREE\n". + + use super::*; + + pub fn special_tokens() -> &'static [&'static str] { + &[ + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + "<|file_sep|>", + CURSOR_MARKER, + ] + } + + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); + write!(prompt, "<|file_sep|>{}\n", path_str).ok(); + + prompt.push_str(&context[..cursor_offset]); + prompt.push_str(CURSOR_MARKER); + prompt.push_str(&context[cursor_offset..]); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + prompt.push_str("<|fim_prefix|>\n") + } + + /// Apply a variable-edit model output to the original context text. + /// + /// The model output has the form: + /// + /// - prefix context lines + /// - `<|fim_middle|>` + /// - new text + /// - `<|fim_suffix|>` + /// - suffix context lines + /// + /// We locate the prefix/suffix context lines in the original text and replace + /// everything between them with the new text. + pub fn apply_variable_edit( + context: &str, + model_output: &str, + ) -> Result<(Range, String)> { + let (prefix_context, rest) = model_output + .split_once("<|fim_middle|>\n") + .or_else(|| model_output.split_once("<|fim_middle|>")) + .ok_or_else(|| anyhow::anyhow!("missing <|fim_middle|> in model output"))?; + + let (new_text, suffix_context) = rest + .split_once("<|fim_suffix|>\n") + .or_else(|| rest.split_once("<|fim_suffix|>")) + .unwrap_or((rest, "")); + + let suffix_context = if prefix_context.is_empty() && !suffix_context.is_empty() { + suffix_context.strip_prefix('\n').unwrap_or(suffix_context) + } else { + suffix_context + }; + + let prefix_offset = find_substring_at_line_boundary(context, prefix_context) + .ok_or_else(|| anyhow!("could not locate prefix lines"))? + + prefix_context.len(); + let suffix_offset = if suffix_context.is_empty() { + context.len() + } else { + find_substring_at_line_boundary(&context[prefix_offset..], suffix_context) + .ok_or_else(|| anyhow!("could not locate suffix lines"))? + + prefix_offset + }; + + let edit_range = prefix_offset..suffix_offset; + return Ok((edit_range, new_text.to_string())); + } + + fn find_substring_at_line_boundary(haystack: &str, needle: &str) -> Option { + if needle.is_empty() { + return Some(0); + } + + haystack.match_indices(needle).find_map(|(offset, _)| { + let matched_line_start = offset == 0 || haystack[..offset].ends_with('\n'); + matched_line_start.then_some(offset) + }) + } + + /// Convert a unified diff patch into the variable-edit output format. + /// + /// Parses `patch` as a unified diff against `old_text` and produces model + /// output with context lines surrounding `<|fim_middle|>` / `<|fim_suffix|>` + /// delimiters. The diff is resolved by content matching rather than line + /// numbers. + pub fn patch_to_variable_edit_output( + old_text: &str, + patch: &str, + cursor_offset: Option, + ) -> Result { + // Parse the unified diff into hunks. Each hunk has an `old_context` + // string (context + deleted lines interleaved in order) and a list of + // edits expressed as byte ranges within that context plus replacement + // text. + let hunks = parse_hunks(patch); + if hunks.is_empty() { + return Ok(String::new()); + } + + // Apply each hunk by finding its old_context in the text and + // performing the edits. We search forward from where the previous + // hunk ended so that hunks are applied in order. + let mut new_text = old_text.to_string(); + let mut search_from: usize = 0; + let mut first_hunk_pos: Option = None; + + for hunk in &hunks { + let context_pos = new_text[search_from..] + .find(&hunk.old_context) + .map(|pos| pos + search_from) + .ok_or_else(|| anyhow::anyhow!("could not locate hunk context in text"))?; + + if first_hunk_pos.is_none() { + first_hunk_pos = Some(context_pos); + } + + // Apply edits in reverse order so byte offsets remain valid. + for edit in hunk.edits.iter().rev() { + let abs_start = context_pos + edit.range.start; + let abs_end = context_pos + edit.range.end; + new_text.replace_range(abs_start..abs_end, &edit.text); + } + + // Advance past this hunk's region in the (now modified) text. + let new_region_len: usize = + hunk.edits.iter().fold(hunk.old_context.len(), |len, edit| { + len + edit.text.len() - (edit.range.end - edit.range.start) + }); + search_from = context_pos + new_region_len; + } + + // Now we have old_text and new_text. Find the changed line range by + // comparing them. + let old_lines: Vec<&str> = old_text.lines().collect(); + let new_lines: Vec<&str> = new_text.lines().collect(); + + // Find first differing line. + let first_changed_row = old_lines + .iter() + .zip(new_lines.iter()) + .position(|(a, b)| a != b) + .unwrap_or_else(|| old_lines.len().min(new_lines.len())); + + // Find last differing line (from the end). + let max_suffix = old_lines.len().min(new_lines.len()) - first_changed_row; + let common_suffix = old_lines + .iter() + .rev() + .zip(new_lines.iter().rev()) + .take(max_suffix) + .take_while(|(a, b)| a == b) + .count(); + + let old_end = old_lines.len() - common_suffix; + let new_end = new_lines.len() - common_suffix; + + if first_changed_row == old_end && first_changed_row == new_end { + return Ok(String::new()); + } + + // Build the replacement text from new_lines[first_diff..new_end]. + let mut merged_new_text = String::new(); + for line in &new_lines[first_changed_row..new_end] { + merged_new_text.push_str(line); + merged_new_text.push('\n'); + } + + // cursor_offset is relative to the first hunk's new content in + // new_text. Translate it to an offset within merged_new_text, which + // only contains lines first_diff..new_end of new_text. + if let Some(hunk_offset) = cursor_offset { + let hunk_start = first_hunk_pos.unwrap_or(0); + let absolute_pos = hunk_start + hunk_offset; + + // Byte offset where first_diff starts in new_text. + let merged_start: usize = new_lines[..first_changed_row] + .iter() + .map(|line| line.len() + 1) + .sum(); + + if absolute_pos >= merged_start { + let relative_offset = absolute_pos - merged_start; + if relative_offset <= merged_new_text.len() { + merged_new_text.insert_str(relative_offset, CURSOR_MARKER); + } + } + } + + // Build output with 2 lines of context above and below. + let context_lines_count = 2; + let mut prefix_start = first_changed_row.saturating_sub(context_lines_count); + let mut suffix_end = (old_end + context_lines_count).min(old_lines.len()); + + fn count_matches(line_range: Range, lines: &[&str]) -> usize { + let pattern = &lines[line_range]; + let pattern_len = pattern.len(); + + let mut count = 0; + for offset in 0..=lines.len() - pattern_len { + if &lines[offset..offset + pattern_len] == pattern { + count += 1; + } + } + count + } + + // Expand prefix and suffix until they are unique + while prefix_start > 0 { + if count_matches(prefix_start..first_changed_row, &old_lines) > 1 { + prefix_start -= 1; + } else { + break; + } + } + while suffix_end < old_lines.len() { + if count_matches(old_end..suffix_end, &old_lines) > 1 { + suffix_end += 1; + } else { + break; + } + } + + let mut output = String::new(); + for line in &old_lines[prefix_start..first_changed_row] { + output.push_str(line); + output.push('\n'); + } + output.push_str("<|fim_middle|>\n"); + output.push_str(&merged_new_text); + output.push_str("<|fim_suffix|>\n"); + for line in &old_lines[old_end..suffix_end] { + output.push_str(line); + output.push('\n'); + } + + Ok(output) + } + + struct ParsedHunk { + old_context: String, + edits: Vec, + } + + struct ParsedEdit { + range: Range, + text: String, + } + + /// Parse a unified diff into content-based hunks. Each hunk contains an + /// `old_context` string (context lines + deleted lines, which together + /// form the text that should be found in the original) and a list of edits + /// expressed as byte ranges within that context. + fn parse_hunks(patch: &str) -> Vec { + let mut hunks = Vec::new(); + let mut current: Option = None; + + for line in patch.lines() { + if line.starts_with("@@") { + if let Some(hunk) = current.take() { + if !hunk.old_context.is_empty() || !hunk.edits.is_empty() { + hunks.push(hunk); + } + } + current = Some(ParsedHunk { + old_context: String::new(), + edits: Vec::new(), + }); + } else if line.starts_with("---") || line.starts_with("+++") { + continue; + } else if let Some(hunk) = &mut current { + if let Some(added) = line.strip_prefix('+') { + let pos = hunk.old_context.len(); + if let Some(last_edit) = hunk.edits.last_mut() { + if last_edit.range.end == pos { + writeln!(&mut last_edit.text, "{added}").ok(); + continue; + } + } + hunk.edits.push(ParsedEdit { + range: pos..pos, + text: format!("{added}\n"), + }); + } else if let Some(removed) = line.strip_prefix('-') { + let start = hunk.old_context.len(); + writeln!(&mut hunk.old_context, "{removed}").ok(); + let end = hunk.old_context.len(); + if let Some(last_edit) = hunk.edits.last_mut() { + if last_edit.range.end == start { + last_edit.range.end = end; + continue; + } + } + hunk.edits.push(ParsedEdit { + range: start..end, + text: String::new(), + }); + } else { + let ctx = line.strip_prefix(' ').unwrap_or(line); + writeln!(&mut hunk.old_context, "{ctx}").ok(); + } + } + } + + if let Some(hunk) = current { + if !hunk.old_context.is_empty() || !hunk.edits.is_empty() { + hunks.push(hunk); + } + } + + hunks + } + + #[cfg(test)] + mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn test_apply_variable_edit() { + struct Case { + name: &'static str, + original: &'static str, + model_output: &'static str, + expected: &'static str, + } + + let cases = [ + Case { + name: "simple_single_line_replacement", + original: indoc! {" + zero + one + two + three + four + five + "}, + model_output: indoc! {" + two + <|fim_middle|> + THREE + <|fim_suffix|> + four + "}, + expected: indoc! {" + zero + one + two + THREE + four + five + "}, + }, + Case { + name: "multi_line_replacement", + original: indoc! {" + a + b + c + d + e + "}, + model_output: indoc! {" + a + <|fim_middle|> + B + C + D + <|fim_suffix|> + e + "}, + expected: indoc! {" + a + B + C + D + e + "}, + }, + Case { + name: "insertion_between_existing_lines", + original: indoc! {" + a + b + c + "}, + model_output: indoc! {" + a + <|fim_middle|> + X + <|fim_suffix|> + b + "}, + expected: indoc! {" + a + X + b + c + "}, + }, + Case { + name: "deletion", + original: indoc! {" + a + b + c + d + "}, + model_output: indoc! {" + a + <|fim_middle|> + <|fim_suffix|> + c + "}, + expected: indoc! {" + a + c + d + "}, + }, + Case { + name: "replacement_at_start_no_prefix_context", + original: indoc! {" + a + b + c + "}, + model_output: indoc! {" + <|fim_middle|> + X + <|fim_suffix|> + b + "}, + expected: indoc! {" + X + b + c + "}, + }, + Case { + name: "replacement_at_end_no_suffix_context", + original: indoc! {" + a + b + c + "}, + model_output: indoc! {" + b + <|fim_middle|> + Z + <|fim_suffix|> + "}, + expected: indoc! {" + a + b + Z + "}, + }, + Case { + name: "context_with_trailing_newline_is_preserved", + original: indoc! {" + a + b + c + "}, + model_output: indoc! {" + a + <|fim_middle|> + B + <|fim_suffix|> + c + "}, + expected: indoc! {" + a + B + c + "}, + }, + Case { + name: "cursor_marker_passes_through_untouched", + original: indoc! {" + a + b + c + "}, + model_output: indoc! {" + a + <|fim_middle|> + B<|user_cursor|>B + <|fim_suffix|> + c + "}, + expected: indoc! {" + a + B<|user_cursor|>B + c + "}, + }, + Case { + name: "multiple_prefix_context_lines", + original: indoc! {" + a + b + c + d + e + "}, + model_output: indoc! {" + b + c + <|fim_middle|> + D + <|fim_suffix|> + e + "}, + expected: indoc! {" + a + b + c + D + e + "}, + }, + ]; + + for case in cases { + let (edit_range, replacement) = + apply_variable_edit(case.original, case.model_output).unwrap(); + let mut edited = case.original.to_string(); + edited.replace_range(edit_range, &replacement); + assert_eq!(edited, case.expected, "{}", case.name); + } + } + + #[test] + fn test_patch_to_variable_edit() { + struct Case { + name: &'static str, + old: &'static str, + patch: &'static str, + cursor_offset: Option, + expected_variable_edit: &'static str, + expected_after_apply: &'static str, + } + + let cases = [ + Case { + name: "simple_replacement", + old: indoc! {" + zero + one + two + three + four + five + "}, + patch: indoc! {" + @@ -3,3 +3,3 @@ + two + -three + +THREE + four + "}, + cursor_offset: None, + expected_variable_edit: indoc! {" + one + two + <|fim_middle|> + THREE + <|fim_suffix|> + four + five + "}, + expected_after_apply: indoc! {" + zero + one + two + THREE + four + five + "}, + }, + Case { + name: "insertion", + old: indoc! {" + a + b + c + d + e + "}, + patch: indoc! {" + @@ -2,0 +3,1 @@ + b + +X + c + "}, + cursor_offset: None, + expected_variable_edit: indoc! {" + a + b + <|fim_middle|> + X + <|fim_suffix|> + c + d + "}, + expected_after_apply: indoc! {" + a + b + X + c + d + e + "}, + }, + Case { + name: "deletion", + old: indoc! {" + a + b + c + d + e + "}, + patch: indoc! {" + @@ -2,3 +2,2 @@ + b + -c + d + "}, + cursor_offset: None, + expected_variable_edit: indoc! {" + a + b + <|fim_middle|> + <|fim_suffix|> + d + e + "}, + expected_after_apply: indoc! {" + a + b + d + e + "}, + }, + Case { + name: "edit_near_start", + old: indoc! {" + first + second + third + fourth + "}, + patch: indoc! {" + @@ -1,1 +1,1 @@ + -first + +FIRST + "}, + cursor_offset: None, + expected_variable_edit: indoc! {" + <|fim_middle|> + FIRST + <|fim_suffix|> + second + third + "}, + expected_after_apply: indoc! {" + FIRST + second + third + fourth + "}, + }, + Case { + name: "edit_near_end", + old: indoc! {" + first + second + third + fourth + "}, + patch: indoc! {" + @@ -4,1 +4,1 @@ + -fourth + +FOURTH + "}, + cursor_offset: None, + expected_variable_edit: indoc! {" + second + third + <|fim_middle|> + FOURTH + <|fim_suffix|> + "}, + expected_after_apply: indoc! {" + first + second + third + FOURTH + "}, + }, + Case { + name: "cursor_at_start_of_replacement", + old: indoc! {" + zero + one + two + three + four + five + "}, + patch: indoc! {" + @@ -3,3 +3,3 @@ + two + -three + +THREE + four + "}, + cursor_offset: Some(4), + expected_variable_edit: indoc! {" + one + two + <|fim_middle|> + <|user_cursor|>THREE + <|fim_suffix|> + four + five + "}, + expected_after_apply: indoc! {" + zero + one + two + <|user_cursor|>THREE + four + five + "}, + }, + Case { + name: "cursor_in_middle_of_replacement", + old: indoc! {" + zero + one + two + three + four + five + "}, + patch: indoc! {" + @@ -3,3 +3,3 @@ + two + -three + +THREE + four + "}, + cursor_offset: Some(6), + expected_variable_edit: indoc! {" + one + two + <|fim_middle|> + TH<|user_cursor|>REE + <|fim_suffix|> + four + five + "}, + expected_after_apply: indoc! {" + zero + one + two + TH<|user_cursor|>REE + four + five + "}, + }, + Case { + name: "expands_context_when_two_lines_not_unique_before_and_after", + old: indoc! {" + one + a + b + c + d + two + a + b + c + d + three + a + b + c + d + four + "}, + patch: indoc! {" + @@ -4,5 +4,5 @@ + two + a + b + -c + +C + d + three + "}, + cursor_offset: None, + expected_variable_edit: indoc! {" + two + a + b + <|fim_middle|> + C + <|fim_suffix|> + d + three + "}, + expected_after_apply: indoc! {" + one + a + b + c + d + two + a + b + C + d + three + a + b + c + d + four + "}, + }, + Case { + name: "expands_context_when_two_lines_not_unique_before_and_after", + old: indoc! {" + { + { + one(); + } + } + { + { + two(); + } + } + { + { + three(); + } + } + { + { + four(); + } + } + "}, + patch: indoc! {" + @@ -4,5 +4,5 @@ + { + - two(); + + TWO(); + } + "}, + cursor_offset: None, + expected_variable_edit: indoc! {" + one(); + } + } + { + { + <|fim_middle|> + TWO(); + <|fim_suffix|> + } + } + { + { + three(); + "}, + expected_after_apply: indoc! {" + { + { + one(); + } + } + { + { + TWO(); + } + } + { + { + three(); + } + } + { + { + four(); + } + } + "}, + }, + ]; + + for case in cases { + let output = + patch_to_variable_edit_output(case.old, case.patch, case.cursor_offset) + .unwrap_or_else(|error| { + panic!("failed converting patch for {}: {error}", case.name) + }); + assert_eq!( + output, case.expected_variable_edit, + "patch->variable_edit mismatch for {}", + case.name + ); + + let (edit_range, replacement) = apply_variable_edit(case.old, &output) + .unwrap_or_else(|error| { + panic!("failed applying variable_edit for {}: {error}", case.name) + }); + let mut edited_by_variable_edit = case.old.to_string(); + edited_by_variable_edit.replace_range(edit_range, &replacement); + assert_eq!( + edited_by_variable_edit, case.expected_after_apply, + "variable_edit apply mismatch for {}", + case.name + ); + + let (expected_edit_range, expected_replacement) = + apply_variable_edit(case.old, case.expected_variable_edit).unwrap_or_else( + |error| { + panic!( + "failed applying expected variable_edit for {}: {error}", + case.name + ) + }, + ); + let mut edited_by_expected_variable_edit = case.old.to_string(); + edited_by_expected_variable_edit + .replace_range(expected_edit_range, &expected_replacement); + assert_eq!( + edited_by_expected_variable_edit, case.expected_after_apply, + "expected variable_edit apply mismatch for {}", + case.name + ); + } + } + + #[test] + fn test_write_cursor_excerpt_section() { + let path = Path::new("test.rs"); + let context = "fn main() {\n hello();\n}\n"; + let cursor_offset = 17; + let mut prompt = String::new(); + write_cursor_excerpt_section(&mut prompt, path, context, cursor_offset); + assert_eq!( + prompt, + "<|file_sep|>test.rs\nfn main() {\n h<|user_cursor|>ello();\n}\n<|fim_prefix|>\n" + ); + } + } +} + /// The zeta1 prompt format pub mod zeta1 { use super::*; @@ -3356,21 +4357,6 @@ mod tests { ); } - #[test] - fn test_seed_coder_clean_output() { - let output_with_marker = "new code\n>>>>>>> UPDATED\n"; - let output_without_marker = "new code\n"; - - assert_eq!( - clean_zeta2_model_output(output_with_marker, ZetaFormat::V0211SeedCoder), - "new code\n" - ); - assert_eq!( - clean_zeta2_model_output(output_without_marker, ZetaFormat::V0211SeedCoder), - "new code\n" - ); - } - #[test] fn test_format_zeta1_from_input_basic() { let excerpt = "fn before() {}\nfn foo() {\n let x = 1;\n}\nfn after() {}\n";