From d420dd63ed1e8274691eaf8466d39fc9a4a60997 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 10 Nov 2025 11:58:42 -0300 Subject: [PATCH] zeta: Improve unified diff prompt (#42354) Extract some of the improvements from to the unified diff prompt from https://github.com/zed-industries/zed/pull/42171 and adds some other about how context work to improve the reliability of predictions. We also now strip the `<|user_cursor|>` marker if it appears in the output rather than failing. Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld --- .../src/cloud_zeta2_prompt.rs | 105 ++++++++---------- crates/zeta2/src/zeta2.rs | 43 ++++--- crates/zeta_cli/src/evaluate.rs | 2 +- crates/zeta_cli/src/predict.rs | 14 ++- 4 files changed, 86 insertions(+), 78 deletions(-) diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 7fb79906f29f38579feef82bb25e7ed42d1d6c83..6055c39e16ea95b38754bb26fd7371250d1fc525 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -56,48 +56,48 @@ const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#" const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#" # Instructions - You are a code completion assistant helping a programmer finish their work. Your task is to: + You are an edit prediction agent in a code editor. + Your job is to predict the next edit that the user will make, + based on their last few edits and their current cursor location. - 1. Analyze the edit history to understand what the programmer is trying to achieve - 2. Identify any incomplete refactoring or changes that need to be finished - 3. Make the remaining edits that a human programmer would logically make next - 4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere. + ## Output Format - Focus on: - - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs) - - Completing any partially-applied changes across the codebase - - Ensuring consistency with the programming style and patterns already established - - Making edits that maintain or improve code quality - - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances - - Don't write a lot of code if you're not sure what to do - - Rules: - - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. - - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. - - Write the edits in the unified diff format as shown in the example. - - # Example output: + You must briefly explain your understanding of the user's goal, in one + or two sentences, and then specify their next edit in the form of a + unified diff, like this: ``` --- a/src/myapp/cli.py +++ b/src/myapp/cli.py - @@ -1,3 +1,3 @@ - - - - - -import sys - +import json + @@ ... @@ + import os + import time + import sys + +from constants import LOG_LEVEL_WARNING + @@ ... @@ + config.headless() + config.set_interactive(false) + -config.set_log_level(LOG_L) + +config.set_log_level(LOG_LEVEL_WARNING) + config.set_use_color(True) ``` - # Edit History: + ## Edit History "#}; const UNIFIED_DIFF_REMINDER: &str = indoc! {" --- - Please analyze the edit history and the files, then provide the unified diff for your predicted edits. + Analyze the edit history and the files, then provide the unified diff for your predicted edits. Do not include the cursor marker in your output. - If you're editing multiple files, be sure to reflect filename in the hunk's header. + Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`). + Do not include line numbers in the hunk headers, use `@@ ... @@`. + Removed lines begin with `-`. + Added lines begin with `+`. + Context lines begin with an extra space. + Context and removed lines are used to match the target edit location, so make sure to include enough of them + to uniquely identify it amongst all excerpts of code provided. "}; pub fn build_prompt( @@ -121,8 +121,7 @@ pub fn build_prompt( EDITABLE_REGION_END_MARKER_WITH_NEWLINE, ), ], - PromptFormat::LabeledSections => vec![(request.cursor_point, CURSOR_MARKER)], - PromptFormat::NumLinesUniDiff => { + PromptFormat::LabeledSections | PromptFormat::NumLinesUniDiff => { vec![(request.cursor_point, CURSOR_MARKER)] } PromptFormat::OnlySnippets => vec![], @@ -132,46 +131,31 @@ pub fn build_prompt( PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(), PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(), PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(), - // only intended for use via zeta_cli PromptFormat::OnlySnippets => String::new(), }; if request.events.is_empty() { prompt.push_str("(No edit history)\n\n"); } else { - prompt.push_str( - "The following are the latest edits made by the user, from earlier to later.\n\n", - ); + prompt.push_str("Here are the latest edits made by the user, from earlier to later.\n\n"); push_events(&mut prompt, &request.events); } + prompt.push_str(indoc! {" + # Code Excerpts + + The cursor marker <|user_cursor|> indicates the current user cursor position. + The file is in current state, edits from edit history have been applied. + "}); + if request.prompt_format == PromptFormat::NumLinesUniDiff { - if request.referenced_declarations.is_empty() { - prompt.push_str(indoc! {" - # File under the cursor: - - The cursor marker <|user_cursor|> indicates the current user cursor position. - The file is in current state, edits from edit history have been applied. - We prepend line numbers (e.g., `123|`); they are not part of the file. - - "}); - } else { - // Note: This hasn't been trained on yet - prompt.push_str(indoc! {" - # Code Excerpts: - - The cursor marker <|user_cursor|> indicates the current user cursor position. - Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor. - Context excerpts are not guaranteed to be relevant, so use your own judgement. - Files are in their current state, edits from edit history have been applied. - We prepend line numbers (e.g., `123|`); they are not part of the file. - - "}); - } - } else { - prompt.push_str("\n## Code\n\n"); + prompt.push_str(indoc! {" + We prepend line numbers (e.g., `123|`); they are not part of the file. + "}); } + prompt.push('\n'); + let mut section_labels = Default::default(); if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() { @@ -198,8 +182,11 @@ pub fn build_prompt( } } - if request.prompt_format == PromptFormat::NumLinesUniDiff { - prompt.push_str(UNIFIED_DIFF_REMINDER); + match request.prompt_format { + PromptFormat::NumLinesUniDiff => { + prompt.push_str(UNIFIED_DIFF_REMINDER); + } + _ => {} } Ok((prompt, section_labels)) diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 3a51f9975ccbcf3fb325712f7aafadc5187da541..297bfa1c4a940448e7fdb570ea4b808556c3f416 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1,4 +1,4 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result, anyhow, bail}; use chrono::TimeDelta; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature}; @@ -6,8 +6,8 @@ use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; -use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; +use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES}; use collections::HashMap; use edit_prediction_context::{ DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, @@ -943,23 +943,34 @@ impl Zeta { let (res, usage) = response?; let request_id = EditPredictionId(res.id.clone().into()); - let Some(output_text) = text_from_response(res) else { + let Some(mut output_text) = text_from_response(res) else { return Ok((None, usage)) }; - let (edited_buffer_snapshot, edits) = - crate::udiff::parse_diff(&output_text, |path| { - included_files - .iter() - .find_map(|(_, buffer, probe_path, ranges)| { - if probe_path.as_ref() == path { - Some((buffer, ranges.as_slice())) - } else { - None - } - }) - }) - .await?; + if output_text.contains(CURSOR_MARKER) { + log::trace!("Stripping out {CURSOR_MARKER} from response"); + output_text = output_text.replace(CURSOR_MARKER, ""); + } + + let (edited_buffer_snapshot, edits) = match options.prompt_format { + PromptFormat::NumLinesUniDiff => { + crate::udiff::parse_diff(&output_text, |path| { + included_files + .iter() + .find_map(|(_, buffer, probe_path, ranges)| { + if probe_path.as_ref() == path { + Some((buffer, ranges.as_slice())) + } else { + None + } + }) + }) + .await? + } + _ => { + bail!("unsupported prompt format {}", options.prompt_format) + } + }; let edited_buffer = included_files .iter() diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index f99747e676b777e5d7a086c61db2f9e8d152c20b..c0f513fa38df5fb837be2294845eeae3214074bd 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -67,7 +67,7 @@ pub async fn run_evaluate_one( ); as_json } else { - zeta2_predict(example.clone(), &app_state, cx) + zeta2_predict(example.clone(), Default::default(), &app_state, cx) .await .unwrap() }; diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index a593a1b12ceb2b72a316463076657f35ac2c4e9d..f7f503ffebe24d71023ad259ce76adfdea364efc 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -1,9 +1,11 @@ +use crate::PromptFormat; use crate::example::{ActualExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; use crate::paths::LOGS_DIR; use ::serde::Serialize; use anyhow::{Result, anyhow}; use clap::Args; +// use cloud_llm_client::predict_edits_v3::PromptFormat; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use futures::StreamExt as _; use gpui::{AppContext, AsyncApp}; @@ -19,9 +21,11 @@ use std::time::{Duration, Instant}; #[derive(Debug, Args)] pub struct PredictArguments { - example_path: PathBuf, + #[arg(long, value_enum, default_value_t = PromptFormat::default())] + prompt_format: PromptFormat, #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] format: PredictionsOutputFormat, + example_path: PathBuf, } #[derive(clap::ValueEnum, Debug, Clone)] @@ -36,7 +40,9 @@ pub async fn run_zeta2_predict( cx: &mut AsyncApp, ) { let example = NamedExample::load(args.example_path).unwrap(); - let result = zeta2_predict(example, &app_state, cx).await.unwrap(); + let result = zeta2_predict(example, args.prompt_format, &app_state, cx) + .await + .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); } @@ -46,6 +52,7 @@ thread_local! { pub async fn zeta2_predict( example: NamedExample, + prompt_format: PromptFormat, app_state: &Arc, cx: &mut AsyncApp, ) -> Result { @@ -193,6 +200,9 @@ pub async fn zeta2_predict( }); zeta.update(cx, |zeta, cx| { + let mut options = zeta.options().clone(); + options.prompt_format = prompt_format.into(); + zeta.set_options(options); zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) })? .await?;