@@ -160,17 +160,20 @@ impl Display for Command {
Command::FormatPrompt(args) => {
write!(f, "format-prompt --provider={}", args.provider)
}
- Command::Predict(args) => {
- write!(f, "predict --provider={}", args.provider)
- }
+ Command::Predict(args) => match &args.provider {
+ Some(provider) => write!(f, "predict --provider={}", provider),
+ None => write!(f, "predict"),
+ },
Command::ParseOutput => write!(f, "parse-output"),
- Command::Score(args) => {
- write!(f, "score --provider={}", args.provider)
- }
+ Command::Score(args) => match &args.provider {
+ Some(provider) => write!(f, "score --provider={}", provider),
+ None => write!(f, "score"),
+ },
Command::Distill => write!(f, "distill"),
- Command::Eval(args) => {
- write!(f, "eval --provider={}", args.provider)
- }
+ Command::Eval(args) => match &args.provider {
+ Some(provider) => write!(f, "eval --provider={}", provider),
+ None => write!(f, "eval"),
+ },
Command::Synthesize(args) => {
write!(f, "synthesize --repos {}", args.repos.join(" "))
}
@@ -189,8 +192,8 @@ struct FormatPromptArgs {
#[derive(Debug, Args, Clone)]
struct PredictArgs {
- #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
- provider: PredictionProvider,
+ #[clap(long, short('p'))]
+ provider: Option<PredictionProvider>,
#[clap(long, default_value_t = 1)]
repetitions: usize,
}
@@ -519,7 +522,7 @@ fn main() {
match &command {
Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
- predict::sync_batches(&args.provider).await?;
+ predict::sync_batches(args.provider.as_ref()).await?;
}
_ => (),
}
@@ -698,7 +701,7 @@ fn main() {
match &command {
Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
- predict::sync_batches(&args.provider).await?;
+ predict::sync_batches(args.provider.as_ref()).await?;
}
_ => (),
}
@@ -31,21 +31,30 @@ pub async fn run_prediction(
example_progress: &ExampleProgress,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
- let provider = args.provider;
let repetition_count = args.repetitions;
if let Some(existing_prediction) = example.predictions.first() {
- if existing_prediction.provider == provider {
- return Ok(());
- } else {
- example.predictions.clear();
+ let has_prediction = existing_prediction.actual_patch.is_some()
+ || !existing_prediction.actual_output.is_empty();
+ if has_prediction {
+ match args.provider {
+ None => return Ok(()),
+ Some(provider) if existing_prediction.provider == provider => return Ok(()),
+ Some(_) => example.predictions.clear(),
+ }
}
}
+ let Some(provider) = args.provider else {
+ anyhow::bail!(
+ "No existing predictions found. Use --provider to specify which model to use for prediction."
+ );
+ };
+
run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
- args.provider
+ provider
{
let _step_progress = example_progress.start(Step::Predict);
@@ -304,9 +313,9 @@ async fn predict_anthropic(
Ok(())
}
-pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
+pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
match provider {
- PredictionProvider::Teacher(..) => {
+ Some(PredictionProvider::Teacher(..)) => {
let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
.expect("Failed to create Anthropic client")