Cargo.lock 🔗
@@ -5201,7 +5201,6 @@ dependencies = [
"wasmtime",
"watch",
"zeta_prompt",
- "zlog",
]
[[package]]
Agus Zubiaga created
- Limit status lines to 10 in case `max_parallelism` is specified with a
grater value
- Handle logging gracefully rather than writing over it when clearing
status lines
Release Notes:
- N/A
Cargo.lock | 1
crates/edit_prediction_cli/Cargo.toml | 1
crates/edit_prediction_cli/src/format_prompt.rs | 7
crates/edit_prediction_cli/src/load_project.rs | 13
crates/edit_prediction_cli/src/main.rs | 35 --
crates/edit_prediction_cli/src/predict.rs | 18 -
crates/edit_prediction_cli/src/progress.rs | 198 ++++++++++++---
crates/edit_prediction_cli/src/retrieve_context.rs | 7
crates/edit_prediction_cli/src/score.rs | 4
9 files changed, 173 insertions(+), 111 deletions(-)
@@ -5201,7 +5201,6 @@ dependencies = [
"wasmtime",
"watch",
"zeta_prompt",
- "zlog",
]
[[package]]
@@ -56,7 +56,6 @@ watch.workspace = true
edit_prediction = { workspace = true, features = ["cli-support"] }
wasmtime.workspace = true
zeta_prompt.workspace = true
-zlog.workspace = true
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.
@@ -18,12 +18,11 @@ pub async fn run_format_prompt(
example: &mut Example,
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
- progress: Arc<Progress>,
mut cx: AsyncApp,
) {
- run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await;
- let _step_progress = progress.start(Step::FormatPrompt, &example.name);
+ let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
match prompt_format {
PromptFormat::Teacher => {
@@ -35,7 +34,7 @@ pub async fn run_format_prompt(
});
}
PromptFormat::Zeta2 => {
- run_load_project(example, app_state, progress.clone(), cx.clone()).await;
+ run_load_project(example, app_state, cx.clone()).await;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
@@ -25,17 +25,12 @@ use std::{
use util::{paths::PathStyle, rel_path::RelPath};
use zeta_prompt::CURSOR_MARKER;
-pub async fn run_load_project(
- example: &mut Example,
- app_state: Arc<EpAppState>,
- progress: Arc<Progress>,
- mut cx: AsyncApp,
-) {
+pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
if example.state.is_some() {
return;
}
- let progress = progress.start(Step::LoadProject, &example.name);
+ let progress = Progress::global().start(Step::LoadProject, &example.name);
let project = setup_project(example, &app_state, &progress, &mut cx).await;
@@ -149,7 +144,7 @@ async fn cursor_position(
async fn setup_project(
example: &mut Example,
app_state: &Arc<EpAppState>,
- step_progress: &Arc<StepProgress>,
+ step_progress: &StepProgress,
cx: &mut AsyncApp,
) -> Entity<Project> {
let ep_store = cx
@@ -227,7 +222,7 @@ async fn setup_project(
project
}
-async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) -> PathBuf {
+async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf {
let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let worktree_path = WORKTREES_DIR
@@ -32,7 +32,7 @@ use crate::score::run_scoring;
struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
- #[clap(long, default_value_t = 10)]
+ #[clap(long, default_value_t = 10, global = true)]
max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
@@ -112,8 +112,6 @@ impl EpArgs {
}
fn main() {
- let _ = zlog::try_init(Some("error".into()));
- zlog::init_output_stderr();
let args = EpArgs::parse();
if args.printenv {
@@ -152,7 +150,7 @@ fn main() {
};
let total_examples = examples.len();
- let progress = Progress::new(total_examples);
+ Progress::global().set_total_examples(total_examples);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
@@ -163,29 +161,16 @@ fn main() {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
- run_load_project(
- example,
- app_state.clone(),
- progress.clone(),
- cx.clone(),
- )
- .await;
+ run_load_project(example, app_state.clone(), cx.clone()).await;
}
Command::Context => {
- run_context_retrieval(
- example,
- app_state.clone(),
- progress.clone(),
- cx.clone(),
- )
- .await;
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
- progress.clone(),
cx.clone(),
)
.await;
@@ -196,7 +181,6 @@ fn main() {
Some(args.provider),
args.repetitions,
app_state.clone(),
- progress.clone(),
cx.clone(),
)
.await;
@@ -205,14 +189,7 @@ fn main() {
run_distill(example).await;
}
Command::Score(args) | Command::Eval(args) => {
- run_scoring(
- example,
- &args,
- app_state.clone(),
- progress.clone(),
- cx.clone(),
- )
- .await;
+ run_scoring(example, &args, app_state.clone(), cx.clone()).await;
}
Command::Clean => {
unreachable!()
@@ -222,7 +199,7 @@ fn main() {
});
futures::future::join_all(futures).await;
}
- progress.clear();
+ Progress::global().clear();
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
@@ -25,7 +25,6 @@ pub async fn run_prediction(
provider: Option<PredictionProvider>,
repetition_count: usize,
app_state: Arc<EpAppState>,
- progress: Arc<Progress>,
mut cx: AsyncApp,
) {
if !example.predictions.is_empty() {
@@ -34,32 +33,25 @@ pub async fn run_prediction(
let provider = provider.unwrap();
- run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await;
if matches!(
provider,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
) {
- let _step_progress = progress.start(Step::Predict, &example.name);
+ let _step_progress = Progress::global().start(Step::Predict, &example.name);
if example.prompt.is_none() {
- run_format_prompt(
- example,
- PromptFormat::Teacher,
- app_state.clone(),
- progress,
- cx,
- )
- .await;
+ run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
}
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
}
- run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
+ run_load_project(example, app_state.clone(), cx.clone()).await;
- let _step_progress = progress.start(Step::Predict, &example.name);
+ let _step_progress = Progress::global().start(Step::Predict, &example.name);
if matches!(
provider,
@@ -2,10 +2,12 @@ use std::{
borrow::Cow,
collections::HashMap,
io::{IsTerminal, Write},
- sync::{Arc, Mutex},
+ sync::{Arc, Mutex, OnceLock},
time::{Duration, Instant},
};
+use log::{Level, Log, Metadata, Record};
+
pub struct Progress {
inner: Mutex<ProgressInner>,
}
@@ -18,6 +20,7 @@ struct ProgressInner {
max_example_name_len: usize,
status_lines_displayed: usize,
total_examples: usize,
+ last_line_is_logging: bool,
}
#[derive(Clone)]
@@ -72,70 +75,114 @@ impl Step {
}
}
+static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
+static LOGGER: ProgressLogger = ProgressLogger;
+
const RIGHT_MARGIN: usize = 4;
+const MAX_STATUS_LINES: usize = 10;
impl Progress {
- pub fn new(total_examples: usize) -> Arc<Self> {
- Arc::new(Self {
- inner: Mutex::new(ProgressInner {
- completed: Vec::new(),
- in_progress: HashMap::new(),
- is_tty: std::io::stderr().is_terminal(),
- terminal_width: get_terminal_width(),
- max_example_name_len: 0,
- status_lines_displayed: 0,
- total_examples,
- }),
- })
+ /// Returns the global Progress instance, initializing it if necessary.
+ pub fn global() -> Arc<Progress> {
+ GLOBAL
+ .get_or_init(|| {
+ let progress = Arc::new(Self {
+ inner: Mutex::new(ProgressInner {
+ completed: Vec::new(),
+ in_progress: HashMap::new(),
+ is_tty: std::io::stderr().is_terminal(),
+ terminal_width: get_terminal_width(),
+ max_example_name_len: 0,
+ status_lines_displayed: 0,
+ total_examples: 0,
+ last_line_is_logging: false,
+ }),
+ });
+ let _ = log::set_logger(&LOGGER);
+ log::set_max_level(log::LevelFilter::Error);
+ progress
+ })
+ .clone()
}
- pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> Arc<StepProgress> {
- {
- let mut inner = self.inner.lock().unwrap();
+ pub fn set_total_examples(&self, total: usize) {
+ let mut inner = self.inner.lock().unwrap();
+ inner.total_examples = total;
+ }
- Self::clear_status_lines(&mut inner);
+ /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
+ /// This should be used for any output that needs to appear above the status lines.
+ fn log(&self, message: &str) {
+ let mut inner = self.inner.lock().unwrap();
+ Self::clear_status_lines(&mut inner);
- inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
+ if !inner.last_line_is_logging {
+ let reset = "\x1b[0m";
+ let dim = "\x1b[2m";
+ let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
+ eprintln!("{dim}{divider}{reset}");
+ inner.last_line_is_logging = true;
+ }
- inner.in_progress.insert(
- example_name.to_string(),
- InProgressTask {
- step,
- started_at: Instant::now(),
- substatus: None,
- info: None,
- },
- );
+ eprintln!("{}", message);
+ }
- Self::print_status_lines(&mut inner);
- }
+ pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
+ let mut inner = self.inner.lock().unwrap();
+
+ Self::clear_status_lines(&mut inner);
+
+ inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
+ inner.in_progress.insert(
+ example_name.to_string(),
+ InProgressTask {
+ step,
+ started_at: Instant::now(),
+ substatus: None,
+ info: None,
+ },
+ );
+
+ Self::print_status_lines(&mut inner);
- Arc::new(StepProgress {
+ StepProgress {
progress: self.clone(),
step,
example_name: example_name.to_string(),
- })
+ }
}
- pub fn finish(&self, step: Step, example_name: &str) {
+ fn finish(&self, step: Step, example_name: &str) {
let mut inner = self.inner.lock().unwrap();
- let task = inner.in_progress.remove(example_name);
- if let Some(task) = task {
- if task.step == step {
- inner.completed.push(CompletedTask {
- step: task.step,
- example_name: example_name.to_string(),
- duration: task.started_at.elapsed(),
- info: task.info,
- });
+ let Some(task) = inner.in_progress.remove(example_name) else {
+ return;
+ };
- Self::clear_status_lines(&mut inner);
- Self::print_completed(&inner, inner.completed.last().unwrap());
- Self::print_status_lines(&mut inner);
- } else {
- inner.in_progress.insert(example_name.to_string(), task);
- }
+ if task.step == step {
+ inner.completed.push(CompletedTask {
+ step: task.step,
+ example_name: example_name.to_string(),
+ duration: task.started_at.elapsed(),
+ info: task.info,
+ });
+
+ Self::clear_status_lines(&mut inner);
+ Self::print_logging_closing_divider(&mut inner);
+ Self::print_completed(&inner, inner.completed.last().unwrap());
+ Self::print_status_lines(&mut inner);
+ } else {
+ inner.in_progress.insert(example_name.to_string(), task);
+ }
+ }
+
+ fn print_logging_closing_divider(inner: &mut ProgressInner) {
+ if inner.last_line_is_logging {
+ let reset = "\x1b[0m";
+ let dim = "\x1b[2m";
+ let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
+ eprintln!("{dim}{divider}{reset}");
+ inner.last_line_is_logging = false;
}
}
@@ -234,9 +281,10 @@ impl Progress {
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
tasks.sort_by_key(|(name, _)| *name);
+ let total_tasks = tasks.len();
let mut lines_printed = 0;
- for (name, task) in tasks.iter() {
+ for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
let elapsed = format_duration(task.started_at.elapsed());
let substatus_part = task
.substatus
@@ -265,6 +313,13 @@ impl Progress {
lines_printed += 1;
}
+ // Show "+N more" on its own line if there are more tasks
+ if total_tasks > MAX_STATUS_LINES {
+ let remaining = total_tasks - MAX_STATUS_LINES;
+ eprintln!("{:>12} +{remaining} more", "");
+ lines_printed += 1;
+ }
+
inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
let _ = std::io::stderr().flush();
}
@@ -314,6 +369,53 @@ impl Drop for StepProgress {
}
}
+struct ProgressLogger;
+
+impl Log for ProgressLogger {
+ fn enabled(&self, metadata: &Metadata) -> bool {
+ metadata.level() <= Level::Info
+ }
+
+ fn log(&self, record: &Record) {
+ if !self.enabled(record.metadata()) {
+ return;
+ }
+
+ let level_color = match record.level() {
+ Level::Error => "\x1b[31m",
+ Level::Warn => "\x1b[33m",
+ Level::Info => "\x1b[32m",
+ Level::Debug => "\x1b[34m",
+ Level::Trace => "\x1b[35m",
+ };
+ let reset = "\x1b[0m";
+ let bold = "\x1b[1m";
+
+ let level_label = match record.level() {
+ Level::Error => "Error",
+ Level::Warn => "Warn",
+ Level::Info => "Info",
+ Level::Debug => "Debug",
+ Level::Trace => "Trace",
+ };
+
+ let message = format!(
+ "{bold}{level_color}{level_label:>12}{reset} {}",
+ record.args()
+ );
+
+ if let Some(progress) = GLOBAL.get() {
+ progress.log(&message);
+ } else {
+ eprintln!("{}", message);
+ }
+ }
+
+ fn flush(&self) {
+ let _ = std::io::stderr().flush();
+ }
+}
+
#[cfg(unix)]
fn get_terminal_width() -> usize {
unsafe {
@@ -16,16 +16,17 @@ use std::time::Duration;
pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
- progress: Arc<Progress>,
mut cx: AsyncApp,
) {
if example.context.is_some() {
return;
}
- run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
+ run_load_project(example, app_state.clone(), cx.clone()).await;
- let step_progress = progress.start(Step::Context, &example.name);
+ let step_progress: Arc<StepProgress> = Progress::global()
+ .start(Step::Context, &example.name)
+ .into();
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
@@ -14,7 +14,6 @@ pub async fn run_scoring(
example: &mut Example,
args: &PredictArgs,
app_state: Arc<EpAppState>,
- progress: Arc<Progress>,
cx: AsyncApp,
) {
run_prediction(
@@ -22,12 +21,11 @@ pub async fn run_scoring(
Some(args.provider),
args.repetitions,
app_state,
- progress.clone(),
cx,
)
.await;
- let _progress = progress.start(Step::Score, &example.name);
+ let _progress = Progress::global().start(Step::Score, &example.name);
let expected_patch = parse_patch(&example.expected_patch);