ep: Combine PredictionProvider and ZetaVersion (#46896)

Oleksiy Syvokon created

We can specify prompt version in the provider name itself, like this
`--provider zeta2:0113`.

This kind of tag will also be stored in the `provider` field of
jsonlines files.

This drops the `--version` parameter.


Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/format_prompt.rs |   4 
crates/edit_prediction_cli/src/main.rs          | 141 +++++++++++-------
crates/edit_prediction_cli/src/predict.rs       |  16 -
crates/zeta_prompt/src/zeta_prompt.rs           |  16 +
4 files changed, 105 insertions(+), 72 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -42,7 +42,7 @@ pub async fn run_format_prompt(
                 provider: args.provider,
             });
         }
-        PredictionProvider::Zeta2 => {
+        PredictionProvider::Zeta2(version) => {
             step_progress.set_substatus("formatting zeta2 prompt");
 
             let context_start = prompt_inputs.context_range.start;
@@ -59,7 +59,7 @@ pub async fn run_format_prompt(
                 events: prompt_inputs.edit_history.clone(),
                 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
             };
-            let prompt = format_zeta_prompt(&input, args.version);
+            let prompt = format_zeta_prompt(&input, version);
             let expected_output = zeta2_output_for_patch(
                 &input,
                 &example

crates/edit_prediction_cli/src/main.rs 🔗

@@ -25,7 +25,7 @@ use gpui::{AppContext as _, Application, BackgroundExecutor};
 use zeta_prompt::ZetaVersion;
 
 use reqwest_client::ReqwestClient;
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
 use std::fmt::Display;
 use std::fs::{File, OpenOptions};
 use std::hash::{Hash, Hasher};
@@ -152,47 +152,19 @@ impl Display for Command {
             Command::ParseExample => write!(f, "parse-example"),
             Command::LoadProject => write!(f, "load-project"),
             Command::Context => write!(f, "context"),
-            Command::FormatPrompt(format_prompt_args) => write!(
-                f,
-                "format-prompt --prompt-format={}",
-                format_prompt_args
-                    .provider
-                    .to_possible_value()
-                    .unwrap()
-                    .get_name()
-            ),
-            Command::Predict(predict_args) => {
-                write!(
-                    f,
-                    "predict --provider={:?}",
-                    predict_args
-                        .provider
-                        .to_possible_value()
-                        .unwrap()
-                        .get_name()
-                )
+            Command::FormatPrompt(args) => {
+                write!(f, "format-prompt --provider={}", args.provider)
             }
-            Command::Score(predict_args) => {
-                write!(
-                    f,
-                    "score --provider={:?}",
-                    predict_args
-                        .provider
-                        .to_possible_value()
-                        .unwrap()
-                        .get_name()
-                )
+            Command::Predict(args) => {
+                write!(f, "predict --provider={}", args.provider)
+            }
+            Command::Score(args) => {
+                write!(f, "score --provider={}", args.provider)
             }
             Command::Distill => write!(f, "distill"),
-            Command::Eval(predict_args) => write!(
-                f,
-                "eval --provider={:?}",
-                predict_args
-                    .provider
-                    .to_possible_value()
-                    .unwrap()
-                    .get_name()
-            ),
+            Command::Eval(args) => {
+                write!(f, "eval --provider={}", args.provider)
+            }
             Command::Synthesize(args) => {
                 write!(f, "synthesize --repos {}", args.repos.join(" "))
             }
@@ -205,43 +177,96 @@ impl Display for Command {
 
 #[derive(Debug, Args, Clone)]
 struct FormatPromptArgs {
-    #[clap(long, short)]
+    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
     provider: PredictionProvider,
-    #[clap(
-        long,
-        short,
-        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
-        value_parser = ZetaVersion::parse,
-        default_value_t = ZetaVersion::default(),
-    )]
-    version: ZetaVersion,
 }
 
 #[derive(Debug, Args, Clone)]
 struct PredictArgs {
-    #[clap(long, short)]
+    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
     provider: PredictionProvider,
     #[clap(long, default_value_t = 1)]
     repetitions: usize,
-    #[clap(
-        long,
-        short,
-        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
-        value_parser = ZetaVersion::parse,
-    )]
-    version: ZetaVersion,
 }
 
-#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 enum PredictionProvider {
     Sweep,
     Mercury,
     Zeta1,
-    Zeta2,
+    Zeta2(ZetaVersion),
     Teacher,
     TeacherNonBatching,
 }
 
+impl Default for PredictionProvider {
+    fn default() -> Self {
+        PredictionProvider::Zeta2(ZetaVersion::default())
+    }
+}
+
+impl std::fmt::Display for PredictionProvider {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            PredictionProvider::Sweep => write!(f, "sweep"),
+            PredictionProvider::Mercury => write!(f, "mercury"),
+            PredictionProvider::Zeta1 => write!(f, "zeta1"),
+            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
+            PredictionProvider::Teacher => write!(f, "teacher"),
+            PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"),
+        }
+    }
+}
+
+impl std::str::FromStr for PredictionProvider {
+    type Err = anyhow::Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        let s_lower = s.to_lowercase();
+        match s_lower.as_str() {
+            "sweep" => Ok(PredictionProvider::Sweep),
+            "mercury" => Ok(PredictionProvider::Mercury),
+            "zeta1" => Ok(PredictionProvider::Zeta1),
+            // Handle both old format "zeta2" and new format with version
+            "zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())),
+            "teacher" => Ok(PredictionProvider::Teacher),
+            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
+                Ok(PredictionProvider::TeacherNonBatching)
+            }
+            _ if s_lower.starts_with("zeta2:") => {
+                let version_str = &s[6..];
+                let version = ZetaVersion::parse(version_str)?;
+                Ok(PredictionProvider::Zeta2(version))
+            }
+            _ => anyhow::bail!(
+                "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
+                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
+                 Available zeta versions:\n{}",
+                ZetaVersion::options_as_string()
+            ),
+        }
+    }
+}
+
+impl Serialize for PredictionProvider {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_str(&self.to_string())
+    }
+}
+
+impl<'de> Deserialize<'de> for PredictionProvider {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        let s = String::deserialize(deserializer)?;
+        s.parse().map_err(serde::de::Error::custom)
+    }
+}
+
 #[derive(Debug, Args, Clone)]
 struct SynthesizeArgs {
     /// Repository URLs (git@github.com:owner/repo or https://...)

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -31,7 +31,6 @@ pub async fn run_prediction(
 ) -> anyhow::Result<()> {
     let provider = args.provider;
     let repetition_count = args.repetitions;
-    let zeta_version = args.version;
 
     if let Some(existing_prediction) = example.predictions.first() {
         if existing_prediction.provider == provider {
@@ -51,10 +50,7 @@ pub async fn run_prediction(
 
         run_format_prompt(
             example,
-            &FormatPromptArgs {
-                provider,
-                version: args.version,
-            },
+            &FormatPromptArgs { provider },
             app_state.clone(),
             cx,
         )
@@ -70,7 +66,7 @@ pub async fn run_prediction(
 
     if matches!(
         provider,
-        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
+        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
     ) {
         step_progress.set_substatus("authenticating");
         static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
@@ -95,9 +91,9 @@ pub async fn run_prediction(
     ep_store.update(&mut cx, |store, _cx| {
         let model = match provider {
             PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
-            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
-                version: zeta_version,
-            },
+            PredictionProvider::Zeta2(version) => {
+                edit_prediction::EditPredictionModel::Zeta2 { version }
+            }
             PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
             PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
             PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
@@ -135,7 +131,7 @@ pub async fn run_prediction(
 
                         if let Some(prompt) = request.prompt {
                             fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
-                            if provider == PredictionProvider::Zeta2 {
+                            if matches!(provider, PredictionProvider::Zeta2(_)) {
                                 updated_example.prompt.get_or_insert(ExamplePrompt {
                                     input: prompt,
                                     expected_output: String::new(),

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -18,7 +18,19 @@ pub struct ZetaPromptInput {
     pub related_files: Vec<RelatedFile>,
 }
 
-#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
+#[derive(
+    Default,
+    Clone,
+    Copy,
+    Debug,
+    PartialEq,
+    Eq,
+    Hash,
+    EnumIter,
+    IntoStaticStr,
+    Serialize,
+    Deserialize,
+)]
 #[allow(non_camel_case_types)]
 pub enum ZetaVersion {
     V0112_MiddleAtEnd,
@@ -54,7 +66,7 @@ impl ZetaVersion {
         Ok(result)
     }
 
-    fn options_as_string() -> String {
+    pub fn options_as_string() -> String {
         ZetaVersion::iter()
             .map(|version| format!("- {}\n", <&'static str>::from(version)))
             .collect::<Vec<_>>()