progress.rs

  1use std::{
  2    borrow::Cow,
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::{Arc, Mutex, OnceLock},
  6    time::{Duration, Instant},
  7};
  8
  9use log::{Level, Log, Metadata, Record};
 10
 11pub struct Progress {
 12    inner: Mutex<ProgressInner>,
 13}
 14
 15struct ProgressInner {
 16    completed: Vec<CompletedTask>,
 17    in_progress: HashMap<String, InProgressTask>,
 18    is_tty: bool,
 19    terminal_width: usize,
 20    max_example_name_len: usize,
 21    status_lines_displayed: usize,
 22    total_examples: usize,
 23    failed_examples: usize,
 24    last_line_is_logging: bool,
 25}
 26
 27#[derive(Clone)]
 28struct InProgressTask {
 29    step: Step,
 30    started_at: Instant,
 31    substatus: Option<String>,
 32    info: Option<(String, InfoStyle)>,
 33}
 34
 35struct CompletedTask {
 36    step: Step,
 37    example_name: String,
 38    duration: Duration,
 39    info: Option<(String, InfoStyle)>,
 40}
 41
 42#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 43pub enum Step {
 44    LoadProject,
 45    Context,
 46    FormatPrompt,
 47    Predict,
 48    Score,
 49}
 50
 51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 52pub enum InfoStyle {
 53    Normal,
 54    Warning,
 55}
 56
 57impl Step {
 58    pub fn label(&self) -> &'static str {
 59        match self {
 60            Step::LoadProject => "Load",
 61            Step::Context => "Context",
 62            Step::FormatPrompt => "Format",
 63            Step::Predict => "Predict",
 64            Step::Score => "Score",
 65        }
 66    }
 67
 68    fn color_code(&self) -> &'static str {
 69        match self {
 70            Step::LoadProject => "\x1b[33m",
 71            Step::Context => "\x1b[35m",
 72            Step::FormatPrompt => "\x1b[34m",
 73            Step::Predict => "\x1b[32m",
 74            Step::Score => "\x1b[31m",
 75        }
 76    }
 77}
 78
 79static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
 80static LOGGER: ProgressLogger = ProgressLogger;
 81
 82const MARGIN: usize = 4;
 83const MAX_STATUS_LINES: usize = 10;
 84
 85impl Progress {
 86    /// Returns the global Progress instance, initializing it if necessary.
 87    pub fn global() -> Arc<Progress> {
 88        GLOBAL
 89            .get_or_init(|| {
 90                let progress = Arc::new(Self {
 91                    inner: Mutex::new(ProgressInner {
 92                        completed: Vec::new(),
 93                        in_progress: HashMap::new(),
 94                        is_tty: std::io::stderr().is_terminal(),
 95                        terminal_width: get_terminal_width(),
 96                        max_example_name_len: 0,
 97                        status_lines_displayed: 0,
 98                        total_examples: 0,
 99                        failed_examples: 0,
100                        last_line_is_logging: false,
101                    }),
102                });
103                let _ = log::set_logger(&LOGGER);
104                log::set_max_level(log::LevelFilter::Error);
105                progress
106            })
107            .clone()
108    }
109
110    pub fn set_total_examples(&self, total: usize) {
111        let mut inner = self.inner.lock().unwrap();
112        inner.total_examples = total;
113    }
114
115    pub fn increment_failed(&self) {
116        let mut inner = self.inner.lock().unwrap();
117        inner.failed_examples += 1;
118    }
119
120    /// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
121    /// This should be used for any output that needs to appear above the status lines.
122    fn log(&self, message: &str) {
123        let mut inner = self.inner.lock().unwrap();
124        Self::clear_status_lines(&mut inner);
125
126        if !inner.last_line_is_logging {
127            let reset = "\x1b[0m";
128            let dim = "\x1b[2m";
129            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
130            eprintln!("{dim}{divider}{reset}");
131            inner.last_line_is_logging = true;
132        }
133
134        eprintln!("{}", message);
135    }
136
137    pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
138        let mut inner = self.inner.lock().unwrap();
139
140        Self::clear_status_lines(&mut inner);
141
142        inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
143        inner.in_progress.insert(
144            example_name.to_string(),
145            InProgressTask {
146                step,
147                started_at: Instant::now(),
148                substatus: None,
149                info: None,
150            },
151        );
152
153        Self::print_status_lines(&mut inner);
154
155        StepProgress {
156            progress: self.clone(),
157            step,
158            example_name: example_name.to_string(),
159        }
160    }
161
162    fn finish(&self, step: Step, example_name: &str) {
163        let mut inner = self.inner.lock().unwrap();
164
165        let Some(task) = inner.in_progress.remove(example_name) else {
166            return;
167        };
168
169        if task.step == step {
170            inner.completed.push(CompletedTask {
171                step: task.step,
172                example_name: example_name.to_string(),
173                duration: task.started_at.elapsed(),
174                info: task.info,
175            });
176
177            Self::clear_status_lines(&mut inner);
178            Self::print_logging_closing_divider(&mut inner);
179            Self::print_completed(&inner, inner.completed.last().unwrap());
180            Self::print_status_lines(&mut inner);
181        } else {
182            inner.in_progress.insert(example_name.to_string(), task);
183        }
184    }
185
186    fn print_logging_closing_divider(inner: &mut ProgressInner) {
187        if inner.last_line_is_logging {
188            let reset = "\x1b[0m";
189            let dim = "\x1b[2m";
190            let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
191            eprintln!("{dim}{divider}{reset}");
192            inner.last_line_is_logging = false;
193        }
194    }
195
196    fn clear_status_lines(inner: &mut ProgressInner) {
197        if inner.is_tty && inner.status_lines_displayed > 0 {
198            // Move up and clear each line we previously displayed
199            for _ in 0..inner.status_lines_displayed {
200                eprint!("\x1b[A\x1b[K");
201            }
202            let _ = std::io::stderr().flush();
203            inner.status_lines_displayed = 0;
204        }
205    }
206
207    fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
208        let duration = format_duration(task.duration);
209        let name_width = inner.max_example_name_len;
210
211        if inner.is_tty {
212            let reset = "\x1b[0m";
213            let bold = "\x1b[1m";
214            let dim = "\x1b[2m";
215
216            let yellow = "\x1b[33m";
217            let info_part = task
218                .info
219                .as_ref()
220                .map(|(s, style)| {
221                    if *style == InfoStyle::Warning {
222                        format!("{yellow}{s}{reset}")
223                    } else {
224                        s.to_string()
225                    }
226                })
227                .unwrap_or_default();
228
229            let prefix = format!(
230                "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
231                color = task.step.color_code(),
232                label = task.step.label(),
233                name = task.example_name,
234            );
235
236            let duration_with_margin = format!("{duration} ");
237            let padding_needed = inner
238                .terminal_width
239                .saturating_sub(MARGIN)
240                .saturating_sub(duration_with_margin.len())
241                .saturating_sub(strip_ansi_len(&prefix));
242            let padding = " ".repeat(padding_needed);
243
244            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
245        } else {
246            let info_part = task
247                .info
248                .as_ref()
249                .map(|(s, _)| format!(" | {}", s))
250                .unwrap_or_default();
251
252            eprintln!(
253                "{label:>12} {name:<name_width$}{info_part} {duration}",
254                label = task.step.label(),
255                name = task.example_name,
256            );
257        }
258    }
259
260    fn print_status_lines(inner: &mut ProgressInner) {
261        if !inner.is_tty || inner.in_progress.is_empty() {
262            inner.status_lines_displayed = 0;
263            return;
264        }
265
266        let reset = "\x1b[0m";
267        let bold = "\x1b[1m";
268        let dim = "\x1b[2m";
269
270        // Build the done/in-progress/total label
271        let done_count = inner.completed.len();
272        let in_progress_count = inner.in_progress.len();
273        let failed_count = inner.failed_examples;
274
275        let failed_label = if failed_count > 0 {
276            format!(" {} failed ", failed_count)
277        } else {
278            String::new()
279        };
280
281        let range_label = format!(
282            " {}/{}/{} ",
283            done_count, in_progress_count, inner.total_examples
284        );
285
286        // Print a divider line with failed count on left, range label on right
287        let failed_visible_len = strip_ansi_len(&failed_label);
288        let range_visible_len = range_label.len();
289        let middle_divider_len = inner
290            .terminal_width
291            .saturating_sub(MARGIN * 2)
292            .saturating_sub(failed_visible_len)
293            .saturating_sub(range_visible_len);
294        let left_divider = "".repeat(MARGIN);
295        let middle_divider = "".repeat(middle_divider_len);
296        let right_divider = "".repeat(MARGIN);
297        eprintln!(
298            "{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
299        );
300
301        let mut tasks: Vec<_> = inner.in_progress.iter().collect();
302        tasks.sort_by_key(|(name, _)| *name);
303
304        let total_tasks = tasks.len();
305        let mut lines_printed = 0;
306
307        for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
308            let elapsed = format_duration(task.started_at.elapsed());
309            let substatus_part = task
310                .substatus
311                .as_ref()
312                .map(|s| truncate_with_ellipsis(s, 30))
313                .unwrap_or_default();
314
315            let step_label = task.step.label();
316            let step_color = task.step.color_code();
317            let name_width = inner.max_example_name_len;
318
319            let prefix = format!(
320                "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
321                name = name,
322            );
323
324            let duration_with_margin = format!("{elapsed} ");
325            let padding_needed = inner
326                .terminal_width
327                .saturating_sub(MARGIN)
328                .saturating_sub(duration_with_margin.len())
329                .saturating_sub(strip_ansi_len(&prefix));
330            let padding = " ".repeat(padding_needed);
331
332            eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
333            lines_printed += 1;
334        }
335
336        // Show "+N more" on its own line if there are more tasks
337        if total_tasks > MAX_STATUS_LINES {
338            let remaining = total_tasks - MAX_STATUS_LINES;
339            eprintln!("{:>12} +{remaining} more", "");
340            lines_printed += 1;
341        }
342
343        inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
344        let _ = std::io::stderr().flush();
345    }
346
347    pub fn finalize(&self) {
348        let mut inner = self.inner.lock().unwrap();
349        Self::clear_status_lines(&mut inner);
350
351        // Print summary if there were failures
352        if inner.failed_examples > 0 {
353            let total_processed = inner.completed.len() + inner.failed_examples;
354            let percentage = if total_processed > 0 {
355                inner.failed_examples as f64 / total_processed as f64 * 100.0
356            } else {
357                0.0
358            };
359            eprintln!(
360                "\n{} of {} examples failed ({:.1}%)",
361                inner.failed_examples, total_processed, percentage
362            );
363        }
364    }
365}
366
367pub struct StepProgress {
368    progress: Arc<Progress>,
369    step: Step,
370    example_name: String,
371}
372
373impl StepProgress {
374    pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
375        let mut inner = self.progress.inner.lock().unwrap();
376        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
377            task.substatus = Some(substatus.into().into_owned());
378            Progress::clear_status_lines(&mut inner);
379            Progress::print_status_lines(&mut inner);
380        }
381    }
382
383    pub fn clear_substatus(&self) {
384        let mut inner = self.progress.inner.lock().unwrap();
385        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
386            task.substatus = None;
387            Progress::clear_status_lines(&mut inner);
388            Progress::print_status_lines(&mut inner);
389        }
390    }
391
392    pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
393        let mut inner = self.progress.inner.lock().unwrap();
394        if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
395            task.info = Some((info.into(), style));
396        }
397    }
398}
399
400impl Drop for StepProgress {
401    fn drop(&mut self) {
402        self.progress.finish(self.step, &self.example_name);
403    }
404}
405
406struct ProgressLogger;
407
408impl Log for ProgressLogger {
409    fn enabled(&self, metadata: &Metadata) -> bool {
410        metadata.level() <= Level::Info
411    }
412
413    fn log(&self, record: &Record) {
414        if !self.enabled(record.metadata()) {
415            return;
416        }
417
418        let level_color = match record.level() {
419            Level::Error => "\x1b[31m",
420            Level::Warn => "\x1b[33m",
421            Level::Info => "\x1b[32m",
422            Level::Debug => "\x1b[34m",
423            Level::Trace => "\x1b[35m",
424        };
425        let reset = "\x1b[0m";
426        let bold = "\x1b[1m";
427
428        let level_label = match record.level() {
429            Level::Error => "Error",
430            Level::Warn => "Warn",
431            Level::Info => "Info",
432            Level::Debug => "Debug",
433            Level::Trace => "Trace",
434        };
435
436        let message = format!(
437            "{bold}{level_color}{level_label:>12}{reset} {}",
438            record.args()
439        );
440
441        if let Some(progress) = GLOBAL.get() {
442            progress.log(&message);
443        } else {
444            eprintln!("{}", message);
445        }
446    }
447
448    fn flush(&self) {
449        let _ = std::io::stderr().flush();
450    }
451}
452
453#[cfg(unix)]
454fn get_terminal_width() -> usize {
455    unsafe {
456        let mut winsize: libc::winsize = std::mem::zeroed();
457        if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
458            && winsize.ws_col > 0
459        {
460            winsize.ws_col as usize
461        } else {
462            80
463        }
464    }
465}
466
467#[cfg(not(unix))]
468fn get_terminal_width() -> usize {
469    80
470}
471
472fn strip_ansi_len(s: &str) -> usize {
473    let mut len = 0;
474    let mut in_escape = false;
475    for c in s.chars() {
476        if c == '\x1b' {
477            in_escape = true;
478        } else if in_escape {
479            if c == 'm' {
480                in_escape = false;
481            }
482        } else {
483            len += 1;
484        }
485    }
486    len
487}
488
489fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
490    if s.len() <= max_len {
491        s.to_string()
492    } else {
493        format!("{}", &s[..max_len.saturating_sub(1)])
494    }
495}
496
497fn format_duration(duration: Duration) -> String {
498    const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
499
500    let millis = duration.as_millis() as f32;
501    if millis < 1000.0 {
502        format!("{}ms", millis)
503    } else if millis < MINUTE_IN_MILLIS {
504        format!("{:.1}s", millis / 1_000.0)
505    } else {
506        format!("{:.1}m", millis / MINUTE_IN_MILLIS)
507    }
508}