diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index c839be804fe9599f1b7a2b077218041ce58e238a..6d61401a798a8a5c465d65b0762a66d54046a7df 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -160,17 +160,20 @@ impl Display for Command { Command::FormatPrompt(args) => { write!(f, "format-prompt --provider={}", args.provider) } - Command::Predict(args) => { - write!(f, "predict --provider={}", args.provider) - } + Command::Predict(args) => match &args.provider { + Some(provider) => write!(f, "predict --provider={}", provider), + None => write!(f, "predict"), + }, Command::ParseOutput => write!(f, "parse-output"), - Command::Score(args) => { - write!(f, "score --provider={}", args.provider) - } + Command::Score(args) => match &args.provider { + Some(provider) => write!(f, "score --provider={}", provider), + None => write!(f, "score"), + }, Command::Distill => write!(f, "distill"), - Command::Eval(args) => { - write!(f, "eval --provider={}", args.provider) - } + Command::Eval(args) => match &args.provider { + Some(provider) => write!(f, "eval --provider={}", provider), + None => write!(f, "eval"), + }, Command::Synthesize(args) => { write!(f, "synthesize --repos {}", args.repos.join(" ")) } @@ -189,8 +192,8 @@ struct FormatPromptArgs { #[derive(Debug, Args, Clone)] struct PredictArgs { - #[clap(long, short('p'), default_value_t = PredictionProvider::default())] - provider: PredictionProvider, + #[clap(long, short('p'))] + provider: Option, #[clap(long, default_value_t = 1)] repetitions: usize, } @@ -519,7 +522,7 @@ fn main() { match &command { Command::Predict(args) | Command::Score(args) | Command::Eval(args) => { - predict::sync_batches(&args.provider).await?; + predict::sync_batches(args.provider.as_ref()).await?; } _ => (), } @@ -698,7 +701,7 @@ fn main() { match &command { Command::Predict(args) | Command::Score(args) | Command::Eval(args) => { - predict::sync_batches(&args.provider).await?; + predict::sync_batches(args.provider.as_ref()).await?; } _ => (), } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index f0bc99c9dd855c39fccb3068f20150aef407e6b8..b73a8121e88d4ea4a7a2b2ded591fff42970c507 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -31,21 +31,30 @@ pub async fn run_prediction( example_progress: &ExampleProgress, mut cx: AsyncApp, ) -> anyhow::Result<()> { - let provider = args.provider; let repetition_count = args.repetitions; if let Some(existing_prediction) = example.predictions.first() { - if existing_prediction.provider == provider { - return Ok(()); - } else { - example.predictions.clear(); + let has_prediction = existing_prediction.actual_patch.is_some() + || !existing_prediction.actual_output.is_empty(); + if has_prediction { + match args.provider { + None => return Ok(()), + Some(provider) if existing_prediction.provider == provider => return Ok(()), + Some(_) => example.predictions.clear(), + } } } + let Some(provider) = args.provider else { + anyhow::bail!( + "No existing predictions found. Use --provider to specify which model to use for prediction." + ); + }; + run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?; if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) = - args.provider + provider { let _step_progress = example_progress.start(Step::Predict); @@ -304,9 +313,9 @@ async fn predict_anthropic( Ok(()) } -pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> { +pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> { match provider { - PredictionProvider::Teacher(..) => { + Some(PredictionProvider::Teacher(..)) => { let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) .expect("Failed to create Anthropic client") diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 3f7e0167ee667996fd9487e578f63680f7ca5803..cae85fbfa5fb5f9950aa5d3e11b90937634c1ece 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -24,7 +24,11 @@ pub async fn run_scoring( let progress = example_progress.start(Step::Score); progress.set_substatus("applying patches"); - let original_text = &example.prompt_inputs.as_ref().unwrap().content; + let original_text = &example + .prompt_inputs + .as_ref() + .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")? + .content; let expected_texts: Vec = example .spec .expected_patches