Detailed changes
@@ -1,14 +1,22 @@
+use anyhow::{Result, anyhow};
use std::mem;
use crate::example::Example;
-pub async fn run_distill(example: &mut Example) {
- let [prediction]: [_; 1] = mem::take(&mut example.predictions)
- .try_into()
- .expect("Run predict first with a single repetition");
+pub async fn run_distill(example: &mut Example) -> Result<()> {
+ let [prediction]: [_; 1] =
+ mem::take(&mut example.predictions)
+ .try_into()
+ .map_err(|preds: Vec<_>| {
+ anyhow!(
+ "Example has {} predictions, but it should have exactly one",
+ preds.len()
+ )
+ })?;
example.expected_patch = prediction.actual_patch;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
+ Ok(())
}
@@ -6,6 +6,7 @@ use crate::{
progress::{Progress, Step},
retrieve_context::run_context_retrieval,
};
+use anyhow::{Context as _, Result, ensure};
use edit_prediction::{
EditPredictionStore,
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
@@ -19,8 +20,8 @@ pub async fn run_format_prompt(
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
-) {
- run_context_retrieval(example, app_state.clone(), cx.clone()).await;
+) -> Result<()> {
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
@@ -34,29 +35,33 @@ pub async fn run_format_prompt(
});
}
PromptFormat::Zeta2 => {
- run_load_project(example, app_state, cx.clone()).await;
+ run_load_project(example, app_state, cx.clone()).await?;
- let ep_store = cx
- .update(|cx| EditPredictionStore::try_global(cx).unwrap())
- .unwrap();
+ let ep_store = cx.update(|cx| {
+ EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
+ })??;
- let state = example.state.as_ref().unwrap();
- let snapshot = state
- .buffer
- .read_with(&cx, |buffer, _| buffer.snapshot())
- .unwrap();
+ let state = example.state.as_ref().context("state must be set")?;
+ let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
let project = state.project.clone();
- let (_, input) = ep_store
- .update(&mut cx, |ep_store, _cx| {
- zeta2_prompt_input(
- &snapshot,
- example.context.as_ref().unwrap().files.clone(),
- ep_store.edit_history_for_project(&project),
- example.cursor_path.clone(),
- example.buffer.as_ref().unwrap().cursor_offset,
- )
- })
- .unwrap();
+ let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| {
+ anyhow::Ok(zeta2_prompt_input(
+ &snapshot,
+ example
+ .context
+ .as_ref()
+ .context("context must be set")?
+ .files
+ .clone(),
+ ep_store.edit_history_for_project(&project),
+ example.cursor_path.clone(),
+ example
+ .buffer
+ .as_ref()
+ .context("buffer must be set")?
+ .cursor_offset,
+ ))
+ })??;
let prompt = format_zeta_prompt(&input);
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
example.prompt = Some(ExamplePrompt {
@@ -66,6 +71,7 @@ pub async fn run_format_prompt(
});
}
};
+ Ok(())
}
pub struct TeacherPrompt;
@@ -91,7 +97,7 @@ impl TeacherPrompt {
prompt
}
- pub fn parse(example: &Example, response: &str) -> String {
+ pub fn parse(example: &Example, response: &str) -> Result<String> {
// Ideally, we should always be able to find cursor position in the retrieved context.
// In reality, sometimes we don't find it for these reasons:
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
@@ -102,7 +108,7 @@ impl TeacherPrompt {
let cursor_file = &example
.buffer
.as_ref()
- .expect("`buffer` should be filled in in the context collection step")
+ .context("`buffer` should be filled in in the context collection step")?
.content;
// Extract updated (new) editable region from the model response
@@ -111,9 +117,10 @@ impl TeacherPrompt {
// Reconstruct old editable region we sent to the model
let old_editable_region = Self::format_editable_region(example);
let old_editable_region = Self::extract_editable_region(&old_editable_region);
- if !cursor_file.contains(&old_editable_region) {
- panic!("Something's wrong: editable_region is not found in the cursor file")
- }
+ ensure!(
+ cursor_file.contains(&old_editable_region),
+ "Something's wrong: editable_region is not found in the cursor file"
+ );
// Apply editable region to a larger context and compute diff.
// This is needed to get a better context lines around the editable region
@@ -128,7 +135,7 @@ impl TeacherPrompt {
diff = diff,
};
- diff
+ Ok(diff)
}
fn format_edit_history(edit_history: &str) -> String {
@@ -152,9 +159,7 @@ impl TeacherPrompt {
}
fn format_context(example: &Example) -> String {
- if example.context.is_none() {
- panic!("Missing context retriever step");
- }
+ assert!(example.context.is_some(), "Missing context retriever step");
let mut prompt = String::new();
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
@@ -4,7 +4,7 @@ use crate::{
paths::{REPOS_DIR, WORKTREES_DIR},
progress::{InfoStyle, Progress, Step, StepProgress},
};
-use anyhow::{Result, anyhow};
+use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
@@ -25,38 +25,38 @@ use std::{
use util::{paths::PathStyle, rel_path::RelPath};
use zeta_prompt::CURSOR_MARKER;
-pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
+pub async fn run_load_project(
+ example: &mut Example,
+ app_state: Arc<EpAppState>,
+ mut cx: AsyncApp,
+) -> Result<()> {
if example.state.is_some() {
- return;
+ return Ok(());
}
let progress = Progress::global().start(Step::LoadProject, &example.name);
- let project = setup_project(example, &app_state, &progress, &mut cx).await;
-
- let _open_buffers = apply_edit_history(example, &project, &mut cx)
- .await
- .unwrap();
-
- let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
- let (example_buffer, language_name) = buffer
- .read_with(&cx, |buffer, _cx| {
- let cursor_point = cursor_position.to_point(&buffer);
- let language_name = buffer
- .language()
- .map(|l| l.name().to_string())
- .unwrap_or_else(|| "Unknown".to_string());
- (
- ExampleBuffer {
- content: buffer.text(),
- cursor_row: cursor_point.row,
- cursor_column: cursor_point.column,
- cursor_offset: cursor_position.to_offset(&buffer),
- },
- language_name,
- )
- })
- .unwrap();
+ let project = setup_project(example, &app_state, &progress, &mut cx).await?;
+
+ let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
+
+ let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
+ let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
+ let cursor_point = cursor_position.to_point(&buffer);
+ let language_name = buffer
+ .language()
+ .map(|l| l.name().to_string())
+ .unwrap_or_else(|| "Unknown".to_string());
+ (
+ ExampleBuffer {
+ content: buffer.text(),
+ cursor_row: cursor_point.row,
+ cursor_column: cursor_point.column,
+ cursor_offset: cursor_position.to_offset(&buffer),
+ },
+ language_name,
+ )
+ })?;
progress.set_info(language_name, InfoStyle::Normal);
@@ -67,16 +67,15 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>,
cursor_position,
_open_buffers,
});
+ Ok(())
}
async fn cursor_position(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
-) -> (Entity<Buffer>, Anchor) {
- let language_registry = project
- .read_with(cx, |project, _| project.languages().clone())
- .unwrap();
+) -> Result<(Entity<Buffer>, Anchor)> {
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
let result = language_registry
.load_language_for_file_path(&example.cursor_path)
.await;
@@ -84,17 +83,18 @@ async fn cursor_position(
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
- panic!("Failed to load language for file path: {}", error);
+ return Err(error);
}
- let worktree = project
- .read_with(cx, |project, cx| {
- project.visible_worktrees(cx).next().unwrap()
- })
- .unwrap();
+ let worktree = project.read_with(cx, |project, cx| {
+ project
+ .visible_worktrees(cx)
+ .next()
+ .context("No visible worktrees")
+ })??;
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
- .unwrap()
+ .context("Failed to create RelPath")?
.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
@@ -105,15 +105,12 @@ async fn cursor_position(
},
cx,
)
- })
- .unwrap()
- .await
- .unwrap();
+ })?
+ .await?;
let cursor_offset_within_excerpt = example
.cursor_position
.find(CURSOR_MARKER)
- .ok_or_else(|| anyhow!("missing cursor marker"))
- .unwrap();
+ .context("missing cursor marker")?;
let mut cursor_excerpt = example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
@@ -123,22 +120,21 @@ async fn cursor_position(
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
- let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
- panic!(
+ let (excerpt_offset, _) = matches.next().with_context(|| {
+ format!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
example.name
- );
- });
- assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
- excerpt_offset
- }).unwrap();
+ )
+ })?;
+ anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
+ Ok(excerpt_offset)
+ })??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
- let cursor_anchor = cursor_buffer
- .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
- .unwrap();
+ let cursor_anchor =
+ cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
- (cursor_buffer, cursor_anchor)
+ Ok((cursor_buffer, cursor_anchor))
}
async fn setup_project(
@@ -146,67 +142,54 @@ async fn setup_project(
app_state: &Arc<EpAppState>,
step_progress: &StepProgress,
cx: &mut AsyncApp,
-) -> Entity<Project> {
+) -> Result<Entity<Project>> {
let ep_store = cx
- .update(|cx| EditPredictionStore::try_global(cx).unwrap())
- .unwrap();
+ .update(|cx| EditPredictionStore::try_global(cx))?
+ .context("Store should be initialized at init")?;
- let worktree_path = setup_worktree(example, step_progress).await;
+ let worktree_path = setup_worktree(example, step_progress).await?;
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
- ep_store
- .update(cx, |ep_store, _| {
- ep_store.clear_history_for_project(&project);
- })
- .unwrap();
- let buffer_store = project
- .read_with(cx, |project, _| project.buffer_store().clone())
- .unwrap();
- let buffers = buffer_store
- .read_with(cx, |buffer_store, _| {
- buffer_store.buffers().collect::<Vec<_>>()
- })
- .unwrap();
+ ep_store.update(cx, |ep_store, _| {
+ ep_store.clear_history_for_project(&project);
+ })?;
+ let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+ let buffers = buffer_store.read_with(cx, |buffer_store, _| {
+ buffer_store.buffers().collect::<Vec<_>>()
+ })?;
for buffer in buffers {
buffer
- .update(cx, |buffer, cx| buffer.reload(cx))
- .unwrap()
+ .update(cx, |buffer, cx| buffer.reload(cx))?
.await
.ok();
}
- return project;
+ return Ok(project);
}
- let project = cx
- .update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })
- .unwrap();
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
project
.update(cx, |project, cx| {
project.disable_worktree_scanner(cx);
project.create_worktree(&worktree_path, true, cx)
- })
- .unwrap()
- .await
- .unwrap();
+ })?
+ .await?;
app_state
.project_cache
.insert(example.repository_url.clone(), project.clone());
- let buffer_store = project
- .read_with(cx, |project, _| project.buffer_store().clone())
- .unwrap();
+ let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
@@ -215,15 +198,14 @@ async fn setup_project(
}
_ => {}
}
- })
- .unwrap()
+ })?
.detach();
- project
+ Ok(project)
}
-async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf {
- let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
+async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
+ let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let worktree_path = WORKTREES_DIR
.join(repo_owner.as_ref())
@@ -232,14 +214,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
if !repo_dir.is_dir() {
step_progress.set_substatus(format!("cloning {}", repo_name));
- fs::create_dir_all(&repo_dir).unwrap();
- run_git(&repo_dir, &["init"]).await.unwrap();
+ fs::create_dir_all(&repo_dir)?;
+ run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &example.repository_url],
)
- .await
- .unwrap();
+ .await?;
}
// Resolve the example to a revision, fetching it if needed.
@@ -259,34 +240,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
.await
.is_err()
{
- run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
+ run_git(&repo_dir, &["fetch", "origin"]).await?;
}
- let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
- .await
- .unwrap();
+ let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
revision
};
// Create the worktree for this example if needed.
step_progress.set_substatus("preparing worktree");
if worktree_path.is_dir() {
- run_git(&worktree_path, &["clean", "--force", "-d"])
- .await
- .unwrap();
- run_git(&worktree_path, &["reset", "--hard", "HEAD"])
- .await
- .unwrap();
- run_git(&worktree_path, &["checkout", revision.as_str()])
- .await
- .unwrap();
+ run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+ run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(
&repo_dir,
&["branch", "-f", &example.name, revision.as_str()],
)
- .await
- .unwrap();
+ .await?;
run_git(
&repo_dir,
&[
@@ -297,8 +269,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
&example.name,
],
)
- .await
- .unwrap();
+ .await?;
}
drop(repo_lock);
@@ -309,30 +280,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
- .spawn()
- .unwrap();
-
- let mut stdin = apply_process.stdin.take().unwrap();
- stdin
- .write_all(example.uncommitted_diff.as_bytes())
- .await
- .unwrap();
- stdin.close().await.unwrap();
+ .spawn()?;
+
+ let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
+ stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
+ stdin.close().await?;
drop(stdin);
- let apply_result = apply_process.output().await.unwrap();
- if !apply_result.status.success() {
- panic!(
- "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
- apply_result.status,
- String::from_utf8_lossy(&apply_result.stderr),
- String::from_utf8_lossy(&apply_result.stdout),
- );
- }
+ let apply_result = apply_process.output().await?;
+ anyhow::ensure!(
+ apply_result.status.success(),
+ "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+ apply_result.status,
+ String::from_utf8_lossy(&apply_result.stderr),
+ String::from_utf8_lossy(&apply_result.stdout),
+ );
}
step_progress.clear_substatus();
- worktree_path
+ Ok(worktree_path)
}
async fn apply_edit_history(
@@ -16,12 +16,14 @@ use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
+use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
use crate::example::{group_examples_by_repo, read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
+use crate::paths::FAILED_EXAMPLES_DIR;
use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
@@ -42,6 +44,8 @@ struct EpArgs {
output: Option<PathBuf>,
#[arg(long, short, global = true)]
in_place: bool,
+ #[arg(long, short, global = true)]
+ failfast: bool,
}
#[derive(Subcommand, Debug)]
@@ -67,6 +71,58 @@ enum Command {
Clean,
}
+impl Display for Command {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Command::ParseExample => write!(f, "parse-example"),
+ Command::LoadProject => write!(f, "load-project"),
+ Command::Context => write!(f, "context"),
+ Command::FormatPrompt(format_prompt_args) => write!(
+ f,
+ "format-prompt --prompt-format={}",
+ format_prompt_args
+ .prompt_format
+ .to_possible_value()
+ .unwrap()
+ .get_name()
+ ),
+ Command::Predict(predict_args) => {
+ write!(
+ f,
+ "predict --provider={:?}",
+ predict_args
+ .provider
+ .to_possible_value()
+ .unwrap()
+ .get_name()
+ )
+ }
+ Command::Score(predict_args) => {
+ write!(
+ f,
+ "score --provider={:?}",
+ predict_args
+ .provider
+ .to_possible_value()
+ .unwrap()
+ .get_name()
+ )
+ }
+ Command::Distill => write!(f, "distill"),
+ Command::Eval(predict_args) => write!(
+ f,
+ "eval --provider={:?}",
+ predict_args
+ .provider
+ .to_possible_value()
+ .unwrap()
+ .get_name()
+ ),
+ Command::Clean => write!(f, "clean"),
+ }
+ }
+}
+
#[derive(Debug, Args)]
struct FormatPromptArgs {
#[clap(long)]
@@ -145,71 +201,140 @@ fn main() {
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
- if let Command::Predict(args) = &command {
- predict::sync_batches(&args.provider).await
- };
-
- let total_examples = examples.len();
- Progress::global().set_total_examples(total_examples);
-
- let mut grouped_examples = group_examples_by_repo(&mut examples);
- let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
-
- for example_batch in example_batches {
- let futures = example_batch.into_iter().map(|repo_examples| async {
- for example in repo_examples.iter_mut() {
- match &command {
- Command::ParseExample => {}
- Command::LoadProject => {
- run_load_project(example, app_state.clone(), cx.clone()).await;
- }
- Command::Context => {
- run_context_retrieval(example, app_state.clone(), cx.clone()).await;
- }
- Command::FormatPrompt(args) => {
- run_format_prompt(
- example,
- args.prompt_format,
- app_state.clone(),
- cx.clone(),
- )
- .await;
- }
- Command::Predict(args) => {
- run_prediction(
- example,
- Some(args.provider),
- args.repetitions,
- app_state.clone(),
- cx.clone(),
- )
- .await;
- }
- Command::Distill => {
- run_distill(example).await;
- }
- Command::Score(args) | Command::Eval(args) => {
- run_scoring(example, &args, app_state.clone(), cx.clone()).await;
+ let result = async {
+ if let Command::Predict(args) = &command {
+ predict::sync_batches(&args.provider).await?;
+ }
+
+ let total_examples = examples.len();
+ Progress::global().set_total_examples(total_examples);
+
+ let mut grouped_examples = group_examples_by_repo(&mut examples);
+ let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
+
+ for example_batch in example_batches {
+ let futures = example_batch.into_iter().map(|repo_examples| async {
+ for example in repo_examples.iter_mut() {
+ let result = async {
+ match &command {
+ Command::ParseExample => {}
+ Command::LoadProject => {
+ run_load_project(example, app_state.clone(), cx.clone())
+ .await?;
+ }
+ Command::Context => {
+ run_context_retrieval(
+ example,
+ app_state.clone(),
+ cx.clone(),
+ )
+ .await?;
+ }
+ Command::FormatPrompt(args) => {
+ run_format_prompt(
+ example,
+ args.prompt_format,
+ app_state.clone(),
+ cx.clone(),
+ )
+ .await?;
+ }
+ Command::Predict(args) => {
+ run_prediction(
+ example,
+ Some(args.provider),
+ args.repetitions,
+ app_state.clone(),
+ cx.clone(),
+ )
+ .await?;
+ }
+ Command::Distill => {
+ run_distill(example).await?;
+ }
+ Command::Score(args) | Command::Eval(args) => {
+ run_scoring(example, &args, app_state.clone(), cx.clone())
+ .await?;
+ }
+ Command::Clean => {
+ unreachable!()
+ }
+ }
+ anyhow::Ok(())
}
- Command::Clean => {
- unreachable!()
+ .await;
+
+ if let Err(e) = result {
+ Progress::global().increment_failed();
+ let failed_example_path =
+ FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
+ app_state
+ .fs
+ .write(
+ &failed_example_path,
+ &serde_json::to_vec_pretty(&example).unwrap(),
+ )
+ .await
+ .unwrap();
+ let err_path =
+ FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
+ app_state
+ .fs
+ .write(&err_path, e.to_string().as_bytes())
+ .await
+ .unwrap();
+
+ let msg = format!(
+ indoc::indoc! {"
+ While processing {}:
+
+ {:?}
+
+ Written to: \x1b[36m{}\x1b[0m
+
+ Explore this example data with:
+ fx \x1b[36m{}\x1b[0m
+
+ Re-run this example with:
+ cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
+ "},
+ example.name,
+ e,
+ err_path.display(),
+ failed_example_path.display(),
+ command,
+ failed_example_path.display(),
+ );
+ if args.failfast || total_examples == 1 {
+ Progress::global().finalize();
+ panic!("{}", msg);
+ } else {
+ log::error!("{}", msg);
+ }
}
}
- }
- });
- futures::future::join_all(futures).await;
- }
- Progress::global().clear();
+ });
+ futures::future::join_all(futures).await;
+ }
+ Progress::global().finalize();
- if args.output.is_some() || !matches!(command, Command::Eval(_)) {
- write_examples(&examples, output.as_ref());
+ if args.output.is_some() || !matches!(command, Command::Eval(_)) {
+ write_examples(&examples, output.as_ref());
+ }
+
+ match &command {
+ Command::Predict(args) => predict::sync_batches(&args.provider).await?,
+ Command::Eval(_) => score::print_report(&examples),
+ _ => (),
+ };
+
+ anyhow::Ok(())
}
+ .await;
- match &command {
- Command::Predict(args) => predict::sync_batches(&args.provider).await,
- Command::Eval(_) => score::print_report(&examples),
- _ => (),
- };
+ if let Err(e) = result {
+ panic!("Fatal error: {:?}", e);
+ }
let _ = cx.update(|cx| cx.quit());
})
@@ -18,6 +18,8 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
+pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
+ LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
fn ensure_dir(path: &Path) -> PathBuf {
std::fs::create_dir_all(path).expect("Failed to create directory");
@@ -9,6 +9,7 @@ use crate::{
progress::{InfoStyle, Progress, Step},
retrieve_context::run_context_retrieval,
};
+use anyhow::Context as _;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
@@ -26,14 +27,14 @@ pub async fn run_prediction(
repetition_count: usize,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
-) {
+) -> anyhow::Result<()> {
if !example.predictions.is_empty() {
- return;
+ return Ok(());
}
- let provider = provider.unwrap();
+ let provider = provider.context("provider is required")?;
- run_context_retrieval(example, app_state.clone(), cx.clone()).await;
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
if matches!(
provider,
@@ -42,14 +43,14 @@ pub async fn run_prediction(
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if example.prompt.is_none() {
- run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
+ run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
}
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
}
- run_load_project(example, app_state.clone(), cx.clone()).await;
+ run_load_project(example, app_state.clone(), cx.clone()).await?;
let _step_progress = Progress::global().start(Step::Predict, &example.name);
@@ -62,10 +63,9 @@ pub async fn run_prediction(
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
- client
- .sign_in_with_optional_connect(true, cx)
- .await
- .unwrap();
+ if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
+ eprintln!("Authentication failed: {}", e);
+ }
})
.shared()
})
@@ -73,33 +73,30 @@ pub async fn run_prediction(
.await;
}
- let ep_store = cx
- .update(|cx| EditPredictionStore::try_global(cx).unwrap())
- .unwrap();
-
- ep_store
- .update(&mut cx, |store, _cx| {
- let model = match provider {
- PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
- PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
- PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
- PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
- PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
- unreachable!()
- }
- };
- store.set_edit_prediction_model(model);
- })
- .unwrap();
- let state = example.state.as_ref().unwrap();
+ let ep_store = cx.update(|cx| {
+ EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
+ })??;
+
+ ep_store.update(&mut cx, |store, _cx| {
+ let model = match provider {
+ PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
+ PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+ PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
+ PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
+ PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
+ unreachable!()
+ }
+ };
+ store.set_edit_prediction_model(model);
+ })?;
+ let state = example.state.as_ref().context("state must be set")?;
let run_dir = RUN_DIR.join(&example.name);
let updated_example = Arc::new(Mutex::new(example.clone()));
let current_run_ix = Arc::new(AtomicUsize::new(0));
- let mut debug_rx = ep_store
- .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
- .unwrap();
+ let mut debug_rx =
+ ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
let debug_task = cx.background_spawn({
let updated_example = updated_example.clone();
let current_run_ix = current_run_ix.clone();
@@ -153,14 +150,14 @@ pub async fn run_prediction(
run_dir.clone()
};
- fs::create_dir_all(&run_dir).unwrap();
+ fs::create_dir_all(&run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
- fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
+ fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
- std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+ std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
#[cfg(windows)]
- std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+ std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
updated_example
.lock()
@@ -181,10 +178,8 @@ pub async fn run_prediction(
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
- })
- .unwrap()
- .await
- .unwrap();
+ })?
+ .await?;
let actual_patch = prediction
.and_then(|prediction| {
@@ -213,20 +208,23 @@ pub async fn run_prediction(
}
}
- ep_store
- .update(&mut cx, |store, _| {
- store.remove_project(&state.project);
- })
- .unwrap();
- debug_task.await.unwrap();
+ ep_store.update(&mut cx, |store, _| {
+ store.remove_project(&state.project);
+ })?;
+ debug_task.await?;
*example = Arc::into_inner(updated_example)
- .unwrap()
+ .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
.into_inner()
- .unwrap();
+ .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
+ Ok(())
}
-async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
+async fn predict_anthropic(
+ example: &mut Example,
+ _repetition_count: usize,
+ batched: bool,
+) -> anyhow::Result<()> {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
let llm_client = if batched {
@@ -234,12 +232,9 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
} else {
AnthropicClient::plain()
};
- let llm_client = llm_client.expect("Failed to create LLM client");
+ let llm_client = llm_client.context("Failed to create LLM client")?;
- let prompt = example
- .prompt
- .as_ref()
- .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
+ let prompt = example.prompt.as_ref().context("Prompt is required")?;
let messages = vec![anthropic::Message {
role: anthropic::Role::User,
@@ -251,11 +246,10 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
let Some(response) = llm_client
.generate(llm_model_name, max_tokens, messages)
- .await
- .unwrap()
+ .await?
else {
// Request stashed for batched processing
- return;
+ return Ok(());
};
let actual_output = response
@@ -268,7 +262,7 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
.collect::<Vec<String>>()
.join("\n");
- let actual_patch = TeacherPrompt::parse(example, &actual_output);
+ let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch,
@@ -277,19 +271,21 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
};
example.predictions.push(prediction);
+ Ok(())
}
-pub async fn sync_batches(provider: &PredictionProvider) {
+pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
match provider {
PredictionProvider::Teacher => {
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
let llm_client =
- AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
+ AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
llm_client
.sync_batches()
.await
- .expect("Failed to sync batches");
+ .context("Failed to sync batches")?;
}
_ => (),
- }
+ };
+ Ok(())
}
@@ -20,6 +20,7 @@ struct ProgressInner {
max_example_name_len: usize,
status_lines_displayed: usize,
total_examples: usize,
+ failed_examples: usize,
last_line_is_logging: bool,
}
@@ -78,7 +79,7 @@ impl Step {
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
static LOGGER: ProgressLogger = ProgressLogger;
-const RIGHT_MARGIN: usize = 4;
+const MARGIN: usize = 4;
const MAX_STATUS_LINES: usize = 10;
impl Progress {
@@ -95,6 +96,7 @@ impl Progress {
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples: 0,
+ failed_examples: 0,
last_line_is_logging: false,
}),
});
@@ -110,6 +112,11 @@ impl Progress {
inner.total_examples = total;
}
+ pub fn increment_failed(&self) {
+ let mut inner = self.inner.lock().unwrap();
+ inner.failed_examples += 1;
+ }
+
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
/// This should be used for any output that needs to appear above the status lines.
fn log(&self, message: &str) {
@@ -119,7 +126,7 @@ impl Progress {
if !inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
- let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
+ let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = true;
}
@@ -180,7 +187,7 @@ impl Progress {
if inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
- let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
+ let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = false;
}
@@ -229,7 +236,7 @@ impl Progress {
let duration_with_margin = format!("{duration} ");
let padding_needed = inner
.terminal_width
- .saturating_sub(RIGHT_MARGIN)
+ .saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -263,20 +270,33 @@ impl Progress {
// Build the done/in-progress/total label
let done_count = inner.completed.len();
let in_progress_count = inner.in_progress.len();
+ let failed_count = inner.failed_examples;
+
+ let failed_label = if failed_count > 0 {
+ format!(" {} failed ", failed_count)
+ } else {
+ String::new()
+ };
+
let range_label = format!(
" {}/{}/{} ",
done_count, in_progress_count, inner.total_examples
);
- // Print a divider line with range label aligned with timestamps
+ // Print a divider line with failed count on left, range label on right
+ let failed_visible_len = strip_ansi_len(&failed_label);
let range_visible_len = range_label.len();
- let left_divider_len = inner
+ let middle_divider_len = inner
.terminal_width
- .saturating_sub(RIGHT_MARGIN)
+ .saturating_sub(MARGIN * 2)
+ .saturating_sub(failed_visible_len)
.saturating_sub(range_visible_len);
- let left_divider = "─".repeat(left_divider_len);
- let right_divider = "─".repeat(RIGHT_MARGIN);
- eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
+ let left_divider = "─".repeat(MARGIN);
+ let middle_divider = "─".repeat(middle_divider_len);
+ let right_divider = "─".repeat(MARGIN);
+ eprintln!(
+ "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
+ );
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
tasks.sort_by_key(|(name, _)| *name);
@@ -304,7 +324,7 @@ impl Progress {
let duration_with_margin = format!("{elapsed} ");
let padding_needed = inner
.terminal_width
- .saturating_sub(RIGHT_MARGIN)
+ .saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -324,9 +344,23 @@ impl Progress {
let _ = std::io::stderr().flush();
}
- pub fn clear(&self) {
+ pub fn finalize(&self) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
+
+ // Print summary if there were failures
+ if inner.failed_examples > 0 {
+ let total_processed = inner.completed.len() + inner.failed_examples;
+ let percentage = if total_processed > 0 {
+ inner.failed_examples as f64 / total_processed as f64 * 100.0
+ } else {
+ 0.0
+ };
+ eprintln!(
+ "\n{} of {} examples failed ({:.1}%)",
+ inner.failed_examples, total_processed, percentage
+ );
+ }
}
}
@@ -4,6 +4,7 @@ use crate::{
load_project::run_load_project,
progress::{InfoStyle, Progress, Step, StepProgress},
};
+use anyhow::Context as _;
use collections::HashSet;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
@@ -17,12 +18,12 @@ pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
-) {
+) -> anyhow::Result<()> {
if example.context.is_some() {
- return;
+ return Ok(());
}
- run_load_project(example, app_state.clone(), cx.clone()).await;
+ run_load_project(example, app_state.clone(), cx.clone()).await?;
let step_progress: Arc<StepProgress> = Progress::global()
.start(Step::Context, &example.name)
@@ -31,25 +32,21 @@ pub async fn run_context_retrieval(
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
- let _lsp_handle = project
- .update(&mut cx, |project, cx| {
- project.register_buffer_with_language_servers(&state.buffer, cx)
- })
- .unwrap();
- wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
-
- let ep_store = cx
- .update(|cx| EditPredictionStore::try_global(cx).unwrap())
- .unwrap();
-
- let mut events = ep_store
- .update(&mut cx, |store, cx| {
- store.register_buffer(&state.buffer, &project, cx);
- store.set_use_context(true);
- store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
- store.debug_info(&project, cx)
- })
- .unwrap();
+ let _lsp_handle = project.update(&mut cx, |project, cx| {
+ project.register_buffer_with_language_servers(&state.buffer, cx)
+ })?;
+ wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
+
+ let ep_store = cx.update(|cx| {
+ EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
+ })??;
+
+ let mut events = ep_store.update(&mut cx, |store, cx| {
+ store.register_buffer(&state.buffer, &project, cx);
+ store.set_use_context(true);
+ store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
+ store.debug_info(&project, cx)
+ })?;
while let Some(event) = events.next().await {
match event {
@@ -60,9 +57,8 @@ pub async fn run_context_retrieval(
}
}
- let context_files = ep_store
- .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
- .unwrap();
+ let context_files =
+ ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
@@ -70,6 +66,7 @@ pub async fn run_context_retrieval(
example.context = Some(ExampleContext {
files: context_files,
});
+ Ok(())
}
async fn wait_for_language_servers_to_start(
@@ -77,10 +74,8 @@ async fn wait_for_language_servers_to_start(
buffer: &Entity<Buffer>,
step_progress: &Arc<StepProgress>,
cx: &mut AsyncApp,
-) {
- let lsp_store = project
- .read_with(cx, |project, _| project.lsp_store())
- .unwrap();
+) -> anyhow::Result<()> {
+ let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
let (language_server_ids, mut starting_language_server_ids) = buffer
.update(cx, |buffer, cx| {
@@ -123,7 +118,7 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
- panic!("LSP wait timed out after 5 minutes");
+ return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
@@ -132,8 +127,7 @@ async fn wait_for_language_servers_to_start(
if !language_server_ids.is_empty() {
project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
- .unwrap()
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.detach();
}
@@ -175,10 +169,8 @@ async fn wait_for_language_servers_to_start(
];
project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
- .unwrap()
- .await
- .unwrap();
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
+ .await?;
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
while !pending_language_server_ids.is_empty() {
@@ -189,11 +181,12 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
- panic!("LSP wait timed out after 5 minutes");
+ return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
drop(subscriptions);
step_progress.clear_substatus();
+ Ok(())
}
@@ -15,7 +15,7 @@ pub async fn run_scoring(
args: &PredictArgs,
app_state: Arc<EpAppState>,
cx: AsyncApp,
-) {
+) -> anyhow::Result<()> {
run_prediction(
example,
Some(args.provider),
@@ -23,7 +23,7 @@ pub async fn run_scoring(
app_state,
cx,
)
- .await;
+ .await?;
let _progress = Progress::global().start(Step::Score, &example.name);
@@ -43,6 +43,7 @@ pub async fn run_scoring(
}
example.score = scores;
+ Ok(())
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {