From a10fdfd2b8e3b050f1dab78fbb08a201d6ad4558 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Thu, 15 Jan 2026 19:00:21 +0200 Subject: [PATCH] ep: Combine PredictionProvider and ZetaVersion (#46896) We can specify prompt version in the provider name itself, like this `--provider zeta2:0113`. This kind of tag will also be stored in the `provider` field of jsonlines files. This drops the `--version` parameter. Release Notes: - N/A --- .../edit_prediction_cli/src/format_prompt.rs | 4 +- crates/edit_prediction_cli/src/main.rs | 141 +++++++++++------- crates/edit_prediction_cli/src/predict.rs | 16 +- crates/zeta_prompt/src/zeta_prompt.rs | 16 +- 4 files changed, 105 insertions(+), 72 deletions(-) diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index a2c23bb37eb5119b50050a821ba564e09cf95b1b..a6ce738f3071e97c0f83bd6b17d65867449b4de7 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -42,7 +42,7 @@ pub async fn run_format_prompt( provider: args.provider, }); } - PredictionProvider::Zeta2 => { + PredictionProvider::Zeta2(version) => { step_progress.set_substatus("formatting zeta2 prompt"); let context_start = prompt_inputs.context_range.start; @@ -59,7 +59,7 @@ pub async fn run_format_prompt( events: prompt_inputs.edit_history.clone(), related_files: prompt_inputs.related_files.clone().unwrap_or_default(), }; - let prompt = format_zeta_prompt(&input, args.version); + let prompt = format_zeta_prompt(&input, version); let expected_output = zeta2_output_for_patch( &input, &example diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 1a65ba432549e6f1518ec953f77875d78f0abf9f..b8954f92745992839234cab73142278a176bc955 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -25,7 +25,7 @@ use gpui::{AppContext as _, Application, BackgroundExecutor}; use zeta_prompt::ZetaVersion; use reqwest_client::ReqwestClient; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt::Display; use std::fs::{File, OpenOptions}; use std::hash::{Hash, Hasher}; @@ -152,47 +152,19 @@ impl Display for Command { Command::ParseExample => write!(f, "parse-example"), Command::LoadProject => write!(f, "load-project"), Command::Context => write!(f, "context"), - Command::FormatPrompt(format_prompt_args) => write!( - f, - "format-prompt --prompt-format={}", - format_prompt_args - .provider - .to_possible_value() - .unwrap() - .get_name() - ), - Command::Predict(predict_args) => { - write!( - f, - "predict --provider={:?}", - predict_args - .provider - .to_possible_value() - .unwrap() - .get_name() - ) + Command::FormatPrompt(args) => { + write!(f, "format-prompt --provider={}", args.provider) } - Command::Score(predict_args) => { - write!( - f, - "score --provider={:?}", - predict_args - .provider - .to_possible_value() - .unwrap() - .get_name() - ) + Command::Predict(args) => { + write!(f, "predict --provider={}", args.provider) + } + Command::Score(args) => { + write!(f, "score --provider={}", args.provider) } Command::Distill => write!(f, "distill"), - Command::Eval(predict_args) => write!( - f, - "eval --provider={:?}", - predict_args - .provider - .to_possible_value() - .unwrap() - .get_name() - ), + Command::Eval(args) => { + write!(f, "eval --provider={}", args.provider) + } Command::Synthesize(args) => { write!(f, "synthesize --repos {}", args.repos.join(" ")) } @@ -205,43 +177,96 @@ impl Display for Command { #[derive(Debug, Args, Clone)] struct FormatPromptArgs { - #[clap(long, short)] + #[clap(long, short('p'), default_value_t = PredictionProvider::default())] provider: PredictionProvider, - #[clap( - long, - short, - help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use", - value_parser = ZetaVersion::parse, - default_value_t = ZetaVersion::default(), - )] - version: ZetaVersion, } #[derive(Debug, Args, Clone)] struct PredictArgs { - #[clap(long, short)] + #[clap(long, short('p'), default_value_t = PredictionProvider::default())] provider: PredictionProvider, #[clap(long, default_value_t = 1)] repetitions: usize, - #[clap( - long, - short, - help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use", - value_parser = ZetaVersion::parse, - )] - version: ZetaVersion, } -#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] enum PredictionProvider { Sweep, Mercury, Zeta1, - Zeta2, + Zeta2(ZetaVersion), Teacher, TeacherNonBatching, } +impl Default for PredictionProvider { + fn default() -> Self { + PredictionProvider::Zeta2(ZetaVersion::default()) + } +} + +impl std::fmt::Display for PredictionProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PredictionProvider::Sweep => write!(f, "sweep"), + PredictionProvider::Mercury => write!(f, "mercury"), + PredictionProvider::Zeta1 => write!(f, "zeta1"), + PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"), + PredictionProvider::Teacher => write!(f, "teacher"), + PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"), + } + } +} + +impl std::str::FromStr for PredictionProvider { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let s_lower = s.to_lowercase(); + match s_lower.as_str() { + "sweep" => Ok(PredictionProvider::Sweep), + "mercury" => Ok(PredictionProvider::Mercury), + "zeta1" => Ok(PredictionProvider::Zeta1), + // Handle both old format "zeta2" and new format with version + "zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())), + "teacher" => Ok(PredictionProvider::Teacher), + "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => { + Ok(PredictionProvider::TeacherNonBatching) + } + _ if s_lower.starts_with("zeta2:") => { + let version_str = &s[6..]; + let version = ZetaVersion::parse(version_str)?; + Ok(PredictionProvider::Zeta2(version)) + } + _ => anyhow::bail!( + "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher-non-batching\n\ + For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ + Available zeta versions:\n{}", + ZetaVersion::options_as_string() + ), + } + } +} + +impl Serialize for PredictionProvider { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for PredictionProvider { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + s.parse().map_err(serde::de::Error::custom) + } +} + #[derive(Debug, Args, Clone)] struct SynthesizeArgs { /// Repository URLs (git@github.com:owner/repo or https://...) diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 25995ec960f1b73381a076aed5e27b7311be39a0..a5f92ba55fd83d8bb5979fe6b9d831f185dcd338 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -31,7 +31,6 @@ pub async fn run_prediction( ) -> anyhow::Result<()> { let provider = args.provider; let repetition_count = args.repetitions; - let zeta_version = args.version; if let Some(existing_prediction) = example.predictions.first() { if existing_prediction.provider == provider { @@ -51,10 +50,7 @@ pub async fn run_prediction( run_format_prompt( example, - &FormatPromptArgs { - provider, - version: args.version, - }, + &FormatPromptArgs { provider }, app_state.clone(), cx, ) @@ -70,7 +66,7 @@ pub async fn run_prediction( if matches!( provider, - PredictionProvider::Zeta1 | PredictionProvider::Zeta2 + PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_) ) { step_progress.set_substatus("authenticating"); static AUTHENTICATED: OnceLock>> = OnceLock::new(); @@ -95,9 +91,9 @@ pub async fn run_prediction( ep_store.update(&mut cx, |store, _cx| { let model = match provider { PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, - PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 { - version: zeta_version, - }, + PredictionProvider::Zeta2(version) => { + edit_prediction::EditPredictionModel::Zeta2 { version } + } PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { @@ -135,7 +131,7 @@ pub async fn run_prediction( if let Some(prompt) = request.prompt { fs::write(run_dir.join("prediction_prompt.md"), &prompt)?; - if provider == PredictionProvider::Zeta2 { + if matches!(provider, PredictionProvider::Zeta2(_)) { updated_example.prompt.get_or_insert(ExamplePrompt { input: prompt, expected_output: String::new(), diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 76fcd7818600c193b2a5b4d080144d5bae637e49..35fcb0d02453567b60f0ed95292b8c418a0da40a 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -18,7 +18,19 @@ pub struct ZetaPromptInput { pub related_files: Vec, } -#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)] +#[derive( + Default, + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + EnumIter, + IntoStaticStr, + Serialize, + Deserialize, +)] #[allow(non_camel_case_types)] pub enum ZetaVersion { V0112_MiddleAtEnd, @@ -54,7 +66,7 @@ impl ZetaVersion { Ok(result) } - fn options_as_string() -> String { + pub fn options_as_string() -> String { ZetaVersion::iter() .map(|version| format!("- {}\n", <&'static str>::from(version))) .collect::>()