From 60f4aa333be1b2661c1a60d5750701122c6c5d8c Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 12 Dec 2025 14:15:58 -0300 Subject: [PATCH] edit prediction cli: Improve error handling (#44718) We were panicking whenever something went wrong with an example in the CLI. This can be very disruptive when running many examples, and e.g a single request fails. Instead, if running more than one example, errors will now be logged alongside instructions to explore and re-run the example by itself. CleanShot 2025-12-12 at 13 32 04@2x You can still opt in to stop as soon as en error occurs with the new `--failfast` argument. Release Notes: - N/A --- crates/edit_prediction_cli/src/distill.rs | 16 +- .../edit_prediction_cli/src/format_prompt.rs | 67 ++--- .../edit_prediction_cli/src/load_project.rs | 248 ++++++++---------- crates/edit_prediction_cli/src/main.rs | 243 ++++++++++++----- crates/edit_prediction_cli/src/paths.rs | 2 + crates/edit_prediction_cli/src/predict.rs | 120 ++++----- crates/edit_prediction_cli/src/progress.rs | 58 +++- .../src/retrieve_context.rs | 67 +++-- crates/edit_prediction_cli/src/score.rs | 5 +- 9 files changed, 478 insertions(+), 348 deletions(-) diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs index 495b3cd88cbd05ad1917517580b913aacf4fb107..085c5f744a1837cbb97f4c33b6f89b6031088e2b 100644 --- a/crates/edit_prediction_cli/src/distill.rs +++ b/crates/edit_prediction_cli/src/distill.rs @@ -1,14 +1,22 @@ +use anyhow::{Result, anyhow}; use std::mem; use crate::example::Example; -pub async fn run_distill(example: &mut Example) { - let [prediction]: [_; 1] = mem::take(&mut example.predictions) - .try_into() - .expect("Run predict first with a single repetition"); +pub async fn run_distill(example: &mut Example) -> Result<()> { + let [prediction]: [_; 1] = + mem::take(&mut example.predictions) + .try_into() + .map_err(|preds: Vec<_>| { + anyhow!( + "Example has {} predictions, but it should have exactly one", + preds.len() + ) + })?; example.expected_patch = prediction.actual_patch; example.prompt = None; example.predictions = Vec::new(); example.score = Vec::new(); + Ok(()) } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 017e11a54c77e06bde7b74ed3f924692e33cd480..f8fd9b2023a84abcf59bcb5ba54d2d228a0c6484 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -6,6 +6,7 @@ use crate::{ progress::{Progress, Step}, retrieve_context::run_context_retrieval, }; +use anyhow::{Context as _, Result, ensure}; use edit_prediction::{ EditPredictionStore, zeta2::{zeta2_output_for_patch, zeta2_prompt_input}, @@ -19,8 +20,8 @@ pub async fn run_format_prompt( prompt_format: PromptFormat, app_state: Arc, mut cx: AsyncApp, -) { - run_context_retrieval(example, app_state.clone(), cx.clone()).await; +) -> Result<()> { + run_context_retrieval(example, app_state.clone(), cx.clone()).await?; let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name); @@ -34,29 +35,33 @@ pub async fn run_format_prompt( }); } PromptFormat::Zeta2 => { - run_load_project(example, app_state, cx.clone()).await; + run_load_project(example, app_state, cx.clone()).await?; - let ep_store = cx - .update(|cx| EditPredictionStore::try_global(cx).unwrap()) - .unwrap(); + let ep_store = cx.update(|cx| { + EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized") + })??; - let state = example.state.as_ref().unwrap(); - let snapshot = state - .buffer - .read_with(&cx, |buffer, _| buffer.snapshot()) - .unwrap(); + let state = example.state.as_ref().context("state must be set")?; + let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?; let project = state.project.clone(); - let (_, input) = ep_store - .update(&mut cx, |ep_store, _cx| { - zeta2_prompt_input( - &snapshot, - example.context.as_ref().unwrap().files.clone(), - ep_store.edit_history_for_project(&project), - example.cursor_path.clone(), - example.buffer.as_ref().unwrap().cursor_offset, - ) - }) - .unwrap(); + let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| { + anyhow::Ok(zeta2_prompt_input( + &snapshot, + example + .context + .as_ref() + .context("context must be set")? + .files + .clone(), + ep_store.edit_history_for_project(&project), + example.cursor_path.clone(), + example + .buffer + .as_ref() + .context("buffer must be set")? + .cursor_offset, + )) + })??; let prompt = format_zeta_prompt(&input); let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone()); example.prompt = Some(ExamplePrompt { @@ -66,6 +71,7 @@ pub async fn run_format_prompt( }); } }; + Ok(()) } pub struct TeacherPrompt; @@ -91,7 +97,7 @@ impl TeacherPrompt { prompt } - pub fn parse(example: &Example, response: &str) -> String { + pub fn parse(example: &Example, response: &str) -> Result { // Ideally, we should always be able to find cursor position in the retrieved context. // In reality, sometimes we don't find it for these reasons: // 1. `example.cursor_position` contains _more_ context than included in the retrieved context @@ -102,7 +108,7 @@ impl TeacherPrompt { let cursor_file = &example .buffer .as_ref() - .expect("`buffer` should be filled in in the context collection step") + .context("`buffer` should be filled in in the context collection step")? .content; // Extract updated (new) editable region from the model response @@ -111,9 +117,10 @@ impl TeacherPrompt { // Reconstruct old editable region we sent to the model let old_editable_region = Self::format_editable_region(example); let old_editable_region = Self::extract_editable_region(&old_editable_region); - if !cursor_file.contains(&old_editable_region) { - panic!("Something's wrong: editable_region is not found in the cursor file") - } + ensure!( + cursor_file.contains(&old_editable_region), + "Something's wrong: editable_region is not found in the cursor file" + ); // Apply editable region to a larger context and compute diff. // This is needed to get a better context lines around the editable region @@ -128,7 +135,7 @@ impl TeacherPrompt { diff = diff, }; - diff + Ok(diff) } fn format_edit_history(edit_history: &str) -> String { @@ -152,9 +159,7 @@ impl TeacherPrompt { } fn format_context(example: &Example) -> String { - if example.context.is_none() { - panic!("Missing context retriever step"); - } + assert!(example.context.is_some(), "Missing context retriever step"); let mut prompt = String::new(); zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files); diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 4d98ae9f3b85f4e6253d9ead4d846ed3d9deee89..4517e6ccbebca76a7ba8ce73322d6467000fc189 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -4,7 +4,7 @@ use crate::{ paths::{REPOS_DIR, WORKTREES_DIR}, progress::{InfoStyle, Progress, Step, StepProgress}, }; -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result}; use collections::HashMap; use edit_prediction::EditPredictionStore; use edit_prediction::udiff::OpenedBuffers; @@ -25,38 +25,38 @@ use std::{ use util::{paths::PathStyle, rel_path::RelPath}; use zeta_prompt::CURSOR_MARKER; -pub async fn run_load_project(example: &mut Example, app_state: Arc, mut cx: AsyncApp) { +pub async fn run_load_project( + example: &mut Example, + app_state: Arc, + mut cx: AsyncApp, +) -> Result<()> { if example.state.is_some() { - return; + return Ok(()); } let progress = Progress::global().start(Step::LoadProject, &example.name); - let project = setup_project(example, &app_state, &progress, &mut cx).await; - - let _open_buffers = apply_edit_history(example, &project, &mut cx) - .await - .unwrap(); - - let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await; - let (example_buffer, language_name) = buffer - .read_with(&cx, |buffer, _cx| { - let cursor_point = cursor_position.to_point(&buffer); - let language_name = buffer - .language() - .map(|l| l.name().to_string()) - .unwrap_or_else(|| "Unknown".to_string()); - ( - ExampleBuffer { - content: buffer.text(), - cursor_row: cursor_point.row, - cursor_column: cursor_point.column, - cursor_offset: cursor_position.to_offset(&buffer), - }, - language_name, - ) - }) - .unwrap(); + let project = setup_project(example, &app_state, &progress, &mut cx).await?; + + let _open_buffers = apply_edit_history(example, &project, &mut cx).await?; + + let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?; + let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| { + let cursor_point = cursor_position.to_point(&buffer); + let language_name = buffer + .language() + .map(|l| l.name().to_string()) + .unwrap_or_else(|| "Unknown".to_string()); + ( + ExampleBuffer { + content: buffer.text(), + cursor_row: cursor_point.row, + cursor_column: cursor_point.column, + cursor_offset: cursor_position.to_offset(&buffer), + }, + language_name, + ) + })?; progress.set_info(language_name, InfoStyle::Normal); @@ -67,16 +67,15 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc, cursor_position, _open_buffers, }); + Ok(()) } async fn cursor_position( example: &Example, project: &Entity, cx: &mut AsyncApp, -) -> (Entity, Anchor) { - let language_registry = project - .read_with(cx, |project, _| project.languages().clone()) - .unwrap(); +) -> Result<(Entity, Anchor)> { + let language_registry = project.read_with(cx, |project, _| project.languages().clone())?; let result = language_registry .load_language_for_file_path(&example.cursor_path) .await; @@ -84,17 +83,18 @@ async fn cursor_position( if let Err(error) = result && !error.is::() { - panic!("Failed to load language for file path: {}", error); + return Err(error); } - let worktree = project - .read_with(cx, |project, cx| { - project.visible_worktrees(cx).next().unwrap() - }) - .unwrap(); + let worktree = project.read_with(cx, |project, cx| { + project + .visible_worktrees(cx) + .next() + .context("No visible worktrees") + })??; let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix) - .unwrap() + .context("Failed to create RelPath")? .into_arc(); let cursor_buffer = project .update(cx, |project, cx| { @@ -105,15 +105,12 @@ async fn cursor_position( }, cx, ) - }) - .unwrap() - .await - .unwrap(); + })? + .await?; let cursor_offset_within_excerpt = example .cursor_position .find(CURSOR_MARKER) - .ok_or_else(|| anyhow!("missing cursor marker")) - .unwrap(); + .context("missing cursor marker")?; let mut cursor_excerpt = example.cursor_position.clone(); cursor_excerpt.replace_range( cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), @@ -123,22 +120,21 @@ async fn cursor_position( let text = buffer.text(); let mut matches = text.match_indices(&cursor_excerpt); - let (excerpt_offset, _) = matches.next().unwrap_or_else(|| { - panic!( + let (excerpt_offset, _) = matches.next().with_context(|| { + format!( "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.", example.name - ); - }); - assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name); - excerpt_offset - }).unwrap(); + ) + })?; + anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name); + Ok(excerpt_offset) + })??; let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; - let cursor_anchor = cursor_buffer - .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset)) - .unwrap(); + let cursor_anchor = + cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; - (cursor_buffer, cursor_anchor) + Ok((cursor_buffer, cursor_anchor)) } async fn setup_project( @@ -146,67 +142,54 @@ async fn setup_project( app_state: &Arc, step_progress: &StepProgress, cx: &mut AsyncApp, -) -> Entity { +) -> Result> { let ep_store = cx - .update(|cx| EditPredictionStore::try_global(cx).unwrap()) - .unwrap(); + .update(|cx| EditPredictionStore::try_global(cx))? + .context("Store should be initialized at init")?; - let worktree_path = setup_worktree(example, step_progress).await; + let worktree_path = setup_worktree(example, step_progress).await?; if let Some(project) = app_state.project_cache.get(&example.repository_url) { - ep_store - .update(cx, |ep_store, _| { - ep_store.clear_history_for_project(&project); - }) - .unwrap(); - let buffer_store = project - .read_with(cx, |project, _| project.buffer_store().clone()) - .unwrap(); - let buffers = buffer_store - .read_with(cx, |buffer_store, _| { - buffer_store.buffers().collect::>() - }) - .unwrap(); + ep_store.update(cx, |ep_store, _| { + ep_store.clear_history_for_project(&project); + })?; + let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; + let buffers = buffer_store.read_with(cx, |buffer_store, _| { + buffer_store.buffers().collect::>() + })?; for buffer in buffers { buffer - .update(cx, |buffer, cx| buffer.reload(cx)) - .unwrap() + .update(cx, |buffer, cx| buffer.reload(cx))? .await .ok(); } - return project; + return Ok(project); } - let project = cx - .update(|cx| { - Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ) - }) - .unwrap(); + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; project .update(cx, |project, cx| { project.disable_worktree_scanner(cx); project.create_worktree(&worktree_path, true, cx) - }) - .unwrap() - .await - .unwrap(); + })? + .await?; app_state .project_cache .insert(example.repository_url.clone(), project.clone()); - let buffer_store = project - .read_with(cx, |project, _| project.buffer_store().clone()) - .unwrap(); + let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; cx.subscribe(&buffer_store, { let project = project.clone(); move |_, event, cx| match event { @@ -215,15 +198,14 @@ async fn setup_project( } _ => {} } - }) - .unwrap() + })? .detach(); - project + Ok(project) } -async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf { - let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name"); +async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result { + let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?; let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); let worktree_path = WORKTREES_DIR .join(repo_owner.as_ref()) @@ -232,14 +214,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path if !repo_dir.is_dir() { step_progress.set_substatus(format!("cloning {}", repo_name)); - fs::create_dir_all(&repo_dir).unwrap(); - run_git(&repo_dir, &["init"]).await.unwrap(); + fs::create_dir_all(&repo_dir)?; + run_git(&repo_dir, &["init"]).await?; run_git( &repo_dir, &["remote", "add", "origin", &example.repository_url], ) - .await - .unwrap(); + .await?; } // Resolve the example to a revision, fetching it if needed. @@ -259,34 +240,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path .await .is_err() { - run_git(&repo_dir, &["fetch", "origin"]).await.unwrap(); + run_git(&repo_dir, &["fetch", "origin"]).await?; } - let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]) - .await - .unwrap(); + let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; revision }; // Create the worktree for this example if needed. step_progress.set_substatus("preparing worktree"); if worktree_path.is_dir() { - run_git(&worktree_path, &["clean", "--force", "-d"]) - .await - .unwrap(); - run_git(&worktree_path, &["reset", "--hard", "HEAD"]) - .await - .unwrap(); - run_git(&worktree_path, &["checkout", revision.as_str()]) - .await - .unwrap(); + run_git(&worktree_path, &["clean", "--force", "-d"]).await?; + run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; + run_git(&worktree_path, &["checkout", revision.as_str()]).await?; } else { let worktree_path_string = worktree_path.to_string_lossy(); run_git( &repo_dir, &["branch", "-f", &example.name, revision.as_str()], ) - .await - .unwrap(); + .await?; run_git( &repo_dir, &[ @@ -297,8 +269,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path &example.name, ], ) - .await - .unwrap(); + .await?; } drop(repo_lock); @@ -309,30 +280,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path .current_dir(&worktree_path) .args(&["apply", "-"]) .stdin(std::process::Stdio::piped()) - .spawn() - .unwrap(); - - let mut stdin = apply_process.stdin.take().unwrap(); - stdin - .write_all(example.uncommitted_diff.as_bytes()) - .await - .unwrap(); - stdin.close().await.unwrap(); + .spawn()?; + + let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?; + stdin.write_all(example.uncommitted_diff.as_bytes()).await?; + stdin.close().await?; drop(stdin); - let apply_result = apply_process.output().await.unwrap(); - if !apply_result.status.success() { - panic!( - "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", - apply_result.status, - String::from_utf8_lossy(&apply_result.stderr), - String::from_utf8_lossy(&apply_result.stdout), - ); - } + let apply_result = apply_process.output().await?; + anyhow::ensure!( + apply_result.status.success(), + "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", + apply_result.status, + String::from_utf8_lossy(&apply_result.stderr), + String::from_utf8_lossy(&apply_result.stdout), + ); } step_progress.clear_substatus(); - worktree_path + Ok(worktree_path) } async fn apply_edit_history( diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 075f8862e6f86276a0df550c6d27f8c15a5d1293..3b185103390016f60fc4f621f280d16a58c363e5 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -16,12 +16,14 @@ use edit_prediction::EditPredictionStore; use gpui::Application; use reqwest_client::ReqwestClient; use serde::{Deserialize, Serialize}; +use std::fmt::Display; use std::{path::PathBuf, sync::Arc}; use crate::distill::run_distill; use crate::example::{group_examples_by_repo, read_examples, write_examples}; use crate::format_prompt::run_format_prompt; use crate::load_project::run_load_project; +use crate::paths::FAILED_EXAMPLES_DIR; use crate::predict::run_prediction; use crate::progress::Progress; use crate::retrieve_context::run_context_retrieval; @@ -42,6 +44,8 @@ struct EpArgs { output: Option, #[arg(long, short, global = true)] in_place: bool, + #[arg(long, short, global = true)] + failfast: bool, } #[derive(Subcommand, Debug)] @@ -67,6 +71,58 @@ enum Command { Clean, } +impl Display for Command { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Command::ParseExample => write!(f, "parse-example"), + Command::LoadProject => write!(f, "load-project"), + Command::Context => write!(f, "context"), + Command::FormatPrompt(format_prompt_args) => write!( + f, + "format-prompt --prompt-format={}", + format_prompt_args + .prompt_format + .to_possible_value() + .unwrap() + .get_name() + ), + Command::Predict(predict_args) => { + write!( + f, + "predict --provider={:?}", + predict_args + .provider + .to_possible_value() + .unwrap() + .get_name() + ) + } + Command::Score(predict_args) => { + write!( + f, + "score --provider={:?}", + predict_args + .provider + .to_possible_value() + .unwrap() + .get_name() + ) + } + Command::Distill => write!(f, "distill"), + Command::Eval(predict_args) => write!( + f, + "eval --provider={:?}", + predict_args + .provider + .to_possible_value() + .unwrap() + .get_name() + ), + Command::Clean => write!(f, "clean"), + } + } +} + #[derive(Debug, Args)] struct FormatPromptArgs { #[clap(long)] @@ -145,71 +201,140 @@ fn main() { EditPredictionStore::global(&app_state.client, &app_state.user_store, cx); cx.spawn(async move |cx| { - if let Command::Predict(args) = &command { - predict::sync_batches(&args.provider).await - }; - - let total_examples = examples.len(); - Progress::global().set_total_examples(total_examples); - - let mut grouped_examples = group_examples_by_repo(&mut examples); - let example_batches = grouped_examples.chunks_mut(args.max_parallelism); - - for example_batch in example_batches { - let futures = example_batch.into_iter().map(|repo_examples| async { - for example in repo_examples.iter_mut() { - match &command { - Command::ParseExample => {} - Command::LoadProject => { - run_load_project(example, app_state.clone(), cx.clone()).await; - } - Command::Context => { - run_context_retrieval(example, app_state.clone(), cx.clone()).await; - } - Command::FormatPrompt(args) => { - run_format_prompt( - example, - args.prompt_format, - app_state.clone(), - cx.clone(), - ) - .await; - } - Command::Predict(args) => { - run_prediction( - example, - Some(args.provider), - args.repetitions, - app_state.clone(), - cx.clone(), - ) - .await; - } - Command::Distill => { - run_distill(example).await; - } - Command::Score(args) | Command::Eval(args) => { - run_scoring(example, &args, app_state.clone(), cx.clone()).await; + let result = async { + if let Command::Predict(args) = &command { + predict::sync_batches(&args.provider).await?; + } + + let total_examples = examples.len(); + Progress::global().set_total_examples(total_examples); + + let mut grouped_examples = group_examples_by_repo(&mut examples); + let example_batches = grouped_examples.chunks_mut(args.max_parallelism); + + for example_batch in example_batches { + let futures = example_batch.into_iter().map(|repo_examples| async { + for example in repo_examples.iter_mut() { + let result = async { + match &command { + Command::ParseExample => {} + Command::LoadProject => { + run_load_project(example, app_state.clone(), cx.clone()) + .await?; + } + Command::Context => { + run_context_retrieval( + example, + app_state.clone(), + cx.clone(), + ) + .await?; + } + Command::FormatPrompt(args) => { + run_format_prompt( + example, + args.prompt_format, + app_state.clone(), + cx.clone(), + ) + .await?; + } + Command::Predict(args) => { + run_prediction( + example, + Some(args.provider), + args.repetitions, + app_state.clone(), + cx.clone(), + ) + .await?; + } + Command::Distill => { + run_distill(example).await?; + } + Command::Score(args) | Command::Eval(args) => { + run_scoring(example, &args, app_state.clone(), cx.clone()) + .await?; + } + Command::Clean => { + unreachable!() + } + } + anyhow::Ok(()) } - Command::Clean => { - unreachable!() + .await; + + if let Err(e) = result { + Progress::global().increment_failed(); + let failed_example_path = + FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name)); + app_state + .fs + .write( + &failed_example_path, + &serde_json::to_vec_pretty(&example).unwrap(), + ) + .await + .unwrap(); + let err_path = + FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name)); + app_state + .fs + .write(&err_path, e.to_string().as_bytes()) + .await + .unwrap(); + + let msg = format!( + indoc::indoc! {" + While processing {}: + + {:?} + + Written to: \x1b[36m{}\x1b[0m + + Explore this example data with: + fx \x1b[36m{}\x1b[0m + + Re-run this example with: + cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m + "}, + example.name, + e, + err_path.display(), + failed_example_path.display(), + command, + failed_example_path.display(), + ); + if args.failfast || total_examples == 1 { + Progress::global().finalize(); + panic!("{}", msg); + } else { + log::error!("{}", msg); + } } } - } - }); - futures::future::join_all(futures).await; - } - Progress::global().clear(); + }); + futures::future::join_all(futures).await; + } + Progress::global().finalize(); - if args.output.is_some() || !matches!(command, Command::Eval(_)) { - write_examples(&examples, output.as_ref()); + if args.output.is_some() || !matches!(command, Command::Eval(_)) { + write_examples(&examples, output.as_ref()); + } + + match &command { + Command::Predict(args) => predict::sync_batches(&args.provider).await?, + Command::Eval(_) => score::print_report(&examples), + _ => (), + }; + + anyhow::Ok(()) } + .await; - match &command { - Command::Predict(args) => predict::sync_batches(&args.provider).await, - Command::Eval(_) => score::print_report(&examples), - _ => (), - }; + if let Err(e) = result { + panic!("Fatal error: {:?}", e); + } let _ = cx.update(|cx| cx.quit()); }) diff --git a/crates/edit_prediction_cli/src/paths.rs b/crates/edit_prediction_cli/src/paths.rs index 0f470fae556b6d61739ab77083d7edbedf77ef89..e5d420d0e3dbeda9c50b8e5a3683238149dbc604 100644 --- a/crates/edit_prediction_cli/src/paths.rs +++ b/crates/edit_prediction_cli/src/paths.rs @@ -18,6 +18,8 @@ pub static RUN_DIR: LazyLock = LazyLock::new(|| { }); pub static LATEST_EXAMPLE_RUN_DIR: LazyLock = LazyLock::new(|| DATA_DIR.join("latest")); pub static LLM_CACHE_DB: LazyLock = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite")); +pub static FAILED_EXAMPLES_DIR: LazyLock = + LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed"))); fn ensure_dir(path: &Path) -> PathBuf { std::fs::create_dir_all(path).expect("Failed to create directory"); diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 3f690266e3165b2d52f642457e7aebf959a40a03..3e6104e3a8afc3adc609df094a70fc34138c1619 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -9,6 +9,7 @@ use crate::{ progress::{InfoStyle, Progress, Step}, retrieve_context::run_context_retrieval, }; +use anyhow::Context as _; use edit_prediction::{DebugEvent, EditPredictionStore}; use futures::{FutureExt as _, StreamExt as _, future::Shared}; use gpui::{AppContext as _, AsyncApp, Task}; @@ -26,14 +27,14 @@ pub async fn run_prediction( repetition_count: usize, app_state: Arc, mut cx: AsyncApp, -) { +) -> anyhow::Result<()> { if !example.predictions.is_empty() { - return; + return Ok(()); } - let provider = provider.unwrap(); + let provider = provider.context("provider is required")?; - run_context_retrieval(example, app_state.clone(), cx.clone()).await; + run_context_retrieval(example, app_state.clone(), cx.clone()).await?; if matches!( provider, @@ -42,14 +43,14 @@ pub async fn run_prediction( let _step_progress = Progress::global().start(Step::Predict, &example.name); if example.prompt.is_none() { - run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await; + run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?; } let batched = matches!(provider, PredictionProvider::Teacher); return predict_anthropic(example, repetition_count, batched).await; } - run_load_project(example, app_state.clone(), cx.clone()).await; + run_load_project(example, app_state.clone(), cx.clone()).await?; let _step_progress = Progress::global().start(Step::Predict, &example.name); @@ -62,10 +63,9 @@ pub async fn run_prediction( .get_or_init(|| { let client = app_state.client.clone(); cx.spawn(async move |cx| { - client - .sign_in_with_optional_connect(true, cx) - .await - .unwrap(); + if let Err(e) = client.sign_in_with_optional_connect(true, cx).await { + eprintln!("Authentication failed: {}", e); + } }) .shared() }) @@ -73,33 +73,30 @@ pub async fn run_prediction( .await; } - let ep_store = cx - .update(|cx| EditPredictionStore::try_global(cx).unwrap()) - .unwrap(); - - ep_store - .update(&mut cx, |store, _cx| { - let model = match provider { - PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, - PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, - PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, - PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, - PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { - unreachable!() - } - }; - store.set_edit_prediction_model(model); - }) - .unwrap(); - let state = example.state.as_ref().unwrap(); + let ep_store = cx.update(|cx| { + EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized") + })??; + + ep_store.update(&mut cx, |store, _cx| { + let model = match provider { + PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, + PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, + PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, + PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, + PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { + unreachable!() + } + }; + store.set_edit_prediction_model(model); + })?; + let state = example.state.as_ref().context("state must be set")?; let run_dir = RUN_DIR.join(&example.name); let updated_example = Arc::new(Mutex::new(example.clone())); let current_run_ix = Arc::new(AtomicUsize::new(0)); - let mut debug_rx = ep_store - .update(&mut cx, |store, cx| store.debug_info(&state.project, cx)) - .unwrap(); + let mut debug_rx = + ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?; let debug_task = cx.background_spawn({ let updated_example = updated_example.clone(); let current_run_ix = current_run_ix.clone(); @@ -153,14 +150,14 @@ pub async fn run_prediction( run_dir.clone() }; - fs::create_dir_all(&run_dir).unwrap(); + fs::create_dir_all(&run_dir)?; if LATEST_EXAMPLE_RUN_DIR.is_symlink() { - fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap(); + fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?; } #[cfg(unix)] - std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap(); + std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?; #[cfg(windows)] - std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap(); + std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?; updated_example .lock() @@ -181,10 +178,8 @@ pub async fn run_prediction( cloud_llm_client::PredictEditsRequestTrigger::Cli, cx, ) - }) - .unwrap() - .await - .unwrap(); + })? + .await?; let actual_patch = prediction .and_then(|prediction| { @@ -213,20 +208,23 @@ pub async fn run_prediction( } } - ep_store - .update(&mut cx, |store, _| { - store.remove_project(&state.project); - }) - .unwrap(); - debug_task.await.unwrap(); + ep_store.update(&mut cx, |store, _| { + store.remove_project(&state.project); + })?; + debug_task.await?; *example = Arc::into_inner(updated_example) - .unwrap() + .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))? .into_inner() - .unwrap(); + .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?; + Ok(()) } -async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) { +async fn predict_anthropic( + example: &mut Example, + _repetition_count: usize, + batched: bool, +) -> anyhow::Result<()> { let llm_model_name = "claude-sonnet-4-5"; let max_tokens = 16384; let llm_client = if batched { @@ -234,12 +232,9 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc } else { AnthropicClient::plain() }; - let llm_client = llm_client.expect("Failed to create LLM client"); + let llm_client = llm_client.context("Failed to create LLM client")?; - let prompt = example - .prompt - .as_ref() - .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name)); + let prompt = example.prompt.as_ref().context("Prompt is required")?; let messages = vec![anthropic::Message { role: anthropic::Role::User, @@ -251,11 +246,10 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc let Some(response) = llm_client .generate(llm_model_name, max_tokens, messages) - .await - .unwrap() + .await? else { // Request stashed for batched processing - return; + return Ok(()); }; let actual_output = response @@ -268,7 +262,7 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc .collect::>() .join("\n"); - let actual_patch = TeacherPrompt::parse(example, &actual_output); + let actual_patch = TeacherPrompt::parse(example, &actual_output)?; let prediction = ExamplePrediction { actual_patch, @@ -277,19 +271,21 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc }; example.predictions.push(prediction); + Ok(()) } -pub async fn sync_batches(provider: &PredictionProvider) { +pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> { match provider { PredictionProvider::Teacher => { let cache_path = crate::paths::LLM_CACHE_DB.as_ref(); let llm_client = - AnthropicClient::batch(cache_path).expect("Failed to create LLM client"); + AnthropicClient::batch(cache_path).context("Failed to create LLM client")?; llm_client .sync_batches() .await - .expect("Failed to sync batches"); + .context("Failed to sync batches")?; } _ => (), - } + }; + Ok(()) } diff --git a/crates/edit_prediction_cli/src/progress.rs b/crates/edit_prediction_cli/src/progress.rs index 8195485d70c70c0cbfb38e2de83a055598d5e4e5..ddc710f202cc98e5932c234cb6bebcc93b28171c 100644 --- a/crates/edit_prediction_cli/src/progress.rs +++ b/crates/edit_prediction_cli/src/progress.rs @@ -20,6 +20,7 @@ struct ProgressInner { max_example_name_len: usize, status_lines_displayed: usize, total_examples: usize, + failed_examples: usize, last_line_is_logging: bool, } @@ -78,7 +79,7 @@ impl Step { static GLOBAL: OnceLock> = OnceLock::new(); static LOGGER: ProgressLogger = ProgressLogger; -const RIGHT_MARGIN: usize = 4; +const MARGIN: usize = 4; const MAX_STATUS_LINES: usize = 10; impl Progress { @@ -95,6 +96,7 @@ impl Progress { max_example_name_len: 0, status_lines_displayed: 0, total_examples: 0, + failed_examples: 0, last_line_is_logging: false, }), }); @@ -110,6 +112,11 @@ impl Progress { inner.total_examples = total; } + pub fn increment_failed(&self) { + let mut inner = self.inner.lock().unwrap(); + inner.failed_examples += 1; + } + /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption. /// This should be used for any output that needs to appear above the status lines. fn log(&self, message: &str) { @@ -119,7 +126,7 @@ impl Progress { if !inner.last_line_is_logging { let reset = "\x1b[0m"; let dim = "\x1b[2m"; - let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN)); + let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN)); eprintln!("{dim}{divider}{reset}"); inner.last_line_is_logging = true; } @@ -180,7 +187,7 @@ impl Progress { if inner.last_line_is_logging { let reset = "\x1b[0m"; let dim = "\x1b[2m"; - let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN)); + let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN)); eprintln!("{dim}{divider}{reset}"); inner.last_line_is_logging = false; } @@ -229,7 +236,7 @@ impl Progress { let duration_with_margin = format!("{duration} "); let padding_needed = inner .terminal_width - .saturating_sub(RIGHT_MARGIN) + .saturating_sub(MARGIN) .saturating_sub(duration_with_margin.len()) .saturating_sub(strip_ansi_len(&prefix)); let padding = " ".repeat(padding_needed); @@ -263,20 +270,33 @@ impl Progress { // Build the done/in-progress/total label let done_count = inner.completed.len(); let in_progress_count = inner.in_progress.len(); + let failed_count = inner.failed_examples; + + let failed_label = if failed_count > 0 { + format!(" {} failed ", failed_count) + } else { + String::new() + }; + let range_label = format!( " {}/{}/{} ", done_count, in_progress_count, inner.total_examples ); - // Print a divider line with range label aligned with timestamps + // Print a divider line with failed count on left, range label on right + let failed_visible_len = strip_ansi_len(&failed_label); let range_visible_len = range_label.len(); - let left_divider_len = inner + let middle_divider_len = inner .terminal_width - .saturating_sub(RIGHT_MARGIN) + .saturating_sub(MARGIN * 2) + .saturating_sub(failed_visible_len) .saturating_sub(range_visible_len); - let left_divider = "─".repeat(left_divider_len); - let right_divider = "─".repeat(RIGHT_MARGIN); - eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}"); + let left_divider = "─".repeat(MARGIN); + let middle_divider = "─".repeat(middle_divider_len); + let right_divider = "─".repeat(MARGIN); + eprintln!( + "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}" + ); let mut tasks: Vec<_> = inner.in_progress.iter().collect(); tasks.sort_by_key(|(name, _)| *name); @@ -304,7 +324,7 @@ impl Progress { let duration_with_margin = format!("{elapsed} "); let padding_needed = inner .terminal_width - .saturating_sub(RIGHT_MARGIN) + .saturating_sub(MARGIN) .saturating_sub(duration_with_margin.len()) .saturating_sub(strip_ansi_len(&prefix)); let padding = " ".repeat(padding_needed); @@ -324,9 +344,23 @@ impl Progress { let _ = std::io::stderr().flush(); } - pub fn clear(&self) { + pub fn finalize(&self) { let mut inner = self.inner.lock().unwrap(); Self::clear_status_lines(&mut inner); + + // Print summary if there were failures + if inner.failed_examples > 0 { + let total_processed = inner.completed.len() + inner.failed_examples; + let percentage = if total_processed > 0 { + inner.failed_examples as f64 / total_processed as f64 * 100.0 + } else { + 0.0 + }; + eprintln!( + "\n{} of {} examples failed ({:.1}%)", + inner.failed_examples, total_processed, percentage + ); + } } } diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index c066cf3caa9ece27144222ef94e3ac72c2285be8..a07c7ec8752ff987b8783c4fa15904078bd5612d 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -4,6 +4,7 @@ use crate::{ load_project::run_load_project, progress::{InfoStyle, Progress, Step, StepProgress}, }; +use anyhow::Context as _; use collections::HashSet; use edit_prediction::{DebugEvent, EditPredictionStore}; use futures::{FutureExt as _, StreamExt as _, channel::mpsc}; @@ -17,12 +18,12 @@ pub async fn run_context_retrieval( example: &mut Example, app_state: Arc, mut cx: AsyncApp, -) { +) -> anyhow::Result<()> { if example.context.is_some() { - return; + return Ok(()); } - run_load_project(example, app_state.clone(), cx.clone()).await; + run_load_project(example, app_state.clone(), cx.clone()).await?; let step_progress: Arc = Progress::global() .start(Step::Context, &example.name) @@ -31,25 +32,21 @@ pub async fn run_context_retrieval( let state = example.state.as_ref().unwrap(); let project = state.project.clone(); - let _lsp_handle = project - .update(&mut cx, |project, cx| { - project.register_buffer_with_language_servers(&state.buffer, cx) - }) - .unwrap(); - wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await; - - let ep_store = cx - .update(|cx| EditPredictionStore::try_global(cx).unwrap()) - .unwrap(); - - let mut events = ep_store - .update(&mut cx, |store, cx| { - store.register_buffer(&state.buffer, &project, cx); - store.set_use_context(true); - store.refresh_context(&project, &state.buffer, state.cursor_position, cx); - store.debug_info(&project, cx) - }) - .unwrap(); + let _lsp_handle = project.update(&mut cx, |project, cx| { + project.register_buffer_with_language_servers(&state.buffer, cx) + })?; + wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?; + + let ep_store = cx.update(|cx| { + EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized") + })??; + + let mut events = ep_store.update(&mut cx, |store, cx| { + store.register_buffer(&state.buffer, &project, cx); + store.set_use_context(true); + store.refresh_context(&project, &state.buffer, state.cursor_position, cx); + store.debug_info(&project, cx) + })?; while let Some(event) = events.next().await { match event { @@ -60,9 +57,8 @@ pub async fn run_context_retrieval( } } - let context_files = ep_store - .update(&mut cx, |store, cx| store.context_for_project(&project, cx)) - .unwrap(); + let context_files = + ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?; let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum(); step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal); @@ -70,6 +66,7 @@ pub async fn run_context_retrieval( example.context = Some(ExampleContext { files: context_files, }); + Ok(()) } async fn wait_for_language_servers_to_start( @@ -77,10 +74,8 @@ async fn wait_for_language_servers_to_start( buffer: &Entity, step_progress: &Arc, cx: &mut AsyncApp, -) { - let lsp_store = project - .read_with(cx, |project, _| project.lsp_store()) - .unwrap(); +) -> anyhow::Result<()> { + let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?; let (language_server_ids, mut starting_language_server_ids) = buffer .update(cx, |buffer, cx| { @@ -123,7 +118,7 @@ async fn wait_for_language_servers_to_start( } }, _ = timeout.clone().fuse() => { - panic!("LSP wait timed out after 5 minutes"); + return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes")); } } } @@ -132,8 +127,7 @@ async fn wait_for_language_servers_to_start( if !language_server_ids.is_empty() { project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) - .unwrap() + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? .detach(); } @@ -175,10 +169,8 @@ async fn wait_for_language_servers_to_start( ]; project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) - .unwrap() - .await - .unwrap(); + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? + .await?; let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter()); while !pending_language_server_ids.is_empty() { @@ -189,11 +181,12 @@ async fn wait_for_language_servers_to_start( } }, _ = timeout.clone().fuse() => { - panic!("LSP wait timed out after 5 minutes"); + return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes")); } } } drop(subscriptions); step_progress.clear_substatus(); + Ok(()) } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index b87d8e4df24c8cb12676ed71fe1ea930a841791d..314d19b67259e6a4a0fcff932826325f4366ddde 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -15,7 +15,7 @@ pub async fn run_scoring( args: &PredictArgs, app_state: Arc, cx: AsyncApp, -) { +) -> anyhow::Result<()> { run_prediction( example, Some(args.provider), @@ -23,7 +23,7 @@ pub async fn run_scoring( app_state, cx, ) - .await; + .await?; let _progress = Progress::global().start(Step::Score, &example.name); @@ -43,6 +43,7 @@ pub async fn run_scoring( } example.score = scores; + Ok(()) } fn parse_patch(patch: &str) -> Vec> {