Introduce zeta2 format with cursor content in original order (#46732)

Max Brunsfeld , Agus Zubiaga , and Ben Kunkle created

This one does `fim_prefix`, `fim_middle`, and `fim_suffix` in that
order, in the prompt, instead of putting the current middle last.

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

Cargo.lock                                          |   2 
crates/edit_prediction/src/edit_prediction.rs       |  22 +
crates/edit_prediction/src/edit_prediction_tests.rs |  86 ++++++++
crates/edit_prediction/src/zeta2.rs                 |  15 +
crates/edit_prediction_cli/src/example.rs           |  26 -
crates/edit_prediction_cli/src/format_prompt.rs     | 122 +++++------
crates/edit_prediction_cli/src/load_project.rs      |  23 +
crates/edit_prediction_cli/src/main.rs              |  35 ++-
crates/edit_prediction_cli/src/predict.rs           |  26 +
crates/edit_prediction_cli/src/pull_examples.rs     |   3 
crates/edit_prediction_cli/src/retrieve_context.rs  |  14 
crates/edit_prediction_cli/src/score.rs             |  11 
crates/zed/src/zed/edit_prediction_registry.rs      |   4 
crates/zeta_prompt/Cargo.toml                       |   4 
crates/zeta_prompt/src/zeta_prompt.rs               | 149 ++++++++++++--
15 files changed, 382 insertions(+), 160 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -21179,7 +21179,9 @@ dependencies = [
 name = "zeta_prompt"
 version = "0.1.0"
 dependencies = [
+ "anyhow",
  "serde",
+ "strum 0.27.2",
 ]
 
 [[package]]

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -38,6 +38,7 @@ use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
 use std::collections::{VecDeque, hash_map};
 use text::Edit;
 use workspace::Workspace;
+use zeta_prompt::ZetaVersion;
 
 use std::ops::Range;
 use std::path::Path;
@@ -183,7 +184,9 @@ pub struct EditPredictionStore {
 pub enum EditPredictionModel {
     #[default]
     Zeta1,
-    Zeta2,
+    Zeta2 {
+        version: ZetaVersion,
+    },
     Sweep,
     Mercury,
 }
@@ -654,7 +657,9 @@ impl EditPredictionStore {
             update_required: false,
             #[cfg(feature = "cli-support")]
             eval_cache: None,
-            edit_prediction_model: EditPredictionModel::Zeta2,
+            edit_prediction_model: EditPredictionModel::Zeta2 {
+                version: Default::default(),
+            },
             sweep_ai: SweepAi::new(cx),
             mercury: Mercury::new(cx),
             data_collection_choice,
@@ -794,7 +799,10 @@ impl EditPredictionStore {
     }
 
     pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
-        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
+        if matches!(
+            self.edit_prediction_model,
+            EditPredictionModel::Zeta2 { .. }
+        ) {
             self.user_store.read(cx).edit_prediction_usage()
         } else {
             None
@@ -1204,7 +1212,7 @@ impl EditPredictionStore {
                 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
             }
             EditPredictionModel::Mercury => {}
-            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
                 zeta2::edit_prediction_accepted(self, current_prediction, cx)
             }
         }
@@ -1338,7 +1346,7 @@ impl EditPredictionStore {
         was_shown: bool,
     ) {
         match self.edit_prediction_model {
-            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
                 if self.custom_predict_edits_url.is_some() {
                     return;
                 }
@@ -1773,7 +1781,9 @@ impl EditPredictionStore {
         }
         let task = match self.edit_prediction_model {
             EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
-            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
+            EditPredictionModel::Zeta2 { version } => {
+                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
+            }
             EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
             EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
         };

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1332,12 +1332,20 @@ fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawComp
 
     let current_marker = "<|fim_middle|>current\n";
     let updated_marker = "<|fim_middle|>updated\n";
+    let suffix_marker = "<|fim_suffix|>\n";
     let cursor = "<|user_cursor|>";
 
     let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
     let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
     let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
-    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
+    // In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content.
+    // Strip that out to get just the editable region.
+    let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) {
+        &excerpt[..suffix_pos]
+    } else {
+        &excerpt
+    };
+    let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap();
 
     RawCompletionResponse {
         id: Uuid::new_v4().to_string(),
@@ -1629,6 +1637,82 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
     );
 }
 
+#[gpui::test]
+async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
+    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
+    // When the buffer ends without a trailing newline, but the model returns output
+    // with a trailing newline, zeta2 should normalize both sides before diffing
+    // so no spurious newline is inserted.
+    let (ep_store, mut requests) = init_test_with_fake_client(cx);
+    let fs = FakeFs::new(cx.executor());
+
+    // Single line buffer with no trailing newline
+    fs.insert_tree(
+        "/root",
+        json!({
+            "foo.txt": "hello"
+        }),
+    )
+    .await;
+    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+    let buffer = project
+        .update(cx, |project, cx| {
+            let path = project
+                .find_project_path(path!("root/foo.txt"), cx)
+                .unwrap();
+            project.open_buffer(path, cx)
+        })
+        .await
+        .unwrap();
+
+    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+    let position = snapshot.anchor_before(language::Point::new(0, 5));
+
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+    });
+
+    let (_request, respond_tx) = requests.predict.next().await.unwrap();
+
+    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
+    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
+    let response = RawCompletionResponse {
+        id: Uuid::new_v4().to_string(),
+        object: "text_completion".into(),
+        created: 0,
+        model: "model".into(),
+        choices: vec![RawCompletionChoice {
+            text: "hello world\n".to_string(),
+            finish_reason: None,
+        }],
+        usage: RawCompletionUsage {
+            prompt_tokens: 0,
+            completion_tokens: 0,
+            total_tokens: 0,
+        },
+    };
+    respond_tx.send(response).unwrap();
+
+    cx.run_until_parked();
+
+    // The prediction should insert " world" without adding a newline
+    ep_store.update(cx, |ep_store, cx| {
+        let prediction = ep_store
+            .prediction_at(&buffer, None, &project, cx)
+            .expect("should have prediction");
+        let edits: Vec<_> = prediction
+            .edits
+            .iter()
+            .map(|(range, text)| {
+                let snapshot = buffer.read(cx).snapshot();
+                (range.to_offset(&snapshot), text.clone())
+            })
+            .collect();
+        assert_eq!(edits, vec![(5..5, " world".into())]);
+    });
+}
+
 #[gpui::test]
 async fn test_can_collect_data(cx: &mut TestAppContext) {
     init_test(cx);

crates/edit_prediction/src/zeta2.rs 🔗

@@ -15,8 +15,8 @@ use release_channel::AppVersion;
 
 use std::env;
 use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::CURSOR_MARKER;
 use zeta_prompt::format_zeta_prompt;
+use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
 
 pub const MAX_CONTEXT_TOKENS: usize = 350;
 pub const MAX_EDITABLE_TOKENS: usize = 150;
@@ -32,6 +32,7 @@ pub fn request_prediction_with_zeta2(
         debug_tx,
         ..
     }: EditPredictionModelInput,
+    zeta_version: ZetaVersion,
     cx: &mut Context<EditPredictionStore>,
 ) -> Task<Result<Option<EditPredictionResult>>> {
     let buffer_snapshotted_at = Instant::now();
@@ -62,7 +63,7 @@ pub fn request_prediction_with_zeta2(
                 cursor_offset,
             );
 
-            let prompt = format_zeta_prompt(&prompt_input);
+            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 
             if let Some(debug_tx) = &debug_tx {
                 debug_tx
@@ -125,9 +126,17 @@ pub fn request_prediction_with_zeta2(
                 output_text = output_text.replace(CURSOR_MARKER, "");
             }
 
-            let old_text = snapshot
+            let mut old_text = snapshot
                 .text_for_range(editable_offset_range.clone())
                 .collect::<String>();
+
+            if !output_text.is_empty() && !output_text.ends_with('\n') {
+                output_text.push('\n');
+            }
+            if !old_text.is_empty() && !old_text.ends_with('\n') {
+                old_text.push('\n');
+            }
+
             let edits: Vec<_> = language::text_diff(&old_text, &output_text)
                 .into_iter()
                 .map(|(range, text)| {

crates/edit_prediction_cli/src/example.rs 🔗

@@ -1,5 +1,5 @@
+use crate::PredictionProvider;
 use crate::paths::WORKTREES_DIR;
-use crate::{PredictionProvider, PromptFormat};
 use anyhow::{Context as _, Result};
 use collections::HashMap;
 use edit_prediction::example_spec::ExampleSpec;
@@ -9,11 +9,12 @@ use http_client::Url;
 use language::{Anchor, Buffer};
 use project::Project;
 use serde::{Deserialize, Serialize};
-use std::ops::Range;
 use std::{
     borrow::Cow,
     io::Read,
+    ops::Range,
     path::{Path, PathBuf},
+    sync::Arc,
 };
 use zeta_prompt::RelatedFile;
 
@@ -25,12 +26,7 @@ pub struct Example {
     /// The full content of the file where an edit is being predicted, and the
     /// actual cursor offset.
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub buffer: Option<ExampleBuffer>,
-
-    /// The context retrieved for the prediction. This requires the worktree to
-    /// be loaded and the language server to be started.
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub context: Option<ExampleContext>,
+    pub prompt_inputs: Option<ExamplePromptInputs>,
 
     /// The input and expected output from the edit prediction model.
     #[serde(skip_serializing_if = "Option::is_none")]
@@ -59,25 +55,22 @@ pub struct ExampleState {
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct ExampleContext {
-    pub files: Vec<RelatedFile>,
-}
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct ExampleBuffer {
+pub struct ExamplePromptInputs {
     pub content: String,
     pub cursor_row: u32,
     pub cursor_column: u32,
     pub cursor_offset: usize,
     pub context_range: Range<usize>,
     pub editable_range: Range<usize>,
+    pub edit_history: Vec<Arc<zeta_prompt::Event>>,
+    pub related_files: Option<Vec<RelatedFile>>,
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct ExamplePrompt {
     pub input: String,
     pub expected_output: String,
-    pub format: PromptFormat,
+    pub provider: PredictionProvider,
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
@@ -239,8 +232,7 @@ fn parse_markdown_example(input: &str) -> Result<Example> {
     let spec = ExampleSpec::from_markdown(input)?;
     Ok(Example {
         spec,
-        buffer: None,
-        context: None,
+        prompt_inputs: None,
         prompt: None,
         predictions: Vec::new(),
         score: Vec::new(),

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -1,14 +1,12 @@
 use crate::{
-    PromptFormat,
+    FormatPromptArgs, PredictionProvider,
     example::{Example, ExamplePrompt},
     headless::EpAppState,
-    load_project::run_load_project,
     progress::{Progress, Step},
     retrieve_context::run_context_retrieval,
 };
 use anyhow::{Context as _, Result};
-use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
-use gpui::{AsyncApp, Entity};
+use gpui::AsyncApp;
 use similar::DiffableStr;
 use std::fmt::Write as _;
 use std::sync::Arc;
@@ -16,16 +14,21 @@ use zeta_prompt::format_zeta_prompt;
 
 pub async fn run_format_prompt(
     example: &mut Example,
-    prompt_format: PromptFormat,
+    args: &FormatPromptArgs,
     app_state: Arc<EpAppState>,
-    mut cx: AsyncApp,
+    cx: AsyncApp,
 ) -> Result<()> {
-    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
+    run_context_retrieval(example, app_state, cx).await?;
 
     let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 
-    match prompt_format {
-        PromptFormat::Teacher => {
+    let prompt_inputs = example
+        .prompt_inputs
+        .as_ref()
+        .context("prompt_inputs must be set after context retrieval")?;
+
+    match args.provider {
+        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
             step_progress.set_substatus("formatting teacher prompt");
             let prompt = TeacherPrompt::format_prompt(example);
             example.prompt = Some(ExamplePrompt {
@@ -36,47 +39,27 @@ pub async fn run_format_prompt(
                     .first()
                     .cloned()
                     .unwrap_or_default(),
-                format: prompt_format,
+                provider: args.provider,
             });
         }
-        PromptFormat::Zeta2 => {
-            step_progress.set_substatus("loading project");
-            run_load_project(example, app_state, cx.clone()).await?;
-
+        PredictionProvider::Zeta2 => {
             step_progress.set_substatus("formatting zeta2 prompt");
 
-            let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
-                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
-            })?;
-
-            let state = example.state.as_ref().context("state must be set")?;
-            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
-            let project = state.project.clone();
-            let (_, input) =
-                ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
-                    let events = ep_store
-                        .edit_history_for_project(&project, cx)
-                        .into_iter()
-                        .map(|e| e.event)
-                        .collect();
-                    anyhow::Ok(zeta2_prompt_input(
-                        &snapshot,
-                        example
-                            .context
-                            .as_ref()
-                            .context("context must be set")?
-                            .files
-                            .clone(),
-                        events,
-                        example.spec.cursor_path.clone(),
-                        example
-                            .buffer
-                            .as_ref()
-                            .context("buffer must be set")?
-                            .cursor_offset,
-                    ))
-                })?;
-            let prompt = format_zeta_prompt(&input);
+            let context_start = prompt_inputs.context_range.start;
+            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
+            let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start)
+                ..(prompt_inputs.editable_range.end - context_start);
+            let input = zeta_prompt::ZetaPromptInput {
+                cursor_path: example.spec.cursor_path.clone(),
+                cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()]
+                    .to_string()
+                    .into(),
+                editable_range_in_excerpt,
+                cursor_offset_in_excerpt,
+                events: prompt_inputs.edit_history.clone(),
+                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
+            };
+            let prompt = format_zeta_prompt(&input, args.version);
             let expected_output = zeta2_output_for_patch(
                 &input,
                 &example
@@ -89,9 +72,12 @@ pub async fn run_format_prompt(
             example.prompt = Some(ExamplePrompt {
                 input: prompt,
                 expected_output,
-                format: prompt_format,
+                provider: args.provider,
             });
         }
+        _ => {
+            panic!("Cannot format prompt for {:?}", args.provider);
+        }
     };
     Ok(())
 }
@@ -144,10 +130,10 @@ impl TeacherPrompt {
         // 2. Context retriever just didn't include cursor line.
         //
         // In that case, fallback to using `cursor_position` as excerpt.
-        let example_buffer = example
-            .buffer
+        let prompt_inputs = example
+            .prompt_inputs
             .as_ref()
-            .context("`buffer` should be filled in in the context collection step")?;
+            .context("`prompt_inputs` should be filled in in the context collection step")?;
 
         // Extract updated (new) editable region from the model response.
         // The model may include editable region markers in its output, so we need to strip them.
@@ -155,7 +141,7 @@ impl TeacherPrompt {
         let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
 
         let old_editable_region =
-            example_buffer.content[example_buffer.editable_range.clone()].to_string();
+            prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string();
 
         // Normalize leading newlines: if old starts with newline but new doesn't,
         // prepend newline to new to preserve whitespace structure.
@@ -164,8 +150,8 @@ impl TeacherPrompt {
             new_editable_region.insert(0, '\n');
         }
 
-        let editable_region_start_line = example_buffer.content
-            [..example_buffer.editable_range.start]
+        let editable_region_start_line = prompt_inputs.content
+            [..prompt_inputs.editable_range.start]
             .matches('\n')
             .count();
 
@@ -208,17 +194,21 @@ impl TeacherPrompt {
     }
 
     fn format_context(example: &Example) -> String {
-        let context = example
-            .context
+        let related_files = example
+            .prompt_inputs
             .as_ref()
-            .expect("Missing context retriever step");
+            .and_then(|pi| pi.related_files.as_ref());
+
+        let Some(related_files) = related_files else {
+            return "(No context)".to_string();
+        };
 
-        if context.files.is_empty() {
+        if related_files.is_empty() {
             return "(No context)".to_string();
         }
 
         let mut prompt = String::new();
-        for file in context.files.iter() {
+        for file in related_files {
             let path_str = file.path.to_string_lossy();
             writeln!(&mut prompt, "`````{path_str}").ok();
             let mut prev_row = 0;
@@ -242,28 +232,26 @@ impl TeacherPrompt {
     fn format_cursor_excerpt(example: &Example) -> String {
         let mut result = String::new();
 
-        let example_buffer = example.buffer.as_ref().unwrap();
+        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
 
         let path_str = example.spec.cursor_path.to_string_lossy();
         result.push_str(&format!("`````{path_str}\n"));
         result.push_str(
-            &example_buffer.content
-                [example_buffer.context_range.start..example_buffer.editable_range.start],
+            &prompt_inputs.content
+                [prompt_inputs.context_range.start..prompt_inputs.editable_range.start],
         );
         result.push_str(Self::EDITABLE_REGION_START);
         result.push_str(
-            &example_buffer.content
-                [example_buffer.editable_range.start..example_buffer.cursor_offset],
+            &prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset],
         );
         result.push_str(Self::USER_CURSOR_MARKER);
         result.push_str(
-            &example_buffer.content
-                [example_buffer.cursor_offset..example_buffer.editable_range.end],
+            &prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end],
         );
         result.push_str(Self::EDITABLE_REGION_END);
         result.push_str(
-            &example_buffer.content
-                [example_buffer.editable_range.end..example_buffer.context_range.end],
+            &prompt_inputs.content
+                [prompt_inputs.editable_range.end..prompt_inputs.context_range.end],
         );
         result.push_str("\n`````");
 

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    example::{Example, ExampleBuffer, ExampleState},
+    example::{Example, ExamplePromptInputs, ExampleState},
     git,
     headless::EpAppState,
     progress::{InfoStyle, Progress, Step, StepProgress},
@@ -38,7 +38,20 @@ pub async fn run_load_project(
     buffer
         .read_with(&cx, |buffer, _| buffer.parsing_idle())
         .await;
-    let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
+
+    let ep_store = cx
+        .update(|cx| EditPredictionStore::try_global(cx))
+        .context("EditPredictionStore not initialized")?;
+
+    let edit_history = ep_store.update(&mut cx, |store, cx| {
+        store
+            .edit_history_for_project(&project, cx)
+            .into_iter()
+            .map(|e| e.event)
+            .collect()
+    });
+
+    let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| {
         let cursor_point = cursor_position.to_point(&buffer);
         let snapshot = buffer.snapshot();
         let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
@@ -54,13 +67,15 @@ pub async fn run_load_project(
             .map(|l| l.name().to_string())
             .unwrap_or_else(|| "Unknown".to_string());
         (
-            ExampleBuffer {
+            ExamplePromptInputs {
                 content: buffer.text(),
                 cursor_row: cursor_point.row,
                 cursor_column: cursor_point.column,
                 cursor_offset: cursor_position.to_offset(&buffer),
                 context_range,
                 editable_range,
+                edit_history,
+                related_files: None,
             },
             language_name,
         )
@@ -68,7 +83,7 @@ pub async fn run_load_project(
 
     progress.set_info(language_name, InfoStyle::Normal);
 
-    example.buffer = Some(example_buffer);
+    example.prompt_inputs = Some(prompt_inputs);
     example.state = Some(ExampleState {
         buffer,
         project,

crates/edit_prediction_cli/src/main.rs 🔗

@@ -22,6 +22,7 @@ use edit_prediction::EditPredictionStore;
 use futures::channel::mpsc;
 use futures::{SinkExt as _, StreamExt as _};
 use gpui::{AppContext as _, Application};
+use zeta_prompt::ZetaVersion;
 
 use reqwest_client::ReqwestClient;
 use serde::{Deserialize, Serialize};
@@ -155,7 +156,7 @@ impl Display for Command {
                 f,
                 "format-prompt --prompt-format={}",
                 format_prompt_args
-                    .prompt_format
+                    .provider
                     .to_possible_value()
                     .unwrap()
                     .get_name()
@@ -204,22 +205,31 @@ impl Display for Command {
 
 #[derive(Debug, Args, Clone)]
 struct FormatPromptArgs {
-    #[clap(long, short('p'))]
-    prompt_format: PromptFormat,
-}
-
-#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
-enum PromptFormat {
-    Teacher,
-    Zeta2,
+    #[clap(long, short)]
+    provider: PredictionProvider,
+    #[clap(
+        long,
+        short,
+        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
+        value_parser = ZetaVersion::parse,
+        default_value_t = ZetaVersion::default(),
+    )]
+    version: ZetaVersion,
 }
 
 #[derive(Debug, Args, Clone)]
 struct PredictArgs {
-    #[clap(long)]
+    #[clap(long, short)]
     provider: PredictionProvider,
     #[clap(long, default_value_t = 1)]
     repetitions: usize,
+    #[clap(
+        long,
+        short,
+        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
+        value_parser = ZetaVersion::parse,
+    )]
+    version: ZetaVersion,
 }
 
 #[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
@@ -514,7 +524,7 @@ fn main() {
                                     Command::FormatPrompt(args) => {
                                         run_format_prompt(
                                             example,
-                                            args.prompt_format,
+                                            args,
                                             app_state.clone(),
                                             cx.clone(),
                                         )
@@ -523,8 +533,7 @@ fn main() {
                                     Command::Predict(args) => {
                                         run_prediction(
                                             example,
-                                            Some(args.provider),
-                                            args.repetitions,
+                                            args,
                                             app_state.clone(),
                                             cx.clone(),
                                         )

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    PredictionProvider, PromptFormat,
+    FormatPromptArgs, PredictArgs, PredictionProvider,
     anthropic_client::AnthropicClient,
     example::{Example, ExamplePrediction, ExamplePrompt},
     format_prompt::{TeacherPrompt, run_format_prompt},
@@ -25,12 +25,13 @@ static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 
 pub async fn run_prediction(
     example: &mut Example,
-    provider: Option<PredictionProvider>,
-    repetition_count: usize,
+    args: &PredictArgs,
     app_state: Arc<EpAppState>,
     mut cx: AsyncApp,
 ) -> anyhow::Result<()> {
-    let provider = provider.context("provider is required")?;
+    let provider = args.provider;
+    let repetition_count = args.repetitions;
+    let zeta_version = args.version;
 
     if let Some(existing_prediction) = example.predictions.first() {
         if existing_prediction.provider == provider {
@@ -48,7 +49,16 @@ pub async fn run_prediction(
     ) {
         let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 
-        run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
+        run_format_prompt(
+            example,
+            &FormatPromptArgs {
+                provider,
+                version: args.version,
+            },
+            app_state.clone(),
+            cx,
+        )
+        .await?;
 
         let batched = matches!(provider, PredictionProvider::Teacher);
         return predict_anthropic(example, repetition_count, batched).await;
@@ -85,7 +95,9 @@ pub async fn run_prediction(
     ep_store.update(&mut cx, |store, _cx| {
         let model = match provider {
             PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
-            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
+                version: zeta_version,
+            },
             PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
             PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
             PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
@@ -127,7 +139,7 @@ pub async fn run_prediction(
                                 updated_example.prompt.get_or_insert(ExamplePrompt {
                                     input: prompt,
                                     expected_output: String::new(),
-                                    format: PromptFormat::Zeta2,
+                                    provider,
                                 });
                             }
                         }

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -149,8 +149,7 @@ fn examples_from_response(
         match parse_result {
             Ok(spec) => Some(Example {
                 spec,
-                buffer: None,
-                context: None,
+                prompt_inputs: None,
                 prompt: None,
                 predictions: Vec::new(),
                 score: Vec::new(),

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    example::{Example, ExampleContext},
+    example::Example,
     headless::EpAppState,
     load_project::run_load_project,
     progress::{InfoStyle, Progress, Step, StepProgress},
@@ -19,7 +19,11 @@ pub async fn run_context_retrieval(
     app_state: Arc<EpAppState>,
     mut cx: AsyncApp,
 ) -> anyhow::Result<()> {
-    if example.context.is_some() {
+    if example
+        .prompt_inputs
+        .as_ref()
+        .is_some_and(|inputs| inputs.related_files.is_some())
+    {
         return Ok(());
     }
 
@@ -63,9 +67,9 @@ pub async fn run_context_retrieval(
     let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
     step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 
-    example.context = Some(ExampleContext {
-        files: context_files,
-    });
+    if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
+        prompt_inputs.related_files = Some(context_files);
+    }
     Ok(())
 }
 

crates/edit_prediction_cli/src/score.rs 🔗

@@ -17,19 +17,12 @@ pub async fn run_scoring(
     app_state: Arc<EpAppState>,
     cx: AsyncApp,
 ) -> anyhow::Result<()> {
-    run_prediction(
-        example,
-        Some(args.provider),
-        args.repetitions,
-        app_state,
-        cx,
-    )
-    .await?;
+    run_prediction(example, args, app_state, cx).await?;
 
     let progress = Progress::global().start(Step::Score, &example.spec.name);
 
     progress.set_substatus("applying patches");
-    let original_text = &example.buffer.as_ref().unwrap().content;
+    let original_text = &example.prompt_inputs.as_ref().unwrap().content;
     let expected_texts: Vec<String> = example
         .spec
         .expected_patches

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -204,7 +204,9 @@ fn assign_edit_prediction_provider(
                         } else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
                             && cx.has_flag::<Zeta2FeatureFlag>()
                         {
-                            edit_prediction::EditPredictionModel::Zeta2
+                            edit_prediction::EditPredictionModel::Zeta2 {
+                                version: Default::default(),
+                            }
                         } else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
                             && cx.has_flag::<MercuryFeatureFlag>()
                         {

crates/zeta_prompt/Cargo.toml 🔗

@@ -12,4 +12,6 @@ workspace = true
 path = "src/zeta_prompt.rs"
 
 [dependencies]
-serde.workspace = true
+anyhow.workspace = true
+serde.workspace = true
+strum.workspace = true

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -1,8 +1,10 @@
+use anyhow::Result;
 use serde::{Deserialize, Serialize};
 use std::fmt::Write;
 use std::ops::Range;
 use std::path::Path;
 use std::sync::Arc;
+use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
 
 pub const CURSOR_MARKER: &str = "<|user_cursor|>";
 
@@ -16,6 +18,54 @@ pub struct ZetaPromptInput {
     pub related_files: Vec<RelatedFile>,
 }
 
+#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
+#[allow(non_camel_case_types)]
+pub enum ZetaVersion {
+    V0112_MiddleAtEnd,
+    #[default]
+    V0113_Ordered,
+}
+
+impl std::fmt::Display for ZetaVersion {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}", <&'static str>::from(self))
+    }
+}
+
+impl ZetaVersion {
+    pub fn parse(version_string: &str) -> Result<Self> {
+        let mut results = ZetaVersion::iter().filter(|version| {
+            <&'static str>::from(version)
+                .to_lowercase()
+                .contains(&version_string.to_lowercase())
+        });
+        let Some(result) = results.next() else {
+            anyhow::bail!(
+                "`{version_string}` did not match any of:\n{}",
+                Self::options_as_string()
+            );
+        };
+        if results.next().is_some() {
+            anyhow::bail!(
+                "`{version_string}` matched more than one of:\n{}",
+                Self::options_as_string()
+            );
+        }
+        Ok(result)
+    }
+
+    fn options_as_string() -> String {
+        ZetaVersion::iter()
+            .map(|version| format!("- {}\n", <&'static str>::from(version)))
+            .collect::<Vec<_>>()
+            .concat()
+    }
+
+    pub fn default_as_string() -> String {
+        <&'static str>::from(Self::default()).to_string()
+    }
+}
+
 #[derive(Clone, Debug, Serialize, Deserialize)]
 #[serde(tag = "event")]
 pub enum Event {
@@ -69,11 +119,20 @@ pub struct RelatedExcerpt {
     pub text: Arc<str>,
 }
 
-pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
+pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
     let mut prompt = String::new();
     write_related_files(&mut prompt, &input.related_files);
     write_edit_history_section(&mut prompt, input);
-    write_cursor_excerpt_section(&mut prompt, input);
+
+    match version {
+        ZetaVersion::V0112_MiddleAtEnd => {
+            v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
+        }
+        ZetaVersion::V0113_Ordered => {
+            v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
+        }
+    }
+
     prompt
 }
 
@@ -100,31 +159,73 @@ fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
     }
 }
 
-fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
-    let path_str = input.cursor_path.to_string_lossy();
-    write!(prompt, "<|file_sep|>{}\n", path_str).ok();
+mod v0112_middle_at_end {
+    use super::*;
 
-    prompt.push_str("<|fim_prefix|>\n");
-    prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+        let path_str = input.cursor_path.to_string_lossy();
+        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
 
-    prompt.push_str("<|fim_suffix|>\n");
-    prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
-    if !prompt.ends_with('\n') {
-        prompt.push('\n');
-    }
+        prompt.push_str("<|fim_prefix|>\n");
+        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
 
-    prompt.push_str("<|fim_middle|>current\n");
-    prompt.push_str(
-        &input.cursor_excerpt
-            [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
-    );
-    prompt.push_str(CURSOR_MARKER);
-    prompt.push_str(
-        &input.cursor_excerpt[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
-    );
-    if !prompt.ends_with('\n') {
-        prompt.push('\n');
+        prompt.push_str("<|fim_suffix|>\n");
+        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+
+        prompt.push_str("<|fim_middle|>current\n");
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+        );
+        prompt.push_str(CURSOR_MARKER);
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+        );
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+
+        prompt.push_str("<|fim_middle|>updated\n");
     }
+}
+
+mod v0113_ordered {
+    use super::*;
+
+    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+        let path_str = input.cursor_path.to_string_lossy();
+        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
 
-    prompt.push_str("<|fim_middle|>updated\n");
+        prompt.push_str("<|fim_prefix|>\n");
+        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+
+        prompt.push_str("<|fim_middle|>current\n");
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+        );
+        prompt.push_str(CURSOR_MARKER);
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+        );
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+
+        prompt.push_str("<|fim_suffix|>\n");
+        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+
+        prompt.push_str("<|fim_middle|>updated\n");
+    }
 }