diff --git a/Cargo.lock b/Cargo.lock index 2447303bacc666324a99c54247ab70f950d3bb0c..928e9f1a1db069d4e14cb80fe909aa22ac93e1ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5201,7 +5201,6 @@ dependencies = [ "wasmtime", "watch", "zeta_prompt", - "zlog", ] [[package]] diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 61e55e09a3b0b46a7d6ad0338be3ab76c1e08401..811808c72304f4c11a9858e61395e46024b83f1e 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -56,7 +56,6 @@ watch.workspace = true edit_prediction = { workspace = true, features = ["cli-support"] } wasmtime.workspace = true zeta_prompt.workspace = true -zlog.workspace = true # Wasmtime is included as a dependency in order to enable the same # features that are enabled in Zed. diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 2225f1d294144753408968c6f464988378e2691d..017e11a54c77e06bde7b74ed3f924692e33cd480 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -18,12 +18,11 @@ pub async fn run_format_prompt( example: &mut Example, prompt_format: PromptFormat, app_state: Arc, - progress: Arc, mut cx: AsyncApp, ) { - run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await; + run_context_retrieval(example, app_state.clone(), cx.clone()).await; - let _step_progress = progress.start(Step::FormatPrompt, &example.name); + let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name); match prompt_format { PromptFormat::Teacher => { @@ -35,7 +34,7 @@ pub async fn run_format_prompt( }); } PromptFormat::Zeta2 => { - run_load_project(example, app_state, progress.clone(), cx.clone()).await; + run_load_project(example, app_state, cx.clone()).await; let ep_store = cx .update(|cx| EditPredictionStore::try_global(cx).unwrap()) diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 895105966713f653a0ce8277387276a0ae40a4bc..4d98ae9f3b85f4e6253d9ead4d846ed3d9deee89 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -25,17 +25,12 @@ use std::{ use util::{paths::PathStyle, rel_path::RelPath}; use zeta_prompt::CURSOR_MARKER; -pub async fn run_load_project( - example: &mut Example, - app_state: Arc, - progress: Arc, - mut cx: AsyncApp, -) { +pub async fn run_load_project(example: &mut Example, app_state: Arc, mut cx: AsyncApp) { if example.state.is_some() { return; } - let progress = progress.start(Step::LoadProject, &example.name); + let progress = Progress::global().start(Step::LoadProject, &example.name); let project = setup_project(example, &app_state, &progress, &mut cx).await; @@ -149,7 +144,7 @@ async fn cursor_position( async fn setup_project( example: &mut Example, app_state: &Arc, - step_progress: &Arc, + step_progress: &StepProgress, cx: &mut AsyncApp, ) -> Entity { let ep_store = cx @@ -227,7 +222,7 @@ async fn setup_project( project } -async fn setup_worktree(example: &Example, step_progress: &Arc) -> PathBuf { +async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf { let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name"); let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); let worktree_path = WORKTREES_DIR diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index b053af128c82c1aeefb35756ec28bc22a3ff2387..075f8862e6f86276a0df550c6d27f8c15a5d1293 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -32,7 +32,7 @@ use crate::score::run_scoring; struct EpArgs { #[arg(long, default_value_t = false)] printenv: bool, - #[clap(long, default_value_t = 10)] + #[clap(long, default_value_t = 10, global = true)] max_parallelism: usize, #[command(subcommand)] command: Option, @@ -112,8 +112,6 @@ impl EpArgs { } fn main() { - let _ = zlog::try_init(Some("error".into())); - zlog::init_output_stderr(); let args = EpArgs::parse(); if args.printenv { @@ -152,7 +150,7 @@ fn main() { }; let total_examples = examples.len(); - let progress = Progress::new(total_examples); + Progress::global().set_total_examples(total_examples); let mut grouped_examples = group_examples_by_repo(&mut examples); let example_batches = grouped_examples.chunks_mut(args.max_parallelism); @@ -163,29 +161,16 @@ fn main() { match &command { Command::ParseExample => {} Command::LoadProject => { - run_load_project( - example, - app_state.clone(), - progress.clone(), - cx.clone(), - ) - .await; + run_load_project(example, app_state.clone(), cx.clone()).await; } Command::Context => { - run_context_retrieval( - example, - app_state.clone(), - progress.clone(), - cx.clone(), - ) - .await; + run_context_retrieval(example, app_state.clone(), cx.clone()).await; } Command::FormatPrompt(args) => { run_format_prompt( example, args.prompt_format, app_state.clone(), - progress.clone(), cx.clone(), ) .await; @@ -196,7 +181,6 @@ fn main() { Some(args.provider), args.repetitions, app_state.clone(), - progress.clone(), cx.clone(), ) .await; @@ -205,14 +189,7 @@ fn main() { run_distill(example).await; } Command::Score(args) | Command::Eval(args) => { - run_scoring( - example, - &args, - app_state.clone(), - progress.clone(), - cx.clone(), - ) - .await; + run_scoring(example, &args, app_state.clone(), cx.clone()).await; } Command::Clean => { unreachable!() @@ -222,7 +199,7 @@ fn main() { }); futures::future::join_all(futures).await; } - progress.clear(); + Progress::global().clear(); if args.output.is_some() || !matches!(command, Command::Eval(_)) { write_examples(&examples, output.as_ref()); diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 14628a896273f7ff11166a1daac248598e198847..3f690266e3165b2d52f642457e7aebf959a40a03 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -25,7 +25,6 @@ pub async fn run_prediction( provider: Option, repetition_count: usize, app_state: Arc, - progress: Arc, mut cx: AsyncApp, ) { if !example.predictions.is_empty() { @@ -34,32 +33,25 @@ pub async fn run_prediction( let provider = provider.unwrap(); - run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await; + run_context_retrieval(example, app_state.clone(), cx.clone()).await; if matches!( provider, PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching ) { - let _step_progress = progress.start(Step::Predict, &example.name); + let _step_progress = Progress::global().start(Step::Predict, &example.name); if example.prompt.is_none() { - run_format_prompt( - example, - PromptFormat::Teacher, - app_state.clone(), - progress, - cx, - ) - .await; + run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await; } let batched = matches!(provider, PredictionProvider::Teacher); return predict_anthropic(example, repetition_count, batched).await; } - run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await; + run_load_project(example, app_state.clone(), cx.clone()).await; - let _step_progress = progress.start(Step::Predict, &example.name); + let _step_progress = Progress::global().start(Step::Predict, &example.name); if matches!( provider, diff --git a/crates/edit_prediction_cli/src/progress.rs b/crates/edit_prediction_cli/src/progress.rs index 5cd906d89a20813676b09af0d2cbeca532c5ba12..8195485d70c70c0cbfb38e2de83a055598d5e4e5 100644 --- a/crates/edit_prediction_cli/src/progress.rs +++ b/crates/edit_prediction_cli/src/progress.rs @@ -2,10 +2,12 @@ use std::{ borrow::Cow, collections::HashMap, io::{IsTerminal, Write}, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, OnceLock}, time::{Duration, Instant}, }; +use log::{Level, Log, Metadata, Record}; + pub struct Progress { inner: Mutex, } @@ -18,6 +20,7 @@ struct ProgressInner { max_example_name_len: usize, status_lines_displayed: usize, total_examples: usize, + last_line_is_logging: bool, } #[derive(Clone)] @@ -72,70 +75,114 @@ impl Step { } } +static GLOBAL: OnceLock> = OnceLock::new(); +static LOGGER: ProgressLogger = ProgressLogger; + const RIGHT_MARGIN: usize = 4; +const MAX_STATUS_LINES: usize = 10; impl Progress { - pub fn new(total_examples: usize) -> Arc { - Arc::new(Self { - inner: Mutex::new(ProgressInner { - completed: Vec::new(), - in_progress: HashMap::new(), - is_tty: std::io::stderr().is_terminal(), - terminal_width: get_terminal_width(), - max_example_name_len: 0, - status_lines_displayed: 0, - total_examples, - }), - }) + /// Returns the global Progress instance, initializing it if necessary. + pub fn global() -> Arc { + GLOBAL + .get_or_init(|| { + let progress = Arc::new(Self { + inner: Mutex::new(ProgressInner { + completed: Vec::new(), + in_progress: HashMap::new(), + is_tty: std::io::stderr().is_terminal(), + terminal_width: get_terminal_width(), + max_example_name_len: 0, + status_lines_displayed: 0, + total_examples: 0, + last_line_is_logging: false, + }), + }); + let _ = log::set_logger(&LOGGER); + log::set_max_level(log::LevelFilter::Error); + progress + }) + .clone() } - pub fn start(self: &Arc, step: Step, example_name: &str) -> Arc { - { - let mut inner = self.inner.lock().unwrap(); + pub fn set_total_examples(&self, total: usize) { + let mut inner = self.inner.lock().unwrap(); + inner.total_examples = total; + } - Self::clear_status_lines(&mut inner); + /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption. + /// This should be used for any output that needs to appear above the status lines. + fn log(&self, message: &str) { + let mut inner = self.inner.lock().unwrap(); + Self::clear_status_lines(&mut inner); - inner.max_example_name_len = inner.max_example_name_len.max(example_name.len()); + if !inner.last_line_is_logging { + let reset = "\x1b[0m"; + let dim = "\x1b[2m"; + let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN)); + eprintln!("{dim}{divider}{reset}"); + inner.last_line_is_logging = true; + } - inner.in_progress.insert( - example_name.to_string(), - InProgressTask { - step, - started_at: Instant::now(), - substatus: None, - info: None, - }, - ); + eprintln!("{}", message); + } - Self::print_status_lines(&mut inner); - } + pub fn start(self: &Arc, step: Step, example_name: &str) -> StepProgress { + let mut inner = self.inner.lock().unwrap(); + + Self::clear_status_lines(&mut inner); + + inner.max_example_name_len = inner.max_example_name_len.max(example_name.len()); + inner.in_progress.insert( + example_name.to_string(), + InProgressTask { + step, + started_at: Instant::now(), + substatus: None, + info: None, + }, + ); + + Self::print_status_lines(&mut inner); - Arc::new(StepProgress { + StepProgress { progress: self.clone(), step, example_name: example_name.to_string(), - }) + } } - pub fn finish(&self, step: Step, example_name: &str) { + fn finish(&self, step: Step, example_name: &str) { let mut inner = self.inner.lock().unwrap(); - let task = inner.in_progress.remove(example_name); - if let Some(task) = task { - if task.step == step { - inner.completed.push(CompletedTask { - step: task.step, - example_name: example_name.to_string(), - duration: task.started_at.elapsed(), - info: task.info, - }); + let Some(task) = inner.in_progress.remove(example_name) else { + return; + }; - Self::clear_status_lines(&mut inner); - Self::print_completed(&inner, inner.completed.last().unwrap()); - Self::print_status_lines(&mut inner); - } else { - inner.in_progress.insert(example_name.to_string(), task); - } + if task.step == step { + inner.completed.push(CompletedTask { + step: task.step, + example_name: example_name.to_string(), + duration: task.started_at.elapsed(), + info: task.info, + }); + + Self::clear_status_lines(&mut inner); + Self::print_logging_closing_divider(&mut inner); + Self::print_completed(&inner, inner.completed.last().unwrap()); + Self::print_status_lines(&mut inner); + } else { + inner.in_progress.insert(example_name.to_string(), task); + } + } + + fn print_logging_closing_divider(inner: &mut ProgressInner) { + if inner.last_line_is_logging { + let reset = "\x1b[0m"; + let dim = "\x1b[2m"; + let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN)); + eprintln!("{dim}{divider}{reset}"); + inner.last_line_is_logging = false; } } @@ -234,9 +281,10 @@ impl Progress { let mut tasks: Vec<_> = inner.in_progress.iter().collect(); tasks.sort_by_key(|(name, _)| *name); + let total_tasks = tasks.len(); let mut lines_printed = 0; - for (name, task) in tasks.iter() { + for (name, task) in tasks.iter().take(MAX_STATUS_LINES) { let elapsed = format_duration(task.started_at.elapsed()); let substatus_part = task .substatus @@ -265,6 +313,13 @@ impl Progress { lines_printed += 1; } + // Show "+N more" on its own line if there are more tasks + if total_tasks > MAX_STATUS_LINES { + let remaining = total_tasks - MAX_STATUS_LINES; + eprintln!("{:>12} +{remaining} more", ""); + lines_printed += 1; + } + inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line let _ = std::io::stderr().flush(); } @@ -314,6 +369,53 @@ impl Drop for StepProgress { } } +struct ProgressLogger; + +impl Log for ProgressLogger { + fn enabled(&self, metadata: &Metadata) -> bool { + metadata.level() <= Level::Info + } + + fn log(&self, record: &Record) { + if !self.enabled(record.metadata()) { + return; + } + + let level_color = match record.level() { + Level::Error => "\x1b[31m", + Level::Warn => "\x1b[33m", + Level::Info => "\x1b[32m", + Level::Debug => "\x1b[34m", + Level::Trace => "\x1b[35m", + }; + let reset = "\x1b[0m"; + let bold = "\x1b[1m"; + + let level_label = match record.level() { + Level::Error => "Error", + Level::Warn => "Warn", + Level::Info => "Info", + Level::Debug => "Debug", + Level::Trace => "Trace", + }; + + let message = format!( + "{bold}{level_color}{level_label:>12}{reset} {}", + record.args() + ); + + if let Some(progress) = GLOBAL.get() { + progress.log(&message); + } else { + eprintln!("{}", message); + } + } + + fn flush(&self) { + let _ = std::io::stderr().flush(); + } +} + #[cfg(unix)] fn get_terminal_width() -> usize { unsafe { diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index 83b5906e976ca3a1a6bdff6a96c36713eef08058..c066cf3caa9ece27144222ef94e3ac72c2285be8 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -16,16 +16,17 @@ use std::time::Duration; pub async fn run_context_retrieval( example: &mut Example, app_state: Arc, - progress: Arc, mut cx: AsyncApp, ) { if example.context.is_some() { return; } - run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await; + run_load_project(example, app_state.clone(), cx.clone()).await; - let step_progress = progress.start(Step::Context, &example.name); + let step_progress: Arc = Progress::global() + .start(Step::Context, &example.name) + .into(); let state = example.state.as_ref().unwrap(); let project = state.project.clone(); diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 23086dcc6e9279820216961ef0fe9fc65c3ea3eb..b87d8e4df24c8cb12676ed71fe1ea930a841791d 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -14,7 +14,6 @@ pub async fn run_scoring( example: &mut Example, args: &PredictArgs, app_state: Arc, - progress: Arc, cx: AsyncApp, ) { run_prediction( @@ -22,12 +21,11 @@ pub async fn run_scoring( Some(args.provider), args.repetitions, app_state, - progress.clone(), cx, ) .await; - let _progress = progress.start(Step::Score, &example.name); + let _progress = Progress::global().start(Step::Score, &example.name); let expected_patch = parse_patch(&example.expected_patch);