edit prediction cli: Progress output cleanup (#44708)

Agus Zubiaga created

- Limit status lines to 10 in case `max_parallelism` is specified with a
grater value
- Handle logging gracefully rather than writing over it when clearing
status lines

Release Notes:

- N/A

Change summary

Cargo.lock                                         |   1 
crates/edit_prediction_cli/Cargo.toml              |   1 
crates/edit_prediction_cli/src/format_prompt.rs    |   7 
crates/edit_prediction_cli/src/load_project.rs     |  13 
crates/edit_prediction_cli/src/main.rs             |  35 --
crates/edit_prediction_cli/src/predict.rs          |  18 -
crates/edit_prediction_cli/src/progress.rs         | 198 ++++++++++++---
crates/edit_prediction_cli/src/retrieve_context.rs |   7 
crates/edit_prediction_cli/src/score.rs            |   4 
9 files changed, 173 insertions(+), 111 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5201,7 +5201,6 @@ dependencies = [
  "wasmtime",
  "watch",
  "zeta_prompt",
- "zlog",
 ]
 
 [[package]]

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -56,7 +56,6 @@ watch.workspace = true
 edit_prediction = { workspace = true, features = ["cli-support"] }
 wasmtime.workspace = true
 zeta_prompt.workspace = true
-zlog.workspace = true
 
 # Wasmtime is included as a dependency in order to enable the same
 # features that are enabled in Zed.

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -18,12 +18,11 @@ pub async fn run_format_prompt(
     example: &mut Example,
     prompt_format: PromptFormat,
     app_state: Arc<EpAppState>,
-    progress: Arc<Progress>,
     mut cx: AsyncApp,
 ) {
-    run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
+    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 
-    let _step_progress = progress.start(Step::FormatPrompt, &example.name);
+    let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
 
     match prompt_format {
         PromptFormat::Teacher => {
@@ -35,7 +34,7 @@ pub async fn run_format_prompt(
             });
         }
         PromptFormat::Zeta2 => {
-            run_load_project(example, app_state, progress.clone(), cx.clone()).await;
+            run_load_project(example, app_state, cx.clone()).await;
 
             let ep_store = cx
                 .update(|cx| EditPredictionStore::try_global(cx).unwrap())

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -25,17 +25,12 @@ use std::{
 use util::{paths::PathStyle, rel_path::RelPath};
 use zeta_prompt::CURSOR_MARKER;
 
-pub async fn run_load_project(
-    example: &mut Example,
-    app_state: Arc<EpAppState>,
-    progress: Arc<Progress>,
-    mut cx: AsyncApp,
-) {
+pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
     if example.state.is_some() {
         return;
     }
 
-    let progress = progress.start(Step::LoadProject, &example.name);
+    let progress = Progress::global().start(Step::LoadProject, &example.name);
 
     let project = setup_project(example, &app_state, &progress, &mut cx).await;
 
@@ -149,7 +144,7 @@ async fn cursor_position(
 async fn setup_project(
     example: &mut Example,
     app_state: &Arc<EpAppState>,
-    step_progress: &Arc<StepProgress>,
+    step_progress: &StepProgress,
     cx: &mut AsyncApp,
 ) -> Entity<Project> {
     let ep_store = cx
@@ -227,7 +222,7 @@ async fn setup_project(
     project
 }
 
-async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) -> PathBuf {
+async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf {
     let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
     let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
     let worktree_path = WORKTREES_DIR

crates/edit_prediction_cli/src/main.rs 🔗

@@ -32,7 +32,7 @@ use crate::score::run_scoring;
 struct EpArgs {
     #[arg(long, default_value_t = false)]
     printenv: bool,
-    #[clap(long, default_value_t = 10)]
+    #[clap(long, default_value_t = 10, global = true)]
     max_parallelism: usize,
     #[command(subcommand)]
     command: Option<Command>,
@@ -112,8 +112,6 @@ impl EpArgs {
 }
 
 fn main() {
-    let _ = zlog::try_init(Some("error".into()));
-    zlog::init_output_stderr();
     let args = EpArgs::parse();
 
     if args.printenv {
@@ -152,7 +150,7 @@ fn main() {
             };
 
             let total_examples = examples.len();
-            let progress = Progress::new(total_examples);
+            Progress::global().set_total_examples(total_examples);
 
             let mut grouped_examples = group_examples_by_repo(&mut examples);
             let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
@@ -163,29 +161,16 @@ fn main() {
                         match &command {
                             Command::ParseExample => {}
                             Command::LoadProject => {
-                                run_load_project(
-                                    example,
-                                    app_state.clone(),
-                                    progress.clone(),
-                                    cx.clone(),
-                                )
-                                .await;
+                                run_load_project(example, app_state.clone(), cx.clone()).await;
                             }
                             Command::Context => {
-                                run_context_retrieval(
-                                    example,
-                                    app_state.clone(),
-                                    progress.clone(),
-                                    cx.clone(),
-                                )
-                                .await;
+                                run_context_retrieval(example, app_state.clone(), cx.clone()).await;
                             }
                             Command::FormatPrompt(args) => {
                                 run_format_prompt(
                                     example,
                                     args.prompt_format,
                                     app_state.clone(),
-                                    progress.clone(),
                                     cx.clone(),
                                 )
                                 .await;
@@ -196,7 +181,6 @@ fn main() {
                                     Some(args.provider),
                                     args.repetitions,
                                     app_state.clone(),
-                                    progress.clone(),
                                     cx.clone(),
                                 )
                                 .await;
@@ -205,14 +189,7 @@ fn main() {
                                 run_distill(example).await;
                             }
                             Command::Score(args) | Command::Eval(args) => {
-                                run_scoring(
-                                    example,
-                                    &args,
-                                    app_state.clone(),
-                                    progress.clone(),
-                                    cx.clone(),
-                                )
-                                .await;
+                                run_scoring(example, &args, app_state.clone(), cx.clone()).await;
                             }
                             Command::Clean => {
                                 unreachable!()
@@ -222,7 +199,7 @@ fn main() {
                 });
                 futures::future::join_all(futures).await;
             }
-            progress.clear();
+            Progress::global().clear();
 
             if args.output.is_some() || !matches!(command, Command::Eval(_)) {
                 write_examples(&examples, output.as_ref());

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -25,7 +25,6 @@ pub async fn run_prediction(
     provider: Option<PredictionProvider>,
     repetition_count: usize,
     app_state: Arc<EpAppState>,
-    progress: Arc<Progress>,
     mut cx: AsyncApp,
 ) {
     if !example.predictions.is_empty() {
@@ -34,32 +33,25 @@ pub async fn run_prediction(
 
     let provider = provider.unwrap();
 
-    run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
+    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 
     if matches!(
         provider,
         PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
     ) {
-        let _step_progress = progress.start(Step::Predict, &example.name);
+        let _step_progress = Progress::global().start(Step::Predict, &example.name);
 
         if example.prompt.is_none() {
-            run_format_prompt(
-                example,
-                PromptFormat::Teacher,
-                app_state.clone(),
-                progress,
-                cx,
-            )
-            .await;
+            run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
         }
 
         let batched = matches!(provider, PredictionProvider::Teacher);
         return predict_anthropic(example, repetition_count, batched).await;
     }
 
-    run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
+    run_load_project(example, app_state.clone(), cx.clone()).await;
 
-    let _step_progress = progress.start(Step::Predict, &example.name);
+    let _step_progress = Progress::global().start(Step::Predict, &example.name);
 
     if matches!(
         provider,

crates/edit_prediction_cli/src/progress.rs 🔗

@@ -2,10 +2,12 @@ use std::{
     borrow::Cow,
     collections::HashMap,
     io::{IsTerminal, Write},
-    sync::{Arc, Mutex},
+    sync::{Arc, Mutex, OnceLock},
     time::{Duration, Instant},
 };
 
+use log::{Level, Log, Metadata, Record};
+
 pub struct Progress {
     inner: Mutex<ProgressInner>,
 }
@@ -18,6 +20,7 @@ struct ProgressInner {
     max_example_name_len: usize,
     status_lines_displayed: usize,
     total_examples: usize,
+    last_line_is_logging: bool,
 }
 
 #[derive(Clone)]
@@ -72,70 +75,114 @@ impl Step {
     }
 }
 
+static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
+static LOGGER: ProgressLogger = ProgressLogger;
+
 const RIGHT_MARGIN: usize = 4;
+const MAX_STATUS_LINES: usize = 10;
 
 impl Progress {
-    pub fn new(total_examples: usize) -> Arc<Self> {
-        Arc::new(Self {
-            inner: Mutex::new(ProgressInner {
-                completed: Vec::new(),
-                in_progress: HashMap::new(),
-                is_tty: std::io::stderr().is_terminal(),
-                terminal_width: get_terminal_width(),
-                max_example_name_len: 0,
-                status_lines_displayed: 0,
-                total_examples,
-            }),
-        })
+    /// Returns the global Progress instance, initializing it if necessary.
+    pub fn global() -> Arc<Progress> {
+        GLOBAL
+            .get_or_init(|| {
+                let progress = Arc::new(Self {
+                    inner: Mutex::new(ProgressInner {
+                        completed: Vec::new(),
+                        in_progress: HashMap::new(),
+                        is_tty: std::io::stderr().is_terminal(),
+                        terminal_width: get_terminal_width(),
+                        max_example_name_len: 0,
+                        status_lines_displayed: 0,
+                        total_examples: 0,
+                        last_line_is_logging: false,
+                    }),
+                });
+                let _ = log::set_logger(&LOGGER);
+                log::set_max_level(log::LevelFilter::Error);
+                progress
+            })
+            .clone()
     }
 
-    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> Arc<StepProgress> {
-        {
-            let mut inner = self.inner.lock().unwrap();
+    pub fn set_total_examples(&self, total: usize) {
+        let mut inner = self.inner.lock().unwrap();
+        inner.total_examples = total;
+    }
 
-            Self::clear_status_lines(&mut inner);
+    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
+    /// This should be used for any output that needs to appear above the status lines.
+    fn log(&self, message: &str) {
+        let mut inner = self.inner.lock().unwrap();
+        Self::clear_status_lines(&mut inner);
 
-            inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
+        if !inner.last_line_is_logging {
+            let reset = "\x1b[0m";
+            let dim = "\x1b[2m";
+            let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
+            eprintln!("{dim}{divider}{reset}");
+            inner.last_line_is_logging = true;
+        }
 
-            inner.in_progress.insert(
-                example_name.to_string(),
-                InProgressTask {
-                    step,
-                    started_at: Instant::now(),
-                    substatus: None,
-                    info: None,
-                },
-            );
+        eprintln!("{}", message);
+    }
 
-            Self::print_status_lines(&mut inner);
-        }
+    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
+        let mut inner = self.inner.lock().unwrap();
+
+        Self::clear_status_lines(&mut inner);
+
+        inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
+        inner.in_progress.insert(
+            example_name.to_string(),
+            InProgressTask {
+                step,
+                started_at: Instant::now(),
+                substatus: None,
+                info: None,
+            },
+        );
+
+        Self::print_status_lines(&mut inner);
 
-        Arc::new(StepProgress {
+        StepProgress {
             progress: self.clone(),
             step,
             example_name: example_name.to_string(),
-        })
+        }
     }
 
-    pub fn finish(&self, step: Step, example_name: &str) {
+    fn finish(&self, step: Step, example_name: &str) {
         let mut inner = self.inner.lock().unwrap();
 
-        let task = inner.in_progress.remove(example_name);
-        if let Some(task) = task {
-            if task.step == step {
-                inner.completed.push(CompletedTask {
-                    step: task.step,
-                    example_name: example_name.to_string(),
-                    duration: task.started_at.elapsed(),
-                    info: task.info,
-                });
+        let Some(task) = inner.in_progress.remove(example_name) else {
+            return;
+        };
 
-                Self::clear_status_lines(&mut inner);
-                Self::print_completed(&inner, inner.completed.last().unwrap());
-                Self::print_status_lines(&mut inner);
-            } else {
-                inner.in_progress.insert(example_name.to_string(), task);
-            }
+        if task.step == step {
+            inner.completed.push(CompletedTask {
+                step: task.step,
+                example_name: example_name.to_string(),
+                duration: task.started_at.elapsed(),
+                info: task.info,
+            });
+
+            Self::clear_status_lines(&mut inner);
+            Self::print_logging_closing_divider(&mut inner);
+            Self::print_completed(&inner, inner.completed.last().unwrap());
+            Self::print_status_lines(&mut inner);
+        } else {
+            inner.in_progress.insert(example_name.to_string(), task);
+        }
+    }
+
+    fn print_logging_closing_divider(inner: &mut ProgressInner) {
+        if inner.last_line_is_logging {
+            let reset = "\x1b[0m";
+            let dim = "\x1b[2m";
+            let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
+            eprintln!("{dim}{divider}{reset}");
+            inner.last_line_is_logging = false;
         }
     }
 
@@ -234,9 +281,10 @@ impl Progress {
         let mut tasks: Vec<_> = inner.in_progress.iter().collect();
         tasks.sort_by_key(|(name, _)| *name);
 
+        let total_tasks = tasks.len();
         let mut lines_printed = 0;
 
-        for (name, task) in tasks.iter() {
+        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
             let elapsed = format_duration(task.started_at.elapsed());
             let substatus_part = task
                 .substatus
@@ -265,6 +313,13 @@ impl Progress {
             lines_printed += 1;
         }
 
+        // Show "+N more" on its own line if there are more tasks
+        if total_tasks > MAX_STATUS_LINES {
+            let remaining = total_tasks - MAX_STATUS_LINES;
+            eprintln!("{:>12} +{remaining} more", "");
+            lines_printed += 1;
+        }
+
         inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
         let _ = std::io::stderr().flush();
     }
@@ -314,6 +369,53 @@ impl Drop for StepProgress {
     }
 }
 
+struct ProgressLogger;
+
+impl Log for ProgressLogger {
+    fn enabled(&self, metadata: &Metadata) -> bool {
+        metadata.level() <= Level::Info
+    }
+
+    fn log(&self, record: &Record) {
+        if !self.enabled(record.metadata()) {
+            return;
+        }
+
+        let level_color = match record.level() {
+            Level::Error => "\x1b[31m",
+            Level::Warn => "\x1b[33m",
+            Level::Info => "\x1b[32m",
+            Level::Debug => "\x1b[34m",
+            Level::Trace => "\x1b[35m",
+        };
+        let reset = "\x1b[0m";
+        let bold = "\x1b[1m";
+
+        let level_label = match record.level() {
+            Level::Error => "Error",
+            Level::Warn => "Warn",
+            Level::Info => "Info",
+            Level::Debug => "Debug",
+            Level::Trace => "Trace",
+        };
+
+        let message = format!(
+            "{bold}{level_color}{level_label:>12}{reset} {}",
+            record.args()
+        );
+
+        if let Some(progress) = GLOBAL.get() {
+            progress.log(&message);
+        } else {
+            eprintln!("{}", message);
+        }
+    }
+
+    fn flush(&self) {
+        let _ = std::io::stderr().flush();
+    }
+}
+
 #[cfg(unix)]
 fn get_terminal_width() -> usize {
     unsafe {

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -16,16 +16,17 @@ use std::time::Duration;
 pub async fn run_context_retrieval(
     example: &mut Example,
     app_state: Arc<EpAppState>,
-    progress: Arc<Progress>,
     mut cx: AsyncApp,
 ) {
     if example.context.is_some() {
         return;
     }
 
-    run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
+    run_load_project(example, app_state.clone(), cx.clone()).await;
 
-    let step_progress = progress.start(Step::Context, &example.name);
+    let step_progress: Arc<StepProgress> = Progress::global()
+        .start(Step::Context, &example.name)
+        .into();
 
     let state = example.state.as_ref().unwrap();
     let project = state.project.clone();

crates/edit_prediction_cli/src/score.rs 🔗

@@ -14,7 +14,6 @@ pub async fn run_scoring(
     example: &mut Example,
     args: &PredictArgs,
     app_state: Arc<EpAppState>,
-    progress: Arc<Progress>,
     cx: AsyncApp,
 ) {
     run_prediction(
@@ -22,12 +21,11 @@ pub async fn run_scoring(
         Some(args.provider),
         args.repetitions,
         app_state,
-        progress.clone(),
         cx,
     )
     .await;
 
-    let _progress = progress.start(Step::Score, &example.name);
+    let _progress = Progress::global().start(Step::Score, &example.name);
 
     let expected_patch = parse_patch(&example.expected_patch);