zeta: Improve unified diff prompt (#42354)

Agus Zubiaga and Max Brunsfeld created

Extract some of the improvements from to the unified diff prompt from
https://github.com/zed-industries/zed/pull/42171 and adds some other
about how context work to improve the reliability of predictions.

We also now strip the `<|user_cursor|>` marker if it appears in the
output rather than failing.

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs | 105 ++++++--------
crates/zeta2/src/zeta2.rs                           |  43 +++--
crates/zeta_cli/src/evaluate.rs                     |   2 
crates/zeta_cli/src/predict.rs                      |  14 +
4 files changed, 86 insertions(+), 78 deletions(-)

Detailed changes

crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs 🔗

@@ -56,48 +56,48 @@ const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
 const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
     # Instructions
 
-    You are a code completion assistant helping a programmer finish their work. Your task is to:
+    You are an edit prediction agent in a code editor.
+    Your job is to predict the next edit that the user will make,
+    based on their last few edits and their current cursor location.
 
-    1. Analyze the edit history to understand what the programmer is trying to achieve
-    2. Identify any incomplete refactoring or changes that need to be finished
-    3. Make the remaining edits that a human programmer would logically make next
-    4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
+    ## Output Format
 
-    Focus on:
-    - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
-    - Completing any partially-applied changes across the codebase
-    - Ensuring consistency with the programming style and patterns already established
-    - Making edits that maintain or improve code quality
-    - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
-    - Don't write a lot of code if you're not sure what to do
-
-    Rules:
-    - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
-    - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
-    - Write the edits in the unified diff format as shown in the example.
-
-    # Example output:
+    You must briefly explain your understanding of the user's goal, in one
+    or two sentences, and then specify their next edit in the form of a
+    unified diff, like this:
 
     ```
     --- a/src/myapp/cli.py
     +++ b/src/myapp/cli.py
-    @@ -1,3 +1,3 @@
-    -
-    -
-    -import sys
-    +import json
+    @@ ... @@
+     import os
+     import time
+     import sys
+    +from constants import LOG_LEVEL_WARNING
+    @@ ... @@
+     config.headless()
+     config.set_interactive(false)
+    -config.set_log_level(LOG_L)
+    +config.set_log_level(LOG_LEVEL_WARNING)
+     config.set_use_color(True)
     ```
 
-    # Edit History:
+    ## Edit History
 
 "#};
 
 const UNIFIED_DIFF_REMINDER: &str = indoc! {"
     ---
 
-    Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
+    Analyze the edit history and the files, then provide the unified diff for your predicted edits.
     Do not include the cursor marker in your output.
-    If you're editing multiple files, be sure to reflect filename in the hunk's header.
+    Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`).
+    Do not include line numbers in the hunk headers, use `@@ ... @@`.
+    Removed lines begin with `-`.
+    Added lines begin with `+`.
+    Context lines begin with an extra space.
+    Context and removed lines are used to match the target edit location, so make sure to include enough of them
+    to uniquely identify it amongst all excerpts of code provided.
 "};
 
 pub fn build_prompt(
@@ -121,8 +121,7 @@ pub fn build_prompt(
                 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
             ),
         ],
-        PromptFormat::LabeledSections => vec![(request.cursor_point, CURSOR_MARKER)],
-        PromptFormat::NumLinesUniDiff => {
+        PromptFormat::LabeledSections | PromptFormat::NumLinesUniDiff => {
             vec![(request.cursor_point, CURSOR_MARKER)]
         }
         PromptFormat::OnlySnippets => vec![],
@@ -132,46 +131,31 @@ pub fn build_prompt(
         PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
         PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
         PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
-        // only intended for use via zeta_cli
         PromptFormat::OnlySnippets => String::new(),
     };
 
     if request.events.is_empty() {
         prompt.push_str("(No edit history)\n\n");
     } else {
-        prompt.push_str(
-            "The following are the latest edits made by the user, from earlier to later.\n\n",
-        );
+        prompt.push_str("Here are the latest edits made by the user, from earlier to later.\n\n");
         push_events(&mut prompt, &request.events);
     }
 
+    prompt.push_str(indoc! {"
+        # Code Excerpts
+
+        The cursor marker <|user_cursor|> indicates the current user cursor position.
+        The file is in current state, edits from edit history have been applied.
+    "});
+
     if request.prompt_format == PromptFormat::NumLinesUniDiff {
-        if request.referenced_declarations.is_empty() {
-            prompt.push_str(indoc! {"
-                # File under the cursor:
-
-                The cursor marker <|user_cursor|> indicates the current user cursor position.
-                The file is in current state, edits from edit history have been applied.
-                We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
-
-            "});
-        } else {
-            // Note: This hasn't been trained on yet
-            prompt.push_str(indoc! {"
-                # Code Excerpts:
-
-                The cursor marker <|user_cursor|> indicates the current user cursor position.
-                Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor.
-                Context excerpts are not guaranteed to be relevant, so use your own judgement.
-                Files are in their current state, edits from edit history have been applied.
-                We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
-
-            "});
-        }
-    } else {
-        prompt.push_str("\n## Code\n\n");
+        prompt.push_str(indoc! {"
+            We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
+        "});
     }
 
+    prompt.push('\n');
+
     let mut section_labels = Default::default();
 
     if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() {
@@ -198,8 +182,11 @@ pub fn build_prompt(
         }
     }
 
-    if request.prompt_format == PromptFormat::NumLinesUniDiff {
-        prompt.push_str(UNIFIED_DIFF_REMINDER);
+    match request.prompt_format {
+        PromptFormat::NumLinesUniDiff => {
+            prompt.push_str(UNIFIED_DIFF_REMINDER);
+        }
+        _ => {}
     }
 
     Ok((prompt, section_labels))

crates/zeta2/src/zeta2.rs 🔗

@@ -1,4 +1,4 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Context as _, Result, anyhow, bail};
 use chrono::TimeDelta;
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
@@ -6,8 +6,8 @@ use cloud_llm_client::{
     AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
     ZED_VERSION_HEADER_NAME,
 };
-use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
+use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
 use collections::HashMap;
 use edit_prediction_context::{
     DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
@@ -943,23 +943,34 @@ impl Zeta {
 
                 let (res, usage) = response?;
                 let request_id = EditPredictionId(res.id.clone().into());
-                let Some(output_text) = text_from_response(res) else {
+                let Some(mut output_text) = text_from_response(res) else {
                     return Ok((None, usage))
                 };
 
-                let (edited_buffer_snapshot, edits) =
-                    crate::udiff::parse_diff(&output_text, |path| {
-                        included_files
-                            .iter()
-                            .find_map(|(_, buffer, probe_path, ranges)| {
-                                if probe_path.as_ref() == path {
-                                    Some((buffer, ranges.as_slice()))
-                                } else {
-                                    None
-                                }
-                            })
-                    })
-                    .await?;
+                if output_text.contains(CURSOR_MARKER) {
+                    log::trace!("Stripping out {CURSOR_MARKER} from response");
+                    output_text = output_text.replace(CURSOR_MARKER, "");
+                }
+
+                let (edited_buffer_snapshot, edits) = match options.prompt_format {
+                    PromptFormat::NumLinesUniDiff => {
+                        crate::udiff::parse_diff(&output_text, |path| {
+                            included_files
+                                .iter()
+                                .find_map(|(_, buffer, probe_path, ranges)| {
+                                    if probe_path.as_ref() == path {
+                                        Some((buffer, ranges.as_slice()))
+                                    } else {
+                                        None
+                                    }
+                                })
+                        })
+                        .await?
+                    }
+                    _ => {
+                        bail!("unsupported prompt format {}", options.prompt_format)
+                    }
+                };
 
                 let edited_buffer = included_files
                     .iter()

crates/zeta_cli/src/evaluate.rs 🔗

@@ -67,7 +67,7 @@ pub async fn run_evaluate_one(
         );
         as_json
     } else {
-        zeta2_predict(example.clone(), &app_state, cx)
+        zeta2_predict(example.clone(), Default::default(), &app_state, cx)
             .await
             .unwrap()
     };

crates/zeta_cli/src/predict.rs 🔗

@@ -1,9 +1,11 @@
+use crate::PromptFormat;
 use crate::example::{ActualExcerpt, NamedExample};
 use crate::headless::ZetaCliAppState;
 use crate::paths::LOGS_DIR;
 use ::serde::Serialize;
 use anyhow::{Result, anyhow};
 use clap::Args;
+// use cloud_llm_client::predict_edits_v3::PromptFormat;
 use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 use futures::StreamExt as _;
 use gpui::{AppContext, AsyncApp};
@@ -19,9 +21,11 @@ use std::time::{Duration, Instant};
 
 #[derive(Debug, Args)]
 pub struct PredictArguments {
-    example_path: PathBuf,
+    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
+    prompt_format: PromptFormat,
     #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
     format: PredictionsOutputFormat,
+    example_path: PathBuf,
 }
 
 #[derive(clap::ValueEnum, Debug, Clone)]
@@ -36,7 +40,9 @@ pub async fn run_zeta2_predict(
     cx: &mut AsyncApp,
 ) {
     let example = NamedExample::load(args.example_path).unwrap();
-    let result = zeta2_predict(example, &app_state, cx).await.unwrap();
+    let result = zeta2_predict(example, args.prompt_format, &app_state, cx)
+        .await
+        .unwrap();
     result.write(args.format, std::io::stdout()).unwrap();
 }
 
@@ -46,6 +52,7 @@ thread_local! {
 
 pub async fn zeta2_predict(
     example: NamedExample,
+    prompt_format: PromptFormat,
     app_state: &Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
 ) -> Result<PredictionDetails> {
@@ -193,6 +200,9 @@ pub async fn zeta2_predict(
     });
 
     zeta.update(cx, |zeta, cx| {
+        let mut options = zeta.options().clone();
+        options.prompt_format = prompt_format.into();
+        zeta.set_options(options);
         zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
     })?
     .await?;