@@ -25,7 +25,7 @@ use gpui::{AppContext as _, Application, BackgroundExecutor};
use zeta_prompt::ZetaVersion;
use reqwest_client::ReqwestClient;
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt::Display;
use std::fs::{File, OpenOptions};
use std::hash::{Hash, Hasher};
@@ -152,47 +152,19 @@ impl Display for Command {
Command::ParseExample => write!(f, "parse-example"),
Command::LoadProject => write!(f, "load-project"),
Command::Context => write!(f, "context"),
- Command::FormatPrompt(format_prompt_args) => write!(
- f,
- "format-prompt --prompt-format={}",
- format_prompt_args
- .provider
- .to_possible_value()
- .unwrap()
- .get_name()
- ),
- Command::Predict(predict_args) => {
- write!(
- f,
- "predict --provider={:?}",
- predict_args
- .provider
- .to_possible_value()
- .unwrap()
- .get_name()
- )
+ Command::FormatPrompt(args) => {
+ write!(f, "format-prompt --provider={}", args.provider)
}
- Command::Score(predict_args) => {
- write!(
- f,
- "score --provider={:?}",
- predict_args
- .provider
- .to_possible_value()
- .unwrap()
- .get_name()
- )
+ Command::Predict(args) => {
+ write!(f, "predict --provider={}", args.provider)
+ }
+ Command::Score(args) => {
+ write!(f, "score --provider={}", args.provider)
}
Command::Distill => write!(f, "distill"),
- Command::Eval(predict_args) => write!(
- f,
- "eval --provider={:?}",
- predict_args
- .provider
- .to_possible_value()
- .unwrap()
- .get_name()
- ),
+ Command::Eval(args) => {
+ write!(f, "eval --provider={}", args.provider)
+ }
Command::Synthesize(args) => {
write!(f, "synthesize --repos {}", args.repos.join(" "))
}
@@ -205,43 +177,96 @@ impl Display for Command {
#[derive(Debug, Args, Clone)]
struct FormatPromptArgs {
- #[clap(long, short)]
+ #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
provider: PredictionProvider,
- #[clap(
- long,
- short,
- help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
- value_parser = ZetaVersion::parse,
- default_value_t = ZetaVersion::default(),
- )]
- version: ZetaVersion,
}
#[derive(Debug, Args, Clone)]
struct PredictArgs {
- #[clap(long, short)]
+ #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
provider: PredictionProvider,
#[clap(long, default_value_t = 1)]
repetitions: usize,
- #[clap(
- long,
- short,
- help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
- value_parser = ZetaVersion::parse,
- )]
- version: ZetaVersion,
}
-#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum PredictionProvider {
Sweep,
Mercury,
Zeta1,
- Zeta2,
+ Zeta2(ZetaVersion),
Teacher,
TeacherNonBatching,
}
+impl Default for PredictionProvider {
+ fn default() -> Self {
+ PredictionProvider::Zeta2(ZetaVersion::default())
+ }
+}
+
+impl std::fmt::Display for PredictionProvider {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ PredictionProvider::Sweep => write!(f, "sweep"),
+ PredictionProvider::Mercury => write!(f, "mercury"),
+ PredictionProvider::Zeta1 => write!(f, "zeta1"),
+ PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
+ PredictionProvider::Teacher => write!(f, "teacher"),
+ PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"),
+ }
+ }
+}
+
+impl std::str::FromStr for PredictionProvider {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let s_lower = s.to_lowercase();
+ match s_lower.as_str() {
+ "sweep" => Ok(PredictionProvider::Sweep),
+ "mercury" => Ok(PredictionProvider::Mercury),
+ "zeta1" => Ok(PredictionProvider::Zeta1),
+ // Handle both old format "zeta2" and new format with version
+ "zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())),
+ "teacher" => Ok(PredictionProvider::Teacher),
+ "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
+ Ok(PredictionProvider::TeacherNonBatching)
+ }
+ _ if s_lower.starts_with("zeta2:") => {
+ let version_str = &s[6..];
+ let version = ZetaVersion::parse(version_str)?;
+ Ok(PredictionProvider::Zeta2(version))
+ }
+ _ => anyhow::bail!(
+ "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
+ For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
+ Available zeta versions:\n{}",
+ ZetaVersion::options_as_string()
+ ),
+ }
+ }
+}
+
+impl Serialize for PredictionProvider {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ serializer.serialize_str(&self.to_string())
+ }
+}
+
+impl<'de> Deserialize<'de> for PredictionProvider {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ let s = String::deserialize(deserializer)?;
+ s.parse().map_err(serde::de::Error::custom)
+ }
+}
+
#[derive(Debug, Args, Clone)]
struct SynthesizeArgs {
/// Repository URLs (git@github.com:owner/repo or https://...)
@@ -31,7 +31,6 @@ pub async fn run_prediction(
) -> anyhow::Result<()> {
let provider = args.provider;
let repetition_count = args.repetitions;
- let zeta_version = args.version;
if let Some(existing_prediction) = example.predictions.first() {
if existing_prediction.provider == provider {
@@ -51,10 +50,7 @@ pub async fn run_prediction(
run_format_prompt(
example,
- &FormatPromptArgs {
- provider,
- version: args.version,
- },
+ &FormatPromptArgs { provider },
app_state.clone(),
cx,
)
@@ -70,7 +66,7 @@ pub async fn run_prediction(
if matches!(
provider,
- PredictionProvider::Zeta1 | PredictionProvider::Zeta2
+ PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
) {
step_progress.set_substatus("authenticating");
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
@@ -95,9 +91,9 @@ pub async fn run_prediction(
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
- PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
- version: zeta_version,
- },
+ PredictionProvider::Zeta2(version) => {
+ edit_prediction::EditPredictionModel::Zeta2 { version }
+ }
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
@@ -135,7 +131,7 @@ pub async fn run_prediction(
if let Some(prompt) = request.prompt {
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
- if provider == PredictionProvider::Zeta2 {
+ if matches!(provider, PredictionProvider::Zeta2(_)) {
updated_example.prompt.get_or_insert(ExamplePrompt {
input: prompt,
expected_output: String::new(),
@@ -18,7 +18,19 @@ pub struct ZetaPromptInput {
pub related_files: Vec<RelatedFile>,
}
-#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
+#[derive(
+ Default,
+ Clone,
+ Copy,
+ Debug,
+ PartialEq,
+ Eq,
+ Hash,
+ EnumIter,
+ IntoStaticStr,
+ Serialize,
+ Deserialize,
+)]
#[allow(non_camel_case_types)]
pub enum ZetaVersion {
V0112_MiddleAtEnd,
@@ -54,7 +66,7 @@ impl ZetaVersion {
Ok(result)
}
- fn options_as_string() -> String {
+ pub fn options_as_string() -> String {
ZetaVersion::iter()
.map(|version| format!("- {}\n", <&'static str>::from(version)))
.collect::<Vec<_>>()