ep: Make --provider optional, skip prediction when results exist (#47225)

Oleksiy Syvokon created

When --provider is not provided, `ep` will now use whatever provider is
recorded in the data.

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/main.rs    | 29 +++++++++++++-----------
crates/edit_prediction_cli/src/predict.rs | 25 ++++++++++++++------
crates/edit_prediction_cli/src/score.rs   |  6 ++++
3 files changed, 38 insertions(+), 22 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -160,17 +160,20 @@ impl Display for Command {
             Command::FormatPrompt(args) => {
                 write!(f, "format-prompt --provider={}", args.provider)
             }
-            Command::Predict(args) => {
-                write!(f, "predict --provider={}", args.provider)
-            }
+            Command::Predict(args) => match &args.provider {
+                Some(provider) => write!(f, "predict --provider={}", provider),
+                None => write!(f, "predict"),
+            },
             Command::ParseOutput => write!(f, "parse-output"),
-            Command::Score(args) => {
-                write!(f, "score --provider={}", args.provider)
-            }
+            Command::Score(args) => match &args.provider {
+                Some(provider) => write!(f, "score --provider={}", provider),
+                None => write!(f, "score"),
+            },
             Command::Distill => write!(f, "distill"),
-            Command::Eval(args) => {
-                write!(f, "eval --provider={}", args.provider)
-            }
+            Command::Eval(args) => match &args.provider {
+                Some(provider) => write!(f, "eval --provider={}", provider),
+                None => write!(f, "eval"),
+            },
             Command::Synthesize(args) => {
                 write!(f, "synthesize --repos {}", args.repos.join(" "))
             }
@@ -189,8 +192,8 @@ struct FormatPromptArgs {
 
 #[derive(Debug, Args, Clone)]
 struct PredictArgs {
-    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
-    provider: PredictionProvider,
+    #[clap(long, short('p'))]
+    provider: Option<PredictionProvider>,
     #[clap(long, default_value_t = 1)]
     repetitions: usize,
 }
@@ -519,7 +522,7 @@ fn main() {
 
                 match &command {
                     Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
-                        predict::sync_batches(&args.provider).await?;
+                        predict::sync_batches(args.provider.as_ref()).await?;
                     }
                     _ => (),
                 }
@@ -698,7 +701,7 @@ fn main() {
 
                 match &command {
                     Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
-                        predict::sync_batches(&args.provider).await?;
+                        predict::sync_batches(args.provider.as_ref()).await?;
                     }
                     _ => (),
                 }

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -31,21 +31,30 @@ pub async fn run_prediction(
     example_progress: &ExampleProgress,
     mut cx: AsyncApp,
 ) -> anyhow::Result<()> {
-    let provider = args.provider;
     let repetition_count = args.repetitions;
 
     if let Some(existing_prediction) = example.predictions.first() {
-        if existing_prediction.provider == provider {
-            return Ok(());
-        } else {
-            example.predictions.clear();
+        let has_prediction = existing_prediction.actual_patch.is_some()
+            || !existing_prediction.actual_output.is_empty();
+        if has_prediction {
+            match args.provider {
+                None => return Ok(()),
+                Some(provider) if existing_prediction.provider == provider => return Ok(()),
+                Some(_) => example.predictions.clear(),
+            }
         }
     }
 
+    let Some(provider) = args.provider else {
+        anyhow::bail!(
+            "No existing predictions found. Use --provider to specify which model to use for prediction."
+        );
+    };
+
     run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 
     if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
-        args.provider
+        provider
     {
         let _step_progress = example_progress.start(Step::Predict);
 
@@ -304,9 +313,9 @@ async fn predict_anthropic(
     Ok(())
 }
 
-pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
+pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
     match provider {
-        PredictionProvider::Teacher(..) => {
+        Some(PredictionProvider::Teacher(..)) => {
             let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
                 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
                     .expect("Failed to create Anthropic client")

crates/edit_prediction_cli/src/score.rs 🔗

@@ -24,7 +24,11 @@ pub async fn run_scoring(
     let progress = example_progress.start(Step::Score);
 
     progress.set_substatus("applying patches");
-    let original_text = &example.prompt_inputs.as_ref().unwrap().content;
+    let original_text = &example
+        .prompt_inputs
+        .as_ref()
+        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
+        .content;
     let expected_texts: Vec<String> = example
         .spec
         .expected_patches