@@ -1,24 +1,30 @@
-use crate::cursor_excerpt::compute_excerpt_ranges;
-use crate::prediction::EditPredictionResult;
use crate::{
CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
+ ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges,
+ prediction::EditPredictionResult,
};
use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
-use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
+use cloud_llm_client::{
+ AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
+};
use edit_prediction_types::PredictedCursorPosition;
-use gpui::{App, AppContext as _, Task, prelude::*};
-use language::language_settings::all_language_settings;
-use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
+use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
+use language::{
+ Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings,
+ text_diff,
+};
use release_channel::AppVersion;
use settings::EditPredictionPromptFormat;
use text::{Anchor, Bias};
+use ui::SharedString;
+use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+use zeta_prompt::ZetaPromptInput;
use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::{
- CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
- output_with_context_for_format, prompt_input_contains_special_tokens,
+ CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
+ prompt_input_contains_special_tokens,
zeta1::{self, EDITABLE_REGION_END_MARKER},
};
@@ -86,6 +92,17 @@ pub fn request_prediction_with_zeta(
.map(|organization| organization.id.clone());
let app_version = AppVersion::global(cx);
+ struct Prediction {
+ prompt_input: ZetaPromptInput,
+ buffer: Entity<Buffer>,
+ snapshot: BufferSnapshot,
+ edits: Vec<(Range<Anchor>, Arc<str>)>,
+ cursor_position: Option<PredictedCursorPosition>,
+ received_response_at: Instant,
+ editable_range_in_buffer: Range<usize>,
+ model_version: Option<String>,
+ }
+
let request_task = cx.background_spawn({
async move {
let zeta_version = raw_config
@@ -94,7 +111,6 @@ pub fn request_prediction_with_zeta(
.unwrap_or(ZetaFormat::default());
let cursor_offset = position.to_offset(&snapshot);
- let editable_range_in_excerpt: Range<usize>;
let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
&snapshot,
related_files,
@@ -108,7 +124,7 @@ pub fn request_prediction_with_zeta(
);
if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
- return Ok((None, None));
+ return Err(anyhow::anyhow!("prompt contains special tokens"));
}
if let Some(debug_tx) = &debug_tx {
@@ -126,19 +142,19 @@ pub fn request_prediction_with_zeta(
log::trace!("Sending edit prediction request");
- let (request_id, output_text, model_version, usage) =
+ let (request_id, output, model_version, usage) =
if let Some(custom_settings) = &custom_server_settings {
let max_tokens = custom_settings.max_output_tokens * 4;
match custom_settings.prompt_format {
EditPredictionPromptFormat::Zeta => {
let ranges = &prompt_input.excerpt_ranges;
+ let editable_range_in_excerpt = ranges.editable_350.clone();
let prompt = zeta1::format_zeta1_from_input(
&prompt_input,
- ranges.editable_350.clone(),
+ editable_range_in_excerpt.clone(),
ranges.editable_350_context_150.clone(),
);
- editable_range_in_excerpt = ranges.editable_350.clone();
let stop_tokens = vec![
EDITABLE_REGION_END_MARKER.to_string(),
format!("{EDITABLE_REGION_END_MARKER}\n"),
@@ -160,19 +176,18 @@ pub fn request_prediction_with_zeta(
let request_id = EditPredictionId(request_id.into());
let output_text = zeta1::clean_zeta1_model_output(&response_text);
- (request_id, output_text, None, None)
+ (
+ request_id,
+ Some(editable_range_in_excerpt).zip(output_text),
+ None,
+ None,
+ )
}
EditPredictionPromptFormat::Zeta2 => {
let prompt = format_zeta_prompt(&prompt_input, zeta_version);
let prefill = get_prefill(&prompt_input, zeta_version);
let prompt = format!("{prompt}{prefill}");
- editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
- zeta_version,
- &prompt_input.excerpt_ranges,
- )
- .0;
-
let (response_text, request_id) = send_custom_server_request(
provider,
custom_settings,
@@ -189,7 +204,11 @@ pub fn request_prediction_with_zeta(
None
} else {
let output = format!("{prefill}{response_text}");
- Some(clean_zeta2_model_output(&output, zeta_version).to_string())
+ Some(parse_zeta2_model_output(
+ &output,
+ zeta_version,
+ &prompt_input,
+ )?)
};
(request_id, output_text, None, None)
@@ -213,12 +232,6 @@ pub fn request_prediction_with_zeta(
environment,
};
- editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
- config.format,
- &prompt_input.excerpt_ranges,
- )
- .1;
-
let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
request,
client,
@@ -230,13 +243,19 @@ pub fn request_prediction_with_zeta(
.await?;
let request_id = EditPredictionId(response.id.clone().into());
- let output_text = response.choices.pop().map(|choice| {
+ let output = if let Some(choice) = response.choices.pop() {
let response = &choice.text;
let output = format!("{prefill}{response}");
- clean_zeta2_model_output(&output, config.format).to_string()
- });
+ Some(parse_zeta2_model_output(
+ &output,
+ config.format,
+ &prompt_input,
+ )?)
+ } else {
+ None
+ };
- (request_id, output_text, None, usage)
+ (request_id, output, None, usage)
} else {
// Use V3 endpoint - server handles model/version selection and suffix stripping
let (response, usage) = EditPredictionStore::send_v3_request(
@@ -250,23 +269,23 @@ pub fn request_prediction_with_zeta(
.await?;
let request_id = EditPredictionId(response.request_id.into());
- let output_text = if response.output.is_empty() {
- None
- } else {
- Some(response.output)
- };
- editable_range_in_excerpt = response.editable_range;
+ let output_text = Some(response.output).filter(|s| !s.is_empty());
let model_version = response.model_version;
- (request_id, output_text, model_version, usage)
+ (
+ request_id,
+ Some(response.editable_range).zip(output_text),
+ model_version,
+ usage,
+ )
};
let received_response_at = Instant::now();
log::trace!("Got edit prediction response");
- let Some(mut output_text) = output_text else {
- return Ok((Some((request_id, None, model_version)), usage));
+ let Some((editable_range_in_excerpt, mut output_text)) = output else {
+ return Ok(((request_id, None), None));
};
let editable_range_in_buffer = editable_range_in_excerpt.start
@@ -277,17 +296,6 @@ pub fn request_prediction_with_zeta(
.text_for_range(editable_range_in_buffer.clone())
.collect::<String>();
- // For the hashline format, the model may return <|set|>/<|insert|>
- // edit commands instead of a full replacement. Apply them against
- // the original editable region to produce the full replacement text.
- // This must happen before cursor marker stripping because the cursor
- // marker is embedded inside edit command content.
- if let Some(rewritten_output) =
- output_with_context_for_format(zeta_version, &old_text, &output_text)?
- {
- output_text = rewritten_output;
- }
-
// Client-side cursor marker processing (applies to both raw and v3 responses)
let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
if let Some(offset) = cursor_offset_in_output {
@@ -323,40 +331,37 @@ pub fn request_prediction_with_zeta(
);
anyhow::Ok((
- Some((
+ (
request_id,
- Some((
+ Some(Prediction {
prompt_input,
buffer,
- snapshot.clone(),
+ snapshot: snapshot.clone(),
edits,
cursor_position,
received_response_at,
editable_range_in_buffer,
- )),
- model_version,
- )),
+ model_version,
+ }),
+ ),
usage,
))
}
});
cx.spawn(async move |this, cx| {
- let Some((id, prediction, model_version)) =
- EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
- else {
- return Ok(None);
- };
+ let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
- let Some((
- inputs,
- edited_buffer,
- edited_buffer_snapshot,
+ let Some(Prediction {
+ prompt_input: inputs,
+ buffer: edited_buffer,
+ snapshot: edited_buffer_snapshot,
edits,
cursor_position,
received_response_at,
editable_range_in_buffer,
- )) = prediction
+ model_version,
+ }) = prediction
else {
return Ok(Some(EditPredictionResult {
id,
@@ -423,6 +428,49 @@ pub fn request_prediction_with_zeta(
})
}
+fn handle_api_response<T>(
+ this: &WeakEntity<EditPredictionStore>,
+ response: Result<(T, Option<client::EditPredictionUsage>)>,
+ cx: &mut gpui::AsyncApp,
+) -> Result<T> {
+ match response {
+ Ok((data, usage)) => {
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
+ Ok(data)
+ }
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |this, _cx| {
+ this.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button("Update Zed", "https://zed.dev/releases")
+ })
+ },
+ );
+ });
+ }
+ Err(err)
+ }
+ }
+}
+
pub fn zeta2_prompt_input(
snapshot: &language::BufferSnapshot,
related_files: Vec<zeta_prompt::RelatedFile>,
@@ -1,4 +1,4 @@
-use anyhow::Result;
+use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::ops::Range;
@@ -89,6 +89,7 @@ pub enum ZetaFormat {
V0211Prefill,
V0211SeedCoder,
v0226Hashline,
+ V0304VariableEdit,
V0304SeedNoEdits,
}
@@ -216,6 +217,7 @@ pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str]
ZetaFormat::V0211Prefill => v0211_prefill::special_tokens(),
ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(),
ZetaFormat::v0226Hashline => hashline::special_tokens(),
+ ZetaFormat::V0304VariableEdit => v0304_variable_edit::special_tokens(),
ZetaFormat::V0304SeedNoEdits => seed_coder::special_tokens(),
}
}
@@ -242,6 +244,13 @@ pub fn excerpt_ranges_for_format(
ranges.editable_350.clone(),
ranges.editable_350_context_150.clone(),
),
+ ZetaFormat::V0304VariableEdit => {
+ let context = ranges
+ .context_8192
+ .clone()
+ .unwrap_or_else(|| ranges.editable_350_context_150.clone());
+ (context.clone(), context)
+ }
}
}
@@ -302,6 +311,9 @@ pub fn write_cursor_excerpt_section_for_format(
editable_range,
cursor_offset,
),
+ ZetaFormat::V0304VariableEdit => {
+ v0304_variable_edit::write_cursor_excerpt_section(prompt, path, context, cursor_offset)
+ }
}
}
@@ -418,7 +430,8 @@ pub fn get_prefill_for_format(
| ZetaFormat::V0131GitMergeMarkersPrefix
| ZetaFormat::V0211SeedCoder
| ZetaFormat::v0226Hashline
- | ZetaFormat::V0304SeedNoEdits => String::new(),
+ | ZetaFormat::V0304VariableEdit => String::new(),
+ ZetaFormat::V0304SeedNoEdits => String::new(),
}
}
@@ -431,32 +444,8 @@ pub fn output_end_marker_for_format(format: ZetaFormat) -> Option<&'static str>
ZetaFormat::V0112MiddleAtEnd
| ZetaFormat::V0113Ordered
| ZetaFormat::V0114180EditableRegion
- | ZetaFormat::v0226Hashline => None,
- }
-}
-
-pub fn current_region_markers_for_format(format: ZetaFormat) -> (&'static str, &'static str) {
- match format {
- ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
- ZetaFormat::V0113Ordered
- | ZetaFormat::V0114180EditableRegion
- | ZetaFormat::v0226Hashline => ("<|fim_middle|>current\n", "<|fim_suffix|>"),
- ZetaFormat::V0120GitMergeMarkers
- | ZetaFormat::V0131GitMergeMarkersPrefix
- | ZetaFormat::V0211Prefill => (
- v0120_git_merge_markers::START_MARKER,
- v0120_git_merge_markers::SEPARATOR,
- ),
- ZetaFormat::V0211SeedCoder | ZetaFormat::V0304SeedNoEdits => {
- (seed_coder::START_MARKER, seed_coder::SEPARATOR)
- }
- }
-}
-
-pub fn clean_extracted_region_for_format(format: ZetaFormat, region: &str) -> String {
- match format {
- ZetaFormat::v0226Hashline => hashline::strip_hashline_prefixes(region),
- _ => region.to_string(),
+ | ZetaFormat::v0226Hashline
+ | ZetaFormat::V0304VariableEdit => None,
}
}
@@ -470,43 +459,52 @@ pub fn encode_patch_as_output_for_format(
ZetaFormat::v0226Hashline => {
hashline::patch_to_edit_commands(old_editable_region, patch, cursor_offset).map(Some)
}
+ ZetaFormat::V0304VariableEdit => v0304_variable_edit::patch_to_variable_edit_output(
+ old_editable_region,
+ patch,
+ cursor_offset,
+ )
+ .map(Some),
ZetaFormat::V0304SeedNoEdits => Ok(seed_coder::no_edits(patch)),
_ => Ok(None),
}
}
-pub fn output_with_context_for_format(
- format: ZetaFormat,
- old_editable_region: &str,
+/// Parse model output for the given zeta format
+pub fn parse_zeta2_model_output(
output: &str,
-) -> Result<Option<String>> {
+ format: ZetaFormat,
+ prompt_inputs: &ZetaPromptInput,
+) -> Result<(Range<usize>, String)> {
+ let output = match output_end_marker_for_format(format) {
+ Some(marker) => output.strip_suffix(marker).unwrap_or(output),
+ None => output,
+ };
+
+ let (context, editable_range, _, _) = resolve_cursor_region(prompt_inputs, format);
+ let old_editable_region = &context[editable_range.clone()];
+
match format {
- ZetaFormat::v0226Hashline => {
+ ZetaFormat::v0226Hashline => Ok((
+ editable_range,
if hashline::output_has_edit_commands(output) {
- Ok(Some(hashline::apply_edit_commands(
- old_editable_region,
- output,
- )))
+ hashline::apply_edit_commands(old_editable_region, output)
} else {
- Ok(None)
- }
+ output.to_string()
+ },
+ )),
+ ZetaFormat::V0304VariableEdit => {
+ v0304_variable_edit::apply_variable_edit(old_editable_region, output)
}
- ZetaFormat::V0304SeedNoEdits => {
+ ZetaFormat::V0304SeedNoEdits => Ok((
+ editable_range,
if output.starts_with(seed_coder::NO_EDITS) {
- Ok(Some(old_editable_region.to_owned()))
+ old_editable_region.to_string()
} else {
- Ok(None)
- }
- }
- _ => Ok(None),
- }
-}
-
-/// Post-processes model output for the given zeta format by stripping format-specific suffixes.
-pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str {
- match output_end_marker_for_format(format) {
- Some(marker) => output.strip_suffix(marker).unwrap_or(output),
- None => output,
+ output.to_string()
+ },
+ )),
+ _ => Ok((editable_range, output.to_string())),
}
}
@@ -2565,6 +2563,1009 @@ pub mod seed_coder {
}
}
+pub mod v0304_variable_edit {
+ //! A prompt format with no fixed editable region. The entire context is shown
+ //! to the model, and it chooses which text to replace by outputting surrounding
+ //! context lines with `<|fim_middle|>` and `<|fim_suffix|>` delimiting the new
+ //! text.
+ //!
+ //! Example prompt:
+ //!
+ //! <|file_sep|>path/to/file.py
+ //! zero
+ //! one
+ //! two
+ //! three<|user_cursor|>
+ //! four
+ //! five
+ //! <|fim_prefix|>
+ //
+ //! Expected output (model generates):
+ //!
+ //! two
+ //! <|fim_middle|>
+ //! THREE
+ //! <|fim_suffix|>
+ //! four
+ //!
+ //! The output means: find "two\n...\nfour" in the context, and replace
+ //! everything between "two\n" and "four" with "THREE\n".
+
+ use super::*;
+
+ pub fn special_tokens() -> &'static [&'static str] {
+ &[
+ "<|fim_prefix|>",
+ "<|fim_suffix|>",
+ "<|fim_middle|>",
+ "<|file_sep|>",
+ CURSOR_MARKER,
+ ]
+ }
+
+ pub fn write_cursor_excerpt_section(
+ prompt: &mut String,
+ path: &Path,
+ context: &str,
+ cursor_offset: usize,
+ ) {
+ let path_str = path.to_string_lossy();
+ write!(prompt, "<|file_sep|>{}\n", path_str).ok();
+
+ prompt.push_str(&context[..cursor_offset]);
+ prompt.push_str(CURSOR_MARKER);
+ prompt.push_str(&context[cursor_offset..]);
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+ prompt.push_str("<|fim_prefix|>\n")
+ }
+
+ /// Apply a variable-edit model output to the original context text.
+ ///
+ /// The model output has the form:
+ ///
+ /// - prefix context lines
+ /// - `<|fim_middle|>`
+ /// - new text
+ /// - `<|fim_suffix|>`
+ /// - suffix context lines
+ ///
+ /// We locate the prefix/suffix context lines in the original text and replace
+ /// everything between them with the new text.
+ pub fn apply_variable_edit(
+ context: &str,
+ model_output: &str,
+ ) -> Result<(Range<usize>, String)> {
+ let (prefix_context, rest) = model_output
+ .split_once("<|fim_middle|>\n")
+ .or_else(|| model_output.split_once("<|fim_middle|>"))
+ .ok_or_else(|| anyhow::anyhow!("missing <|fim_middle|> in model output"))?;
+
+ let (new_text, suffix_context) = rest
+ .split_once("<|fim_suffix|>\n")
+ .or_else(|| rest.split_once("<|fim_suffix|>"))
+ .unwrap_or((rest, ""));
+
+ let suffix_context = if prefix_context.is_empty() && !suffix_context.is_empty() {
+ suffix_context.strip_prefix('\n').unwrap_or(suffix_context)
+ } else {
+ suffix_context
+ };
+
+ let prefix_offset = find_substring_at_line_boundary(context, prefix_context)
+ .ok_or_else(|| anyhow!("could not locate prefix lines"))?
+ + prefix_context.len();
+ let suffix_offset = if suffix_context.is_empty() {
+ context.len()
+ } else {
+ find_substring_at_line_boundary(&context[prefix_offset..], suffix_context)
+ .ok_or_else(|| anyhow!("could not locate suffix lines"))?
+ + prefix_offset
+ };
+
+ let edit_range = prefix_offset..suffix_offset;
+ return Ok((edit_range, new_text.to_string()));
+ }
+
+ fn find_substring_at_line_boundary(haystack: &str, needle: &str) -> Option<usize> {
+ if needle.is_empty() {
+ return Some(0);
+ }
+
+ haystack.match_indices(needle).find_map(|(offset, _)| {
+ let matched_line_start = offset == 0 || haystack[..offset].ends_with('\n');
+ matched_line_start.then_some(offset)
+ })
+ }
+
+ /// Convert a unified diff patch into the variable-edit output format.
+ ///
+ /// Parses `patch` as a unified diff against `old_text` and produces model
+ /// output with context lines surrounding `<|fim_middle|>` / `<|fim_suffix|>`
+ /// delimiters. The diff is resolved by content matching rather than line
+ /// numbers.
+ pub fn patch_to_variable_edit_output(
+ old_text: &str,
+ patch: &str,
+ cursor_offset: Option<usize>,
+ ) -> Result<String> {
+ // Parse the unified diff into hunks. Each hunk has an `old_context`
+ // string (context + deleted lines interleaved in order) and a list of
+ // edits expressed as byte ranges within that context plus replacement
+ // text.
+ let hunks = parse_hunks(patch);
+ if hunks.is_empty() {
+ return Ok(String::new());
+ }
+
+ // Apply each hunk by finding its old_context in the text and
+ // performing the edits. We search forward from where the previous
+ // hunk ended so that hunks are applied in order.
+ let mut new_text = old_text.to_string();
+ let mut search_from: usize = 0;
+ let mut first_hunk_pos: Option<usize> = None;
+
+ for hunk in &hunks {
+ let context_pos = new_text[search_from..]
+ .find(&hunk.old_context)
+ .map(|pos| pos + search_from)
+ .ok_or_else(|| anyhow::anyhow!("could not locate hunk context in text"))?;
+
+ if first_hunk_pos.is_none() {
+ first_hunk_pos = Some(context_pos);
+ }
+
+ // Apply edits in reverse order so byte offsets remain valid.
+ for edit in hunk.edits.iter().rev() {
+ let abs_start = context_pos + edit.range.start;
+ let abs_end = context_pos + edit.range.end;
+ new_text.replace_range(abs_start..abs_end, &edit.text);
+ }
+
+ // Advance past this hunk's region in the (now modified) text.
+ let new_region_len: usize =
+ hunk.edits.iter().fold(hunk.old_context.len(), |len, edit| {
+ len + edit.text.len() - (edit.range.end - edit.range.start)
+ });
+ search_from = context_pos + new_region_len;
+ }
+
+ // Now we have old_text and new_text. Find the changed line range by
+ // comparing them.
+ let old_lines: Vec<&str> = old_text.lines().collect();
+ let new_lines: Vec<&str> = new_text.lines().collect();
+
+ // Find first differing line.
+ let first_changed_row = old_lines
+ .iter()
+ .zip(new_lines.iter())
+ .position(|(a, b)| a != b)
+ .unwrap_or_else(|| old_lines.len().min(new_lines.len()));
+
+ // Find last differing line (from the end).
+ let max_suffix = old_lines.len().min(new_lines.len()) - first_changed_row;
+ let common_suffix = old_lines
+ .iter()
+ .rev()
+ .zip(new_lines.iter().rev())
+ .take(max_suffix)
+ .take_while(|(a, b)| a == b)
+ .count();
+
+ let old_end = old_lines.len() - common_suffix;
+ let new_end = new_lines.len() - common_suffix;
+
+ if first_changed_row == old_end && first_changed_row == new_end {
+ return Ok(String::new());
+ }
+
+ // Build the replacement text from new_lines[first_diff..new_end].
+ let mut merged_new_text = String::new();
+ for line in &new_lines[first_changed_row..new_end] {
+ merged_new_text.push_str(line);
+ merged_new_text.push('\n');
+ }
+
+ // cursor_offset is relative to the first hunk's new content in
+ // new_text. Translate it to an offset within merged_new_text, which
+ // only contains lines first_diff..new_end of new_text.
+ if let Some(hunk_offset) = cursor_offset {
+ let hunk_start = first_hunk_pos.unwrap_or(0);
+ let absolute_pos = hunk_start + hunk_offset;
+
+ // Byte offset where first_diff starts in new_text.
+ let merged_start: usize = new_lines[..first_changed_row]
+ .iter()
+ .map(|line| line.len() + 1)
+ .sum();
+
+ if absolute_pos >= merged_start {
+ let relative_offset = absolute_pos - merged_start;
+ if relative_offset <= merged_new_text.len() {
+ merged_new_text.insert_str(relative_offset, CURSOR_MARKER);
+ }
+ }
+ }
+
+ // Build output with 2 lines of context above and below.
+ let context_lines_count = 2;
+ let mut prefix_start = first_changed_row.saturating_sub(context_lines_count);
+ let mut suffix_end = (old_end + context_lines_count).min(old_lines.len());
+
+ fn count_matches(line_range: Range<usize>, lines: &[&str]) -> usize {
+ let pattern = &lines[line_range];
+ let pattern_len = pattern.len();
+
+ let mut count = 0;
+ for offset in 0..=lines.len() - pattern_len {
+ if &lines[offset..offset + pattern_len] == pattern {
+ count += 1;
+ }
+ }
+ count
+ }
+
+ // Expand prefix and suffix until they are unique
+ while prefix_start > 0 {
+ if count_matches(prefix_start..first_changed_row, &old_lines) > 1 {
+ prefix_start -= 1;
+ } else {
+ break;
+ }
+ }
+ while suffix_end < old_lines.len() {
+ if count_matches(old_end..suffix_end, &old_lines) > 1 {
+ suffix_end += 1;
+ } else {
+ break;
+ }
+ }
+
+ let mut output = String::new();
+ for line in &old_lines[prefix_start..first_changed_row] {
+ output.push_str(line);
+ output.push('\n');
+ }
+ output.push_str("<|fim_middle|>\n");
+ output.push_str(&merged_new_text);
+ output.push_str("<|fim_suffix|>\n");
+ for line in &old_lines[old_end..suffix_end] {
+ output.push_str(line);
+ output.push('\n');
+ }
+
+ Ok(output)
+ }
+
+ struct ParsedHunk {
+ old_context: String,
+ edits: Vec<ParsedEdit>,
+ }
+
+ struct ParsedEdit {
+ range: Range<usize>,
+ text: String,
+ }
+
+ /// Parse a unified diff into content-based hunks. Each hunk contains an
+ /// `old_context` string (context lines + deleted lines, which together
+ /// form the text that should be found in the original) and a list of edits
+ /// expressed as byte ranges within that context.
+ fn parse_hunks(patch: &str) -> Vec<ParsedHunk> {
+ let mut hunks = Vec::new();
+ let mut current: Option<ParsedHunk> = None;
+
+ for line in patch.lines() {
+ if line.starts_with("@@") {
+ if let Some(hunk) = current.take() {
+ if !hunk.old_context.is_empty() || !hunk.edits.is_empty() {
+ hunks.push(hunk);
+ }
+ }
+ current = Some(ParsedHunk {
+ old_context: String::new(),
+ edits: Vec::new(),
+ });
+ } else if line.starts_with("---") || line.starts_with("+++") {
+ continue;
+ } else if let Some(hunk) = &mut current {
+ if let Some(added) = line.strip_prefix('+') {
+ let pos = hunk.old_context.len();
+ if let Some(last_edit) = hunk.edits.last_mut() {
+ if last_edit.range.end == pos {
+ writeln!(&mut last_edit.text, "{added}").ok();
+ continue;
+ }
+ }
+ hunk.edits.push(ParsedEdit {
+ range: pos..pos,
+ text: format!("{added}\n"),
+ });
+ } else if let Some(removed) = line.strip_prefix('-') {
+ let start = hunk.old_context.len();
+ writeln!(&mut hunk.old_context, "{removed}").ok();
+ let end = hunk.old_context.len();
+ if let Some(last_edit) = hunk.edits.last_mut() {
+ if last_edit.range.end == start {
+ last_edit.range.end = end;
+ continue;
+ }
+ }
+ hunk.edits.push(ParsedEdit {
+ range: start..end,
+ text: String::new(),
+ });
+ } else {
+ let ctx = line.strip_prefix(' ').unwrap_or(line);
+ writeln!(&mut hunk.old_context, "{ctx}").ok();
+ }
+ }
+ }
+
+ if let Some(hunk) = current {
+ if !hunk.old_context.is_empty() || !hunk.edits.is_empty() {
+ hunks.push(hunk);
+ }
+ }
+
+ hunks
+ }
+
+ #[cfg(test)]
+ mod tests {
+ use super::*;
+ use indoc::indoc;
+
+ #[test]
+ fn test_apply_variable_edit() {
+ struct Case {
+ name: &'static str,
+ original: &'static str,
+ model_output: &'static str,
+ expected: &'static str,
+ }
+
+ let cases = [
+ Case {
+ name: "simple_single_line_replacement",
+ original: indoc! {"
+ zero
+ one
+ two
+ three
+ four
+ five
+ "},
+ model_output: indoc! {"
+ two
+ <|fim_middle|>
+ THREE
+ <|fim_suffix|>
+ four
+ "},
+ expected: indoc! {"
+ zero
+ one
+ two
+ THREE
+ four
+ five
+ "},
+ },
+ Case {
+ name: "multi_line_replacement",
+ original: indoc! {"
+ a
+ b
+ c
+ d
+ e
+ "},
+ model_output: indoc! {"
+ a
+ <|fim_middle|>
+ B
+ C
+ D
+ <|fim_suffix|>
+ e
+ "},
+ expected: indoc! {"
+ a
+ B
+ C
+ D
+ e
+ "},
+ },
+ Case {
+ name: "insertion_between_existing_lines",
+ original: indoc! {"
+ a
+ b
+ c
+ "},
+ model_output: indoc! {"
+ a
+ <|fim_middle|>
+ X
+ <|fim_suffix|>
+ b
+ "},
+ expected: indoc! {"
+ a
+ X
+ b
+ c
+ "},
+ },
+ Case {
+ name: "deletion",
+ original: indoc! {"
+ a
+ b
+ c
+ d
+ "},
+ model_output: indoc! {"
+ a
+ <|fim_middle|>
+ <|fim_suffix|>
+ c
+ "},
+ expected: indoc! {"
+ a
+ c
+ d
+ "},
+ },
+ Case {
+ name: "replacement_at_start_no_prefix_context",
+ original: indoc! {"
+ a
+ b
+ c
+ "},
+ model_output: indoc! {"
+ <|fim_middle|>
+ X
+ <|fim_suffix|>
+ b
+ "},
+ expected: indoc! {"
+ X
+ b
+ c
+ "},
+ },
+ Case {
+ name: "replacement_at_end_no_suffix_context",
+ original: indoc! {"
+ a
+ b
+ c
+ "},
+ model_output: indoc! {"
+ b
+ <|fim_middle|>
+ Z
+ <|fim_suffix|>
+ "},
+ expected: indoc! {"
+ a
+ b
+ Z
+ "},
+ },
+ Case {
+ name: "context_with_trailing_newline_is_preserved",
+ original: indoc! {"
+ a
+ b
+ c
+ "},
+ model_output: indoc! {"
+ a
+ <|fim_middle|>
+ B
+ <|fim_suffix|>
+ c
+ "},
+ expected: indoc! {"
+ a
+ B
+ c
+ "},
+ },
+ Case {
+ name: "cursor_marker_passes_through_untouched",
+ original: indoc! {"
+ a
+ b
+ c
+ "},
+ model_output: indoc! {"
+ a
+ <|fim_middle|>
+ B<|user_cursor|>B
+ <|fim_suffix|>
+ c
+ "},
+ expected: indoc! {"
+ a
+ B<|user_cursor|>B
+ c
+ "},
+ },
+ Case {
+ name: "multiple_prefix_context_lines",
+ original: indoc! {"
+ a
+ b
+ c
+ d
+ e
+ "},
+ model_output: indoc! {"
+ b
+ c
+ <|fim_middle|>
+ D
+ <|fim_suffix|>
+ e
+ "},
+ expected: indoc! {"
+ a
+ b
+ c
+ D
+ e
+ "},
+ },
+ ];
+
+ for case in cases {
+ let (edit_range, replacement) =
+ apply_variable_edit(case.original, case.model_output).unwrap();
+ let mut edited = case.original.to_string();
+ edited.replace_range(edit_range, &replacement);
+ assert_eq!(edited, case.expected, "{}", case.name);
+ }
+ }
+
+ #[test]
+ fn test_patch_to_variable_edit() {
+ struct Case {
+ name: &'static str,
+ old: &'static str,
+ patch: &'static str,
+ cursor_offset: Option<usize>,
+ expected_variable_edit: &'static str,
+ expected_after_apply: &'static str,
+ }
+
+ let cases = [
+ Case {
+ name: "simple_replacement",
+ old: indoc! {"
+ zero
+ one
+ two
+ three
+ four
+ five
+ "},
+ patch: indoc! {"
+ @@ -3,3 +3,3 @@
+ two
+ -three
+ +THREE
+ four
+ "},
+ cursor_offset: None,
+ expected_variable_edit: indoc! {"
+ one
+ two
+ <|fim_middle|>
+ THREE
+ <|fim_suffix|>
+ four
+ five
+ "},
+ expected_after_apply: indoc! {"
+ zero
+ one
+ two
+ THREE
+ four
+ five
+ "},
+ },
+ Case {
+ name: "insertion",
+ old: indoc! {"
+ a
+ b
+ c
+ d
+ e
+ "},
+ patch: indoc! {"
+ @@ -2,0 +3,1 @@
+ b
+ +X
+ c
+ "},
+ cursor_offset: None,
+ expected_variable_edit: indoc! {"
+ a
+ b
+ <|fim_middle|>
+ X
+ <|fim_suffix|>
+ c
+ d
+ "},
+ expected_after_apply: indoc! {"
+ a
+ b
+ X
+ c
+ d
+ e
+ "},
+ },
+ Case {
+ name: "deletion",
+ old: indoc! {"
+ a
+ b
+ c
+ d
+ e
+ "},
+ patch: indoc! {"
+ @@ -2,3 +2,2 @@
+ b
+ -c
+ d
+ "},
+ cursor_offset: None,
+ expected_variable_edit: indoc! {"
+ a
+ b
+ <|fim_middle|>
+ <|fim_suffix|>
+ d
+ e
+ "},
+ expected_after_apply: indoc! {"
+ a
+ b
+ d
+ e
+ "},
+ },
+ Case {
+ name: "edit_near_start",
+ old: indoc! {"
+ first
+ second
+ third
+ fourth
+ "},
+ patch: indoc! {"
+ @@ -1,1 +1,1 @@
+ -first
+ +FIRST
+ "},
+ cursor_offset: None,
+ expected_variable_edit: indoc! {"
+ <|fim_middle|>
+ FIRST
+ <|fim_suffix|>
+ second
+ third
+ "},
+ expected_after_apply: indoc! {"
+ FIRST
+ second
+ third
+ fourth
+ "},
+ },
+ Case {
+ name: "edit_near_end",
+ old: indoc! {"
+ first
+ second
+ third
+ fourth
+ "},
+ patch: indoc! {"
+ @@ -4,1 +4,1 @@
+ -fourth
+ +FOURTH
+ "},
+ cursor_offset: None,
+ expected_variable_edit: indoc! {"
+ second
+ third
+ <|fim_middle|>
+ FOURTH
+ <|fim_suffix|>
+ "},
+ expected_after_apply: indoc! {"
+ first
+ second
+ third
+ FOURTH
+ "},
+ },
+ Case {
+ name: "cursor_at_start_of_replacement",
+ old: indoc! {"
+ zero
+ one
+ two
+ three
+ four
+ five
+ "},
+ patch: indoc! {"
+ @@ -3,3 +3,3 @@
+ two
+ -three
+ +THREE
+ four
+ "},
+ cursor_offset: Some(4),
+ expected_variable_edit: indoc! {"
+ one
+ two
+ <|fim_middle|>
+ <|user_cursor|>THREE
+ <|fim_suffix|>
+ four
+ five
+ "},
+ expected_after_apply: indoc! {"
+ zero
+ one
+ two
+ <|user_cursor|>THREE
+ four
+ five
+ "},
+ },
+ Case {
+ name: "cursor_in_middle_of_replacement",
+ old: indoc! {"
+ zero
+ one
+ two
+ three
+ four
+ five
+ "},
+ patch: indoc! {"
+ @@ -3,3 +3,3 @@
+ two
+ -three
+ +THREE
+ four
+ "},
+ cursor_offset: Some(6),
+ expected_variable_edit: indoc! {"
+ one
+ two
+ <|fim_middle|>
+ TH<|user_cursor|>REE
+ <|fim_suffix|>
+ four
+ five
+ "},
+ expected_after_apply: indoc! {"
+ zero
+ one
+ two
+ TH<|user_cursor|>REE
+ four
+ five
+ "},
+ },
+ Case {
+ name: "expands_context_when_two_lines_not_unique_before_and_after",
+ old: indoc! {"
+ one
+ a
+ b
+ c
+ d
+ two
+ a
+ b
+ c
+ d
+ three
+ a
+ b
+ c
+ d
+ four
+ "},
+ patch: indoc! {"
+ @@ -4,5 +4,5 @@
+ two
+ a
+ b
+ -c
+ +C
+ d
+ three
+ "},
+ cursor_offset: None,
+ expected_variable_edit: indoc! {"
+ two
+ a
+ b
+ <|fim_middle|>
+ C
+ <|fim_suffix|>
+ d
+ three
+ "},
+ expected_after_apply: indoc! {"
+ one
+ a
+ b
+ c
+ d
+ two
+ a
+ b
+ C
+ d
+ three
+ a
+ b
+ c
+ d
+ four
+ "},
+ },
+ Case {
+ name: "expands_context_when_two_lines_not_unique_before_and_after",
+ old: indoc! {"
+ {
+ {
+ one();
+ }
+ }
+ {
+ {
+ two();
+ }
+ }
+ {
+ {
+ three();
+ }
+ }
+ {
+ {
+ four();
+ }
+ }
+ "},
+ patch: indoc! {"
+ @@ -4,5 +4,5 @@
+ {
+ - two();
+ + TWO();
+ }
+ "},
+ cursor_offset: None,
+ expected_variable_edit: indoc! {"
+ one();
+ }
+ }
+ {
+ {
+ <|fim_middle|>
+ TWO();
+ <|fim_suffix|>
+ }
+ }
+ {
+ {
+ three();
+ "},
+ expected_after_apply: indoc! {"
+ {
+ {
+ one();
+ }
+ }
+ {
+ {
+ TWO();
+ }
+ }
+ {
+ {
+ three();
+ }
+ }
+ {
+ {
+ four();
+ }
+ }
+ "},
+ },
+ ];
+
+ for case in cases {
+ let output =
+ patch_to_variable_edit_output(case.old, case.patch, case.cursor_offset)
+ .unwrap_or_else(|error| {
+ panic!("failed converting patch for {}: {error}", case.name)
+ });
+ assert_eq!(
+ output, case.expected_variable_edit,
+ "patch->variable_edit mismatch for {}",
+ case.name
+ );
+
+ let (edit_range, replacement) = apply_variable_edit(case.old, &output)
+ .unwrap_or_else(|error| {
+ panic!("failed applying variable_edit for {}: {error}", case.name)
+ });
+ let mut edited_by_variable_edit = case.old.to_string();
+ edited_by_variable_edit.replace_range(edit_range, &replacement);
+ assert_eq!(
+ edited_by_variable_edit, case.expected_after_apply,
+ "variable_edit apply mismatch for {}",
+ case.name
+ );
+
+ let (expected_edit_range, expected_replacement) =
+ apply_variable_edit(case.old, case.expected_variable_edit).unwrap_or_else(
+ |error| {
+ panic!(
+ "failed applying expected variable_edit for {}: {error}",
+ case.name
+ )
+ },
+ );
+ let mut edited_by_expected_variable_edit = case.old.to_string();
+ edited_by_expected_variable_edit
+ .replace_range(expected_edit_range, &expected_replacement);
+ assert_eq!(
+ edited_by_expected_variable_edit, case.expected_after_apply,
+ "expected variable_edit apply mismatch for {}",
+ case.name
+ );
+ }
+ }
+
+ #[test]
+ fn test_write_cursor_excerpt_section() {
+ let path = Path::new("test.rs");
+ let context = "fn main() {\n hello();\n}\n";
+ let cursor_offset = 17;
+ let mut prompt = String::new();
+ write_cursor_excerpt_section(&mut prompt, path, context, cursor_offset);
+ assert_eq!(
+ prompt,
+ "<|file_sep|>test.rs\nfn main() {\n h<|user_cursor|>ello();\n}\n<|fim_prefix|>\n"
+ );
+ }
+ }
+}
+
/// The zeta1 prompt format
pub mod zeta1 {
use super::*;