From 20284e4f218081a2c9f90ab78dbb7062e6203725 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 13 Jan 2026 13:53:44 -0800 Subject: [PATCH] Introduce zeta2 format with cursor content in original order (#46732) This one does `fim_prefix`, `fim_middle`, and `fim_suffix` in that order, in the prompt, instead of putting the current middle last. Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga Co-authored-by: Ben Kunkle --- Cargo.lock | 2 + crates/edit_prediction/src/edit_prediction.rs | 22 ++- .../src/edit_prediction_tests.rs | 86 +++++++++- crates/edit_prediction/src/zeta2.rs | 15 +- crates/edit_prediction_cli/src/example.rs | 26 ++- .../edit_prediction_cli/src/format_prompt.rs | 122 +++++++------- .../edit_prediction_cli/src/load_project.rs | 23 ++- crates/edit_prediction_cli/src/main.rs | 35 ++-- crates/edit_prediction_cli/src/predict.rs | 26 ++- .../edit_prediction_cli/src/pull_examples.rs | 3 +- .../src/retrieve_context.rs | 14 +- crates/edit_prediction_cli/src/score.rs | 11 +- .../zed/src/zed/edit_prediction_registry.rs | 4 +- crates/zeta_prompt/Cargo.toml | 4 +- crates/zeta_prompt/src/zeta_prompt.rs | 149 +++++++++++++++--- 15 files changed, 382 insertions(+), 160 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fb211271f995bcdff1121f9fef479c2cd3a1df8e..05ec45aded8d5449bbfb5d89bed3f2a299d214d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21179,7 +21179,9 @@ dependencies = [ name = "zeta_prompt" version = "0.1.0" dependencies = [ + "anyhow", "serde", + "strum 0.27.2", ] [[package]] diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 8b85862412a224e2a8ed87addab7b332d4fd5c1d..4f9a67880b31f8c958fb9b922bfd5c102d365c1a 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -38,6 +38,7 @@ use settings::{EditPredictionProvider, SettingsStore, update_settings_file}; use std::collections::{VecDeque, hash_map}; use text::Edit; use workspace::Workspace; +use zeta_prompt::ZetaVersion; use std::ops::Range; use std::path::Path; @@ -183,7 +184,9 @@ pub struct EditPredictionStore { pub enum EditPredictionModel { #[default] Zeta1, - Zeta2, + Zeta2 { + version: ZetaVersion, + }, Sweep, Mercury, } @@ -654,7 +657,9 @@ impl EditPredictionStore { update_required: false, #[cfg(feature = "cli-support")] eval_cache: None, - edit_prediction_model: EditPredictionModel::Zeta2, + edit_prediction_model: EditPredictionModel::Zeta2 { + version: Default::default(), + }, sweep_ai: SweepAi::new(cx), mercury: Mercury::new(cx), data_collection_choice, @@ -794,7 +799,10 @@ impl EditPredictionStore { } pub fn usage(&self, cx: &App) -> Option { - if self.edit_prediction_model == EditPredictionModel::Zeta2 { + if matches!( + self.edit_prediction_model, + EditPredictionModel::Zeta2 { .. } + ) { self.user_store.read(cx).edit_prediction_usage() } else { None @@ -1204,7 +1212,7 @@ impl EditPredictionStore { sweep_ai::edit_prediction_accepted(self, current_prediction, cx) } EditPredictionModel::Mercury => {} - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { zeta2::edit_prediction_accepted(self, current_prediction, cx) } } @@ -1338,7 +1346,7 @@ impl EditPredictionStore { was_shown: bool, ) { match self.edit_prediction_model { - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { if self.custom_predict_edits_url.is_some() { return; } @@ -1773,7 +1781,9 @@ impl EditPredictionStore { } let task = match self.edit_prediction_model { EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx), - EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx), + EditPredictionModel::Zeta2 { version } => { + zeta2::request_prediction_with_zeta2(self, inputs, version, cx) + } EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), }; diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 3d13d2d7f8cec902fbde8560b7e251f1273e6d24..7432ebb888a2ca8648388d55d0b6bf52b40fb153 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1332,12 +1332,20 @@ fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawComp let current_marker = "<|fim_middle|>current\n"; let updated_marker = "<|fim_middle|>updated\n"; + let suffix_marker = "<|fim_suffix|>\n"; let cursor = "<|user_cursor|>"; let start_ix = current_marker.len() + prompt.find(current_marker).unwrap(); let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap(); let excerpt = prompt[start_ix..end_ix].replace(cursor, ""); - let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap(); + // In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content. + // Strip that out to get just the editable region. + let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) { + &excerpt[..suffix_pos] + } else { + &excerpt + }; + let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap(); RawCompletionResponse { id: Uuid::new_v4().to_string(), @@ -1629,6 +1637,82 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) { + // Test that zeta2's newline normalization logic doesn't insert spurious newlines. + // When the buffer ends without a trailing newline, but the model returns output + // with a trailing newline, zeta2 should normalize both sides before diffing + // so no spurious newline is inserted. + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + + // Single line buffer with no trailing newline + fs.insert_tree( + "/root", + json!({ + "foo.txt": "hello" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project + .find_project_path(path!("root/foo.txt"), cx) + .unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(0, 5)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_request, respond_tx) = requests.predict.next().await.unwrap(); + + // Model returns output WITH a trailing newline, even though the buffer doesn't have one. + // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted. + let response = RawCompletionResponse { + id: Uuid::new_v4().to_string(), + object: "text_completion".into(), + created: 0, + model: "model".into(), + choices: vec![RawCompletionChoice { + text: "hello world\n".to_string(), + finish_reason: None, + }], + usage: RawCompletionUsage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + }; + respond_tx.send(response).unwrap(); + + cx.run_until_parked(); + + // The prediction should insert " world" without adding a newline + ep_store.update(cx, |ep_store, cx| { + let prediction = ep_store + .prediction_at(&buffer, None, &project, cx) + .expect("should have prediction"); + let edits: Vec<_> = prediction + .edits + .iter() + .map(|(range, text)| { + let snapshot = buffer.read(cx).snapshot(); + (range.to_offset(&snapshot), text.clone()) + }) + .collect(); + assert_eq!(edits, vec![(5..5, " world".into())]); + }); +} + #[gpui::test] async fn test_can_collect_data(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index da3408fee5deb72f832f01e97690c87017ea0ca4..17f379f23eeac36f388dbcf72e00f4c63ed7a053 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -15,8 +15,8 @@ use release_channel::AppVersion; use std::env; use std::{path::Path, sync::Arc, time::Instant}; -use zeta_prompt::CURSOR_MARKER; use zeta_prompt::format_zeta_prompt; +use zeta_prompt::{CURSOR_MARKER, ZetaVersion}; pub const MAX_CONTEXT_TOKENS: usize = 350; pub const MAX_EDITABLE_TOKENS: usize = 150; @@ -32,6 +32,7 @@ pub fn request_prediction_with_zeta2( debug_tx, .. }: EditPredictionModelInput, + zeta_version: ZetaVersion, cx: &mut Context, ) -> Task>> { let buffer_snapshotted_at = Instant::now(); @@ -62,7 +63,7 @@ pub fn request_prediction_with_zeta2( cursor_offset, ); - let prompt = format_zeta_prompt(&prompt_input); + let prompt = format_zeta_prompt(&prompt_input, zeta_version); if let Some(debug_tx) = &debug_tx { debug_tx @@ -125,9 +126,17 @@ pub fn request_prediction_with_zeta2( output_text = output_text.replace(CURSOR_MARKER, ""); } - let old_text = snapshot + let mut old_text = snapshot .text_for_range(editable_offset_range.clone()) .collect::(); + + if !output_text.is_empty() && !output_text.ends_with('\n') { + output_text.push('\n'); + } + if !old_text.is_empty() && !old_text.ends_with('\n') { + old_text.push('\n'); + } + let edits: Vec<_> = language::text_diff(&old_text, &output_text) .into_iter() .map(|(range, text)| { diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index d6e45360a70b6a14f23311dd539cb16a72c66788..d3e10834f10b071e3602b7f399fbc8f28509fff1 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,5 +1,5 @@ +use crate::PredictionProvider; use crate::paths::WORKTREES_DIR; -use crate::{PredictionProvider, PromptFormat}; use anyhow::{Context as _, Result}; use collections::HashMap; use edit_prediction::example_spec::ExampleSpec; @@ -9,11 +9,12 @@ use http_client::Url; use language::{Anchor, Buffer}; use project::Project; use serde::{Deserialize, Serialize}; -use std::ops::Range; use std::{ borrow::Cow, io::Read, + ops::Range, path::{Path, PathBuf}, + sync::Arc, }; use zeta_prompt::RelatedFile; @@ -25,12 +26,7 @@ pub struct Example { /// The full content of the file where an edit is being predicted, and the /// actual cursor offset. #[serde(skip_serializing_if = "Option::is_none")] - pub buffer: Option, - - /// The context retrieved for the prediction. This requires the worktree to - /// be loaded and the language server to be started. - #[serde(skip_serializing_if = "Option::is_none")] - pub context: Option, + pub prompt_inputs: Option, /// The input and expected output from the edit prediction model. #[serde(skip_serializing_if = "Option::is_none")] @@ -59,25 +55,22 @@ pub struct ExampleState { } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ExampleContext { - pub files: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ExampleBuffer { +pub struct ExamplePromptInputs { pub content: String, pub cursor_row: u32, pub cursor_column: u32, pub cursor_offset: usize, pub context_range: Range, pub editable_range: Range, + pub edit_history: Vec>, + pub related_files: Option>, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ExamplePrompt { pub input: String, pub expected_output: String, - pub format: PromptFormat, + pub provider: PredictionProvider, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -239,8 +232,7 @@ fn parse_markdown_example(input: &str) -> Result { let spec = ExampleSpec::from_markdown(input)?; Ok(Example { spec, - buffer: None, - context: None, + prompt_inputs: None, prompt: None, predictions: Vec::new(), score: Vec::new(), diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 04d83371bdb2e7c26b92c5f9adb28e2d8f29b096..a2c23bb37eb5119b50050a821ba564e09cf95b1b 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -1,14 +1,12 @@ use crate::{ - PromptFormat, + FormatPromptArgs, PredictionProvider, example::{Example, ExamplePrompt}, headless::EpAppState, - load_project::run_load_project, progress::{Progress, Step}, retrieve_context::run_context_retrieval, }; use anyhow::{Context as _, Result}; -use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input}; -use gpui::{AsyncApp, Entity}; +use gpui::AsyncApp; use similar::DiffableStr; use std::fmt::Write as _; use std::sync::Arc; @@ -16,16 +14,21 @@ use zeta_prompt::format_zeta_prompt; pub async fn run_format_prompt( example: &mut Example, - prompt_format: PromptFormat, + args: &FormatPromptArgs, app_state: Arc, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result<()> { - run_context_retrieval(example, app_state.clone(), cx.clone()).await?; + run_context_retrieval(example, app_state, cx).await?; let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name); - match prompt_format { - PromptFormat::Teacher => { + let prompt_inputs = example + .prompt_inputs + .as_ref() + .context("prompt_inputs must be set after context retrieval")?; + + match args.provider { + PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { step_progress.set_substatus("formatting teacher prompt"); let prompt = TeacherPrompt::format_prompt(example); example.prompt = Some(ExamplePrompt { @@ -36,47 +39,27 @@ pub async fn run_format_prompt( .first() .cloned() .unwrap_or_default(), - format: prompt_format, + provider: args.provider, }); } - PromptFormat::Zeta2 => { - step_progress.set_substatus("loading project"); - run_load_project(example, app_state, cx.clone()).await?; - + PredictionProvider::Zeta2 => { step_progress.set_substatus("formatting zeta2 prompt"); - let ep_store: Entity = cx.update(|cx| { - EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized") - })?; - - let state = example.state.as_ref().context("state must be set")?; - let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot()); - let project = state.project.clone(); - let (_, input) = - ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| { - let events = ep_store - .edit_history_for_project(&project, cx) - .into_iter() - .map(|e| e.event) - .collect(); - anyhow::Ok(zeta2_prompt_input( - &snapshot, - example - .context - .as_ref() - .context("context must be set")? - .files - .clone(), - events, - example.spec.cursor_path.clone(), - example - .buffer - .as_ref() - .context("buffer must be set")? - .cursor_offset, - )) - })?; - let prompt = format_zeta_prompt(&input); + let context_start = prompt_inputs.context_range.start; + let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start; + let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start) + ..(prompt_inputs.editable_range.end - context_start); + let input = zeta_prompt::ZetaPromptInput { + cursor_path: example.spec.cursor_path.clone(), + cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()] + .to_string() + .into(), + editable_range_in_excerpt, + cursor_offset_in_excerpt, + events: prompt_inputs.edit_history.clone(), + related_files: prompt_inputs.related_files.clone().unwrap_or_default(), + }; + let prompt = format_zeta_prompt(&input, args.version); let expected_output = zeta2_output_for_patch( &input, &example @@ -89,9 +72,12 @@ pub async fn run_format_prompt( example.prompt = Some(ExamplePrompt { input: prompt, expected_output, - format: prompt_format, + provider: args.provider, }); } + _ => { + panic!("Cannot format prompt for {:?}", args.provider); + } }; Ok(()) } @@ -144,10 +130,10 @@ impl TeacherPrompt { // 2. Context retriever just didn't include cursor line. // // In that case, fallback to using `cursor_position` as excerpt. - let example_buffer = example - .buffer + let prompt_inputs = example + .prompt_inputs .as_ref() - .context("`buffer` should be filled in in the context collection step")?; + .context("`prompt_inputs` should be filled in in the context collection step")?; // Extract updated (new) editable region from the model response. // The model may include editable region markers in its output, so we need to strip them. @@ -155,7 +141,7 @@ impl TeacherPrompt { let mut new_editable_region = Self::extract_editable_region(&new_editable_region); let old_editable_region = - example_buffer.content[example_buffer.editable_range.clone()].to_string(); + prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string(); // Normalize leading newlines: if old starts with newline but new doesn't, // prepend newline to new to preserve whitespace structure. @@ -164,8 +150,8 @@ impl TeacherPrompt { new_editable_region.insert(0, '\n'); } - let editable_region_start_line = example_buffer.content - [..example_buffer.editable_range.start] + let editable_region_start_line = prompt_inputs.content + [..prompt_inputs.editable_range.start] .matches('\n') .count(); @@ -208,17 +194,21 @@ impl TeacherPrompt { } fn format_context(example: &Example) -> String { - let context = example - .context + let related_files = example + .prompt_inputs .as_ref() - .expect("Missing context retriever step"); + .and_then(|pi| pi.related_files.as_ref()); + + let Some(related_files) = related_files else { + return "(No context)".to_string(); + }; - if context.files.is_empty() { + if related_files.is_empty() { return "(No context)".to_string(); } let mut prompt = String::new(); - for file in context.files.iter() { + for file in related_files { let path_str = file.path.to_string_lossy(); writeln!(&mut prompt, "`````{path_str}").ok(); let mut prev_row = 0; @@ -242,28 +232,26 @@ impl TeacherPrompt { fn format_cursor_excerpt(example: &Example) -> String { let mut result = String::new(); - let example_buffer = example.buffer.as_ref().unwrap(); + let prompt_inputs = example.prompt_inputs.as_ref().unwrap(); let path_str = example.spec.cursor_path.to_string_lossy(); result.push_str(&format!("`````{path_str}\n")); result.push_str( - &example_buffer.content - [example_buffer.context_range.start..example_buffer.editable_range.start], + &prompt_inputs.content + [prompt_inputs.context_range.start..prompt_inputs.editable_range.start], ); result.push_str(Self::EDITABLE_REGION_START); result.push_str( - &example_buffer.content - [example_buffer.editable_range.start..example_buffer.cursor_offset], + &prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset], ); result.push_str(Self::USER_CURSOR_MARKER); result.push_str( - &example_buffer.content - [example_buffer.cursor_offset..example_buffer.editable_range.end], + &prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end], ); result.push_str(Self::EDITABLE_REGION_END); result.push_str( - &example_buffer.content - [example_buffer.editable_range.end..example_buffer.context_range.end], + &prompt_inputs.content + [prompt_inputs.editable_range.end..prompt_inputs.context_range.end], ); result.push_str("\n`````"); diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index ae62c016862fd5bbb8de6b0def4ec6949f9c2604..8fda65452a4badf6dc21277058e413a29d000e98 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -1,5 +1,5 @@ use crate::{ - example::{Example, ExampleBuffer, ExampleState}, + example::{Example, ExamplePromptInputs, ExampleState}, git, headless::EpAppState, progress::{InfoStyle, Progress, Step, StepProgress}, @@ -38,7 +38,20 @@ pub async fn run_load_project( buffer .read_with(&cx, |buffer, _| buffer.parsing_idle()) .await; - let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| { + + let ep_store = cx + .update(|cx| EditPredictionStore::try_global(cx)) + .context("EditPredictionStore not initialized")?; + + let edit_history = ep_store.update(&mut cx, |store, cx| { + store + .edit_history_for_project(&project, cx) + .into_iter() + .map(|e| e.event) + .collect() + }); + + let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| { let cursor_point = cursor_position.to_point(&buffer); let snapshot = buffer.snapshot(); let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( @@ -54,13 +67,15 @@ pub async fn run_load_project( .map(|l| l.name().to_string()) .unwrap_or_else(|| "Unknown".to_string()); ( - ExampleBuffer { + ExamplePromptInputs { content: buffer.text(), cursor_row: cursor_point.row, cursor_column: cursor_point.column, cursor_offset: cursor_position.to_offset(&buffer), context_range, editable_range, + edit_history, + related_files: None, }, language_name, ) @@ -68,7 +83,7 @@ pub async fn run_load_project( progress.set_info(language_name, InfoStyle::Normal); - example.buffer = Some(example_buffer); + example.prompt_inputs = Some(prompt_inputs); example.state = Some(ExampleState { buffer, project, diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 819b5dff4090a84819c053751e4de51dbe40856b..1829ad18d80bee2a1c28c7da7f68ea910ba56d74 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -22,6 +22,7 @@ use edit_prediction::EditPredictionStore; use futures::channel::mpsc; use futures::{SinkExt as _, StreamExt as _}; use gpui::{AppContext as _, Application}; +use zeta_prompt::ZetaVersion; use reqwest_client::ReqwestClient; use serde::{Deserialize, Serialize}; @@ -155,7 +156,7 @@ impl Display for Command { f, "format-prompt --prompt-format={}", format_prompt_args - .prompt_format + .provider .to_possible_value() .unwrap() .get_name() @@ -204,22 +205,31 @@ impl Display for Command { #[derive(Debug, Args, Clone)] struct FormatPromptArgs { - #[clap(long, short('p'))] - prompt_format: PromptFormat, -} - -#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)] -enum PromptFormat { - Teacher, - Zeta2, + #[clap(long, short)] + provider: PredictionProvider, + #[clap( + long, + short, + help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use", + value_parser = ZetaVersion::parse, + default_value_t = ZetaVersion::default(), + )] + version: ZetaVersion, } #[derive(Debug, Args, Clone)] struct PredictArgs { - #[clap(long)] + #[clap(long, short)] provider: PredictionProvider, #[clap(long, default_value_t = 1)] repetitions: usize, + #[clap( + long, + short, + help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use", + value_parser = ZetaVersion::parse, + )] + version: ZetaVersion, } #[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)] @@ -514,7 +524,7 @@ fn main() { Command::FormatPrompt(args) => { run_format_prompt( example, - args.prompt_format, + args, app_state.clone(), cx.clone(), ) @@ -523,8 +533,7 @@ fn main() { Command::Predict(args) => { run_prediction( example, - Some(args.provider), - args.repetitions, + args, app_state.clone(), cx.clone(), ) diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index d2691246147ad27acac9961ee848997231bff8b4..25995ec960f1b73381a076aed5e27b7311be39a0 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -1,5 +1,5 @@ use crate::{ - PredictionProvider, PromptFormat, + FormatPromptArgs, PredictArgs, PredictionProvider, anthropic_client::AnthropicClient, example::{Example, ExamplePrediction, ExamplePrompt}, format_prompt::{TeacherPrompt, run_format_prompt}, @@ -25,12 +25,13 @@ static ANTHROPIC_CLIENT: OnceLock = OnceLock::new(); pub async fn run_prediction( example: &mut Example, - provider: Option, - repetition_count: usize, + args: &PredictArgs, app_state: Arc, mut cx: AsyncApp, ) -> anyhow::Result<()> { - let provider = provider.context("provider is required")?; + let provider = args.provider; + let repetition_count = args.repetitions; + let zeta_version = args.version; if let Some(existing_prediction) = example.predictions.first() { if existing_prediction.provider == provider { @@ -48,7 +49,16 @@ pub async fn run_prediction( ) { let _step_progress = Progress::global().start(Step::Predict, &example.spec.name); - run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?; + run_format_prompt( + example, + &FormatPromptArgs { + provider, + version: args.version, + }, + app_state.clone(), + cx, + ) + .await?; let batched = matches!(provider, PredictionProvider::Teacher); return predict_anthropic(example, repetition_count, batched).await; @@ -85,7 +95,9 @@ pub async fn run_prediction( ep_store.update(&mut cx, |store, _cx| { let model = match provider { PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, - PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, + PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 { + version: zeta_version, + }, PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { @@ -127,7 +139,7 @@ pub async fn run_prediction( updated_example.prompt.get_or_insert(ExamplePrompt { input: prompt, expected_output: String::new(), - format: PromptFormat::Zeta2, + provider, }); } } diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index 3ddc006841233be6ee1797aa7bd082f3ad6e16da..91ffa53c4453d918082a6a1e7e9d84abb7d60770 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -149,8 +149,7 @@ fn examples_from_response( match parse_result { Ok(spec) => Some(Example { spec, - buffer: None, - context: None, + prompt_inputs: None, prompt: None, predictions: Vec::new(), score: Vec::new(), diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index 8ccfcae9fe17542b99e81df6168484fb1bcd55b0..8d9a5b072920527884d3b83e727551efa2ffb985 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -1,5 +1,5 @@ use crate::{ - example::{Example, ExampleContext}, + example::Example, headless::EpAppState, load_project::run_load_project, progress::{InfoStyle, Progress, Step, StepProgress}, @@ -19,7 +19,11 @@ pub async fn run_context_retrieval( app_state: Arc, mut cx: AsyncApp, ) -> anyhow::Result<()> { - if example.context.is_some() { + if example + .prompt_inputs + .as_ref() + .is_some_and(|inputs| inputs.related_files.is_some()) + { return Ok(()); } @@ -63,9 +67,9 @@ pub async fn run_context_retrieval( let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum(); step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal); - example.context = Some(ExampleContext { - files: context_files, - }); + if let Some(prompt_inputs) = example.prompt_inputs.as_mut() { + prompt_inputs.related_files = Some(context_files); + } Ok(()) } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 4df4564c883f897fc548e3d15e054764d1c27e86..d713137f3decae3a2e25e0bbe520724c8756018d 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -17,19 +17,12 @@ pub async fn run_scoring( app_state: Arc, cx: AsyncApp, ) -> anyhow::Result<()> { - run_prediction( - example, - Some(args.provider), - args.repetitions, - app_state, - cx, - ) - .await?; + run_prediction(example, args, app_state, cx).await?; let progress = Progress::global().start(Step::Score, &example.spec.name); progress.set_substatus("applying patches"); - let original_text = &example.buffer.as_ref().unwrap().content; + let original_text = &example.prompt_inputs.as_ref().unwrap().content; let expected_texts: Vec = example .spec .expected_patches diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 0dcee753407c8ea8debabab386704bec9046d4e0..98d5fcaad848920bce47c119bfb046c74e6188c1 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -204,7 +204,9 @@ fn assign_edit_prediction_provider( } else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME && cx.has_flag::() { - edit_prediction::EditPredictionModel::Zeta2 + edit_prediction::EditPredictionModel::Zeta2 { + version: Default::default(), + } } else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME && cx.has_flag::() { diff --git a/crates/zeta_prompt/Cargo.toml b/crates/zeta_prompt/Cargo.toml index c9b1e2d784d10ea2fd278f70ffdae2ef0981fce0..12b4371d58c41c3b38569e637cbe54c7ba27404d 100644 --- a/crates/zeta_prompt/Cargo.toml +++ b/crates/zeta_prompt/Cargo.toml @@ -12,4 +12,6 @@ workspace = true path = "src/zeta_prompt.rs" [dependencies] -serde.workspace = true \ No newline at end of file +anyhow.workspace = true +serde.workspace = true +strum.workspace = true diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index e6c6f56d3d8ef5631776841d15e7c9c623044f25..76fcd7818600c193b2a5b4d080144d5bae637e49 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -1,8 +1,10 @@ +use anyhow::Result; use serde::{Deserialize, Serialize}; use std::fmt::Write; use std::ops::Range; use std::path::Path; use std::sync::Arc; +use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr}; pub const CURSOR_MARKER: &str = "<|user_cursor|>"; @@ -16,6 +18,54 @@ pub struct ZetaPromptInput { pub related_files: Vec, } +#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)] +#[allow(non_camel_case_types)] +pub enum ZetaVersion { + V0112_MiddleAtEnd, + #[default] + V0113_Ordered, +} + +impl std::fmt::Display for ZetaVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", <&'static str>::from(self)) + } +} + +impl ZetaVersion { + pub fn parse(version_string: &str) -> Result { + let mut results = ZetaVersion::iter().filter(|version| { + <&'static str>::from(version) + .to_lowercase() + .contains(&version_string.to_lowercase()) + }); + let Some(result) = results.next() else { + anyhow::bail!( + "`{version_string}` did not match any of:\n{}", + Self::options_as_string() + ); + }; + if results.next().is_some() { + anyhow::bail!( + "`{version_string}` matched more than one of:\n{}", + Self::options_as_string() + ); + } + Ok(result) + } + + fn options_as_string() -> String { + ZetaVersion::iter() + .map(|version| format!("- {}\n", <&'static str>::from(version))) + .collect::>() + .concat() + } + + pub fn default_as_string() -> String { + <&'static str>::from(Self::default()).to_string() + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "event")] pub enum Event { @@ -69,11 +119,20 @@ pub struct RelatedExcerpt { pub text: Arc, } -pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String { +pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String { let mut prompt = String::new(); write_related_files(&mut prompt, &input.related_files); write_edit_history_section(&mut prompt, input); - write_cursor_excerpt_section(&mut prompt, input); + + match version { + ZetaVersion::V0112_MiddleAtEnd => { + v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input); + } + ZetaVersion::V0113_Ordered => { + v0113_ordered::write_cursor_excerpt_section(&mut prompt, input) + } + } + prompt } @@ -100,31 +159,73 @@ fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) { } } -fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); - write!(prompt, "<|file_sep|>{}\n", path_str).ok(); +mod v0112_middle_at_end { + use super::*; - prompt.push_str("<|fim_prefix|>\n"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { + let path_str = input.cursor_path.to_string_lossy(); + write!(prompt, "<|file_sep|>{}\n", path_str).ok(); - prompt.push_str("<|fim_suffix|>\n"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); - if !prompt.ends_with('\n') { - prompt.push('\n'); - } + prompt.push_str("<|fim_prefix|>\n"); + prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); - prompt.push_str("<|fim_middle|>current\n"); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); - prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); - if !prompt.ends_with('\n') { - prompt.push('\n'); + prompt.push_str("<|fim_suffix|>\n"); + prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + + prompt.push_str("<|fim_middle|>current\n"); + prompt.push_str( + &input.cursor_excerpt + [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], + ); + prompt.push_str(CURSOR_MARKER); + prompt.push_str( + &input.cursor_excerpt + [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], + ); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + + prompt.push_str("<|fim_middle|>updated\n"); } +} + +mod v0113_ordered { + use super::*; + + pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { + let path_str = input.cursor_path.to_string_lossy(); + write!(prompt, "<|file_sep|>{}\n", path_str).ok(); - prompt.push_str("<|fim_middle|>updated\n"); + prompt.push_str("<|fim_prefix|>\n"); + prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + + prompt.push_str("<|fim_middle|>current\n"); + prompt.push_str( + &input.cursor_excerpt + [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], + ); + prompt.push_str(CURSOR_MARKER); + prompt.push_str( + &input.cursor_excerpt + [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], + ); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + + prompt.push_str("<|fim_suffix|>\n"); + prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + + prompt.push_str("<|fim_middle|>updated\n"); + } }